From 0886e68f247bc47be993959bfc7a5ea35418d21f Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Tue, 9 Apr 2019 13:26:25 -0700 Subject: [PATCH 0001/1918] Initial Release --- flytestdlib/.gitignore | 113 +++ flytestdlib/.golangci.yml | 25 + flytestdlib/.travis.yml | 20 + flytestdlib/CODE_OF_CONDUCT.md | 2 + flytestdlib/Gopkg.lock | 516 ++++++++++++++ flytestdlib/Gopkg.toml | 72 ++ flytestdlib/LICENSE | 202 ++++++ flytestdlib/Makefile | 24 + flytestdlib/NOTICE | 21 + flytestdlib/README.rst | 35 + flytestdlib/atomic/atomic.go | 129 ++++ flytestdlib/atomic/atomic_test.go | 61 ++ flytestdlib/atomic/non_blocking_lock.go | 23 + flytestdlib/atomic/non_blocking_lock_test.go | 15 + .../lyft/golang_test_targets/Makefile | 31 + .../lyft/golang_test_targets/Readme.rst | 31 + .../lyft/golang_test_targets/goimports | 3 + flytestdlib/boilerplate/update.cfg | 2 + flytestdlib/cli/pflags/api/generator.go | 277 ++++++++ flytestdlib/cli/pflags/api/generator_test.go | 56 ++ flytestdlib/cli/pflags/api/pflag_provider.go | 90 +++ flytestdlib/cli/pflags/api/sample.go | 53 ++ flytestdlib/cli/pflags/api/tag.go | 66 ++ flytestdlib/cli/pflags/api/templates.go | 175 +++++ .../cli/pflags/api/testdata/testtype.go | 36 + .../cli/pflags/api/testdata/testtype_test.go | 520 ++++++++++++++ flytestdlib/cli/pflags/api/types.go | 36 + flytestdlib/cli/pflags/api/utils.go | 32 + flytestdlib/cli/pflags/cmd/root.go | 71 ++ flytestdlib/cli/pflags/cmd/version.go | 17 + flytestdlib/cli/pflags/main.go | 15 + flytestdlib/cli/pflags/readme.rst | 24 + flytestdlib/config/accessor.go | 61 ++ flytestdlib/config/accessor_test.go | 91 +++ flytestdlib/config/config_cmd.go | 113 +++ flytestdlib/config/config_cmd_test.go | 64 ++ flytestdlib/config/duration.go | 47 ++ flytestdlib/config/duration_test.go | 66 ++ flytestdlib/config/errors.go | 37 + flytestdlib/config/errors_test.go | 31 + flytestdlib/config/files/finder.go | 83 +++ flytestdlib/config/files/finder_test.go | 28 + .../config/files/testdata/config-1.yaml | 9 + .../config/files/testdata/config-2.yaml | 2 + .../config/files/testdata/other-group-1.yaml | 9 + .../config/files/testdata/other-group-2.yaml | 2 + flytestdlib/config/port.go | 56 ++ flytestdlib/config/port_test.go | 81 +++ flytestdlib/config/section.go | 230 +++++++ flytestdlib/config/section_test.go | 119 ++++ flytestdlib/config/testdata/config.yaml | 11 + flytestdlib/config/tests/accessor_test.go | 641 +++++++++++++++++ flytestdlib/config/tests/config_cmd_test.go | 92 +++ .../config/tests/testdata/array_configs.yaml | 7 + .../config/tests/testdata/bad_config.yaml | 13 + flytestdlib/config/tests/testdata/config.yaml | 11 + .../config/tests/testdata/nested_config.yaml | 11 + flytestdlib/config/tests/types_test.go | 66 ++ flytestdlib/config/url.go | 36 + flytestdlib/config/url_test.go | 59 ++ flytestdlib/config/utils.go | 69 ++ flytestdlib/config/utils_test.go | 39 ++ flytestdlib/config/viper/collection.go | 175 +++++ flytestdlib/config/viper/viper.go | 357 ++++++++++ flytestdlib/contextutils/context.go | 140 ++++ flytestdlib/contextutils/context_test.go | 113 +++ flytestdlib/internal/utils/parsers.go | 20 + flytestdlib/internal/utils/parsers_test.go | 25 + flytestdlib/ioutils/bytes.go | 21 + flytestdlib/ioutils/bytes_test.go | 17 + flytestdlib/ioutils/timed_readers.go | 17 + flytestdlib/ioutils/timed_readers_test.go | 20 + flytestdlib/logger/config.go | 81 +++ flytestdlib/logger/config_flags.go | 21 + flytestdlib/logger/config_flags_test.go | 190 ++++++ flytestdlib/logger/config_test.go | 20 + flytestdlib/logger/logger.go | 337 +++++++++ flytestdlib/logger/logger_test.go | 643 ++++++++++++++++++ flytestdlib/pbhash/pbhash.go | 58 ++ flytestdlib/pbhash/pbhash_test.go | 145 ++++ flytestdlib/profutils/server.go | 116 ++++ flytestdlib/profutils/server_test.go | 103 +++ flytestdlib/promutils/labeled/counter.go | 65 ++ flytestdlib/promutils/labeled/counter_test.go | 31 + flytestdlib/promutils/labeled/keys.go | 47 ++ flytestdlib/promutils/labeled/keys_test.go | 24 + .../promutils/labeled/metric_option.go | 15 + .../promutils/labeled/metric_option_test.go | 13 + flytestdlib/promutils/labeled/stopwatch.go | 87 +++ .../promutils/labeled/stopwatch_test.go | 47 ++ .../promutils/labeled/timer_wrapper.go | 20 + .../promutils/labeled/timer_wrapper_test.go | 28 + flytestdlib/promutils/scope.go | 434 ++++++++++++ flytestdlib/promutils/scope_test.go | 151 ++++ flytestdlib/promutils/workqueue.go | 82 +++ flytestdlib/promutils/workqueue_test.go | 42 ++ flytestdlib/sets/generic_set.go | 195 ++++++ flytestdlib/sets/generic_set_test.go | 116 ++++ flytestdlib/storage/cached_rawstore.go | 123 ++++ flytestdlib/storage/cached_rawstore_test.go | 182 +++++ flytestdlib/storage/config.go | 85 +++ flytestdlib/storage/config_flags.go | 28 + flytestdlib/storage/config_flags_test.go | 344 ++++++++++ flytestdlib/storage/config_test.go | 45 ++ flytestdlib/storage/copy_impl.go | 60 ++ flytestdlib/storage/copy_impl_test.go | 82 +++ flytestdlib/storage/localstore.go | 48 ++ flytestdlib/storage/localstore_test.go | 66 ++ flytestdlib/storage/mem_store.go | 74 ++ flytestdlib/storage/mem_store_test.go | 78 +++ flytestdlib/storage/protobuf_store.go | 85 +++ flytestdlib/storage/protobuf_store_test.go | 41 ++ flytestdlib/storage/rawstores.go | 40 ++ flytestdlib/storage/s3store.go | 102 +++ flytestdlib/storage/s3stsore_test.go | 26 + flytestdlib/storage/storage.go | 95 +++ flytestdlib/storage/storage_test.go | 52 ++ flytestdlib/storage/stow_store.go | 174 +++++ flytestdlib/storage/stow_store_test.go | 133 ++++ flytestdlib/storage/testdata/config.yaml | 14 + flytestdlib/storage/url_path.go | 44 ++ flytestdlib/storage/url_path_test.go | 15 + flytestdlib/storage/utils.go | 34 + flytestdlib/tests/config_test.go | 78 +++ flytestdlib/tests/testdata/combined.yaml | 19 + flytestdlib/utils/auto_refresh_cache.go | 99 +++ flytestdlib/utils/auto_refresh_cache_test.go | 108 +++ .../utils/auto_refresh_example_test.go | 88 +++ flytestdlib/utils/rate_limiter.go | 35 + flytestdlib/utils/rate_limiter_test.go | 39 ++ flytestdlib/utils/sequencer.go | 39 ++ flytestdlib/utils/sequencer_test.go | 54 ++ flytestdlib/version/version.go | 29 + flytestdlib/version/version_test.go | 29 + flytestdlib/yamlutils/yaml_json.go | 17 + 135 files changed, 11697 insertions(+) create mode 100644 flytestdlib/.gitignore create mode 100644 flytestdlib/.golangci.yml create mode 100644 flytestdlib/.travis.yml create mode 100644 flytestdlib/CODE_OF_CONDUCT.md create mode 100644 flytestdlib/Gopkg.lock create mode 100644 flytestdlib/Gopkg.toml create mode 100644 flytestdlib/LICENSE create mode 100644 flytestdlib/Makefile create mode 100644 flytestdlib/NOTICE create mode 100644 flytestdlib/README.rst create mode 100644 flytestdlib/atomic/atomic.go create mode 100644 flytestdlib/atomic/atomic_test.go create mode 100644 flytestdlib/atomic/non_blocking_lock.go create mode 100644 flytestdlib/atomic/non_blocking_lock_test.go create mode 100644 flytestdlib/boilerplate/lyft/golang_test_targets/Makefile create mode 100644 flytestdlib/boilerplate/lyft/golang_test_targets/Readme.rst create mode 100755 flytestdlib/boilerplate/lyft/golang_test_targets/goimports create mode 100644 flytestdlib/boilerplate/update.cfg create mode 100644 flytestdlib/cli/pflags/api/generator.go create mode 100644 flytestdlib/cli/pflags/api/generator_test.go create mode 100644 flytestdlib/cli/pflags/api/pflag_provider.go create mode 100644 flytestdlib/cli/pflags/api/sample.go create mode 100644 flytestdlib/cli/pflags/api/tag.go create mode 100644 flytestdlib/cli/pflags/api/templates.go create mode 100755 flytestdlib/cli/pflags/api/testdata/testtype.go create mode 100755 flytestdlib/cli/pflags/api/testdata/testtype_test.go create mode 100644 flytestdlib/cli/pflags/api/types.go create mode 100644 flytestdlib/cli/pflags/api/utils.go create mode 100644 flytestdlib/cli/pflags/cmd/root.go create mode 100644 flytestdlib/cli/pflags/cmd/version.go create mode 100644 flytestdlib/cli/pflags/main.go create mode 100644 flytestdlib/cli/pflags/readme.rst create mode 100644 flytestdlib/config/accessor.go create mode 100644 flytestdlib/config/accessor_test.go create mode 100644 flytestdlib/config/config_cmd.go create mode 100644 flytestdlib/config/config_cmd_test.go create mode 100644 flytestdlib/config/duration.go create mode 100644 flytestdlib/config/duration_test.go create mode 100644 flytestdlib/config/errors.go create mode 100644 flytestdlib/config/errors_test.go create mode 100644 flytestdlib/config/files/finder.go create mode 100644 flytestdlib/config/files/finder_test.go create mode 100644 flytestdlib/config/files/testdata/config-1.yaml create mode 100755 flytestdlib/config/files/testdata/config-2.yaml create mode 100644 flytestdlib/config/files/testdata/other-group-1.yaml create mode 100755 flytestdlib/config/files/testdata/other-group-2.yaml create mode 100644 flytestdlib/config/port.go create mode 100644 flytestdlib/config/port_test.go create mode 100644 flytestdlib/config/section.go create mode 100644 flytestdlib/config/section_test.go create mode 100755 flytestdlib/config/testdata/config.yaml create mode 100644 flytestdlib/config/tests/accessor_test.go create mode 100644 flytestdlib/config/tests/config_cmd_test.go create mode 100644 flytestdlib/config/tests/testdata/array_configs.yaml create mode 100644 flytestdlib/config/tests/testdata/bad_config.yaml create mode 100755 flytestdlib/config/tests/testdata/config.yaml create mode 100755 flytestdlib/config/tests/testdata/nested_config.yaml create mode 100644 flytestdlib/config/tests/types_test.go create mode 100644 flytestdlib/config/url.go create mode 100644 flytestdlib/config/url_test.go create mode 100644 flytestdlib/config/utils.go create mode 100644 flytestdlib/config/utils_test.go create mode 100644 flytestdlib/config/viper/collection.go create mode 100644 flytestdlib/config/viper/viper.go create mode 100644 flytestdlib/contextutils/context.go create mode 100644 flytestdlib/contextutils/context_test.go create mode 100644 flytestdlib/internal/utils/parsers.go create mode 100644 flytestdlib/internal/utils/parsers_test.go create mode 100644 flytestdlib/ioutils/bytes.go create mode 100644 flytestdlib/ioutils/bytes_test.go create mode 100644 flytestdlib/ioutils/timed_readers.go create mode 100644 flytestdlib/ioutils/timed_readers_test.go create mode 100644 flytestdlib/logger/config.go create mode 100755 flytestdlib/logger/config_flags.go create mode 100755 flytestdlib/logger/config_flags_test.go create mode 100644 flytestdlib/logger/config_test.go create mode 100644 flytestdlib/logger/logger.go create mode 100644 flytestdlib/logger/logger_test.go create mode 100644 flytestdlib/pbhash/pbhash.go create mode 100644 flytestdlib/pbhash/pbhash_test.go create mode 100644 flytestdlib/profutils/server.go create mode 100644 flytestdlib/profutils/server_test.go create mode 100644 flytestdlib/promutils/labeled/counter.go create mode 100644 flytestdlib/promutils/labeled/counter_test.go create mode 100644 flytestdlib/promutils/labeled/keys.go create mode 100644 flytestdlib/promutils/labeled/keys_test.go create mode 100644 flytestdlib/promutils/labeled/metric_option.go create mode 100644 flytestdlib/promutils/labeled/metric_option_test.go create mode 100644 flytestdlib/promutils/labeled/stopwatch.go create mode 100644 flytestdlib/promutils/labeled/stopwatch_test.go create mode 100644 flytestdlib/promutils/labeled/timer_wrapper.go create mode 100644 flytestdlib/promutils/labeled/timer_wrapper_test.go create mode 100644 flytestdlib/promutils/scope.go create mode 100644 flytestdlib/promutils/scope_test.go create mode 100644 flytestdlib/promutils/workqueue.go create mode 100644 flytestdlib/promutils/workqueue_test.go create mode 100644 flytestdlib/sets/generic_set.go create mode 100644 flytestdlib/sets/generic_set_test.go create mode 100644 flytestdlib/storage/cached_rawstore.go create mode 100644 flytestdlib/storage/cached_rawstore_test.go create mode 100644 flytestdlib/storage/config.go create mode 100755 flytestdlib/storage/config_flags.go create mode 100755 flytestdlib/storage/config_flags_test.go create mode 100644 flytestdlib/storage/config_test.go create mode 100644 flytestdlib/storage/copy_impl.go create mode 100644 flytestdlib/storage/copy_impl_test.go create mode 100644 flytestdlib/storage/localstore.go create mode 100644 flytestdlib/storage/localstore_test.go create mode 100644 flytestdlib/storage/mem_store.go create mode 100644 flytestdlib/storage/mem_store_test.go create mode 100644 flytestdlib/storage/protobuf_store.go create mode 100644 flytestdlib/storage/protobuf_store_test.go create mode 100644 flytestdlib/storage/rawstores.go create mode 100644 flytestdlib/storage/s3store.go create mode 100644 flytestdlib/storage/s3stsore_test.go create mode 100644 flytestdlib/storage/storage.go create mode 100644 flytestdlib/storage/storage_test.go create mode 100644 flytestdlib/storage/stow_store.go create mode 100644 flytestdlib/storage/stow_store_test.go create mode 100755 flytestdlib/storage/testdata/config.yaml create mode 100644 flytestdlib/storage/url_path.go create mode 100644 flytestdlib/storage/url_path_test.go create mode 100644 flytestdlib/storage/utils.go create mode 100644 flytestdlib/tests/config_test.go create mode 100755 flytestdlib/tests/testdata/combined.yaml create mode 100644 flytestdlib/utils/auto_refresh_cache.go create mode 100644 flytestdlib/utils/auto_refresh_cache_test.go create mode 100644 flytestdlib/utils/auto_refresh_example_test.go create mode 100644 flytestdlib/utils/rate_limiter.go create mode 100644 flytestdlib/utils/rate_limiter_test.go create mode 100644 flytestdlib/utils/sequencer.go create mode 100644 flytestdlib/utils/sequencer_test.go create mode 100644 flytestdlib/version/version.go create mode 100644 flytestdlib/version/version_test.go create mode 100644 flytestdlib/yamlutils/yaml_json.go diff --git a/flytestdlib/.gitignore b/flytestdlib/.gitignore new file mode 100644 index 0000000000..00820a03e4 --- /dev/null +++ b/flytestdlib/.gitignore @@ -0,0 +1,113 @@ + +# Temporary Build Files +tmp/_output +tmp/_test + + +# Created by https://www.gitignore.io/api/go,vim,emacs,visualstudiocode + +### Emacs ### +# -*- mode: gitignore; -*- +*~ +\#*\# +/.emacs.desktop +/.emacs.desktop.lock +*.elc +auto-save-list +tramp +.\#* + +# Org-mode +.org-id-locations +*_archive + +# flymake-mode +*_flymake.* + +# eshell files +/eshell/history +/eshell/lastdir + +# elpa packages +/elpa/ + +# reftex files +*.rel + +# AUCTeX auto folder +/auto/ + +# cask packages +.cask/ +dist/ + +# Flycheck +flycheck_*.el + +# server auth directory +/server/ + +# projectiles files +.projectile +projectile-bookmarks.eld + +# directory configuration +.dir-locals.el + +# saveplace +places + +# url cache +url/cache/ + +# cedet +ede-projects.el + +# smex +smex-items + +# company-statistics +company-statistics-cache.el + +# anaconda-mode +anaconda-mode/ + +### Go ### +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib +vendor/ + +# Test binary, build with 'go test -c' +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +### Vim ### +# swap +.sw[a-p] +.*.sw[a-p] +# session +Session.vim +# temporary +.netrwhist +# auto-generated tag files +tags + +### VisualStudioCode ### +.vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json +.history + +### GoLand ### +.idea/* + + +# End of https://www.gitignore.io/api/go,vim,emacs,visualstudiocode diff --git a/flytestdlib/.golangci.yml b/flytestdlib/.golangci.yml new file mode 100644 index 0000000000..dbfea73e09 --- /dev/null +++ b/flytestdlib/.golangci.yml @@ -0,0 +1,25 @@ +run: + skip-dirs: + - pkg/client + +linters: + disable-all: true + enable: + - deadcode + - errcheck + - gas + - goconst + - goimports + - golint + - gosimple + - govet + - ineffassign + - misspell + - nakedret + - staticcheck + - structcheck + - typecheck + - unconvert + - unparam + - unused + - varcheck diff --git a/flytestdlib/.travis.yml b/flytestdlib/.travis.yml new file mode 100644 index 0000000000..91723384da --- /dev/null +++ b/flytestdlib/.travis.yml @@ -0,0 +1,20 @@ +sudo: required +language: go +go: + - "1.10" +services: + - docker +jobs: + include: + - stage: test + name: unit tests + install: make install + script: make test_unit + - stage: test + name: benchmark tests + install: make install + script: make test_benchmark + - stage: test + install: make install + name: lint + script: make lint diff --git a/flytestdlib/CODE_OF_CONDUCT.md b/flytestdlib/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000..4c3a38cc48 --- /dev/null +++ b/flytestdlib/CODE_OF_CONDUCT.md @@ -0,0 +1,2 @@ +This project is governed by [Lyft's code of conduct](https://github.com/lyft/code-of-conduct). +All contributors and participants agree to abide by its terms. diff --git a/flytestdlib/Gopkg.lock b/flytestdlib/Gopkg.lock new file mode 100644 index 0000000000..a56ebd2dba --- /dev/null +++ b/flytestdlib/Gopkg.lock @@ -0,0 +1,516 @@ +# This file is autogenerated, do not edit; changes may be undone by the next 'dep ensure'. + + +[[projects]] + digest = "1:7fbdc0ca5fc0b0bb66b81ec2fdca82fbe64416742267f11aceb8ed56e6ca3121" + name = "github.com/aws/aws-sdk-go" + packages = [ + "aws", + "aws/awserr", + "aws/awsutil", + "aws/client", + "aws/client/metadata", + "aws/corehandlers", + "aws/credentials", + "aws/credentials/ec2rolecreds", + "aws/credentials/endpointcreds", + "aws/credentials/processcreds", + "aws/credentials/stscreds", + "aws/csm", + "aws/defaults", + "aws/ec2metadata", + "aws/endpoints", + "aws/request", + "aws/session", + "aws/signer/v4", + "internal/ini", + "internal/s3err", + "internal/sdkio", + "internal/sdkrand", + "internal/sdkuri", + "internal/shareddefaults", + "private/protocol", + "private/protocol/eventstream", + "private/protocol/eventstream/eventstreamapi", + "private/protocol/query", + "private/protocol/query/queryutil", + "private/protocol/rest", + "private/protocol/restxml", + "private/protocol/xml/xmlutil", + "service/s3", + "service/sts", + ] + pruneopts = "UT" + revision = "81f3829f5a9d041041bdf56e55926691309d7699" + version = "v1.16.26" + +[[projects]] + branch = "master" + digest = "1:a6609679ca468a89b711934f16b346e99f6ec344eadd2f7b00b1156785dd1236" + name = "github.com/benlaurie/objecthash" + packages = ["go/objecthash"] + pruneopts = "UT" + revision = "d1e3d6079fc16f8f542183fb5b2fdc11d9f00866" + +[[projects]] + branch = "master" + digest = "1:d6afaeed1502aa28e80a4ed0981d570ad91b2579193404256ce672ed0a609e0d" + name = "github.com/beorn7/perks" + packages = ["quantile"] + pruneopts = "UT" + revision = "3a771d992973f24aa725d07868b467d1ddfceafb" + +[[projects]] + digest = "1:998cf998358a303ac2430c386ba3fd3398477d6013153d3c6e11432765cc9ae6" + name = "github.com/cespare/xxhash" + packages = ["."] + pruneopts = "UT" + revision = "3b82fb7d186719faeedd0c2864f868c74fbf79a1" + version = "v2.0.0" + +[[projects]] + digest = "1:04179a5bcbecdb18f06cca42e3808ae8560f86ad7fe470fde21206008f0c5e26" + name = "github.com/coocood/freecache" + packages = ["."] + pruneopts = "UT" + revision = "f3233c8095b26cd0dea0b136b931708c05defa08" + version = "v1.0.1" + +[[projects]] + digest = "1:ffe9824d294da03b391f44e1ae8281281b4afc1bdaa9588c9097785e3af10cec" + name = "github.com/davecgh/go-spew" + packages = ["spew"] + pruneopts = "UT" + revision = "8991bc29aa16c548c550c7ff78260e27b9ab7c73" + version = "v1.1.1" + +[[projects]] + branch = "master" + digest = "1:dc8bf44b7198605c83a4f2bb36a92c4d9f71eab2e8cf8094ce31b0297dd8ea89" + name = "github.com/ernesto-jimenez/gogen" + packages = [ + "gogenutil", + "imports", + ] + pruneopts = "UT" + revision = "d7d4131e6607813977e78297a6060f360f056a97" + +[[projects]] + digest = "1:865079840386857c809b72ce300be7580cb50d3d3129ce11bf9aa6ca2bc1934a" + name = "github.com/fatih/color" + packages = ["."] + pruneopts = "UT" + revision = "5b77d2a35fb0ede96d138fc9a99f5c9b6aef11b4" + version = "v1.7.0" + +[[projects]] + digest = "1:060e2ff7ee3e51b4a0fadf46308033bfe3b8030af6a8078ec26916e2e9b2fdc3" + name = "github.com/fatih/structtag" + packages = ["."] + pruneopts = "UT" + revision = "76ae1d6d2117609598c7d4e8f3e938145f204e8f" + version = "v1.0.0" + +[[projects]] + branch = "master" + digest = "1:b9414457752702c53f6afd3838da3d89b9513ada40cdbe9603bdf54b1ceb5014" + name = "github.com/fsnotify/fsnotify" + packages = ["."] + pruneopts = "UT" + revision = "ccc981bf80385c528a65fbfdd49bf2d8da22aa23" + +[[projects]] + digest = "1:2cd7915ab26ede7d95b8749e6b1f933f1c6d5398030684e6505940a10f31cfda" + name = "github.com/ghodss/yaml" + packages = ["."] + pruneopts = "UT" + revision = "0ca9ea5df5451ffdf184b4428c902747c2c11cd7" + version = "v1.0.0" + +[[projects]] + branch = "master" + digest = "1:1ba1d79f2810270045c328ae5d674321db34e3aae468eb4233883b473c5c0467" + name = "github.com/golang/glog" + packages = ["."] + pruneopts = "UT" + revision = "23def4e6c14b4da8ac2ed8007337bc5eb5007998" + +[[projects]] + digest = "1:9d6dc4d6de69b330d0de86494d6db90c09848c003d5db748f40c925f865c8534" + name = "github.com/golang/protobuf" + packages = [ + "jsonpb", + "proto", + "ptypes", + "ptypes/any", + "ptypes/duration", + "ptypes/struct", + "ptypes/timestamp", + ] + pruneopts = "UT" + revision = "aa810b61a9c79d51363740d207bb46cf8e620ed5" + version = "v1.2.0" + +[[projects]] + digest = "1:3dab0e385faed192353d2150f6a041f4607f04a0e885f4a5a824eee6b676b4b9" + name = "github.com/graymeta/stow" + packages = [ + ".", + "local", + "s3", + ] + pruneopts = "UT" + revision = "77c84b1dd69c41b74fe0a94ca8ee257d85947327" + +[[projects]] + digest = "1:c0d19ab64b32ce9fe5cf4ddceba78d5bc9807f0016db6b1183599da3dcc24d10" + name = "github.com/hashicorp/hcl" + packages = [ + ".", + "hcl/ast", + "hcl/parser", + "hcl/printer", + "hcl/scanner", + "hcl/strconv", + "hcl/token", + "json/parser", + "json/scanner", + "json/token", + ] + pruneopts = "UT" + revision = "8cb6e5b959231cc1119e43259c4a608f9c51a241" + version = "v1.0.0" + +[[projects]] + digest = "1:870d441fe217b8e689d7949fef6e43efbc787e50f200cb1e70dbca9204a1d6be" + name = "github.com/inconshreveable/mousetrap" + packages = ["."] + pruneopts = "UT" + revision = "76626ae9c91c4f2a10f34cad8ce83ea42c93bb75" + version = "v1.0" + +[[projects]] + digest = "1:bb81097a5b62634f3e9fec1014657855610c82d19b9a40c17612e32651e35dca" + name = "github.com/jmespath/go-jmespath" + packages = ["."] + pruneopts = "UT" + revision = "c2b33e84" + +[[projects]] + digest = "1:0a69a1c0db3591fcefb47f115b224592c8dfa4368b7ba9fae509d5e16cdc95c8" + name = "github.com/konsorten/go-windows-terminal-sequences" + packages = ["."] + pruneopts = "UT" + revision = "5c8c8bd35d3832f5d134ae1e1e375b69a4d25242" + version = "v1.0.1" + +[[projects]] + digest = "1:53e8c5c79716437e601696140e8b1801aae4204f4ec54a504333702a49572c4f" + name = "github.com/magiconair/properties" + packages = [ + ".", + "assert", + ] + pruneopts = "UT" + revision = "c2353362d570a7bfa228149c62842019201cfb71" + version = "v1.8.0" + +[[projects]] + digest = "1:c658e84ad3916da105a761660dcaeb01e63416c8ec7bc62256a9b411a05fcd67" + name = "github.com/mattn/go-colorable" + packages = ["."] + pruneopts = "UT" + revision = "167de6bfdfba052fa6b2d3664c8f5272e23c9072" + version = "v0.0.9" + +[[projects]] + digest = "1:0981502f9816113c9c8c4ac301583841855c8cf4da8c72f696b3ebedf6d0e4e5" + name = "github.com/mattn/go-isatty" + packages = ["."] + pruneopts = "UT" + revision = "6ca4dbf54d38eea1a992b3c722a76a5d1c4cb25c" + version = "v0.0.4" + +[[projects]] + digest = "1:ff5ebae34cfbf047d505ee150de27e60570e8c394b3b8fdbb720ff6ac71985fc" + name = "github.com/matttproud/golang_protobuf_extensions" + packages = ["pbutil"] + pruneopts = "UT" + revision = "c12348ce28de40eed0136aa2b644d0ee0650e56c" + version = "v1.0.1" + +[[projects]] + digest = "1:53bc4cd4914cd7cd52139990d5170d6dc99067ae31c56530621b18b35fc30318" + name = "github.com/mitchellh/mapstructure" + packages = ["."] + pruneopts = "UT" + revision = "3536a929edddb9a5b34bd6861dc4a9647cb459fe" + version = "v1.1.2" + +[[projects]] + digest = "1:95741de3af260a92cc5c7f3f3061e85273f5a81b5db20d4bd68da74bd521675e" + name = "github.com/pelletier/go-toml" + packages = ["."] + pruneopts = "UT" + revision = "c01d1270ff3e442a8a57cddc1c92dc1138598194" + version = "v1.2.0" + +[[projects]] + digest = "1:cf31692c14422fa27c83a05292eb5cbe0fb2775972e8f1f8446a71549bd8980b" + name = "github.com/pkg/errors" + packages = ["."] + pruneopts = "UT" + revision = "ba968bfe8b2f7e042a574c888954fccecfa385b4" + version = "v0.8.1" + +[[projects]] + digest = "1:0028cb19b2e4c3112225cd871870f2d9cf49b9b4276531f03438a88e94be86fe" + name = "github.com/pmezard/go-difflib" + packages = ["difflib"] + pruneopts = "UT" + revision = "792786c7400a136282c1664665ae0a8db921c6c2" + version = "v1.0.0" + +[[projects]] + digest = "1:93a746f1060a8acbcf69344862b2ceced80f854170e1caae089b2834c5fbf7f4" + name = "github.com/prometheus/client_golang" + packages = [ + "prometheus", + "prometheus/internal", + "prometheus/promhttp", + ] + pruneopts = "UT" + revision = "505eaef017263e299324067d40ca2c48f6a2cf50" + version = "v0.9.2" + +[[projects]] + branch = "master" + digest = "1:2d5cd61daa5565187e1d96bae64dbbc6080dacf741448e9629c64fd93203b0d4" + name = "github.com/prometheus/client_model" + packages = ["go"] + pruneopts = "UT" + revision = "fd36f4220a901265f90734c3183c5f0c91daa0b8" + +[[projects]] + digest = "1:35cf6bdf68db765988baa9c4f10cc5d7dda1126a54bd62e252dbcd0b1fc8da90" + name = "github.com/prometheus/common" + packages = [ + "expfmt", + "internal/bitbucket.org/ww/goautoneg", + "model", + ] + pruneopts = "UT" + revision = "cfeb6f9992ffa54aaa4f2170ade4067ee478b250" + version = "v0.2.0" + +[[projects]] + branch = "master" + digest = "1:5833c61ebbd625a6bad8e5a1ada2b3e13710cf3272046953a2c8915340fe60a3" + name = "github.com/prometheus/procfs" + packages = [ + ".", + "internal/util", + "nfs", + "xfs", + ] + pruneopts = "UT" + revision = "316cf8ccfec56d206735d46333ca162eb374da8b" + +[[projects]] + digest = "1:87c2e02fb01c27060ccc5ba7c5a407cc91147726f8f40b70cceeedbc52b1f3a8" + name = "github.com/sirupsen/logrus" + packages = ["."] + pruneopts = "UT" + revision = "e1e72e9de974bd926e5c56f83753fba2df402ce5" + version = "v1.3.0" + +[[projects]] + digest = "1:3e39bafd6c2f4bf3c76c3bfd16a2e09e016510ad5db90dc02b88e2f565d6d595" + name = "github.com/spf13/afero" + packages = [ + ".", + "mem", + ] + pruneopts = "UT" + revision = "f4711e4db9e9a1d3887343acb72b2bbfc2f686f5" + version = "v1.2.1" + +[[projects]] + digest = "1:08d65904057412fc0270fc4812a1c90c594186819243160dc779a402d4b6d0bc" + name = "github.com/spf13/cast" + packages = ["."] + pruneopts = "UT" + revision = "8c9545af88b134710ab1cd196795e7f2388358d7" + version = "v1.3.0" + +[[projects]] + digest = "1:645cabccbb4fa8aab25a956cbcbdf6a6845ca736b2c64e197ca7cbb9d210b939" + name = "github.com/spf13/cobra" + packages = ["."] + pruneopts = "UT" + revision = "ef82de70bb3f60c65fb8eebacbb2d122ef517385" + version = "v0.0.3" + +[[projects]] + digest = "1:68ea4e23713989dc20b1bded5d9da2c5f9be14ff9885beef481848edd18c26cb" + name = "github.com/spf13/jwalterweatherman" + packages = ["."] + pruneopts = "UT" + revision = "4a4406e478ca629068e7768fc33f3f044173c0a6" + version = "v1.0.0" + +[[projects]] + digest = "1:c1b1102241e7f645bc8e0c22ae352e8f0dc6484b6cb4d132fa9f24174e0119e2" + name = "github.com/spf13/pflag" + packages = ["."] + pruneopts = "UT" + revision = "298182f68c66c05229eb03ac171abe6e309ee79a" + version = "v1.0.3" + +[[projects]] + digest = "1:2532daa308722c7b65f4566e634dac2ddfaa0a398a17d8418e96ef2af3939e37" + name = "github.com/spf13/viper" + packages = ["."] + pruneopts = "UT" + revision = "ae103d7e593e371c69e832d5eb3347e2b80cbbc9" + +[[projects]] + digest = "1:972c2427413d41a1e06ca4897e8528e5a1622894050e2f527b38ddf0f343f759" + name = "github.com/stretchr/testify" + packages = ["assert"] + pruneopts = "UT" + revision = "ffdc059bfe9ce6a4e144ba849dbedead332c6053" + version = "v1.3.0" + +[[projects]] + branch = "master" + digest = "1:fde12c4da6237363bf36b81b59aa36a43d28061167ec4acb0d41fc49464e28b9" + name = "golang.org/x/crypto" + packages = ["ssh/terminal"] + pruneopts = "UT" + revision = "b01c7a72566457eb1420261cdafef86638fc3861" + +[[projects]] + branch = "master" + digest = "1:7941e2f16c0833b438cbef7fccfe4f8346f9f7876b42b29717a75d7e8c4800cb" + name = "golang.org/x/sys" + packages = [ + "unix", + "windows", + ] + pruneopts = "UT" + revision = "aca44879d5644da7c5b8ec6a1115e9b6ea6c40d9" + +[[projects]] + digest = "1:8029e9743749d4be5bc9f7d42ea1659471767860f0cdc34d37c3111bd308a295" + name = "golang.org/x/text" + packages = [ + "internal/gen", + "internal/triegen", + "internal/ucd", + "transform", + "unicode/cldr", + "unicode/norm", + ] + pruneopts = "UT" + revision = "f21a4dfb5e38f5895301dc265a8def02365cc3d0" + version = "v0.3.0" + +[[projects]] + branch = "master" + digest = "1:9fdc2b55e8e0fafe4b41884091e51e77344f7dc511c5acedcfd98200003bff90" + name = "golang.org/x/time" + packages = ["rate"] + pruneopts = "UT" + revision = "85acf8d2951cb2a3bde7632f9ff273ef0379bcbd" + +[[projects]] + branch = "master" + digest = "1:86d002f2c67e364e097c5047f517ab38cdef342c3c20be53974c1bfd5b191d30" + name = "golang.org/x/tools" + packages = [ + "go/ast/astutil", + "go/gcexportdata", + "go/internal/cgo", + "go/internal/gcimporter", + "go/internal/packagesdriver", + "go/packages", + "go/types/typeutil", + "imports", + "internal/fastwalk", + "internal/gopathwalk", + "internal/module", + "internal/semver", + ] + pruneopts = "UT" + revision = "58ecf64b2ccd4e014267d2ea143d23c617ee7e4c" + +[[projects]] + digest = "1:4d2e5a73dc1500038e504a8d78b986630e3626dc027bc030ba5c75da257cdb96" + name = "gopkg.in/yaml.v2" + packages = ["."] + pruneopts = "UT" + revision = "51d6538a90f86fe93ac480b35f37b2be17fef232" + version = "v2.2.2" + +[[projects]] + digest = "1:5922c4db083d03579c576df514f096003f422b602aeb30028aedd892b69a4876" + name = "k8s.io/apimachinery" + packages = [ + "pkg/util/clock", + "pkg/util/rand", + "pkg/util/runtime", + "pkg/util/wait", + ] + pruneopts = "UT" + revision = "103fd098999dc9c0c88536f5c9ad2e5da39373ae" + version = "kubernetes-1.11.2" + +[[projects]] + digest = "1:8d66fef1249b9b2105840377af3bab078604d3c298058f563685e88d2a9e6ad3" + name = "k8s.io/client-go" + packages = ["util/workqueue"] + pruneopts = "UT" + revision = "1f13a808da65775f22cbf47862c4e5898d8f4ca1" + version = "kubernetes-1.11.2" + +[solve-meta] + analyzer-name = "dep" + analyzer-version = 1 + input-imports = [ + "github.com/aws/aws-sdk-go/aws/awserr", + "github.com/aws/aws-sdk-go/service/s3", + "github.com/benlaurie/objecthash/go/objecthash", + "github.com/coocood/freecache", + "github.com/ernesto-jimenez/gogen/gogenutil", + "github.com/ernesto-jimenez/gogen/imports", + "github.com/fatih/color", + "github.com/fatih/structtag", + "github.com/fsnotify/fsnotify", + "github.com/ghodss/yaml", + "github.com/golang/protobuf/jsonpb", + "github.com/golang/protobuf/proto", + "github.com/golang/protobuf/ptypes", + "github.com/golang/protobuf/ptypes/duration", + "github.com/golang/protobuf/ptypes/timestamp", + "github.com/graymeta/stow", + "github.com/graymeta/stow/local", + "github.com/graymeta/stow/s3", + "github.com/magiconair/properties/assert", + "github.com/mitchellh/mapstructure", + "github.com/pkg/errors", + "github.com/prometheus/client_golang/prometheus", + "github.com/prometheus/client_golang/prometheus/promhttp", + "github.com/sirupsen/logrus", + "github.com/spf13/cobra", + "github.com/spf13/pflag", + "github.com/spf13/viper", + "github.com/stretchr/testify/assert", + "golang.org/x/time/rate", + "golang.org/x/tools/imports", + "k8s.io/apimachinery/pkg/util/rand", + "k8s.io/apimachinery/pkg/util/wait", + "k8s.io/client-go/util/workqueue", + ] + solver-name = "gps-cdcl" + solver-version = 1 diff --git a/flytestdlib/Gopkg.toml b/flytestdlib/Gopkg.toml new file mode 100644 index 0000000000..3d6cf08f6f --- /dev/null +++ b/flytestdlib/Gopkg.toml @@ -0,0 +1,72 @@ +[[constraint]] + name = "github.com/aws/aws-sdk-go" + version = "1.15.0" + +[[constraint]] + name = "github.com/coocood/freecache" + version = "1.0.1" + +[[constraint]] + branch = "master" + name = "github.com/ernesto-jimenez/gogen" + +[[constraint]] + name = "github.com/fatih/color" + version = "1.7.0" + +[[constraint]] + name = "github.com/fatih/structtag" + version = "1.0.0" + +[[constraint]] + branch = "master" + name = "github.com/fsnotify/fsnotify" + +[[constraint]] + name = "github.com/golang/protobuf" + version = "1.1.0" + +[[constraint]] + name = "github.com/mitchellh/mapstructure" + version = "1.1.2" + +[[constraint]] + name = "github.com/pkg/errors" + version = "0.8.0" + +[[constraint]] + name = "github.com/prometheus/client_golang" + version = "^0.9.0" + +[[constraint]] + name = "github.com/spf13/cobra" + version = "0.0.3" + +[[constraint]] + name = "github.com/spf13/pflag" + version = "1.0.1" + +[[constraint]] + name = "github.com/spf13/viper" + # Viper only fixed symlink config watching after this SHA. move to a proper semVer when one is available. + revision = "ae103d7e593e371c69e832d5eb3347e2b80cbbc9" + +[[constraint]] + branch = "master" + name = "golang.org/x/time" + +[[constraint]] + name = "k8s.io/apimachinery" + version = "kubernetes-1.11.2" + +[[constraint]] + name = "k8s.io/client-go" + version = "kubernetes-1.11.2" + +[[constraint]] + name = "github.com/graymeta/stow" + revision = "77c84b1dd69c41b74fe0a94ca8ee257d85947327" + +[prune] + go-tests = true + unused-packages = true diff --git a/flytestdlib/LICENSE b/flytestdlib/LICENSE new file mode 100644 index 0000000000..bed437514f --- /dev/null +++ b/flytestdlib/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2019 Lyft, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/flytestdlib/Makefile b/flytestdlib/Makefile new file mode 100644 index 0000000000..623663f5f1 --- /dev/null +++ b/flytestdlib/Makefile @@ -0,0 +1,24 @@ +export REPOSITORY=flytestdlib +include boilerplate/lyft/golang_test_targets/Makefile + +# Generate golden files. Add test packages that generate golden files here. +golden: + go test ./cli/pflags/api -update + go test ./config -update + go test ./storage -update + go test ./tests -update + + +generate: + @echo "************************ go generate **********************************" + go generate ./... + +# This is the only target that should be overriden by the project. Get your binary into ${GOREPO}/bin +.PHONY: compile +compile: + mkdir -p ./bin + go build -o pflags ./cli/pflags/main.go && mv ./pflags ./bin + +gen-config: + which pflags || (go get github.com/lyft/flytestdlib/cli/pflags) + @go generate ./... diff --git a/flytestdlib/NOTICE b/flytestdlib/NOTICE new file mode 100644 index 0000000000..9316928ad6 --- /dev/null +++ b/flytestdlib/NOTICE @@ -0,0 +1,21 @@ +flytestdlib +Copyright 2019-2020 Lyft Inc. + +This product includes software developed at Lyft Inc. + +Notices for file(s): + promutils/workqueue.go contains work from https://github.com/kubernetes/kubernetes/ + under the Apache2 license. + +/* +Copyright 2016 The Kubernetes Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ diff --git a/flytestdlib/README.rst b/flytestdlib/README.rst new file mode 100644 index 0000000000..f78785fa3d --- /dev/null +++ b/flytestdlib/README.rst @@ -0,0 +1,35 @@ +K8s Standard Library +===================== +Shared components we found ourselves building time and time again, so we collected them in one place! + +This library consists of: + - config + + Enables strongly typed config throughout your application. Offers a way to represent config in go structs. takes care of parsing, validating and watching for changes on config. + + - cli/pflags + + Tool to generate a pflags for all fields in a given struct. + - storage + + Abstract storage library that uses stow behind the scenes to connect to s3/azure/gcs but also offers configurable factory, in-memory storage (for testing) as well as native protobuf support. + - contextutils + + Wrapper around golang's context to set/get known keys. + - logger + + Wrapper around logrus that's configurable, taggable and context-aware. + - profutils + + Starts an http server that serves /metrics (exposes prometheus metrics), /healthcheck and /version endpoints. + - promutils + + Exposes a Scope instance that's a more convenient way to construct prometheus metrics and scope them per component. + - atomic + + Wrapper around sync.atomic library to offer AtomicInt32 and other convenient types. + - sets + + Offers strongly types and convenient interface sets. + - utils + - version diff --git a/flytestdlib/atomic/atomic.go b/flytestdlib/atomic/atomic.go new file mode 100644 index 0000000000..26d04ab721 --- /dev/null +++ b/flytestdlib/atomic/atomic.go @@ -0,0 +1,129 @@ +package atomic + +import "sync/atomic" + +// This file contains some simplified atomics primitives that Golang default library does not offer +// like, Boolean + +// Takes in a uint32 and converts to bool by checking whether the last bit is set to 1 +func toBool(n uint32) bool { + return n&1 == 1 +} + +// Takes in a bool and returns a uint32 representation +func toInt(b bool) uint32 { + if b { + return 1 + } + return 0 +} + +// Bool is an atomic Boolean. +// It stores the bool as a uint32 internally. This is to use the uint32 atomic functions available in golang +type Bool struct{ v uint32 } + +// NewBool creates a Bool. +func NewBool(initial bool) Bool { + return Bool{v: toInt(initial)} +} + +// Load atomically loads the Boolean. +func (b *Bool) Load() bool { + return toBool(atomic.LoadUint32(&b.v)) +} + +// CAS is an atomic compare-and-swap. +func (b *Bool) CompareAndSwap(old, new bool) bool { + return atomic.CompareAndSwapUint32(&b.v, toInt(old), toInt(new)) +} + +// Store atomically stores the passed value. +func (b *Bool) Store(new bool) { + atomic.StoreUint32(&b.v, toInt(new)) +} + +// Swap sets the given value and returns the previous value. +func (b *Bool) Swap(new bool) bool { + return toBool(atomic.SwapUint32(&b.v, toInt(new))) +} + +// Toggle atomically negates the Boolean and returns the previous value. +func (b *Bool) Toggle() bool { + return toBool(atomic.AddUint32(&b.v, 1) - 1) +} + +type Uint32 struct { + v uint32 +} + +// Returns a loaded uint32 value +func (u *Uint32) Load() uint32 { + return atomic.LoadUint32(&u.v) +} + +// CAS is an atomic compare-and-swap. +func (u *Uint32) CompareAndSwap(old, new uint32) bool { + return atomic.CompareAndSwapUint32(&u.v, old, new) +} + +// Add a delta to the number +func (u *Uint32) Add(delta uint32) uint32 { + return atomic.AddUint32(&u.v, delta) +} + +// Increment the value +func (u *Uint32) Inc() uint32 { + return atomic.AddUint32(&u.v, 1) +} + +// Set the value +func (u *Uint32) Store(v uint32) { + atomic.StoreUint32(&u.v, v) +} + +func NewUint32(v uint32) Uint32 { + return Uint32{v: v} +} + +type Int32 struct { + v int32 +} + +// Returns a loaded uint32 value +func (i *Int32) Load() int32 { + return atomic.LoadInt32(&i.v) +} + +// CAS is an atomic compare-and-swap. +func (i *Int32) CompareAndSwap(old, new int32) bool { + return atomic.CompareAndSwapInt32(&i.v, old, new) +} + +// Add a delta to the number +func (i *Int32) Add(delta int32) int32 { + return atomic.AddInt32(&i.v, delta) +} + +// Subtract a delta from the number +func (i *Int32) Sub(delta int32) int32 { + return atomic.AddInt32(&i.v, -delta) +} + +// Increment the value +func (i *Int32) Inc() int32 { + return atomic.AddInt32(&i.v, 1) +} + +// Decrement the value +func (i *Int32) Dec() int32 { + return atomic.AddInt32(&i.v, -1) +} + +// Set the value +func (i *Int32) Store(v int32) { + atomic.StoreInt32(&i.v, v) +} + +func NewInt32(v int32) Int32 { + return Int32{v: v} +} diff --git a/flytestdlib/atomic/atomic_test.go b/flytestdlib/atomic/atomic_test.go new file mode 100644 index 0000000000..c441a88993 --- /dev/null +++ b/flytestdlib/atomic/atomic_test.go @@ -0,0 +1,61 @@ +package atomic + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBool(t *testing.T) { + atom := NewBool(false) + assert.False(t, atom.Toggle(), "Expected swap to return False.") + assert.True(t, atom.Load(), "Unexpected state after swap. Expected True") + + assert.True(t, atom.CompareAndSwap(true, true), "CAS should swap when old matches") + assert.True(t, atom.Load(), "previous swap should have no effect") + assert.True(t, atom.CompareAndSwap(true, false), "CAS should swap when old matches") + assert.False(t, atom.Load(), "Post swap the value should be true") + assert.False(t, atom.CompareAndSwap(true, false), "CAS should fail on old mismatch") + assert.False(t, atom.Load(), "CAS should not have modified the value") + + atom.Store(false) + assert.False(t, atom.Load(), "Unexpected state after store.") + + prev := atom.Swap(false) + assert.False(t, prev, "Expected Swap to return previous value.") + + prev = atom.Swap(true) + assert.False(t, prev, "Expected Swap to return previous value.") +} + +func TestInt32(t *testing.T) { + atom := NewInt32(2) + assert.False(t, atom.CompareAndSwap(3, 4), "Expected swap to return False.") + assert.Equal(t, int32(2), atom.Load(), "Unexpected state after swap. Expected True") + + assert.True(t, atom.CompareAndSwap(2, 2), "CAS should swap when old matches") + assert.Equal(t, int32(2), atom.Load(), "previous swap should have no effect") + assert.True(t, atom.CompareAndSwap(2, 4), "CAS should swap when old matches") + assert.Equal(t, int32(4), atom.Load(), "Post swap the value should be true") + assert.False(t, atom.CompareAndSwap(2, 3), "CAS should fail on old mismatch") + assert.Equal(t, int32(4), atom.Load(), "CAS should not have modified the value") + + atom.Store(5) + assert.Equal(t, int32(5), atom.Load(), "Unexpected state after store.") +} + +func TestUint32(t *testing.T) { + atom := NewUint32(2) + assert.False(t, atom.CompareAndSwap(3, 4), "Expected swap to return False.") + assert.Equal(t, uint32(2), atom.Load(), "Unexpected state after swap. Expected True") + + assert.True(t, atom.CompareAndSwap(2, 2), "CAS should swap when old matches") + assert.Equal(t, uint32(2), atom.Load(), "previous swap should have no effect") + assert.True(t, atom.CompareAndSwap(2, 4), "CAS should swap when old matches") + assert.Equal(t, uint32(4), atom.Load(), "Post swap the value should be true") + assert.False(t, atom.CompareAndSwap(2, 3), "CAS should fail on old mismatch") + assert.Equal(t, uint32(4), atom.Load(), "CAS should not have modified the value") + + atom.Store(5) + assert.Equal(t, uint32(5), atom.Load(), "Unexpected state after store.") +} diff --git a/flytestdlib/atomic/non_blocking_lock.go b/flytestdlib/atomic/non_blocking_lock.go new file mode 100644 index 0000000000..449841a960 --- /dev/null +++ b/flytestdlib/atomic/non_blocking_lock.go @@ -0,0 +1,23 @@ +package atomic + +// Lock that provides TryLock method instead of blocking lock +type NonBlockingLock interface { + TryLock() bool + Release() +} + +func NewNonBlockingLock() NonBlockingLock { + return &nonBlockingLock{lock: NewBool(false)} +} + +type nonBlockingLock struct { + lock Bool +} + +func (n *nonBlockingLock) TryLock() bool { + return n.lock.CompareAndSwap(false, true) +} + +func (n *nonBlockingLock) Release() { + n.lock.Store(false) +} diff --git a/flytestdlib/atomic/non_blocking_lock_test.go b/flytestdlib/atomic/non_blocking_lock_test.go new file mode 100644 index 0000000000..ddd0a2123d --- /dev/null +++ b/flytestdlib/atomic/non_blocking_lock_test.go @@ -0,0 +1,15 @@ +package atomic + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewNonBlockingLock(t *testing.T) { + lock := NewNonBlockingLock() + assert.True(t, lock.TryLock(), "Unexpected lock acquire failure") + assert.False(t, lock.TryLock(), "already-acquired lock acquired again") + lock.Release() + assert.True(t, lock.TryLock(), "Unexpected lock acquire failure") +} diff --git a/flytestdlib/boilerplate/lyft/golang_test_targets/Makefile b/flytestdlib/boilerplate/lyft/golang_test_targets/Makefile new file mode 100644 index 0000000000..1c6f893521 --- /dev/null +++ b/flytestdlib/boilerplate/lyft/golang_test_targets/Makefile @@ -0,0 +1,31 @@ +.PHONY: lint +lint: #lints the package for common code smells + which golangci-lint || curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | bash -s -- -b $$GOPATH/bin v1.10 + golangci-lint run + +# If code is failing goimports linter, this will fix. +# skips 'vendor' +.PHONY: goimports +goimports: + @boilerplate/lyft/golang_test_targets/goimports + +.PHONY: install +install: #download dependencies (including test deps) for the package + which dep || (curl https://raw.githubusercontent.com/golang/dep/master/install.sh | sh) + dep ensure + +.PHONY: test_unit +test_unit: + go test -cover ./... -race + +.PHONY: test_benchmark +test_benchmark: + go test -bench . ./... + +.PHONY: test_unit_cover +test_unit_cover: + go test ./... -coverprofile /tmp/cover.out -covermode=count; go tool cover -func /tmp/cover.out + +.PHONY: test_unit_visual +test_unit_visual: + go test ./... -coverprofile /tmp/cover.out -covermode=count; go tool cover -html=/tmp/cover.out diff --git a/flytestdlib/boilerplate/lyft/golang_test_targets/Readme.rst b/flytestdlib/boilerplate/lyft/golang_test_targets/Readme.rst new file mode 100644 index 0000000000..acc5744f59 --- /dev/null +++ b/flytestdlib/boilerplate/lyft/golang_test_targets/Readme.rst @@ -0,0 +1,31 @@ +Golang Test Targets +~~~~~~~~~~~~~~~~~~~ + +Provides an ``install`` make target that uses ``dep`` install golang dependencies. + +Provides a ``lint`` make target that uses golangci to lint your code. + +Provides a ``test_unit`` target for unit tests. + +Provides a ``test_unit_cover`` target for analysing coverage of unit tests, which will output the coverage of each function and total statement coverage. + +Provides a ``test_unit_visual`` target for visualizing coverage of unit tests through an interactive html code heat map. + +Provides a ``test_benchmark`` target for benchmark tests. + +**To Enable:** + +Add ``lyft/golang_test_targets`` to your ``boilerplate/update.cfg`` file. + +Make sure you're using ``dep`` for dependency management. + +Provide a ``.golangci`` configuration (the lint target requires it). + +Add ``include boilerplate/lyft/golang_test_targets/Makefile`` in your main ``Makefile`` _after_ your REPOSITORY environment variable + +:: + + REPOSITORY= + include boilerplate/lyft/golang_test_targets/Makefile + +(this ensures the extra make targets get included in your main Makefile) diff --git a/flytestdlib/boilerplate/lyft/golang_test_targets/goimports b/flytestdlib/boilerplate/lyft/golang_test_targets/goimports new file mode 100755 index 0000000000..11d3c9af06 --- /dev/null +++ b/flytestdlib/boilerplate/lyft/golang_test_targets/goimports @@ -0,0 +1,3 @@ +#!/usr/bin/env bash + +goimports -w $(find . -type f -name '*.go' -not -path "./vendor/*" -not -path "./pkg/client/*") diff --git a/flytestdlib/boilerplate/update.cfg b/flytestdlib/boilerplate/update.cfg new file mode 100644 index 0000000000..f861a23ccd --- /dev/null +++ b/flytestdlib/boilerplate/update.cfg @@ -0,0 +1,2 @@ +lyft/golang_test_targets +lyft/golangci_file diff --git a/flytestdlib/cli/pflags/api/generator.go b/flytestdlib/cli/pflags/api/generator.go new file mode 100644 index 0000000000..2e4dd30c54 --- /dev/null +++ b/flytestdlib/cli/pflags/api/generator.go @@ -0,0 +1,277 @@ +package api + +import ( + "context" + "fmt" + "go/types" + "path/filepath" + + "github.com/lyft/flytestdlib/logger" + + "go/importer" + + "github.com/ernesto-jimenez/gogen/gogenutil" +) + +const ( + indent = " " +) + +// PFlagProviderGenerator parses and generates GetPFlagSet implementation to add PFlags for a given struct's fields. +type PFlagProviderGenerator struct { + pkg *types.Package + st *types.Named +} + +// This list is restricted because that's the only kinds viper parses out, otherwise it assumes strings. +// github.com/spf13/viper/viper.go:1016 +var allowedKinds = []types.Type{ + types.Typ[types.Int], + types.Typ[types.Int8], + types.Typ[types.Int16], + types.Typ[types.Int32], + types.Typ[types.Int64], + types.Typ[types.Bool], + types.Typ[types.String], +} + +type SliceOrArray interface { + Elem() types.Type +} + +func capitalize(s string) string { + if s[0] >= 'a' && s[0] <= 'z' { + return string(s[0]-'a'+'A') + s[1:] + } + + return s +} + +func buildFieldForSlice(ctx context.Context, t SliceOrArray, name, goName, usage, defaultValue string) (FieldInfo, error) { + strategy := SliceRaw + FlagMethodName := "StringSlice" + typ := types.NewSlice(types.Typ[types.String]) + emptyDefaultValue := `[]string{}` + if b, ok := t.Elem().(*types.Basic); !ok { + logger.Infof(ctx, "Elem of type [%v] is not a basic type. It must be json unmarshalable or generation will fail.", t.Elem()) + if !jsonUnmarshaler(t.Elem()) { + return FieldInfo{}, + fmt.Errorf("slice of type [%v] is not supported. Only basic slices or slices of json-unmarshalable types are supported", + t.Elem().String()) + } + } else { + logger.Infof(ctx, "Elem of type [%v] is a basic type. Will use a pflag as a Slice.", b) + strategy = SliceJoined + FlagMethodName = fmt.Sprintf("%vSlice", capitalize(b.Name())) + typ = types.NewSlice(b) + emptyDefaultValue = fmt.Sprintf(`[]%v{}`, b.Name()) + } + + testValue := defaultValue + if len(defaultValue) == 0 { + defaultValue = emptyDefaultValue + testValue = `"1,1"` + } + + return FieldInfo{ + Name: name, + GoName: goName, + Typ: typ, + FlagMethodName: FlagMethodName, + DefaultValue: defaultValue, + UsageString: usage, + TestValue: testValue, + TestStrategy: strategy, + }, nil +} + +// Traverses fields in type and follows recursion tree to discover all fields. It stops when one of two conditions is +// met; encountered a basic type (e.g. string, int... etc.) or the field type implements UnmarshalJSON. +func discoverFieldsRecursive(ctx context.Context, typ *types.Named) ([]FieldInfo, error) { + logger.Printf(ctx, "Finding all fields in [%v.%v.%v]", + typ.Obj().Pkg().Path(), typ.Obj().Pkg().Name(), typ.Obj().Name()) + + ctx = logger.WithIndent(ctx, indent) + + st := typ.Underlying().(*types.Struct) + fields := make([]FieldInfo, 0, st.NumFields()) + for i := 0; i < st.NumFields(); i++ { + v := st.Field(i) + if !v.IsField() { + continue + } + + // Parses out the tag if one exists. + tag, err := ParseTag(st.Tag(i)) + if err != nil { + return nil, err + } + + if len(tag.Name) == 0 { + tag.Name = v.Name() + } + + typ := v.Type() + if ptr, isPtr := typ.(*types.Pointer); isPtr { + typ = ptr.Elem() + } + + switch t := typ.(type) { + case *types.Basic: + if len(tag.DefaultValue) == 0 { + tag.DefaultValue = fmt.Sprintf("*new(%v)", typ.String()) + } + + logger.Infof(ctx, "[%v] is of a basic type with default value [%v].", tag.Name, tag.DefaultValue) + + isAllowed := false + for _, k := range allowedKinds { + if t.String() == k.String() { + isAllowed = true + break + } + } + + if !isAllowed { + return nil, fmt.Errorf("only these basic kinds are allowed. given [%v] (Kind: [%v]. expected: [%+v]", + t.String(), t.Kind(), allowedKinds) + } + + fields = append(fields, FieldInfo{ + Name: tag.Name, + GoName: v.Name(), + Typ: t, + FlagMethodName: camelCase(t.String()), + DefaultValue: tag.DefaultValue, + UsageString: tag.Usage, + TestValue: `"1"`, + TestStrategy: JSON, + }) + case *types.Named: + if _, isStruct := t.Underlying().(*types.Struct); !isStruct { + // TODO: Add a more descriptive error message. + return nil, fmt.Errorf("invalid type. it must be struct, received [%v] for field [%v]", t.Underlying().String(), tag.Name) + } + + // If the type has json unmarshaler, then stop the recursion and assume the type is string. config package + // will use json unmarshaler to fill in the final config object. + jsonUnmarshaler := jsonUnmarshaler(t) + + testValue := tag.DefaultValue + if len(tag.DefaultValue) == 0 { + tag.DefaultValue = `""` + testValue = `"1"` + } + + logger.Infof(ctx, "[%v] is of a Named type (struct) with default value [%v].", tag.Name, tag.DefaultValue) + + if jsonUnmarshaler { + logger.Infof(logger.WithIndent(ctx, indent), "Type is json unmarhslalable.") + + fields = append(fields, FieldInfo{ + Name: tag.Name, + GoName: v.Name(), + Typ: types.Typ[types.String], + FlagMethodName: "String", + DefaultValue: tag.DefaultValue, + UsageString: tag.Usage, + TestValue: testValue, + TestStrategy: JSON, + }) + } else { + logger.Infof(ctx, "Traversing fields in type.") + + nested, err := discoverFieldsRecursive(logger.WithIndent(ctx, indent), t) + if err != nil { + return nil, err + } + + for _, subField := range nested { + fields = append(fields, FieldInfo{ + Name: fmt.Sprintf("%v.%v", tag.Name, subField.Name), + GoName: fmt.Sprintf("%v.%v", v.Name(), subField.GoName), + Typ: subField.Typ, + FlagMethodName: subField.FlagMethodName, + DefaultValue: subField.DefaultValue, + UsageString: subField.UsageString, + TestValue: subField.TestValue, + TestStrategy: subField.TestStrategy, + }) + } + } + case *types.Slice: + logger.Infof(ctx, "[%v] is of a slice type with default value [%v].", tag.Name, tag.DefaultValue) + + f, err := buildFieldForSlice(logger.WithIndent(ctx, indent), t, tag.Name, v.Name(), tag.Usage, tag.DefaultValue) + if err != nil { + return nil, err + } + + fields = append(fields, f) + case *types.Array: + logger.Infof(ctx, "[%v] is of an array with default value [%v].", tag.Name, tag.DefaultValue) + + f, err := buildFieldForSlice(logger.WithIndent(ctx, indent), t, tag.Name, v.Name(), tag.Usage, tag.DefaultValue) + if err != nil { + return nil, err + } + + fields = append(fields, f) + default: + return nil, fmt.Errorf("unexpected type %v", t.String()) + } + } + + return fields, nil +} + +// NewGenerator initializes a PFlagProviderGenerator for pflags files for targetTypeName struct under pkg. If pkg is not filled in, +// it's assumed to be current package (which is expected to be the common use case when invoking pflags from +// go:generate comments) +func NewGenerator(pkg, targetTypeName string) (*PFlagProviderGenerator, error) { + var err error + // Resolve package path + if pkg == "" || pkg[0] == '.' { + pkg, err = filepath.Abs(filepath.Clean(pkg)) + if err != nil { + return nil, err + } + pkg = gogenutil.StripGopath(pkg) + } + + targetPackage, err := importer.For("source", nil).Import(pkg) + if err != nil { + return nil, err + } + + obj := targetPackage.Scope().Lookup(targetTypeName) + if obj == nil { + return nil, fmt.Errorf("struct %s missing", targetTypeName) + } + + var st *types.Named + switch obj.Type().Underlying().(type) { + case *types.Struct: + st = obj.Type().(*types.Named) + default: + return nil, fmt.Errorf("%s should be an struct, was %s", targetTypeName, obj.Type().Underlying()) + } + + return &PFlagProviderGenerator{ + st: st, + pkg: targetPackage, + }, nil +} + +func (g PFlagProviderGenerator) GetTargetPackage() *types.Package { + return g.pkg +} + +func (g PFlagProviderGenerator) Generate(ctx context.Context) (PFlagProvider, error) { + fields, err := discoverFieldsRecursive(ctx, g.st) + if err != nil { + return PFlagProvider{}, err + } + + return newPflagProvider(g.pkg, g.st.Obj().Name(), fields), nil +} diff --git a/flytestdlib/cli/pflags/api/generator_test.go b/flytestdlib/cli/pflags/api/generator_test.go new file mode 100644 index 0000000000..edfab6c1ca --- /dev/null +++ b/flytestdlib/cli/pflags/api/generator_test.go @@ -0,0 +1,56 @@ +package api + +import ( + "context" + "flag" + "io/ioutil" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" +) + +// Make sure existing config file(s) parse correctly before overriding them with this flag! +var update = flag.Bool("update", false, "Updates testdata") + +func TestNewGenerator(t *testing.T) { + g, err := NewGenerator(".", "TestType") + assert.NoError(t, err) + + ctx := context.Background() + p, err := g.Generate(ctx) + assert.NoError(t, err) + + codeOutput, err := ioutil.TempFile("", "output-*.go") + assert.NoError(t, err) + defer func() { assert.NoError(t, os.Remove(codeOutput.Name())) }() + + testOutput, err := ioutil.TempFile("", "output-*_test.go") + assert.NoError(t, err) + defer func() { assert.NoError(t, os.Remove(testOutput.Name())) }() + + assert.NoError(t, p.WriteCodeFile(codeOutput.Name())) + assert.NoError(t, p.WriteTestFile(testOutput.Name())) + + codeBytes, err := ioutil.ReadFile(codeOutput.Name()) + assert.NoError(t, err) + + testBytes, err := ioutil.ReadFile(testOutput.Name()) + assert.NoError(t, err) + + goldenFilePath := filepath.Join("testdata", "testtype.go") + goldenTestFilePath := filepath.Join("testdata", "testtype_test.go") + if *update { + assert.NoError(t, ioutil.WriteFile(goldenFilePath, codeBytes, os.ModePerm)) + assert.NoError(t, ioutil.WriteFile(goldenTestFilePath, testBytes, os.ModePerm)) + } + + goldenOutput, err := ioutil.ReadFile(filepath.Clean(goldenFilePath)) + assert.NoError(t, err) + assert.Equal(t, goldenOutput, codeBytes) + + goldenTestOutput, err := ioutil.ReadFile(filepath.Clean(goldenTestFilePath)) + assert.NoError(t, err) + assert.Equal(t, string(goldenTestOutput), string(testBytes)) +} diff --git a/flytestdlib/cli/pflags/api/pflag_provider.go b/flytestdlib/cli/pflags/api/pflag_provider.go new file mode 100644 index 0000000000..e414398063 --- /dev/null +++ b/flytestdlib/cli/pflags/api/pflag_provider.go @@ -0,0 +1,90 @@ +package api + +import ( + "bytes" + "fmt" + "go/types" + "io/ioutil" + "os" + "time" + + "github.com/ernesto-jimenez/gogen/imports" + goimports "golang.org/x/tools/imports" +) + +type PFlagProvider struct { + typeName string + pkg *types.Package + fields []FieldInfo +} + +// Adds any needed imports for types not directly declared in this package. +func (p PFlagProvider) Imports() map[string]string { + imp := imports.New(p.pkg.Name()) + for _, m := range p.fields { + imp.AddImportsFrom(m.Typ) + } + + return imp.Imports() +} + +// Evaluates the main code file template and writes the output to outputFilePath +func (p PFlagProvider) WriteCodeFile(outputFilePath string) error { + buf := bytes.Buffer{} + err := p.generate(GenerateCodeFile, &buf, outputFilePath) + if err != nil { + return fmt.Errorf("error generating code, Error: %v. Source: %v", err, buf.String()) + } + + return p.writeToFile(&buf, outputFilePath) +} + +// Evaluates the test code file template and writes the output to outputFilePath +func (p PFlagProvider) WriteTestFile(outputFilePath string) error { + buf := bytes.Buffer{} + err := p.generate(GenerateTestFile, &buf, outputFilePath) + if err != nil { + return fmt.Errorf("error generating code, Error: %v. Source: %v", err, buf.String()) + } + + return p.writeToFile(&buf, outputFilePath) +} + +func (p PFlagProvider) writeToFile(buffer *bytes.Buffer, fileName string) error { + return ioutil.WriteFile(fileName, buffer.Bytes(), os.ModePerm) +} + +// Evaluates the generator and writes the output to buffer. targetFileName is used only to influence how imports are +// generated/optimized. +func (p PFlagProvider) generate(generator func(buffer *bytes.Buffer, info TypeInfo) error, buffer *bytes.Buffer, targetFileName string) error { + info := TypeInfo{ + Name: p.typeName, + Fields: p.fields, + Package: p.pkg.Name(), + Timestamp: time.Now(), + Imports: p.Imports(), + } + + if err := generator(buffer, info); err != nil { + return err + } + + // Update imports + newBytes, err := goimports.Process(targetFileName, buffer.Bytes(), nil) + if err != nil { + return err + } + + buffer.Reset() + _, err = buffer.Write(newBytes) + + return err +} + +func newPflagProvider(pkg *types.Package, typeName string, fields []FieldInfo) PFlagProvider { + return PFlagProvider{ + typeName: typeName, + pkg: pkg, + fields: fields, + } +} diff --git a/flytestdlib/cli/pflags/api/sample.go b/flytestdlib/cli/pflags/api/sample.go new file mode 100644 index 0000000000..b1ebb50684 --- /dev/null +++ b/flytestdlib/cli/pflags/api/sample.go @@ -0,0 +1,53 @@ +package api + +import ( + "encoding/json" + "errors" + + "github.com/lyft/flytestdlib/storage" +) + +type TestType struct { + StringValue string `json:"str" pflag:"\"hello world\",\"life is short\""` + BoolValue bool `json:"bl" pflag:"true"` + NestedType NestedType `json:"nested"` + IntArray []int `json:"ints" pflag:"[]int{12%2C1}"` + StringArray []string `json:"strs" pflag:"[]string{\"12\"%2C\"1\"}"` + ComplexJSONArray []ComplexJSONType `json:"complexArr"` + StringToJSON ComplexJSONType `json:"c" pflag:",I'm a complex type but can be converted from string."` + StorageConfig storage.Config `json:"storage"` + IntValue *int `json:"i"` +} + +type NestedType struct { + IntValue int `json:"i" pflag:",this is an important flag"` +} + +type ComplexJSONType struct { + StringValue string `json:"str"` + IntValue int `json:"i"` +} + +func (c *ComplexJSONType) UnmarshalJSON(b []byte) error { + if len(b) == 0 { + c.StringValue = "" + return nil + } + + var v interface{} + if err := json.Unmarshal(b, &v); err != nil { + return err + } + switch value := v.(type) { + case string: + if len(value) == 0 { + c.StringValue = "" + } else { + c.StringValue = value + } + default: + return errors.New("invalid duration") + } + + return nil +} diff --git a/flytestdlib/cli/pflags/api/tag.go b/flytestdlib/cli/pflags/api/tag.go new file mode 100644 index 0000000000..5d4a2d8e57 --- /dev/null +++ b/flytestdlib/cli/pflags/api/tag.go @@ -0,0 +1,66 @@ +package api + +import ( + "fmt" + "net/url" + "strings" + + "github.com/fatih/structtag" +) + +const ( + TagName = "pflag" + JSONTagName = "json" +) + +// Represents parsed PFlag Go-struct tag. +// type Foo struct { +// StringValue string `json:"str" pflag:"\"hello world\",This is a string value"` +// } +// Name will be "str", Default value is "hello world" and Usage is "This is a string value" +type Tag struct { + Name string + DefaultValue string + Usage string +} + +// Parses tag. Name is computed from json tag, defaultvalue is the name of the pflag tag and usage is the concatenation +// of all options for pflag tag. +// e.g. `json:"name" pflag:"2,this is a useful param"` +func ParseTag(tag string) (t Tag, err error) { + tags, err := structtag.Parse(tag) + if err != nil { + return Tag{}, err + } + + t = Tag{} + + jsonTag, err := tags.Get(JSONTagName) + if err == nil { + t.Name = jsonTag.Name + } + + pflagTag, err := tags.Get(TagName) + if err == nil { + t.DefaultValue, err = url.QueryUnescape(pflagTag.Name) + if err != nil { + fmt.Printf("Failed to Query unescape tag name [%v], will use value as is. Error: %v", pflagTag.Name, err) + t.DefaultValue = pflagTag.Name + } + + t.Usage = strings.Join(pflagTag.Options, ", ") + if len(t.Usage) == 0 { + t.Usage = `""` + } + + if t.Usage[0] != '"' { + t.Usage = fmt.Sprintf(`"%v"`, t.Usage) + } + } else { + // We receive an error when the tag isn't present (or is malformed). Because there is no strongly-typed way to + // do that, we will just set Usage to empty string and move on. + t.Usage = `""` + } + + return t, nil +} diff --git a/flytestdlib/cli/pflags/api/templates.go b/flytestdlib/cli/pflags/api/templates.go new file mode 100644 index 0000000000..a7adba8922 --- /dev/null +++ b/flytestdlib/cli/pflags/api/templates.go @@ -0,0 +1,175 @@ +package api + +import ( + "bytes" + "text/template" +) + +func GenerateCodeFile(buffer *bytes.Buffer, info TypeInfo) error { + return mainTmpl.Execute(buffer, info) +} + +func GenerateTestFile(buffer *bytes.Buffer, info TypeInfo) error { + return testTmpl.Execute(buffer, info) +} + +var mainTmpl = template.Must(template.New("MainFile").Parse( + `// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package {{ .Package }} + +import ( + "github.com/spf13/pflag" + "fmt" +{{range $path, $name := .Imports}} + {{$name}} "{{$path}}"{{end}} +) + +// GetPFlagSet will return strongly types pflags for all fields in {{ .Name }} and its nested types. The format of the +// flags is json-name.json-sub-name... etc. +func ({{ .Name }}) GetPFlagSet(prefix string) *pflag.FlagSet { + cmdFlags := pflag.NewFlagSet("{{ .Name }}", pflag.ExitOnError) + {{- range .Fields }} + cmdFlags.{{ .FlagMethodName }}(fmt.Sprintf("%v%v", prefix, "{{ .Name }}"), {{ .DefaultValue }}, {{ .UsageString }}) + {{- end }} + return cmdFlags +} +`)) + +var testTmpl = template.Must(template.New("TestFile").Parse( + `// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package {{ .Package }} + +import ( + "encoding/json" + "fmt" + "reflect" + "strings" + "testing" + + "github.com/mitchellh/mapstructure" + "github.com/stretchr/testify/assert" +{{- range $path, $name := .Imports}} + {{$name}} "{{$path}}" +{{- end}} +) + +var dereferencableKinds{{ .Name }} = map[reflect.Kind]struct{}{ + reflect.Array: {}, reflect.Chan: {}, reflect.Map: {}, reflect.Ptr: {}, reflect.Slice: {}, +} + +// Checks if t is a kind that can be dereferenced to get its underlying type. +func canGetElement{{ .Name }}(t reflect.Kind) bool { + _, exists := dereferencableKinds{{ .Name }}[t] + return exists +} + +// This decoder hook tests types for json unmarshaling capability. If implemented, it uses json unmarshal to build the +// object. Otherwise, it'll just pass on the original data. +func jsonUnmarshalerHook{{ .Name }}(_, to reflect.Type, data interface{}) (interface{}, error) { + unmarshalerType := reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() + if to.Implements(unmarshalerType) || reflect.PtrTo(to).Implements(unmarshalerType) || + (canGetElement{{ .Name }}(to.Kind()) && to.Elem().Implements(unmarshalerType)) { + + raw, err := json.Marshal(data) + if err != nil { + fmt.Printf("Failed to marshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + res := reflect.New(to).Interface() + err = json.Unmarshal(raw, &res) + if err != nil { + fmt.Printf("Failed to umarshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + return res, nil + } + + return data, nil +} + +func decode_{{ .Name }}(input, result interface{}) error { + config := &mapstructure.DecoderConfig{ + TagName: "json", + WeaklyTypedInput: true, + Result: result, + DecodeHook: mapstructure.ComposeDecodeHookFunc( + mapstructure.StringToTimeDurationHookFunc(), + mapstructure.StringToSliceHookFunc(","), + jsonUnmarshalerHook{{ .Name }}, + ), + } + + decoder, err := mapstructure.NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(input) +} + +func join_{{ .Name }}(arr interface{}, sep string) string { + listValue := reflect.ValueOf(arr) + strs := make([]string, 0, listValue.Len()) + for i := 0; i < listValue.Len(); i++ { + strs = append(strs, fmt.Sprintf("%v", listValue.Index(i))) + } + + return strings.Join(strs, sep) +} + +func testDecodeJson_{{ .Name }}(t *testing.T, val, result interface{}) { + assert.NoError(t, decode_{{ .Name }}(val, result)) +} + +func testDecodeSlice_{{ .Name }}(t *testing.T, vStringSlice, result interface{}) { + assert.NoError(t, decode_{{ .Name }}(vStringSlice, result)) +} + +func Test{{ .Name }}_GetPFlagSet(t *testing.T) { + val := {{ .Name }}{} + cmdFlags := val.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) +} + +func Test{{ .Name }}_SetFlags(t *testing.T) { + actual := {{ .Name }}{} + cmdFlags := actual.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) + + {{ $ParentName := .Name }} + {{- range .Fields }} + t.Run("Test_{{ .Name }}", func(t *testing.T) { {{ $varName := print "v" .FlagMethodName }} + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if {{ $varName }}, err := cmdFlags.Get{{ .FlagMethodName }}("{{ .Name }}"); err == nil { + assert.Equal(t, {{ .Typ }}({{ .DefaultValue }}), {{ $varName }}) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + {{ if eq .TestStrategy "Json" }}testValue := {{ .TestValue }} + {{ else if eq .TestStrategy "SliceRaw" }}testValue := {{ .TestValue }} + {{ else }}testValue := join_{{ $ParentName }}({{ .TestValue }}, ",") + {{ end }} + cmdFlags.Set("{{ .Name }}", testValue) + if {{ $varName }}, err := cmdFlags.Get{{ .FlagMethodName }}("{{ .Name }}"); err == nil { + {{ if eq .TestStrategy "Json" }}testDecodeJson_{{ $ParentName }}(t, fmt.Sprintf("%v", {{ print "v" .FlagMethodName }}), &actual.{{ .GoName }}) + {{ else if eq .TestStrategy "SliceRaw" }}testDecodeSlice_{{ $ParentName }}(t, {{ print "v" .FlagMethodName }}, &actual.{{ .GoName }}) + {{ else }}testDecodeSlice_{{ $ParentName }}(t, join_{{ $ParentName }}({{ print "v" .FlagMethodName }}, ","), &actual.{{ .GoName }}) + {{ end }} + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + {{- end }} +} +`)) diff --git a/flytestdlib/cli/pflags/api/testdata/testtype.go b/flytestdlib/cli/pflags/api/testdata/testtype.go new file mode 100755 index 0000000000..87f5cb7dfe --- /dev/null +++ b/flytestdlib/cli/pflags/api/testdata/testtype.go @@ -0,0 +1,36 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package api + +import ( + "fmt" + + "github.com/spf13/pflag" +) + +// GetPFlagSet will return strongly types pflags for all fields in TestType and its nested types. The format of the +// flags is json-name.json-sub-name... etc. +func (TestType) GetPFlagSet(prefix string) *pflag.FlagSet { + cmdFlags := pflag.NewFlagSet("TestType", pflag.ExitOnError) + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "str"), "hello world", "life is short") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "bl"), true, "") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "nested.i"), *new(int), "this is an important flag") + cmdFlags.IntSlice(fmt.Sprintf("%v%v", prefix, "ints"), []int{12, 1}, "") + cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "strs"), []string{"12", "1"}, "") + cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "complexArr"), []string{}, "") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "c"), "", "I'm a complex type but can be converted from string.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.type"), "s3", "Sets the type of storage to configure [s3/minio/local/mem].") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.connection.endpoint"), "", "URL for storage client to connect to.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.connection.auth-type"), "iam", "Auth Type to use [iam, accesskey].") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.connection.access-key"), *new(string), "Access key to use. Only required when authtype is set to accesskey.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.connection.secret-key"), *new(string), "Secret to use when accesskey is set.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.connection.region"), "us-east-1", "Region to connect to.") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "storage.connection.disable-ssl"), *new(bool), "Disables SSL connection. Should only be used for development.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.container"), *new(string), "Initial container to create -if it doesn't exist-.'") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "storage.cache.max_size_mbs"), *new(int), "Maximum size of the cache where the Blob store data is cached in-memory. If not specified or set to 0, cache is not used") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "storage.cache.target_gc_percent"), *new(int), "Sets the garbage collection target percentage.") + cmdFlags.Int64(fmt.Sprintf("%v%v", prefix, "storage.limits.maxDownloadMBs"), 2, "Maximum allowed download size (in MBs) per call.") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "i"), *new(int), "") + return cmdFlags +} diff --git a/flytestdlib/cli/pflags/api/testdata/testtype_test.go b/flytestdlib/cli/pflags/api/testdata/testtype_test.go new file mode 100755 index 0000000000..f8b81bbe11 --- /dev/null +++ b/flytestdlib/cli/pflags/api/testdata/testtype_test.go @@ -0,0 +1,520 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package api + +import ( + "encoding/json" + "fmt" + "reflect" + "strings" + "testing" + + "github.com/mitchellh/mapstructure" + "github.com/stretchr/testify/assert" +) + +var dereferencableKindsTestType = map[reflect.Kind]struct{}{ + reflect.Array: {}, reflect.Chan: {}, reflect.Map: {}, reflect.Ptr: {}, reflect.Slice: {}, +} + +// Checks if t is a kind that can be dereferenced to get its underlying type. +func canGetElementTestType(t reflect.Kind) bool { + _, exists := dereferencableKindsTestType[t] + return exists +} + +// This decoder hook tests types for json unmarshaling capability. If implemented, it uses json unmarshal to build the +// object. Otherwise, it'll just pass on the original data. +func jsonUnmarshalerHookTestType(_, to reflect.Type, data interface{}) (interface{}, error) { + unmarshalerType := reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() + if to.Implements(unmarshalerType) || reflect.PtrTo(to).Implements(unmarshalerType) || + (canGetElementTestType(to.Kind()) && to.Elem().Implements(unmarshalerType)) { + + raw, err := json.Marshal(data) + if err != nil { + fmt.Printf("Failed to marshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + res := reflect.New(to).Interface() + err = json.Unmarshal(raw, &res) + if err != nil { + fmt.Printf("Failed to umarshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + return res, nil + } + + return data, nil +} + +func decode_TestType(input, result interface{}) error { + config := &mapstructure.DecoderConfig{ + TagName: "json", + WeaklyTypedInput: true, + Result: result, + DecodeHook: mapstructure.ComposeDecodeHookFunc( + mapstructure.StringToTimeDurationHookFunc(), + mapstructure.StringToSliceHookFunc(","), + jsonUnmarshalerHookTestType, + ), + } + + decoder, err := mapstructure.NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(input) +} + +func join_TestType(arr interface{}, sep string) string { + listValue := reflect.ValueOf(arr) + strs := make([]string, 0, listValue.Len()) + for i := 0; i < listValue.Len(); i++ { + strs = append(strs, fmt.Sprintf("%v", listValue.Index(i))) + } + + return strings.Join(strs, sep) +} + +func testDecodeJson_TestType(t *testing.T, val, result interface{}) { + assert.NoError(t, decode_TestType(val, result)) +} + +func testDecodeSlice_TestType(t *testing.T, vStringSlice, result interface{}) { + assert.NoError(t, decode_TestType(vStringSlice, result)) +} + +func TestTestType_GetPFlagSet(t *testing.T) { + val := TestType{} + cmdFlags := val.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) +} + +func TestTestType_SetFlags(t *testing.T) { + actual := TestType{} + cmdFlags := actual.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) + + t.Run("Test_str", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("str"); err == nil { + assert.Equal(t, string("hello world"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("str", testValue) + if vString, err := cmdFlags.GetString("str"); err == nil { + testDecodeJson_TestType(t, fmt.Sprintf("%v", vString), &actual.StringValue) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_bl", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vBool, err := cmdFlags.GetBool("bl"); err == nil { + assert.Equal(t, bool(true), vBool) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("bl", testValue) + if vBool, err := cmdFlags.GetBool("bl"); err == nil { + testDecodeJson_TestType(t, fmt.Sprintf("%v", vBool), &actual.BoolValue) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_nested.i", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("nested.i"); err == nil { + assert.Equal(t, int(*new(int)), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("nested.i", testValue) + if vInt, err := cmdFlags.GetInt("nested.i"); err == nil { + testDecodeJson_TestType(t, fmt.Sprintf("%v", vInt), &actual.NestedType.IntValue) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_ints", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vIntSlice, err := cmdFlags.GetIntSlice("ints"); err == nil { + assert.Equal(t, []int([]int{12, 1}), vIntSlice) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := join_TestType([]int{12, 1}, ",") + + cmdFlags.Set("ints", testValue) + if vIntSlice, err := cmdFlags.GetIntSlice("ints"); err == nil { + testDecodeSlice_TestType(t, join_TestType(vIntSlice, ","), &actual.IntArray) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_strs", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vStringSlice, err := cmdFlags.GetStringSlice("strs"); err == nil { + assert.Equal(t, []string([]string{"12", "1"}), vStringSlice) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := join_TestType([]string{"12", "1"}, ",") + + cmdFlags.Set("strs", testValue) + if vStringSlice, err := cmdFlags.GetStringSlice("strs"); err == nil { + testDecodeSlice_TestType(t, join_TestType(vStringSlice, ","), &actual.StringArray) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_complexArr", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vStringSlice, err := cmdFlags.GetStringSlice("complexArr"); err == nil { + assert.Equal(t, []string([]string{}), vStringSlice) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1,1" + + cmdFlags.Set("complexArr", testValue) + if vStringSlice, err := cmdFlags.GetStringSlice("complexArr"); err == nil { + testDecodeSlice_TestType(t, vStringSlice, &actual.ComplexJSONArray) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_c", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("c"); err == nil { + assert.Equal(t, string(""), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("c", testValue) + if vString, err := cmdFlags.GetString("c"); err == nil { + testDecodeJson_TestType(t, fmt.Sprintf("%v", vString), &actual.StringToJSON) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_storage.type", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("storage.type"); err == nil { + assert.Equal(t, string("s3"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("storage.type", testValue) + if vString, err := cmdFlags.GetString("storage.type"); err == nil { + testDecodeJson_TestType(t, fmt.Sprintf("%v", vString), &actual.StorageConfig.Type) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_storage.connection.endpoint", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("storage.connection.endpoint"); err == nil { + assert.Equal(t, string(""), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("storage.connection.endpoint", testValue) + if vString, err := cmdFlags.GetString("storage.connection.endpoint"); err == nil { + testDecodeJson_TestType(t, fmt.Sprintf("%v", vString), &actual.StorageConfig.Connection.Endpoint) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_storage.connection.auth-type", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("storage.connection.auth-type"); err == nil { + assert.Equal(t, string("iam"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("storage.connection.auth-type", testValue) + if vString, err := cmdFlags.GetString("storage.connection.auth-type"); err == nil { + testDecodeJson_TestType(t, fmt.Sprintf("%v", vString), &actual.StorageConfig.Connection.AuthType) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_storage.connection.access-key", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("storage.connection.access-key"); err == nil { + assert.Equal(t, string(*new(string)), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("storage.connection.access-key", testValue) + if vString, err := cmdFlags.GetString("storage.connection.access-key"); err == nil { + testDecodeJson_TestType(t, fmt.Sprintf("%v", vString), &actual.StorageConfig.Connection.AccessKey) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_storage.connection.secret-key", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("storage.connection.secret-key"); err == nil { + assert.Equal(t, string(*new(string)), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("storage.connection.secret-key", testValue) + if vString, err := cmdFlags.GetString("storage.connection.secret-key"); err == nil { + testDecodeJson_TestType(t, fmt.Sprintf("%v", vString), &actual.StorageConfig.Connection.SecretKey) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_storage.connection.region", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("storage.connection.region"); err == nil { + assert.Equal(t, string("us-east-1"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("storage.connection.region", testValue) + if vString, err := cmdFlags.GetString("storage.connection.region"); err == nil { + testDecodeJson_TestType(t, fmt.Sprintf("%v", vString), &actual.StorageConfig.Connection.Region) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_storage.connection.disable-ssl", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vBool, err := cmdFlags.GetBool("storage.connection.disable-ssl"); err == nil { + assert.Equal(t, bool(*new(bool)), vBool) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("storage.connection.disable-ssl", testValue) + if vBool, err := cmdFlags.GetBool("storage.connection.disable-ssl"); err == nil { + testDecodeJson_TestType(t, fmt.Sprintf("%v", vBool), &actual.StorageConfig.Connection.DisableSSL) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_storage.container", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("storage.container"); err == nil { + assert.Equal(t, string(*new(string)), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("storage.container", testValue) + if vString, err := cmdFlags.GetString("storage.container"); err == nil { + testDecodeJson_TestType(t, fmt.Sprintf("%v", vString), &actual.StorageConfig.InitContainer) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_storage.cache.max_size_mbs", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("storage.cache.max_size_mbs"); err == nil { + assert.Equal(t, int(*new(int)), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("storage.cache.max_size_mbs", testValue) + if vInt, err := cmdFlags.GetInt("storage.cache.max_size_mbs"); err == nil { + testDecodeJson_TestType(t, fmt.Sprintf("%v", vInt), &actual.StorageConfig.Cache.MaxSizeMegabytes) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_storage.cache.target_gc_percent", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("storage.cache.target_gc_percent"); err == nil { + assert.Equal(t, int(*new(int)), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("storage.cache.target_gc_percent", testValue) + if vInt, err := cmdFlags.GetInt("storage.cache.target_gc_percent"); err == nil { + testDecodeJson_TestType(t, fmt.Sprintf("%v", vInt), &actual.StorageConfig.Cache.TargetGCPercent) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_storage.limits.maxDownloadMBs", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt64, err := cmdFlags.GetInt64("storage.limits.maxDownloadMBs"); err == nil { + assert.Equal(t, int64(2), vInt64) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("storage.limits.maxDownloadMBs", testValue) + if vInt64, err := cmdFlags.GetInt64("storage.limits.maxDownloadMBs"); err == nil { + testDecodeJson_TestType(t, fmt.Sprintf("%v", vInt64), &actual.StorageConfig.Limits.GetLimitMegabytes) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_i", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("i"); err == nil { + assert.Equal(t, int(*new(int)), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("i", testValue) + if vInt, err := cmdFlags.GetInt("i"); err == nil { + testDecodeJson_TestType(t, fmt.Sprintf("%v", vInt), &actual.IntValue) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) +} diff --git a/flytestdlib/cli/pflags/api/types.go b/flytestdlib/cli/pflags/api/types.go new file mode 100644 index 0000000000..1e6c1297ff --- /dev/null +++ b/flytestdlib/cli/pflags/api/types.go @@ -0,0 +1,36 @@ +package api + +import ( + "go/types" + "time" +) + +// Determines how tests should be generated. +type TestStrategy string + +const ( + JSON TestStrategy = "Json" + SliceJoined TestStrategy = "SliceJoined" + SliceRaw TestStrategy = "SliceRaw" +) + +type FieldInfo struct { + Name string + GoName string + Typ types.Type + DefaultValue string + UsageString string + FlagMethodName string + TestValue string + TestStrategy TestStrategy +} + +// Holds the finalized information passed to the template for evaluation. +type TypeInfo struct { + Timestamp time.Time + Fields []FieldInfo + Package string + Name string + TypeRef string + Imports map[string]string +} diff --git a/flytestdlib/cli/pflags/api/utils.go b/flytestdlib/cli/pflags/api/utils.go new file mode 100644 index 0000000000..4c71fbb1c4 --- /dev/null +++ b/flytestdlib/cli/pflags/api/utils.go @@ -0,0 +1,32 @@ +package api + +import ( + "bytes" + "fmt" + "go/types" + "unicode" +) + +func camelCase(str string) string { + if len(str) == 0 { + return str + } + + firstRune := bytes.Runes([]byte(str))[0] + if unicode.IsLower(firstRune) { + return fmt.Sprintf("%v%v", string(unicode.ToUpper(firstRune)), str[1:]) + } + + return str +} + +func jsonUnmarshaler(t types.Type) bool { + mset := types.NewMethodSet(t) + jsonUnmarshaler := mset.Lookup(nil, "UnmarshalJSON") + if jsonUnmarshaler == nil { + mset = types.NewMethodSet(types.NewPointer(t)) + jsonUnmarshaler = mset.Lookup(nil, "UnmarshalJSON") + } + + return jsonUnmarshaler != nil +} diff --git a/flytestdlib/cli/pflags/cmd/root.go b/flytestdlib/cli/pflags/cmd/root.go new file mode 100644 index 0000000000..b6562d8a1b --- /dev/null +++ b/flytestdlib/cli/pflags/cmd/root.go @@ -0,0 +1,71 @@ +package cmd + +import ( + "bytes" + "context" + "flag" + "fmt" + "strings" + + "github.com/lyft/flytestdlib/cli/pflags/api" + "github.com/lyft/flytestdlib/logger" + "github.com/spf13/cobra" +) + +var ( + pkg = flag.String("pkg", ".", "what package to get the interface from") +) + +var root = cobra.Command{ + Use: "pflags MyStructName --package myproject/mypackage", + Args: cobra.ExactArgs(1), + RunE: generatePflagsProvider, + Example: ` +// go:generate pflags MyStruct +type MyStruct struct { + BoolValue bool ` + "`json:\"bl\" pflag:\"true\"`" + ` + NestedType NestedType ` + "`json:\"nested\"`" + ` + IntArray []int ` + "`json:\"ints\" pflag:\"[]int{12%2C1}\"`" + ` +} + `, +} + +func init() { + root.Flags().StringP("package", "p", ".", "Determines the source/destination package.") +} + +func Execute() error { + return root.Execute() +} + +func generatePflagsProvider(cmd *cobra.Command, args []string) error { + structName := args[0] + if structName == "" { + return fmt.Errorf("need to specify a struct name") + } + + ctx := context.Background() + gen, err := api.NewGenerator(*pkg, structName) + if err != nil { + return err + } + + provider, err := gen.Generate(ctx) + if err != nil { + return err + } + + var buf bytes.Buffer + defer buf.Reset() + + logger.Infof(ctx, "Generating PFlags for type [%v.%v.%v]\n", gen.GetTargetPackage().Path(), gen.GetTargetPackage().Name(), structName) + + outFilePath := fmt.Sprintf("%s_flags.go", strings.ToLower(structName)) + err = provider.WriteCodeFile(outFilePath) + if err != nil { + return err + } + + tOutFilePath := fmt.Sprintf("%s_flags_test.go", strings.ToLower(structName)) + return provider.WriteTestFile(tOutFilePath) +} diff --git a/flytestdlib/cli/pflags/cmd/version.go b/flytestdlib/cli/pflags/cmd/version.go new file mode 100644 index 0000000000..8ee00af2e4 --- /dev/null +++ b/flytestdlib/cli/pflags/cmd/version.go @@ -0,0 +1,17 @@ +package cmd + +import ( + "github.com/lyft/flytestdlib/version" + "github.com/spf13/cobra" +) + +var versionCmd = &cobra.Command{ + Aliases: []string{"version", "ver"}, + Run: func(cmd *cobra.Command, args []string) { + cmd.Printf("Version: %s\nBuildSHA: %s\nBuildTS: %s\n", version.Version, version.Build, version.BuildTime.String()) + }, +} + +func init() { + root.AddCommand(versionCmd) +} diff --git a/flytestdlib/cli/pflags/main.go b/flytestdlib/cli/pflags/main.go new file mode 100644 index 0000000000..e4c784a7b1 --- /dev/null +++ b/flytestdlib/cli/pflags/main.go @@ -0,0 +1,15 @@ +// Generates a Register method to automatically add pflags to a pflagSet for all fields in a given type. +package main + +import ( + "log" + + "github.com/lyft/flytestdlib/cli/pflags/cmd" +) + +func main() { + err := cmd.Execute() + if err != nil { + log.Fatal(err) + } +} diff --git a/flytestdlib/cli/pflags/readme.rst b/flytestdlib/cli/pflags/readme.rst new file mode 100644 index 0000000000..8a47d921f8 --- /dev/null +++ b/flytestdlib/cli/pflags/readme.rst @@ -0,0 +1,24 @@ +================ +Pflags Generator +================ + +This tool enables you to generate code to add pflags for all fields in a struct (recursively). In conjunction with the config package, this can be useful to generate cli flags that overrides configs while maintaing type safety and not having to deal with string typos. + +Getting Started +^^^^^^^^^^^^^^^ + - ``go get github.com/lyft/flytestdlib/cli/pflags`` + - call ``pflags --package `` OR + - add ``//go:generate pflags `` to the top of the file where the struct is declared. + has to be a struct type (it can't be, for instance, a slice type). + Supported fields' types within the struct: basic types (string, int8, int16, int32, int64, bool), json-unmarshalable types and other structs that conform to the same rules or slices of these types. + +This generates two files (struct_name_pflags.go and struct_name_pflags_test.go). If you open those, you will notice that all generated flags default to empty/zero values and no usage strings. That behavior can be customized using ``pflag`` tag. + +.. code-block:: + + type TestType struct { + StringValue string `json:"str" pflag:"\"hello world\",\"life is short\""` + BoolValue bool `json:"bl" pflag:",This is a bool value that will default to false."` + } + +``pflag`` tag is a comma-separated list. First item represents default value. Second value is usage. diff --git a/flytestdlib/config/accessor.go b/flytestdlib/config/accessor.go new file mode 100644 index 0000000000..1b5a693316 --- /dev/null +++ b/flytestdlib/config/accessor.go @@ -0,0 +1,61 @@ +// A strongly-typed config library to parse configs from PFlags, Env Vars and Config files. +// Config package enables consumers to access (readonly for now) strongly typed configs without worrying about mismatching +// keys or casting to the wrong type. It supports basic types (e.g. int, string) as well as more complex structures through +// json encoding/decoding. +// +// Config package introduces the concept of Sections. Each section should be given a unique section key. The binary will +// not load if there is a conflict. Each section should be represented as a Go struct and registered at startup before +// config is loaded/parsed. +// +// Sections can be nested too. A new config section can be registered as a sub-section of an existing one. This allows +// dynamic grouping of sections while continuing to enforce strong-typed parsing of configs. +// +// Config data can be parsed from supported config file(s) (yaml, prop, toml), env vars, PFlags or a combination of these +// Precedence is (flags, env vars, config file, defaults). When data is read from config files, a file watcher is started +// to monitor for changes in those files. If the registrant of a section subscribes to changes then a handler is called +// when the relevant section has been updated. Sections within a single config file will be invoked after all sections +// from that particular config file are parsed. It follows that if there are inter-dependent sections (e.g. changing one +// MUST be followed by a change in another), then make sure those sections are placed in the same config file. +// +// A convenience tool is also provided in cli package (pflags) that generates an implementation for PFlagProvider interface +// based on json names of the fields. +package config + +import ( + "context" + "flag" + + "github.com/spf13/pflag" +) + +// Provides a simple config parser interface. +type Accessor interface { + // Gets a friendly identifier for the accessor. + ID() string + + // Initializes the config parser with golang's default flagset. + InitializeFlags(cmdFlags *flag.FlagSet) + + // Initializes the config parser with pflag's flagset. + InitializePflags(cmdFlags *pflag.FlagSet) + + // Parses and validates config file(s) discovered then updates the underlying config section with the results. + // Exercise caution when calling this because multiple invocations will overwrite each other's results. + UpdateConfig(ctx context.Context) error + + // Gets path(s) to the config file(s) used. + ConfigFilesUsed() []string +} + +// Options used to initialize a Config Accessor +type Options struct { + // Instructs parser to fail if any section/key in the config file read do not have a corresponding registered section. + StrictMode bool + + // Search paths to look for config file(s). If not specified, it searches for config.yaml under current directory as well + // as /etc/flyte/config directories. + SearchPaths []string + + // Defines the root section to use with the accessor. + RootSection Section +} diff --git a/flytestdlib/config/accessor_test.go b/flytestdlib/config/accessor_test.go new file mode 100644 index 0000000000..f386b30b9c --- /dev/null +++ b/flytestdlib/config/accessor_test.go @@ -0,0 +1,91 @@ +package config + +import ( + "context" + "fmt" + "path/filepath" +) + +func Example() { + // This example demonstrates basic usage of config sections. + + //go:generate pflags OtherComponentConfig + + type OtherComponentConfig struct { + DurationValue Duration `json:"duration-value"` + URLValue URL `json:"url-value"` + StringValue string `json:"string-value"` + } + + // Each component should register their section in package init() or as a package var + section := MustRegisterSection("other-component", &OtherComponentConfig{}) + + // Override configpath to look for a custom location. + configPath := filepath.Join("testdata", "config.yaml") + + // Initialize an accessor. + var accessor Accessor + // e.g. + // accessor = viper.NewAccessor(viper.Options{ + // StrictMode: true, + // SearchPaths: []string{configPath, configPath2}, + // }) + + // Optionally bind to Pflags. + // accessor.InitializePflags(flags) + + // Parse config from file or pass empty to rely on env variables and PFlags + err := accessor.UpdateConfig(context.Background()) + if err != nil { + fmt.Printf("Failed to validate config from [%v], error: %v", configPath, err) + return + } + + // Get parsed config value. + parsedConfig := section.GetConfig().(*OtherComponentConfig) + fmt.Printf("Config: %v", parsedConfig) +} + +func Example_nested() { + // This example demonstrates registering nested config sections dynamically. + + //go:generate pflags OtherComponentConfig + + type OtherComponentConfig struct { + DurationValue Duration `json:"duration-value"` + URLValue URL `json:"url-value"` + StringValue string `json:"string-value"` + } + + // Each component should register their section in package init() or as a package var + Section := MustRegisterSection("my-component", &MyComponentConfig{}) + + // Other packages can register their sections at the root level (like the above line) or as nested sections of other + // sections (like the below line) + NestedSection := Section.MustRegisterSection("nested", &OtherComponentConfig{}) + + // Override configpath to look for a custom location. + configPath := filepath.Join("testdata", "nested_config.yaml") + + // Initialize an accessor. + var accessor Accessor + // e.g. + // accessor = viper.NewAccessor(viper.Options{ + // StrictMode: true, + // SearchPaths: []string{configPath, configPath2}, + // }) + + // Optionally bind to Pflags. + // accessor.InitializePflags(flags) + + // Parse config from file or pass empty to rely on env variables and PFlags + err := accessor.UpdateConfig(context.Background()) + if err != nil { + fmt.Printf("Failed to validate config from [%v], error: %v", configPath, err) + return + } + + // Get parsed config value. + parsedConfig := NestedSection.GetConfig().(*OtherComponentConfig) + fmt.Printf("Config: %v", parsedConfig) +} diff --git a/flytestdlib/config/config_cmd.go b/flytestdlib/config/config_cmd.go new file mode 100644 index 0000000000..a5023ceb61 --- /dev/null +++ b/flytestdlib/config/config_cmd.go @@ -0,0 +1,113 @@ +package config + +import ( + "context" + "os" + "strings" + + "github.com/fatih/color" + "github.com/spf13/cobra" +) + +const ( + PathFlag = "file" + StrictModeFlag = "strict" + CommandValidate = "validate" + CommandDiscover = "discover" +) + +type AccessorProvider func(options Options) Accessor + +type printer interface { + Printf(format string, i ...interface{}) + Println(i ...interface{}) +} + +func NewConfigCommand(accessorProvider AccessorProvider) *cobra.Command { + opts := Options{} + rootCmd := &cobra.Command{ + Use: "config", + Short: "Runs various config commands, look at the help of this command to get a list of available commands..", + ValidArgs: []string{CommandValidate, CommandDiscover}, + } + + validateCmd := &cobra.Command{ + Use: "validate", + Short: "Validates the loaded config.", + RunE: func(cmd *cobra.Command, args []string) error { + return validate(accessorProvider(opts), cmd) + }, + } + + discoverCmd := &cobra.Command{ + Use: "discover", + Short: "Searches for a config in one of the default search paths.", + RunE: func(cmd *cobra.Command, args []string) error { + return validate(accessorProvider(opts), cmd) + }, + } + + // Configure Root Command + rootCmd.PersistentFlags().StringArrayVar(&opts.SearchPaths, PathFlag, []string{}, `Passes the config file to load. +If empty, it'll first search for the config file path then, if found, will load config from there.`) + + rootCmd.AddCommand(validateCmd) + rootCmd.AddCommand(discoverCmd) + + // Configure Validate Command + validateCmd.Flags().BoolVar(&opts.StrictMode, StrictModeFlag, false, `Validates that all keys in loaded config +map to already registered sections.`) + + return rootCmd +} + +// Redirects Stdout to a string buffer until context is cancelled. +func redirectStdOut() (old, new *os.File) { + old = os.Stdout // keep backup of the real stdout + var err error + _, new, err = os.Pipe() + if err != nil { + panic(err) + } + + os.Stdout = new + + return +} + +func validate(accessor Accessor, p printer) error { + // Redirect stdout + old, n := redirectStdOut() + defer func() { + err := n.Close() + if err != nil { + panic(err) + } + }() + defer func() { os.Stdout = old }() + + err := accessor.UpdateConfig(context.Background()) + + printInfo(p, accessor) + if err == nil { + green := color.New(color.FgGreen).SprintFunc() + p.Println(green("Validated config file successfully.")) + } else { + red := color.New(color.FgRed).SprintFunc() + p.Println(red("Failed to validate config file.")) + } + + return err +} + +func printInfo(p printer, v Accessor) { + cfgFile := v.ConfigFilesUsed() + if len(cfgFile) != 0 { + green := color.New(color.FgGreen).SprintFunc() + + p.Printf("Config file(s) found at: %v\n", green(strings.Join(cfgFile, "\n"))) + } else { + red := color.New(color.FgRed).SprintFunc() + p.Println(red("Couldn't find a config file.")) + } +} diff --git a/flytestdlib/config/config_cmd_test.go b/flytestdlib/config/config_cmd_test.go new file mode 100644 index 0000000000..d8ef31d298 --- /dev/null +++ b/flytestdlib/config/config_cmd_test.go @@ -0,0 +1,64 @@ +package config + +import ( + "bytes" + "context" + "flag" + "testing" + + "github.com/spf13/cobra" + "github.com/spf13/pflag" + "github.com/stretchr/testify/assert" +) + +type MockAccessor struct { +} + +func (MockAccessor) ID() string { + panic("implement me") +} + +func (MockAccessor) InitializeFlags(cmdFlags *flag.FlagSet) { +} + +func (MockAccessor) InitializePflags(cmdFlags *pflag.FlagSet) { +} + +func (MockAccessor) UpdateConfig(ctx context.Context) error { + return nil +} + +func (MockAccessor) ConfigFilesUsed() []string { + return []string{"test"} +} + +func (MockAccessor) RefreshFromConfig() error { + return nil +} + +func newMockAccessor(options Options) Accessor { + return MockAccessor{} +} + +func executeCommandC(root *cobra.Command, args ...string) (c *cobra.Command, output string, err error) { + buf := new(bytes.Buffer) + root.SetOutput(buf) + root.SetArgs(args) + + c, err = root.ExecuteC() + + return c, buf.String(), err +} + +func TestNewConfigCommand(t *testing.T) { + cmd := NewConfigCommand(newMockAccessor) + assert.NotNil(t, cmd) + + _, output, err := executeCommandC(cmd, CommandDiscover) + assert.NoError(t, err) + assert.Contains(t, output, "test") + + _, output, err = executeCommandC(cmd, CommandValidate) + assert.NoError(t, err) + assert.Contains(t, output, "test") +} diff --git a/flytestdlib/config/duration.go b/flytestdlib/config/duration.go new file mode 100644 index 0000000000..e1a978d5c8 --- /dev/null +++ b/flytestdlib/config/duration.go @@ -0,0 +1,47 @@ +package config + +import ( + "encoding/json" + "errors" + "time" +) + +// A wrapper around time.Duration that enables Json Marshalling capabilities +type Duration struct { + time.Duration +} + +func (d Duration) MarshalJSON() ([]byte, error) { + return json.Marshal(d.String()) +} + +func (d *Duration) UnmarshalJSON(b []byte) error { + if len(b) == 0 { + d.Duration = time.Duration(0) + return nil + } + + var v interface{} + if err := json.Unmarshal(b, &v); err != nil { + return err + } + switch value := v.(type) { + case float64: + d.Duration = time.Duration(value) + return nil + case string: + if len(value) == 0 { + d.Duration = time.Duration(0) + } else { + var err error + d.Duration, err = time.ParseDuration(value) + if err != nil { + return err + } + } + default: + return errors.New("invalid duration") + } + + return nil +} diff --git a/flytestdlib/config/duration_test.go b/flytestdlib/config/duration_test.go new file mode 100644 index 0000000000..8e411987ac --- /dev/null +++ b/flytestdlib/config/duration_test.go @@ -0,0 +1,66 @@ +package config + +import ( + "encoding/json" + "reflect" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestDuration_MarshalJSON(t *testing.T) { + t.Run("Valid", func(t *testing.T) { + expected := Duration{ + Duration: time.Second * 2, + } + + b, err := expected.MarshalJSON() + assert.NoError(t, err) + + actual := Duration{} + err = actual.UnmarshalJSON(b) + assert.NoError(t, err) + + assert.True(t, reflect.DeepEqual(expected, actual)) + }) +} + +func TestDuration_UnmarshalJSON(t *testing.T) { + t.Run("Empty", func(t *testing.T) { + actual := Duration{} + err := actual.UnmarshalJSON([]byte{}) + assert.NoError(t, err) + assert.Equal(t, time.Duration(0), actual.Duration) + }) + + t.Run("Invalid_string", func(t *testing.T) { + input := "blah" + raw, err := json.Marshal(input) + assert.NoError(t, err) + + actual := Duration{} + err = actual.UnmarshalJSON(raw) + assert.Error(t, err) + }) + + t.Run("Valid_float", func(t *testing.T) { + input := float64(12345) + raw, err := json.Marshal(input) + assert.NoError(t, err) + + actual := Duration{} + err = actual.UnmarshalJSON(raw) + assert.NoError(t, err) + }) + + t.Run("Invalid_bool", func(t *testing.T) { + input := true + raw, err := json.Marshal(input) + assert.NoError(t, err) + + actual := Duration{} + err = actual.UnmarshalJSON(raw) + assert.Error(t, err) + }) +} diff --git a/flytestdlib/config/errors.go b/flytestdlib/config/errors.go new file mode 100644 index 0000000000..b46459a9f9 --- /dev/null +++ b/flytestdlib/config/errors.go @@ -0,0 +1,37 @@ +package config + +import "fmt" + +var ( + ErrStrictModeValidation = fmt.Errorf("failed strict mode check") + ErrChildConfigOverridesConfig = fmt.Errorf("child config attempts to override an existing native config property") +) + +// A helper object that collects errors. +type ErrorCollection []error + +func (e ErrorCollection) Error() string { + res := "" + for _, err := range e { + res = fmt.Sprintf("%v\n%v", res, err.Error()) + } + + return res +} + +func (e ErrorCollection) ErrorOrDefault() error { + if len(e) == 0 { + return nil + } + + return e +} + +func (e *ErrorCollection) Append(err error) bool { + if err != nil { + *e = append(*e, err) + return true + } + + return false +} diff --git a/flytestdlib/config/errors_test.go b/flytestdlib/config/errors_test.go new file mode 100644 index 0000000000..4dc1e72876 --- /dev/null +++ b/flytestdlib/config/errors_test.go @@ -0,0 +1,31 @@ +package config + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestErrorCollection_ErrorOrDefault(t *testing.T) { + errs := ErrorCollection{} + assert.Error(t, errs) + assert.NoError(t, errs.ErrorOrDefault()) +} + +func TestErrorCollection_Append(t *testing.T) { + errs := ErrorCollection{} + errs.Append(nil) + errs.Append(fmt.Errorf("this is an actual error")) + assert.Error(t, errs.ErrorOrDefault()) + assert.Len(t, errs, 1) + assert.Len(t, errs.ErrorOrDefault(), 1) +} + +func TestErrorCollection_Error(t *testing.T) { + errs := ErrorCollection{} + errs.Append(nil) + errs.Append(fmt.Errorf("this is an actual error")) + assert.Error(t, errs.ErrorOrDefault()) + assert.Contains(t, errs.ErrorOrDefault().Error(), "this is an actual error") +} diff --git a/flytestdlib/config/files/finder.go b/flytestdlib/config/files/finder.go new file mode 100644 index 0000000000..e389d88306 --- /dev/null +++ b/flytestdlib/config/files/finder.go @@ -0,0 +1,83 @@ +package files + +import ( + "os" + "path/filepath" +) + +const ( + configFileType = "yaml" + configFileName = "config" +) + +var configLocations = [][]string{ + {"."}, + {"/etc", "flyte", "config"}, + {os.ExpandEnv("$GOPATH"), "src", "github.com", "lyft", "flytestdlib"}, +} + +// Check if File / Directory Exists +func exists(path string) (bool, error) { + _, err := os.Stat(path) + if err == nil { + return true, nil + } + + if os.IsNotExist(err) { + return false, nil + } + + return false, err +} + +func isFile(path string) (bool, error) { + s, err := os.Stat(path) + if err != nil { + return false, err + } + + return !s.IsDir(), nil +} + +func contains(slice []string, value string) bool { + for _, s := range slice { + if s == value { + return true + } + } + + return false +} + +// Finds config files in search paths. If searchPaths is empty, it'll look in default locations (see configLocations above) +// If searchPaths is not empty but no configs are found there, it'll still look into configLocations. +// If it found any config file in searchPaths, it'll stop the search. +// searchPaths can contain patterns to match (behavior is OS-dependent). And it'll try to Glob the pattern for any matching +// files. +func FindConfigFiles(searchPaths []string) []string { + res := make([]string, 0, 1) + + for _, location := range searchPaths { + matchedFiles, err := filepath.Glob(location) + if err != nil { + continue + } + + for _, matchedFile := range matchedFiles { + if file, err := isFile(matchedFile); err == nil && file && !contains(res, matchedFile) { + res = append(res, matchedFile) + } + } + } + + if len(res) == 0 { + for _, location := range configLocations { + pathToTest := filepath.Join(append(location, configFileName+"."+configFileType)...) + if b, err := exists(pathToTest); err == nil && b && !contains(res, pathToTest) { + res = append(res, pathToTest) + } + } + } + + return res +} diff --git a/flytestdlib/config/files/finder_test.go b/flytestdlib/config/files/finder_test.go new file mode 100644 index 0000000000..833ac461f1 --- /dev/null +++ b/flytestdlib/config/files/finder_test.go @@ -0,0 +1,28 @@ +package files + +import ( + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestFindConfigFiles(t *testing.T) { + t.Run("Find config-* group", func(t *testing.T) { + files := FindConfigFiles([]string{filepath.Join("testdata", "config*.yaml")}) + assert.Equal(t, 2, len(files)) + }) + + t.Run("Find other-group-* group", func(t *testing.T) { + files := FindConfigFiles([]string{filepath.Join("testdata", "other-group*.yaml")}) + assert.Equal(t, 2, len(files)) + }) + + t.Run("Absolute path", func(t *testing.T) { + files := FindConfigFiles([]string{filepath.Join("testdata", "other-group-1.yaml")}) + assert.Equal(t, 1, len(files)) + + files = FindConfigFiles([]string{filepath.Join("testdata", "other-group-3.yaml")}) + assert.Equal(t, 0, len(files)) + }) +} diff --git a/flytestdlib/config/files/testdata/config-1.yaml b/flytestdlib/config/files/testdata/config-1.yaml new file mode 100644 index 0000000000..a5b6191d98 --- /dev/null +++ b/flytestdlib/config/files/testdata/config-1.yaml @@ -0,0 +1,9 @@ +other-component: + duration-value: 20s + int-val: 4 + string-value: Hey there! + strings: + - hello + - world + - '!' + url-value: http://something.com diff --git a/flytestdlib/config/files/testdata/config-2.yaml b/flytestdlib/config/files/testdata/config-2.yaml new file mode 100755 index 0000000000..5c28d00a09 --- /dev/null +++ b/flytestdlib/config/files/testdata/config-2.yaml @@ -0,0 +1,2 @@ +my-component: + str: Hello World diff --git a/flytestdlib/config/files/testdata/other-group-1.yaml b/flytestdlib/config/files/testdata/other-group-1.yaml new file mode 100644 index 0000000000..a5b6191d98 --- /dev/null +++ b/flytestdlib/config/files/testdata/other-group-1.yaml @@ -0,0 +1,9 @@ +other-component: + duration-value: 20s + int-val: 4 + string-value: Hey there! + strings: + - hello + - world + - '!' + url-value: http://something.com diff --git a/flytestdlib/config/files/testdata/other-group-2.yaml b/flytestdlib/config/files/testdata/other-group-2.yaml new file mode 100755 index 0000000000..5c28d00a09 --- /dev/null +++ b/flytestdlib/config/files/testdata/other-group-2.yaml @@ -0,0 +1,2 @@ +my-component: + str: Hello World diff --git a/flytestdlib/config/port.go b/flytestdlib/config/port.go new file mode 100644 index 0000000000..87bbc854e2 --- /dev/null +++ b/flytestdlib/config/port.go @@ -0,0 +1,56 @@ +package config + +import ( + "encoding/json" + "errors" + "fmt" + "strconv" +) + +// A common port struct that supports Json marshal/unmarshal into/from simple strings/floats. +type Port struct { + Port int `json:"port,omitempty"` +} + +func (p Port) MarshalJSON() ([]byte, error) { + return json.Marshal(p.Port) +} + +func (p *Port) UnmarshalJSON(b []byte) error { + var v interface{} + if err := json.Unmarshal(b, &v); err != nil { + return err + } + + switch value := v.(type) { + case string: + u, err := parsePortString(value) + if err != nil { + return err + } + + p.Port = u + return nil + case float64: + if !validPortRange(value) { + return fmt.Errorf("port must be a valid number between 0 and 65535, inclusive") + } + + p.Port = int(value) + return nil + default: + return errors.New("invalid port") + } +} + +func parsePortString(port string) (int, error) { + if portInt, err := strconv.Atoi(port); err == nil && validPortRange(float64(portInt)) { + return portInt, nil + } + + return 0, fmt.Errorf("port must be a valid number between 1 and 65535, inclusive") +} + +func validPortRange(port float64) bool { + return 0 <= port && port <= 65535 +} diff --git a/flytestdlib/config/port_test.go b/flytestdlib/config/port_test.go new file mode 100644 index 0000000000..c69c09570d --- /dev/null +++ b/flytestdlib/config/port_test.go @@ -0,0 +1,81 @@ +package config + +import ( + "encoding/json" + "fmt" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +type PortTestCase struct { + Expected Port + Input interface{} +} + +func TestPort_MarshalJSON(t *testing.T) { + validPorts := []PortTestCase{ + {Expected: Port{Port: 8080}, Input: 8080}, + {Expected: Port{Port: 1}, Input: 1}, + {Expected: Port{Port: 65535}, Input: "65535"}, + {Expected: Port{Port: 65535}, Input: 65535}, + } + + for i, validPort := range validPorts { + t.Run(fmt.Sprintf("Valid %v [%v]", i, validPort.Input), func(t *testing.T) { + b, err := json.Marshal(validPort.Input) + assert.NoError(t, err) + + actual := Port{} + err = actual.UnmarshalJSON(b) + assert.NoError(t, err) + + assert.True(t, reflect.DeepEqual(validPort.Expected, actual)) + }) + } +} + +func TestPort_UnmarshalJSON(t *testing.T) { + invalidValues := []interface{}{ + "%gh&%ij", + 1000000, + true, + } + + for i, invalidPort := range invalidValues { + t.Run(fmt.Sprintf("Invalid %v", i), func(t *testing.T) { + raw, err := json.Marshal(invalidPort) + assert.NoError(t, err) + + actual := URL{} + err = actual.UnmarshalJSON(raw) + assert.Error(t, err) + }) + } + + t.Run("Invalid json", func(t *testing.T) { + actual := Port{} + err := actual.UnmarshalJSON([]byte{}) + assert.Error(t, err) + }) + + t.Run("Invalid Range", func(t *testing.T) { + b, err := json.Marshal(float64(100000)) + assert.NoError(t, err) + + actual := Port{} + err = actual.UnmarshalJSON(b) + assert.Error(t, err) + }) + + t.Run("Unmarshal Empty", func(t *testing.T) { + p := Port{} + raw, err := json.Marshal(p) + assert.NoError(t, err) + + actual := Port{} + assert.NoError(t, actual.UnmarshalJSON(raw)) + assert.Equal(t, 0, actual.Port) + }) +} diff --git a/flytestdlib/config/section.go b/flytestdlib/config/section.go new file mode 100644 index 0000000000..41fddd0c1b --- /dev/null +++ b/flytestdlib/config/section.go @@ -0,0 +1,230 @@ +package config + +import ( + "context" + "errors" + "fmt" + "reflect" + "strings" + "sync" + + "github.com/lyft/flytestdlib/atomic" + + "github.com/spf13/pflag" +) + +type Section interface { + // Gets a cloned copy of the Config registered to this section. This config instance does not account for any child + // section registered. + GetConfig() Config + + // Gets a function pointer to call when the config has been updated. + GetConfigUpdatedHandler() SectionUpdated + + // Sets the config and sets a bit indicating whether the new config is different when compared to the existing value. + SetConfig(config Config) error + + // Gets a value indicating whether the config has changed since the last call to GetConfigChangedAndClear and clears + // the changed bit. This operation is atomic. + GetConfigChangedAndClear() bool + + // Retrieves the loaded values for section key if one exists, or nil otherwise. + GetSection(key SectionKey) Section + + // Gets all child config sections. + GetSections() SectionMap + + // Registers a section with the config manager. Section keys are case insensitive and must be unique. + // The section object must be passed by reference since it'll be used to unmarshal into. It must also support json + // marshaling. If the section registered gets updated at runtime, the updatesFn will be invoked to handle the propagation + // of changes. + RegisterSectionWithUpdates(key SectionKey, configSection Config, updatesFn SectionUpdated) (Section, error) + + // Registers a section with the config manager. Section keys are case insensitive and must be unique. + // The section object must be passed by reference since it'll be used to unmarshal into. It must also support json + // marshaling. If the section registered gets updated at runtime, the updatesFn will be invoked to handle the propagation + // of changes. + MustRegisterSectionWithUpdates(key SectionKey, configSection Config, updatesFn SectionUpdated) Section + + // Registers a section with the config manager. Section keys are case insensitive and must be unique. + // The section object must be passed by reference since it'll be used to unmarshal into. It must also support json + // marshaling. + RegisterSection(key SectionKey, configSection Config) (Section, error) + + // Registers a section with the config manager. Section keys are case insensitive and must be unique. + // The section object must be passed by reference since it'll be used to unmarshal into. It must also support json + // marshaling. + MustRegisterSection(key SectionKey, configSection Config) Section +} + +type Config = interface{} +type SectionKey = string +type SectionMap map[SectionKey]Section + +// A section can optionally implements this interface to add its fields as cmdline arguments. +type PFlagProvider interface { + GetPFlagSet(prefix string) *pflag.FlagSet +} + +type SectionUpdated func(ctx context.Context, newValue Config) + +// Global section to use with any root-level config sections registered. +var rootSection = NewRootSection() + +type section struct { + config Config + handler SectionUpdated + isDirty atomic.Bool + sections SectionMap + lockObj sync.RWMutex +} + +// Gets the global root section. +func GetRootSection() Section { + return rootSection +} + +func MustRegisterSection(key SectionKey, configSection Config) Section { + s, err := RegisterSection(key, configSection) + if err != nil { + panic(err) + } + + return s +} + +func (r *section) MustRegisterSection(key SectionKey, configSection Config) Section { + s, err := r.RegisterSection(key, configSection) + if err != nil { + panic(err) + } + + return s +} + +// Registers a section with the config manager. Section keys are case insensitive and must be unique. +// The section object must be passed by reference since it'll be used to unmarshal into. It must also support json +// marshaling. +func RegisterSection(key SectionKey, configSection Config) (Section, error) { + return rootSection.RegisterSection(key, configSection) +} + +func (r *section) RegisterSection(key SectionKey, configSection Config) (Section, error) { + return r.RegisterSectionWithUpdates(key, configSection, nil) +} + +func MustRegisterSectionWithUpdates(key SectionKey, configSection Config, updatesFn SectionUpdated) Section { + s, err := RegisterSectionWithUpdates(key, configSection, updatesFn) + if err != nil { + panic(err) + } + + return s +} + +func (r *section) MustRegisterSectionWithUpdates(key SectionKey, configSection Config, updatesFn SectionUpdated) Section { + s, err := r.RegisterSectionWithUpdates(key, configSection, updatesFn) + if err != nil { + panic(err) + } + + return s +} + +// Registers a section with the config manager. Section keys are case insensitive and must be unique. +// The section object must be passed by reference since it'll be used to unmarshal into. It must also support json +// marshaling. If the section registered gets updated at runtime, the updatesFn will be invoked to handle the propagation +// of changes. +func RegisterSectionWithUpdates(key SectionKey, configSection Config, updatesFn SectionUpdated) (Section, error) { + return rootSection.RegisterSectionWithUpdates(key, configSection, updatesFn) +} + +func (r *section) RegisterSectionWithUpdates(key SectionKey, configSection Config, updatesFn SectionUpdated) (Section, error) { + r.lockObj.Lock() + defer r.lockObj.Unlock() + + key = strings.ToLower(key) + + if len(key) == 0 { + return nil, errors.New("key must be a non-zero string") + } + + if configSection == nil { + return nil, fmt.Errorf("configSection must be a non-nil pointer. SectionKey: %v", key) + } + + if reflect.TypeOf(configSection).Kind() != reflect.Ptr { + return nil, fmt.Errorf("section must be a Pointer. SectionKey: %v", key) + } + + if _, alreadyExists := r.sections[key]; alreadyExists { + return nil, fmt.Errorf("key already exists [%v]", key) + } + + section := NewSection(configSection, updatesFn) + r.sections[key] = section + return section, nil +} + +// Retrieves the loaded values for section key if one exists, or nil otherwise. +func GetSection(key SectionKey) Section { + return rootSection.GetSection(key) +} + +func (r *section) GetSection(key SectionKey) Section { + r.lockObj.RLock() + defer r.lockObj.RUnlock() + + key = strings.ToLower(key) + + if section, alreadyExists := r.sections[key]; alreadyExists { + return section + } + + return nil +} + +func (r *section) GetSections() SectionMap { + return r.sections +} + +func (r *section) GetConfig() Config { + r.lockObj.RLock() + defer r.lockObj.RUnlock() + + return r.config +} + +func (r *section) SetConfig(c Config) error { + r.lockObj.Lock() + defer r.lockObj.Unlock() + + if !DeepEqual(r.config, c) { + r.config = c + r.isDirty.Store(true) + } + + return nil +} + +func (r *section) GetConfigUpdatedHandler() SectionUpdated { + return r.handler +} + +func (r *section) GetConfigChangedAndClear() bool { + return r.isDirty.CompareAndSwap(true, false) +} + +func NewSection(configSection Config, updatesFn SectionUpdated) Section { + return §ion{ + config: configSection, + handler: updatesFn, + isDirty: atomic.NewBool(false), + sections: map[SectionKey]Section{}, + lockObj: sync.RWMutex{}, + } +} + +func NewRootSection() Section { + return NewSection(nil, nil) +} diff --git a/flytestdlib/config/section_test.go b/flytestdlib/config/section_test.go new file mode 100644 index 0000000000..5b314bb339 --- /dev/null +++ b/flytestdlib/config/section_test.go @@ -0,0 +1,119 @@ +package config + +import ( + "flag" + "fmt" + "io/ioutil" + "os" + "path/filepath" + "reflect" + "testing" + "time" + + "k8s.io/apimachinery/pkg/util/rand" + + "github.com/ghodss/yaml" + "github.com/lyft/flytestdlib/internal/utils" + "github.com/spf13/pflag" + + "github.com/stretchr/testify/assert" +) + +// Make sure existing config file(s) parse correctly before overriding them with this flag! +var update = flag.Bool("update", false, "Updates testdata") + +type MyComponentConfig struct { + StringValue string `json:"str"` +} + +type OtherComponentConfig struct { + DurationValue Duration `json:"duration-value"` + URLValue URL `json:"url-value"` + StringValue string `json:"string-value"` + IntValue int `json:"int-val"` + StringArray []string `json:"strings"` +} + +func (MyComponentConfig) GetPFlagSet(prefix string) *pflag.FlagSet { + cmdFlags := pflag.NewFlagSet("MyComponentConfig", pflag.ExitOnError) + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "str"), "hello world", "life is short") + return cmdFlags +} + +func (OtherComponentConfig) GetPFlagSet(prefix string) *pflag.FlagSet { + cmdFlags := pflag.NewFlagSet("MyComponentConfig", pflag.ExitOnError) + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "string-value"), "hello world", "life is short") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "duration-value"), "20s", "") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "int-val"), 4, "this is an important flag") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "url-value"), "http://blah.com", "Sets the type of storage to configure [s3/minio/local/mem].") + return cmdFlags +} + +type TestConfig struct { + MyComponentConfig MyComponentConfig `json:"my-component"` + OtherComponentConfig OtherComponentConfig `json:"other-component"` +} + +func TestMarshal(t *testing.T) { + expected := TestConfig{ + MyComponentConfig: MyComponentConfig{ + StringValue: "Hello World", + }, + OtherComponentConfig: OtherComponentConfig{ + StringValue: "Hey there!", + IntValue: 4, + URLValue: URL{URL: utils.MustParseURL("http://something.com")}, + DurationValue: Duration{Duration: time.Second * 20}, + StringArray: []string{"hello", "world", "!"}, + }, + } + + configPath := filepath.Join("testdata", "config.yaml") + if *update { + t.Log("Updating config file.") + raw, err := yaml.Marshal(expected) + assert.NoError(t, err) + assert.NoError(t, ioutil.WriteFile(configPath, raw, os.ModePerm)) + } + + r := TestConfig{} + raw, err := ioutil.ReadFile(configPath) + assert.NoError(t, err) + assert.NoError(t, yaml.Unmarshal(raw, &r)) + assert.True(t, reflect.DeepEqual(expected, r)) +} + +func TestRegisterSection(t *testing.T) { + t.Run("New Section", func(t *testing.T) { + _, err := RegisterSection(rand.String(6), &TestConfig{}) + assert.NoError(t, err) + }) + + t.Run("Duplicate", func(t *testing.T) { + s := rand.String(6) + _, err := RegisterSection(s, &TestConfig{}) + assert.NoError(t, err) + _, err = RegisterSection(s, &TestConfig{}) + assert.Error(t, err) + }) + + t.Run("Register Nested", func(t *testing.T) { + root := NewRootSection() + s := rand.String(6) + _, err := root.RegisterSection(s, &TestConfig{}) + assert.NoError(t, err) + _, err = root.RegisterSection(s, &TestConfig{}) + assert.Error(t, err) + }) +} + +func TestGetSection(t *testing.T) { + sectionName := rand.String(6) + actual1, err := RegisterSection(sectionName, &TestConfig{}) + assert.NoError(t, err) + assert.Equal(t, reflect.TypeOf(&TestConfig{}), reflect.TypeOf(actual1.GetConfig())) + + actual2 := GetSection(sectionName) + assert.NotNil(t, actual2) + assert.Equal(t, reflect.TypeOf(&TestConfig{}), reflect.TypeOf(actual2.GetConfig())) +} diff --git a/flytestdlib/config/testdata/config.yaml b/flytestdlib/config/testdata/config.yaml new file mode 100755 index 0000000000..2f20ad97b5 --- /dev/null +++ b/flytestdlib/config/testdata/config.yaml @@ -0,0 +1,11 @@ +my-component: + str: Hello World +other-component: + duration-value: 20s + int-val: 4 + string-value: Hey there! + strings: + - hello + - world + - '!' + url-value: http://something.com diff --git a/flytestdlib/config/tests/accessor_test.go b/flytestdlib/config/tests/accessor_test.go new file mode 100644 index 0000000000..34d86237e9 --- /dev/null +++ b/flytestdlib/config/tests/accessor_test.go @@ -0,0 +1,641 @@ +package tests + +import ( + "context" + "crypto/rand" + "encoding/binary" + "fmt" + "io/ioutil" + "os" + "os/exec" + "path" + "path/filepath" + "reflect" + "runtime" + "strings" + "testing" + "time" + + "github.com/fsnotify/fsnotify" + + k8sRand "k8s.io/apimachinery/pkg/util/rand" + + "github.com/lyft/flytestdlib/config" + "github.com/lyft/flytestdlib/internal/utils" + "github.com/spf13/pflag" + + "github.com/ghodss/yaml" + "github.com/stretchr/testify/assert" +) + +type accessorCreatorFn func(registry config.Section, configPath string) config.Accessor + +func getRandInt() uint64 { + c := 10 + b := make([]byte, c) + _, err := rand.Read(b) + if err != nil { + return 0 + } + + return binary.BigEndian.Uint64(b) +} + +func tempFileName(pattern string) string { + // TODO: Remove this hack after we use Go1.11 everywhere: + // https://github.com/golang/go/commit/191efbc419d7e5dec842c20841f6f716da4b561d + + var prefix, suffix string + if pos := strings.LastIndex(pattern, "*"); pos != -1 { + prefix, suffix = pattern[:pos], pattern[pos+1:] + } else { + prefix = pattern + } + + return filepath.Join(os.TempDir(), prefix+k8sRand.String(6)+suffix) +} + +func populateConfigData(configPath string) (TestConfig, error) { + expected := TestConfig{ + MyComponentConfig: MyComponentConfig{ + StringValue: fmt.Sprintf("Hello World %v", getRandInt()), + }, + OtherComponentConfig: OtherComponentConfig{ + StringValue: fmt.Sprintf("Hello World %v", getRandInt()), + URLValue: config.URL{URL: utils.MustParseURL("http://something.com")}, + DurationValue: config.Duration{Duration: time.Second * 20}, + }, + } + + raw, err := yaml.Marshal(expected) + if err != nil { + return TestConfig{}, err + } + + return expected, ioutil.WriteFile(configPath, raw, os.ModePerm) +} + +func TestGetEmptySection(t *testing.T) { + t.Run("empty", func(t *testing.T) { + r := config.GetSection("Empty") + assert.Nil(t, r) + }) +} + +type ComplexType struct { + IntValue int `json:"int-val"` +} + +type ComplexTypeArray []ComplexType + +type ConfigWithLists struct { + ListOfStuff []ComplexType `json:"list"` + StringValue string `json:"string-val"` +} + +type ConfigWithMaps struct { + MapOfStuff map[string]ComplexType `json:"m"` + MapWithoutJSON map[string]ComplexType +} + +type ConfigWithJSONTypes struct { + Duration config.Duration `json:"duration"` +} + +func TestAccessor_InitializePflags(t *testing.T) { + for _, provider := range providers { + t.Run(fmt.Sprintf("[%v] Unused flag", provider(config.Options{}).ID()), func(t *testing.T) { + reg := config.NewRootSection() + _, err := reg.RegisterSection(MyComponentSectionKey, &MyComponentConfig{}) + assert.NoError(t, err) + v := provider(config.Options{ + SearchPaths: []string{filepath.Join("testdata", "config.yaml")}, + RootSection: reg, + }) + + set := pflag.NewFlagSet("test", pflag.ContinueOnError) + set.String("flag1", "123", "") + v.InitializePflags(set) + assert.NoError(t, v.UpdateConfig(context.TODO())) + }) + + t.Run(fmt.Sprintf("[%v] Override string value", provider(config.Options{}).ID()), func(t *testing.T) { + reg := config.NewRootSection() + _, err := reg.RegisterSection(MyComponentSectionKey, &MyComponentConfig{}) + assert.NoError(t, err) + + v := provider(config.Options{ + SearchPaths: []string{filepath.Join("testdata", "config.yaml")}, + RootSection: reg, + }) + + set := pflag.NewFlagSet("test", pflag.ContinueOnError) + v.InitializePflags(set) + key := "MY_COMPONENT.STR2" + assert.NoError(t, os.Setenv(key, "123")) + defer func() { assert.NoError(t, os.Unsetenv(key)) }() + assert.NoError(t, v.UpdateConfig(context.TODO())) + r := reg.GetSection(MyComponentSectionKey).GetConfig().(*MyComponentConfig) + assert.Equal(t, "123", r.StringValue2) + }) + + t.Run(fmt.Sprintf("[%v] Parse from config file", provider(config.Options{}).ID()), func(t *testing.T) { + reg := config.NewRootSection() + _, err := reg.RegisterSection(MyComponentSectionKey, &MyComponentConfig{}) + assert.NoError(t, err) + + _, err = reg.RegisterSection(OtherComponentSectionKey, &OtherComponentConfig{}) + assert.NoError(t, err) + + v := provider(config.Options{ + SearchPaths: []string{filepath.Join("testdata", "config.yaml")}, + RootSection: reg, + }) + + set := pflag.NewFlagSet("test", pflag.ExitOnError) + v.InitializePflags(set) + assert.NoError(t, v.UpdateConfig(context.TODO())) + r := reg.GetSection(MyComponentSectionKey).GetConfig().(*MyComponentConfig) + assert.Equal(t, "Hello World", r.StringValue) + otherC := reg.GetSection(OtherComponentSectionKey).GetConfig().(*OtherComponentConfig) + assert.Equal(t, 4, otherC.IntValue) + assert.Equal(t, []string{"default value"}, otherC.StringArrayWithDefaults) + }) + } +} + +func TestStrictAccessor(t *testing.T) { + for _, provider := range providers { + t.Run(fmt.Sprintf("[%v] Bad config", provider(config.Options{}).ID()), func(t *testing.T) { + reg := config.NewRootSection() + v := provider(config.Options{ + StrictMode: true, + SearchPaths: []string{filepath.Join("testdata", "bad_config.yaml")}, + RootSection: reg, + }) + + set := pflag.NewFlagSet("test", pflag.ExitOnError) + v.InitializePflags(set) + assert.Error(t, v.UpdateConfig(context.TODO())) + }) + + t.Run(fmt.Sprintf("[%v] Set through env", provider(config.Options{}).ID()), func(t *testing.T) { + reg := config.NewRootSection() + _, err := reg.RegisterSection(MyComponentSectionKey, &MyComponentConfig{}) + assert.NoError(t, err) + + _, err = reg.RegisterSection(OtherComponentSectionKey, &OtherComponentConfig{}) + assert.NoError(t, err) + + v := provider(config.Options{ + StrictMode: true, + SearchPaths: []string{filepath.Join("testdata", "config.yaml")}, + RootSection: reg, + }) + + set := pflag.NewFlagSet("other-component.string-value", pflag.ExitOnError) + v.InitializePflags(set) + + key := "OTHER_COMPONENT.STRING_VALUE" + assert.NoError(t, os.Setenv(key, "set from env")) + defer func() { assert.NoError(t, os.Unsetenv(key)) }() + assert.NoError(t, v.UpdateConfig(context.TODO())) + }) + } +} + +func TestAccessor_UpdateConfig(t *testing.T) { + for _, provider := range providers { + t.Run(fmt.Sprintf("[%v] Static File", provider(config.Options{}).ID()), func(t *testing.T) { + reg := config.NewRootSection() + _, err := reg.RegisterSection(MyComponentSectionKey, &MyComponentConfig{}) + assert.NoError(t, err) + + v := provider(config.Options{ + SearchPaths: []string{filepath.Join("testdata", "config.yaml")}, + RootSection: reg, + }) + + assert.NoError(t, v.UpdateConfig(context.TODO())) + r := reg.GetSection(MyComponentSectionKey).GetConfig().(*MyComponentConfig) + assert.Equal(t, "Hello World", r.StringValue) + }) + + t.Run(fmt.Sprintf("[%v] Nested", provider(config.Options{}).ID()), func(t *testing.T) { + root := config.NewRootSection() + section, err := root.RegisterSection(MyComponentSectionKey, &MyComponentConfig{}) + assert.NoError(t, err) + + _, err = section.RegisterSection("nested", &OtherComponentConfig{}) + assert.NoError(t, err) + + v := provider(config.Options{ + SearchPaths: []string{filepath.Join("testdata", "nested_config.yaml")}, + RootSection: root, + }) + + assert.NoError(t, v.UpdateConfig(context.TODO())) + r := root.GetSection(MyComponentSectionKey).GetConfig().(*MyComponentConfig) + assert.Equal(t, "Hello World", r.StringValue) + + nested := section.GetSection("nested").GetConfig().(*OtherComponentConfig) + assert.Equal(t, "Hey there!", nested.StringValue) + }) + + t.Run(fmt.Sprintf("[%v] Array Configs", provider(config.Options{}).ID()), func(t *testing.T) { + root := config.NewRootSection() + section, err := root.RegisterSection(MyComponentSectionKey, &MyComponentConfig{}) + assert.NoError(t, err) + + _, err = section.RegisterSection("nested", &ComplexTypeArray{}) + assert.NoError(t, err) + + _, err = root.RegisterSection("array-config", &ComplexTypeArray{}) + assert.NoError(t, err) + + v := provider(config.Options{ + SearchPaths: []string{filepath.Join("testdata", "array_configs.yaml")}, + RootSection: root, + }) + + assert.NoError(t, v.UpdateConfig(context.TODO())) + r := root.GetSection(MyComponentSectionKey).GetConfig().(*MyComponentConfig) + assert.Equal(t, "Hello World", r.StringValue) + + nested := section.GetSection("nested").GetConfig().(*ComplexTypeArray) + assert.Len(t, *nested, 1) + assert.Equal(t, 1, (*nested)[0].IntValue) + + topLevel := root.GetSection("array-config").GetConfig().(*ComplexTypeArray) + assert.Len(t, *topLevel, 2) + assert.Equal(t, 4, (*topLevel)[1].IntValue) + }) + + t.Run(fmt.Sprintf("[%v] Override in Env Var", provider(config.Options{}).ID()), func(t *testing.T) { + reg := config.NewRootSection() + _, err := reg.RegisterSection(MyComponentSectionKey, &MyComponentConfig{}) + assert.NoError(t, err) + + v := provider(config.Options{ + SearchPaths: []string{filepath.Join("testdata", "config.yaml")}, + RootSection: reg, + }) + key := strings.ToUpper("my-component.str") + assert.NoError(t, os.Setenv(key, "Set From Env")) + defer func() { assert.NoError(t, os.Unsetenv(key)) }() + assert.NoError(t, v.UpdateConfig(context.TODO())) + r := reg.GetSection(MyComponentSectionKey).GetConfig().(*MyComponentConfig) + assert.Equal(t, "Set From Env", r.StringValue) + }) + + t.Run(fmt.Sprintf("[%v] Override in Env Var no config file", provider(config.Options{}).ID()), func(t *testing.T) { + reg := config.NewRootSection() + _, err := reg.RegisterSection(MyComponentSectionKey, &MyComponentConfig{}) + assert.NoError(t, err) + + v := provider(config.Options{RootSection: reg}) + key := strings.ToUpper("my-component.str3") + assert.NoError(t, os.Setenv(key, "Set From Env")) + defer func() { assert.NoError(t, os.Unsetenv(key)) }() + assert.NoError(t, v.UpdateConfig(context.TODO())) + r := reg.GetSection(MyComponentSectionKey).GetConfig().(*MyComponentConfig) + assert.Equal(t, "Set From Env", r.StringValue3) + }) + + t.Run(fmt.Sprintf("[%v] Change handler", provider(config.Options{}).ID()), func(t *testing.T) { + configFile := tempFileName("config-*.yaml") + defer func() { assert.NoError(t, os.Remove(configFile)) }() + _, err := populateConfigData(configFile) + assert.NoError(t, err) + + reg := config.NewRootSection() + _, err = reg.RegisterSection(MyComponentSectionKey, &MyComponentConfig{}) + assert.NoError(t, err) + + opts := config.Options{ + SearchPaths: []string{configFile}, + RootSection: reg, + } + v := provider(opts) + err = v.UpdateConfig(context.TODO()) + assert.NoError(t, err) + + r := reg.GetSection(MyComponentSectionKey).GetConfig().(*MyComponentConfig) + firstValue := r.StringValue + + fileUpdated, err := beginWaitForFileChange(configFile) + assert.NoError(t, err) + + _, err = populateConfigData(configFile) + assert.NoError(t, err) + + // Simulate filewatcher event + assert.NoError(t, waitForFileChangeOrTimeout(fileUpdated)) + + time.Sleep(2 * time.Second) + + r = reg.GetSection(MyComponentSectionKey).GetConfig().(*MyComponentConfig) + secondValue := r.StringValue + assert.NotEqual(t, firstValue, secondValue) + }) + + t.Run(fmt.Sprintf("[%v] Change handler k8s configmaps", provider(config.Options{}).ID()), func(t *testing.T) { + // 1. Create Dir structure + watchDir, configFile, cleanup := newSymlinkedConfigFile(t) + defer cleanup() + + // Independently watch for when symlink underlying change happens to know when do we expect accessor to have picked up + // the changes + fileUpdated, err := beginWaitForFileChange(configFile) + assert.NoError(t, err) + + // 2. Start accessor with the symlink as config location + reg := config.NewRootSection() + section, err := reg.RegisterSection(MyComponentSectionKey, &MyComponentConfig{}) + assert.NoError(t, err) + + opts := config.Options{ + SearchPaths: []string{configFile}, + RootSection: reg, + } + v := provider(opts) + err = v.UpdateConfig(context.TODO()) + assert.NoError(t, err) + + r := section.GetConfig().(*MyComponentConfig) + firstValue := r.StringValue + + // 3. Now update /data symlink to point to data2 + dataDir2 := path.Join(watchDir, "data2") + err = os.Mkdir(dataDir2, os.ModePerm) + assert.NoError(t, err) + + configFile2 := path.Join(dataDir2, "config.yaml") + _, err = populateConfigData(configFile2) + assert.NoError(t, err) + + // change the symlink using the `ln -sfn` command + err = changeSymLink(dataDir2, path.Join(watchDir, "data")) + assert.NoError(t, err) + + t.Logf("New config Location: %v", configFile2) + + // Wait for filewatcher event + assert.NoError(t, waitForFileChangeOrTimeout(fileUpdated)) + + time.Sleep(2 * time.Second) + + r = section.GetConfig().(*MyComponentConfig) + secondValue := r.StringValue + // Make sure values have changed + assert.NotEqual(t, firstValue, secondValue) + }) + } +} + +func changeSymLink(targetPath, symLink string) error { + if runtime.GOOS == "windows" { + tmpLink := tempFileName("temp-sym-link-*") + err := exec.Command("mklink", filepath.Clean(tmpLink), filepath.Clean(targetPath)).Run() + if err != nil { + return err + } + + err = exec.Command("copy", "/l", "/y", filepath.Clean(tmpLink), filepath.Clean(symLink)).Run() + if err != nil { + return err + } + + return exec.Command("del", filepath.Clean(tmpLink)).Run() + } + + return exec.Command("ln", "-sfn", filepath.Clean(targetPath), filepath.Clean(symLink)).Run() +} + +// 1. Create Dir structure: +// |_ data1 +// |_ config.yaml +// |_ data (symlink for data1) +// |_ config.yaml (symlink for data/config.yaml -recursively a symlink of data1/config.yaml) +func newSymlinkedConfigFile(t *testing.T) (watchDir, configFile string, cleanup func()) { + watchDir, err := ioutil.TempDir("", "config-test-") + assert.NoError(t, err) + + dataDir1 := path.Join(watchDir, "data1") + err = os.Mkdir(dataDir1, os.ModePerm) + assert.NoError(t, err) + + realConfigFile := path.Join(dataDir1, "config.yaml") + t.Logf("Real config file location: %s\n", realConfigFile) + _, err = populateConfigData(realConfigFile) + assert.NoError(t, err) + + cleanup = func() { + assert.NoError(t, os.RemoveAll(watchDir)) + } + + // now, symlink the tm `data1` dir to `data` in the baseDir + assert.NoError(t, os.Symlink(dataDir1, path.Join(watchDir, "data"))) + + // and link the `/datadir1/config.yaml` to `/config.yaml` + configFile = path.Join(watchDir, "config.yaml") + assert.NoError(t, os.Symlink(path.Join(watchDir, "data", "config.yaml"), configFile)) + + t.Logf("Config file location: %s\n", path.Join(watchDir, "config.yaml")) + return watchDir, configFile, cleanup +} + +func waitForFileChangeOrTimeout(done chan error) error { + timeout := make(chan bool, 1) + go func() { + time.Sleep(5 * time.Second) + timeout <- true + }() + + for { + select { + case <-timeout: + return fmt.Errorf("timed out") + case err := <-done: + return err + } + } +} + +func beginWaitForFileChange(filename string) (done chan error, terminalErr error) { + watcher, err := fsnotify.NewWatcher() + if err != nil { + return nil, err + } + + configFile := filepath.Clean(filename) + realConfigFile, err := filepath.EvalSymlinks(configFile) + if err != nil { + return nil, err + } + + configDir, _ := filepath.Split(configFile) + + done = make(chan error) + go func() { + for { + select { + case event := <-watcher.Events: + // we only care about the config file + currentConfigFile, err := filepath.EvalSymlinks(filename) + if err != nil { + closeErr := watcher.Close() + if closeErr != nil { + done <- closeErr + } else { + done <- err + } + + return + } + + // We only care about the config file with the following cases: + // 1 - if the config file was modified or created + // 2 - if the real path to the config file changed (eg: k8s ConfigMap replacement) + const writeOrCreateMask = fsnotify.Write | fsnotify.Create + if (filepath.Clean(event.Name) == configFile && + event.Op&writeOrCreateMask != 0) || + (currentConfigFile != "" && currentConfigFile != realConfigFile) { + realConfigFile = currentConfigFile + closeErr := watcher.Close() + if closeErr != nil { + fmt.Printf("Close Watcher error: %v\n", closeErr) + } else { + done <- nil + } + + return + } else if filepath.Clean(event.Name) == configFile && + event.Op&fsnotify.Remove&fsnotify.Remove != 0 { + closeErr := watcher.Close() + if closeErr != nil { + fmt.Printf("Close Watcher error: %v\n", closeErr) + } else { + done <- nil + } + + return + } + case err, ok := <-watcher.Errors: + if ok { + fmt.Printf("Watcher error: %v\n", err) + closeErr := watcher.Close() + if closeErr != nil { + fmt.Printf("Close Watcher error: %v\n", closeErr) + } + } + + done <- nil + return + } + } + }() + + err = watcher.Add(configDir) + if err != nil { + return nil, err + } + + return done, err +} + +func testTypes(accessor accessorCreatorFn) func(t *testing.T) { + return func(t *testing.T) { + t.Run("ArrayConfigType", func(t *testing.T) { + expected := ComplexTypeArray{ + {IntValue: 1}, + {IntValue: 4}, + } + + runEqualTest(t, accessor, &expected, &ComplexTypeArray{}) + }) + + t.Run("Lists", func(t *testing.T) { + expected := ConfigWithLists{ + ListOfStuff: []ComplexType{ + {IntValue: 1}, + {IntValue: 4}, + }, + } + + runEqualTest(t, accessor, &expected, &ConfigWithLists{}) + }) + + t.Run("Maps", func(t *testing.T) { + expected := ConfigWithMaps{ + MapOfStuff: map[string]ComplexType{ + "item1": {IntValue: 1}, + "item2": {IntValue: 3}, + }, + MapWithoutJSON: map[string]ComplexType{ + "it-1": {IntValue: 5}, + }, + } + + runEqualTest(t, accessor, &expected, &ConfigWithMaps{}) + }) + + t.Run("JsonUnmarshalableTypes", func(t *testing.T) { + expected := ConfigWithJSONTypes{ + Duration: config.Duration{ + Duration: time.Second * 10, + }, + } + + runEqualTest(t, accessor, &expected, &ConfigWithJSONTypes{}) + }) + } +} + +func runEqualTest(t *testing.T, accessor accessorCreatorFn, expected interface{}, emptyType interface{}) { + assert.NotPanics(t, func() { + reflect.TypeOf(expected).Elem() + }, "expected must be a Pointer type. Instead, it was %v", reflect.TypeOf(expected)) + + assert.Equal(t, reflect.TypeOf(expected), reflect.TypeOf(emptyType)) + + rootSection := config.NewRootSection() + sectionKey := fmt.Sprintf("rand-key-%v", getRandInt()%2000) + _, err := rootSection.RegisterSection(sectionKey, emptyType) + assert.NoError(t, err) + + m := map[string]interface{}{ + sectionKey: expected, + } + + raw, err := yaml.Marshal(m) + assert.NoError(t, err) + f := tempFileName("test_type_*.yaml") + assert.NoError(t, err) + defer func() { assert.NoError(t, os.Remove(f)) }() + + assert.NoError(t, ioutil.WriteFile(f, raw, os.ModePerm)) + t.Logf("Generated yaml: %v", string(raw)) + assert.NoError(t, accessor(rootSection, f).UpdateConfig(context.TODO())) + + res := rootSection.GetSection(sectionKey).GetConfig() + t.Logf("Expected: %+v", expected) + t.Logf("Actual: %+v", res) + assert.True(t, reflect.DeepEqual(res, expected)) +} + +func TestAccessor_Integration(t *testing.T) { + accessorsToTest := make([]accessorCreatorFn, 0, len(providers)) + for _, provider := range providers { + accessorsToTest = append(accessorsToTest, func(r config.Section, configPath string) config.Accessor { + return provider(config.Options{ + SearchPaths: []string{configPath}, + RootSection: r, + }) + }) + } + + for _, accessor := range accessorsToTest { + t.Run(fmt.Sprintf(testNameFormatter, accessor(nil, "").ID(), "Types"), testTypes(accessor)) + } +} diff --git a/flytestdlib/config/tests/config_cmd_test.go b/flytestdlib/config/tests/config_cmd_test.go new file mode 100644 index 0000000000..3b15268292 --- /dev/null +++ b/flytestdlib/config/tests/config_cmd_test.go @@ -0,0 +1,92 @@ +package tests + +import ( + "bytes" + "fmt" + "os" + "testing" + + "github.com/lyft/flytestdlib/config" + + "github.com/spf13/cobra" + "github.com/stretchr/testify/assert" +) + +func executeCommand(root *cobra.Command, args ...string) (output string, err error) { + _, output, err = executeCommandC(root, args...) + return output, err +} + +func executeCommandC(root *cobra.Command, args ...string) (c *cobra.Command, output string, err error) { + buf := new(bytes.Buffer) + root.SetOutput(buf) + root.SetArgs(args) + + c, err = root.ExecuteC() + + return c, buf.String(), err +} + +func TestDiscoverCommand(t *testing.T) { + for _, provider := range providers { + t.Run(fmt.Sprintf(testNameFormatter, provider(config.Options{}).ID(), "No config file"), func(t *testing.T) { + cmd := config.NewConfigCommand(provider) + output, err := executeCommand(cmd, config.CommandDiscover) + assert.NoError(t, err) + assert.Contains(t, output, "Couldn't find a config file.") + }) + + t.Run(fmt.Sprintf(testNameFormatter, provider(config.Options{}).ID(), "Valid config file"), func(t *testing.T) { + dir, err := os.Getwd() + assert.NoError(t, err) + wd := os.ExpandEnv("$PWD/testdata") + err = os.Chdir(wd) + assert.NoError(t, err) + defer func() { assert.NoError(t, os.Chdir(dir)) }() + + cmd := config.NewConfigCommand(provider) + output, err := executeCommand(cmd, config.CommandDiscover) + assert.NoError(t, err) + assert.Contains(t, output, "Config") + }) + } +} + +func TestValidateCommand(t *testing.T) { + for _, provider := range providers { + t.Run(fmt.Sprintf(testNameFormatter, provider(config.Options{}).ID(), "No config file"), func(t *testing.T) { + cmd := config.NewConfigCommand(provider) + output, err := executeCommand(cmd, config.CommandValidate) + assert.NoError(t, err) + assert.Contains(t, output, "Couldn't find a config file.") + }) + + t.Run(fmt.Sprintf(testNameFormatter, provider(config.Options{}).ID(), "Invalid Config file"), func(t *testing.T) { + dir, err := os.Getwd() + assert.NoError(t, err) + wd := os.ExpandEnv("$PWD/testdata") + err = os.Chdir(wd) + assert.NoError(t, err) + defer func() { assert.NoError(t, os.Chdir(dir)) }() + + cmd := config.NewConfigCommand(provider) + output, err := executeCommand(cmd, config.CommandValidate, "--file=bad_config.yaml", "--strict") + assert.Error(t, err) + assert.Contains(t, output, "Failed") + }) + + t.Run(fmt.Sprintf(testNameFormatter, provider(config.Options{}).ID(), "Valid config file"), func(t *testing.T) { + dir, err := os.Getwd() + assert.NoError(t, err) + wd := os.ExpandEnv("$PWD/testdata") + err = os.Chdir(wd) + assert.NoError(t, err) + defer func() { assert.NoError(t, os.Chdir(dir)) }() + + cmd := config.NewConfigCommand(provider) + output, err := executeCommand(cmd, config.CommandValidate) + assert.NoError(t, err) + assert.Contains(t, output, "successfully") + }) + } +} diff --git a/flytestdlib/config/tests/testdata/array_configs.yaml b/flytestdlib/config/tests/testdata/array_configs.yaml new file mode 100644 index 0000000000..6a02a280a6 --- /dev/null +++ b/flytestdlib/config/tests/testdata/array_configs.yaml @@ -0,0 +1,7 @@ +my-component: + str: Hello World + nested: + - int-val: 1 +array-config: + - int-val: 1 + - int-val: 4 diff --git a/flytestdlib/config/tests/testdata/bad_config.yaml b/flytestdlib/config/tests/testdata/bad_config.yaml new file mode 100644 index 0000000000..c0707abdeb --- /dev/null +++ b/flytestdlib/config/tests/testdata/bad_config.yaml @@ -0,0 +1,13 @@ +my-component: + str: Hello World +other-component: + duration-value: 20s + int-val: 4 + string-value: Hey there! + strings: + - hello + - world + - '!' + url-value: http://something.com + unknown-key: "something" + diff --git a/flytestdlib/config/tests/testdata/config.yaml b/flytestdlib/config/tests/testdata/config.yaml new file mode 100755 index 0000000000..ca78698fae --- /dev/null +++ b/flytestdlib/config/tests/testdata/config.yaml @@ -0,0 +1,11 @@ +my-component: + str: Hello World +other-component: + duration-value: 20s + int-val: 4 + string-value: Hey there! + strings: + - hello + - world + - '!' + url-value: http://something.com diff --git a/flytestdlib/config/tests/testdata/nested_config.yaml b/flytestdlib/config/tests/testdata/nested_config.yaml new file mode 100755 index 0000000000..321f563a42 --- /dev/null +++ b/flytestdlib/config/tests/testdata/nested_config.yaml @@ -0,0 +1,11 @@ +my-component: + str: Hello World + nested: + duration-value: 20s + int-val: 4 + string-value: Hey there! + strings: + - hello + - world + - '!' + url-value: http://something.com diff --git a/flytestdlib/config/tests/types_test.go b/flytestdlib/config/tests/types_test.go new file mode 100644 index 0000000000..c6150e93fa --- /dev/null +++ b/flytestdlib/config/tests/types_test.go @@ -0,0 +1,66 @@ +package tests + +import ( + "fmt" + + "github.com/lyft/flytestdlib/config" + "github.com/lyft/flytestdlib/config/viper" + "github.com/spf13/pflag" +) + +const testNameFormatter = "[%v] %v" + +var providers = []config.AccessorProvider{viper.NewAccessor} + +type MyComponentConfig struct { + StringValue string `json:"str"` + StringValue2 string `json:"str2"` + StringValue3 string `json:"str3"` +} + +type OtherComponentConfig struct { + DurationValue config.Duration `json:"duration-value"` + URLValue config.URL `json:"url-value"` + StringValue string `json:"string-value"` + IntValue int `json:"int-val"` + StringArray []string `json:"strings"` + StringArrayWithDefaults []string `json:"strings-def"` +} + +func (MyComponentConfig) GetPFlagSet(prefix string) *pflag.FlagSet { + cmdFlags := pflag.NewFlagSet("MyComponentConfig", pflag.ExitOnError) + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "str"), "hello world", "life is short") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "str2"), "hello world", "life is short") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "str3"), "hello world", "life is short") + return cmdFlags +} + +func (OtherComponentConfig) GetPFlagSet(prefix string) *pflag.FlagSet { + cmdFlags := pflag.NewFlagSet("MyComponentConfig", pflag.ExitOnError) + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "string-value"), "hello world", "life is short") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "duration-value"), "20s", "") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "int-val"), 4, "this is an important flag") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "url-value"), "http://blah.com", "Sets the type of storage to configure [s3/minio/local/mem].") + cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "strings-def"), []string{"default value"}, "Sets the type of storage to configure [s3/minio/local/mem].") + return cmdFlags +} + +type TestConfig struct { + MyComponentConfig MyComponentConfig `json:"my-component"` + OtherComponentConfig OtherComponentConfig `json:"other-component"` +} + +const ( + MyComponentSectionKey = "my-component" + OtherComponentSectionKey = "other-component" +) + +func init() { + if _, err := config.RegisterSection(MyComponentSectionKey, &MyComponentConfig{}); err != nil { + panic(err) + } + + if _, err := config.RegisterSection(OtherComponentSectionKey, &OtherComponentConfig{}); err != nil { + panic(err) + } +} diff --git a/flytestdlib/config/url.go b/flytestdlib/config/url.go new file mode 100644 index 0000000000..4045caf962 --- /dev/null +++ b/flytestdlib/config/url.go @@ -0,0 +1,36 @@ +package config + +import ( + "encoding/json" + "errors" + "net/url" +) + +// A url.URL wrapper that can marshal and unmarshal into simple URL strings. +type URL struct { + url.URL +} + +func (d URL) MarshalJSON() ([]byte, error) { + return json.Marshal(d.String()) +} + +func (d *URL) UnmarshalJSON(b []byte) error { + var v interface{} + if err := json.Unmarshal(b, &v); err != nil { + return err + } + + switch value := v.(type) { + case string: + u, err := url.Parse(value) + if err != nil { + return err + } + + d.URL = *u + return nil + default: + return errors.New("invalid url") + } +} diff --git a/flytestdlib/config/url_test.go b/flytestdlib/config/url_test.go new file mode 100644 index 0000000000..e4046b3b16 --- /dev/null +++ b/flytestdlib/config/url_test.go @@ -0,0 +1,59 @@ +package config + +import ( + "encoding/json" + "fmt" + "reflect" + "testing" + + "github.com/lyft/flytestdlib/internal/utils" + + "github.com/stretchr/testify/assert" +) + +func TestURL_MarshalJSON(t *testing.T) { + validURLs := []string{ + "http://localhost:123", + "http://localhost", + "https://non-existent.com/path/to/something", + } + + for i, validURL := range validURLs { + t.Run(fmt.Sprintf("Valid %v", i), func(t *testing.T) { + expected := URL{URL: utils.MustParseURL(validURL)} + + b, err := expected.MarshalJSON() + assert.NoError(t, err) + + actual := URL{} + err = actual.UnmarshalJSON(b) + assert.NoError(t, err) + + assert.True(t, reflect.DeepEqual(expected, actual)) + }) + } +} + +func TestURL_UnmarshalJSON(t *testing.T) { + invalidValues := []interface{}{ + "%gh&%ij", + 123, + true, + } + for i, invalidURL := range invalidValues { + t.Run(fmt.Sprintf("Invalid %v", i), func(t *testing.T) { + raw, err := json.Marshal(invalidURL) + assert.NoError(t, err) + + actual := URL{} + err = actual.UnmarshalJSON(raw) + assert.Error(t, err) + }) + } + + t.Run("Invalid json", func(t *testing.T) { + actual := URL{} + err := actual.UnmarshalJSON([]byte{}) + assert.Error(t, err) + }) +} diff --git a/flytestdlib/config/utils.go b/flytestdlib/config/utils.go new file mode 100644 index 0000000000..fcd833ff75 --- /dev/null +++ b/flytestdlib/config/utils.go @@ -0,0 +1,69 @@ +package config + +import ( + "encoding/json" + "fmt" + "reflect" + + "github.com/pkg/errors" +) + +// Uses Json marshal/unmarshal to make a deep copy of a config object. +func DeepCopyConfig(config Config) (Config, error) { + raw, err := json.Marshal(config) + if err != nil { + return nil, err + } + + t := reflect.TypeOf(config) + ptrValue := reflect.New(t) + newObj := ptrValue.Interface() + if err = json.Unmarshal(raw, newObj); err != nil { + return nil, err + } + + return ptrValue.Elem().Interface(), nil +} + +func DeepEqual(config1, config2 Config) bool { + return reflect.DeepEqual(config1, config2) +} + +func toInterface(config Config) (interface{}, error) { + raw, err := json.Marshal(config) + if err != nil { + return nil, err + } + + var m interface{} + err = json.Unmarshal(raw, &m) + return m, err +} + +// Builds a generic map out of the root section config and its sub-sections configs. +func AllConfigsAsMap(root Section) (m map[string]interface{}, err error) { + errs := ErrorCollection{} + allConfigs := make(map[string]interface{}, len(root.GetSections())) + if root.GetConfig() != nil { + rootConfig, err := toInterface(root.GetConfig()) + if !errs.Append(err) { + if asMap, isCasted := rootConfig.(map[string]interface{}); isCasted { + allConfigs = asMap + } else { + allConfigs[""] = rootConfig + } + } + } + + for k, section := range root.GetSections() { + if _, alreadyExists := allConfigs[k]; alreadyExists { + errs.Append(errors.Wrap(ErrChildConfigOverridesConfig, + fmt.Sprintf("section key [%v] overrides an existing native config property", k))) + } + + allConfigs[k], err = AllConfigsAsMap(section) + errs.Append(err) + } + + return allConfigs, errs.ErrorOrDefault() +} diff --git a/flytestdlib/config/utils_test.go b/flytestdlib/config/utils_test.go new file mode 100644 index 0000000000..e5d7d02105 --- /dev/null +++ b/flytestdlib/config/utils_test.go @@ -0,0 +1,39 @@ +package config + +import ( + "reflect" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDeepCopyConfig(t *testing.T) { + type pair struct { + first interface{} + second interface{} + } + + type fakeConfig struct { + I int + S string + Ptr *fakeConfig + } + + testCases := []pair{ + {3, 3}, + {"word", "word"}, + {fakeConfig{I: 4, S: "four", Ptr: &fakeConfig{I: 5, S: "five"}}, fakeConfig{I: 4, S: "four", Ptr: &fakeConfig{I: 5, S: "five"}}}, + {&fakeConfig{I: 4, S: "four", Ptr: &fakeConfig{I: 5, S: "five"}}, &fakeConfig{I: 4, S: "four", Ptr: &fakeConfig{I: 5, S: "five"}}}, + } + + for i, testCase := range testCases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + input, expected := testCase.first, testCase.second + actual, err := DeepCopyConfig(input) + assert.NoError(t, err) + assert.Equal(t, reflect.TypeOf(expected).String(), reflect.TypeOf(actual).String()) + assert.Equal(t, expected, actual) + }) + } +} diff --git a/flytestdlib/config/viper/collection.go b/flytestdlib/config/viper/collection.go new file mode 100644 index 0000000000..680052f494 --- /dev/null +++ b/flytestdlib/config/viper/collection.go @@ -0,0 +1,175 @@ +package viper + +import ( + "context" + "fmt" + "io" + "os" + "strings" + + "github.com/lyft/flytestdlib/config" + + "github.com/lyft/flytestdlib/logger" + + viperLib "github.com/spf13/viper" + + "github.com/fsnotify/fsnotify" + "github.com/spf13/pflag" +) + +type Viper interface { + BindPFlags(flags *pflag.FlagSet) error + BindEnv(input ...string) error + AutomaticEnv() + ReadInConfig() error + OnConfigChange(run func(in fsnotify.Event)) + WatchConfig() + AllSettings() map[string]interface{} + ConfigFileUsed() string + MergeConfig(in io.Reader) error +} + +// A proxy object for a collection of Viper instances. +type CollectionProxy struct { + underlying []Viper + pflags *pflag.FlagSet + envVars [][]string + automaticEnv bool +} + +func (c *CollectionProxy) BindPFlags(flags *pflag.FlagSet) error { + err := config.ErrorCollection{} + for _, v := range c.underlying { + err.Append(v.BindPFlags(flags)) + } + + c.pflags = flags + + return err.ErrorOrDefault() +} + +func (c *CollectionProxy) BindEnv(input ...string) error { + err := config.ErrorCollection{} + for _, v := range c.underlying { + err.Append(v.BindEnv(input...)) + } + + if c.envVars == nil { + c.envVars = make([][]string, 0, 1) + } + + c.envVars = append(c.envVars, input) + + return err.ErrorOrDefault() +} + +func (c *CollectionProxy) AutomaticEnv() { + for _, v := range c.underlying { + v.AutomaticEnv() + } + + c.automaticEnv = true +} + +func (c CollectionProxy) ReadInConfig() error { + err := config.ErrorCollection{} + for _, v := range c.underlying { + err.Append(v.ReadInConfig()) + } + + return err.ErrorOrDefault() +} + +func (c CollectionProxy) OnConfigChange(run func(in fsnotify.Event)) { + for _, v := range c.underlying { + v.OnConfigChange(run) + } +} + +func (c CollectionProxy) WatchConfig() { + for _, v := range c.underlying { + v.WatchConfig() + } +} + +func (c CollectionProxy) AllSettings() map[string]interface{} { + finalRes := map[string]interface{}{} + if len(c.underlying) == 0 { + return finalRes + } + + combinedConfig, err := c.MergeAllConfigs() + if err != nil { + logger.Warnf(context.TODO(), "Failed to merge config. Error: %v", err) + return finalRes + } + + return combinedConfig.AllSettings() +} + +func (c CollectionProxy) ConfigFileUsed() string { + return fmt.Sprintf("[%v]", strings.Join(c.ConfigFilesUsed(), ",")) +} + +func (c CollectionProxy) MergeConfig(in io.Reader) error { + panic("Not yet implemented.") +} + +func (c CollectionProxy) MergeAllConfigs() (all Viper, err error) { + combinedConfig := viperLib.New() + if c.envVars != nil { + for _, envConfig := range c.envVars { + err = combinedConfig.BindEnv(envConfig...) + if err != nil { + return nil, err + } + } + } + + if c.automaticEnv { + combinedConfig.AutomaticEnv() + } + + if c.pflags != nil { + err = combinedConfig.BindPFlags(c.pflags) + if err != nil { + return nil, err + } + } + + for _, v := range c.underlying { + if _, isCollection := v.(*CollectionProxy); isCollection { + return nil, fmt.Errorf("merging nested CollectionProxies is not yet supported") + } + + if len(v.ConfigFileUsed()) == 0 { + continue + } + + combinedConfig.SetConfigFile(v.ConfigFileUsed()) + + reader, err := os.Open(v.ConfigFileUsed()) + if err != nil { + return nil, err + } + + err = combinedConfig.MergeConfig(reader) + if err != nil { + return nil, err + } + } + + return combinedConfig, nil +} + +func (c CollectionProxy) ConfigFilesUsed() []string { + res := make([]string, 0, len(c.underlying)) + for _, v := range c.underlying { + filePath := v.ConfigFileUsed() + if len(filePath) > 0 { + res = append(res, filePath) + } + } + + return res +} diff --git a/flytestdlib/config/viper/viper.go b/flytestdlib/config/viper/viper.go new file mode 100644 index 0000000000..05a5378ec7 --- /dev/null +++ b/flytestdlib/config/viper/viper.go @@ -0,0 +1,357 @@ +package viper + +import ( + "context" + "encoding/json" + "flag" + "fmt" + "reflect" + "strings" + "sync" + + "github.com/pkg/errors" + + "github.com/lyft/flytestdlib/config" + "github.com/lyft/flytestdlib/config/files" + "github.com/lyft/flytestdlib/logger" + "github.com/spf13/cobra" + + "github.com/fsnotify/fsnotify" + "github.com/mitchellh/mapstructure" + + "github.com/spf13/pflag" + viperLib "github.com/spf13/viper" +) + +const ( + keyDelim = "." +) + +var ( + dereferencableKinds = map[reflect.Kind]struct{}{ + reflect.Array: {}, reflect.Chan: {}, reflect.Map: {}, reflect.Ptr: {}, reflect.Slice: {}, + } +) + +type viperAccessor struct { + // Determines whether parsing config should fail if it contains un-registered sections. + strictMode bool + viper *CollectionProxy + rootConfig config.Section + // Ensures we initialize the file Watcher once. + watcherInitializer *sync.Once +} + +func (viperAccessor) ID() string { + return "Viper" +} + +func (viperAccessor) InitializeFlags(cmdFlags *flag.FlagSet) { + // TODO: Implement? +} + +func (v viperAccessor) InitializePflags(cmdFlags *pflag.FlagSet) { + err := v.addSectionsPFlags(cmdFlags) + if err != nil { + panic(errors.Wrap(err, "error adding config PFlags to flag set")) + } + + // Allow viper to read the value of the flags + err = v.viper.BindPFlags(cmdFlags) + if err != nil { + panic(errors.Wrap(err, "error binding PFlags")) + } +} + +func (v viperAccessor) addSectionsPFlags(flags *pflag.FlagSet) (err error) { + for key, section := range v.rootConfig.GetSections() { + if asPFlagProvider, ok := section.GetConfig().(config.PFlagProvider); ok { + flags.AddFlagSet(asPFlagProvider.GetPFlagSet(key + keyDelim)) + } + } + + return nil +} + +// Binds keys from all sections to viper env vars. This instructs viper to lookup those from env vars when we ask for +// viperLib.AllSettings() +func (v viperAccessor) bindViperConfigsFromEnv(root config.Section) (err error) { + allConfigs, err := config.AllConfigsAsMap(root) + if err != nil { + return err + } + + return v.bindViperConfigsEnvDepth(allConfigs, "") +} + +func (v viperAccessor) bindViperConfigsEnvDepth(m map[string]interface{}, prefix string) error { + errs := config.ErrorCollection{} + for key, val := range m { + subKey := prefix + key + if asMap, ok := val.(map[string]interface{}); ok { + errs.Append(v.bindViperConfigsEnvDepth(asMap, subKey+keyDelim)) + } else { + errs.Append(v.viper.BindEnv(subKey, strings.ToUpper(strings.Replace(subKey, "-", "_", -1)))) + } + } + + return errs.ErrorOrDefault() +} + +func (v viperAccessor) updateConfig(ctx context.Context, r config.Section) error { + // Binds all keys to env vars. + err := v.bindViperConfigsFromEnv(r) + if err != nil { + return err + } + + v.viper.AutomaticEnv() // read in environment variables that match + + shouldWatchChanges := true + // If a config file is found, read it in. + if err = v.viper.ReadInConfig(); err == nil { + logger.Printf(ctx, "Using config file: %+v", v.viper.ConfigFilesUsed()) + } else if asErrorCollection, ok := err.(config.ErrorCollection); ok { + shouldWatchChanges = false + for i, e := range asErrorCollection { + if _, isNotFound := errors.Cause(e).(viperLib.ConfigFileNotFoundError); isNotFound { + logger.Printf(ctx, "[%v] Couldn't find a config file [%v]. Relying on env vars and pflags.", + i, v.viper.underlying[i].ConfigFileUsed()) + } else { + return err + } + } + } else if reflect.TypeOf(err) == reflect.TypeOf(viperLib.ConfigFileNotFoundError{}) { + shouldWatchChanges = false + logger.Printf(ctx, "Couldn't find a config file. Relying on env vars and pflags.") + } else { + return err + } + + if shouldWatchChanges { + v.watcherInitializer.Do(func() { + // Watch config files to pick up on file changes without requiring a full application restart. + // This call must occur after *all* config paths have been added. + v.viper.OnConfigChange(func(e fsnotify.Event) { + fmt.Printf("Got a notification change for file [%v]\n", e.Name) + v.configChangeHandler() + }) + v.viper.WatchConfig() + }) + } + + return v.RefreshFromConfig(ctx, r) +} + +func (v viperAccessor) UpdateConfig(ctx context.Context) error { + return v.updateConfig(ctx, v.rootConfig) +} + +// Checks if t is a kind that can be dereferenced to get its underlying type. +func canGetElement(t reflect.Kind) bool { + _, exists := dereferencableKinds[t] + return exists +} + +// This decoder hook tests types for json unmarshaling capability. If implemented, it uses json unmarshal to build the +// object. Otherwise, it'll just pass on the original data. +func jsonUnmarshallerHook(_, to reflect.Type, data interface{}) (interface{}, error) { + unmarshalerType := reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() + if to.Implements(unmarshalerType) || reflect.PtrTo(to).Implements(unmarshalerType) || + (canGetElement(to.Kind()) && to.Elem().Implements(unmarshalerType)) { + + ctx := context.Background() + raw, err := json.Marshal(data) + if err != nil { + logger.Printf(ctx, "Failed to marshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + res := reflect.New(to).Interface() + err = json.Unmarshal(raw, &res) + if err != nil { + logger.Printf(ctx, "Failed to umarshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + return res, nil + } + + return data, nil +} + +// Parses RootType config from parsed Viper settings. This should be called after viper has parsed config file/pflags...etc. +func (v viperAccessor) parseViperConfig(root config.Section) error { + // We use AllSettings instead of AllKeys to get the root level keys folded. + return v.parseViperConfigRecursive(root, v.viper.AllSettings()) +} + +func (v viperAccessor) parseViperConfigRecursive(root config.Section, settings interface{}) error { + errs := config.ErrorCollection{} + var mine interface{} + myKeysCount := 0 + if asMap, casted := settings.(map[string]interface{}); casted { + myMap := map[string]interface{}{} + for childKey, childValue := range asMap { + if childSection, found := root.GetSections()[childKey]; found { + errs.Append(v.parseViperConfigRecursive(childSection, childValue)) + } else { + myMap[childKey] = childValue + } + } + + mine = myMap + myKeysCount = len(myMap) + } else if asSlice, casted := settings.([]interface{}); casted { + mine = settings + myKeysCount = len(asSlice) + } else { + mine = settings + if settings != nil { + myKeysCount = 1 + } + } + + if root.GetConfig() != nil { + c, err := config.DeepCopyConfig(root.GetConfig()) + errs.Append(err) + if err != nil { + return errs.ErrorOrDefault() + } + + errs.Append(decode(mine, defaultDecoderConfig(c, v.decoderConfigs()...))) + errs.Append(root.SetConfig(c)) + + return errs.ErrorOrDefault() + } else if myKeysCount > 0 { + // There are keys set that are meant to be decoded but no config to receive them. Fail if strict mode is on. + if v.strictMode { + errs.Append(errors.Wrap( + config.ErrStrictModeValidation, + fmt.Sprintf("strict mode is on but received keys [%+v] to decode with no config assigned to"+ + " receive them", mine))) + } + } + + return errs.ErrorOrDefault() +} + +// Adds any specific configs controlled by this viper accessor instance. +func (v viperAccessor) decoderConfigs() []viperLib.DecoderConfigOption { + return []viperLib.DecoderConfigOption{ + func(config *mapstructure.DecoderConfig) { + config.ErrorUnused = v.strictMode + }, + } +} + +// defaultDecoderConfig returns default mapsstructure.DecoderConfig with support +// of time.Duration values & string slices +func defaultDecoderConfig(output interface{}, opts ...viperLib.DecoderConfigOption) *mapstructure.DecoderConfig { + c := &mapstructure.DecoderConfig{ + Metadata: nil, + Result: output, + WeaklyTypedInput: true, + TagName: "json", + DecodeHook: mapstructure.ComposeDecodeHookFunc( + jsonUnmarshallerHook, + mapstructure.StringToTimeDurationHookFunc(), + mapstructure.StringToSliceHookFunc(","), + ), + } + + for _, opt := range opts { + opt(c) + } + + return c +} + +// A wrapper around mapstructure.Decode that mimics the WeakDecode functionality +func decode(input interface{}, config *mapstructure.DecoderConfig) error { + decoder, err := mapstructure.NewDecoder(config) + if err != nil { + return err + } + return decoder.Decode(input) +} + +func (v viperAccessor) configChangeHandler() { + ctx := context.Background() + err := v.RefreshFromConfig(ctx, v.rootConfig) + if err != nil { + // TODO: Retry? panic? + logger.Printf(ctx, "Failed to update config. Error: %v", err) + } else { + logger.Printf(ctx, "Refreshed config in response to file(s) change.") + } +} + +func (v viperAccessor) RefreshFromConfig(ctx context.Context, r config.Section) error { + err := v.parseViperConfig(r) + if err != nil { + return err + } + + v.sendUpdatedEvents(ctx, r, "") + + return nil +} + +func (v viperAccessor) sendUpdatedEvents(ctx context.Context, root config.Section, sectionKey config.SectionKey) { + for key, section := range root.GetSections() { + if !section.GetConfigChangedAndClear() { + logger.Infof(ctx, "Config section [%v] hasn't changed.", sectionKey+key) + } else if section.GetConfigUpdatedHandler() == nil { + logger.Infof(ctx, "Config section [%v] updated. No update handler registered.", sectionKey+key) + } else { + logger.Infof(ctx, "Config section [%v] updated. Firing updated event.", sectionKey+key) + section.GetConfigUpdatedHandler()(ctx, section.GetConfig()) + } + + v.sendUpdatedEvents(ctx, section, sectionKey+key+keyDelim) + } +} + +func (v viperAccessor) ConfigFilesUsed() []string { + return v.viper.ConfigFilesUsed() +} + +// Creates a config accessor that implements Accessor interface and uses viper to load configs. +func NewAccessor(opts config.Options) config.Accessor { + return newAccessor(opts) +} + +func newAccessor(opts config.Options) viperAccessor { + vipers := make([]Viper, 0, 1) + configFiles := files.FindConfigFiles(opts.SearchPaths) + for _, configFile := range configFiles { + v := viperLib.New() + v.SetConfigFile(configFile) + + vipers = append(vipers, v) + } + + // Create a default viper even if we couldn't find any matching files + if len(configFiles) == 0 { + v := viperLib.New() + vipers = append(vipers, v) + } + + r := opts.RootSection + if r == nil { + r = config.GetRootSection() + } + + return viperAccessor{ + strictMode: opts.StrictMode, + rootConfig: r, + viper: &CollectionProxy{underlying: vipers}, + watcherInitializer: &sync.Once{}, + } +} + +// Gets the root level command that can be added to any cobra-powered cli to get config* commands. +func GetConfigCommand() *cobra.Command { + return config.NewConfigCommand(NewAccessor) +} diff --git a/flytestdlib/contextutils/context.go b/flytestdlib/contextutils/context.go new file mode 100644 index 0000000000..b5a9c00fa2 --- /dev/null +++ b/flytestdlib/contextutils/context.go @@ -0,0 +1,140 @@ +// Contains common flyte context utils. +package contextutils + +import ( + "context" + "fmt" +) + +type Key string + +const ( + AppNameKey Key = "app_name" + NamespaceKey Key = "ns" + TaskTypeKey Key = "tasktype" + ProjectKey Key = "project" + DomainKey Key = "domain" + WorkflowIDKey Key = "wf" + NodeIDKey Key = "node" + TaskIDKey Key = "task" + ExecIDKey Key = "exec_id" + JobIDKey Key = "job_id" + PhaseKey Key = "phase" +) + +func (k Key) String() string { + return string(k) +} + +var logKeys = []Key{ + AppNameKey, + JobIDKey, + NamespaceKey, + ExecIDKey, + NodeIDKey, + WorkflowIDKey, + TaskTypeKey, + PhaseKey, +} + +// Gets a new context with namespace set. +func WithNamespace(ctx context.Context, namespace string) context.Context { + return context.WithValue(ctx, NamespaceKey, namespace) +} + +// Gets a new context with JobId set. If the existing context already has a job id, the new context will have +// / set as the job id. +func WithJobID(ctx context.Context, jobID string) context.Context { + existingJobID := ctx.Value(JobIDKey) + if existingJobID != nil { + jobID = fmt.Sprintf("%v/%v", existingJobID, jobID) + } + + return context.WithValue(ctx, JobIDKey, jobID) +} + +// Gets a new context with AppName set. +func WithAppName(ctx context.Context, appName string) context.Context { + return context.WithValue(ctx, AppNameKey, appName) +} + +// Gets a new context with Phase set. +func WithPhase(ctx context.Context, phase string) context.Context { + return context.WithValue(ctx, PhaseKey, phase) +} + +// Gets a new context with ExecutionID set. +func WithExecutionID(ctx context.Context, execID string) context.Context { + return context.WithValue(ctx, ExecIDKey, execID) +} + +// Gets a new context with NodeID (nested) set. +func WithNodeID(ctx context.Context, nodeID string) context.Context { + existingNodeID := ctx.Value(NodeIDKey) + if existingNodeID != nil { + nodeID = fmt.Sprintf("%v/%v", existingNodeID, nodeID) + } + return context.WithValue(ctx, NodeIDKey, nodeID) +} + +// Gets a new context with WorkflowName set. +func WithWorkflowID(ctx context.Context, workflow string) context.Context { + return context.WithValue(ctx, WorkflowIDKey, workflow) +} + +// Get new context with Project and Domain values set +func WithProjectDomain(ctx context.Context, project, domain string) context.Context { + c := context.WithValue(ctx, ProjectKey, project) + return context.WithValue(c, DomainKey, domain) +} + +// Gets a new context with WorkflowName set. +func WithTaskID(ctx context.Context, taskID string) context.Context { + return context.WithValue(ctx, TaskIDKey, taskID) +} + +// Gets a new context with TaskType set. +func WithTaskType(ctx context.Context, taskType string) context.Context { + return context.WithValue(ctx, TaskTypeKey, taskType) +} + +func addFieldIfNotNil(ctx context.Context, m map[string]interface{}, fieldKey Key) { + val := ctx.Value(fieldKey) + if val != nil { + m[fieldKey.String()] = val + } +} + +func addStringFieldWithDefaults(ctx context.Context, m map[string]string, fieldKey Key) { + val := ctx.Value(fieldKey) + if val == nil { + val = "" + } + m[fieldKey.String()] = val.(string) +} + +// Gets a map of all known logKeys set on the context. logKeys are special and should be used incase, context fields +// are to be added to the log lines. +func GetLogFields(ctx context.Context) map[string]interface{} { + res := map[string]interface{}{} + for _, k := range logKeys { + addFieldIfNotNil(ctx, res, k) + } + return res +} + +func Value(ctx context.Context, key Key) string { + val := ctx.Value(key) + if val != nil { + return val.(string) + } + return "" +} + +func Values(ctx context.Context, keys ...Key) map[string]string { + res := map[string]string{} + for _, k := range keys { + addStringFieldWithDefaults(ctx, res, k) + } + return res +} diff --git a/flytestdlib/contextutils/context_test.go b/flytestdlib/contextutils/context_test.go new file mode 100644 index 0000000000..99d29d3e3d --- /dev/null +++ b/flytestdlib/contextutils/context_test.go @@ -0,0 +1,113 @@ +package contextutils + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestWithAppName(t *testing.T) { + ctx := context.Background() + assert.Nil(t, ctx.Value(AppNameKey)) + ctx = WithAppName(ctx, "application-name-123") + assert.Equal(t, "application-name-123", ctx.Value(AppNameKey)) + + ctx = WithAppName(ctx, "app-name-modified") + assert.Equal(t, "app-name-modified", ctx.Value(AppNameKey)) +} + +func TestWithPhase(t *testing.T) { + ctx := context.Background() + assert.Nil(t, ctx.Value(PhaseKey)) + ctx = WithPhase(ctx, "Running") + assert.Equal(t, "Running", ctx.Value(PhaseKey)) + + ctx = WithPhase(ctx, "Failed") + assert.Equal(t, "Failed", ctx.Value(PhaseKey)) +} + +func TestWithJobId(t *testing.T) { + ctx := context.Background() + assert.Nil(t, ctx.Value(JobIDKey)) + ctx = WithJobID(ctx, "job123") + assert.Equal(t, "job123", ctx.Value(JobIDKey)) + + ctx = WithJobID(ctx, "subtask") + assert.Equal(t, "job123/subtask", ctx.Value(JobIDKey)) +} + +func TestWithNamespace(t *testing.T) { + ctx := context.Background() + assert.Nil(t, ctx.Value(NamespaceKey)) + ctx = WithNamespace(ctx, "flyte") + assert.Equal(t, "flyte", ctx.Value(NamespaceKey)) + + ctx = WithNamespace(ctx, "flyte2") + assert.Equal(t, "flyte2", ctx.Value(NamespaceKey)) +} + +func TestWithExecutionID(t *testing.T) { + ctx := context.Background() + assert.Nil(t, ctx.Value(ExecIDKey)) + ctx = WithExecutionID(ctx, "job123") + assert.Equal(t, "job123", ctx.Value(ExecIDKey)) +} + +func TestWithTaskType(t *testing.T) { + ctx := context.Background() + assert.Nil(t, ctx.Value(TaskTypeKey)) + ctx = WithTaskType(ctx, "flyte") + assert.Equal(t, "flyte", ctx.Value(TaskTypeKey)) +} + +func TestWithWorkflowID(t *testing.T) { + ctx := context.Background() + assert.Nil(t, ctx.Value(WorkflowIDKey)) + ctx = WithWorkflowID(ctx, "flyte") + assert.Equal(t, "flyte", ctx.Value(WorkflowIDKey)) +} + +func TestWithNodeID(t *testing.T) { + ctx := context.Background() + assert.Nil(t, ctx.Value(NodeIDKey)) + ctx = WithNodeID(ctx, "n1") + assert.Equal(t, "n1", ctx.Value(NodeIDKey)) + + ctx = WithNodeID(ctx, "n2") + assert.Equal(t, "n1/n2", ctx.Value(NodeIDKey)) +} + +func TestWithProjectDomain(t *testing.T) { + ctx := context.Background() + assert.Nil(t, ctx.Value(ProjectKey)) + assert.Nil(t, ctx.Value(DomainKey)) + ctx = WithProjectDomain(ctx, "proj", "domain") + assert.Equal(t, "proj", ctx.Value(ProjectKey)) + assert.Equal(t, "domain", ctx.Value(DomainKey)) +} + +func TestWithTaskID(t *testing.T) { + ctx := context.Background() + assert.Nil(t, ctx.Value(TaskIDKey)) + ctx = WithTaskID(ctx, "task") + assert.Equal(t, "task", ctx.Value(TaskIDKey)) +} + +func TestGetFields(t *testing.T) { + ctx := context.Background() + ctx = WithJobID(WithNamespace(ctx, "ns123"), "job123") + m := GetLogFields(ctx) + assert.Equal(t, "ns123", m[NamespaceKey.String()]) + assert.Equal(t, "job123", m[JobIDKey.String()]) +} + +func TestValues(t *testing.T) { + ctx := context.Background() + ctx = WithWorkflowID(ctx, "flyte") + m := Values(ctx, ProjectKey, WorkflowIDKey) + assert.NotNil(t, m) + assert.Equal(t, 2, len(m)) + assert.Equal(t, "flyte", m[WorkflowIDKey.String()]) + assert.Equal(t, "", m[ProjectKey.String()]) +} diff --git a/flytestdlib/internal/utils/parsers.go b/flytestdlib/internal/utils/parsers.go new file mode 100644 index 0000000000..c1fcfa3a4a --- /dev/null +++ b/flytestdlib/internal/utils/parsers.go @@ -0,0 +1,20 @@ +package utils + +import ( + "net/url" +) + +// A utility function to be used in tests. It parses urlString as url.URL or panics if it's invalid. +func MustParseURL(urlString string) url.URL { + u, err := url.Parse(urlString) + if err != nil { + panic(err) + } + + return *u +} + +// A utility function to be used in tests. It returns the address of the passed value. +func RefInt(val int) *int { + return &val +} diff --git a/flytestdlib/internal/utils/parsers_test.go b/flytestdlib/internal/utils/parsers_test.go new file mode 100644 index 0000000000..c15202b6ac --- /dev/null +++ b/flytestdlib/internal/utils/parsers_test.go @@ -0,0 +1,25 @@ +package utils + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMustParseURL(t *testing.T) { + t.Run("Valid URL", func(t *testing.T) { + MustParseURL("http://something-profound-localhost.com") + }) + + t.Run("Invalid URL", func(t *testing.T) { + assert.Panics(t, func() { + MustParseURL("invalid_url:is_here\\") + }) + }) +} + +func TestRefUint32(t *testing.T) { + input := int(5) + res := RefInt(input) + assert.Equal(t, input, *res) +} diff --git a/flytestdlib/ioutils/bytes.go b/flytestdlib/ioutils/bytes.go new file mode 100644 index 0000000000..ad69c0d97b --- /dev/null +++ b/flytestdlib/ioutils/bytes.go @@ -0,0 +1,21 @@ +package ioutils + +import ( + "bytes" + "io" +) + +// A Closeable Reader for bytes to mimic stream from inmemory byte storage +type BytesReadCloser struct { + *bytes.Reader +} + +func (*BytesReadCloser) Close() error { + return nil +} + +func NewBytesReadCloser(b []byte) io.ReadCloser { + return &BytesReadCloser{ + Reader: bytes.NewReader(b), + } +} diff --git a/flytestdlib/ioutils/bytes_test.go b/flytestdlib/ioutils/bytes_test.go new file mode 100644 index 0000000000..9745c62c91 --- /dev/null +++ b/flytestdlib/ioutils/bytes_test.go @@ -0,0 +1,17 @@ +package ioutils + +import ( + "io/ioutil" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewBytesReadCloser(t *testing.T) { + i := []byte("abc") + r := NewBytesReadCloser(i) + o, e := ioutil.ReadAll(r) + assert.NoError(t, e) + assert.Equal(t, o, i) + assert.NoError(t, r.Close()) +} diff --git a/flytestdlib/ioutils/timed_readers.go b/flytestdlib/ioutils/timed_readers.go new file mode 100644 index 0000000000..ceb6415952 --- /dev/null +++ b/flytestdlib/ioutils/timed_readers.go @@ -0,0 +1,17 @@ +package ioutils + +import ( + "io" + "io/ioutil" +) + +// Defines a common interface for timers. +type Timer interface { + // Stops the timer and reports observation. + Stop() float64 +} + +func ReadAll(r io.Reader, t Timer) ([]byte, error) { + defer t.Stop() + return ioutil.ReadAll(r) +} diff --git a/flytestdlib/ioutils/timed_readers_test.go b/flytestdlib/ioutils/timed_readers_test.go new file mode 100644 index 0000000000..7fa74f7241 --- /dev/null +++ b/flytestdlib/ioutils/timed_readers_test.go @@ -0,0 +1,20 @@ +package ioutils + +import ( + "bytes" + "testing" + "time" + + "github.com/lyft/flytestdlib/promutils" + "github.com/stretchr/testify/assert" +) + +func TestReadAll(t *testing.T) { + r := bytes.NewReader([]byte("hello")) + s := promutils.NewTestScope() + w, e := s.NewStopWatch("x", "empty", time.Millisecond) + assert.NoError(t, e) + b, err := ReadAll(r, w.Start()) + assert.NoError(t, err) + assert.Equal(t, "hello", string(b)) +} diff --git a/flytestdlib/logger/config.go b/flytestdlib/logger/config.go new file mode 100644 index 0000000000..faa2da6f9e --- /dev/null +++ b/flytestdlib/logger/config.go @@ -0,0 +1,81 @@ +package logger + +import ( + "context" + + "github.com/lyft/flytestdlib/config" +) + +//go:generate pflags Config + +const configSectionKey = "Logger" + +type FormatterType = string + +const ( + FormatterJSON FormatterType = "json" + FormatterText FormatterType = "text" +) + +const ( + jsonDataKey string = "json" +) + +// Global logger config. +type Config struct { + // Determines whether to include source code location in logs. This might incurs a performance hit and is only + // recommended on debug/development builds. + IncludeSourceCode bool `json:"show-source" pflag:",Includes source code location in logs."` + + // Determines whether the logger should mute all logs (including panics) + Mute bool `json:"mute" pflag:",Mutes all logs regardless of severity. Intended for benchmarks/tests only."` + + // Determines the minimum log level to log. + Level Level `json:"level" pflag:"4,Sets the minimum logging level."` + + Formatter FormatterConfig `json:"formatter" pflag:",Sets logging format."` +} + +type FormatterConfig struct { + Type FormatterType `json:"type" pflag:"\"json\",Sets logging format type."` +} + +var globalConfig = Config{} + +// Sets global logger config +func SetConfig(cfg Config) { + globalConfig = cfg + + onConfigUpdated(cfg) +} + +// Level type. +type Level = int + +// These are the different logging levels. +const ( + // PanicLevel level, highest level of severity. Logs and then calls panic with the + // message passed to Debug, Info, ... + PanicLevel Level = iota + // FatalLevel level. Logs and then calls `os.Exit(1)`. It will exit even if the + // logging level is set to Panic. + FatalLevel + // ErrorLevel level. Logs. Used for errors that should definitely be noted. + // Commonly used for hooks to send errors to an error tracking service. + ErrorLevel + // WarnLevel level. Non-critical entries that deserve eyes. + WarnLevel + // InfoLevel level. General operational entries about what's going on inside the + // application. + InfoLevel + // DebugLevel level. Usually only enabled when debugging. Very verbose logging. + DebugLevel +) + +func init() { + if _, err := config.RegisterSectionWithUpdates(configSectionKey, &Config{}, func(ctx context.Context, newValue config.Config) { + SetConfig(*newValue.(*Config)) + }); err != nil { + panic(err) + } +} diff --git a/flytestdlib/logger/config_flags.go b/flytestdlib/logger/config_flags.go new file mode 100755 index 0000000000..cf8950b94f --- /dev/null +++ b/flytestdlib/logger/config_flags.go @@ -0,0 +1,21 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package logger + +import ( + "fmt" + + "github.com/spf13/pflag" +) + +// GetPFlagSet will return strongly types pflags for all fields in Config and its nested types. The format of the +// flags is json-name.json-sub-name... etc. +func (Config) GetPFlagSet(prefix string) *pflag.FlagSet { + cmdFlags := pflag.NewFlagSet("Config", pflag.ExitOnError) + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "show-source"), *new(bool), "Includes source code location in logs.") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "mute"), *new(bool), "Mutes all logs regardless of severity. Intended for benchmarks/tests only.") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "level"), 4, "Sets the minimum logging level.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "formatter.type"), "json", "Sets logging format type.") + return cmdFlags +} diff --git a/flytestdlib/logger/config_flags_test.go b/flytestdlib/logger/config_flags_test.go new file mode 100755 index 0000000000..401d58d493 --- /dev/null +++ b/flytestdlib/logger/config_flags_test.go @@ -0,0 +1,190 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package logger + +import ( + "encoding/json" + "fmt" + "reflect" + "strings" + "testing" + + "github.com/mitchellh/mapstructure" + "github.com/stretchr/testify/assert" +) + +var dereferencableKindsConfig = map[reflect.Kind]struct{}{ + reflect.Array: {}, reflect.Chan: {}, reflect.Map: {}, reflect.Ptr: {}, reflect.Slice: {}, +} + +// Checks if t is a kind that can be dereferenced to get its underlying type. +func canGetElementConfig(t reflect.Kind) bool { + _, exists := dereferencableKindsConfig[t] + return exists +} + +// This decoder hook tests types for json unmarshaling capability. If implemented, it uses json unmarshal to build the +// object. Otherwise, it'll just pass on the original data. +func jsonUnmarshalerHookConfig(_, to reflect.Type, data interface{}) (interface{}, error) { + unmarshalerType := reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() + if to.Implements(unmarshalerType) || reflect.PtrTo(to).Implements(unmarshalerType) || + (canGetElementConfig(to.Kind()) && to.Elem().Implements(unmarshalerType)) { + + raw, err := json.Marshal(data) + if err != nil { + fmt.Printf("Failed to marshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + res := reflect.New(to).Interface() + err = json.Unmarshal(raw, &res) + if err != nil { + fmt.Printf("Failed to umarshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + return res, nil + } + + return data, nil +} + +func decode_Config(input, result interface{}) error { + config := &mapstructure.DecoderConfig{ + TagName: "json", + WeaklyTypedInput: true, + Result: result, + DecodeHook: mapstructure.ComposeDecodeHookFunc( + mapstructure.StringToTimeDurationHookFunc(), + mapstructure.StringToSliceHookFunc(","), + jsonUnmarshalerHookConfig, + ), + } + + decoder, err := mapstructure.NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(input) +} + +func join_Config(arr interface{}, sep string) string { + listValue := reflect.ValueOf(arr) + strs := make([]string, 0, listValue.Len()) + for i := 0; i < listValue.Len(); i++ { + strs = append(strs, fmt.Sprintf("%v", listValue.Index(i))) + } + + return strings.Join(strs, sep) +} + +func testDecodeJson_Config(t *testing.T, val, result interface{}) { + assert.NoError(t, decode_Config(val, result)) +} + +func testDecodeSlice_Config(t *testing.T, vStringSlice, result interface{}) { + assert.NoError(t, decode_Config(vStringSlice, result)) +} + +func TestConfig_GetPFlagSet(t *testing.T) { + val := Config{} + cmdFlags := val.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) +} + +func TestConfig_SetFlags(t *testing.T) { + actual := Config{} + cmdFlags := actual.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) + + t.Run("Test_show-source", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vBool, err := cmdFlags.GetBool("show-source"); err == nil { + assert.Equal(t, bool(*new(bool)), vBool) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("show-source", testValue) + if vBool, err := cmdFlags.GetBool("show-source"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vBool), &actual.IncludeSourceCode) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_mute", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vBool, err := cmdFlags.GetBool("mute"); err == nil { + assert.Equal(t, bool(*new(bool)), vBool) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("mute", testValue) + if vBool, err := cmdFlags.GetBool("mute"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vBool), &actual.Mute) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_level", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("level"); err == nil { + assert.Equal(t, int(4), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("level", testValue) + if vInt, err := cmdFlags.GetInt("level"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.Level) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_formatter.type", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("formatter.type"); err == nil { + assert.Equal(t, string("json"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("formatter.type", testValue) + if vString, err := cmdFlags.GetString("formatter.type"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Formatter.Type) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) +} diff --git a/flytestdlib/logger/config_test.go b/flytestdlib/logger/config_test.go new file mode 100644 index 0000000000..7d2d3782b1 --- /dev/null +++ b/flytestdlib/logger/config_test.go @@ -0,0 +1,20 @@ +package logger + +import "testing" + +func TestSetConfig(t *testing.T) { + type args struct { + cfg Config + } + tests := []struct { + name string + args args + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + SetConfig(tt.args.cfg) + }) + } +} diff --git a/flytestdlib/logger/logger.go b/flytestdlib/logger/logger.go new file mode 100644 index 0000000000..29aadc8bf8 --- /dev/null +++ b/flytestdlib/logger/logger.go @@ -0,0 +1,337 @@ +// Defines global context-aware logger. +// The default implementation uses logrus. This package registers "logger" config section on init(). The structure of the +// config section is expected to be un-marshal-able to Config struct. +package logger + +import ( + "context" + + "github.com/lyft/flytestdlib/contextutils" + + "fmt" + "runtime" + "strings" + + "github.com/sirupsen/logrus" +) + +//go:generate gotests -w -all $FILE + +const indentLevelKey contextutils.Key = "LoggerIndentLevel" + +func onConfigUpdated(cfg Config) { + logrus.SetLevel(logrus.Level(cfg.Level)) + + switch cfg.Formatter.Type { + case FormatterText: + if _, isText := logrus.StandardLogger().Formatter.(*logrus.JSONFormatter); !isText { + logrus.SetFormatter(&logrus.TextFormatter{ + FieldMap: logrus.FieldMap{ + logrus.FieldKeyTime: "ts", + }, + }) + } + default: + if _, isJSON := logrus.StandardLogger().Formatter.(*logrus.JSONFormatter); !isJSON { + logrus.SetFormatter(&logrus.JSONFormatter{ + DataKey: jsonDataKey, + FieldMap: logrus.FieldMap{ + logrus.FieldKeyTime: "ts", + }, + }) + } + } +} + +func getSourceLocation() string { + if globalConfig.IncludeSourceCode { + _, file, line, ok := runtime.Caller(3) + if !ok { + file = "???" + line = 1 + } else { + slash := strings.LastIndex(file, "/") + if slash >= 0 { + file = file[slash+1:] + } + } + + return fmt.Sprintf("[%v:%v] ", file, line) + } + + return "" +} + +func wrapHeader(ctx context.Context, args ...interface{}) []interface{} { + args = append([]interface{}{getIndent(ctx)}, args...) + + if globalConfig.IncludeSourceCode { + return append( + []interface{}{ + fmt.Sprintf("%v", getSourceLocation()), + }, + args...) + } + + return args +} + +func wrapHeaderForMessage(ctx context.Context, message string) string { + message = fmt.Sprintf("%v%v", getIndent(ctx), message) + + if globalConfig.IncludeSourceCode { + return fmt.Sprintf("%v%v", getSourceLocation(), message) + } + + return message +} + +func getLogger(ctx context.Context) *logrus.Entry { + entry := logrus.WithFields(logrus.Fields(contextutils.GetLogFields(ctx))) + entry.Level = logrus.Level(globalConfig.Level) + return entry +} + +func WithIndent(ctx context.Context, additionalIndent string) context.Context { + indentLevel := getIndent(ctx) + additionalIndent + return context.WithValue(ctx, indentLevelKey, indentLevel) +} + +func getIndent(ctx context.Context) string { + if existing := ctx.Value(indentLevelKey); existing != nil { + return existing.(string) + } + + return "" +} + +// Gets a value indicating whether logs at this level will be written to the logger. This is particularly useful to avoid +// computing log messages unnecessarily. +func IsLoggable(ctx context.Context, level Level) bool { + return getLogger(ctx).Level >= logrus.Level(level) +} + +// Debug logs a message at level Debug on the standard logger. +func Debug(ctx context.Context, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(ctx).Debug(wrapHeader(ctx, args)...) +} + +// Print logs a message at level Info on the standard logger. +func Print(ctx context.Context, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(ctx).Print(wrapHeader(ctx, args)...) +} + +// Info logs a message at level Info on the standard logger. +func Info(ctx context.Context, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(ctx).Info(wrapHeader(ctx, args)...) +} + +// Warn logs a message at level Warn on the standard logger. +func Warn(ctx context.Context, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(ctx).Warn(wrapHeader(ctx, args)...) +} + +// Warning logs a message at level Warn on the standard logger. +func Warning(ctx context.Context, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(ctx).Warning(wrapHeader(ctx, args)...) +} + +// Error logs a message at level Error on the standard logger. +func Error(ctx context.Context, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(ctx).Error(wrapHeader(ctx, args)...) +} + +// Panic logs a message at level Panic on the standard logger. +func Panic(ctx context.Context, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(ctx).Panic(wrapHeader(ctx, args)...) +} + +// Fatal logs a message at level Fatal on the standard logger. +func Fatal(ctx context.Context, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(ctx).Fatal(wrapHeader(ctx, args)...) +} + +// Debugf logs a message at level Debug on the standard logger. +func Debugf(ctx context.Context, format string, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(ctx).Debugf(wrapHeaderForMessage(ctx, format), args...) +} + +// Printf logs a message at level Info on the standard logger. +func Printf(ctx context.Context, format string, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(ctx).Printf(wrapHeaderForMessage(ctx, format), args...) +} + +// Infof logs a message at level Info on the standard logger. +func Infof(ctx context.Context, format string, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(ctx).Infof(wrapHeaderForMessage(ctx, format), args...) +} + +// InfofNoCtx logs a formatted message without context. +func InfofNoCtx(format string, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(context.TODO()).Infof(format, args...) +} + +// Warnf logs a message at level Warn on the standard logger. +func Warnf(ctx context.Context, format string, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(ctx).Warnf(wrapHeaderForMessage(ctx, format), args...) +} + +// Warningf logs a message at level Warn on the standard logger. +func Warningf(ctx context.Context, format string, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(ctx).Warningf(wrapHeaderForMessage(ctx, format), args...) +} + +// Errorf logs a message at level Error on the standard logger. +func Errorf(ctx context.Context, format string, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(ctx).Errorf(wrapHeaderForMessage(ctx, format), args...) +} + +// Panicf logs a message at level Panic on the standard logger. +func Panicf(ctx context.Context, format string, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(ctx).Panicf(wrapHeaderForMessage(ctx, format), args...) +} + +// Fatalf logs a message at level Fatal on the standard logger. +func Fatalf(ctx context.Context, format string, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(ctx).Fatalf(wrapHeaderForMessage(ctx, format), args...) +} + +// Debugln logs a message at level Debug on the standard logger. +func Debugln(ctx context.Context, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(ctx).Debugln(wrapHeader(ctx, args)...) +} + +// Println logs a message at level Info on the standard logger. +func Println(ctx context.Context, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(ctx).Println(wrapHeader(ctx, args)...) +} + +// Infoln logs a message at level Info on the standard logger. +func Infoln(ctx context.Context, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(ctx).Infoln(wrapHeader(ctx, args)...) +} + +// Warnln logs a message at level Warn on the standard logger. +func Warnln(ctx context.Context, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(ctx).Warnln(wrapHeader(ctx, args)...) +} + +// Warningln logs a message at level Warn on the standard logger. +func Warningln(ctx context.Context, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(ctx).Warningln(wrapHeader(ctx, args)...) +} + +// Errorln logs a message at level Error on the standard logger. +func Errorln(ctx context.Context, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(ctx).Errorln(wrapHeader(ctx, args)...) +} + +// Panicln logs a message at level Panic on the standard logger. +func Panicln(ctx context.Context, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(ctx).Panicln(wrapHeader(ctx, args)...) +} + +// Fatalln logs a message at level Fatal on the standard logger. +func Fatalln(ctx context.Context, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(ctx).Fatalln(wrapHeader(ctx, args)...) +} diff --git a/flytestdlib/logger/logger_test.go b/flytestdlib/logger/logger_test.go new file mode 100644 index 0000000000..75c73a9432 --- /dev/null +++ b/flytestdlib/logger/logger_test.go @@ -0,0 +1,643 @@ +// Defines global context-aware logger. +// The default implementation uses logrus. This package registers "logger" config section on init(). The structure of the +// config section is expected to be un-marshal-able to Config struct. +package logger + +import ( + "context" + "reflect" + "strings" + "testing" + + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" +) + +func init() { + SetConfig(Config{ + Level: InfoLevel, + IncludeSourceCode: true, + }) +} + +func Test_getSourceLocation(t *testing.T) { + tests := []struct { + name string + want string + }{ + {"current", " "}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := getSourceLocation(); !strings.HasSuffix(got, tt.want) { + t.Errorf("getSourceLocation() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_wrapHeaderForMessage(t *testing.T) { + type args struct { + message string + } + tests := []struct { + name string + args args + want string + }{ + {"no args", args{message: ""}, " "}, + {"1 arg", args{message: "hello"}, " hello"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := wrapHeaderForMessage(context.TODO(), tt.args.message); !strings.HasSuffix(got, tt.want) { + t.Errorf("wrapHeaderForMessage() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestIsLoggable(t *testing.T) { + type args struct { + ctx context.Context + level Level + } + tests := []struct { + name string + args args + want bool + }{ + {"Debug Is not loggable", args{ctx: context.TODO(), level: DebugLevel}, false}, + {"Info Is loggable", args{ctx: context.TODO(), level: InfoLevel}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := IsLoggable(tt.args.ctx, tt.args.level); got != tt.want { + t.Errorf("IsLoggable() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestDebug(t *testing.T) { + type args struct { + ctx context.Context + args []interface{} + } + tests := []struct { + name string + args args + }{ + {"test", args{ctx: context.TODO(), args: []interface{}{"arg"}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Debug(tt.args.ctx, tt.args.args...) + }) + } +} + +func TestPrint(t *testing.T) { + type args struct { + ctx context.Context + args []interface{} + } + tests := []struct { + name string + args args + }{ + {"test", args{ctx: context.TODO(), args: []interface{}{"arg"}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Print(tt.args.ctx, tt.args.args...) + }) + } +} + +func TestInfo(t *testing.T) { + type args struct { + ctx context.Context + args []interface{} + } + tests := []struct { + name string + args args + }{ + {"test", args{ctx: context.TODO(), args: []interface{}{"arg"}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Info(tt.args.ctx, tt.args.args...) + }) + } +} + +func TestWarn(t *testing.T) { + type args struct { + ctx context.Context + args []interface{} + } + tests := []struct { + name string + args args + }{ + {"test", args{ctx: context.TODO(), args: []interface{}{"arg"}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Warn(tt.args.ctx, tt.args.args...) + }) + } +} + +func TestWarning(t *testing.T) { + type args struct { + ctx context.Context + args []interface{} + } + tests := []struct { + name string + args args + }{ + {"test", args{ctx: context.TODO(), args: []interface{}{"arg"}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Warning(tt.args.ctx, tt.args.args...) + }) + } +} + +func TestError(t *testing.T) { + type args struct { + ctx context.Context + args []interface{} + } + tests := []struct { + name string + args args + }{ + {"test", args{ctx: context.TODO(), args: []interface{}{"arg"}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Error(tt.args.ctx, tt.args.args...) + }) + } +} + +func TestPanic(t *testing.T) { + type args struct { + ctx context.Context + args []interface{} + } + tests := []struct { + name string + args args + }{ + {"test", args{ctx: context.TODO(), args: []interface{}{"arg"}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Panics(t, func() { + Panic(tt.args.ctx, tt.args.args...) + }) + }) + } +} + +func TestDebugf(t *testing.T) { + type args struct { + ctx context.Context + format string + args []interface{} + } + tests := []struct { + name string + args args + }{ + {"test", args{ctx: context.TODO(), args: []interface{}{"arg"}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Debugf(tt.args.ctx, tt.args.format, tt.args.args...) + }) + } +} + +func TestPrintf(t *testing.T) { + type args struct { + ctx context.Context + format string + args []interface{} + } + tests := []struct { + name string + args args + }{ + {"test", args{ctx: context.TODO(), args: []interface{}{"arg"}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Printf(tt.args.ctx, tt.args.format, tt.args.args...) + }) + } +} + +func TestInfof(t *testing.T) { + type args struct { + ctx context.Context + format string + args []interface{} + } + tests := []struct { + name string + args args + }{ + {"test", args{ctx: context.TODO(), args: []interface{}{"arg"}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Infof(tt.args.ctx, tt.args.format, tt.args.args...) + }) + } +} + +func TestInfofNoCtx(t *testing.T) { + type args struct { + format string + args []interface{} + } + tests := []struct { + name string + args args + }{ + {"test", args{format: "%v", args: []interface{}{"arg"}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + InfofNoCtx(tt.args.format, tt.args.args...) + }) + } +} + +func TestWarnf(t *testing.T) { + type args struct { + ctx context.Context + format string + args []interface{} + } + tests := []struct { + name string + args args + }{ + {"test", args{ctx: context.TODO(), args: []interface{}{"arg"}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Warnf(tt.args.ctx, tt.args.format, tt.args.args...) + }) + } +} + +func TestWarningf(t *testing.T) { + type args struct { + ctx context.Context + format string + args []interface{} + } + tests := []struct { + name string + args args + }{ + {"test", args{ctx: context.TODO(), args: []interface{}{"arg"}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Warningf(tt.args.ctx, tt.args.format, tt.args.args...) + }) + } +} + +func TestErrorf(t *testing.T) { + type args struct { + ctx context.Context + format string + args []interface{} + } + tests := []struct { + name string + args args + }{ + {"test", args{ctx: context.TODO(), args: []interface{}{"arg"}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Errorf(tt.args.ctx, tt.args.format, tt.args.args...) + }) + } +} + +func TestPanicf(t *testing.T) { + type args struct { + ctx context.Context + format string + args []interface{} + } + tests := []struct { + name string + args args + }{ + {"test", args{ctx: context.TODO(), args: []interface{}{"arg"}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Panics(t, func() { + Panicf(tt.args.ctx, tt.args.format, tt.args.args...) + }) + }) + } +} + +func TestDebugln(t *testing.T) { + type args struct { + ctx context.Context + args []interface{} + } + tests := []struct { + name string + args args + }{ + {"test", args{ctx: context.TODO(), args: []interface{}{"arg"}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Debugln(tt.args.ctx, tt.args.args...) + }) + } +} + +func TestPrintln(t *testing.T) { + type args struct { + ctx context.Context + args []interface{} + } + tests := []struct { + name string + args args + }{ + {"test", args{ctx: context.TODO(), args: []interface{}{"arg"}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Println(tt.args.ctx, tt.args.args...) + }) + } +} + +func TestInfoln(t *testing.T) { + type args struct { + ctx context.Context + args []interface{} + } + tests := []struct { + name string + args args + }{ + {"test", args{ctx: context.TODO(), args: []interface{}{"arg"}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Infoln(tt.args.ctx, tt.args.args...) + }) + } +} + +func TestWarnln(t *testing.T) { + type args struct { + ctx context.Context + args []interface{} + } + tests := []struct { + name string + args args + }{ + {"test", args{ctx: context.TODO(), args: []interface{}{"arg"}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Warnln(tt.args.ctx, tt.args.args...) + }) + } +} + +func TestWarningln(t *testing.T) { + type args struct { + ctx context.Context + args []interface{} + } + tests := []struct { + name string + args args + }{ + {"test", args{ctx: context.TODO(), args: []interface{}{"arg"}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Warningln(tt.args.ctx, tt.args.args...) + }) + } +} + +func TestErrorln(t *testing.T) { + type args struct { + ctx context.Context + args []interface{} + } + tests := []struct { + name string + args args + }{ + {"test", args{ctx: context.TODO(), args: []interface{}{"arg"}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Errorln(tt.args.ctx, tt.args.args...) + }) + } +} + +func TestPanicln(t *testing.T) { + type args struct { + ctx context.Context + args []interface{} + } + tests := []struct { + name string + args args + }{ + {"test", args{ctx: context.TODO(), args: []interface{}{"arg"}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Panics(t, func() { + Panicln(tt.args.ctx, tt.args.args...) + }) + }) + } +} + +func Test_wrapHeader(t *testing.T) { + type args struct { + ctx context.Context + args []interface{} + } + tests := []struct { + name string + args args + want []interface{} + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := wrapHeader(tt.args.ctx, tt.args.args...); !reflect.DeepEqual(got, tt.want) { + t.Errorf("wrapHeader() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_getLogger(t *testing.T) { + type args struct { + ctx context.Context + } + tests := []struct { + name string + args args + want *logrus.Entry + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := getLogger(tt.args.ctx); !reflect.DeepEqual(got, tt.want) { + t.Errorf("getLogger() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestWithIndent(t *testing.T) { + type args struct { + ctx context.Context + additionalIndent string + } + tests := []struct { + name string + args args + want context.Context + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := WithIndent(tt.args.ctx, tt.args.additionalIndent); !reflect.DeepEqual(got, tt.want) { + t.Errorf("WithIndent() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_getIndent(t *testing.T) { + type args struct { + ctx context.Context + } + tests := []struct { + name string + args args + want string + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := getIndent(tt.args.ctx); got != tt.want { + t.Errorf("getIndent() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestFatal(t *testing.T) { + type args struct { + ctx context.Context + args []interface{} + } + tests := []struct { + name string + args args + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Fatal(tt.args.ctx, tt.args.args...) + }) + } +} + +func TestFatalf(t *testing.T) { + type args struct { + ctx context.Context + format string + args []interface{} + } + tests := []struct { + name string + args args + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Fatalf(tt.args.ctx, tt.args.format, tt.args.args...) + }) + } +} + +func TestFatalln(t *testing.T) { + type args struct { + ctx context.Context + args []interface{} + } + tests := []struct { + name string + args args + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Fatalln(tt.args.ctx, tt.args.args...) + }) + } +} + +func Test_onConfigUpdated(t *testing.T) { + type args struct { + cfg Config + } + tests := []struct { + name string + args args + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + onConfigUpdated(tt.args.cfg) + }) + } +} diff --git a/flytestdlib/pbhash/pbhash.go b/flytestdlib/pbhash/pbhash.go new file mode 100644 index 0000000000..820b5c511a --- /dev/null +++ b/flytestdlib/pbhash/pbhash.go @@ -0,0 +1,58 @@ +// This is a package that provides hashing utilities for Protobuf objects. +package pbhash + +import ( + "context" + "encoding/base64" + + goObjectHash "github.com/benlaurie/objecthash/go/objecthash" + "github.com/golang/protobuf/jsonpb" + "github.com/golang/protobuf/proto" + "github.com/lyft/flytestdlib/logger" +) + +var marshaller = &jsonpb.Marshaler{} + +func fromHashToByteArray(input [32]byte) []byte { + output := make([]byte, 32) + for idx, val := range input { + output[idx] = val + } + return output +} + +// Generate a deterministic hash in bytes for the pb object +func ComputeHash(ctx context.Context, pb proto.Message) ([]byte, error) { + // We marshal the pb object to JSON first which should provide a consistent mapping of pb to json fields as stated + // here: https://developers.google.com/protocol-buffers/docs/proto3#json + // jsonpb marshalling includes: + // - sorting map values to provide a stable output + // - omitting empty values which supports backwards compatibility of old protobuf definitions + // We do not use protobuf marshalling because it does not guarantee stable output because of how it handles + // unknown fields and ordering of fields. https://github.com/protocolbuffers/protobuf/issues/2830 + pbJSON, err := marshaller.MarshalToString(pb) + if err != nil { + logger.Warning(ctx, "failed to marshal pb [%+v] to JSON with err %v", pb, err) + return nil, err + } + + // Deterministically hash the JSON object to a byte array. The library will sort the map keys of the JSON object + // so that we do not run into the issues from pb marshalling. + hash, err := goObjectHash.CommonJSONHash(pbJSON) + if err != nil { + logger.Warning(ctx, "failed to hash JSON for pb [%+v] with err %v", pb, err) + return nil, err + } + + return fromHashToByteArray(hash), err +} + +// Generate a deterministic hash as a base64 encoded string for the pb object. +func ComputeHashString(ctx context.Context, pb proto.Message) (string, error) { + hashBytes, err := ComputeHash(ctx, pb) + if err != nil { + return "", err + } + + return base64.StdEncoding.EncodeToString(hashBytes), err +} diff --git a/flytestdlib/pbhash/pbhash_test.go b/flytestdlib/pbhash/pbhash_test.go new file mode 100644 index 0000000000..75735b4135 --- /dev/null +++ b/flytestdlib/pbhash/pbhash_test.go @@ -0,0 +1,145 @@ +package pbhash + +import ( + "context" + "testing" + "time" + + "github.com/golang/protobuf/proto" + "github.com/golang/protobuf/ptypes" + "github.com/golang/protobuf/ptypes/duration" + "github.com/golang/protobuf/ptypes/timestamp" + "github.com/stretchr/testify/assert" +) + +// Mock a Protobuf generated GO object +type mockProtoMessage struct { + Integer int64 `protobuf:"varint,1,opt,name=integer,proto3" json:"integer,omitempty"` + FloatValue float64 `protobuf:"fixed64,2,opt,name=float_value,json=floatValue,proto3" json:"float_value,omitempty"` + StringValue string `protobuf:"bytes,3,opt,name=string_value,json=stringValue,proto3" json:"string_value,omitempty"` + Boolean bool `protobuf:"varint,4,opt,name=boolean,proto3" json:"boolean,omitempty"` + Datetime *timestamp.Timestamp `protobuf:"bytes,5,opt,name=datetime,proto3" json:"datetime,omitempty"` + Duration *duration.Duration `protobuf:"bytes,6,opt,name=duration,proto3" json:"duration,omitempty"` + MapValue map[string]string `protobuf:"bytes,7,rep,name=map_value,json=mapValue,proto3" json:"map_value,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` + Collections []string `protobuf:"bytes,8,rep,name=collections,proto3" json:"collections,omitempty"` +} + +func (mockProtoMessage) Reset() { +} + +func (m mockProtoMessage) String() string { + return proto.CompactTextString(m) +} + +func (mockProtoMessage) ProtoMessage() { +} + +// Mock an older version of the above pb object that doesn't have some fields +type mockOlderProto struct { + Integer int64 `protobuf:"varint,1,opt,name=integer,proto3" json:"integer,omitempty"` + FloatValue float64 `protobuf:"fixed64,2,opt,name=float_value,json=floatValue,proto3" json:"float_value,omitempty"` + StringValue string `protobuf:"bytes,3,opt,name=string_value,json=stringValue,proto3" json:"string_value,omitempty"` + Boolean bool `protobuf:"varint,4,opt,name=boolean,proto3" json:"boolean,omitempty"` +} + +func (mockOlderProto) Reset() { +} + +func (m mockOlderProto) String() string { + return proto.CompactTextString(m) +} + +func (mockOlderProto) ProtoMessage() { +} + +var sampleTime, _ = ptypes.TimestampProto( + time.Date(2019, 03, 29, 12, 0, 0, 0, time.UTC)) + +func TestProtoHash(t *testing.T) { + mockProto := &mockProtoMessage{ + Integer: 18, + FloatValue: 1.3, + StringValue: "lets test this", + Boolean: true, + Datetime: sampleTime, + Duration: ptypes.DurationProto(time.Millisecond), + MapValue: map[string]string{ + "z": "last", + "a": "first", + }, + Collections: []string{"1", "2", "3"}, + } + + expectedHashedMockProto := []byte{0x62, 0x95, 0xb2, 0x2c, 0x23, 0xf5, 0x35, 0x6d, 0x3, 0x56, 0x4d, 0xc7, 0x8f, 0xae, + 0x2d, 0x2b, 0xbd, 0x7, 0xff, 0xdb, 0x7e, 0xe5, 0xf4, 0x25, 0x8f, 0xbc, 0xb2, 0xc, 0xad, 0xa5, 0x48, 0x44} + expectedHashString := "YpWyLCP1NW0DVk3Hj64tK70H/9t+5fQlj7yyDK2lSEQ=" + + t.Run("TestFullProtoHash", func(t *testing.T) { + hashedBytes, err := ComputeHash(context.Background(), mockProto) + assert.Nil(t, err) + assert.Equal(t, expectedHashedMockProto, hashedBytes) + assert.Len(t, hashedBytes, 32) + + hashedString, err := ComputeHashString(context.Background(), mockProto) + assert.Nil(t, err) + assert.Equal(t, hashedString, expectedHashString) + }) + + t.Run("TestFullProtoHashReorderKeys", func(t *testing.T) { + mockProto.MapValue = map[string]string{"a": "first", "z": "last"} + hashedBytes, err := ComputeHash(context.Background(), mockProto) + assert.Nil(t, err) + assert.Equal(t, expectedHashedMockProto, hashedBytes) + assert.Len(t, hashedBytes, 32) + + hashedString, err := ComputeHashString(context.Background(), mockProto) + assert.Nil(t, err) + assert.Equal(t, hashedString, expectedHashString) + }) +} + +func TestPartialFilledProtoHash(t *testing.T) { + + mockProtoOmitEmpty := &mockProtoMessage{ + Integer: 18, + FloatValue: 1.3, + StringValue: "lets test this", + Boolean: true, + } + + expectedHashedMockProtoOmitEmpty := []byte{0x1a, 0x13, 0xcc, 0x4c, 0xab, 0xc9, 0x7d, 0x43, 0xc7, 0x2b, 0xc5, 0x37, + 0xbc, 0x49, 0xa8, 0x8b, 0xfc, 0x1d, 0x54, 0x1c, 0x7b, 0x21, 0x04, 0x8f, 0xab, 0x28, 0xc6, 0x5c, 0x06, 0x73, + 0xaa, 0xe2} + + expectedHashStringOmitEmpty := "GhPMTKvJfUPHK8U3vEmoi/wdVBx7IQSPqyjGXAZzquI=" + + t.Run("TestPartial", func(t *testing.T) { + hashedBytes, err := ComputeHash(context.Background(), mockProtoOmitEmpty) + assert.Nil(t, err) + assert.Equal(t, expectedHashedMockProtoOmitEmpty, hashedBytes) + assert.Len(t, hashedBytes, 32) + + hashedString, err := ComputeHashString(context.Background(), mockProtoOmitEmpty) + assert.Nil(t, err) + assert.Equal(t, hashedString, expectedHashStringOmitEmpty) + }) + + mockOldProtoMessage := &mockOlderProto{ + Integer: 18, + FloatValue: 1.3, + StringValue: "lets test this", + Boolean: true, + } + + t.Run("TestOlderProto", func(t *testing.T) { + hashedBytes, err := ComputeHash(context.Background(), mockOldProtoMessage) + assert.Nil(t, err) + assert.Equal(t, expectedHashedMockProtoOmitEmpty, hashedBytes) + assert.Len(t, hashedBytes, 32) + + hashedString, err := ComputeHashString(context.Background(), mockProtoOmitEmpty) + assert.Nil(t, err) + assert.Equal(t, hashedString, expectedHashStringOmitEmpty) + }) + +} diff --git a/flytestdlib/profutils/server.go b/flytestdlib/profutils/server.go new file mode 100644 index 0000000000..b8240cf25c --- /dev/null +++ b/flytestdlib/profutils/server.go @@ -0,0 +1,116 @@ +package profutils + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + + "github.com/lyft/flytestdlib/config" + + "github.com/lyft/flytestdlib/version" + + "github.com/lyft/flytestdlib/logger" + "github.com/prometheus/client_golang/prometheus/promhttp" + + _ "net/http/pprof" // Import for pprof server +) + +const ( + healthcheck = "/healthcheck" + metricsPath = "/metrics" + versionPath = "/version" + configPath = "/config" +) + +const ( + contentTypeHeader = "Content-Type" + contentTypeJSON = "application/json; charset=utf-8" +) + +func WriteStringResponse(resp http.ResponseWriter, code int, body string) error { + resp.WriteHeader(code) + _, err := resp.Write([]byte(body)) + return err +} + +func WriteJSONResponse(resp http.ResponseWriter, code int, body interface{}) error { + resp.Header().Set(contentTypeHeader, contentTypeJSON) + resp.WriteHeader(code) + j, err := json.Marshal(body) + if err != nil { + return WriteStringResponse(resp, http.StatusInternalServerError, err.Error()) + } + return WriteStringResponse(resp, http.StatusOK, string(j)) +} + +func healtcheckHandler(w http.ResponseWriter, req *http.Request) { + err := WriteStringResponse(w, http.StatusOK, http.StatusText(http.StatusOK)) + if err != nil { + panic(err) + } +} + +func versionHandler(w http.ResponseWriter, req *http.Request) { + err := WriteStringResponse(w, http.StatusOK, fmt.Sprintf("Build [%s], Version [%s]", version.Build, version.Version)) + if err != nil { + panic(err) + } +} + +func configHandler(w http.ResponseWriter, req *http.Request) { + m, err := config.AllConfigsAsMap(config.GetRootSection()) + if err != nil { + err = WriteStringResponse(w, http.StatusInternalServerError, err.Error()) + if err != nil { + logger.Errorf(context.TODO(), "Failed to write error response. Error: %v", err) + panic(err) + } + } + + if err := WriteJSONResponse(w, http.StatusOK, m); err != nil { + panic(err) + } +} + +// Starts an http server on the given port +func StartProfilingServer(ctx context.Context, pprofPort int) error { + logger.Infof(ctx, "Starting profiling server on port [%v]", pprofPort) + e := http.ListenAndServe(fmt.Sprintf(":%d", pprofPort), nil) + if e != nil { + logger.Errorf(ctx, "Failed to start profiling server. Error: %v", e) + return fmt.Errorf("failed to start profiling server, %s", e) + } + + return nil +} + +func configureGlobalHTTPHandler(handlers map[string]http.Handler) error { + if handlers == nil { + handlers = map[string]http.Handler{} + } + handlers[metricsPath] = promhttp.Handler() + handlers[healthcheck] = http.HandlerFunc(healtcheckHandler) + handlers[versionPath] = http.HandlerFunc(versionHandler) + handlers[configPath] = http.HandlerFunc(configHandler) + + for p, h := range handlers { + http.Handle(p, h) + } + + return nil +} + +// Forwards the call to StartProfilingServer +// Also registers: +// 1. the prometheus HTTP handler on '/metrics' path shared with the profiling server. +// 2. A healthcheck (L7) handler on '/healthcheck'. +// 3. A version handler on '/version' provides information about the specific build. +// 4. A config handler on '/config' provides a dump of the currently loaded config. +func StartProfilingServerWithDefaultHandlers(ctx context.Context, pprofPort int, handlers map[string]http.Handler) error { + if err := configureGlobalHTTPHandler(handlers); err != nil { + return err + } + + return StartProfilingServer(ctx, pprofPort) +} diff --git a/flytestdlib/profutils/server_test.go b/flytestdlib/profutils/server_test.go new file mode 100644 index 0000000000..e2eb709943 --- /dev/null +++ b/flytestdlib/profutils/server_test.go @@ -0,0 +1,103 @@ +package profutils + +import ( + "encoding/json" + "net/http" + "testing" + + "github.com/lyft/flytestdlib/internal/utils" + + "github.com/stretchr/testify/assert" +) + +type MockResponseWriter struct { + Status int + Headers http.Header + Body []byte +} + +func (m *MockResponseWriter) Write(b []byte) (int, error) { + m.Body = b + return 1, nil +} + +func (m *MockResponseWriter) WriteHeader(statusCode int) { + m.Status = statusCode +} + +func (m *MockResponseWriter) Header() http.Header { + return m.Headers +} + +type TestObj struct { + X int `json:"x"` +} + +func init() { + if err := configureGlobalHTTPHandler(nil); err != nil { + panic(err) + } +} + +func TestWriteJsonResponse(t *testing.T) { + m := &MockResponseWriter{Headers: http.Header{}} + assert.NoError(t, WriteJSONResponse(m, http.StatusOK, TestObj{10})) + assert.Equal(t, http.StatusOK, m.Status) + assert.Equal(t, http.Header{contentTypeHeader: []string{contentTypeJSON}}, m.Headers) + assert.Equal(t, `{"x":10}`, string(m.Body)) +} + +func TestWriteStringResponse(t *testing.T) { + m := &MockResponseWriter{Headers: http.Header{}} + assert.NoError(t, WriteStringResponse(m, http.StatusOK, "OK")) + assert.Equal(t, http.StatusOK, m.Status) + assert.Equal(t, "OK", string(m.Body)) +} + +func TestConfigHandler(t *testing.T) { + writer := &MockResponseWriter{Headers: http.Header{}} + testURL := utils.MustParseURL(configPath) + request := &http.Request{ + URL: &testURL, + } + + http.DefaultServeMux.ServeHTTP(writer, request) + assert.Equal(t, http.StatusOK, writer.Status) + + m := map[string]interface{}{} + assert.NoError(t, json.Unmarshal(writer.Body, &m)) + assert.Equal(t, map[string]interface{}{ + "logger": map[string]interface{}{ + "show-source": false, + "mute": false, + "level": float64(0), + "formatter": map[string]interface{}{ + "type": "", + }, + }, + }, m) +} + +func TestVersionHandler(t *testing.T) { + writer := &MockResponseWriter{Headers: http.Header{}} + testURL := utils.MustParseURL(versionPath) + request := &http.Request{ + URL: &testURL, + } + + http.DefaultServeMux.ServeHTTP(writer, request) + assert.Equal(t, http.StatusOK, writer.Status) + assert.Equal(t, `Build [unknown], Version [unknown]`, string(writer.Body)) +} + +func TestHealthcheckHandler(t *testing.T) { + writer := &MockResponseWriter{Headers: http.Header{}} + testURL := utils.MustParseURL(healthcheck) + request := &http.Request{ + URL: &testURL, + } + + http.DefaultServeMux.ServeHTTP(writer, request) + assert.Equal(t, http.StatusOK, writer.Status) + assert.Equal(t, `OK`, string(writer.Body)) +} diff --git a/flytestdlib/promutils/labeled/counter.go b/flytestdlib/promutils/labeled/counter.go new file mode 100644 index 0000000000..c68ce02d23 --- /dev/null +++ b/flytestdlib/promutils/labeled/counter.go @@ -0,0 +1,65 @@ +package labeled + +import ( + "context" + + "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flytestdlib/promutils" + "github.com/prometheus/client_golang/prometheus" +) + +// Represents a counter labeled with values from the context. See labeled.SetMetricsKeys for information about to +// configure that. +type Counter struct { + *prometheus.CounterVec + + prometheus.Counter +} + +// Inc increments the counter by 1. Use Add to increment it by arbitrary non-negative values. The data point will be +// labeled with values from context. See labeled.SetMetricsKeys for information about to configure that. +func (c Counter) Inc(ctx context.Context) { + counter, err := c.CounterVec.GetMetricWith(contextutils.Values(ctx, metricKeys...)) + if err != nil { + panic(err.Error()) + } + counter.Inc() + + if c.Counter != nil { + c.Counter.Inc() + } +} + +// Add adds the given value to the counter. It panics if the value is < 0.. The data point will be labeled with values +// from context. See labeled.SetMetricsKeys for information about to configure that. +func (c Counter) Add(ctx context.Context, v float64) { + counter, err := c.CounterVec.GetMetricWith(contextutils.Values(ctx, metricKeys...)) + if err != nil { + panic(err.Error()) + } + counter.Add(v) + + if c.Counter != nil { + c.Counter.Add(v) + } +} + +// Creates a new labeled counter. Label keys must be set before instantiating a counter. See labeled.SetMetricsKeys for +// information about to configure that. +func NewCounter(name, description string, scope promutils.Scope, opts ...MetricOption) Counter { + if len(metricKeys) == 0 { + panic(ErrNeverSet) + } + + c := Counter{ + CounterVec: scope.MustNewCounterVec(name, description, metricStringKeys...), + } + + for _, opt := range opts { + if _, emitUnlabeledMetric := opt.(EmitUnlabeledMetricOption); emitUnlabeledMetric { + c.Counter = scope.MustNewCounter(GetUnlabeledMetricName(name), description) + } + } + + return c +} diff --git a/flytestdlib/promutils/labeled/counter_test.go b/flytestdlib/promutils/labeled/counter_test.go new file mode 100644 index 0000000000..130b8217a2 --- /dev/null +++ b/flytestdlib/promutils/labeled/counter_test.go @@ -0,0 +1,31 @@ +package labeled + +import ( + "context" + "testing" + + "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flytestdlib/promutils" + "github.com/stretchr/testify/assert" +) + +func TestLabeledCounter(t *testing.T) { + assert.NotPanics(t, func() { + SetMetricKeys(contextutils.ProjectKey, contextutils.DomainKey, contextutils.WorkflowIDKey, contextutils.TaskIDKey) + }) + + scope := promutils.NewTestScope() + c := NewCounter("lbl_counter", "help", scope) + assert.NotNil(t, c) + ctx := context.TODO() + c.Inc(ctx) + c.Add(ctx, 1.0) + + ctx = contextutils.WithProjectDomain(ctx, "project", "domain") + c.Inc(ctx) + c.Add(ctx, 1.0) + + ctx = contextutils.WithTaskID(ctx, "task") + c.Inc(ctx) + c.Add(ctx, 1.0) +} diff --git a/flytestdlib/promutils/labeled/keys.go b/flytestdlib/promutils/labeled/keys.go new file mode 100644 index 0000000000..d8c8683750 --- /dev/null +++ b/flytestdlib/promutils/labeled/keys.go @@ -0,0 +1,47 @@ +package labeled + +import ( + "fmt" + "reflect" + "sync" + + "github.com/lyft/flytestdlib/contextutils" +) + +var ( + ErrAlreadySet = fmt.Errorf("cannot set metric keys more than once") + ErrEmpty = fmt.Errorf("cannot set metric keys to an empty set") + ErrNeverSet = fmt.Errorf("must call SetMetricKeys prior to using labeled package") + + // Metric Keys to label metrics with. These keys get pulled from context if they are present. Use contextutils to fill + // them in. + metricKeys = make([]contextutils.Key, 0) + + // :(, we have to create a separate list to satisfy the MustNewCounterVec API as it accepts string only + metricStringKeys = make([]string, 0) + metricKeysAreSet = sync.Once{} +) + +// Sets keys to use with labeled metrics. The values of these keys will be pulled from context at runtime. +func SetMetricKeys(keys ...contextutils.Key) { + if len(keys) == 0 { + panic(ErrEmpty) + } + + ran := false + metricKeysAreSet.Do(func() { + ran = true + metricKeys = keys + for _, metricKey := range metricKeys { + metricStringKeys = append(metricStringKeys, metricKey.String()) + } + }) + + if !ran && !reflect.DeepEqual(keys, metricKeys) { + panic(ErrAlreadySet) + } +} + +func GetUnlabeledMetricName(metricName string) string { + return metricName + "_unlabeled" +} diff --git a/flytestdlib/promutils/labeled/keys_test.go b/flytestdlib/promutils/labeled/keys_test.go new file mode 100644 index 0000000000..4a8600aea3 --- /dev/null +++ b/flytestdlib/promutils/labeled/keys_test.go @@ -0,0 +1,24 @@ +package labeled + +import ( + "testing" + + "github.com/lyft/flytestdlib/contextutils" + "github.com/stretchr/testify/assert" +) + +func TestMetricKeys(t *testing.T) { + input := []contextutils.Key{ + contextutils.ProjectKey, contextutils.DomainKey, contextutils.WorkflowIDKey, contextutils.TaskIDKey, + } + + assert.NotPanics(t, func() { SetMetricKeys(input...) }) + assert.Equal(t, input, metricKeys) + + for i, k := range metricKeys { + assert.Equal(t, k.String(), metricStringKeys[i]) + } + + assert.NotPanics(t, func() { SetMetricKeys(input...) }) + assert.Panics(t, func() { SetMetricKeys(contextutils.DomainKey) }) +} diff --git a/flytestdlib/promutils/labeled/metric_option.go b/flytestdlib/promutils/labeled/metric_option.go new file mode 100644 index 0000000000..08fb2f76f9 --- /dev/null +++ b/flytestdlib/promutils/labeled/metric_option.go @@ -0,0 +1,15 @@ +package labeled + +// Defines extra set of options to customize the emitted metric. +type MetricOption interface { + isMetricOption() +} + +// Instructs the metric to emit unlabeled metric (besides the labeled one). This is useful to get overall system +// performance. +type EmitUnlabeledMetricOption struct { +} + +func (EmitUnlabeledMetricOption) isMetricOption() {} + +var EmitUnlabeledMetric = EmitUnlabeledMetricOption{} diff --git a/flytestdlib/promutils/labeled/metric_option_test.go b/flytestdlib/promutils/labeled/metric_option_test.go new file mode 100644 index 0000000000..0a070f7420 --- /dev/null +++ b/flytestdlib/promutils/labeled/metric_option_test.go @@ -0,0 +1,13 @@ +package labeled + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMetricOption(t *testing.T) { + var opt MetricOption = &EmitUnlabeledMetric + _, isMetricOption := opt.(MetricOption) + assert.True(t, isMetricOption) +} diff --git a/flytestdlib/promutils/labeled/stopwatch.go b/flytestdlib/promutils/labeled/stopwatch.go new file mode 100644 index 0000000000..90e29971f9 --- /dev/null +++ b/flytestdlib/promutils/labeled/stopwatch.go @@ -0,0 +1,87 @@ +package labeled + +import ( + "context" + "time" + + "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flytestdlib/promutils" +) + +type StopWatch struct { + *promutils.StopWatchVec + + // We use SummaryVec for emitting StopWatchVec, this computes percentiles per metric tags combination on the client- + // side. This makes it impossible to aggregate percentiles across tags (e.g. to have system-wide view). When enabled + // through a flag in the constructor, we initialize this additional untagged stopwatch to compute percentiles + // across tags. + promutils.StopWatch +} + +// Start creates a new Instance of the StopWatch called a Timer that is closeable/stoppable. +// Common pattern to time a scope would be +// { +// timer := stopWatch.Start(ctx) +// defer timer.Stop() +// .... +// } +func (c StopWatch) Start(ctx context.Context) Timer { + w, err := c.StopWatchVec.GetMetricWith(contextutils.Values(ctx, metricKeys...)) + if err != nil { + panic(err.Error()) + } + + if c.StopWatch.Observer == nil { + return w.Start() + } + + return timer{ + Timers: []Timer{ + w.Start(), + c.StopWatch.Start(), + }, + } +} + +// Observes specified duration between the start and end time. The data point will be labeled with values from context. +// See labeled.SetMetricsKeys for information about to configure that. +func (c StopWatch) Observe(ctx context.Context, start, end time.Time) { + w, err := c.StopWatchVec.GetMetricWith(contextutils.Values(ctx, metricKeys...)) + if err != nil { + panic(err.Error()) + } + w.Observe(start, end) + + if c.StopWatch.Observer != nil { + c.StopWatch.Observe(start, end) + } +} + +// This method observes the elapsed duration since the creation of the timer. The timer is created using a StopWatch. +// The data point will be labeled with values from context. See labeled.SetMetricsKeys for information about to +// configure that. +func (c StopWatch) Time(ctx context.Context, f func()) { + t := c.Start(ctx) + f() + t.Stop() +} + +// Creates a new labeled stopwatch. Label keys must be set before instantiating a counter. See labeled.SetMetricsKeys +// for information about to configure that. +func NewStopWatch(name, description string, scale time.Duration, scope promutils.Scope, opts ...MetricOption) StopWatch { + if len(metricKeys) == 0 { + panic(ErrNeverSet) + } + + sw := StopWatch{ + StopWatchVec: scope.MustNewStopWatchVec(name, description, scale, metricStringKeys...), + } + + for _, opt := range opts { + if _, emitUnableMetric := opt.(EmitUnlabeledMetricOption); emitUnableMetric { + sw.StopWatch = scope.MustNewStopWatch(GetUnlabeledMetricName(name), description, scale) + } + } + + return sw +} diff --git a/flytestdlib/promutils/labeled/stopwatch_test.go b/flytestdlib/promutils/labeled/stopwatch_test.go new file mode 100644 index 0000000000..d5adf0eade --- /dev/null +++ b/flytestdlib/promutils/labeled/stopwatch_test.go @@ -0,0 +1,47 @@ +package labeled + +import ( + "context" + "testing" + "time" + + "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flytestdlib/promutils" + "github.com/stretchr/testify/assert" +) + +func TestLabeledStopWatch(t *testing.T) { + assert.NotPanics(t, func() { + SetMetricKeys(contextutils.ProjectKey, contextutils.DomainKey, contextutils.WorkflowIDKey, contextutils.TaskIDKey) + }) + + t.Run("always labeled", func(t *testing.T) { + scope := promutils.NewTestScope() + c := NewStopWatch("lbl_counter", "help", time.Second, scope) + assert.NotNil(t, c) + ctx := context.TODO() + w := c.Start(ctx) + w.Stop() + + ctx = contextutils.WithProjectDomain(ctx, "project", "domain") + w = c.Start(ctx) + w.Stop() + + ctx = contextutils.WithTaskID(ctx, "task") + w = c.Start(ctx) + w.Stop() + + c.Observe(ctx, time.Now(), time.Now().Add(time.Second)) + c.Time(ctx, func() { + // Do nothing + }) + }) + + t.Run("unlabeled", func(t *testing.T) { + scope := promutils.NewTestScope() + c := NewStopWatch("lbl_counter_2", "help", time.Second, scope, EmitUnlabeledMetric) + assert.NotNil(t, c) + + c.Start(context.TODO()) + }) +} diff --git a/flytestdlib/promutils/labeled/timer_wrapper.go b/flytestdlib/promutils/labeled/timer_wrapper.go new file mode 100644 index 0000000000..75aa4bee94 --- /dev/null +++ b/flytestdlib/promutils/labeled/timer_wrapper.go @@ -0,0 +1,20 @@ +package labeled + +// Defines a common interface for timers. +type Timer interface { + // Stops the timer and reports observation. + Stop() float64 +} + +type timer struct { + Timers []Timer +} + +func (t timer) Stop() float64 { + var res float64 + for _, timer := range t.Timers { + res = timer.Stop() + } + + return res +} diff --git a/flytestdlib/promutils/labeled/timer_wrapper_test.go b/flytestdlib/promutils/labeled/timer_wrapper_test.go new file mode 100644 index 0000000000..375836c557 --- /dev/null +++ b/flytestdlib/promutils/labeled/timer_wrapper_test.go @@ -0,0 +1,28 @@ +package labeled + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +type fakeTimer struct { + stopCount int +} + +func (f *fakeTimer) Stop() float64 { + f.stopCount++ + return 0 +} + +func TestTimerStop(t *testing.T) { + ft := &fakeTimer{} + tim := timer{ + Timers: []Timer{ + ft, ft, ft, + }, + } + + tim.Stop() + assert.Equal(t, 3, ft.stopCount) +} diff --git a/flytestdlib/promutils/scope.go b/flytestdlib/promutils/scope.go new file mode 100644 index 0000000000..d32fab4cfa --- /dev/null +++ b/flytestdlib/promutils/scope.go @@ -0,0 +1,434 @@ +package promutils + +import ( + "strings" + "time" + + "k8s.io/apimachinery/pkg/util/rand" + + "github.com/prometheus/client_golang/prometheus" +) + +const defaultScopeDelimiterStr = ":" +const defaultMetricDelimiterStr = "_" + +func panicIfError(err error) { + if err != nil { + panic("Failed to register metrics. Error: " + err.Error()) + } +} + +// A Simple StopWatch that works with prometheus summary +// It will scale the output to match the expected time scale (milliseconds, seconds etc) +// NOTE: Do not create a StopWatch object by hand, use a Scope to get a new instance of the StopWatch object +type StopWatch struct { + prometheus.Observer + outputScale time.Duration +} + +// Start creates a new Instance of the StopWatch called a Timer that is closeable/stoppable. +// Common pattern to time a scope would be +// { +// timer := stopWatch.Start() +// defer timer.Stop() +// .... +// } +func (s StopWatch) Start() Timer { + return Timer{ + start: time.Now(), + outputScale: s.outputScale, + timer: s.Observer, + } +} + +// Observes specified duration between the start and end time +func (s StopWatch) Observe(start, end time.Time) { + observed := end.Sub(start).Nanoseconds() + outputScaleDuration := s.outputScale.Nanoseconds() + if outputScaleDuration == 0 { + s.Observer.Observe(0) + return + } + scaled := float64(observed / outputScaleDuration) + s.Observer.Observe(scaled) +} + +// Observes/records the time to execute the given function synchronously +func (s StopWatch) Time(f func()) { + t := s.Start() + f() + t.Stop() +} + +// A Simple StopWatch that works with prometheus summary +// It will scale the output to match the expected time scale (milliseconds, seconds etc) +// NOTE: Do not create a StopWatch object by hand, use a Scope to get a new instance of the StopWatch object +type StopWatchVec struct { + *prometheus.SummaryVec + outputScale time.Duration +} + +// Gets a concrete StopWatch instance that can be used to start a timer and record observations. +func (s StopWatchVec) WithLabelValues(values ...string) StopWatch { + return StopWatch{ + Observer: s.SummaryVec.WithLabelValues(values...), + outputScale: s.outputScale, + } +} + +func (s StopWatchVec) GetMetricWith(labels prometheus.Labels) (StopWatch, error) { + sVec, err := s.SummaryVec.GetMetricWith(labels) + if err != nil { + return StopWatch{}, err + } + return StopWatch{ + Observer: sVec, + outputScale: s.outputScale, + }, nil +} + +// This is a stoppable instance of a StopWatch or a Timer +// A Timer can only be stopped. On stopping it will output the elapsed duration to prometheus +type Timer struct { + start time.Time + outputScale time.Duration + timer prometheus.Observer +} + +// This method observes the elapsed duration since the creation of the timer. The timer is created using a StopWatch +func (s Timer) Stop() float64 { + observed := time.Since(s.start).Nanoseconds() + outputScaleDuration := s.outputScale.Nanoseconds() + if outputScaleDuration == 0 { + s.timer.Observe(0) + return 0 + } + scaled := float64(observed / outputScaleDuration) + s.timer.Observe(scaled) + return scaled +} + +// A Scope represents a prefix in Prometheus. It is nestable, thus every metric that is published does not need to +// provide a prefix, but just the name of the metric. As long as the Scope is used to create a new instance of the metric +// The prefix (or scope) is automatically set. +type Scope interface { + // Creates new prometheus.Gauge metric with the prefix as the CurrentScope + // Name is a string that follows prometheus conventions (mostly [_a-z]) + // Refer to https://prometheus.io/docs/concepts/metric_types/ for more information + NewGauge(name, description string) (prometheus.Gauge, error) + MustNewGauge(name, description string) prometheus.Gauge + + // Creates new prometheus.GaugeVec metric with the prefix as the CurrentScope + // Refer to https://prometheus.io/docs/concepts/metric_types/ for more information + NewGaugeVec(name, description string, labelNames ...string) (*prometheus.GaugeVec, error) + MustNewGaugeVec(name, description string, labelNames ...string) *prometheus.GaugeVec + + // Creates new prometheus.Summary metric with the prefix as the CurrentScope + // Refer to https://prometheus.io/docs/concepts/metric_types/ for more information + NewSummary(name, description string) (prometheus.Summary, error) + MustNewSummary(name, description string) prometheus.Summary + + // Creates new prometheus.SummaryVec metric with the prefix as the CurrentScope + // Refer to https://prometheus.io/docs/concepts/metric_types/ for more information + NewSummaryVec(name, description string, labelNames ...string) (*prometheus.SummaryVec, error) + MustNewSummaryVec(name, description string, labelNames ...string) *prometheus.SummaryVec + + // Creates new prometheus.Histogram metric with the prefix as the CurrentScope + // Refer to https://prometheus.io/docs/concepts/metric_types/ for more information + NewHistogram(name, description string) (prometheus.Histogram, error) + MustNewHistogram(name, description string) prometheus.Histogram + + // Creates new prometheus.HistogramVec metric with the prefix as the CurrentScope + // Refer to https://prometheus.io/docs/concepts/metric_types/ for more information + NewHistogramVec(name, description string, labelNames ...string) (*prometheus.HistogramVec, error) + MustNewHistogramVec(name, description string, labelNames ...string) *prometheus.HistogramVec + + // Creates new prometheus.Counter metric with the prefix as the CurrentScope + // Refer to https://prometheus.io/docs/concepts/metric_types/ for more information + // Important to note, counters are not like typical counters. These are ever increasing and cumulative. + // So if you want to observe counters within buckets use Summary/Histogram + NewCounter(name, description string) (prometheus.Counter, error) + MustNewCounter(name, description string) prometheus.Counter + + // Creates new prometheus.GaugeVec metric with the prefix as the CurrentScope + // Refer to https://prometheus.io/docs/concepts/metric_types/ for more information + NewCounterVec(name, description string, labelNames ...string) (*prometheus.CounterVec, error) + MustNewCounterVec(name, description string, labelNames ...string) *prometheus.CounterVec + + // This is a custom wrapper to create a StopWatch object in the current Scope. + // Duration is to specify the scale of the Timer. For example if you are measuring times in milliseconds + // pass scale=times.Millisecond + // https://golang.org/pkg/time/#Duration + // The metric name is auto-suffixed with the right scale. Refer to DurationToString to understand + NewStopWatch(name, description string, scale time.Duration) (StopWatch, error) + MustNewStopWatch(name, description string, scale time.Duration) StopWatch + + // This is a custom wrapper to create a StopWatch object in the current Scope. + // Duration is to specify the scale of the Timer. For example if you are measuring times in milliseconds + // pass scale=times.Millisecond + // https://golang.org/pkg/time/#Duration + // The metric name is auto-suffixed with the right scale. Refer to DurationToString to understand + NewStopWatchVec(name, description string, scale time.Duration, labelNames ...string) (*StopWatchVec, error) + MustNewStopWatchVec(name, description string, scale time.Duration, labelNames ...string) *StopWatchVec + + // In case nesting is desired for metrics, create a new subScope. This is generally useful in creating + // Scoped and SubScoped metrics + NewSubScope(name string) Scope + + // Returns the current ScopeName. Use for creating your own metrics + CurrentScope() string + + // Method that provides a scoped metric name. Can be used, if you want to directly create your own metric + NewScopedMetricName(name string) string +} + +type metricsScope struct { + scope string +} + +func (m metricsScope) NewGauge(name, description string) (prometheus.Gauge, error) { + g := prometheus.NewGauge( + prometheus.GaugeOpts{ + Name: m.NewScopedMetricName(name), + Help: description, + }, + ) + return g, prometheus.Register(g) +} + +func (m metricsScope) MustNewGauge(name, description string) prometheus.Gauge { + g, err := m.NewGauge(name, description) + panicIfError(err) + return g +} + +func (m metricsScope) NewGaugeVec(name, description string, labelNames ...string) (*prometheus.GaugeVec, error) { + g := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: m.NewScopedMetricName(name), + Help: description, + }, + labelNames, + ) + return g, prometheus.Register(g) +} + +func (m metricsScope) MustNewGaugeVec(name, description string, labelNames ...string) *prometheus.GaugeVec { + g, err := m.NewGaugeVec(name, description, labelNames...) + panicIfError(err) + return g +} + +func (m metricsScope) NewSummary(name, description string) (prometheus.Summary, error) { + s := prometheus.NewSummary( + prometheus.SummaryOpts{ + Name: m.NewScopedMetricName(name), + Help: description, + }, + ) + + return s, prometheus.Register(s) +} + +func (m metricsScope) MustNewSummary(name, description string) prometheus.Summary { + s, err := m.NewSummary(name, description) + panicIfError(err) + return s +} + +func (m metricsScope) NewSummaryVec(name, description string, labelNames ...string) (*prometheus.SummaryVec, error) { + s := prometheus.NewSummaryVec( + prometheus.SummaryOpts{ + Name: m.NewScopedMetricName(name), + Help: description, + }, + labelNames, + ) + + return s, prometheus.Register(s) +} +func (m metricsScope) MustNewSummaryVec(name, description string, labelNames ...string) *prometheus.SummaryVec { + s, err := m.NewSummaryVec(name, description, labelNames...) + panicIfError(err) + return s +} + +func (m metricsScope) NewHistogram(name, description string) (prometheus.Histogram, error) { + h := prometheus.NewHistogram( + prometheus.HistogramOpts{ + Name: m.NewScopedMetricName(name), + Help: description, + }, + ) + return h, prometheus.Register(h) +} + +func (m metricsScope) MustNewHistogram(name, description string) prometheus.Histogram { + h, err := m.NewHistogram(name, description) + panicIfError(err) + return h +} + +func (m metricsScope) NewHistogramVec(name, description string, labelNames ...string) (*prometheus.HistogramVec, error) { + h := prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Name: m.NewScopedMetricName(name), + Help: description, + }, + labelNames, + ) + return h, prometheus.Register(h) +} + +func (m metricsScope) MustNewHistogramVec(name, description string, labelNames ...string) *prometheus.HistogramVec { + h, err := m.NewHistogramVec(name, description, labelNames...) + panicIfError(err) + return h +} + +func (m metricsScope) NewCounter(name, description string) (prometheus.Counter, error) { + c := prometheus.NewCounter( + prometheus.CounterOpts{ + Name: m.NewScopedMetricName(name), + Help: description, + }, + ) + return c, prometheus.Register(c) +} + +func (m metricsScope) MustNewCounter(name, description string) prometheus.Counter { + c, err := m.NewCounter(name, description) + panicIfError(err) + return c +} + +func (m metricsScope) NewCounterVec(name, description string, labelNames ...string) (*prometheus.CounterVec, error) { + c := prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: m.NewScopedMetricName(name), + Help: description, + }, + labelNames, + ) + return c, prometheus.Register(c) +} + +func (m metricsScope) MustNewCounterVec(name, description string, labelNames ...string) *prometheus.CounterVec { + c, err := m.NewCounterVec(name, description, labelNames...) + panicIfError(err) + return c +} + +func (m metricsScope) NewStopWatch(name, description string, scale time.Duration) (StopWatch, error) { + if !strings.HasSuffix(name, defaultMetricDelimiterStr) { + name += defaultMetricDelimiterStr + } + name += DurationToString(scale) + s, err := m.NewSummary(name, description) + if err != nil { + return StopWatch{}, err + } + + return StopWatch{ + Observer: s, + outputScale: scale, + }, nil +} + +func (m metricsScope) MustNewStopWatch(name, description string, scale time.Duration) StopWatch { + s, err := m.NewStopWatch(name, description, scale) + panicIfError(err) + return s +} + +func (m metricsScope) NewStopWatchVec(name, description string, scale time.Duration, labelNames ...string) (*StopWatchVec, error) { + if !strings.HasSuffix(name, defaultMetricDelimiterStr) { + name += defaultMetricDelimiterStr + } + name += DurationToString(scale) + s, err := m.NewSummaryVec(name, description, labelNames...) + if err != nil { + return &StopWatchVec{}, err + } + + return &StopWatchVec{ + SummaryVec: s, + outputScale: scale, + }, nil +} + +func (m metricsScope) MustNewStopWatchVec(name, description string, scale time.Duration, labelNames ...string) *StopWatchVec { + s, err := m.NewStopWatchVec(name, description, scale, labelNames...) + panicIfError(err) + return s +} + +func (m metricsScope) CurrentScope() string { + return m.scope +} + +// Creates a metric name under the scope. Scope will always have a defaultScopeDelimiterRune as the last character +func (m metricsScope) NewScopedMetricName(name string) string { + if name == "" { + panic("metric name cannot be an empty string") + } + + return m.scope + name +} + +func (m metricsScope) NewSubScope(subscopeName string) Scope { + if subscopeName == "" { + panic("scope name cannot be an empty string") + } + + // If the last character of the new subscope is already a defaultScopeDelimiterRune, do not add anything + if !strings.HasSuffix(subscopeName, defaultScopeDelimiterStr) { + subscopeName += defaultScopeDelimiterStr + } + + // Always add a new defaultScopeDelimiterRune to every scope name + return NewScope(m.scope + subscopeName) +} + +// Creates a new scope in the format `name + defaultScopeDelimiterRune` +// If the last character is already a defaultScopeDelimiterRune, then it does not add it to the scope name +func NewScope(name string) Scope { + if name == "" { + panic("base scope for a metric cannot be an empty string") + } + + // If the last character of the new subscope is already a defaultScopeDelimiterRune, do not add anything + if !strings.HasSuffix(name, defaultScopeDelimiterStr) { + name += defaultScopeDelimiterStr + } + + return metricsScope{ + scope: name, + } +} + +// Returns a randomly-named scope for use in tests. +// Prometheus requires that metric names begin with a single word, which is generated from the alphabetic testScopeNameCharset. +func NewTestScope() Scope { + return NewScope("test" + rand.String(6)) +} + +// DurationToString converts the duration to a string suffix that indicates the scale of the timer. +func DurationToString(duration time.Duration) string { + if duration >= time.Hour { + return "h" + } + if duration >= time.Minute { + return "m" + } + if duration >= time.Second { + return "s" + } + if duration >= time.Millisecond { + return "ms" + } + if duration >= time.Microsecond { + return "us" + } + return "ns" +} diff --git a/flytestdlib/promutils/scope_test.go b/flytestdlib/promutils/scope_test.go new file mode 100644 index 0000000000..cb076ee980 --- /dev/null +++ b/flytestdlib/promutils/scope_test.go @@ -0,0 +1,151 @@ +package promutils + +import ( + "testing" + "time" + + "k8s.io/apimachinery/pkg/util/rand" + + "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/assert" +) + +func TestDurationToString(t *testing.T) { + assert.Equal(t, "m", DurationToString(time.Minute)) + assert.Equal(t, "m", DurationToString(time.Minute*10)) + assert.Equal(t, "h", DurationToString(time.Hour)) + assert.Equal(t, "h", DurationToString(time.Hour*10)) + assert.Equal(t, "s", DurationToString(time.Second)) + assert.Equal(t, "s", DurationToString(time.Second*10)) + assert.Equal(t, "us", DurationToString(time.Microsecond*10)) + assert.Equal(t, "us", DurationToString(time.Microsecond)) + assert.Equal(t, "ms", DurationToString(time.Millisecond*10)) + assert.Equal(t, "ms", DurationToString(time.Millisecond)) + assert.Equal(t, "ns", DurationToString(1)) +} + +func TestNewScope(t *testing.T) { + assert.Panics(t, func() { + NewScope("") + }) + s := NewScope("test") + assert.Equal(t, "test:", s.CurrentScope()) + assert.Equal(t, "test:hello:", s.NewSubScope("hello").CurrentScope()) + assert.Panics(t, func() { + s.NewSubScope("") + }) + assert.Equal(t, "test:timer_x", s.NewScopedMetricName("timer_x")) + assert.Equal(t, "test:hello:timer:", s.NewSubScope("hello").NewSubScope("timer").CurrentScope()) + assert.Equal(t, "test:hello:timer:", s.NewSubScope("hello").NewSubScope("timer:").CurrentScope()) +} + +func TestMetricsScope(t *testing.T) { + s := NewScope("test") + const description = "some x" + if !assert.NotNil(t, prometheus.DefaultRegisterer) { + assert.Fail(t, "Prometheus registrar failed") + } + t.Run("Counter", func(t *testing.T) { + m := s.MustNewCounter("xc", description) + assert.Equal(t, `Desc{fqName: "test:xc", help: "some x", constLabels: {}, variableLabels: []}`, m.Desc().String()) + mv := s.MustNewCounterVec("xcv", description) + assert.NotNil(t, mv) + assert.Panics(t, func() { + _ = s.MustNewCounter("xc", description) + }) + assert.Panics(t, func() { + _ = s.MustNewCounterVec("xcv", description) + }) + }) + + t.Run("Histogram", func(t *testing.T) { + m := s.MustNewHistogram("xh", description) + assert.Equal(t, `Desc{fqName: "test:xh", help: "some x", constLabels: {}, variableLabels: []}`, m.Desc().String()) + mv := s.MustNewHistogramVec("xhv", description) + assert.NotNil(t, mv) + assert.Panics(t, func() { + _ = s.MustNewHistogram("xh", description) + }) + assert.Panics(t, func() { + _ = s.MustNewHistogramVec("xhv", description) + }) + }) + + t.Run("Summary", func(t *testing.T) { + m := s.MustNewSummary("xs", description) + assert.Equal(t, `Desc{fqName: "test:xs", help: "some x", constLabels: {}, variableLabels: []}`, m.Desc().String()) + mv := s.MustNewSummaryVec("xsv", description) + assert.NotNil(t, mv) + assert.Panics(t, func() { + _ = s.MustNewSummary("xs", description) + }) + assert.Panics(t, func() { + _ = s.MustNewSummaryVec("xsv", description) + }) + }) + + t.Run("Gauge", func(t *testing.T) { + m := s.MustNewGauge("xg", description) + assert.Equal(t, `Desc{fqName: "test:xg", help: "some x", constLabels: {}, variableLabels: []}`, m.Desc().String()) + mv := s.MustNewGaugeVec("xgv", description) + assert.NotNil(t, mv) + assert.Panics(t, func() { + m = s.MustNewGauge("xg", description) + }) + assert.Panics(t, func() { + _ = s.MustNewGaugeVec("xgv", description) + }) + }) + + t.Run("Timer", func(t *testing.T) { + m := s.MustNewStopWatch("xt", description, time.Second) + asDesc, ok := m.Observer.(prometheus.Metric) + assert.True(t, ok) + assert.Equal(t, `Desc{fqName: "test:xt_s", help: "some x", constLabels: {}, variableLabels: []}`, asDesc.Desc().String()) + assert.Panics(t, func() { + _ = s.MustNewStopWatch("xt", description, time.Second) + }) + }) + +} + +func TestStopWatch_Start(t *testing.T) { + scope := NewTestScope() + s, e := scope.NewStopWatch("yt"+rand.String(3), "timer", time.Millisecond) + assert.NoError(t, e) + assert.Equal(t, time.Millisecond, s.outputScale) + i := s.Start() + assert.Equal(t, time.Millisecond, i.outputScale) + assert.NotNil(t, i.start) + i.Stop() +} + +func TestStopWatch_Observe(t *testing.T) { + scope := NewTestScope() + s, e := scope.NewStopWatch("yt"+rand.String(3), "timer", time.Millisecond) + assert.NoError(t, e) + assert.Equal(t, time.Millisecond, s.outputScale) + s.Observe(time.Now(), time.Now().Add(time.Second)) +} + +func TestStopWatch_Time(t *testing.T) { + scope := NewTestScope() + s, e := scope.NewStopWatch("yt"+rand.String(3), "timer", time.Millisecond) + assert.NoError(t, e) + assert.Equal(t, time.Millisecond, s.outputScale) + s.Time(func() { + }) +} + +func TestStopWatchVec_WithLabelValues(t *testing.T) { + scope := NewTestScope() + v, e := scope.NewStopWatchVec("yt"+rand.String(3), "timer", time.Millisecond, "workflow", "label") + assert.NoError(t, e) + assert.Equal(t, time.Millisecond, v.outputScale) + s := v.WithLabelValues("my_wf", "something") + assert.NotNil(t, s) + i := s.Start() + assert.Equal(t, time.Millisecond, i.outputScale) + assert.NotNil(t, i.start) + i.Stop() +} diff --git a/flytestdlib/promutils/workqueue.go b/flytestdlib/promutils/workqueue.go new file mode 100644 index 0000000000..24de6a0053 --- /dev/null +++ b/flytestdlib/promutils/workqueue.go @@ -0,0 +1,82 @@ +// Source: https://raw.githubusercontent.com/kubernetes/kubernetes/3dbbd0bdf44cb07fdde85aa392adf99ea7e95939/pkg/util/workqueue/prometheus/prometheus.go +/* +Copyright 2016 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package promutils + +import ( + "k8s.io/client-go/util/workqueue" + + "github.com/prometheus/client_golang/prometheus" +) + +// Package prometheus sets the workqueue DefaultMetricsFactory to produce +// prometheus metrics. To use this package, you just have to import it. + +func init() { + workqueue.SetProvider(prometheusMetricsProvider{}) +} + +type prometheusMetricsProvider struct{} + +func (prometheusMetricsProvider) NewDepthMetric(name string) workqueue.GaugeMetric { + depth := prometheus.NewGauge(prometheus.GaugeOpts{ + Subsystem: name, + Name: "depth", + Help: "Current depth of workqueue: " + name, + }) + prometheus.MustRegister(depth) + return depth +} + +func (prometheusMetricsProvider) NewAddsMetric(name string) workqueue.CounterMetric { + adds := prometheus.NewCounter(prometheus.CounterOpts{ + Subsystem: name, + Name: "adds", + Help: "Total number of adds handled by workqueue: " + name, + }) + prometheus.MustRegister(adds) + return adds +} + +func (prometheusMetricsProvider) NewLatencyMetric(name string) workqueue.SummaryMetric { + latency := prometheus.NewSummary(prometheus.SummaryOpts{ + Subsystem: name, + Name: "queue_latency_us", + Help: "How long an item stays in workqueue" + name + " before being requested.", + }) + prometheus.MustRegister(latency) + return latency +} + +func (prometheusMetricsProvider) NewWorkDurationMetric(name string) workqueue.SummaryMetric { + workDuration := prometheus.NewSummary(prometheus.SummaryOpts{ + Subsystem: name, + Name: "work_duration_us", + Help: "How long processing an item from workqueue" + name + " takes.", + }) + prometheus.MustRegister(workDuration) + return workDuration +} + +func (prometheusMetricsProvider) NewRetriesMetric(name string) workqueue.CounterMetric { + retries := prometheus.NewCounter(prometheus.CounterOpts{ + Subsystem: name, + Name: "retries", + Help: "Total number of retries handled by workqueue: " + name, + }) + prometheus.MustRegister(retries) + return retries +} diff --git a/flytestdlib/promutils/workqueue_test.go b/flytestdlib/promutils/workqueue_test.go new file mode 100644 index 0000000000..4c5bbcae9e --- /dev/null +++ b/flytestdlib/promutils/workqueue_test.go @@ -0,0 +1,42 @@ +package promutils + +import ( + "testing" + + "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/assert" +) + +var provider = prometheusMetricsProvider{} + +func TestPrometheusMetricsProvider(t *testing.T) { + t.Run("Adds", func(t *testing.T) { + c := provider.NewAddsMetric("x") + _, ok := c.(prometheus.Counter) + assert.True(t, ok) + }) + + t.Run("Depth", func(t *testing.T) { + c := provider.NewDepthMetric("x") + _, ok := c.(prometheus.Gauge) + assert.True(t, ok) + }) + + t.Run("Latency", func(t *testing.T) { + c := provider.NewLatencyMetric("x") + _, ok := c.(prometheus.Summary) + assert.True(t, ok) + }) + + t.Run("Retries", func(t *testing.T) { + c := provider.NewRetriesMetric("x") + _, ok := c.(prometheus.Counter) + assert.True(t, ok) + }) + + t.Run("WorkDuration", func(t *testing.T) { + c := provider.NewWorkDurationMetric("x") + _, ok := c.(prometheus.Summary) + assert.True(t, ok) + }) +} diff --git a/flytestdlib/sets/generic_set.go b/flytestdlib/sets/generic_set.go new file mode 100644 index 0000000000..abc739f8aa --- /dev/null +++ b/flytestdlib/sets/generic_set.go @@ -0,0 +1,195 @@ +package sets + +import ( + "sort" +) + +type SetObject interface { + GetID() string +} + +type Generic map[string]SetObject + +// New creates a Generic from a list of values. +func NewGeneric(items ...SetObject) Generic { + gs := Generic{} + gs.Insert(items...) + return gs +} + +// Insert adds items to the set. +func (g Generic) Insert(items ...SetObject) { + for _, item := range items { + g[item.GetID()] = item + } +} + +// Delete removes all items from the set. +func (g Generic) Delete(items ...SetObject) { + for _, item := range items { + delete(g, item.GetID()) + } +} + +// Has returns true if and only if item is contained in the set. +func (g Generic) Has(item SetObject) bool { + _, contained := g[item.GetID()] + return contained +} + +// HasAll returns true if and only if all items are contained in the set. +func (g Generic) HasAll(items ...SetObject) bool { + for _, item := range items { + if !g.Has(item) { + return false + } + } + return true +} + +// HasAny returns true if any items are contained in the set. +func (g Generic) HasAny(items ...SetObject) bool { + for _, item := range items { + if g.Has(item) { + return true + } + } + return false +} + +// Difference returns a set of objects that are not in s2 +// For example: +// s1 = {a1, a2, a3} +// s2 = {a1, a2, a4, a5} +// s1.Difference(s2) = {a3} +// s2.Difference(s1) = {a4, a5} +func (g Generic) Difference(g2 Generic) Generic { + result := NewGeneric() + for _, v := range g { + if !g2.Has(v) { + result.Insert(v) + } + } + return result +} + +// Union returns a new set which includes items in either s1 or s2. +// For example: +// s1 = {a1, a2} +// s2 = {a3, a4} +// s1.Union(s2) = {a1, a2, a3, a4} +// s2.Union(s1) = {a1, a2, a3, a4} +func (g Generic) Union(s2 Generic) Generic { + result := NewGeneric() + for _, v := range g { + result.Insert(v) + } + for _, v := range s2 { + result.Insert(v) + } + return result +} + +// Intersection returns a new set which includes the item in BOTH s1 and s2 +// For example: +// s1 = {a1, a2} +// s2 = {a2, a3} +// s1.Intersection(s2) = {a2} +func (g Generic) Intersection(s2 Generic) Generic { + var walk, other Generic + result := NewGeneric() + if g.Len() < s2.Len() { + walk = g + other = s2 + } else { + walk = s2 + other = g + } + for _, v := range walk { + if other.Has(v) { + result.Insert(v) + } + } + return result +} + +// IsSuperset returns true if and only if s1 is a superset of s2. +func (g Generic) IsSuperset(s2 Generic) bool { + for _, v := range s2 { + if !g.Has(v) { + return false + } + } + return true +} + +// Equal returns true if and only if s1 is equal (as a set) to s2. +// Two sets are equal if their membership is identical. +// (In practice, this means same elements, order doesn't matter) +func (g Generic) Equal(s2 Generic) bool { + return len(g) == len(s2) && g.IsSuperset(s2) +} + +type sortableSliceOfGeneric []string + +func (s sortableSliceOfGeneric) Len() int { return len(s) } +func (s sortableSliceOfGeneric) Less(i, j int) bool { return lessString(s[i], s[j]) } +func (s sortableSliceOfGeneric) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +// List returns the contents as a sorted string slice. +func (g Generic) ListKeys() []string { + res := make(sortableSliceOfGeneric, 0, len(g)) + for key := range g { + res = append(res, key) + } + sort.Sort(res) + return []string(res) +} + +// List returns the contents as a sorted string slice. +func (g Generic) List() []SetObject { + keys := g.ListKeys() + res := make([]SetObject, 0, len(keys)) + for _, k := range keys { + s := g[k] + res = append(res, s) + } + return res +} + +// UnsortedList returns the slice with contents in random order. +func (g Generic) UnsortedListKeys() []string { + res := make([]string, 0, len(g)) + for key := range g { + res = append(res, key) + } + return res +} + +// UnsortedList returns the slice with contents in random order. +func (g Generic) UnsortedList() []SetObject { + res := make([]SetObject, 0, len(g)) + for _, v := range g { + res = append(res, v) + } + return res +} + +// Returns a single element from the set. +func (g Generic) PopAny() (SetObject, bool) { + for _, v := range g { + g.Delete(v) + return v, true + } + var zeroValue SetObject + return zeroValue, false +} + +// Len returns the size of the set. +func (g Generic) Len() int { + return len(g) +} + +func lessString(lhs, rhs string) bool { + return lhs < rhs +} diff --git a/flytestdlib/sets/generic_set_test.go b/flytestdlib/sets/generic_set_test.go new file mode 100644 index 0000000000..9d9f165ed5 --- /dev/null +++ b/flytestdlib/sets/generic_set_test.go @@ -0,0 +1,116 @@ +package sets + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +type GenericVal string + +func (g GenericVal) GetID() string { + return string(g) +} + +func TestGenericSet(t *testing.T) { + assert.Equal(t, []string{"a", "b"}, NewGeneric(GenericVal("a"), GenericVal("b")).ListKeys()) + assert.Equal(t, []string{"a", "b"}, NewGeneric(GenericVal("a"), GenericVal("b"), GenericVal("a")).ListKeys()) + + { + g1 := NewGeneric(GenericVal("a"), GenericVal("b")) + g2 := NewGeneric(GenericVal("a"), GenericVal("b"), GenericVal("c")) + { + g := g1.Intersection(g2) + assert.Equal(t, []string{"a", "b"}, g.ListKeys()) + } + { + g := g2.Intersection(g1) + assert.Equal(t, []string{"a", "b"}, g.ListKeys()) + } + } + { + g1 := NewGeneric(GenericVal("a"), GenericVal("b")) + g2 := NewGeneric(GenericVal("a"), GenericVal("b"), GenericVal("c")) + g := g2.Difference(g1) + assert.Equal(t, []string{"c"}, g.ListKeys()) + } + { + g1 := NewGeneric(GenericVal("a"), GenericVal("b")) + g2 := NewGeneric(GenericVal("a"), GenericVal("b"), GenericVal("c")) + g := g1.Difference(g2) + assert.Equal(t, []string{}, g.ListKeys()) + } + { + g1 := NewGeneric(GenericVal("a"), GenericVal("b")) + assert.True(t, g1.Has(GenericVal("a"))) + assert.False(t, g1.Has(GenericVal("c"))) + } + { + g1 := NewGeneric(GenericVal("a"), GenericVal("b")) + assert.False(t, g1.HasAll(GenericVal("a"), GenericVal("b"), GenericVal("c"))) + assert.True(t, g1.HasAll(GenericVal("a"), GenericVal("b"))) + } + + { + g1 := NewGeneric(GenericVal("a"), GenericVal("b")) + g2 := NewGeneric(GenericVal("a"), GenericVal("b"), GenericVal("c")) + g := g1.Union(g2) + assert.Equal(t, []string{"a", "b", "c"}, g.ListKeys()) + } + + { + g1 := NewGeneric(GenericVal("a"), GenericVal("b")) + g2 := NewGeneric(GenericVal("a"), GenericVal("b"), GenericVal("c")) + assert.True(t, g2.IsSuperset(g1)) + assert.False(t, g1.IsSuperset(g2)) + } + + { + g1 := NewGeneric(GenericVal("a"), GenericVal("b")) + assert.True(t, g1.Has(GenericVal("a"))) + assert.False(t, g1.Has(GenericVal("c"))) + assert.Equal(t, []SetObject{ + GenericVal("a"), + GenericVal("b"), + }, g1.List()) + g2 := NewGeneric(g1.UnsortedList()...) + assert.True(t, g1.Equal(g2)) + } + { + g1 := NewGeneric(GenericVal("a"), GenericVal("b")) + g2 := NewGeneric(GenericVal("a"), GenericVal("b"), GenericVal("c")) + g3 := NewGeneric(GenericVal("a"), GenericVal("b")) + assert.False(t, g1.Equal(g2)) + assert.True(t, g1.Equal(g3)) + + assert.Equal(t, 2, g1.Len()) + g1.Insert(GenericVal("b")) + assert.Equal(t, 2, g1.Len()) + assert.True(t, g1.Equal(g3)) + g1.Insert(GenericVal("c")) + assert.Equal(t, 3, g1.Len()) + assert.True(t, g1.Equal(g2)) + assert.True(t, g1.HasAny(GenericVal("a"), GenericVal("d"))) + assert.False(t, g1.HasAny(GenericVal("f"), GenericVal("d"))) + g1.Delete(GenericVal("f")) + assert.True(t, g1.Equal(g2)) + g1.Delete(GenericVal("c")) + assert.True(t, g1.Equal(g3)) + + { + p, ok := g1.PopAny() + assert.NotNil(t, p) + assert.True(t, ok) + } + { + p, ok := g1.PopAny() + assert.NotNil(t, p) + assert.True(t, ok) + } + { + p, ok := g1.PopAny() + assert.Nil(t, p) + assert.False(t, ok) + } + } +} diff --git a/flytestdlib/storage/cached_rawstore.go b/flytestdlib/storage/cached_rawstore.go new file mode 100644 index 0000000000..2b539d7bb1 --- /dev/null +++ b/flytestdlib/storage/cached_rawstore.go @@ -0,0 +1,123 @@ +package storage + +import ( + "bytes" + "context" + "io" + "runtime/debug" + "time" + + "github.com/prometheus/client_golang/prometheus" + + "github.com/lyft/flytestdlib/promutils" + + "github.com/coocood/freecache" + "github.com/lyft/flytestdlib/ioutils" + "github.com/lyft/flytestdlib/logger" +) + +const neverExpire = 0 + +// TODO Freecache has bunch of metrics it calculates. Lets write a prom collector to publish these metrics +type cacheMetrics struct { + CacheHit prometheus.Counter + CacheMiss prometheus.Counter + CacheWriteError prometheus.Counter + FetchLatency promutils.StopWatch +} + +type cachedRawStore struct { + RawStore + cache *freecache.Cache + scope promutils.Scope + metrics *cacheMetrics +} + +// Gets metadata about the reference. This should generally be a light weight operation. +func (s *cachedRawStore) Head(ctx context.Context, reference DataReference) (Metadata, error) { + key := []byte(reference) + if oRaw, err := s.cache.Get(key); err == nil { + s.metrics.CacheHit.Inc() + // Found, Cache hit + size := int64(len(oRaw)) + // return size in metadata + return StowMetadata{exists: true, size: size}, nil + } + s.metrics.CacheMiss.Inc() + return s.RawStore.Head(ctx, reference) +} + +// Retrieves a byte array from the Blob store or an error +func (s *cachedRawStore) ReadRaw(ctx context.Context, reference DataReference) (io.ReadCloser, error) { + key := []byte(reference) + if oRaw, err := s.cache.Get(key); err == nil { + // Found, Cache hit + s.metrics.CacheHit.Inc() + return ioutils.NewBytesReadCloser(oRaw), nil + } + s.metrics.CacheMiss.Inc() + reader, err := s.RawStore.ReadRaw(ctx, reference) + if err != nil { + return nil, err + } + + defer func() { + err = reader.Close() + if err != nil { + logger.Warnf(ctx, "Failed to close reader [%v]. Error: %v", reference, err) + } + }() + + b, err := ioutils.ReadAll(reader, s.metrics.FetchLatency.Start()) + if err != nil { + return nil, err + } + + err = s.cache.Set(key, b, 0) + if err != nil { + // TODO Ignore errors in writing to cache? + logger.Debugf(ctx, "Failed to Cache the metadata") + } + + return ioutils.NewBytesReadCloser(b), nil +} + +// Stores a raw byte array. +func (s *cachedRawStore) WriteRaw(ctx context.Context, reference DataReference, size int64, opts Options, raw io.Reader) error { + var buf bytes.Buffer + teeReader := io.TeeReader(raw, &buf) + err := s.RawStore.WriteRaw(ctx, reference, size, opts, teeReader) + if err != nil { + return err + } + + err = s.cache.Set([]byte(reference), buf.Bytes(), neverExpire) + if err != nil { + s.metrics.CacheWriteError.Inc() + } + + // TODO ignore errors? + return err +} + +// Creates a CachedStore if Caching is enabled, otherwise returns a RawStore +func newCachedRawStore(cfg *Config, store RawStore, scope promutils.Scope) RawStore { + if cfg.Cache.MaxSizeMegabytes > 0 { + c := &cachedRawStore{ + RawStore: store, + cache: freecache.NewCache(cfg.Cache.MaxSizeMegabytes * 1024 * 1024), + scope: scope, + metrics: &cacheMetrics{ + FetchLatency: scope.MustNewStopWatch("remote_fetch", "Total Time to read from remote metastore", time.Millisecond), + CacheHit: scope.MustNewCounter("cache_hit", "Number of times metadata was found in cache"), + CacheMiss: scope.MustNewCounter("cache_miss", "Number of times metadata was not found in cache and remote fetch was required"), + CacheWriteError: scope.MustNewCounter("cache_write_err", "Failed to write to cache"), + }, + } + if cfg.Cache.TargetGCPercent > 0 { + debug.SetGCPercent(cfg.Cache.TargetGCPercent) + } + return c + } + return store +} diff --git a/flytestdlib/storage/cached_rawstore_test.go b/flytestdlib/storage/cached_rawstore_test.go new file mode 100644 index 0000000000..316f999bf7 --- /dev/null +++ b/flytestdlib/storage/cached_rawstore_test.go @@ -0,0 +1,182 @@ +package storage + +import ( + "bytes" + "context" + "fmt" + "io" + "io/ioutil" + "runtime/debug" + "testing" + + "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flytestdlib/promutils/labeled" + + "github.com/lyft/flytestdlib/promutils" + + "github.com/lyft/flytestdlib/ioutils" + "github.com/stretchr/testify/assert" +) + +func init() { + labeled.SetMetricKeys(contextutils.ProjectKey, contextutils.DomainKey, contextutils.WorkflowIDKey, contextutils.TaskIDKey) +} + +func TestNewCachedStore(t *testing.T) { + + t.Run("CachingDisabled", func(t *testing.T) { + testScope := promutils.NewTestScope() + cfg := &Config{} + assert.Nil(t, newCachedRawStore(cfg, nil, testScope)) + store, err := NewInMemoryRawStore(cfg, testScope) + assert.NoError(t, err) + assert.Equal(t, store, newCachedRawStore(cfg, store, testScope)) + }) + + t.Run("CachingEnabled", func(t *testing.T) { + testScope := promutils.NewTestScope() + cfg := &Config{ + Cache: CachingConfig{ + MaxSizeMegabytes: 1, + TargetGCPercent: 20, + }, + } + store, err := NewInMemoryRawStore(cfg, testScope) + assert.NoError(t, err) + cStore := newCachedRawStore(cfg, store, testScope) + assert.Equal(t, 20, debug.SetGCPercent(100)) + assert.NotNil(t, cStore) + assert.NotNil(t, cStore.(*cachedRawStore).cache) + }) +} + +func dummyCacheStore(t *testing.T, store RawStore, scope promutils.Scope) *cachedRawStore { + cfg := &Config{ + Cache: CachingConfig{ + MaxSizeMegabytes: 1, + TargetGCPercent: 20, + }, + } + cStore := newCachedRawStore(cfg, store, scope) + assert.NotNil(t, cStore) + return cStore.(*cachedRawStore) +} + +type dummyStore struct { + copyImpl + HeadCb func(ctx context.Context, reference DataReference) (Metadata, error) + ReadRawCb func(ctx context.Context, reference DataReference) (io.ReadCloser, error) + WriteRawCb func(ctx context.Context, reference DataReference, size int64, opts Options, raw io.Reader) error +} + +func (d *dummyStore) GetBaseContainerFQN(ctx context.Context) DataReference { + return "dummy" +} + +func (d *dummyStore) Head(ctx context.Context, reference DataReference) (Metadata, error) { + return d.HeadCb(ctx, reference) +} + +func (d *dummyStore) ReadRaw(ctx context.Context, reference DataReference) (io.ReadCloser, error) { + return d.ReadRawCb(ctx, reference) +} + +func (d *dummyStore) WriteRaw(ctx context.Context, reference DataReference, size int64, opts Options, raw io.Reader) error { + return d.WriteRawCb(ctx, reference, size, opts, raw) +} + +func TestCachedRawStore(t *testing.T) { + ctx := context.TODO() + k1 := DataReference("k1") + k2 := DataReference("k2") + d1 := []byte("abc") + d2 := []byte("xyz") + writeCalled := false + readCalled := false + store := &dummyStore{ + HeadCb: func(ctx context.Context, reference DataReference) (Metadata, error) { + if reference == "k1" { + return MemoryMetadata{exists: true, size: int64(len(d1))}, nil + } + return MemoryMetadata{}, fmt.Errorf("err") + }, + WriteRawCb: func(ctx context.Context, reference DataReference, size int64, opts Options, raw io.Reader) error { + if writeCalled { + assert.FailNow(t, "Should not be writeCalled") + } + writeCalled = true + if reference == "k2" { + b, err := ioutil.ReadAll(raw) + assert.NoError(t, err) + assert.Equal(t, d2, b) + return nil + } + return fmt.Errorf("err") + }, + ReadRawCb: func(ctx context.Context, reference DataReference) (io.ReadCloser, error) { + if readCalled { + assert.FailNow(t, "Should not be invoked again") + } + readCalled = true + if reference == "k1" { + return ioutils.NewBytesReadCloser(d1), nil + } + return nil, fmt.Errorf("err") + }, + } + testScope := promutils.NewTestScope() + + store.copyImpl = newCopyImpl(store, testScope.NewSubScope("copy")) + + cStore := dummyCacheStore(t, store, testScope.NewSubScope("x")) + + t.Run("HeadExists", func(t *testing.T) { + m, err := cStore.Head(ctx, k1) + assert.NoError(t, err) + assert.Equal(t, int64(len(d1)), m.Size()) + assert.True(t, m.Exists()) + }) + + t.Run("HeadNotExists", func(t *testing.T) { + m, err := cStore.Head(ctx, k2) + assert.Error(t, err) + assert.False(t, m.Exists()) + }) + + t.Run("ReadCachePopulate", func(t *testing.T) { + o, err := cStore.ReadRaw(ctx, k1) + assert.NoError(t, err) + b, err := ioutil.ReadAll(o) + assert.NoError(t, err) + assert.Equal(t, d1, b) + assert.True(t, readCalled) + readCalled = false + o, err = cStore.ReadRaw(ctx, k1) + assert.NoError(t, err) + b, err = ioutil.ReadAll(o) + assert.NoError(t, err) + assert.Equal(t, d1, b) + assert.False(t, readCalled) + }) + + t.Run("ReadFail", func(t *testing.T) { + readCalled = false + _, err := cStore.ReadRaw(ctx, k2) + assert.Error(t, err) + assert.True(t, readCalled) + }) + + t.Run("WriteAndRead", func(t *testing.T) { + readCalled = false + assert.NoError(t, cStore.WriteRaw(ctx, k2, int64(len(d2)), Options{}, bytes.NewReader(d2))) + assert.True(t, writeCalled) + + o, err := cStore.ReadRaw(ctx, k2) + assert.NoError(t, err) + b, err := ioutil.ReadAll(o) + assert.NoError(t, err) + assert.Equal(t, d2, b) + assert.False(t, readCalled) + }) + +} diff --git a/flytestdlib/storage/config.go b/flytestdlib/storage/config.go new file mode 100644 index 0000000000..59db0197c3 --- /dev/null +++ b/flytestdlib/storage/config.go @@ -0,0 +1,85 @@ +package storage + +import ( + "context" + + "github.com/lyft/flytestdlib/config" + "github.com/lyft/flytestdlib/logger" +) + +//go:generate pflags Config + +// Defines the storage config type. +type Type = string + +// The reserved config section key for storage. +const configSectionKey = "Storage" + +const ( + TypeMemory Type = "mem" + TypeS3 Type = "s3" + TypeLocal Type = "local" + TypeMinio Type = "minio" +) + +const ( + KiB int64 = 1024 + MiB int64 = 1024 * KiB +) + +var ( + ConfigSection = config.MustRegisterSection(configSectionKey, &Config{}) +) + +// A common storage config. +type Config struct { + Type Type `json:"type" pflag:"\"s3\",Sets the type of storage to configure [s3/minio/local/mem]."` + Connection ConnectionConfig `json:"connection"` + InitContainer string `json:"container" pflag:",Initial container to create -if it doesn't exist-.'"` + // Caching is recommended to improve the performance of underlying systems. It caches the metadata and resolving + // inputs is accelerated. The size of the cache is large so understand how to configure the cache. + // TODO provide some default config choices + // If this section is skipped, Caching is disabled + Cache CachingConfig `json:"cache"` + Limits LimitsConfig `json:"limits" pflag:",Sets limits for stores."` +} + +// Defines connection configurations. +type ConnectionConfig struct { + Endpoint config.URL `json:"endpoint" pflag:",URL for storage client to connect to."` + AuthType string `json:"auth-type" pflag:"\"iam\",Auth Type to use [iam,accesskey]."` + AccessKey string `json:"access-key" pflag:",Access key to use. Only required when authtype is set to accesskey."` + SecretKey string `json:"secret-key" pflag:",Secret to use when accesskey is set."` + Region string `json:"region" pflag:"\"us-east-1\",Region to connect to."` + DisableSSL bool `json:"disable-ssl" pflag:",Disables SSL connection. Should only be used for development."` +} + +type CachingConfig struct { + // Maximum size of the cache where the Blob store data is cached in-memory + // Refer to https://github.com/coocood/freecache to understand how to set the value + // If not specified or set to 0, cache is not used + // NOTE: if Object sizes are larger than 1/1024 of the cache size, the entry will not be written to the cache + // Also refer to https://github.com/coocood/freecache/issues/17 to understand how to set the cache + MaxSizeMegabytes int `json:"max_size_mbs" pflag:",Maximum size of the cache where the Blob store data is cached in-memory. If not specified or set to 0, cache is not used"` + // sets the garbage collection target percentage: + // a collection is triggered when the ratio of freshly allocated data + // to live data remaining after the previous collection reaches this percentage. + // refer to https://golang.org/pkg/runtime/debug/#SetGCPercent + // If not specified or set to 0, GC percent is not tweaked + TargetGCPercent int `json:"target_gc_percent" pflag:",Sets the garbage collection target percentage."` +} + +// Specifies limits for storage package. +type LimitsConfig struct { + GetLimitMegabytes int64 `json:"maxDownloadMBs" pflag:"2,Maximum allowed download size (in MBs) per call."` +} + +// Retrieve current global config for storage. +func GetConfig() *Config { + if c, ok := ConfigSection.GetConfig().(*Config); ok { + return c + } + + logger.Warnf(context.TODO(), "Failed to retrieve config section [%v].", configSectionKey) + return nil +} diff --git a/flytestdlib/storage/config_flags.go b/flytestdlib/storage/config_flags.go new file mode 100755 index 0000000000..9a74efdd64 --- /dev/null +++ b/flytestdlib/storage/config_flags.go @@ -0,0 +1,28 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package storage + +import ( + "fmt" + + "github.com/spf13/pflag" +) + +// GetPFlagSet will return strongly types pflags for all fields in Config and its nested types. The format of the +// flags is json-name.json-sub-name... etc. +func (Config) GetPFlagSet(prefix string) *pflag.FlagSet { + cmdFlags := pflag.NewFlagSet("Config", pflag.ExitOnError) + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "type"), "s3", "Sets the type of storage to configure [s3/minio/local/mem].") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "connection.endpoint"), "", "URL for storage client to connect to.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "connection.auth-type"), "iam", "Auth Type to use [iam, accesskey].") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "connection.access-key"), *new(string), "Access key to use. Only required when authtype is set to accesskey.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "connection.secret-key"), *new(string), "Secret to use when accesskey is set.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "connection.region"), "us-east-1", "Region to connect to.") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "connection.disable-ssl"), *new(bool), "Disables SSL connection. Should only be used for development.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "container"), *new(string), "Initial container to create -if it doesn't exist-.'") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "cache.max_size_mbs"), *new(int), "Maximum size of the cache where the Blob store data is cached in-memory. If not specified or set to 0, cache is not used") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "cache.target_gc_percent"), *new(int), "Sets the garbage collection target percentage.") + cmdFlags.Int64(fmt.Sprintf("%v%v", prefix, "limits.maxDownloadMBs"), 2, "Maximum allowed download size (in MBs) per call.") + return cmdFlags +} diff --git a/flytestdlib/storage/config_flags_test.go b/flytestdlib/storage/config_flags_test.go new file mode 100755 index 0000000000..2f39f00006 --- /dev/null +++ b/flytestdlib/storage/config_flags_test.go @@ -0,0 +1,344 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package storage + +import ( + "encoding/json" + "fmt" + "reflect" + "strings" + "testing" + + "github.com/mitchellh/mapstructure" + "github.com/stretchr/testify/assert" +) + +var dereferencableKindsConfig = map[reflect.Kind]struct{}{ + reflect.Array: {}, reflect.Chan: {}, reflect.Map: {}, reflect.Ptr: {}, reflect.Slice: {}, +} + +// Checks if t is a kind that can be dereferenced to get its underlying type. +func canGetElementConfig(t reflect.Kind) bool { + _, exists := dereferencableKindsConfig[t] + return exists +} + +// This decoder hook tests types for json unmarshaling capability. If implemented, it uses json unmarshal to build the +// object. Otherwise, it'll just pass on the original data. +func jsonUnmarshalerHookConfig(_, to reflect.Type, data interface{}) (interface{}, error) { + unmarshalerType := reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() + if to.Implements(unmarshalerType) || reflect.PtrTo(to).Implements(unmarshalerType) || + (canGetElementConfig(to.Kind()) && to.Elem().Implements(unmarshalerType)) { + + raw, err := json.Marshal(data) + if err != nil { + fmt.Printf("Failed to marshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + res := reflect.New(to).Interface() + err = json.Unmarshal(raw, &res) + if err != nil { + fmt.Printf("Failed to umarshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + return res, nil + } + + return data, nil +} + +func decode_Config(input, result interface{}) error { + config := &mapstructure.DecoderConfig{ + TagName: "json", + WeaklyTypedInput: true, + Result: result, + DecodeHook: mapstructure.ComposeDecodeHookFunc( + mapstructure.StringToTimeDurationHookFunc(), + mapstructure.StringToSliceHookFunc(","), + jsonUnmarshalerHookConfig, + ), + } + + decoder, err := mapstructure.NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(input) +} + +func join_Config(arr interface{}, sep string) string { + listValue := reflect.ValueOf(arr) + strs := make([]string, 0, listValue.Len()) + for i := 0; i < listValue.Len(); i++ { + strs = append(strs, fmt.Sprintf("%v", listValue.Index(i))) + } + + return strings.Join(strs, sep) +} + +func testDecodeJson_Config(t *testing.T, val, result interface{}) { + assert.NoError(t, decode_Config(val, result)) +} + +func testDecodeSlice_Config(t *testing.T, vStringSlice, result interface{}) { + assert.NoError(t, decode_Config(vStringSlice, result)) +} + +func TestConfig_GetPFlagSet(t *testing.T) { + val := Config{} + cmdFlags := val.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) +} + +func TestConfig_SetFlags(t *testing.T) { + actual := Config{} + cmdFlags := actual.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) + + t.Run("Test_type", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("type"); err == nil { + assert.Equal(t, string("s3"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("type", testValue) + if vString, err := cmdFlags.GetString("type"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Type) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_connection.endpoint", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("connection.endpoint"); err == nil { + assert.Equal(t, string(""), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("connection.endpoint", testValue) + if vString, err := cmdFlags.GetString("connection.endpoint"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Connection.Endpoint) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_connection.auth-type", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("connection.auth-type"); err == nil { + assert.Equal(t, string("iam"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("connection.auth-type", testValue) + if vString, err := cmdFlags.GetString("connection.auth-type"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Connection.AuthType) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_connection.access-key", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("connection.access-key"); err == nil { + assert.Equal(t, string(*new(string)), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("connection.access-key", testValue) + if vString, err := cmdFlags.GetString("connection.access-key"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Connection.AccessKey) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_connection.secret-key", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("connection.secret-key"); err == nil { + assert.Equal(t, string(*new(string)), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("connection.secret-key", testValue) + if vString, err := cmdFlags.GetString("connection.secret-key"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Connection.SecretKey) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_connection.region", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("connection.region"); err == nil { + assert.Equal(t, string("us-east-1"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("connection.region", testValue) + if vString, err := cmdFlags.GetString("connection.region"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Connection.Region) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_connection.disable-ssl", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vBool, err := cmdFlags.GetBool("connection.disable-ssl"); err == nil { + assert.Equal(t, bool(*new(bool)), vBool) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("connection.disable-ssl", testValue) + if vBool, err := cmdFlags.GetBool("connection.disable-ssl"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vBool), &actual.Connection.DisableSSL) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_container", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("container"); err == nil { + assert.Equal(t, string(*new(string)), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("container", testValue) + if vString, err := cmdFlags.GetString("container"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.InitContainer) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_cache.max_size_mbs", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("cache.max_size_mbs"); err == nil { + assert.Equal(t, int(*new(int)), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("cache.max_size_mbs", testValue) + if vInt, err := cmdFlags.GetInt("cache.max_size_mbs"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.Cache.MaxSizeMegabytes) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_cache.target_gc_percent", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("cache.target_gc_percent"); err == nil { + assert.Equal(t, int(*new(int)), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("cache.target_gc_percent", testValue) + if vInt, err := cmdFlags.GetInt("cache.target_gc_percent"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.Cache.TargetGCPercent) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_limits.maxDownloadMBs", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt64, err := cmdFlags.GetInt64("limits.maxDownloadMBs"); err == nil { + assert.Equal(t, int64(2), vInt64) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("limits.maxDownloadMBs", testValue) + if vInt64, err := cmdFlags.GetInt64("limits.maxDownloadMBs"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt64), &actual.Limits.GetLimitMegabytes) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) +} diff --git a/flytestdlib/storage/config_test.go b/flytestdlib/storage/config_test.go new file mode 100644 index 0000000000..93a5fe887c --- /dev/null +++ b/flytestdlib/storage/config_test.go @@ -0,0 +1,45 @@ +package storage + +import ( + "flag" + "io/ioutil" + "os" + "path/filepath" + "reflect" + "testing" + + "github.com/ghodss/yaml" + "github.com/lyft/flytestdlib/config" + "github.com/lyft/flytestdlib/internal/utils" + "github.com/stretchr/testify/assert" +) + +// Make sure existing config file(s) parse correctly before overriding them with this flag! +var update = flag.Bool("update", false, "Updates testdata") + +func TestMarshal(t *testing.T) { + expected := Config{ + Type: "s3", + Connection: ConnectionConfig{ + Endpoint: config.URL{URL: utils.MustParseURL("http://minio:9000")}, + AuthType: "accesskey", + AccessKey: "minio", + SecretKey: "miniostorage", + Region: "us-east-1", + DisableSSL: true, + }, + } + + if *update { + t.Log("Updating config file.") + raw, err := yaml.Marshal(expected) + assert.NoError(t, err) + assert.NoError(t, ioutil.WriteFile(filepath.Join("testdata", "config.yaml"), raw, os.ModePerm)) + } + + actual := Config{} + raw, err := ioutil.ReadFile(filepath.Join("testdata", "config.yaml")) + assert.NoError(t, err) + assert.NoError(t, yaml.Unmarshal(raw, &actual)) + assert.True(t, reflect.DeepEqual(expected, actual)) +} diff --git a/flytestdlib/storage/copy_impl.go b/flytestdlib/storage/copy_impl.go new file mode 100644 index 0000000000..43f97f026f --- /dev/null +++ b/flytestdlib/storage/copy_impl.go @@ -0,0 +1,60 @@ +package storage + +import ( + "context" + "io" + "time" + + "github.com/lyft/flytestdlib/ioutils" + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/promutils/labeled" +) + +type copyImpl struct { + rawStore RawStore + metrics copyMetrics +} + +type copyMetrics struct { + CopyLatency labeled.StopWatch + ComputeLengthLatency labeled.StopWatch +} + +// A naiive implementation for copy that reads all data locally then writes them to destination. +// TODO: We should upstream an API change to stow to implement copy more natively. E.g. Use s3 copy: +// https://docs.aws.amazon.com/AmazonS3/latest/dev/CopyingObjectUsingREST.html +func (c copyImpl) CopyRaw(ctx context.Context, source, destination DataReference, opts Options) error { + rc, err := c.rawStore.ReadRaw(ctx, source) + if err != nil { + return err + } + + length := int64(0) + if _, isSeeker := rc.(io.Seeker); !isSeeker { + // If the returned ReadCloser doesn't implement Seeker interface, then the underlying writer won't be able to + // calculate content length on its own. Some implementations (e.g. S3 Stow Store) will error if it can't. + var raw []byte + raw, err = ioutils.ReadAll(rc, c.metrics.ComputeLengthLatency.Start(ctx)) + if err != nil { + return err + } + + length = int64(len(raw)) + } + + return c.rawStore.WriteRaw(ctx, destination, length, Options{}, rc) +} + +func newCopyMetrics(scope promutils.Scope) copyMetrics { + return copyMetrics{ + CopyLatency: labeled.NewStopWatch("overall", "Overall copy latency", time.Millisecond, scope, labeled.EmitUnlabeledMetric), + ComputeLengthLatency: labeled.NewStopWatch("length", "Latency involved in computing length of content before writing.", time.Millisecond, scope, labeled.EmitUnlabeledMetric), + } +} + +func newCopyImpl(store RawStore, metricsScope promutils.Scope) copyImpl { + return copyImpl{ + rawStore: store, + metrics: newCopyMetrics(metricsScope.NewSubScope("copy")), + } +} diff --git a/flytestdlib/storage/copy_impl_test.go b/flytestdlib/storage/copy_impl_test.go new file mode 100644 index 0000000000..fc8d78cd1a --- /dev/null +++ b/flytestdlib/storage/copy_impl_test.go @@ -0,0 +1,82 @@ +package storage + +import ( + "context" + "io" + "testing" + + "github.com/lyft/flytestdlib/ioutils" + "github.com/lyft/flytestdlib/promutils" + "github.com/stretchr/testify/assert" +) + +type notSeekerReader struct { + bytesCount int +} + +func (notSeekerReader) Close() error { + return nil +} + +func (r *notSeekerReader) Read(p []byte) (n int, err error) { + if len(p) < 1 { + return 0, nil + } + + p[0] = byte(10) + + r.bytesCount-- + if r.bytesCount <= 0 { + return 0, io.EOF + } + + return 1, nil +} + +func newNotSeekerReader(bytesCount int) *notSeekerReader { + return ¬SeekerReader{ + bytesCount: bytesCount, + } +} + +func TestCopyRaw(t *testing.T) { + t.Run("Called", func(t *testing.T) { + readerCalled := false + writerCalled := false + store := dummyStore{ + ReadRawCb: func(ctx context.Context, reference DataReference) (closer io.ReadCloser, e error) { + readerCalled = true + return ioutils.NewBytesReadCloser([]byte{}), nil + }, + WriteRawCb: func(ctx context.Context, reference DataReference, size int64, opts Options, raw io.Reader) error { + writerCalled = true + return nil + }, + } + + copier := newCopyImpl(&store, promutils.NewTestScope()) + assert.NoError(t, copier.CopyRaw(context.Background(), DataReference("source.pb"), DataReference("dest.pb"), Options{})) + assert.True(t, readerCalled) + assert.True(t, writerCalled) + }) + + t.Run("Not Seeker", func(t *testing.T) { + readerCalled := false + writerCalled := false + store := dummyStore{ + ReadRawCb: func(ctx context.Context, reference DataReference) (closer io.ReadCloser, e error) { + readerCalled = true + return newNotSeekerReader(10), nil + }, + WriteRawCb: func(ctx context.Context, reference DataReference, size int64, opts Options, raw io.Reader) error { + writerCalled = true + return nil + }, + } + + copier := newCopyImpl(&store, promutils.NewTestScope()) + assert.NoError(t, copier.CopyRaw(context.Background(), DataReference("source.pb"), DataReference("dest.pb"), Options{})) + assert.True(t, readerCalled) + assert.True(t, writerCalled) + }) +} diff --git a/flytestdlib/storage/localstore.go b/flytestdlib/storage/localstore.go new file mode 100644 index 0000000000..450c102a24 --- /dev/null +++ b/flytestdlib/storage/localstore.go @@ -0,0 +1,48 @@ +package storage + +import ( + "fmt" + + "github.com/lyft/flytestdlib/promutils" + + "github.com/graymeta/stow" + "github.com/graymeta/stow/local" +) + +func getLocalStowConfigMap(cfg *Config) stow.ConfigMap { + stowConfig := stow.ConfigMap{} + if endpoint := cfg.Connection.Endpoint.String(); endpoint != "" { + stowConfig[local.ConfigKeyPath] = endpoint + } + + return stowConfig +} + +// Creates a Data store backed by Stow-S3 raw store. +func newLocalRawStore(cfg *Config, metricsScope promutils.Scope) (RawStore, error) { + if cfg.InitContainer == "" { + return nil, fmt.Errorf("initContainer is required") + } + + loc, err := stow.Dial(local.Kind, getLocalStowConfigMap(cfg)) + + if err != nil { + return emptyStore, fmt.Errorf("unable to configure the storage for local. Error: %v", err) + } + + c, err := loc.Container(cfg.InitContainer) + if err != nil { + if IsNotFound(err) { + c, err = loc.CreateContainer(cfg.InitContainer) + if err != nil && !IsExists(err) { + return emptyStore, fmt.Errorf("unable to initialize container [%v]. Error: %v", cfg.InitContainer, err) + } + + return NewStowRawStore(DataReference(c.Name()), c, metricsScope) + } + + return emptyStore, err + } + + return NewStowRawStore(DataReference(c.Name()), c, metricsScope) +} diff --git a/flytestdlib/storage/localstore_test.go b/flytestdlib/storage/localstore_test.go new file mode 100644 index 0000000000..31f8aabeb2 --- /dev/null +++ b/flytestdlib/storage/localstore_test.go @@ -0,0 +1,66 @@ +package storage + +import ( + "context" + "io/ioutil" + "os" + "path/filepath" + "testing" + + "github.com/lyft/flytestdlib/promutils" + + "github.com/lyft/flytestdlib/config" + "github.com/lyft/flytestdlib/internal/utils" + "github.com/stretchr/testify/assert" +) + +func TestNewLocalStore(t *testing.T) { + t.Run("Valid config", func(t *testing.T) { + testScope := promutils.NewTestScope() + store, err := newLocalRawStore(&Config{ + Connection: ConnectionConfig{ + Endpoint: config.URL{URL: utils.MustParseURL("./")}, + }, + InitContainer: "testdata", + }, testScope.NewSubScope("x")) + + assert.NoError(t, err) + assert.NotNil(t, store) + + // Stow local store expects the full path after the container portion (looks like a bug to me) + rc, err := store.ReadRaw(context.TODO(), DataReference("file://testdata/testdata/config.yaml")) + assert.NoError(t, err) + assert.NotNil(t, rc) + assert.NoError(t, rc.Close()) + }) + + t.Run("Invalid config", func(t *testing.T) { + testScope := promutils.NewTestScope() + _, err := newLocalRawStore(&Config{}, testScope) + assert.Error(t, err) + }) + + t.Run("Initialize container", func(t *testing.T) { + testScope := promutils.NewTestScope() + tmpDir, err := ioutil.TempDir("", "stdlib_local") + assert.NoError(t, err) + + stats, err := os.Stat(tmpDir) + assert.NoError(t, err) + assert.NotNil(t, stats) + + store, err := newLocalRawStore(&Config{ + Connection: ConnectionConfig{ + Endpoint: config.URL{URL: utils.MustParseURL(tmpDir)}, + }, + InitContainer: "tmp", + }, testScope.NewSubScope("y")) + + assert.NoError(t, err) + assert.NotNil(t, store) + + stats, err = os.Stat(filepath.Join(tmpDir, "tmp")) + assert.NoError(t, err) + assert.True(t, stats.IsDir()) + }) +} diff --git a/flytestdlib/storage/mem_store.go b/flytestdlib/storage/mem_store.go new file mode 100644 index 0000000000..cc0c5854c0 --- /dev/null +++ b/flytestdlib/storage/mem_store.go @@ -0,0 +1,74 @@ +package storage + +import ( + "bytes" + "context" + "io" + "io/ioutil" + "os" + + "github.com/lyft/flytestdlib/promutils" +) + +type rawFile = []byte + +type InMemoryStore struct { + copyImpl + cache map[DataReference]rawFile +} + +type MemoryMetadata struct { + exists bool + size int64 +} + +func (m MemoryMetadata) Size() int64 { + return m.size +} + +func (m MemoryMetadata) Exists() bool { + return m.exists +} + +func (s *InMemoryStore) Head(ctx context.Context, reference DataReference) (Metadata, error) { + data, found := s.cache[reference] + return MemoryMetadata{exists: found, size: int64(len(data))}, nil +} + +func (s *InMemoryStore) ReadRaw(ctx context.Context, reference DataReference) (io.ReadCloser, error) { + if raw, found := s.cache[reference]; found { + return ioutil.NopCloser(bytes.NewReader(raw)), nil + } + + return nil, os.ErrNotExist +} + +func (s *InMemoryStore) WriteRaw(ctx context.Context, reference DataReference, size int64, opts Options, raw io.Reader) ( + err error) { + + rawBytes, err := ioutil.ReadAll(raw) + if err != nil { + return err + } + + s.cache[reference] = rawBytes + return nil +} + +func (s *InMemoryStore) Clear(ctx context.Context) error { + s.cache = map[DataReference]rawFile{} + return nil +} + +func (s *InMemoryStore) GetBaseContainerFQN(ctx context.Context) DataReference { + return DataReference("") +} + +func NewInMemoryRawStore(_ *Config, scope promutils.Scope) (RawStore, error) { + self := &InMemoryStore{ + cache: map[DataReference]rawFile{}, + } + + self.copyImpl = newCopyImpl(self, scope) + return self, nil +} diff --git a/flytestdlib/storage/mem_store_test.go b/flytestdlib/storage/mem_store_test.go new file mode 100644 index 0000000000..fdfe2b724a --- /dev/null +++ b/flytestdlib/storage/mem_store_test.go @@ -0,0 +1,78 @@ +package storage + +import ( + "bytes" + "context" + "testing" + + "github.com/lyft/flytestdlib/promutils" + + "github.com/stretchr/testify/assert" +) + +func TestInMemoryStore_Head(t *testing.T) { + t.Run("Empty store", func(t *testing.T) { + testScope := promutils.NewTestScope() + s, err := NewInMemoryRawStore(&Config{}, testScope) + assert.NoError(t, err) + metadata, err := s.Head(context.TODO(), DataReference("hello")) + assert.NoError(t, err) + assert.False(t, metadata.Exists()) + }) + + t.Run("Existing Item", func(t *testing.T) { + testScope := promutils.NewTestScope() + s, err := NewInMemoryRawStore(&Config{}, testScope) + assert.NoError(t, err) + err = s.WriteRaw(context.TODO(), DataReference("hello"), 0, Options{}, bytes.NewReader([]byte{})) + assert.NoError(t, err) + + metadata, err := s.Head(context.TODO(), DataReference("hello")) + assert.NoError(t, err) + assert.True(t, metadata.Exists()) + }) +} + +func TestInMemoryStore_ReadRaw(t *testing.T) { + t.Run("Empty store", func(t *testing.T) { + testScope := promutils.NewTestScope() + s, err := NewInMemoryRawStore(&Config{}, testScope) + assert.NoError(t, err) + + raw, err := s.ReadRaw(context.TODO(), DataReference("hello")) + assert.Error(t, err) + assert.Nil(t, raw) + }) + + t.Run("Existing Item", func(t *testing.T) { + testScope := promutils.NewTestScope() + s, err := NewInMemoryRawStore(&Config{}, testScope) + assert.NoError(t, err) + + err = s.WriteRaw(context.TODO(), DataReference("hello"), 0, Options{}, bytes.NewReader([]byte{})) + assert.NoError(t, err) + + _, err = s.ReadRaw(context.TODO(), DataReference("hello")) + assert.NoError(t, err) + }) +} + +func TestInMemoryStore_Clear(t *testing.T) { + testScope := promutils.NewTestScope() + m, err := NewInMemoryRawStore(&Config{}, testScope) + assert.NoError(t, err) + + mStore := m.(*InMemoryStore) + err = m.WriteRaw(context.TODO(), DataReference("hello"), 0, Options{}, bytes.NewReader([]byte("world"))) + assert.NoError(t, err) + + _, err = m.ReadRaw(context.TODO(), DataReference("hello")) + assert.NoError(t, err) + + err = mStore.Clear(context.TODO()) + assert.NoError(t, err) + + _, err = m.ReadRaw(context.TODO(), DataReference("hello")) + assert.Error(t, err) + assert.True(t, IsNotFound(err)) +} diff --git a/flytestdlib/storage/protobuf_store.go b/flytestdlib/storage/protobuf_store.go new file mode 100644 index 0000000000..ba11d3311c --- /dev/null +++ b/flytestdlib/storage/protobuf_store.go @@ -0,0 +1,85 @@ +package storage + +import ( + "bytes" + "context" + "fmt" + "time" + + "github.com/lyft/flytestdlib/logger" + + "github.com/lyft/flytestdlib/ioutils" + "github.com/lyft/flytestdlib/promutils" + "github.com/prometheus/client_golang/prometheus" + + "github.com/golang/protobuf/proto" + errs "github.com/pkg/errors" +) + +type protoMetrics struct { + FetchLatency promutils.StopWatch + MarshalTime promutils.StopWatch + UnmarshalTime promutils.StopWatch + MarshalFailure prometheus.Counter + UnmarshalFailure prometheus.Counter +} + +// Implements ProtobufStore to marshal and unmarshal protobufs to/from a RawStore +type DefaultProtobufStore struct { + RawStore + metrics *protoMetrics +} + +func (s DefaultProtobufStore) ReadProtobuf(ctx context.Context, reference DataReference, msg proto.Message) error { + rc, err := s.ReadRaw(ctx, reference) + if err != nil { + return errs.Wrap(err, fmt.Sprintf("path:%v", reference)) + } + + defer func() { + err = rc.Close() + if err != nil { + logger.Warn(ctx, "Failed to close reference [%v]. Error: %v", reference, err) + } + }() + + docContents, err := ioutils.ReadAll(rc, s.metrics.FetchLatency.Start()) + if err != nil { + return errs.Wrap(err, fmt.Sprintf("readAll: %v", reference)) + } + + t := s.metrics.UnmarshalTime.Start() + err = proto.Unmarshal(docContents, msg) + t.Stop() + if err != nil { + s.metrics.UnmarshalFailure.Inc() + return errs.Wrap(err, fmt.Sprintf("unmarshall: %v", reference)) + } + + return nil +} + +func (s DefaultProtobufStore) WriteProtobuf(ctx context.Context, reference DataReference, opts Options, msg proto.Message) error { + t := s.metrics.MarshalTime.Start() + raw, err := proto.Marshal(msg) + t.Stop() + if err != nil { + s.metrics.MarshalFailure.Inc() + return err + } + + return s.WriteRaw(ctx, reference, int64(len(raw)), opts, bytes.NewReader(raw)) +} + +func NewDefaultProtobufStore(store RawStore, metricsScope promutils.Scope) DefaultProtobufStore { + return DefaultProtobufStore{ + RawStore: store, + metrics: &protoMetrics{ + FetchLatency: metricsScope.MustNewStopWatch("proto_fetch", "Time to read data before unmarshalling", time.Millisecond), + MarshalTime: metricsScope.MustNewStopWatch("marshal", "Time incurred in marshalling data before writing", time.Millisecond), + UnmarshalTime: metricsScope.MustNewStopWatch("unmarshal", "Time incurred in unmarshalling received data", time.Millisecond), + MarshalFailure: metricsScope.MustNewCounter("marshal_failure", "Failures when marshalling"), + UnmarshalFailure: metricsScope.MustNewCounter("unmarshal_failure", "Failures when unmarshalling"), + }, + } +} diff --git a/flytestdlib/storage/protobuf_store_test.go b/flytestdlib/storage/protobuf_store_test.go new file mode 100644 index 0000000000..160239bb73 --- /dev/null +++ b/flytestdlib/storage/protobuf_store_test.go @@ -0,0 +1,41 @@ +package storage + +import ( + "context" + "testing" + + "github.com/lyft/flytestdlib/promutils" + + "github.com/golang/protobuf/proto" + "github.com/stretchr/testify/assert" +) + +type mockProtoMessage struct { + X int64 `protobuf:"varint,2,opt,name=x,json=x,proto3" json:"x,omitempty"` +} + +func (mockProtoMessage) Reset() { +} + +func (m mockProtoMessage) String() string { + return proto.CompactTextString(m) +} + +func (mockProtoMessage) ProtoMessage() { +} + +func TestDefaultProtobufStore_ReadProtobuf(t *testing.T) { + t.Run("Read after Write", func(t *testing.T) { + testScope := promutils.NewTestScope() + s, err := NewDataStore(&Config{Type: TypeMemory}, testScope) + assert.NoError(t, err) + + err = s.WriteProtobuf(context.TODO(), DataReference("hello"), Options{}, &mockProtoMessage{X: 5}) + assert.NoError(t, err) + + m := &mockProtoMessage{} + err = s.ReadProtobuf(context.TODO(), DataReference("hello"), m) + assert.NoError(t, err) + assert.Equal(t, int64(5), m.X) + }) +} diff --git a/flytestdlib/storage/rawstores.go b/flytestdlib/storage/rawstores.go new file mode 100644 index 0000000000..4a6d35573c --- /dev/null +++ b/flytestdlib/storage/rawstores.go @@ -0,0 +1,40 @@ +package storage + +import ( + "fmt" + + "github.com/lyft/flytestdlib/promutils" +) + +type dataStoreCreateFn func(cfg *Config, metricsScope promutils.Scope) (RawStore, error) + +var stores = map[string]dataStoreCreateFn{ + TypeMemory: NewInMemoryRawStore, + TypeLocal: newLocalRawStore, + TypeMinio: newS3RawStore, + TypeS3: newS3RawStore, +} + +// Creates a new Data Store with the supplied config. +func NewDataStore(cfg *Config, metricsScope promutils.Scope) (s *DataStore, err error) { + var rawStore RawStore + if fn, found := stores[cfg.Type]; found { + rawStore, err = fn(cfg, metricsScope) + if err != nil { + return &emptyStore, err + } + + protoStore := NewDefaultProtobufStore(newCachedRawStore(cfg, rawStore, metricsScope), metricsScope) + return NewCompositeDataStore(URLPathConstructor{}, protoStore), nil + } + + return &emptyStore, fmt.Errorf("type is of an invalid value [%v]", cfg.Type) +} + +// Composes a new DataStore. +func NewCompositeDataStore(refConstructor ReferenceConstructor, composedProtobufStore ComposedProtobufStore) *DataStore { + return &DataStore{ + ReferenceConstructor: refConstructor, + ComposedProtobufStore: composedProtobufStore, + } +} diff --git a/flytestdlib/storage/s3store.go b/flytestdlib/storage/s3store.go new file mode 100644 index 0000000000..3c96730e0a --- /dev/null +++ b/flytestdlib/storage/s3store.go @@ -0,0 +1,102 @@ +package storage + +import ( + "context" + "fmt" + + "github.com/lyft/flytestdlib/promutils" + + "github.com/aws/aws-sdk-go/aws/awserr" + awsS3 "github.com/aws/aws-sdk-go/service/s3" + "github.com/lyft/flytestdlib/logger" + "github.com/pkg/errors" + + "github.com/graymeta/stow" + "github.com/graymeta/stow/s3" +) + +func getStowConfigMap(cfg *Config) stow.ConfigMap { + // Non-nullable fields + stowConfig := stow.ConfigMap{ + s3.ConfigAuthType: cfg.Connection.AuthType, + s3.ConfigRegion: cfg.Connection.Region, + } + + // Fields that differ between minio and real S3 + if endpoint := cfg.Connection.Endpoint.String(); endpoint != "" { + stowConfig[s3.ConfigEndpoint] = endpoint + } + + if accessKey := cfg.Connection.AccessKey; accessKey != "" { + stowConfig[s3.ConfigAccessKeyID] = accessKey + } + + if secretKey := cfg.Connection.SecretKey; secretKey != "" { + stowConfig[s3.ConfigSecretKey] = secretKey + } + + if disableSsl := cfg.Connection.DisableSSL; disableSsl { + stowConfig[s3.ConfigDisableSSL] = "True" + } + + return stowConfig + +} + +func s3FQN(bucket string) DataReference { + return DataReference(fmt.Sprintf("s3://%s", bucket)) +} + +func newS3RawStore(cfg *Config, metricsScope promutils.Scope) (RawStore, error) { + if cfg.InitContainer == "" { + return nil, fmt.Errorf("initContainer is required") + } + + loc, err := stow.Dial(s3.Kind, getStowConfigMap(cfg)) + + if err != nil { + return emptyStore, fmt.Errorf("unable to configure the storage for s3. Error: %v", err) + } + + c, err := loc.Container(cfg.InitContainer) + if err != nil { + if IsNotFound(err) || awsBucketIsNotFound(err) { + c, err := loc.CreateContainer(cfg.InitContainer) + if err != nil { + // If the container's already created, move on. Otherwise, fail with error. + if awsBucketAlreadyExists(err) { + logger.Infof(context.TODO(), "Storage init-container already exists [%v].", cfg.InitContainer) + return NewStowRawStore(s3FQN(c.Name()), c, metricsScope) + } + } + return emptyStore, fmt.Errorf("unable to initialize container [%v]. Error: %v", cfg.InitContainer, err) + } + return emptyStore, err + } + + return NewStowRawStore(s3FQN(c.Name()), c, metricsScope) +} + +func awsBucketIsNotFound(err error) bool { + if IsNotFound(err) { + return true + } + + if awsErr, errOk := errors.Cause(err).(awserr.Error); errOk { + return awsErr.Code() == awsS3.ErrCodeNoSuchBucket + } + + return false +} + +func awsBucketAlreadyExists(err error) bool { + if IsExists(err) { + return true + } + + if awsErr, errOk := errors.Cause(err).(awserr.Error); errOk { + return awsErr.Code() == awsS3.ErrCodeBucketAlreadyOwnedByYou + } + + return false +} diff --git a/flytestdlib/storage/s3stsore_test.go b/flytestdlib/storage/s3stsore_test.go new file mode 100644 index 0000000000..2e8674a834 --- /dev/null +++ b/flytestdlib/storage/s3stsore_test.go @@ -0,0 +1,26 @@ +package storage + +import ( + "testing" + + "github.com/lyft/flytestdlib/promutils" + + "github.com/lyft/flytestdlib/config" + "github.com/lyft/flytestdlib/internal/utils" + "github.com/stretchr/testify/assert" +) + +func TestNewS3RawStore(t *testing.T) { + t.Run("Missing required config", func(t *testing.T) { + testScope := promutils.NewTestScope() + _, err := NewDataStore(&Config{ + Type: TypeMinio, + InitContainer: "some-container", + Connection: ConnectionConfig{ + Endpoint: config.URL{URL: utils.MustParseURL("http://minio:9000")}, + }, + }, testScope) + + assert.Error(t, err) + }) +} diff --git a/flytestdlib/storage/storage.go b/flytestdlib/storage/storage.go new file mode 100644 index 0000000000..ffbc2c62b0 --- /dev/null +++ b/flytestdlib/storage/storage.go @@ -0,0 +1,95 @@ +// Defines extensible storage interface. +// This package registers "storage" config section that maps to Config struct. Use NewDataStore(cfg) to initialize a +// DataStore with the provided config. The package provides default implementation to access local, S3 (and minio), +// and In-Memory storage. Use NewCompositeDataStore to swap any portions of the DataStore interface with an external +// implementation (e.g. a cached protobuf store). The underlying storage is provided by extensible "stow" library. You +// can use NewStowRawStore(cfg) to create a Raw store based on any other stow-supported configs (e.g. Azure Blob Storage) +package storage + +import ( + "context" + "strings" + + "io" + "net/url" + + "github.com/golang/protobuf/proto" +) + +// Defines a reference to data location. +type DataReference string + +var emptyStore = DataStore{} + +// Holder for recording storage options. It is used to pass Metadata (like headers for S3) and also tags or labels for +// objects +type Options struct { + Metadata map[string]interface{} +} + +// Placeholder for data reference metadata. +type Metadata interface { + Exists() bool + Size() int64 +} + +// A simplified interface for accessing and storing data in one of the Cloud stores. +// Today we rely on Stow for multi-cloud support, but this interface abstracts that part +type DataStore struct { + ComposedProtobufStore + ReferenceConstructor +} + +// Defines a low level interface for accessing and storing bytes. +type RawStore interface { + // returns a FQN DataReference with the configured base init container + GetBaseContainerFQN(ctx context.Context) DataReference + + // Gets metadata about the reference. This should generally be a light weight operation. + Head(ctx context.Context, reference DataReference) (Metadata, error) + + // Retrieves a byte array from the Blob store or an error + ReadRaw(ctx context.Context, reference DataReference) (io.ReadCloser, error) + + // Stores a raw byte array. + WriteRaw(ctx context.Context, reference DataReference, size int64, opts Options, raw io.Reader) error + + // Copies from source to destination. + CopyRaw(ctx context.Context, source, destination DataReference, opts Options) error +} + +// Defines an interface for building data reference paths. +type ReferenceConstructor interface { + // Creates a new dataReference that matches the storage structure. + ConstructReference(ctx context.Context, reference DataReference, nestedKeys ...string) (DataReference, error) +} + +// Defines an interface for reading and writing protobuf messages +type ProtobufStore interface { + // Retrieves the entire blob from blobstore and unmarshals it to the passed protobuf + ReadProtobuf(ctx context.Context, reference DataReference, msg proto.Message) error + + // Serializes and stores the protobuf. + WriteProtobuf(ctx context.Context, reference DataReference, opts Options, msg proto.Message) error +} + +// A ProtobufStore needs a RawStore to get the RawData. This interface provides all the necessary components to make +// Protobuf fetching work +type ComposedProtobufStore interface { + RawStore + ProtobufStore +} + +// Splits the data reference into parts. +func (r DataReference) Split() (scheme, container, key string, err error) { + u, err := url.Parse(string(r)) + if err != nil { + return "", "", "", err + } + + return u.Scheme, u.Host, strings.Trim(u.Path, "/"), nil +} + +func (r DataReference) String() string { + return string(r) +} diff --git a/flytestdlib/storage/storage_test.go b/flytestdlib/storage/storage_test.go new file mode 100644 index 0000000000..1895b0ac5f --- /dev/null +++ b/flytestdlib/storage/storage_test.go @@ -0,0 +1,52 @@ +package storage + +import ( + "context" + "fmt" + "strings" + "testing" + + "github.com/lyft/flytestdlib/promutils" + + "github.com/stretchr/testify/assert" +) + +func TestDataReference_Split(t *testing.T) { + input := DataReference("s3://container/path/to/file") + scheme, container, key, err := input.Split() + + assert.NoError(t, err) + assert.Equal(t, "s3", scheme) + assert.Equal(t, "container", container) + assert.Equal(t, "path/to/file", key) +} + +func ExampleNewDataStore() { + testScope := promutils.NewTestScope() + ctx := context.Background() + fmt.Println("Creating in memory data store.") + store, err := NewDataStore(&Config{ + Type: TypeMemory, + }, testScope.NewSubScope("exp_new")) + + if err != nil { + fmt.Printf("Failed to create data store. Error: %v", err) + } + + ref, err := store.ConstructReference(ctx, DataReference("root"), "subkey", "subkey2") + if err != nil { + fmt.Printf("Failed to construct data reference. Error: %v", err) + } + + fmt.Printf("Constructed data reference [%v] and writing data to it.\n", ref) + + dataToStore := "hello world" + err = store.WriteRaw(ctx, ref, int64(len(dataToStore)), Options{}, strings.NewReader(dataToStore)) + if err != nil { + fmt.Printf("Failed to write data. Error: %v", err) + } + + // Output: + // Creating in memory data store. + // Constructed data reference [/root/subkey/subkey2] and writing data to it. +} diff --git a/flytestdlib/storage/stow_store.go b/flytestdlib/storage/stow_store.go new file mode 100644 index 0000000000..b13170bf09 --- /dev/null +++ b/flytestdlib/storage/stow_store.go @@ -0,0 +1,174 @@ +package storage + +import ( + "context" + "io" + "time" + + "github.com/prometheus/client_golang/prometheus" + + "github.com/lyft/flytestdlib/promutils" + + "github.com/graymeta/stow" + errs "github.com/pkg/errors" +) + +type stowMetrics struct { + BadReference prometheus.Counter + BadContainer prometheus.Counter + + HeadFailure prometheus.Counter + HeadLatency promutils.StopWatch + + ReadFailure prometheus.Counter + ReadOpenLatency promutils.StopWatch + + WriteFailure prometheus.Counter + WriteLatency promutils.StopWatch +} + +// Implements DataStore to talk to stow location store. +type StowStore struct { + stow.Container + copyImpl + metrics *stowMetrics + containerBaseFQN DataReference +} + +type StowMetadata struct { + exists bool + size int64 +} + +func (s StowMetadata) Size() int64 { + return s.size +} + +func (s StowMetadata) Exists() bool { + return s.exists +} + +func (s *StowStore) getContainer(container string) (c stow.Container, err error) { + if s.Container.Name() != container { + s.metrics.BadContainer.Inc() + return nil, stow.ErrNotFound + } + + return s.Container, nil +} + +func (s *StowStore) Head(ctx context.Context, reference DataReference) (Metadata, error) { + _, c, k, err := reference.Split() + if err != nil { + s.metrics.BadReference.Inc() + return nil, err + } + + container, err := s.getContainer(c) + if err != nil { + return nil, err + } + + t := s.metrics.HeadLatency.Start() + item, err := container.Item(k) + if err == nil { + if _, err = item.Metadata(); err == nil { + size, err := item.Size() + if err == nil { + t.Stop() + return StowMetadata{ + exists: true, + size: size, + }, nil + } + } + } + s.metrics.HeadFailure.Inc() + if IsNotFound(err) { + return StowMetadata{exists: false}, nil + } + return StowMetadata{exists: false}, errs.Wrapf(err, "path:%v", k) +} + +func (s *StowStore) ReadRaw(ctx context.Context, reference DataReference) (io.ReadCloser, error) { + _, c, k, err := reference.Split() + if err != nil { + s.metrics.BadReference.Inc() + return nil, err + } + + container, err := s.getContainer(c) + if err != nil { + return nil, err + } + + t := s.metrics.ReadOpenLatency.Start() + item, err := container.Item(k) + if err != nil { + s.metrics.ReadFailure.Inc() + return nil, err + } + t.Stop() + + sizeBytes, err := item.Size() + if err != nil { + return nil, err + } + + if sizeBytes/MiB > GetConfig().Limits.GetLimitMegabytes { + return nil, ErrExceedsLimit + } + + return item.Open() +} + +func (s *StowStore) WriteRaw(ctx context.Context, reference DataReference, size int64, opts Options, raw io.Reader) error { + _, c, k, err := reference.Split() + if err != nil { + s.metrics.BadReference.Inc() + return err + } + + container, err := s.getContainer(c) + if err != nil { + return err + } + + t := s.metrics.WriteLatency.Start() + _, err = container.Put(k, raw, size, opts.Metadata) + if err != nil { + s.metrics.WriteFailure.Inc() + return errs.Wrapf(err, "Failed to write data [%vb] to path [%v].", size, k) + } + t.Stop() + + return nil +} + +func (s *StowStore) GetBaseContainerFQN(ctx context.Context) DataReference { + return s.containerBaseFQN +} + +func NewStowRawStore(containerBaseFQN DataReference, container stow.Container, metricsScope promutils.Scope) (*StowStore, error) { + self := &StowStore{ + Container: container, + containerBaseFQN: containerBaseFQN, + metrics: &stowMetrics{ + BadReference: metricsScope.MustNewCounter("bad_key", "Indicates the provided storage reference/key is incorrectly formatted"), + BadContainer: metricsScope.MustNewCounter("bad_container", "Indicates request for a container that has not been initialized"), + + HeadFailure: metricsScope.MustNewCounter("head_failure", "Indicates failure in HEAD for a given reference"), + HeadLatency: metricsScope.MustNewStopWatch("head", "Indicates time to fetch metadata using the Head API", time.Millisecond), + + ReadFailure: metricsScope.MustNewCounter("read_failure", "Indicates failure in GET for a given reference"), + ReadOpenLatency: metricsScope.MustNewStopWatch("read_open", "Indicates time to first byte when reading", time.Millisecond), + + WriteFailure: metricsScope.MustNewCounter("write_failure", "Indicates failure in storing/PUT for a given reference"), + WriteLatency: metricsScope.MustNewStopWatch("write", "Time to write an object irrespective of size", time.Millisecond), + }, + } + + self.copyImpl = newCopyImpl(self, metricsScope) + + return self, nil +} diff --git a/flytestdlib/storage/stow_store_test.go b/flytestdlib/storage/stow_store_test.go new file mode 100644 index 0000000000..f0fd282340 --- /dev/null +++ b/flytestdlib/storage/stow_store_test.go @@ -0,0 +1,133 @@ +package storage + +import ( + "bytes" + "context" + "io" + "io/ioutil" + "net/url" + "testing" + "time" + + "github.com/lyft/flytestdlib/promutils" + + "github.com/graymeta/stow" + "github.com/stretchr/testify/assert" +) + +type mockStowContainer struct { + id string + items map[string]mockStowItem +} + +func (m mockStowContainer) ID() string { + return m.id +} + +func (m mockStowContainer) Name() string { + return m.id +} + +func (m mockStowContainer) Item(id string) (stow.Item, error) { + if item, found := m.items[id]; found { + return item, nil + } + + return nil, stow.ErrNotFound +} + +func (mockStowContainer) Items(prefix, cursor string, count int) ([]stow.Item, string, error) { + return []stow.Item{}, "", nil +} + +func (mockStowContainer) RemoveItem(id string) error { + return nil +} + +func (m *mockStowContainer) Put(name string, r io.Reader, size int64, metadata map[string]interface{}) (stow.Item, error) { + item := mockStowItem{url: name, size: size} + m.items[name] = item + return item, nil +} + +func newMockStowContainer(id string) *mockStowContainer { + return &mockStowContainer{ + id: id, + items: map[string]mockStowItem{}, + } +} + +type mockStowItem struct { + url string + size int64 +} + +func (m mockStowItem) ID() string { + return m.url +} + +func (m mockStowItem) Name() string { + return m.url +} + +func (m mockStowItem) URL() *url.URL { + u, err := url.Parse(m.url) + if err != nil { + panic(err) + } + + return u +} + +func (m mockStowItem) Size() (int64, error) { + return m.size, nil +} + +func (mockStowItem) Open() (io.ReadCloser, error) { + return ioutil.NopCloser(bytes.NewReader([]byte{})), nil +} + +func (mockStowItem) ETag() (string, error) { + return "", nil +} + +func (mockStowItem) LastMod() (time.Time, error) { + return time.Now(), nil +} + +func (mockStowItem) Metadata() (map[string]interface{}, error) { + return map[string]interface{}{}, nil +} + +func TestStowStore_ReadRaw(t *testing.T) { + t.Run("Happy Path", func(t *testing.T) { + testScope := promutils.NewTestScope() + s, err := NewStowRawStore(s3FQN("container"), newMockStowContainer("container"), testScope) + assert.NoError(t, err) + err = s.WriteRaw(context.TODO(), DataReference("s3://container/path"), 0, Options{}, bytes.NewReader([]byte{})) + assert.NoError(t, err) + metadata, err := s.Head(context.TODO(), DataReference("s3://container/path")) + assert.NoError(t, err) + assert.True(t, metadata.Exists()) + raw, err := s.ReadRaw(context.TODO(), DataReference("s3://container/path")) + assert.NoError(t, err) + rawBytes, err := ioutil.ReadAll(raw) + assert.NoError(t, err) + assert.Equal(t, 0, len(rawBytes)) + assert.Equal(t, DataReference("s3://container"), s.GetBaseContainerFQN(context.TODO())) + }) + + t.Run("Exceeds limit", func(t *testing.T) { + testScope := promutils.NewTestScope() + s, err := NewStowRawStore(s3FQN("container"), newMockStowContainer("container"), testScope) + assert.NoError(t, err) + err = s.WriteRaw(context.TODO(), DataReference("s3://container/path"), 3*MiB, Options{}, bytes.NewReader([]byte{})) + assert.NoError(t, err) + metadata, err := s.Head(context.TODO(), DataReference("s3://container/path")) + assert.NoError(t, err) + assert.True(t, metadata.Exists()) + _, err = s.ReadRaw(context.TODO(), DataReference("s3://container/path")) + assert.Error(t, err) + assert.True(t, IsExceedsLimit(err)) + }) +} diff --git a/flytestdlib/storage/testdata/config.yaml b/flytestdlib/storage/testdata/config.yaml new file mode 100755 index 0000000000..d8664ca347 --- /dev/null +++ b/flytestdlib/storage/testdata/config.yaml @@ -0,0 +1,14 @@ +cache: + max_size_mbs: 0 + target_gc_percent: 0 +connection: + access-key: minio + auth-type: accesskey + disable-ssl: true + endpoint: http://minio:9000 + region: us-east-1 + secret-key: miniostorage +container: "" +limits: + maxDownloadMBs: 0 +type: s3 diff --git a/flytestdlib/storage/url_path.go b/flytestdlib/storage/url_path.go new file mode 100644 index 0000000000..94e26e317a --- /dev/null +++ b/flytestdlib/storage/url_path.go @@ -0,0 +1,44 @@ +package storage + +import ( + "context" + "fmt" + + "github.com/pkg/errors" + + "net/url" + "os" + "path/filepath" + + "github.com/lyft/flytestdlib/logger" +) + +// Implements ReferenceConstructor that assumes paths are URL-compatible. +type URLPathConstructor struct { +} + +func ensureEndingPathSeparator(path DataReference) DataReference { + if len(path) > 0 && path[len(path)-1] == os.PathSeparator { + return path + } + + return path + "/" +} + +func (URLPathConstructor) ConstructReference(ctx context.Context, reference DataReference, nestedKeys ...string) (DataReference, error) { + u, err := url.Parse(string(ensureEndingPathSeparator(reference))) + if err != nil { + logger.Errorf(ctx, "Failed to parse prefix: %v", reference) + return "", errors.Wrap(err, fmt.Sprintf("Reference is of an invalid format [%v]", reference)) + } + + rel, err := url.Parse(filepath.Join(nestedKeys...)) + if err != nil { + logger.Errorf(ctx, "Failed to parse nested keys: %v", reference) + return "", errors.Wrap(err, fmt.Sprintf("Reference is of an invalid format [%v]", reference)) + } + + u = u.ResolveReference(rel) + + return DataReference(u.String()), nil +} diff --git a/flytestdlib/storage/url_path_test.go b/flytestdlib/storage/url_path_test.go new file mode 100644 index 0000000000..5dc24661cf --- /dev/null +++ b/flytestdlib/storage/url_path_test.go @@ -0,0 +1,15 @@ +package storage + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestUrlPathConstructor_ConstructReference(t *testing.T) { + s := URLPathConstructor{} + r, err := s.ConstructReference(context.TODO(), DataReference("hello"), "key1", "key2/", "key3") + assert.NoError(t, err) + assert.Equal(t, "/hello/key1/key2/key3", r.String()) +} diff --git a/flytestdlib/storage/utils.go b/flytestdlib/storage/utils.go new file mode 100644 index 0000000000..62fb2aa22c --- /dev/null +++ b/flytestdlib/storage/utils.go @@ -0,0 +1,34 @@ +package storage + +import ( + "fmt" + "os" + + "github.com/graymeta/stow" + "github.com/pkg/errors" +) + +var ErrExceedsLimit = fmt.Errorf("limit exceeded") + +// Gets a value indicating whether the underlying error is a Not Found error. +func IsNotFound(err error) bool { + if root := errors.Cause(err); root == stow.ErrNotFound || os.IsNotExist(root) { + return true + } + + return false +} + +// Gets a value indicating whether the underlying error is "already exists" error. +func IsExists(err error) bool { + if root := errors.Cause(err); os.IsExist(root) { + return true + } + + return false +} + +// Gets a value indicating whether the root cause of error is a "limit exceeded" error. +func IsExceedsLimit(err error) bool { + return errors.Cause(err) == ErrExceedsLimit +} diff --git a/flytestdlib/tests/config_test.go b/flytestdlib/tests/config_test.go new file mode 100644 index 0000000000..15c890cf08 --- /dev/null +++ b/flytestdlib/tests/config_test.go @@ -0,0 +1,78 @@ +package tests + +import ( + "context" + "flag" + "io/ioutil" + "os" + "path/filepath" + "reflect" + "testing" + + "github.com/lyft/flytestdlib/config/viper" + + "github.com/ghodss/yaml" + "github.com/lyft/flytestdlib/config" + "github.com/lyft/flytestdlib/internal/utils" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/storage" + "github.com/stretchr/testify/assert" +) + +// Make sure existing config file(s) parse correctly before overriding them with this flag! +var update = flag.Bool("update", false, "Updates testdata") + +func TestStorageAndLoggerConfig(t *testing.T) { + type CompositeConfig struct { + Storage storage.Config `json:"storage"` + Logger logger.Config `json:"logger"` + } + + expected := CompositeConfig{ + Storage: storage.Config{ + Type: "s3", + Connection: storage.ConnectionConfig{ + Endpoint: config.URL{URL: utils.MustParseURL("http://minio:9000")}, + AuthType: "accesskey", + AccessKey: "minio", + SecretKey: "miniostorage", + Region: "us-east-1", + DisableSSL: true, + }, + }, + Logger: logger.Config{ + Level: logger.DebugLevel, + }, + } + + configPath := filepath.Join("testdata", "combined.yaml") + if *update { + t.Log("Updating golden files.") + raw, err := yaml.Marshal(expected) + assert.NoError(t, err) + assert.NoError(t, ioutil.WriteFile(configPath, raw, os.ModePerm)) + } + + actual := CompositeConfig{} + /* #nosec */ + raw, err := ioutil.ReadFile(configPath) + assert.NoError(t, err) + assert.NoError(t, yaml.Unmarshal(raw, &actual)) + assert.True(t, reflect.DeepEqual(expected, actual)) +} + +func TestParseExistingConfig(t *testing.T) { + accessor := viper.NewAccessor(config.Options{ + SearchPaths: []string{filepath.Join("testdata", "combined.yaml")}, + }) + + assert.NoError(t, accessor.UpdateConfig(context.TODO())) + + assert.NotNil(t, storage.ConfigSection) + + if _, ok := storage.ConfigSection.GetConfig().(*storage.Config); ok { + assert.True(t, ok) + } else { + assert.FailNow(t, "Retrieved section is not of type storage.") + } +} diff --git a/flytestdlib/tests/testdata/combined.yaml b/flytestdlib/tests/testdata/combined.yaml new file mode 100755 index 0000000000..f167b1ab33 --- /dev/null +++ b/flytestdlib/tests/testdata/combined.yaml @@ -0,0 +1,19 @@ +logger: + level: 5 + mute: false + show-source: false +storage: + cache: + max_size_mbs: 0 + target_gc_percent: 0 + connection: + access-key: minio + auth-type: accesskey + disable-ssl: true + endpoint: http://minio:9000 + region: us-east-1 + secret-key: miniostorage + container: "" + limits: + maxDownloadMBs: 0 + type: s3 diff --git a/flytestdlib/utils/auto_refresh_cache.go b/flytestdlib/utils/auto_refresh_cache.go new file mode 100644 index 0000000000..a231bbd7c2 --- /dev/null +++ b/flytestdlib/utils/auto_refresh_cache.go @@ -0,0 +1,99 @@ +package utils + +import ( + "context" + "sync" + "time" + + "github.com/lyft/flytestdlib/logger" + "k8s.io/apimachinery/pkg/util/wait" +) + +// AutoRefreshCache with regular GetOrCreate and Delete along with background asynchronous refresh. Caller provides +// callbacks for create, refresh and delete item. +// The cache doesn't provide apis to update items. +type AutoRefreshCache interface { + // starts background refresh of items + Start(ctx context.Context) + + // Get item by id if exists else null + Get(id string) CacheItem + + // Get object if exists else create it + GetOrCreate(item CacheItem) (CacheItem, error) +} + +type CacheItem interface { + ID() string +} + +type CacheSyncItem func(ctx context.Context, obj CacheItem) (CacheItem, error) + +func NewAutoRefreshCache(syncCb CacheSyncItem, syncRateLimiter RateLimiter, resyncPeriod time.Duration) AutoRefreshCache { + cache := &autoRefreshCache{ + syncCb: syncCb, + syncRateLimiter: syncRateLimiter, + resyncPeriod: resyncPeriod, + } + + return cache +} + +// Thread-safe general purpose auto-refresh cache that watches for updates asynchronously for the keys after they are added to +// the cache. An item can be inserted only once. +// +// Get reads from sync.map while refresh is invoked on a snapshot of keys. Cache eventually catches up on deleted items. +// +// Sync is run as a fixed-interval-scheduled-task, and is skipped if sync from previous cycle is still running. +type autoRefreshCache struct { + syncCb CacheSyncItem + syncMap sync.Map + syncRateLimiter RateLimiter + resyncPeriod time.Duration +} + +func (w *autoRefreshCache) Start(ctx context.Context) { + go wait.Until(func() { w.sync(ctx) }, w.resyncPeriod, ctx.Done()) +} + +func (w *autoRefreshCache) Get(id string) CacheItem { + if val, ok := w.syncMap.Load(id); ok { + return val.(CacheItem) + } + return nil +} + +// Return the item if exists else create it. +// Create should be invoked only once. recreating the object is not supported. +func (w *autoRefreshCache) GetOrCreate(item CacheItem) (CacheItem, error) { + if val, ok := w.syncMap.Load(item.ID()); ok { + return val.(CacheItem), nil + } + + w.syncMap.Store(item.ID(), item) + return item, nil +} + +func (w *autoRefreshCache) sync(ctx context.Context) { + w.syncMap.Range(func(key, value interface{}) bool { + if w.syncRateLimiter != nil { + err := w.syncRateLimiter.Wait(ctx) + if err != nil { + logger.Warnf(ctx, "unexpected failure in rate-limiter wait %v", key) + return true + } + } + item, err := w.syncCb(ctx, value.(CacheItem)) + if err != nil { + logger.Error(ctx, "failed to get latest copy of the item %v", key) + } + + if item == nil { + w.syncMap.Delete(key) + } else { + w.syncMap.Store(key, item) + } + + return true + }) +} diff --git a/flytestdlib/utils/auto_refresh_cache_test.go b/flytestdlib/utils/auto_refresh_cache_test.go new file mode 100644 index 0000000000..85a09ce508 --- /dev/null +++ b/flytestdlib/utils/auto_refresh_cache_test.go @@ -0,0 +1,108 @@ +package utils + +import ( + "context" + "testing" + "time" + + atomic2 "sync/atomic" + + "github.com/lyft/flytestdlib/atomic" + "github.com/stretchr/testify/assert" +) + +type testCacheItem struct { + val int + deleted atomic.Bool + resyncPeriod time.Duration +} + +func (m *testCacheItem) ID() string { + return "id" +} + +func (m *testCacheItem) moveNext() { + // change value and spare enough time for cache to process the change. + m.val++ + time.Sleep(m.resyncPeriod * 5) +} + +func (m *testCacheItem) syncItem(ctx context.Context, obj CacheItem) (CacheItem, error) { + if m.deleted.Load() { + return nil, nil + } + return m, nil +} + +type testAutoIncrementItem struct { + val int32 +} + +func (a *testAutoIncrementItem) ID() string { + return "autoincrement" +} + +func (a *testAutoIncrementItem) syncItem(ctx context.Context, obj CacheItem) (CacheItem, error) { + atomic2.AddInt32(&a.val, 1) + return a, nil +} + +func TestCache(t *testing.T) { + testResyncPeriod := time.Millisecond + rateLimiter := NewRateLimiter("mockLimiter", 100, 1) + + item := &testCacheItem{val: 0, resyncPeriod: testResyncPeriod, deleted: atomic.NewBool(false)} + cache := NewAutoRefreshCache(item.syncItem, rateLimiter, testResyncPeriod) + + //ctx := context.Background() + ctx, cancel := context.WithCancel(context.Background()) + cache.Start(ctx) + + // create + _, err := cache.GetOrCreate(item) + assert.NoError(t, err, "unexpected GetOrCreate failure") + + // synced? + item.moveNext() + m := cache.Get(item.ID()).(*testCacheItem) + assert.Equal(t, 1, m.val) + + // synced again? + item.moveNext() + m = cache.Get(item.ID()).(*testCacheItem) + assert.Equal(t, 2, m.val) + + // removed? + item.moveNext() + item.deleted.Store(true) + time.Sleep(testResyncPeriod * 10) // spare enough time to process remove! + val := cache.Get(item.ID()) + + assert.Nil(t, val) + cancel() +} + +func TestCacheContextCancel(t *testing.T) { + testResyncPeriod := time.Millisecond + rateLimiter := NewRateLimiter("mockLimiter", 10000, 1) + + item := &testAutoIncrementItem{val: 0} + cache := NewAutoRefreshCache(item.syncItem, rateLimiter, testResyncPeriod) + + ctx, cancel := context.WithCancel(context.Background()) + cache.Start(ctx) + _, err := cache.GetOrCreate(item) + assert.NoError(t, err, "failed to add item to cache") + time.Sleep(testResyncPeriod * 10) // spare enough time to process remove! + cancel() + + // Get item + m, err := cache.GetOrCreate(item) + val1 := m.(*testAutoIncrementItem).val + assert.NoError(t, err, "unexpected GetOrCreate failure") + + // wait a few more resync periods and check that nothings has changed as auto-refresh is stopped + time.Sleep(testResyncPeriod * 20) + val2 := m.(*testAutoIncrementItem).val + assert.Equal(t, val1, val2) +} diff --git a/flytestdlib/utils/auto_refresh_example_test.go b/flytestdlib/utils/auto_refresh_example_test.go new file mode 100644 index 0000000000..1c47d39ec5 --- /dev/null +++ b/flytestdlib/utils/auto_refresh_example_test.go @@ -0,0 +1,88 @@ +package utils + +import ( + "context" + "fmt" + "time" +) + +type ExampleItemStatus string + +const ( + ExampleStatusNotStarted ExampleItemStatus = "Not-started" + ExampleStatusStarted ExampleItemStatus = "Started" + ExampleStatusSucceeded ExampleItemStatus = "Completed" +) + +type ExampleCacheItem struct { + status ExampleItemStatus + id string +} + +func (e *ExampleCacheItem) ID() string { + return e.id +} + +type ExampleService struct { + jobStatus map[string]ExampleItemStatus +} + +func newExampleService() *ExampleService { + return &ExampleService{jobStatus: make(map[string]ExampleItemStatus)} +} + +// advance the status to next, and return +func (f *ExampleService) getStatus(id string) *ExampleCacheItem { + if _, ok := f.jobStatus[id]; !ok { + f.jobStatus[id] = ExampleStatusStarted + } + f.jobStatus[id] = ExampleStatusSucceeded + return &ExampleCacheItem{f.jobStatus[id], id} +} + +func ExampleNewAutoRefreshCache() { + // This auto-refresh cache can be used for cases where keys are created by caller but processed by + // an external service and we want to asynchronously keep track of its progress. + exampleService := newExampleService() + + // define a sync method that the cache can use to auto-refresh in background + syncItemCb := func(ctx context.Context, obj CacheItem) (CacheItem, error) { + return exampleService.getStatus(obj.(*ExampleCacheItem).ID()), nil + } + + // define resync period as time duration we want cache to refresh. We can go as low as we want but cache + // would still be constrained by time it takes to run Sync call for each item. + resyncPeriod := time.Millisecond + + // Since number of items in the cache is dynamic, rate limiter is our knob to control resources we spend on + // sync. + rateLimiter := NewRateLimiter("ExampleRateLimiter", 10000, 1) + + // since cache refreshes itself asynchronously, it may not notice that an object has been deleted immediately, + // so users of the cache should have the delete logic aware of this shortcoming (eg. not-exists may be a valid + // error during removal if based on status in cache). + cache := NewAutoRefreshCache(syncItemCb, rateLimiter, resyncPeriod) + + // start the cache with a context that would be to stop the cache by cancelling the context + ctx, cancel := context.WithCancel(context.Background()) + cache.Start(ctx) + + // creating objects that go through a couple of state transitions to reach the final state. + item1 := &ExampleCacheItem{status: ExampleStatusNotStarted, id: "item1"} + item2 := &ExampleCacheItem{status: ExampleStatusNotStarted, id: "item2"} + _, err1 := cache.GetOrCreate(item1) + _, err2 := cache.GetOrCreate(item2) + if err1 != nil || err2 != nil { + fmt.Printf("unexpected error in create; err1: %v, err2: %v", err1, err2) + } + + // wait for the cache to go through a few refresh cycles and then check status + time.Sleep(resyncPeriod * 10) + fmt.Printf("Current status for item1 is %v", cache.Get(item1.ID()).(*ExampleCacheItem).status) + + // stop the cache + cancel() + + // Output: + // Current status for item1 is Completed +} diff --git a/flytestdlib/utils/rate_limiter.go b/flytestdlib/utils/rate_limiter.go new file mode 100644 index 0000000000..6a28b21da9 --- /dev/null +++ b/flytestdlib/utils/rate_limiter.go @@ -0,0 +1,35 @@ +package utils + +import ( + "context" + + "github.com/lyft/flytestdlib/logger" + "golang.org/x/time/rate" +) + +// Interface to use rate limiter +type RateLimiter interface { + Wait(ctx context.Context) error +} + +type rateLimiter struct { + name string + limiter rate.Limiter +} + +// Blocking method which waits for the next token as per the tps and burst values defined +func (r *rateLimiter) Wait(ctx context.Context) error { + logger.Debugf(ctx, "Waiting for a token from rate limiter %s", r.name) + if err := r.limiter.Wait(ctx); err != nil { + return err + } + return nil +} + +// Create a new rate-limiter with the tps and burst values +func NewRateLimiter(rateLimiterName string, tps float64, burst int) RateLimiter { + return &rateLimiter{ + name: rateLimiterName, + limiter: *rate.NewLimiter(rate.Limit(tps), burst), + } +} diff --git a/flytestdlib/utils/rate_limiter_test.go b/flytestdlib/utils/rate_limiter_test.go new file mode 100644 index 0000000000..3aaf1df522 --- /dev/null +++ b/flytestdlib/utils/rate_limiter_test.go @@ -0,0 +1,39 @@ +package utils + +import ( + "context" + "math" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestInfiniteRateLimiter(t *testing.T) { + infRateLimiter := NewRateLimiter("test_rate_limiter", math.MaxFloat64, 0) + start := time.Now() + + for i := 0; i < 100; i++ { + err := infRateLimiter.Wait(context.Background()) + assert.NoError(t, err, "unexpected failure in wait") + } + assert.True(t, time.Since(start) < 100*time.Millisecond) +} + +func TestRateLimiter(t *testing.T) { + rateLimiter := NewRateLimiter("test_rate_limiter", 1, 1) + start := time.Now() + + for i := 0; i < 5; i++ { + err := rateLimiter.Wait(context.Background()) + assert.Nil(t, err) + } + assert.True(t, time.Since(start) > 3*time.Second) + assert.True(t, time.Since(start) < 5*time.Second) +} + +func TestInvalidRateLimitConfig(t *testing.T) { + rateLimiter := NewRateLimiter("test_rate_limiter", 1, 0) + err := rateLimiter.Wait(context.Background()) + assert.NotNil(t, err) +} diff --git a/flytestdlib/utils/sequencer.go b/flytestdlib/utils/sequencer.go new file mode 100644 index 0000000000..1d41dde367 --- /dev/null +++ b/flytestdlib/utils/sequencer.go @@ -0,0 +1,39 @@ +package utils + +import ( + "sync" + "sync/atomic" +) + +// Sequencer is a thread-safe incremental integer counter. Note that it is a singleton, so +// GetSequencer.GetNext may not always start at 0. +type Sequencer interface { + GetNext() uint64 + GetCur() uint64 +} + +type sequencer struct { + val *uint64 +} + +var once sync.Once +var instance Sequencer + +func GetSequencer() Sequencer { + once.Do(func() { + val := uint64(0) + instance = &sequencer{val: &val} + }) + return instance +} + +// Get the next sequence number, 1 higher than the last and set it as the current one +func (s sequencer) GetNext() uint64 { + x := atomic.AddUint64(s.val, 1) + return x +} + +// Get the current sequence number +func (s sequencer) GetCur() uint64 { + return *s.val +} diff --git a/flytestdlib/utils/sequencer_test.go b/flytestdlib/utils/sequencer_test.go new file mode 100644 index 0000000000..2444048764 --- /dev/null +++ b/flytestdlib/utils/sequencer_test.go @@ -0,0 +1,54 @@ +package utils + +import ( + "sync" + "testing" + + "fmt" + + "github.com/stretchr/testify/assert" +) + +func TestSequencer(t *testing.T) { + size := 3 + sequencer := GetSequencer() + curVal := sequencer.GetCur() + 1 + // sum = n(a0 + aN) / 2 + expectedSum := uint64(size) * (curVal + curVal + uint64(size-1)) / 2 + numbers := make(chan uint64, size) + + var wg sync.WaitGroup + wg.Add(size) + + iter := 0 + for iter < size { + go func() { + number := sequencer.GetNext() + fmt.Printf("list value: %d", number) + numbers <- number + wg.Done() + }() + iter++ + } + wg.Wait() + close(numbers) + + unique, sum := uniqueAndSum(numbers) + assert.True(t, unique, "sequencer generated duplicate numbers") + assert.Equal(t, expectedSum, sum, "sequencer generated sequence numbers with gap %d %d", expectedSum, sum) +} + +func uniqueAndSum(list chan uint64) (bool, uint64) { + set := make(map[uint64]struct{}) + var sum uint64 + + for elem := range list { + fmt.Printf("list value: %d\n", elem) + if _, ok := set[elem]; ok { + return false, sum + } + set[elem] = struct{}{} + sum += elem + } + return true, sum +} diff --git a/flytestdlib/version/version.go b/flytestdlib/version/version.go new file mode 100644 index 0000000000..ab3e4cf112 --- /dev/null +++ b/flytestdlib/version/version.go @@ -0,0 +1,29 @@ +package version + +import ( + "time" + + "github.com/sirupsen/logrus" +) + +// This module provides the ability to inject Build (git sha) and Version information at compile time. +// To set these values invoke go build as follows +// go build -ldflags “-X github.com/lyft/flytestdlib/version.Build=xyz -X github.com/lyft/flytestdlib/version.Version=1.2.3" +// NOTE: If the version is set and server.StartProfilingServerWithDefaultHandlers are initialized then, `/version` +// will provide the build and version information +var ( + // Specifies the GIT sha of the build + Build = "unknown" + // Version for the build, should follow a semver + Version = "unknown" + // Build timestamp + BuildTime = time.Now() +) + +// Use this method to log the build information for the current app. The app name should be provided. To inject the build +// and version information refer to the top-level comment in this file +func LogBuildInformation(appName string) { + logrus.Info("------------------------------------------------------------------------") + logrus.Infof("App [%s], Version [%s], BuildSHA [%s], BuildTS [%s]", appName, Version, Build, BuildTime.String()) + logrus.Info("------------------------------------------------------------------------") +} diff --git a/flytestdlib/version/version_test.go b/flytestdlib/version/version_test.go new file mode 100644 index 0000000000..d4ddb2def4 --- /dev/null +++ b/flytestdlib/version/version_test.go @@ -0,0 +1,29 @@ +package version + +import ( + "bytes" + "fmt" + "testing" + "time" + + "github.com/magiconair/properties/assert" + "github.com/sirupsen/logrus" +) + +type dFormat struct { +} + +func (dFormat) Format(e *logrus.Entry) ([]byte, error) { + return []byte(e.Message), nil +} + +func TestLogBuildInformation(t *testing.T) { + + n := time.Now() + BuildTime = n + buf := bytes.NewBufferString("") + logrus.SetFormatter(dFormat{}) + logrus.SetOutput(buf) + LogBuildInformation("hello") + assert.Equal(t, buf.String(), fmt.Sprintf("------------------------------------------------------------------------App [hello], Version [unknown], BuildSHA [unknown], BuildTS [%s]------------------------------------------------------------------------", n.String())) +} diff --git a/flytestdlib/yamlutils/yaml_json.go b/flytestdlib/yamlutils/yaml_json.go new file mode 100644 index 0000000000..152c6ce5f4 --- /dev/null +++ b/flytestdlib/yamlutils/yaml_json.go @@ -0,0 +1,17 @@ +package yamlutils + +import ( + "io/ioutil" + "path/filepath" + + "github.com/ghodss/yaml" +) + +func ReadYamlFileAsJSON(path string) ([]byte, error) { + r, err := ioutil.ReadFile(filepath.Clean(path)) + if err != nil { + return nil, err + } + + return yaml.YAMLToJSON(r) +} From ab8b66753a7c10d0f7f371cca33e3800f5a86aac Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Tue, 9 Apr 2019 16:00:46 -0700 Subject: [PATCH 0002/1918] Update ReadMe Title --- flytestdlib/README.rst | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/flytestdlib/README.rst b/flytestdlib/README.rst index f78785fa3d..a43dc38bff 100644 --- a/flytestdlib/README.rst +++ b/flytestdlib/README.rst @@ -1,4 +1,4 @@ -K8s Standard Library +Common Go Tools ===================== Shared components we found ourselves building time and time again, so we collected them in one place! @@ -10,26 +10,34 @@ This library consists of: - cli/pflags Tool to generate a pflags for all fields in a given struct. + - storage Abstract storage library that uses stow behind the scenes to connect to s3/azure/gcs but also offers configurable factory, in-memory storage (for testing) as well as native protobuf support. + - contextutils Wrapper around golang's context to set/get known keys. + - logger Wrapper around logrus that's configurable, taggable and context-aware. + - profutils Starts an http server that serves /metrics (exposes prometheus metrics), /healthcheck and /version endpoints. + - promutils Exposes a Scope instance that's a more convenient way to construct prometheus metrics and scope them per component. + - atomic Wrapper around sync.atomic library to offer AtomicInt32 and other convenient types. + - sets Offers strongly types and convenient interface sets. + - utils - version From b4462cc4bb1a69d28893af4e725e51f0e072539f Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Tue, 9 Apr 2019 16:01:08 -0700 Subject: [PATCH 0003/1918] Update Travis GoVersion --- flytestdlib/.travis.yml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/flytestdlib/.travis.yml b/flytestdlib/.travis.yml index 91723384da..05d3a1b406 100644 --- a/flytestdlib/.travis.yml +++ b/flytestdlib/.travis.yml @@ -1,9 +1,7 @@ sudo: required language: go go: - - "1.10" -services: - - docker + - "1.11" jobs: include: - stage: test From 660c51c0e77b6eae002f52780a43d21fda005b2d Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Tue, 9 Apr 2019 16:03:56 -0700 Subject: [PATCH 0004/1918] Temporarily add docker service --- flytestdlib/.travis.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/flytestdlib/.travis.yml b/flytestdlib/.travis.yml index 05d3a1b406..21c689d91e 100644 --- a/flytestdlib/.travis.yml +++ b/flytestdlib/.travis.yml @@ -2,6 +2,8 @@ sudo: required language: go go: - "1.11" +services: + - docker jobs: include: - stage: test From 24e21649711687114778554bafb285bd43be6a5e Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Tue, 9 Apr 2019 16:33:13 -0700 Subject: [PATCH 0005/1918] Remove services from travis config --- flytestdlib/.travis.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/flytestdlib/.travis.yml b/flytestdlib/.travis.yml index 21c689d91e..05d3a1b406 100644 --- a/flytestdlib/.travis.yml +++ b/flytestdlib/.travis.yml @@ -2,8 +2,6 @@ sudo: required language: go go: - "1.11" -services: - - docker jobs: include: - stage: test From 60a7971b106452dd8f7d7e5e0fccc6e7b5a41893 Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Tue, 9 Apr 2019 16:39:34 -0700 Subject: [PATCH 0006/1918] Add CODEOWNERS file --- flytestdlib/CODEOWNERS | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 flytestdlib/CODEOWNERS diff --git a/flytestdlib/CODEOWNERS b/flytestdlib/CODEOWNERS new file mode 100644 index 0000000000..f64c50ba34 --- /dev/null +++ b/flytestdlib/CODEOWNERS @@ -0,0 +1,11 @@ +# These owners will be the default owners for everything in +# the repo. Unless a later match takes precedence, +# @global-owner1 and @global-owner2 will be requested for +# review when someone opens a pull request. +* @EngHabu @katrogan @kumare3 @wild-endeavor + +# Order is important; the last matching pattern takes the most +# precedence. When someone opens a pull request that only +# modifies JS files, only @js-owner and not the global +# owner(s) will be requested for a review. +# *.js @js-owner From 1d47c19933d2f79c4d7ef852c17af7a95693bd58 Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Wed, 10 Apr 2019 08:44:09 -0700 Subject: [PATCH 0007/1918] Remove comments --- flytestdlib/CODEOWNERS | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/flytestdlib/CODEOWNERS b/flytestdlib/CODEOWNERS index f64c50ba34..c05664d497 100644 --- a/flytestdlib/CODEOWNERS +++ b/flytestdlib/CODEOWNERS @@ -1,11 +1,3 @@ # These owners will be the default owners for everything in -# the repo. Unless a later match takes precedence, -# @global-owner1 and @global-owner2 will be requested for -# review when someone opens a pull request. +# the repo. Unless a later match takes precedence. * @EngHabu @katrogan @kumare3 @wild-endeavor - -# Order is important; the last matching pattern takes the most -# precedence. When someone opens a pull request that only -# modifies JS files, only @js-owner and not the global -# owner(s) will be requested for a review. -# *.js @js-owner From 8b92db6330fa358397d1446aa663529c8f44c02c Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Wed, 10 Apr 2019 15:40:57 -0700 Subject: [PATCH 0008/1918] regenerate --- .../lyft/golang_test_targets/Makefile | 2 +- flytestdlib/cli/pflags/api/generator.go | 75 +++++++++++++++---- flytestdlib/cli/pflags/api/generator_test.go | 32 +++++++- flytestdlib/cli/pflags/api/sample.go | 5 +- flytestdlib/cli/pflags/api/templates.go | 18 ++++- .../cli/pflags/api/testdata/testtype.go | 51 ++++++++----- .../cli/pflags/api/testdata/testtype_test.go | 32 ++++---- flytestdlib/cli/pflags/api/utils.go | 28 +++++-- flytestdlib/cli/pflags/cmd/root.go | 9 ++- flytestdlib/logger/config.go | 19 +++-- flytestdlib/logger/config_flags.go | 27 +++++-- flytestdlib/logger/config_flags_test.go | 41 +++++++++- flytestdlib/profutils/server_test.go | 4 +- flytestdlib/storage/config.go | 20 +++-- flytestdlib/storage/config_flags.go | 41 +++++++--- flytestdlib/storage/config_flags_test.go | 22 +++--- 16 files changed, 320 insertions(+), 106 deletions(-) diff --git a/flytestdlib/boilerplate/lyft/golang_test_targets/Makefile b/flytestdlib/boilerplate/lyft/golang_test_targets/Makefile index 1c6f893521..04b79ba99e 100644 --- a/flytestdlib/boilerplate/lyft/golang_test_targets/Makefile +++ b/flytestdlib/boilerplate/lyft/golang_test_targets/Makefile @@ -1,6 +1,6 @@ .PHONY: lint lint: #lints the package for common code smells - which golangci-lint || curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | bash -s -- -b $$GOPATH/bin v1.10 + which golangci-lint || curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | bash -s -- -b $$GOPATH/bin v1.11 golangci-lint run # If code is failing goimports linter, this will fix. diff --git a/flytestdlib/cli/pflags/api/generator.go b/flytestdlib/cli/pflags/api/generator.go index 2e4dd30c54..0ee2a61de4 100644 --- a/flytestdlib/cli/pflags/api/generator.go +++ b/flytestdlib/cli/pflags/api/generator.go @@ -19,8 +19,9 @@ const ( // PFlagProviderGenerator parses and generates GetPFlagSet implementation to add PFlags for a given struct's fields. type PFlagProviderGenerator struct { - pkg *types.Package - st *types.Named + pkg *types.Package + st *types.Named + defaultVar *types.Var } // This list is restricted because that's the only kinds viper parses out, otherwise it assumes strings. @@ -54,7 +55,7 @@ func buildFieldForSlice(ctx context.Context, t SliceOrArray, name, goName, usage emptyDefaultValue := `[]string{}` if b, ok := t.Elem().(*types.Basic); !ok { logger.Infof(ctx, "Elem of type [%v] is not a basic type. It must be json unmarshalable or generation will fail.", t.Elem()) - if !jsonUnmarshaler(t.Elem()) { + if !isJSONUnmarshaler(t.Elem()) { return FieldInfo{}, fmt.Errorf("slice of type [%v] is not supported. Only basic slices or slices of json-unmarshalable types are supported", t.Elem().String()) @@ -85,9 +86,17 @@ func buildFieldForSlice(ctx context.Context, t SliceOrArray, name, goName, usage }, nil } +func appendAccessorIfNotEmpty(baseAccessor, childAccessor string) string { + if len(baseAccessor) == 0 { + return baseAccessor + } + + return baseAccessor + "." + childAccessor +} + // Traverses fields in type and follows recursion tree to discover all fields. It stops when one of two conditions is // met; encountered a basic type (e.g. string, int... etc.) or the field type implements UnmarshalJSON. -func discoverFieldsRecursive(ctx context.Context, typ *types.Named) ([]FieldInfo, error) { +func discoverFieldsRecursive(ctx context.Context, typ *types.Named, defaultValueAccessor string) ([]FieldInfo, error) { logger.Printf(ctx, "Finding all fields in [%v.%v.%v]", typ.Obj().Pkg().Path(), typ.Obj().Pkg().Name(), typ.Obj().Name()) @@ -111,9 +120,11 @@ func discoverFieldsRecursive(ctx context.Context, typ *types.Named) ([]FieldInfo tag.Name = v.Name() } + isPtr := false typ := v.Type() - if ptr, isPtr := typ.(*types.Pointer); isPtr { + if ptr, casted := typ.(*types.Pointer); casted { typ = ptr.Elem() + isPtr = true } switch t := typ.(type) { @@ -137,12 +148,21 @@ func discoverFieldsRecursive(ctx context.Context, typ *types.Named) ([]FieldInfo t.String(), t.Kind(), allowedKinds) } + defaultValue := tag.DefaultValue + if accessor := appendAccessorIfNotEmpty(defaultValueAccessor, v.Name()); len(accessor) > 0 { + defaultValue = accessor + + if isPtr { + defaultValue = fmt.Sprintf("cfg.elemValueOrNil(%s).(%s)", defaultValue, t.Name()) + } + } + fields = append(fields, FieldInfo{ Name: tag.Name, GoName: v.Name(), Typ: t, FlagMethodName: camelCase(t.String()), - DefaultValue: tag.DefaultValue, + DefaultValue: defaultValue, UsageString: tag.Usage, TestValue: `"1"`, TestStrategy: JSON, @@ -155,7 +175,7 @@ func discoverFieldsRecursive(ctx context.Context, typ *types.Named) ([]FieldInfo // If the type has json unmarshaler, then stop the recursion and assume the type is string. config package // will use json unmarshaler to fill in the final config object. - jsonUnmarshaler := jsonUnmarshaler(t) + jsonUnmarshaler := isJSONUnmarshaler(t) testValue := tag.DefaultValue if len(tag.DefaultValue) == 0 { @@ -163,6 +183,16 @@ func discoverFieldsRecursive(ctx context.Context, typ *types.Named) ([]FieldInfo testValue = `"1"` } + defaultValue := tag.DefaultValue + if accessor := appendAccessorIfNotEmpty(defaultValueAccessor, v.Name()); len(accessor) > 0 { + defaultValue = accessor + if isStringer(t) { + defaultValue = defaultValue + ".String()" + } else { + defaultValue = fmt.Sprintf("fmt.Sprintf(\"%%v\",%s)", defaultValue) + } + } + logger.Infof(ctx, "[%v] is of a Named type (struct) with default value [%v].", tag.Name, tag.DefaultValue) if jsonUnmarshaler { @@ -173,7 +203,7 @@ func discoverFieldsRecursive(ctx context.Context, typ *types.Named) ([]FieldInfo GoName: v.Name(), Typ: types.Typ[types.String], FlagMethodName: "String", - DefaultValue: tag.DefaultValue, + DefaultValue: defaultValue, UsageString: tag.Usage, TestValue: testValue, TestStrategy: JSON, @@ -181,7 +211,7 @@ func discoverFieldsRecursive(ctx context.Context, typ *types.Named) ([]FieldInfo } else { logger.Infof(ctx, "Traversing fields in type.") - nested, err := discoverFieldsRecursive(logger.WithIndent(ctx, indent), t) + nested, err := discoverFieldsRecursive(logger.WithIndent(ctx, indent), t, appendAccessorIfNotEmpty(defaultValueAccessor, v.Name())) if err != nil { return nil, err } @@ -228,7 +258,8 @@ func discoverFieldsRecursive(ctx context.Context, typ *types.Named) ([]FieldInfo // NewGenerator initializes a PFlagProviderGenerator for pflags files for targetTypeName struct under pkg. If pkg is not filled in, // it's assumed to be current package (which is expected to be the common use case when invoking pflags from // go:generate comments) -func NewGenerator(pkg, targetTypeName string) (*PFlagProviderGenerator, error) { +func NewGenerator(pkg, targetTypeName, defaultVariableName string) (*PFlagProviderGenerator, error) { + ctx := context.Background() var err error // Resolve package path if pkg == "" || pkg[0] == '.' { @@ -257,9 +288,22 @@ func NewGenerator(pkg, targetTypeName string) (*PFlagProviderGenerator, error) { return nil, fmt.Errorf("%s should be an struct, was %s", targetTypeName, obj.Type().Underlying()) } + var defaultVar *types.Var + obj = targetPackage.Scope().Lookup(defaultVariableName) + if obj != nil { + defaultVar = obj.(*types.Var) + } + + if defaultVar != nil { + logger.Infof(ctx, "Using default variable with name [%v] to assign all default values.", defaultVariableName) + } else { + logger.Infof(ctx, "Using default values defined in tags if any.") + } + return &PFlagProviderGenerator{ - st: st, - pkg: targetPackage, + st: st, + pkg: targetPackage, + defaultVar: defaultVar, }, nil } @@ -268,7 +312,12 @@ func (g PFlagProviderGenerator) GetTargetPackage() *types.Package { } func (g PFlagProviderGenerator) Generate(ctx context.Context) (PFlagProvider, error) { - fields, err := discoverFieldsRecursive(ctx, g.st) + defaultValueAccessor := "" + if g.defaultVar != nil { + defaultValueAccessor = g.defaultVar.Name() + } + + fields, err := discoverFieldsRecursive(ctx, g.st, defaultValueAccessor) if err != nil { return PFlagProvider{}, err } diff --git a/flytestdlib/cli/pflags/api/generator_test.go b/flytestdlib/cli/pflags/api/generator_test.go index edfab6c1ca..b1d7c617c1 100644 --- a/flytestdlib/cli/pflags/api/generator_test.go +++ b/flytestdlib/cli/pflags/api/generator_test.go @@ -6,6 +6,7 @@ import ( "io/ioutil" "os" "path/filepath" + "reflect" "testing" "github.com/stretchr/testify/assert" @@ -14,8 +15,37 @@ import ( // Make sure existing config file(s) parse correctly before overriding them with this flag! var update = flag.Bool("update", false, "Updates testdata") +// If v is a pointer, it will get its element value or the zero value of the element type. +// If v is not a pointer, it will return it as is. +func elemValueOrNil(v interface{}) interface{} { + if t := reflect.TypeOf(v); t.Kind() == reflect.Ptr { + if reflect.ValueOf(v).IsNil() { + return reflect.Zero(t.Elem()).Interface() + } else { + return reflect.ValueOf(v).Interface() + } + } else if v == nil { + return reflect.Zero(t).Interface() + } + + return v +} + +func TestElemValueOrNil(t *testing.T) { + var iPtr *int + assert.Equal(t, 0, elemValueOrNil(iPtr)) + var sPtr *string + assert.Equal(t, "", elemValueOrNil(sPtr)) + var i int + assert.Equal(t, 0, elemValueOrNil(i)) + var s string + assert.Equal(t, "", elemValueOrNil(s)) + var arr []string + assert.Equal(t, arr, elemValueOrNil(arr)) +} + func TestNewGenerator(t *testing.T) { - g, err := NewGenerator(".", "TestType") + g, err := NewGenerator(".", "TestType", "DefaultTestType") assert.NoError(t, err) ctx := context.Background() diff --git a/flytestdlib/cli/pflags/api/sample.go b/flytestdlib/cli/pflags/api/sample.go index b1ebb50684..43dfce7585 100644 --- a/flytestdlib/cli/pflags/api/sample.go +++ b/flytestdlib/cli/pflags/api/sample.go @@ -3,10 +3,13 @@ package api import ( "encoding/json" "errors" - "github.com/lyft/flytestdlib/storage" ) +var DefaultTestType = &TestType{ + StringValue: "Welcome to defaults", +} + type TestType struct { StringValue string `json:"str" pflag:"\"hello world\",\"life is short\""` BoolValue bool `json:"bl" pflag:"true"` diff --git a/flytestdlib/cli/pflags/api/templates.go b/flytestdlib/cli/pflags/api/templates.go index a7adba8922..545d0ceae1 100644 --- a/flytestdlib/cli/pflags/api/templates.go +++ b/flytestdlib/cli/pflags/api/templates.go @@ -26,9 +26,25 @@ import ( {{$name}} "{{$path}}"{{end}} ) +// If v is a pointer, it will get its element value or the zero value of the element type. +// If v is not a pointer, it will return it as is. +func ({{ .Name }}) elemValueOrNil(v interface{}) interface{} { + if t := reflect.TypeOf(v); t.Kind() == reflect.Ptr { + if reflect.ValueOf(v).IsNil() { + return reflect.Zero(t.Elem()).Interface() + } else { + return reflect.ValueOf(v).Interface() + } + } else if v == nil { + return reflect.Zero(t).Interface() + } + + return v +} + // GetPFlagSet will return strongly types pflags for all fields in {{ .Name }} and its nested types. The format of the // flags is json-name.json-sub-name... etc. -func ({{ .Name }}) GetPFlagSet(prefix string) *pflag.FlagSet { +func (cfg {{ .Name }}) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags := pflag.NewFlagSet("{{ .Name }}", pflag.ExitOnError) {{- range .Fields }} cmdFlags.{{ .FlagMethodName }}(fmt.Sprintf("%v%v", prefix, "{{ .Name }}"), {{ .DefaultValue }}, {{ .UsageString }}) diff --git a/flytestdlib/cli/pflags/api/testdata/testtype.go b/flytestdlib/cli/pflags/api/testdata/testtype.go index 87f5cb7dfe..dde0f48c80 100755 --- a/flytestdlib/cli/pflags/api/testdata/testtype.go +++ b/flytestdlib/cli/pflags/api/testdata/testtype.go @@ -5,32 +5,49 @@ package api import ( "fmt" + "reflect" "github.com/spf13/pflag" ) +// If v is a pointer, it will get its element value or the zero value of the element type. +// If v is not a pointer, it will return it as is. +func (TestType) elemValueOrNil(v interface{}) interface{} { + if t := reflect.TypeOf(v); t.Kind() == reflect.Ptr { + if reflect.ValueOf(v).IsNil() { + return reflect.Zero(t.Elem()).Interface() + } else { + return reflect.ValueOf(v).Interface() + } + } else if v == nil { + return reflect.Zero(t).Interface() + } + + return v +} + // GetPFlagSet will return strongly types pflags for all fields in TestType and its nested types. The format of the // flags is json-name.json-sub-name... etc. -func (TestType) GetPFlagSet(prefix string) *pflag.FlagSet { +func (cfg TestType) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags := pflag.NewFlagSet("TestType", pflag.ExitOnError) - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "str"), "hello world", "life is short") - cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "bl"), true, "") - cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "nested.i"), *new(int), "this is an important flag") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "str"), DefaultTestType.StringValue, "life is short") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "bl"), DefaultTestType.BoolValue, "") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "nested.i"), DefaultTestType.NestedType.IntValue, "this is an important flag") cmdFlags.IntSlice(fmt.Sprintf("%v%v", prefix, "ints"), []int{12, 1}, "") cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "strs"), []string{"12", "1"}, "") cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "complexArr"), []string{}, "") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "c"), "", "I'm a complex type but can be converted from string.") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.type"), "s3", "Sets the type of storage to configure [s3/minio/local/mem].") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.connection.endpoint"), "", "URL for storage client to connect to.") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.connection.auth-type"), "iam", "Auth Type to use [iam, accesskey].") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.connection.access-key"), *new(string), "Access key to use. Only required when authtype is set to accesskey.") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.connection.secret-key"), *new(string), "Secret to use when accesskey is set.") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.connection.region"), "us-east-1", "Region to connect to.") - cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "storage.connection.disable-ssl"), *new(bool), "Disables SSL connection. Should only be used for development.") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.container"), *new(string), "Initial container to create -if it doesn't exist-.'") - cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "storage.cache.max_size_mbs"), *new(int), "Maximum size of the cache where the Blob store data is cached in-memory. If not specified or set to 0, cache is not used") - cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "storage.cache.target_gc_percent"), *new(int), "Sets the garbage collection target percentage.") - cmdFlags.Int64(fmt.Sprintf("%v%v", prefix, "storage.limits.maxDownloadMBs"), 2, "Maximum allowed download size (in MBs) per call.") - cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "i"), *new(int), "") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "c"), fmt.Sprintf("%v", DefaultTestType.StringToJSON), "I'm a complex type but can be converted from string.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.type"), DefaultTestType.StorageConfig.Type, "Sets the type of storage to configure [s3/minio/local/mem].") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.connection.endpoint"), DefaultTestType.StorageConfig.Connection.Endpoint.String(), "URL for storage client to connect to.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.connection.auth-type"), DefaultTestType.StorageConfig.Connection.AuthType, "Auth Type to use [iam, accesskey].") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.connection.access-key"), DefaultTestType.StorageConfig.Connection.AccessKey, "Access key to use. Only required when authtype is set to accesskey.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.connection.secret-key"), DefaultTestType.StorageConfig.Connection.SecretKey, "Secret to use when accesskey is set.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.connection.region"), DefaultTestType.StorageConfig.Connection.Region, "Region to connect to.") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "storage.connection.disable-ssl"), DefaultTestType.StorageConfig.Connection.DisableSSL, "Disables SSL connection. Should only be used for development.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.container"), DefaultTestType.StorageConfig.InitContainer, "Initial container to create -if it doesn't exist-.'") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "storage.cache.max_size_mbs"), DefaultTestType.StorageConfig.Cache.MaxSizeMegabytes, "Maximum size of the cache where the Blob store data is cached in-memory. If not specified or set to 0, cache is not used") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "storage.cache.target_gc_percent"), DefaultTestType.StorageConfig.Cache.TargetGCPercent, "Sets the garbage collection target percentage.") + cmdFlags.Int64(fmt.Sprintf("%v%v", prefix, "storage.limits.maxDownloadMBs"), DefaultTestType.StorageConfig.Limits.GetLimitMegabytes, "Maximum allowed download size (in MBs) per call.") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "i"), cfg.elemValueOrNil(DefaultTestType.IntValue).(int), "") return cmdFlags } diff --git a/flytestdlib/cli/pflags/api/testdata/testtype_test.go b/flytestdlib/cli/pflags/api/testdata/testtype_test.go index f8b81bbe11..03412f0d86 100755 --- a/flytestdlib/cli/pflags/api/testdata/testtype_test.go +++ b/flytestdlib/cli/pflags/api/testdata/testtype_test.go @@ -103,7 +103,7 @@ func TestTestType_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vString, err := cmdFlags.GetString("str"); err == nil { - assert.Equal(t, string("hello world"), vString) + assert.Equal(t, string(DefaultTestType.StringValue), vString) } else { assert.FailNow(t, err.Error()) } @@ -125,7 +125,7 @@ func TestTestType_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vBool, err := cmdFlags.GetBool("bl"); err == nil { - assert.Equal(t, bool(true), vBool) + assert.Equal(t, bool(DefaultTestType.BoolValue), vBool) } else { assert.FailNow(t, err.Error()) } @@ -147,7 +147,7 @@ func TestTestType_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vInt, err := cmdFlags.GetInt("nested.i"); err == nil { - assert.Equal(t, int(*new(int)), vInt) + assert.Equal(t, int(DefaultTestType.NestedType.IntValue), vInt) } else { assert.FailNow(t, err.Error()) } @@ -235,7 +235,7 @@ func TestTestType_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vString, err := cmdFlags.GetString("c"); err == nil { - assert.Equal(t, string(""), vString) + assert.Equal(t, string(fmt.Sprintf("%v", DefaultTestType.StringToJSON)), vString) } else { assert.FailNow(t, err.Error()) } @@ -257,7 +257,7 @@ func TestTestType_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vString, err := cmdFlags.GetString("storage.type"); err == nil { - assert.Equal(t, string("s3"), vString) + assert.Equal(t, string(DefaultTestType.StorageConfig.Type), vString) } else { assert.FailNow(t, err.Error()) } @@ -279,7 +279,7 @@ func TestTestType_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vString, err := cmdFlags.GetString("storage.connection.endpoint"); err == nil { - assert.Equal(t, string(""), vString) + assert.Equal(t, string(DefaultTestType.StorageConfig.Connection.Endpoint.String()), vString) } else { assert.FailNow(t, err.Error()) } @@ -301,7 +301,7 @@ func TestTestType_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vString, err := cmdFlags.GetString("storage.connection.auth-type"); err == nil { - assert.Equal(t, string("iam"), vString) + assert.Equal(t, string(DefaultTestType.StorageConfig.Connection.AuthType), vString) } else { assert.FailNow(t, err.Error()) } @@ -323,7 +323,7 @@ func TestTestType_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vString, err := cmdFlags.GetString("storage.connection.access-key"); err == nil { - assert.Equal(t, string(*new(string)), vString) + assert.Equal(t, string(DefaultTestType.StorageConfig.Connection.AccessKey), vString) } else { assert.FailNow(t, err.Error()) } @@ -345,7 +345,7 @@ func TestTestType_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vString, err := cmdFlags.GetString("storage.connection.secret-key"); err == nil { - assert.Equal(t, string(*new(string)), vString) + assert.Equal(t, string(DefaultTestType.StorageConfig.Connection.SecretKey), vString) } else { assert.FailNow(t, err.Error()) } @@ -367,7 +367,7 @@ func TestTestType_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vString, err := cmdFlags.GetString("storage.connection.region"); err == nil { - assert.Equal(t, string("us-east-1"), vString) + assert.Equal(t, string(DefaultTestType.StorageConfig.Connection.Region), vString) } else { assert.FailNow(t, err.Error()) } @@ -389,7 +389,7 @@ func TestTestType_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vBool, err := cmdFlags.GetBool("storage.connection.disable-ssl"); err == nil { - assert.Equal(t, bool(*new(bool)), vBool) + assert.Equal(t, bool(DefaultTestType.StorageConfig.Connection.DisableSSL), vBool) } else { assert.FailNow(t, err.Error()) } @@ -411,7 +411,7 @@ func TestTestType_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vString, err := cmdFlags.GetString("storage.container"); err == nil { - assert.Equal(t, string(*new(string)), vString) + assert.Equal(t, string(DefaultTestType.StorageConfig.InitContainer), vString) } else { assert.FailNow(t, err.Error()) } @@ -433,7 +433,7 @@ func TestTestType_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vInt, err := cmdFlags.GetInt("storage.cache.max_size_mbs"); err == nil { - assert.Equal(t, int(*new(int)), vInt) + assert.Equal(t, int(DefaultTestType.StorageConfig.Cache.MaxSizeMegabytes), vInt) } else { assert.FailNow(t, err.Error()) } @@ -455,7 +455,7 @@ func TestTestType_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vInt, err := cmdFlags.GetInt("storage.cache.target_gc_percent"); err == nil { - assert.Equal(t, int(*new(int)), vInt) + assert.Equal(t, int(DefaultTestType.StorageConfig.Cache.TargetGCPercent), vInt) } else { assert.FailNow(t, err.Error()) } @@ -477,7 +477,7 @@ func TestTestType_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vInt64, err := cmdFlags.GetInt64("storage.limits.maxDownloadMBs"); err == nil { - assert.Equal(t, int64(2), vInt64) + assert.Equal(t, int64(DefaultTestType.StorageConfig.Limits.GetLimitMegabytes), vInt64) } else { assert.FailNow(t, err.Error()) } @@ -499,7 +499,7 @@ func TestTestType_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vInt, err := cmdFlags.GetInt("i"); err == nil { - assert.Equal(t, int(*new(int)), vInt) + assert.Equal(t, int(cfg.elemValueOrNil(DefaultTestType.IntValue).(int)), vInt) } else { assert.FailNow(t, err.Error()) } diff --git a/flytestdlib/cli/pflags/api/utils.go b/flytestdlib/cli/pflags/api/utils.go index 4c71fbb1c4..16edb50d38 100644 --- a/flytestdlib/cli/pflags/api/utils.go +++ b/flytestdlib/cli/pflags/api/utils.go @@ -20,13 +20,29 @@ func camelCase(str string) string { return str } -func jsonUnmarshaler(t types.Type) bool { +func isJSONUnmarshaler(t types.Type) bool { + found, _ := implementsAnyOfMethods(t, "UnmarshalJSON") + return found +} + +func isStringer(t types.Type) bool { + found, _ := implementsAnyOfMethods(t, "String") + return found +} + +func implementsAnyOfMethods(t types.Type, methodNames ...string) (found, implementedByPtr bool) { mset := types.NewMethodSet(t) - jsonUnmarshaler := mset.Lookup(nil, "UnmarshalJSON") - if jsonUnmarshaler == nil { - mset = types.NewMethodSet(types.NewPointer(t)) - jsonUnmarshaler = mset.Lookup(nil, "UnmarshalJSON") + for _, name := range methodNames { + if mset.Lookup(nil, name) != nil { + return true, false + } + } + mset = types.NewMethodSet(types.NewPointer(t)) + for _, name := range methodNames { + if mset.Lookup(nil, name) != nil { + return true, true + } } - return jsonUnmarshaler != nil + return false, false } diff --git a/flytestdlib/cli/pflags/cmd/root.go b/flytestdlib/cli/pflags/cmd/root.go index b6562d8a1b..d78d4c4d4a 100644 --- a/flytestdlib/cli/pflags/cmd/root.go +++ b/flytestdlib/cli/pflags/cmd/root.go @@ -3,7 +3,6 @@ package cmd import ( "bytes" "context" - "flag" "fmt" "strings" @@ -13,7 +12,8 @@ import ( ) var ( - pkg = flag.String("pkg", ".", "what package to get the interface from") + pkg string + defaultValuesVariable string ) var root = cobra.Command{ @@ -31,7 +31,8 @@ type MyStruct struct { } func init() { - root.Flags().StringP("package", "p", ".", "Determines the source/destination package.") + root.Flags().StringVarP(&pkg, "package", "p", ".", "Determines the source/destination package.") + root.Flags().StringVar(&defaultValuesVariable, "default-var", "defaultConfig", "Points to a variable to use to load default configs. If specified & found, it'll be used instead of the values specified in the tag.") } func Execute() error { @@ -45,7 +46,7 @@ func generatePflagsProvider(cmd *cobra.Command, args []string) error { } ctx := context.Background() - gen, err := api.NewGenerator(*pkg, structName) + gen, err := api.NewGenerator(pkg, structName, defaultValuesVariable) if err != nil { return err } diff --git a/flytestdlib/logger/config.go b/flytestdlib/logger/config.go index faa2da6f9e..1c665ac523 100644 --- a/flytestdlib/logger/config.go +++ b/flytestdlib/logger/config.go @@ -6,7 +6,7 @@ import ( "github.com/lyft/flytestdlib/config" ) -//go:generate pflags Config +//go:generate pflags Config --default-var defaultConfig const configSectionKey = "Logger" @@ -21,6 +21,13 @@ const ( jsonDataKey string = "json" ) +var defaultConfig = &Config{ + Formatter: FormatterConfig{ + Type: FormatterJSON, + }, + Level: InfoLevel, +} + // Global logger config. type Config struct { // Determines whether to include source code location in logs. This might incurs a performance hit and is only @@ -31,13 +38,13 @@ type Config struct { Mute bool `json:"mute" pflag:",Mutes all logs regardless of severity. Intended for benchmarks/tests only."` // Determines the minimum log level to log. - Level Level `json:"level" pflag:"4,Sets the minimum logging level."` + Level Level `json:"level" pflag:",Sets the minimum logging level."` Formatter FormatterConfig `json:"formatter" pflag:",Sets logging format."` } type FormatterConfig struct { - Type FormatterType `json:"type" pflag:"\"json\",Sets logging format type."` + Type FormatterType `json:"type" pflag:",Sets logging format type."` } var globalConfig = Config{} @@ -73,9 +80,7 @@ const ( ) func init() { - if _, err := config.RegisterSectionWithUpdates(configSectionKey, &Config{}, func(ctx context.Context, newValue config.Config) { + config.MustRegisterSectionWithUpdates(configSectionKey, defaultConfig, func(ctx context.Context, newValue config.Config) { SetConfig(*newValue.(*Config)) - }); err != nil { - panic(err) - } + }) } diff --git a/flytestdlib/logger/config_flags.go b/flytestdlib/logger/config_flags.go index cf8950b94f..27be2e3440 100755 --- a/flytestdlib/logger/config_flags.go +++ b/flytestdlib/logger/config_flags.go @@ -5,17 +5,34 @@ package logger import ( "fmt" + "reflect" "github.com/spf13/pflag" ) +// If v is a pointer, it will get its element value or the zero value of the element type. +// If v is not a pointer, it will return it as is. +func (Config) elemValueOrNil(v interface{}) interface{} { + if t := reflect.TypeOf(v); t.Kind() == reflect.Ptr { + if reflect.ValueOf(v).IsNil() { + return reflect.Zero(t.Elem()).Interface() + } else { + return reflect.ValueOf(v).Interface() + } + } else if v == nil { + return reflect.Zero(t).Interface() + } + + return v +} + // GetPFlagSet will return strongly types pflags for all fields in Config and its nested types. The format of the // flags is json-name.json-sub-name... etc. -func (Config) GetPFlagSet(prefix string) *pflag.FlagSet { +func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags := pflag.NewFlagSet("Config", pflag.ExitOnError) - cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "show-source"), *new(bool), "Includes source code location in logs.") - cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "mute"), *new(bool), "Mutes all logs regardless of severity. Intended for benchmarks/tests only.") - cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "level"), 4, "Sets the minimum logging level.") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "formatter.type"), "json", "Sets logging format type.") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "show-source"), defaultConfig.IncludeSourceCode, "Includes source code location in logs.") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "mute"), defaultConfig.Mute, "Mutes all logs regardless of severity. Intended for benchmarks/tests only.") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "level"), defaultConfig.Level, "Sets the minimum logging level.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "formatter.type"), defaultConfig.Formatter.Type, "Sets logging format type.") return cmdFlags } diff --git a/flytestdlib/logger/config_flags_test.go b/flytestdlib/logger/config_flags_test.go index 401d58d493..853aeac0b0 100755 --- a/flytestdlib/logger/config_flags_test.go +++ b/flytestdlib/logger/config_flags_test.go @@ -103,7 +103,7 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vBool, err := cmdFlags.GetBool("show-source"); err == nil { - assert.Equal(t, bool(*new(bool)), vBool) + assert.Equal(t, bool(defaultConfig.IncludeSourceCode), vBool) } else { assert.FailNow(t, err.Error()) } @@ -125,7 +125,7 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vBool, err := cmdFlags.GetBool("mute"); err == nil { - assert.Equal(t, bool(*new(bool)), vBool) + assert.Equal(t, bool(defaultConfig.Mute), vBool) } else { assert.FailNow(t, err.Error()) } @@ -147,7 +147,7 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vInt, err := cmdFlags.GetInt("level"); err == nil { - assert.Equal(t, int(4), vInt) + assert.Equal(t, int(defaultConfig.Level), vInt) } else { assert.FailNow(t, err.Error()) } @@ -169,7 +169,7 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vString, err := cmdFlags.GetString("formatter.type"); err == nil { - assert.Equal(t, string("json"), vString) + assert.Equal(t, string(defaultConfig.Formatter.Type), vString) } else { assert.FailNow(t, err.Error()) } @@ -188,3 +188,36 @@ func TestConfig_SetFlags(t *testing.T) { }) }) } + +func TestConfig_elemValueOrNil(t *testing.T) { + type fields struct { + IncludeSourceCode bool + Mute bool + Level Level + Formatter FormatterConfig + } + type args struct { + v interface{} + } + tests := []struct { + name string + fields fields + args args + want interface{} + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := Config{ + IncludeSourceCode: tt.fields.IncludeSourceCode, + Mute: tt.fields.Mute, + Level: tt.fields.Level, + Formatter: tt.fields.Formatter, + } + if got := c.elemValueOrNil(tt.args.v); !reflect.DeepEqual(got, tt.want) { + t.Errorf("Config.elemValueOrNil() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/flytestdlib/profutils/server_test.go b/flytestdlib/profutils/server_test.go index e2eb709943..a7ef350002 100644 --- a/flytestdlib/profutils/server_test.go +++ b/flytestdlib/profutils/server_test.go @@ -70,9 +70,9 @@ func TestConfigHandler(t *testing.T) { "logger": map[string]interface{}{ "show-source": false, "mute": false, - "level": float64(0), + "level": float64(4), "formatter": map[string]interface{}{ - "type": "", + "type": "json", }, }, }, m) diff --git a/flytestdlib/storage/config.go b/flytestdlib/storage/config.go index 59db0197c3..780ef12ac9 100644 --- a/flytestdlib/storage/config.go +++ b/flytestdlib/storage/config.go @@ -28,12 +28,22 @@ const ( ) var ( - ConfigSection = config.MustRegisterSection(configSectionKey, &Config{}) + ConfigSection = config.MustRegisterSection(configSectionKey, defaultConfig) + defaultConfig = &Config{ + Type: TypeS3, + Limits: LimitsConfig{ + GetLimitMegabytes: 2, + }, + Connection: ConnectionConfig{ + Region: "us-east-1", + AuthType: "iam", + }, + } ) // A common storage config. type Config struct { - Type Type `json:"type" pflag:"\"s3\",Sets the type of storage to configure [s3/minio/local/mem]."` + Type Type `json:"type" pflag:",Sets the type of storage to configure [s3/minio/local/mem]."` Connection ConnectionConfig `json:"connection"` InitContainer string `json:"container" pflag:",Initial container to create -if it doesn't exist-.'"` // Caching is recommended to improve the performance of underlying systems. It caches the metadata and resolving @@ -47,10 +57,10 @@ type Config struct { // Defines connection configurations. type ConnectionConfig struct { Endpoint config.URL `json:"endpoint" pflag:",URL for storage client to connect to."` - AuthType string `json:"auth-type" pflag:"\"iam\",Auth Type to use [iam,accesskey]."` + AuthType string `json:"auth-type" pflag:",Auth Type to use [iam,accesskey]."` AccessKey string `json:"access-key" pflag:",Access key to use. Only required when authtype is set to accesskey."` SecretKey string `json:"secret-key" pflag:",Secret to use when accesskey is set."` - Region string `json:"region" pflag:"\"us-east-1\",Region to connect to."` + Region string `json:"region" pflag:",Region to connect to."` DisableSSL bool `json:"disable-ssl" pflag:",Disables SSL connection. Should only be used for development."` } @@ -71,7 +81,7 @@ type CachingConfig struct { // Specifies limits for storage package. type LimitsConfig struct { - GetLimitMegabytes int64 `json:"maxDownloadMBs" pflag:"2,Maximum allowed download size (in MBs) per call."` + GetLimitMegabytes int64 `json:"maxDownloadMBs" pflag:",Maximum allowed download size (in MBs) per call."` } // Retrieve current global config for storage. diff --git a/flytestdlib/storage/config_flags.go b/flytestdlib/storage/config_flags.go index 9a74efdd64..4cde49bec6 100755 --- a/flytestdlib/storage/config_flags.go +++ b/flytestdlib/storage/config_flags.go @@ -5,24 +5,41 @@ package storage import ( "fmt" + "reflect" "github.com/spf13/pflag" ) +// If v is a pointer, it will get its element value or the zero value of the element type. +// If v is not a pointer, it will return it as is. +func (Config) elemValueOrNil(v interface{}) interface{} { + if t := reflect.TypeOf(v); t.Kind() == reflect.Ptr { + if reflect.ValueOf(v).IsNil() { + return reflect.Zero(t.Elem()).Interface() + } else { + return reflect.ValueOf(v).Interface() + } + } else if v == nil { + return reflect.Zero(t).Interface() + } + + return v +} + // GetPFlagSet will return strongly types pflags for all fields in Config and its nested types. The format of the // flags is json-name.json-sub-name... etc. -func (Config) GetPFlagSet(prefix string) *pflag.FlagSet { +func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags := pflag.NewFlagSet("Config", pflag.ExitOnError) - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "type"), "s3", "Sets the type of storage to configure [s3/minio/local/mem].") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "connection.endpoint"), "", "URL for storage client to connect to.") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "connection.auth-type"), "iam", "Auth Type to use [iam, accesskey].") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "connection.access-key"), *new(string), "Access key to use. Only required when authtype is set to accesskey.") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "connection.secret-key"), *new(string), "Secret to use when accesskey is set.") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "connection.region"), "us-east-1", "Region to connect to.") - cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "connection.disable-ssl"), *new(bool), "Disables SSL connection. Should only be used for development.") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "container"), *new(string), "Initial container to create -if it doesn't exist-.'") - cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "cache.max_size_mbs"), *new(int), "Maximum size of the cache where the Blob store data is cached in-memory. If not specified or set to 0, cache is not used") - cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "cache.target_gc_percent"), *new(int), "Sets the garbage collection target percentage.") - cmdFlags.Int64(fmt.Sprintf("%v%v", prefix, "limits.maxDownloadMBs"), 2, "Maximum allowed download size (in MBs) per call.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "type"), defaultConfig.Type, "Sets the type of storage to configure [s3/minio/local/mem].") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "connection.endpoint"), defaultConfig.Connection.Endpoint.String(), "URL for storage client to connect to.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "connection.auth-type"), defaultConfig.Connection.AuthType, "Auth Type to use [iam, accesskey].") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "connection.access-key"), defaultConfig.Connection.AccessKey, "Access key to use. Only required when authtype is set to accesskey.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "connection.secret-key"), defaultConfig.Connection.SecretKey, "Secret to use when accesskey is set.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "connection.region"), defaultConfig.Connection.Region, "Region to connect to.") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "connection.disable-ssl"), defaultConfig.Connection.DisableSSL, "Disables SSL connection. Should only be used for development.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "container"), defaultConfig.InitContainer, "Initial container to create -if it doesn't exist-.'") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "cache.max_size_mbs"), defaultConfig.Cache.MaxSizeMegabytes, "Maximum size of the cache where the Blob store data is cached in-memory. If not specified or set to 0, cache is not used") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "cache.target_gc_percent"), defaultConfig.Cache.TargetGCPercent, "Sets the garbage collection target percentage.") + cmdFlags.Int64(fmt.Sprintf("%v%v", prefix, "limits.maxDownloadMBs"), defaultConfig.Limits.GetLimitMegabytes, "Maximum allowed download size (in MBs) per call.") return cmdFlags } diff --git a/flytestdlib/storage/config_flags_test.go b/flytestdlib/storage/config_flags_test.go index 2f39f00006..429af71283 100755 --- a/flytestdlib/storage/config_flags_test.go +++ b/flytestdlib/storage/config_flags_test.go @@ -103,7 +103,7 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vString, err := cmdFlags.GetString("type"); err == nil { - assert.Equal(t, string("s3"), vString) + assert.Equal(t, string(defaultConfig.Type), vString) } else { assert.FailNow(t, err.Error()) } @@ -125,7 +125,7 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vString, err := cmdFlags.GetString("connection.endpoint"); err == nil { - assert.Equal(t, string(""), vString) + assert.Equal(t, string(defaultConfig.Connection.Endpoint.String()), vString) } else { assert.FailNow(t, err.Error()) } @@ -147,7 +147,7 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vString, err := cmdFlags.GetString("connection.auth-type"); err == nil { - assert.Equal(t, string("iam"), vString) + assert.Equal(t, string(defaultConfig.Connection.AuthType), vString) } else { assert.FailNow(t, err.Error()) } @@ -169,7 +169,7 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vString, err := cmdFlags.GetString("connection.access-key"); err == nil { - assert.Equal(t, string(*new(string)), vString) + assert.Equal(t, string(defaultConfig.Connection.AccessKey), vString) } else { assert.FailNow(t, err.Error()) } @@ -191,7 +191,7 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vString, err := cmdFlags.GetString("connection.secret-key"); err == nil { - assert.Equal(t, string(*new(string)), vString) + assert.Equal(t, string(defaultConfig.Connection.SecretKey), vString) } else { assert.FailNow(t, err.Error()) } @@ -213,7 +213,7 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vString, err := cmdFlags.GetString("connection.region"); err == nil { - assert.Equal(t, string("us-east-1"), vString) + assert.Equal(t, string(defaultConfig.Connection.Region), vString) } else { assert.FailNow(t, err.Error()) } @@ -235,7 +235,7 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vBool, err := cmdFlags.GetBool("connection.disable-ssl"); err == nil { - assert.Equal(t, bool(*new(bool)), vBool) + assert.Equal(t, bool(defaultConfig.Connection.DisableSSL), vBool) } else { assert.FailNow(t, err.Error()) } @@ -257,7 +257,7 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vString, err := cmdFlags.GetString("container"); err == nil { - assert.Equal(t, string(*new(string)), vString) + assert.Equal(t, string(defaultConfig.InitContainer), vString) } else { assert.FailNow(t, err.Error()) } @@ -279,7 +279,7 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vInt, err := cmdFlags.GetInt("cache.max_size_mbs"); err == nil { - assert.Equal(t, int(*new(int)), vInt) + assert.Equal(t, int(defaultConfig.Cache.MaxSizeMegabytes), vInt) } else { assert.FailNow(t, err.Error()) } @@ -301,7 +301,7 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vInt, err := cmdFlags.GetInt("cache.target_gc_percent"); err == nil { - assert.Equal(t, int(*new(int)), vInt) + assert.Equal(t, int(defaultConfig.Cache.TargetGCPercent), vInt) } else { assert.FailNow(t, err.Error()) } @@ -323,7 +323,7 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vInt64, err := cmdFlags.GetInt64("limits.maxDownloadMBs"); err == nil { - assert.Equal(t, int64(2), vInt64) + assert.Equal(t, int64(defaultConfig.Limits.GetLimitMegabytes), vInt64) } else { assert.FailNow(t, err.Error()) } From 17757daad422af6d2830649fad3bd0b1a8bc615e Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Wed, 10 Apr 2019 15:41:36 -0700 Subject: [PATCH 0009/1918] Update target go in lint make target --- flytestdlib/boilerplate/lyft/golang_test_targets/Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flytestdlib/boilerplate/lyft/golang_test_targets/Makefile b/flytestdlib/boilerplate/lyft/golang_test_targets/Makefile index 1c6f893521..04b79ba99e 100644 --- a/flytestdlib/boilerplate/lyft/golang_test_targets/Makefile +++ b/flytestdlib/boilerplate/lyft/golang_test_targets/Makefile @@ -1,6 +1,6 @@ .PHONY: lint lint: #lints the package for common code smells - which golangci-lint || curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | bash -s -- -b $$GOPATH/bin v1.10 + which golangci-lint || curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | bash -s -- -b $$GOPATH/bin v1.11 golangci-lint run # If code is failing goimports linter, this will fix. From 13aff7ac2f34aaf0e9b8b9b7d589b7b2845d5f12 Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Wed, 17 Apr 2019 11:37:50 -0700 Subject: [PATCH 0010/1918] Add lods to debug test concurrent issue --- flytestdlib/config/tests/accessor_test.go | 42 +++++++++++++++++------ 1 file changed, 31 insertions(+), 11 deletions(-) diff --git a/flytestdlib/config/tests/accessor_test.go b/flytestdlib/config/tests/accessor_test.go index 34d86237e9..702c9fc86a 100644 --- a/flytestdlib/config/tests/accessor_test.go +++ b/flytestdlib/config/tests/accessor_test.go @@ -30,6 +30,11 @@ import ( type accessorCreatorFn func(registry config.Section, configPath string) config.Accessor +type testLogger interface { + Logf(format string, args ...interface{}) + Errorf(format string, args ...interface{}) +} + func getRandInt() uint64 { c := 10 b := make([]byte, c) @@ -323,7 +328,7 @@ func TestAccessor_UpdateConfig(t *testing.T) { r := reg.GetSection(MyComponentSectionKey).GetConfig().(*MyComponentConfig) firstValue := r.StringValue - fileUpdated, err := beginWaitForFileChange(configFile) + fileUpdated, err := beginWaitForFileChange(t, configFile) assert.NoError(t, err) _, err = populateConfigData(configFile) @@ -346,7 +351,7 @@ func TestAccessor_UpdateConfig(t *testing.T) { // Independently watch for when symlink underlying change happens to know when do we expect accessor to have picked up // the changes - fileUpdated, err := beginWaitForFileChange(configFile) + fileUpdated, err := beginWaitForFileChange(t, configFile) assert.NoError(t, err) // 2. Start accessor with the symlink as config location @@ -383,7 +388,7 @@ func TestAccessor_UpdateConfig(t *testing.T) { // Wait for filewatcher event assert.NoError(t, waitForFileChangeOrTimeout(fileUpdated)) - time.Sleep(2 * time.Second) + time.Sleep(5 * time.Second) r = section.GetConfig().(*MyComponentConfig) secondValue := r.StringValue @@ -462,7 +467,7 @@ func waitForFileChangeOrTimeout(done chan error) error { } } -func beginWaitForFileChange(filename string) (done chan error, terminalErr error) { +func beginWaitForFileChange(logger testLogger, filename string) (done chan error, terminalErr error) { watcher, err := fsnotify.NewWatcher() if err != nil { return nil, err @@ -480,12 +485,21 @@ func beginWaitForFileChange(filename string) (done chan error, terminalErr error go func() { for { select { - case event := <-watcher.Events: + case event, channelOpen := <-watcher.Events: + if !channelOpen { + logger.Logf("Events Channel has been closed") + done <- nil + return + } + + logger.Logf("Received watcher event [%v], %v", event) // we only care about the config file currentConfigFile, err := filepath.EvalSymlinks(filename) if err != nil { + logger.Errorf("Failed to EvalSymLinks. Will attempt to close watcher now. Error: %v", err) closeErr := watcher.Close() if closeErr != nil { + logger.Errorf("Failed to close watcher. Error: %v", closeErr) done <- closeErr } else { done <- err @@ -501,10 +515,12 @@ func beginWaitForFileChange(filename string) (done chan error, terminalErr error if (filepath.Clean(event.Name) == configFile && event.Op&writeOrCreateMask != 0) || (currentConfigFile != "" && currentConfigFile != realConfigFile) { + + logger.Logf("CurrentConfigFile [%v], RealConfigFile [%v]", currentConfigFile, realConfigFile) realConfigFile = currentConfigFile closeErr := watcher.Close() if closeErr != nil { - fmt.Printf("Close Watcher error: %v\n", closeErr) + logger.Errorf("Failed to close watcher. Error: %v", closeErr) } else { done <- nil } @@ -512,21 +528,25 @@ func beginWaitForFileChange(filename string) (done chan error, terminalErr error return } else if filepath.Clean(event.Name) == configFile && event.Op&fsnotify.Remove&fsnotify.Remove != 0 { + + logger.Logf("ConfigFile [%v] Removed.", configFile) closeErr := watcher.Close() if closeErr != nil { - fmt.Printf("Close Watcher error: %v\n", closeErr) + logger.Logf("Close Watcher error: %v", closeErr) } else { done <- nil } return } - case err, ok := <-watcher.Errors: - if ok { - fmt.Printf("Watcher error: %v\n", err) + case err, channelOpen := <-watcher.Errors: + if !channelOpen { + logger.Logf("Error Channel has been closed.") + } else { + logger.Logf("Watcher error: %v", err) closeErr := watcher.Close() if closeErr != nil { - fmt.Printf("Close Watcher error: %v\n", closeErr) + logger.Logf("Close Watcher error: %v\n", closeErr) } } From 6cfceec205ec460b8bd0ddca122740d0249140b3 Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Wed, 17 Apr 2019 16:29:41 -0700 Subject: [PATCH 0011/1918] atomic symlink --- flytestdlib/config/tests/accessor_test.go | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/flytestdlib/config/tests/accessor_test.go b/flytestdlib/config/tests/accessor_test.go index 702c9fc86a..5bffe8650e 100644 --- a/flytestdlib/config/tests/accessor_test.go +++ b/flytestdlib/config/tests/accessor_test.go @@ -399,8 +399,8 @@ func TestAccessor_UpdateConfig(t *testing.T) { } func changeSymLink(targetPath, symLink string) error { + tmpLink := tempFileName("temp-sym-link-*") if runtime.GOOS == "windows" { - tmpLink := tempFileName("temp-sym-link-*") err := exec.Command("mklink", filepath.Clean(tmpLink), filepath.Clean(targetPath)).Run() if err != nil { return err @@ -414,7 +414,15 @@ func changeSymLink(targetPath, symLink string) error { return exec.Command("del", filepath.Clean(tmpLink)).Run() } - return exec.Command("ln", "-sfn", filepath.Clean(targetPath), filepath.Clean(symLink)).Run() + // ln -sfn is not an atomic operation. Under the hood, it first calls the system unlink then symlink calls. During + // that, there will be a brief moment when there is no symlink at all. mv operation is, however, atomic. That's + // why we make this command instead + err := exec.Command("ln", "-s", filepath.Clean(targetPath), filepath.Clean(tmpLink)).Run() + if err != nil { + return err + } + + return exec.Command("mv", "-Tf", filepath.Clean(tmpLink), filepath.Clean(symLink)).Run() } // 1. Create Dir structure: @@ -492,7 +500,7 @@ func beginWaitForFileChange(logger testLogger, filename string) (done chan error return } - logger.Logf("Received watcher event [%v], %v", event) + logger.Logf("Received watcher event [%v]", event) // we only care about the config file currentConfigFile, err := filepath.EvalSymlinks(filename) if err != nil { From 6cc4a4d48f3ada4101cf6e9909626c37de48982b Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Wed, 17 Apr 2019 23:29:42 -0700 Subject: [PATCH 0012/1918] simplify tests --- flytestdlib/cli/pflags/api/generator.go | 43 +++-- flytestdlib/cli/pflags/api/generator_test.go | 4 +- flytestdlib/cli/pflags/api/sample.go | 1 + flytestdlib/config/tests/accessor_test.go | 167 ++----------------- 4 files changed, 51 insertions(+), 164 deletions(-) diff --git a/flytestdlib/cli/pflags/api/generator.go b/flytestdlib/cli/pflags/api/generator.go index 0ee2a61de4..c6ea06fc21 100644 --- a/flytestdlib/cli/pflags/api/generator.go +++ b/flytestdlib/cli/pflags/api/generator.go @@ -5,6 +5,7 @@ import ( "fmt" "go/types" "path/filepath" + "strings" "github.com/lyft/flytestdlib/logger" @@ -86,17 +87,35 @@ func buildFieldForSlice(ctx context.Context, t SliceOrArray, name, goName, usage }, nil } -func appendAccessorIfNotEmpty(baseAccessor, childAccessor string) string { - if len(baseAccessor) == 0 { - return baseAccessor +// Appends field accessors using "." as the delimiter. +// e.g. appendAccessors("var1", "field1", "subField") will output "var1.field1.subField" +func appendAccessors(accessors ...string) string { + sb := strings.Builder{} + switch len(accessors) { + case 0: + return "" + case 1: + return accessors[0] } - return baseAccessor + "." + childAccessor + for _, s := range accessors { + if len(s) > 0 { + if sb.Len() > 0 { + sb.WriteString(".") + } + + sb.WriteString(s) + } + } + + return sb.String() } // Traverses fields in type and follows recursion tree to discover all fields. It stops when one of two conditions is // met; encountered a basic type (e.g. string, int... etc.) or the field type implements UnmarshalJSON. -func discoverFieldsRecursive(ctx context.Context, typ *types.Named, defaultValueAccessor string) ([]FieldInfo, error) { +// If passed a non-empty defaultValueAccessor, it'll be used to fill in default values instead of any default value +// specified in pflag tag. +func discoverFieldsRecursive(ctx context.Context, typ *types.Named, defaultValueAccessor, fieldPath string) ([]FieldInfo, error) { logger.Printf(ctx, "Finding all fields in [%v.%v.%v]", typ.Obj().Pkg().Path(), typ.Obj().Pkg().Name(), typ.Obj().Name()) @@ -149,8 +168,8 @@ func discoverFieldsRecursive(ctx context.Context, typ *types.Named, defaultValue } defaultValue := tag.DefaultValue - if accessor := appendAccessorIfNotEmpty(defaultValueAccessor, v.Name()); len(accessor) > 0 { - defaultValue = accessor + if len(defaultValueAccessor) > 0 { + defaultValue = appendAccessors(defaultValueAccessor, fieldPath, v.Name()) if isPtr { defaultValue = fmt.Sprintf("cfg.elemValueOrNil(%s).(%s)", defaultValue, t.Name()) @@ -184,11 +203,13 @@ func discoverFieldsRecursive(ctx context.Context, typ *types.Named, defaultValue } defaultValue := tag.DefaultValue - if accessor := appendAccessorIfNotEmpty(defaultValueAccessor, v.Name()); len(accessor) > 0 { - defaultValue = accessor + if len(defaultValueAccessor) > 0 { + defaultValue = appendAccessors(defaultValueAccessor, fieldPath, v.Name()) if isStringer(t) { defaultValue = defaultValue + ".String()" } else { + logger.Infof(ctx, "Field [%v] of type [%v] does not implement Stringer interface."+ + " Will use fmt.Sprintf() to get its default value.", v.Name(), t.String()) defaultValue = fmt.Sprintf("fmt.Sprintf(\"%%v\",%s)", defaultValue) } } @@ -211,7 +232,7 @@ func discoverFieldsRecursive(ctx context.Context, typ *types.Named, defaultValue } else { logger.Infof(ctx, "Traversing fields in type.") - nested, err := discoverFieldsRecursive(logger.WithIndent(ctx, indent), t, appendAccessorIfNotEmpty(defaultValueAccessor, v.Name())) + nested, err := discoverFieldsRecursive(logger.WithIndent(ctx, indent), t, defaultValueAccessor, appendAccessors(fieldPath, v.Name())) if err != nil { return nil, err } @@ -317,7 +338,7 @@ func (g PFlagProviderGenerator) Generate(ctx context.Context) (PFlagProvider, er defaultValueAccessor = g.defaultVar.Name() } - fields, err := discoverFieldsRecursive(ctx, g.st, defaultValueAccessor) + fields, err := discoverFieldsRecursive(ctx, g.st, defaultValueAccessor, "") if err != nil { return PFlagProvider{}, err } diff --git a/flytestdlib/cli/pflags/api/generator_test.go b/flytestdlib/cli/pflags/api/generator_test.go index b1d7c617c1..26f77a64e0 100644 --- a/flytestdlib/cli/pflags/api/generator_test.go +++ b/flytestdlib/cli/pflags/api/generator_test.go @@ -21,9 +21,9 @@ func elemValueOrNil(v interface{}) interface{} { if t := reflect.TypeOf(v); t.Kind() == reflect.Ptr { if reflect.ValueOf(v).IsNil() { return reflect.Zero(t.Elem()).Interface() - } else { - return reflect.ValueOf(v).Interface() } + + return reflect.ValueOf(v).Interface() } else if v == nil { return reflect.Zero(t).Interface() } diff --git a/flytestdlib/cli/pflags/api/sample.go b/flytestdlib/cli/pflags/api/sample.go index 43dfce7585..f1b40e5b74 100644 --- a/flytestdlib/cli/pflags/api/sample.go +++ b/flytestdlib/cli/pflags/api/sample.go @@ -3,6 +3,7 @@ package api import ( "encoding/json" "errors" + "github.com/lyft/flytestdlib/storage" ) diff --git a/flytestdlib/config/tests/accessor_test.go b/flytestdlib/config/tests/accessor_test.go index 5bffe8650e..d8c9e2a5a0 100644 --- a/flytestdlib/config/tests/accessor_test.go +++ b/flytestdlib/config/tests/accessor_test.go @@ -16,8 +16,6 @@ import ( "testing" "time" - "github.com/fsnotify/fsnotify" - k8sRand "k8s.io/apimachinery/pkg/util/rand" "github.com/lyft/flytestdlib/config" @@ -30,11 +28,6 @@ import ( type accessorCreatorFn func(registry config.Section, configPath string) config.Accessor -type testLogger interface { - Logf(format string, args ...interface{}) - Errorf(format string, args ...interface{}) -} - func getRandInt() uint64 { c := 10 b := make([]byte, c) @@ -328,16 +321,11 @@ func TestAccessor_UpdateConfig(t *testing.T) { r := reg.GetSection(MyComponentSectionKey).GetConfig().(*MyComponentConfig) firstValue := r.StringValue - fileUpdated, err := beginWaitForFileChange(t, configFile) - assert.NoError(t, err) - _, err = populateConfigData(configFile) assert.NoError(t, err) - // Simulate filewatcher event - assert.NoError(t, waitForFileChangeOrTimeout(fileUpdated)) - - time.Sleep(2 * time.Second) + // Wait enough for the file change notification to propagate. + time.Sleep(5 * time.Second) r = reg.GetSection(MyComponentSectionKey).GetConfig().(*MyComponentConfig) secondValue := r.StringValue @@ -345,20 +333,17 @@ func TestAccessor_UpdateConfig(t *testing.T) { }) t.Run(fmt.Sprintf("[%v] Change handler k8s configmaps", provider(config.Options{}).ID()), func(t *testing.T) { + reg := config.NewRootSection() + section, err := reg.RegisterSection(MyComponentSectionKey, &MyComponentConfig{}) + assert.NoError(t, err) + + var firstValue string + // 1. Create Dir structure watchDir, configFile, cleanup := newSymlinkedConfigFile(t) defer cleanup() - // Independently watch for when symlink underlying change happens to know when do we expect accessor to have picked up - // the changes - fileUpdated, err := beginWaitForFileChange(t, configFile) - assert.NoError(t, err) - // 2. Start accessor with the symlink as config location - reg := config.NewRootSection() - section, err := reg.RegisterSection(MyComponentSectionKey, &MyComponentConfig{}) - assert.NoError(t, err) - opts := config.Options{ SearchPaths: []string{configFile}, RootSection: reg, @@ -368,7 +353,8 @@ func TestAccessor_UpdateConfig(t *testing.T) { assert.NoError(t, err) r := section.GetConfig().(*MyComponentConfig) - firstValue := r.StringValue + firstValue = r.StringValue + t.Logf("First value: %v", firstValue) // 3. Now update /data symlink to point to data2 dataDir2 := path.Join(watchDir, "data2") @@ -376,8 +362,9 @@ func TestAccessor_UpdateConfig(t *testing.T) { assert.NoError(t, err) configFile2 := path.Join(dataDir2, "config.yaml") - _, err = populateConfigData(configFile2) + newData, err := populateConfigData(configFile2) assert.NoError(t, err) + t.Logf("New value written to file: %v", newData.MyComponentConfig.StringValue) // change the symlink using the `ln -sfn` command err = changeSymLink(dataDir2, path.Join(watchDir, "data")) @@ -385,9 +372,6 @@ func TestAccessor_UpdateConfig(t *testing.T) { t.Logf("New config Location: %v", configFile2) - // Wait for filewatcher event - assert.NoError(t, waitForFileChangeOrTimeout(fileUpdated)) - time.Sleep(5 * time.Second) r = section.GetConfig().(*MyComponentConfig) @@ -414,15 +398,9 @@ func changeSymLink(targetPath, symLink string) error { return exec.Command("del", filepath.Clean(tmpLink)).Run() } - // ln -sfn is not an atomic operation. Under the hood, it first calls the system unlink then symlink calls. During - // that, there will be a brief moment when there is no symlink at all. mv operation is, however, atomic. That's - // why we make this command instead - err := exec.Command("ln", "-s", filepath.Clean(targetPath), filepath.Clean(tmpLink)).Run() - if err != nil { - return err - } - - return exec.Command("mv", "-Tf", filepath.Clean(tmpLink), filepath.Clean(symLink)).Run() + //// ln -sfn is not an atomic operation. Under the hood, it first calls the system unlink then symlink calls. During + //// that, there will be a brief moment when there is no symlink at all. + return exec.Command("ln", "-sfn", filepath.Clean(targetPath), filepath.Clean(symLink)).Run() } // 1. Create Dir structure: @@ -444,6 +422,7 @@ func newSymlinkedConfigFile(t *testing.T) (watchDir, configFile string, cleanup assert.NoError(t, err) cleanup = func() { + t.Logf("Removing watchDir [%v]", watchDir) assert.NoError(t, os.RemoveAll(watchDir)) } @@ -458,120 +437,6 @@ func newSymlinkedConfigFile(t *testing.T) (watchDir, configFile string, cleanup return watchDir, configFile, cleanup } -func waitForFileChangeOrTimeout(done chan error) error { - timeout := make(chan bool, 1) - go func() { - time.Sleep(5 * time.Second) - timeout <- true - }() - - for { - select { - case <-timeout: - return fmt.Errorf("timed out") - case err := <-done: - return err - } - } -} - -func beginWaitForFileChange(logger testLogger, filename string) (done chan error, terminalErr error) { - watcher, err := fsnotify.NewWatcher() - if err != nil { - return nil, err - } - - configFile := filepath.Clean(filename) - realConfigFile, err := filepath.EvalSymlinks(configFile) - if err != nil { - return nil, err - } - - configDir, _ := filepath.Split(configFile) - - done = make(chan error) - go func() { - for { - select { - case event, channelOpen := <-watcher.Events: - if !channelOpen { - logger.Logf("Events Channel has been closed") - done <- nil - return - } - - logger.Logf("Received watcher event [%v]", event) - // we only care about the config file - currentConfigFile, err := filepath.EvalSymlinks(filename) - if err != nil { - logger.Errorf("Failed to EvalSymLinks. Will attempt to close watcher now. Error: %v", err) - closeErr := watcher.Close() - if closeErr != nil { - logger.Errorf("Failed to close watcher. Error: %v", closeErr) - done <- closeErr - } else { - done <- err - } - - return - } - - // We only care about the config file with the following cases: - // 1 - if the config file was modified or created - // 2 - if the real path to the config file changed (eg: k8s ConfigMap replacement) - const writeOrCreateMask = fsnotify.Write | fsnotify.Create - if (filepath.Clean(event.Name) == configFile && - event.Op&writeOrCreateMask != 0) || - (currentConfigFile != "" && currentConfigFile != realConfigFile) { - - logger.Logf("CurrentConfigFile [%v], RealConfigFile [%v]", currentConfigFile, realConfigFile) - realConfigFile = currentConfigFile - closeErr := watcher.Close() - if closeErr != nil { - logger.Errorf("Failed to close watcher. Error: %v", closeErr) - } else { - done <- nil - } - - return - } else if filepath.Clean(event.Name) == configFile && - event.Op&fsnotify.Remove&fsnotify.Remove != 0 { - - logger.Logf("ConfigFile [%v] Removed.", configFile) - closeErr := watcher.Close() - if closeErr != nil { - logger.Logf("Close Watcher error: %v", closeErr) - } else { - done <- nil - } - - return - } - case err, channelOpen := <-watcher.Errors: - if !channelOpen { - logger.Logf("Error Channel has been closed.") - } else { - logger.Logf("Watcher error: %v", err) - closeErr := watcher.Close() - if closeErr != nil { - logger.Logf("Close Watcher error: %v\n", closeErr) - } - } - - done <- nil - return - } - } - }() - - err = watcher.Add(configDir) - if err != nil { - return nil, err - } - - return done, err -} - func testTypes(accessor accessorCreatorFn) func(t *testing.T) { return func(t *testing.T) { t.Run("ArrayConfigType", func(t *testing.T) { From a7ce2b52a30d7fdaa9a30a204a265e289c6f96ed Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Thu, 18 Apr 2019 00:07:53 -0700 Subject: [PATCH 0013/1918] Add marshal utils --- flytestdlib/utils/marshal_utils.go | 66 ++++++++++ flytestdlib/utils/marshal_utils_test.go | 165 ++++++++++++++++++++++++ 2 files changed, 231 insertions(+) create mode 100644 flytestdlib/utils/marshal_utils.go create mode 100644 flytestdlib/utils/marshal_utils_test.go diff --git a/flytestdlib/utils/marshal_utils.go b/flytestdlib/utils/marshal_utils.go new file mode 100644 index 0000000000..4a434108d6 --- /dev/null +++ b/flytestdlib/utils/marshal_utils.go @@ -0,0 +1,66 @@ +package utils + +import ( + "encoding/json" + "fmt" + + "github.com/golang/protobuf/jsonpb" + "github.com/golang/protobuf/proto" + structpb "github.com/golang/protobuf/ptypes/struct" +) + +var jsonPbMarshaler = jsonpb.Marshaler{} + +func UnmarshalStructToPb(structObj *structpb.Struct, msg proto.Message) error { + if structObj == nil { + return fmt.Errorf("nil Struct Object passed") + } + + jsonObj, err := jsonPbMarshaler.MarshalToString(structObj) + if err != nil { + return err + } + + if err = jsonpb.UnmarshalString(jsonObj, msg); err != nil { + return err + } + + return nil +} + +func MarshalPbToStruct(in proto.Message, out *structpb.Struct) error { + if out == nil { + return fmt.Errorf("nil Struct Object passed") + } + + jsonObj, err := jsonPbMarshaler.MarshalToString(in) + if err != nil { + return err + } + + if err = jsonpb.UnmarshalString(jsonObj, out); err != nil { + return err + } + + return nil +} + +func MarshalPbToString(msg proto.Message) (string, error) { + return jsonPbMarshaler.MarshalToString(msg) +} + +// TODO: Use the stdlib version in the future, or move there if not there. +// Don't use this if input is a proto Message. +func MarshalObjToStruct(input interface{}) (*structpb.Struct, error) { + b, err := json.Marshal(input) + if err != nil { + return nil, err + } + + // Turn JSON into a protobuf struct + structObj := &structpb.Struct{} + if err := jsonpb.UnmarshalString(string(b), structObj); err != nil { + return nil, err + } + return structObj, nil +} diff --git a/flytestdlib/utils/marshal_utils_test.go b/flytestdlib/utils/marshal_utils_test.go new file mode 100644 index 0000000000..33d185f3ee --- /dev/null +++ b/flytestdlib/utils/marshal_utils_test.go @@ -0,0 +1,165 @@ +package utils + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/golang/protobuf/proto" + structpb "github.com/golang/protobuf/ptypes/struct" +) + +// Simple proto +type TestProto struct { + StringValue string `protobuf:"bytes,3,opt,name=string_value,json=stringValue,proto3" json:"string_value,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *TestProto) Reset() { *m = TestProto{} } +func (m *TestProto) String() string { return proto.CompactTextString(m) } +func (*TestProto) ProtoMessage() {} +func (*TestProto) Descriptor() ([]byte, []int) { + return []byte{}, []int{0} +} +func (m *TestProto) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_TestProto.Unmarshal(m, b) +} +func (m *TestProto) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_TestProto.Marshal(b, m, deterministic) +} +func (dst *TestProto) XXX_Merge(src proto.Message) { + xxx_messageInfo_TestProto.Merge(dst, src) +} +func (m *TestProto) XXX_Size() int { + return xxx_messageInfo_TestProto.Size(m) +} +func (m *TestProto) XXX_DiscardUnknown() { + xxx_messageInfo_TestProto.DiscardUnknown(m) +} + +var xxx_messageInfo_TestProto proto.InternalMessageInfo + +func (m *TestProto) GetWorkflowId() string { + if m != nil { + return m.StringValue + } + return "" +} + +func init() { + proto.RegisterType((*TestProto)(nil), "test.package.TestProto") +} + +func TestMarshalPbToString(t *testing.T) { + type args struct { + msg proto.Message + } + tests := []struct { + name string + args args + want string + wantErr bool + }{ + {"empty", args{msg: &TestProto{}}, "{}", false}, + {"has value", args{msg: &TestProto{StringValue: "hello"}}, `{"stringValue":"hello"}`, false}, + {"nil input", args{msg: nil}, "", true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := MarshalPbToString(tt.args.msg) + if (err != nil) != tt.wantErr { + t.Errorf("MarshalToString() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("MarshalToString() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestMarshalObjToStruct(t *testing.T) { + type args struct { + input interface{} + } + tests := []struct { + name string + args args + want *structpb.Struct + wantErr bool + }{ + {"has value", args{input: &TestProto{StringValue: "hello"}}, &structpb.Struct{Fields: map[string]*structpb.Value{ + "string_value": {Kind: &structpb.Value_StringValue{StringValue: "hello"}}, + }}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := MarshalObjToStruct(tt.args.input) + if (err != nil) != tt.wantErr { + t.Errorf("MarshalObjToStruct() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("MarshalObjToStruct() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestUnmarshalStructToPb(t *testing.T) { + type args struct { + structObj *structpb.Struct + msg proto.Message + } + tests := []struct { + name string + args args + expected proto.Message + wantErr bool + }{ + {"empty", args{structObj: &structpb.Struct{Fields: map[string]*structpb.Value{}}, msg: &TestProto{}}, &TestProto{}, false}, + {"has value", args{structObj: &structpb.Struct{Fields: map[string]*structpb.Value{ + "stringValue": {Kind: &structpb.Value_StringValue{StringValue: "hello"}}, + }}, msg: &TestProto{}}, &TestProto{StringValue: "hello"}, false}, + {"nil input", args{structObj: nil}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := UnmarshalStructToPb(tt.args.structObj, tt.args.msg); (err != nil) != tt.wantErr { + t.Errorf("UnmarshalStructToPb() error = %v, wantErr %v", err, tt.wantErr) + } else { + assert.Equal(t, tt.expected, tt.args.msg) + } + }) + } +} + +func TestMarshalPbToStruct(t *testing.T) { + type args struct { + in proto.Message + out *structpb.Struct + } + tests := []struct { + name string + args args + expected *structpb.Struct + wantErr bool + }{ + {"empty", args{in: &TestProto{}, out: &structpb.Struct{}}, &structpb.Struct{Fields: map[string]*structpb.Value{}}, false}, + {"has value", args{in: &TestProto{StringValue: "hello"}, out: &structpb.Struct{}}, &structpb.Struct{Fields: map[string]*structpb.Value{ + "stringValue": {Kind: &structpb.Value_StringValue{StringValue: "hello"}}, + }}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := MarshalPbToStruct(tt.args.in, tt.args.out); (err != nil) != tt.wantErr { + t.Errorf("MarshalPbToStruct() error = %v, wantErr %v", err, tt.wantErr) + } else { + assert.Equal(t, tt.expected.Fields, tt.args.out.Fields) + } + }) + } +} From 362c166faaf43e1c4df6e20c89d8a59039355578 Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Thu, 18 Apr 2019 10:26:17 -0700 Subject: [PATCH 0014/1918] lint --- flytestdlib/utils/marshal_utils_test.go | 25 ++----------------------- 1 file changed, 2 insertions(+), 23 deletions(-) diff --git a/flytestdlib/utils/marshal_utils_test.go b/flytestdlib/utils/marshal_utils_test.go index 33d185f3ee..1cece1e580 100644 --- a/flytestdlib/utils/marshal_utils_test.go +++ b/flytestdlib/utils/marshal_utils_test.go @@ -7,15 +7,12 @@ import ( "github.com/stretchr/testify/assert" "github.com/golang/protobuf/proto" - structpb "github.com/golang/protobuf/ptypes/struct" + "github.com/golang/protobuf/ptypes/struct" ) // Simple proto type TestProto struct { - StringValue string `protobuf:"bytes,3,opt,name=string_value,json=stringValue,proto3" json:"string_value,omitempty"` - XXX_NoUnkeyedLiteral struct{} `json:"-"` - XXX_unrecognized []byte `json:"-"` - XXX_sizecache int32 `json:"-"` + StringValue string `protobuf:"bytes,3,opt,name=string_value,json=stringValue,proto3" json:"string_value,omitempty"` } func (m *TestProto) Reset() { *m = TestProto{} } @@ -24,24 +21,6 @@ func (*TestProto) ProtoMessage() {} func (*TestProto) Descriptor() ([]byte, []int) { return []byte{}, []int{0} } -func (m *TestProto) XXX_Unmarshal(b []byte) error { - return xxx_messageInfo_TestProto.Unmarshal(m, b) -} -func (m *TestProto) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { - return xxx_messageInfo_TestProto.Marshal(b, m, deterministic) -} -func (dst *TestProto) XXX_Merge(src proto.Message) { - xxx_messageInfo_TestProto.Merge(dst, src) -} -func (m *TestProto) XXX_Size() int { - return xxx_messageInfo_TestProto.Size(m) -} -func (m *TestProto) XXX_DiscardUnknown() { - xxx_messageInfo_TestProto.DiscardUnknown(m) -} - -var xxx_messageInfo_TestProto proto.InternalMessageInfo - func (m *TestProto) GetWorkflowId() string { if m != nil { return m.StringValue From 904eeff242ec78453baaee1b600d02eb7e67eaa0 Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Thu, 18 Apr 2019 10:31:20 -0700 Subject: [PATCH 0015/1918] lint --- flytestdlib/utils/marshal_utils_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flytestdlib/utils/marshal_utils_test.go b/flytestdlib/utils/marshal_utils_test.go index 1cece1e580..a5f1512e19 100644 --- a/flytestdlib/utils/marshal_utils_test.go +++ b/flytestdlib/utils/marshal_utils_test.go @@ -21,7 +21,7 @@ func (*TestProto) ProtoMessage() {} func (*TestProto) Descriptor() ([]byte, []int) { return []byte{}, []int{0} } -func (m *TestProto) GetWorkflowId() string { +func (m *TestProto) GetWorkflowID() string { if m != nil { return m.StringValue } From 97117a81a755a6f272c302a06497a102f9b29cb2 Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Thu, 18 Apr 2019 12:01:16 -0700 Subject: [PATCH 0016/1918] docs & refactor --- flytestdlib/utils/marshal_utils.go | 52 ++++++++++++++++--------- flytestdlib/utils/marshal_utils_test.go | 21 ++++++---- 2 files changed, 47 insertions(+), 26 deletions(-) diff --git a/flytestdlib/utils/marshal_utils.go b/flytestdlib/utils/marshal_utils.go index 4a434108d6..4d9cc14eff 100644 --- a/flytestdlib/utils/marshal_utils.go +++ b/flytestdlib/utils/marshal_utils.go @@ -1,66 +1,80 @@ package utils import ( + "bytes" "encoding/json" "fmt" - "github.com/golang/protobuf/jsonpb" "github.com/golang/protobuf/proto" - structpb "github.com/golang/protobuf/ptypes/struct" + "github.com/golang/protobuf/ptypes/struct" + "github.com/pkg/errors" ) var jsonPbMarshaler = jsonpb.Marshaler{} +// Unmarshals a proto struct into a proto message using jsonPb marshaler. func UnmarshalStructToPb(structObj *structpb.Struct, msg proto.Message) error { if structObj == nil { - return fmt.Errorf("nil Struct Object passed") + return fmt.Errorf("nil Struct object passed") + } + + if msg == nil { + return fmt.Errorf("nil proto.Message object passed") } jsonObj, err := jsonPbMarshaler.MarshalToString(structObj) if err != nil { - return err + return errors.WithMessage(err, "Failed to marshal strcutObj input") } if err = jsonpb.UnmarshalString(jsonObj, msg); err != nil { - return err + return errors.WithMessage(err, "Failed to unmarshal json obj into proto") } return nil } -func MarshalPbToStruct(in proto.Message, out *structpb.Struct) error { - if out == nil { - return fmt.Errorf("nil Struct Object passed") +// Marshals a proto message into proto Struct using jsonPb marshaler. +func MarshalPbToStruct(in proto.Message) (out *structpb.Struct, err error) { + if in == nil { + return nil, fmt.Errorf("nil proto message passed") } - jsonObj, err := jsonPbMarshaler.MarshalToString(in) - if err != nil { - return err + var buf bytes.Buffer + if err := jsonPbMarshaler.Marshal(&buf, in); err != nil { + return nil, errors.WithMessage(err, "Failed to marshal input proto message") } - if err = jsonpb.UnmarshalString(jsonObj, out); err != nil { - return err + out = &structpb.Struct{} + if err = jsonpb.Unmarshal(bytes.NewReader(buf.Bytes()), out); err != nil { + return nil, errors.WithMessage(err, "Failed to unmarshal json object into struct") } - return nil + return out, nil } +// Marshals a proto message using jsonPb marshaler to string. func MarshalPbToString(msg proto.Message) (string, error) { return jsonPbMarshaler.MarshalToString(msg) } -// TODO: Use the stdlib version in the future, or move there if not there. -// Don't use this if input is a proto Message. +// Marshals obj into a struct. Will use jsonPb if input is a proto message, otherwise, it'll use json +// marshaler. func MarshalObjToStruct(input interface{}) (*structpb.Struct, error) { + if p, casted := input.(proto.Message); casted { + return MarshalPbToStruct(p) + } + b, err := json.Marshal(input) if err != nil { - return nil, err + return nil, errors.WithMessage(err, "Failed to marshal input proto message") } // Turn JSON into a protobuf struct structObj := &structpb.Struct{} - if err := jsonpb.UnmarshalString(string(b), structObj); err != nil { - return nil, err + if err := jsonpb.Unmarshal(bytes.NewReader(b), structObj); err != nil { + return nil, errors.WithMessage(err, "Failed to unmarshal json object into struct") } + return structObj, nil } diff --git a/flytestdlib/utils/marshal_utils_test.go b/flytestdlib/utils/marshal_utils_test.go index a5f1512e19..4ac0fc130b 100644 --- a/flytestdlib/utils/marshal_utils_test.go +++ b/flytestdlib/utils/marshal_utils_test.go @@ -10,6 +10,10 @@ import ( "github.com/golang/protobuf/ptypes/struct" ) +type SimpleType struct { + StringValue string `json:"string_value,omitempty"` +} + // Simple proto type TestProto struct { StringValue string `protobuf:"bytes,3,opt,name=string_value,json=stringValue,proto3" json:"string_value,omitempty"` @@ -70,9 +74,13 @@ func TestMarshalObjToStruct(t *testing.T) { want *structpb.Struct wantErr bool }{ - {"has value", args{input: &TestProto{StringValue: "hello"}}, &structpb.Struct{Fields: map[string]*structpb.Value{ + {"has proto value", args{input: &TestProto{StringValue: "hello"}}, &structpb.Struct{Fields: map[string]*structpb.Value{ + "stringValue": {Kind: &structpb.Value_StringValue{StringValue: "hello"}}, + }}, false}, + {"has struct value", args{input: SimpleType{StringValue: "hello"}}, &structpb.Struct{Fields: map[string]*structpb.Value{ "string_value": {Kind: &structpb.Value_StringValue{StringValue: "hello"}}, }}, false}, + {"has string value", args{input: "hello"}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -118,8 +126,7 @@ func TestUnmarshalStructToPb(t *testing.T) { func TestMarshalPbToStruct(t *testing.T) { type args struct { - in proto.Message - out *structpb.Struct + in proto.Message } tests := []struct { name string @@ -127,17 +134,17 @@ func TestMarshalPbToStruct(t *testing.T) { expected *structpb.Struct wantErr bool }{ - {"empty", args{in: &TestProto{}, out: &structpb.Struct{}}, &structpb.Struct{Fields: map[string]*structpb.Value{}}, false}, - {"has value", args{in: &TestProto{StringValue: "hello"}, out: &structpb.Struct{}}, &structpb.Struct{Fields: map[string]*structpb.Value{ + {"empty", args{in: &TestProto{}}, &structpb.Struct{Fields: map[string]*structpb.Value{}}, false}, + {"has value", args{in: &TestProto{StringValue: "hello"}}, &structpb.Struct{Fields: map[string]*structpb.Value{ "stringValue": {Kind: &structpb.Value_StringValue{StringValue: "hello"}}, }}, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := MarshalPbToStruct(tt.args.in, tt.args.out); (err != nil) != tt.wantErr { + if got, err := MarshalPbToStruct(tt.args.in); (err != nil) != tt.wantErr { t.Errorf("MarshalPbToStruct() error = %v, wantErr %v", err, tt.wantErr) } else { - assert.Equal(t, tt.expected.Fields, tt.args.out.Fields) + assert.Equal(t, tt.expected.Fields, got.Fields) } }) } From c444a2cfeaafffbb97de31fd5ca862b6daddec74 Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Thu, 18 Apr 2019 12:58:47 -0700 Subject: [PATCH 0017/1918] minor --- flytestdlib/cli/pflags/api/generator.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/flytestdlib/cli/pflags/api/generator.go b/flytestdlib/cli/pflags/api/generator.go index c6ea06fc21..2b7faa4eb3 100644 --- a/flytestdlib/cli/pflags/api/generator.go +++ b/flytestdlib/cli/pflags/api/generator.go @@ -139,11 +139,10 @@ func discoverFieldsRecursive(ctx context.Context, typ *types.Named, defaultValue tag.Name = v.Name() } - isPtr := false typ := v.Type() - if ptr, casted := typ.(*types.Pointer); casted { + ptr, isPtr := typ.(*types.Pointer) + if isPtr { typ = ptr.Elem() - isPtr = true } switch t := typ.(type) { From 74d591556dd6304ac30f57795aa366276157612c Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Thu, 18 Apr 2019 13:33:26 -0700 Subject: [PATCH 0018/1918] lint --- flytestdlib/utils/marshal_utils.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/flytestdlib/utils/marshal_utils.go b/flytestdlib/utils/marshal_utils.go index 4d9cc14eff..129de4f87d 100644 --- a/flytestdlib/utils/marshal_utils.go +++ b/flytestdlib/utils/marshal_utils.go @@ -4,9 +4,10 @@ import ( "bytes" "encoding/json" "fmt" + "github.com/golang/protobuf/jsonpb" "github.com/golang/protobuf/proto" - "github.com/golang/protobuf/ptypes/struct" + structpb "github.com/golang/protobuf/ptypes/struct" "github.com/pkg/errors" ) From 578749afa715e3718db4246ba1b761a9ec58a443 Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Thu, 18 Apr 2019 14:11:48 -0700 Subject: [PATCH 0019/1918] solidify unit tests --- flytestdlib/cli/pflags/api/generator.go | 10 ++++++++-- flytestdlib/config/tests/accessor_test.go | 21 ++++++++++++++++++++ flytestdlib/utils/auto_refresh_cache_test.go | 16 ++++++++++++--- 3 files changed, 42 insertions(+), 5 deletions(-) diff --git a/flytestdlib/cli/pflags/api/generator.go b/flytestdlib/cli/pflags/api/generator.go index 2b7faa4eb3..2f41faab9e 100644 --- a/flytestdlib/cli/pflags/api/generator.go +++ b/flytestdlib/cli/pflags/api/generator.go @@ -101,10 +101,16 @@ func appendAccessors(accessors ...string) string { for _, s := range accessors { if len(s) > 0 { if sb.Len() > 0 { - sb.WriteString(".") + if _, err := sb.WriteString("."); err != nil { + fmt.Printf("Failed to writeString, error: %v", err) + return "" + } } - sb.WriteString(s) + if _, err := sb.WriteString(s); err != nil { + fmt.Printf("Failed to writeString, error: %v", err) + return "" + } } } diff --git a/flytestdlib/config/tests/accessor_test.go b/flytestdlib/config/tests/accessor_test.go index d8c9e2a5a0..e6f93095a1 100644 --- a/flytestdlib/config/tests/accessor_test.go +++ b/flytestdlib/config/tests/accessor_test.go @@ -379,6 +379,27 @@ func TestAccessor_UpdateConfig(t *testing.T) { // Make sure values have changed assert.NotEqual(t, firstValue, secondValue) }) + + t.Run(fmt.Sprintf("[%v] Default variables", provider(config.Options{}).ID()), func(t *testing.T) { + reg := config.NewRootSection() + _, err := reg.RegisterSection(MyComponentSectionKey, &MyComponentConfig{ + StringValue: "default value 1", + StringValue2: "default value 2", + }) + assert.NoError(t, err) + + v := provider(config.Options{ + SearchPaths: []string{filepath.Join("testdata", "config.yaml")}, + RootSection: reg, + }) + key := strings.ToUpper("my-component.str") + assert.NoError(t, os.Setenv(key, "Set From Env")) + defer func() { assert.NoError(t, os.Unsetenv(key)) }() + assert.NoError(t, v.UpdateConfig(context.TODO())) + r := reg.GetSection(MyComponentSectionKey).GetConfig().(*MyComponentConfig) + assert.Equal(t, "Set From Env", r.StringValue) + assert.Equal(t, "default value 2", r.StringValue2) + }) } } diff --git a/flytestdlib/utils/auto_refresh_cache_test.go b/flytestdlib/utils/auto_refresh_cache_test.go index 85a09ce508..05d80ed2ba 100644 --- a/flytestdlib/utils/auto_refresh_cache_test.go +++ b/flytestdlib/utils/auto_refresh_cache_test.go @@ -2,6 +2,7 @@ package utils import ( "context" + "sync" "testing" "time" @@ -15,6 +16,7 @@ type testCacheItem struct { val int deleted atomic.Bool resyncPeriod time.Duration + wg sync.WaitGroup } func (m *testCacheItem) ID() string { @@ -28,6 +30,8 @@ func (m *testCacheItem) moveNext() { } func (m *testCacheItem) syncItem(ctx context.Context, obj CacheItem) (CacheItem, error) { + defer func() { m.wg.Done() }() + if m.deleted.Load() { return nil, nil } @@ -51,10 +55,15 @@ func TestCache(t *testing.T) { testResyncPeriod := time.Millisecond rateLimiter := NewRateLimiter("mockLimiter", 100, 1) - item := &testCacheItem{val: 0, resyncPeriod: testResyncPeriod, deleted: atomic.NewBool(false)} + wg := sync.WaitGroup{} + wg.Add(1) + item := &testCacheItem{ + val: 0, + resyncPeriod: testResyncPeriod, + deleted: atomic.NewBool(false), + wg: wg,} cache := NewAutoRefreshCache(item.syncItem, rateLimiter, testResyncPeriod) - //ctx := context.Background() ctx, cancel := context.WithCancel(context.Background()) cache.Start(ctx) @@ -75,7 +84,8 @@ func TestCache(t *testing.T) { // removed? item.moveNext() item.deleted.Store(true) - time.Sleep(testResyncPeriod * 10) // spare enough time to process remove! + wg.Wait() + time.Sleep(testResyncPeriod * 2) // spare enough time to process remove! val := cache.Get(item.ID()) assert.Nil(t, val) From 7e82abf88f3579580e9c1edc386b5f724bcc65d0 Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Thu, 18 Apr 2019 14:34:14 -0700 Subject: [PATCH 0020/1918] Refactor test --- flytestdlib/utils/auto_refresh_cache_test.go | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/flytestdlib/utils/auto_refresh_cache_test.go b/flytestdlib/utils/auto_refresh_cache_test.go index 05d80ed2ba..b7007a702d 100644 --- a/flytestdlib/utils/auto_refresh_cache_test.go +++ b/flytestdlib/utils/auto_refresh_cache_test.go @@ -2,7 +2,6 @@ package utils import ( "context" - "sync" "testing" "time" @@ -16,7 +15,7 @@ type testCacheItem struct { val int deleted atomic.Bool resyncPeriod time.Duration - wg sync.WaitGroup + synced atomic.Int32 } func (m *testCacheItem) ID() string { @@ -30,11 +29,12 @@ func (m *testCacheItem) moveNext() { } func (m *testCacheItem) syncItem(ctx context.Context, obj CacheItem) (CacheItem, error) { - defer func() { m.wg.Done() }() + defer func() { m.synced.Inc() }() if m.deleted.Load() { return nil, nil } + return m, nil } @@ -55,13 +55,11 @@ func TestCache(t *testing.T) { testResyncPeriod := time.Millisecond rateLimiter := NewRateLimiter("mockLimiter", 100, 1) - wg := sync.WaitGroup{} - wg.Add(1) item := &testCacheItem{ val: 0, resyncPeriod: testResyncPeriod, deleted: atomic.NewBool(false), - wg: wg,} + synced: atomic.NewInt32(0),} cache := NewAutoRefreshCache(item.syncItem, rateLimiter, testResyncPeriod) ctx, cancel := context.WithCancel(context.Background()) @@ -83,9 +81,14 @@ func TestCache(t *testing.T) { // removed? item.moveNext() + currentSyncCount := item.synced.Load() item.deleted.Store(true) - wg.Wait() - time.Sleep(testResyncPeriod * 2) // spare enough time to process remove! + for currentSyncCount == item.synced.Load() { + time.Sleep(testResyncPeriod * 5) // spare enough time to process remove! + } + + time.Sleep(testResyncPeriod * 10) // spare enough time to process remove! + val := cache.Get(item.ID()) assert.Nil(t, val) From 5ec38597168d850b0076ed92bd933bcd6b37d50f Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Thu, 18 Apr 2019 14:52:49 -0700 Subject: [PATCH 0021/1918] lint --- flytestdlib/utils/auto_refresh_cache_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flytestdlib/utils/auto_refresh_cache_test.go b/flytestdlib/utils/auto_refresh_cache_test.go index b7007a702d..dab8c9d907 100644 --- a/flytestdlib/utils/auto_refresh_cache_test.go +++ b/flytestdlib/utils/auto_refresh_cache_test.go @@ -59,7 +59,7 @@ func TestCache(t *testing.T) { val: 0, resyncPeriod: testResyncPeriod, deleted: atomic.NewBool(false), - synced: atomic.NewInt32(0),} + synced: atomic.NewInt32(0)} cache := NewAutoRefreshCache(item.syncItem, rateLimiter, testResyncPeriod) ctx, cancel := context.WithCancel(context.Background()) From 29a2c517ca022613266b7093853ac0a2c865f551 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Mon, 22 Apr 2019 14:13:53 -0700 Subject: [PATCH 0022/1918] adding lru Cache --- flytestdlib/Gopkg.lock | 13 ++++ flytestdlib/utils/auto_refresh_cache.go | 98 ++++++++++++++++++------- 2 files changed, 85 insertions(+), 26 deletions(-) diff --git a/flytestdlib/Gopkg.lock b/flytestdlib/Gopkg.lock index a56ebd2dba..c2d2438f89 100644 --- a/flytestdlib/Gopkg.lock +++ b/flytestdlib/Gopkg.lock @@ -162,6 +162,17 @@ pruneopts = "UT" revision = "77c84b1dd69c41b74fe0a94ca8ee257d85947327" +[[projects]] + digest = "1:d15ee511aa0f56baacc1eb4c6b922fa1c03b38413b6be18166b996d82a0156ea" + name = "github.com/hashicorp/golang-lru" + packages = [ + ".", + "simplelru", + ] + pruneopts = "UT" + revision = "7087cb70de9f7a8bc0a10c375cb0d2280a8edf9c" + version = "v0.5.1" + [[projects]] digest = "1:c0d19ab64b32ce9fe5cf4ddceba78d5bc9807f0016db6b1183599da3dcc24d10" name = "github.com/hashicorp/hcl" @@ -492,10 +503,12 @@ "github.com/golang/protobuf/proto", "github.com/golang/protobuf/ptypes", "github.com/golang/protobuf/ptypes/duration", + "github.com/golang/protobuf/ptypes/struct", "github.com/golang/protobuf/ptypes/timestamp", "github.com/graymeta/stow", "github.com/graymeta/stow/local", "github.com/graymeta/stow/s3", + "github.com/hashicorp/golang-lru", "github.com/magiconair/properties/assert", "github.com/mitchellh/mapstructure", "github.com/pkg/errors", diff --git a/flytestdlib/utils/auto_refresh_cache.go b/flytestdlib/utils/auto_refresh_cache.go index a231bbd7c2..e231fa558c 100644 --- a/flytestdlib/utils/auto_refresh_cache.go +++ b/flytestdlib/utils/auto_refresh_cache.go @@ -2,11 +2,13 @@ package utils import ( "context" - "sync" - "time" - + "github.com/hashicorp/golang-lru" "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/promutils" + "github.com/prometheus/client_golang/prometheus" "k8s.io/apimachinery/pkg/util/wait" + "sync" + "time" ) // AutoRefreshCache with regular GetOrCreate and Delete along with background asynchronous refresh. Caller provides @@ -27,16 +29,50 @@ type CacheItem interface { ID() string } -type CacheSyncItem func(ctx context.Context, obj CacheItem) (CacheItem, error) +// Possible actions for the cache to take as a result of running the sync function on any given cache item +type CacheSyncAction int + +const ( + // The item returned has been updated and should be updated in the cache + Update CacheSyncAction = iota + + // The item should be removed from the cache + Delete +) + +type CacheSyncItem func(ctx context.Context, obj CacheItem) ( + newItem CacheItem, result CacheSyncAction, err error) + +func getEvictionFunction(counter prometheus.Counter) func(key interface{}, value interface{}) { + return func(_ interface{}, _ interface{}) { + counter.Inc() + } +} + +func NewAutoRefreshCache(syncCb CacheSyncItem, syncRateLimiter RateLimiter, resyncPeriod time.Duration, + size int, scope promutils.Scope) (AutoRefreshCache, error) { + + var evictionFunction func(key interface{}, value interface{}) + + // If a scope is specified, we'll add a function to log a metric when evicting + if scope != nil { + counter := scope.MustNewCounter("lru_evictions", "Counter for evictions from LRU") + evictionFunction = getEvictionFunction(counter) + } + lruCache, err := lru.NewWithEvict(size, evictionFunction) + if err != nil { + return nil, err + } -func NewAutoRefreshCache(syncCb CacheSyncItem, syncRateLimiter RateLimiter, resyncPeriod time.Duration) AutoRefreshCache { cache := &autoRefreshCache{ syncCb: syncCb, + lruMap: *lruCache, syncRateLimiter: syncRateLimiter, resyncPeriod: resyncPeriod, + scope: scope, } - return cache + return cache, nil } // Thread-safe general purpose auto-refresh cache that watches for updates asynchronously for the keys after they are added to @@ -48,8 +84,10 @@ func NewAutoRefreshCache(syncCb CacheSyncItem, syncRateLimiter RateLimiter, resy type autoRefreshCache struct { syncCb CacheSyncItem syncMap sync.Map + lruMap lru.Cache syncRateLimiter RateLimiter resyncPeriod time.Duration + scope promutils.Scope } func (w *autoRefreshCache) Start(ctx context.Context) { @@ -57,7 +95,7 @@ func (w *autoRefreshCache) Start(ctx context.Context) { } func (w *autoRefreshCache) Get(id string) CacheItem { - if val, ok := w.syncMap.Load(id); ok { + if val, ok := w.lruMap.Get(id); ok { return val.(CacheItem) } return nil @@ -66,34 +104,42 @@ func (w *autoRefreshCache) Get(id string) CacheItem { // Return the item if exists else create it. // Create should be invoked only once. recreating the object is not supported. func (w *autoRefreshCache) GetOrCreate(item CacheItem) (CacheItem, error) { - if val, ok := w.syncMap.Load(item.ID()); ok { + if val, ok := w.lruMap.Get(item.ID()); ok { return val.(CacheItem), nil } - w.syncMap.Store(item.ID(), item) + w.lruMap.Add(item.ID(), item) return item, nil } +// This function is called internally by its own timer. Roughly, it will, +// - List keys +// - For each of the keys, call syncCb, which tells us if the item has been updated +// - If it has, then do a remove followed by an add. We can get away with this because it is guaranteed that +// this loop will run to completion before the next one begins. +// +// What happens when the number of things that a user is trying to keep track of exceeds the size +// of the cache? Trivial case where the cache is size 1 and we're trying to keep track of two things. +// * Plugin asks for update on item 1 - cache evicts item 2, stores 1 and returns it unchanged +// * Plugin asks for update on item 2 - cache evicts item 1, stores 2 and returns it unchanged +// * Sync loop updates item 2, repeat func (w *autoRefreshCache) sync(ctx context.Context) { - w.syncMap.Range(func(key, value interface{}) bool { - if w.syncRateLimiter != nil { - err := w.syncRateLimiter.Wait(ctx) + keys := w.lruMap.Keys() + for _, k := range keys { + // If not ok, it means evicted between the item was evicted between calling the keys and the iteration loop + if value, ok := w.lruMap.Peek(k); ok { + newItem, result, err := w.syncCb(ctx, value.(CacheItem)) if err != nil { - logger.Warnf(ctx, "unexpected failure in rate-limiter wait %v", key) - return true + logger.Error(ctx, "failed to get latest copy of the item %v", key) } - } - item, err := w.syncCb(ctx, value.(CacheItem)) - if err != nil { - logger.Error(ctx, "failed to get latest copy of the item %v", key) - } - if item == nil { - w.syncMap.Delete(key) - } else { - w.syncMap.Store(key, item) - } + if result == Update { + w.lruMap.Remove(k) + w.lruMap.Add(k, newItem) - return true - }) + } else if result == Delete { + w.lruMap.Remove(k) + } + } + } } From a54187bc39049f772cc808791c721827cf73774e Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Mon, 22 Apr 2019 14:21:12 -0700 Subject: [PATCH 0023/1918] add a state for unchanged --- flytestdlib/utils/auto_refresh_cache.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/flytestdlib/utils/auto_refresh_cache.go b/flytestdlib/utils/auto_refresh_cache.go index e231fa558c..ffe9b61429 100644 --- a/flytestdlib/utils/auto_refresh_cache.go +++ b/flytestdlib/utils/auto_refresh_cache.go @@ -31,10 +31,11 @@ type CacheItem interface { // Possible actions for the cache to take as a result of running the sync function on any given cache item type CacheSyncAction int - const ( + Unchanged CacheSyncAction = iota + // The item returned has been updated and should be updated in the cache - Update CacheSyncAction = iota + Update // The item should be removed from the cache Delete @@ -52,9 +53,8 @@ func getEvictionFunction(counter prometheus.Counter) func(key interface{}, value func NewAutoRefreshCache(syncCb CacheSyncItem, syncRateLimiter RateLimiter, resyncPeriod time.Duration, size int, scope promutils.Scope) (AutoRefreshCache, error) { + // If a scope is specified, we'll add a function to log a metric when an object gets evicted var evictionFunction func(key interface{}, value interface{}) - - // If a scope is specified, we'll add a function to log a metric when evicting if scope != nil { counter := scope.MustNewCounter("lru_evictions", "Counter for evictions from LRU") evictionFunction = getEvictionFunction(counter) @@ -130,7 +130,7 @@ func (w *autoRefreshCache) sync(ctx context.Context) { if value, ok := w.lruMap.Peek(k); ok { newItem, result, err := w.syncCb(ctx, value.(CacheItem)) if err != nil { - logger.Error(ctx, "failed to get latest copy of the item %v", key) + logger.Error(ctx, "failed to get latest copy of the item %v", k) } if result == Update { From 6b9a891eb45b1fd11e134c7385c7be533c4deb32 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Mon, 22 Apr 2019 17:58:41 -0700 Subject: [PATCH 0024/1918] adding basic unit test --- flytestdlib/utils/auto_refresh_cache.go | 11 +- flytestdlib/utils/auto_refresh_cache_test.go | 165 ++++++++---------- .../utils/auto_refresh_example_test.go | 14 +- 3 files changed, 89 insertions(+), 101 deletions(-) diff --git a/flytestdlib/utils/auto_refresh_cache.go b/flytestdlib/utils/auto_refresh_cache.go index ffe9b61429..59a1d916e1 100644 --- a/flytestdlib/utils/auto_refresh_cache.go +++ b/flytestdlib/utils/auto_refresh_cache.go @@ -31,6 +31,7 @@ type CacheItem interface { // Possible actions for the cache to take as a result of running the sync function on any given cache item type CacheSyncAction int + const ( Unchanged CacheSyncAction = iota @@ -41,6 +42,10 @@ const ( Delete ) +// Your implementation of this function for your cache instance is responsible for returning +// 1. The new CacheItem, and +// 2. What action should be taken. The sync function has no insight into your object, and needs to be +// told explicitly if the new item is different from the old one. type CacheSyncItem func(ctx context.Context, obj CacheItem) ( newItem CacheItem, result CacheSyncAction, err error) @@ -105,9 +110,11 @@ func (w *autoRefreshCache) Get(id string) CacheItem { // Create should be invoked only once. recreating the object is not supported. func (w *autoRefreshCache) GetOrCreate(item CacheItem) (CacheItem, error) { if val, ok := w.lruMap.Get(item.ID()); ok { + //fmt.Println("existing") return val.(CacheItem), nil } + //fmt.Println("adding") w.lruMap.Add(item.ID(), item) return item, nil } @@ -126,7 +133,8 @@ func (w *autoRefreshCache) GetOrCreate(item CacheItem) (CacheItem, error) { func (w *autoRefreshCache) sync(ctx context.Context) { keys := w.lruMap.Keys() for _, k := range keys { - // If not ok, it means evicted between the item was evicted between calling the keys and the iteration loop + // If not ok, it means evicted between the item was evicted between getting the keys and this update loop + // which is fine, we can just ignore. if value, ok := w.lruMap.Peek(k); ok { newItem, result, err := w.syncCb(ctx, value.(CacheItem)) if err != nil { @@ -136,7 +144,6 @@ func (w *autoRefreshCache) sync(ctx context.Context) { if result == Update { w.lruMap.Remove(k) w.lruMap.Add(k, newItem) - } else if result == Delete { w.lruMap.Remove(k) } diff --git a/flytestdlib/utils/auto_refresh_cache_test.go b/flytestdlib/utils/auto_refresh_cache_test.go index dab8c9d907..d0968892ea 100644 --- a/flytestdlib/utils/auto_refresh_cache_test.go +++ b/flytestdlib/utils/auto_refresh_cache_test.go @@ -2,120 +2,93 @@ package utils import ( "context" + "fmt" + "github.com/stretchr/testify/assert" "testing" "time" - - atomic2 "sync/atomic" - - "github.com/lyft/flytestdlib/atomic" - "github.com/stretchr/testify/assert" ) -type testCacheItem struct { - val int - deleted atomic.Bool - resyncPeriod time.Duration - synced atomic.Int32 -} +const fakeCacheItemValueLimit = 10 -func (m *testCacheItem) ID() string { - return "id" +type fakeCacheItem struct { + id string + val int } -func (m *testCacheItem) moveNext() { - // change value and spare enough time for cache to process the change. - m.val++ - time.Sleep(m.resyncPeriod * 5) +func (f fakeCacheItem) ID() string { + return f.id } -func (m *testCacheItem) syncItem(ctx context.Context, obj CacheItem) (CacheItem, error) { - defer func() { m.synced.Inc() }() - - if m.deleted.Load() { - return nil, nil +func syncFakeItem(_ context.Context, obj CacheItem) (CacheItem, CacheSyncAction, error) { + item := obj.(fakeCacheItem) + if item.val == fakeCacheItemValueLimit { + // After the item has gone through ten update cycles, leave it unchanged + return obj, Unchanged, nil } - return m, nil -} - -type testAutoIncrementItem struct { - val int32 + return fakeCacheItem{id: item.ID(), val: item.val + 1}, Update, nil } -func (a *testAutoIncrementItem) ID() string { - return "autoincrement" +func syncFakeItemLagged(ctx context.Context, obj CacheItem) (CacheItem, CacheSyncAction, error) { + time.Sleep(100 * time.Millisecond) + return syncFakeItem(ctx, obj) } -func (a *testAutoIncrementItem) syncItem(ctx context.Context, obj CacheItem) (CacheItem, error) { - atomic2.AddInt32(&a.val, 1) - return a, nil +func syncFakeItemAlwaysDelete(_ context.Context, obj CacheItem) (CacheItem, CacheSyncAction, error) { + return obj, Delete, nil } -func TestCache(t *testing.T) { +func TestCacheTwo(t *testing.T) { testResyncPeriod := time.Millisecond rateLimiter := NewRateLimiter("mockLimiter", 100, 1) - item := &testCacheItem{ - val: 0, - resyncPeriod: testResyncPeriod, - deleted: atomic.NewBool(false), - synced: atomic.NewInt32(0)} - cache := NewAutoRefreshCache(item.syncItem, rateLimiter, testResyncPeriod) - - ctx, cancel := context.WithCancel(context.Background()) - cache.Start(ctx) - - // create - _, err := cache.GetOrCreate(item) - assert.NoError(t, err, "unexpected GetOrCreate failure") - - // synced? - item.moveNext() - m := cache.Get(item.ID()).(*testCacheItem) - assert.Equal(t, 1, m.val) - - // synced again? - item.moveNext() - m = cache.Get(item.ID()).(*testCacheItem) - assert.Equal(t, 2, m.val) - - // removed? - item.moveNext() - currentSyncCount := item.synced.Load() - item.deleted.Store(true) - for currentSyncCount == item.synced.Load() { - time.Sleep(testResyncPeriod * 5) // spare enough time to process remove! - } - - time.Sleep(testResyncPeriod * 10) // spare enough time to process remove! - - val := cache.Get(item.ID()) - - assert.Nil(t, val) - cancel() -} - -func TestCacheContextCancel(t *testing.T) { - testResyncPeriod := time.Millisecond - rateLimiter := NewRateLimiter("mockLimiter", 10000, 1) - - item := &testAutoIncrementItem{val: 0} - cache := NewAutoRefreshCache(item.syncItem, rateLimiter, testResyncPeriod) - - ctx, cancel := context.WithCancel(context.Background()) - cache.Start(ctx) - _, err := cache.GetOrCreate(item) - assert.NoError(t, err, "failed to add item to cache") - time.Sleep(testResyncPeriod * 10) // spare enough time to process remove! - cancel() - - // Get item - m, err := cache.GetOrCreate(item) - val1 := m.(*testAutoIncrementItem).val - assert.NoError(t, err, "unexpected GetOrCreate failure") - - // wait a few more resync periods and check that nothings has changed as auto-refresh is stopped - time.Sleep(testResyncPeriod * 20) - val2 := m.(*testAutoIncrementItem).val - assert.Equal(t, val1, val2) + t.Run("normal operation", func(t *testing.T) { + // the size of the cache is at least as large as the number of items we're storing + cache, err := NewAutoRefreshCache(syncFakeItem, rateLimiter, testResyncPeriod, 10, nil) + assert.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cache.Start(ctx) + + // Create ten items in the cache + for i := 1; i <= 10; i++ { + cache.GetOrCreate(fakeCacheItem{ + id: fmt.Sprintf("%d", i), + val: 0, + }) + } + + // Wait half a second for all resync periods to complete + time.Sleep(500 * time.Millisecond) + for i := 1; i <= 10; i++ { + item := cache.Get(fmt.Sprintf("%d", i)) + assert.Equal(t, 10, item.(fakeCacheItem).val) + } + cancel() + }) + + t.Run("deleting objects from cache", func(t *testing.T) { + // the size of the cache is at least as large as the number of items we're storing + cache, err := NewAutoRefreshCache(syncFakeItemAlwaysDelete, rateLimiter, testResyncPeriod, 10, nil) + assert.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cache.Start(ctx) + + // Create ten items in the cache + for i := 1; i <= 10; i++ { + cache.GetOrCreate(fakeCacheItem{ + id: fmt.Sprintf("%d", i), + val: 0, + }) + } + + // Wait for all resync periods to complete + time.Sleep(50 * time.Millisecond) + for i := 1; i <= 10; i++ { + obj := cache.Get(fmt.Sprintf("%d", i)) + assert.Nil(t, obj) + } + cancel() + }) } diff --git a/flytestdlib/utils/auto_refresh_example_test.go b/flytestdlib/utils/auto_refresh_example_test.go index 1c47d39ec5..b01ef65a1f 100644 --- a/flytestdlib/utils/auto_refresh_example_test.go +++ b/flytestdlib/utils/auto_refresh_example_test.go @@ -46,8 +46,13 @@ func ExampleNewAutoRefreshCache() { exampleService := newExampleService() // define a sync method that the cache can use to auto-refresh in background - syncItemCb := func(ctx context.Context, obj CacheItem) (CacheItem, error) { - return exampleService.getStatus(obj.(*ExampleCacheItem).ID()), nil + syncItemCb := func(ctx context.Context, obj CacheItem) (CacheItem, CacheSyncAction, error) { + oldItem := obj.(*ExampleCacheItem) + newItem := exampleService.getStatus(oldItem.ID()) + if newItem.status != oldItem.status { + return newItem, Update, nil + } + return newItem, Unchanged, nil } // define resync period as time duration we want cache to refresh. We can go as low as we want but cache @@ -61,7 +66,10 @@ func ExampleNewAutoRefreshCache() { // since cache refreshes itself asynchronously, it may not notice that an object has been deleted immediately, // so users of the cache should have the delete logic aware of this shortcoming (eg. not-exists may be a valid // error during removal if based on status in cache). - cache := NewAutoRefreshCache(syncItemCb, rateLimiter, resyncPeriod) + cache, err := NewAutoRefreshCache(syncItemCb, rateLimiter, resyncPeriod, 100, nil) + if err != nil { + panic(err) + } // start the cache with a context that would be to stop the cache by cancelling the context ctx, cancel := context.WithCancel(context.Background()) From e9d5819be538975f233295d49191b57f3cdcd206 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Mon, 22 Apr 2019 18:28:24 -0700 Subject: [PATCH 0025/1918] go imports --- flytestdlib/utils/auto_refresh_cache.go | 7 +++---- flytestdlib/utils/auto_refresh_cache_test.go | 14 ++++++-------- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/flytestdlib/utils/auto_refresh_cache.go b/flytestdlib/utils/auto_refresh_cache.go index 59a1d916e1..5242beaf63 100644 --- a/flytestdlib/utils/auto_refresh_cache.go +++ b/flytestdlib/utils/auto_refresh_cache.go @@ -2,13 +2,13 @@ package utils import ( "context" - "github.com/hashicorp/golang-lru" + "time" + + lru "github.com/hashicorp/golang-lru" "github.com/lyft/flytestdlib/logger" "github.com/lyft/flytestdlib/promutils" "github.com/prometheus/client_golang/prometheus" "k8s.io/apimachinery/pkg/util/wait" - "sync" - "time" ) // AutoRefreshCache with regular GetOrCreate and Delete along with background asynchronous refresh. Caller provides @@ -88,7 +88,6 @@ func NewAutoRefreshCache(syncCb CacheSyncItem, syncRateLimiter RateLimiter, resy // Sync is run as a fixed-interval-scheduled-task, and is skipped if sync from previous cycle is still running. type autoRefreshCache struct { syncCb CacheSyncItem - syncMap sync.Map lruMap lru.Cache syncRateLimiter RateLimiter resyncPeriod time.Duration diff --git a/flytestdlib/utils/auto_refresh_cache_test.go b/flytestdlib/utils/auto_refresh_cache_test.go index d0968892ea..c30dcde769 100644 --- a/flytestdlib/utils/auto_refresh_cache_test.go +++ b/flytestdlib/utils/auto_refresh_cache_test.go @@ -3,9 +3,10 @@ package utils import ( "context" "fmt" - "github.com/stretchr/testify/assert" "testing" "time" + + "github.com/stretchr/testify/assert" ) const fakeCacheItemValueLimit = 10 @@ -29,11 +30,6 @@ func syncFakeItem(_ context.Context, obj CacheItem) (CacheItem, CacheSyncAction, return fakeCacheItem{id: item.ID(), val: item.val + 1}, Update, nil } -func syncFakeItemLagged(ctx context.Context, obj CacheItem) (CacheItem, CacheSyncAction, error) { - time.Sleep(100 * time.Millisecond) - return syncFakeItem(ctx, obj) -} - func syncFakeItemAlwaysDelete(_ context.Context, obj CacheItem) (CacheItem, CacheSyncAction, error) { return obj, Delete, nil } @@ -52,10 +48,11 @@ func TestCacheTwo(t *testing.T) { // Create ten items in the cache for i := 1; i <= 10; i++ { - cache.GetOrCreate(fakeCacheItem{ + _, err := cache.GetOrCreate(fakeCacheItem{ id: fmt.Sprintf("%d", i), val: 0, }) + assert.NoError(t, err) } // Wait half a second for all resync periods to complete @@ -77,10 +74,11 @@ func TestCacheTwo(t *testing.T) { // Create ten items in the cache for i := 1; i <= 10; i++ { - cache.GetOrCreate(fakeCacheItem{ + _, err = cache.GetOrCreate(fakeCacheItem{ id: fmt.Sprintf("%d", i), val: 0, }) + assert.NoError(t, err) } // Wait for all resync periods to complete From 3cb2359358cce0ba70a8d36ef7c718f26ccc9f3b Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Tue, 23 Apr 2019 09:27:52 -0700 Subject: [PATCH 0026/1918] linting is complaining that i'm copying a mutex even though i copy it before using it --- flytestdlib/utils/auto_refresh_cache.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flytestdlib/utils/auto_refresh_cache.go b/flytestdlib/utils/auto_refresh_cache.go index 5242beaf63..5f1f808912 100644 --- a/flytestdlib/utils/auto_refresh_cache.go +++ b/flytestdlib/utils/auto_refresh_cache.go @@ -71,7 +71,7 @@ func NewAutoRefreshCache(syncCb CacheSyncItem, syncRateLimiter RateLimiter, resy cache := &autoRefreshCache{ syncCb: syncCb, - lruMap: *lruCache, + lruMap: lruCache, syncRateLimiter: syncRateLimiter, resyncPeriod: resyncPeriod, scope: scope, @@ -88,7 +88,7 @@ func NewAutoRefreshCache(syncCb CacheSyncItem, syncRateLimiter RateLimiter, resy // Sync is run as a fixed-interval-scheduled-task, and is skipped if sync from previous cycle is still running. type autoRefreshCache struct { syncCb CacheSyncItem - lruMap lru.Cache + lruMap *lru.Cache syncRateLimiter RateLimiter resyncPeriod time.Duration scope promutils.Scope From 7987dc537e5590db5d717a257467d09fc4b16b91 Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Tue, 23 Apr 2019 11:18:38 -0700 Subject: [PATCH 0027/1918] Update utils/auto_refresh_cache.go Co-Authored-By: wild-endeavor --- flytestdlib/utils/auto_refresh_cache.go | 1 - 1 file changed, 1 deletion(-) diff --git a/flytestdlib/utils/auto_refresh_cache.go b/flytestdlib/utils/auto_refresh_cache.go index 5f1f808912..903d8d4cba 100644 --- a/flytestdlib/utils/auto_refresh_cache.go +++ b/flytestdlib/utils/auto_refresh_cache.go @@ -113,7 +113,6 @@ func (w *autoRefreshCache) GetOrCreate(item CacheItem) (CacheItem, error) { return val.(CacheItem), nil } - //fmt.Println("adding") w.lruMap.Add(item.ID(), item) return item, nil } From 7f46ee0524082893eeaa546c3b604366167fe553 Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Tue, 23 Apr 2019 11:18:44 -0700 Subject: [PATCH 0028/1918] Update utils/auto_refresh_cache.go Co-Authored-By: wild-endeavor --- flytestdlib/utils/auto_refresh_cache.go | 1 - 1 file changed, 1 deletion(-) diff --git a/flytestdlib/utils/auto_refresh_cache.go b/flytestdlib/utils/auto_refresh_cache.go index 903d8d4cba..7d2b5e076f 100644 --- a/flytestdlib/utils/auto_refresh_cache.go +++ b/flytestdlib/utils/auto_refresh_cache.go @@ -109,7 +109,6 @@ func (w *autoRefreshCache) Get(id string) CacheItem { // Create should be invoked only once. recreating the object is not supported. func (w *autoRefreshCache) GetOrCreate(item CacheItem) (CacheItem, error) { if val, ok := w.lruMap.Get(item.ID()); ok { - //fmt.Println("existing") return val.(CacheItem), nil } From 0a2e1aac58c3473887c9dbceb6ddff0fc587bb78 Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Tue, 23 Apr 2019 11:18:51 -0700 Subject: [PATCH 0029/1918] Update utils/auto_refresh_cache.go Co-Authored-By: wild-endeavor --- flytestdlib/utils/auto_refresh_cache.go | 1 - 1 file changed, 1 deletion(-) diff --git a/flytestdlib/utils/auto_refresh_cache.go b/flytestdlib/utils/auto_refresh_cache.go index 7d2b5e076f..3453483914 100644 --- a/flytestdlib/utils/auto_refresh_cache.go +++ b/flytestdlib/utils/auto_refresh_cache.go @@ -139,7 +139,6 @@ func (w *autoRefreshCache) sync(ctx context.Context) { } if result == Update { - w.lruMap.Remove(k) w.lruMap.Add(k, newItem) } else if result == Delete { w.lruMap.Remove(k) From 6df902fc2067725a37336413f276756fb4f89bc3 Mon Sep 17 00:00:00 2001 From: Ketan Umare Date: Tue, 23 Apr 2019 15:08:31 -0700 Subject: [PATCH 0030/1918] Add free badges for the project (#4) - Badges for release, build passing, license and docs - TODO add badge for coverage --- flytestdlib/{README.rst => README.md} | 5 +++++ 1 file changed, 5 insertions(+) rename flytestdlib/{README.rst => README.md} (72%) diff --git a/flytestdlib/README.rst b/flytestdlib/README.md similarity index 72% rename from flytestdlib/README.rst rename to flytestdlib/README.md index a43dc38bff..6411abcfb5 100644 --- a/flytestdlib/README.rst +++ b/flytestdlib/README.md @@ -1,5 +1,10 @@ Common Go Tools ===================== +[![Current Release](https://img.shields.io/github/release/flytestdlib.svg)](https://github.com/lyft/flytestdlib/releases/latest) +[![Build Status](https://travis-ci.org/lyft/flytestdlib.svg?branch=master)](https://travis-ci.org/lyft/flytestdlib) +[![GoDoc](https://godoc.org/github.com/lyft/flytestdlib?status.svg)](https://godoc.org/github.com/lyft/flytestdlib) +[![License](https://img.shields.io/badge/LICENSE-Apache2.0-ff69b4.svg)](http://www.apache.org/licenses/LICENSE-2.0.html) + Shared components we found ourselves building time and time again, so we collected them in one place! This library consists of: From 556fa6ca68d8d312331953a3d44da4bb37a45254 Mon Sep 17 00:00:00 2001 From: Ketan Umare Date: Tue, 23 Apr 2019 16:11:47 -0700 Subject: [PATCH 0031/1918] Updated Version Handler to support BuildTime (#7) * Updated Version Handler to support BuildTime - Also added some documentation - Version information is outputted as JSON * lint fix --- flytestdlib/profutils/server.go | 15 ++++++++++++++- flytestdlib/profutils/server_test.go | 12 +++++++++++- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/flytestdlib/profutils/server.go b/flytestdlib/profutils/server.go index b8240cf25c..b4dcc00ac3 100644 --- a/flytestdlib/profutils/server.go +++ b/flytestdlib/profutils/server.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "net/http" + "time" "github.com/lyft/flytestdlib/config" @@ -28,12 +29,20 @@ const ( contentTypeJSON = "application/json; charset=utf-8" ) +type BuildVersion struct { + Build string `json:"build"` + Version string `json:"version"` + Timestamp time.Time `json:"timestamp,string"` +} + +// Writes a string to the Http output stream func WriteStringResponse(resp http.ResponseWriter, code int, body string) error { resp.WriteHeader(code) _, err := resp.Write([]byte(body)) return err } +// Writes a JSON to the http output stream func WriteJSONResponse(resp http.ResponseWriter, code int, body interface{}) error { resp.Header().Set(contentTypeHeader, contentTypeJSON) resp.WriteHeader(code) @@ -44,6 +53,8 @@ func WriteJSONResponse(resp http.ResponseWriter, code int, body interface{}) err return WriteStringResponse(resp, http.StatusOK, string(j)) } +// Simple healthcheck module that returns OK and provides a simple L7 healthcheck +// TODO we may want to provide a simple function that returns a bool, where users could provide custom healthchecks func healtcheckHandler(w http.ResponseWriter, req *http.Request) { err := WriteStringResponse(w, http.StatusOK, http.StatusText(http.StatusOK)) if err != nil { @@ -51,13 +62,15 @@ func healtcheckHandler(w http.ResponseWriter, req *http.Request) { } } +// Handler that returns a JSON response indicating the Build Version information (refer to #version module) func versionHandler(w http.ResponseWriter, req *http.Request) { - err := WriteStringResponse(w, http.StatusOK, fmt.Sprintf("Build [%s], Version [%s]", version.Build, version.Version)) + err := WriteJSONResponse(w, http.StatusOK, BuildVersion{Build: version.Build, Version: version.Version, Timestamp: version.BuildTime}) if err != nil { panic(err) } } +// Provides a handler that dumps the config information as a string func configHandler(w http.ResponseWriter, req *http.Request) { m, err := config.AllConfigsAsMap(config.GetRootSection()) if err != nil { diff --git a/flytestdlib/profutils/server_test.go b/flytestdlib/profutils/server_test.go index a7ef350002..deb3383128 100644 --- a/flytestdlib/profutils/server_test.go +++ b/flytestdlib/profutils/server_test.go @@ -4,6 +4,9 @@ import ( "encoding/json" "net/http" "testing" + "time" + + "github.com/lyft/flytestdlib/version" "github.com/lyft/flytestdlib/internal/utils" @@ -85,9 +88,16 @@ func TestVersionHandler(t *testing.T) { URL: &testURL, } + version.BuildTime = time.Now() + http.DefaultServeMux.ServeHTTP(writer, request) assert.Equal(t, http.StatusOK, writer.Status) - assert.Equal(t, `Build [unknown], Version [unknown]`, string(writer.Body)) + assert.NotNil(t, writer.Body) + bv := BuildVersion{} + assert.NoError(t, json.Unmarshal(writer.Body, &bv)) + assert.Equal(t, bv.Timestamp.Unix(), version.BuildTime.Unix()) + assert.Equal(t, bv.Build, version.Build) + assert.Equal(t, bv.Version, version.Version) } func TestHealthcheckHandler(t *testing.T) { From 1f05d0c589e17e745306ed377baeb56ee50fe1bf Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Tue, 23 Apr 2019 19:03:54 -0700 Subject: [PATCH 0032/1918] [pflags] If there is a default value, use it for tests. (#8) * If there is a default value, use it for tests. * fixes * regenerate --- flytestdlib/cli/pflags/api/generator.go | 17 ++++++++--------- flytestdlib/cli/pflags/api/templates.go | 11 +++++++++++ flytestdlib/cli/pflags/api/testdata/testtype.go | 17 ++++++++++++++--- .../cli/pflags/api/testdata/testtype_test.go | 8 ++++---- flytestdlib/config/port.go | 4 ++++ flytestdlib/storage/config_flags_test.go | 2 +- flytestdlib/tests/testdata/combined.yaml | 2 ++ 7 files changed, 44 insertions(+), 17 deletions(-) diff --git a/flytestdlib/cli/pflags/api/generator.go b/flytestdlib/cli/pflags/api/generator.go index 2f41faab9e..93a3a040b2 100644 --- a/flytestdlib/cli/pflags/api/generator.go +++ b/flytestdlib/cli/pflags/api/generator.go @@ -177,7 +177,7 @@ func discoverFieldsRecursive(ctx context.Context, typ *types.Named, defaultValue defaultValue = appendAccessors(defaultValueAccessor, fieldPath, v.Name()) if isPtr { - defaultValue = fmt.Sprintf("cfg.elemValueOrNil(%s).(%s)", defaultValue, t.Name()) + defaultValue = fmt.Sprintf("%s.elemValueOrNil(%s).(%s)", defaultValueAccessor, defaultValue, t.Name()) } } @@ -201,12 +201,6 @@ func discoverFieldsRecursive(ctx context.Context, typ *types.Named, defaultValue // will use json unmarshaler to fill in the final config object. jsonUnmarshaler := isJSONUnmarshaler(t) - testValue := tag.DefaultValue - if len(tag.DefaultValue) == 0 { - tag.DefaultValue = `""` - testValue = `"1"` - } - defaultValue := tag.DefaultValue if len(defaultValueAccessor) > 0 { defaultValue = appendAccessors(defaultValueAccessor, fieldPath, v.Name()) @@ -214,11 +208,16 @@ func discoverFieldsRecursive(ctx context.Context, typ *types.Named, defaultValue defaultValue = defaultValue + ".String()" } else { logger.Infof(ctx, "Field [%v] of type [%v] does not implement Stringer interface."+ - " Will use fmt.Sprintf() to get its default value.", v.Name(), t.String()) - defaultValue = fmt.Sprintf("fmt.Sprintf(\"%%v\",%s)", defaultValue) + " Will use %s.mustMarshalJSON() to get its default value.", defaultValueAccessor, v.Name(), t.String()) + defaultValue = fmt.Sprintf("%s.mustMarshalJSON(%s)", defaultValueAccessor, defaultValue) } } + testValue := defaultValue + if len(testValue) == 0 { + testValue = `"1"` + } + logger.Infof(ctx, "[%v] is of a Named type (struct) with default value [%v].", tag.Name, tag.DefaultValue) if jsonUnmarshaler { diff --git a/flytestdlib/cli/pflags/api/templates.go b/flytestdlib/cli/pflags/api/templates.go index 545d0ceae1..15057374f2 100644 --- a/flytestdlib/cli/pflags/api/templates.go +++ b/flytestdlib/cli/pflags/api/templates.go @@ -20,6 +20,8 @@ var mainTmpl = template.Must(template.New("MainFile").Parse( package {{ .Package }} import ( + "encoding/json" + "github.com/spf13/pflag" "fmt" {{range $path, $name := .Imports}} @@ -42,6 +44,15 @@ func ({{ .Name }}) elemValueOrNil(v interface{}) interface{} { return v } +func ({{ .Name }}) mustMarshalJSON(v json.Marshaler) string { + raw, err := v.MarshalJSON() + if err != nil { + panic(err) + } + + return string(raw) +} + // GetPFlagSet will return strongly types pflags for all fields in {{ .Name }} and its nested types. The format of the // flags is json-name.json-sub-name... etc. func (cfg {{ .Name }}) GetPFlagSet(prefix string) *pflag.FlagSet { diff --git a/flytestdlib/cli/pflags/api/testdata/testtype.go b/flytestdlib/cli/pflags/api/testdata/testtype.go index dde0f48c80..ff6b7e63ba 100755 --- a/flytestdlib/cli/pflags/api/testdata/testtype.go +++ b/flytestdlib/cli/pflags/api/testdata/testtype.go @@ -4,9 +4,11 @@ package api import ( - "fmt" + "encoding/json" "reflect" + "fmt" + "github.com/spf13/pflag" ) @@ -26,6 +28,15 @@ func (TestType) elemValueOrNil(v interface{}) interface{} { return v } +func (TestType) mustMarshalJSON(v json.Marshaler) string { + raw, err := v.MarshalJSON() + if err != nil { + panic(err) + } + + return string(raw) +} + // GetPFlagSet will return strongly types pflags for all fields in TestType and its nested types. The format of the // flags is json-name.json-sub-name... etc. func (cfg TestType) GetPFlagSet(prefix string) *pflag.FlagSet { @@ -36,7 +47,7 @@ func (cfg TestType) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.IntSlice(fmt.Sprintf("%v%v", prefix, "ints"), []int{12, 1}, "") cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "strs"), []string{"12", "1"}, "") cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "complexArr"), []string{}, "") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "c"), fmt.Sprintf("%v", DefaultTestType.StringToJSON), "I'm a complex type but can be converted from string.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "c"), DefaultTestType.mustMarshalJSON(DefaultTestType.StringToJSON), "I'm a complex type but can be converted from string.") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.type"), DefaultTestType.StorageConfig.Type, "Sets the type of storage to configure [s3/minio/local/mem].") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.connection.endpoint"), DefaultTestType.StorageConfig.Connection.Endpoint.String(), "URL for storage client to connect to.") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.connection.auth-type"), DefaultTestType.StorageConfig.Connection.AuthType, "Auth Type to use [iam, accesskey].") @@ -48,6 +59,6 @@ func (cfg TestType) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "storage.cache.max_size_mbs"), DefaultTestType.StorageConfig.Cache.MaxSizeMegabytes, "Maximum size of the cache where the Blob store data is cached in-memory. If not specified or set to 0, cache is not used") cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "storage.cache.target_gc_percent"), DefaultTestType.StorageConfig.Cache.TargetGCPercent, "Sets the garbage collection target percentage.") cmdFlags.Int64(fmt.Sprintf("%v%v", prefix, "storage.limits.maxDownloadMBs"), DefaultTestType.StorageConfig.Limits.GetLimitMegabytes, "Maximum allowed download size (in MBs) per call.") - cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "i"), cfg.elemValueOrNil(DefaultTestType.IntValue).(int), "") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "i"), DefaultTestType.elemValueOrNil(DefaultTestType.IntValue).(int), "") return cmdFlags } diff --git a/flytestdlib/cli/pflags/api/testdata/testtype_test.go b/flytestdlib/cli/pflags/api/testdata/testtype_test.go index 03412f0d86..450e04068b 100755 --- a/flytestdlib/cli/pflags/api/testdata/testtype_test.go +++ b/flytestdlib/cli/pflags/api/testdata/testtype_test.go @@ -235,14 +235,14 @@ func TestTestType_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vString, err := cmdFlags.GetString("c"); err == nil { - assert.Equal(t, string(fmt.Sprintf("%v", DefaultTestType.StringToJSON)), vString) + assert.Equal(t, string(DefaultTestType.mustMarshalJSON(DefaultTestType.StringToJSON)), vString) } else { assert.FailNow(t, err.Error()) } }) t.Run("Override", func(t *testing.T) { - testValue := "1" + testValue := DefaultTestType.mustMarshalJSON(DefaultTestType.StringToJSON) cmdFlags.Set("c", testValue) if vString, err := cmdFlags.GetString("c"); err == nil { @@ -286,7 +286,7 @@ func TestTestType_SetFlags(t *testing.T) { }) t.Run("Override", func(t *testing.T) { - testValue := "1" + testValue := DefaultTestType.StorageConfig.Connection.Endpoint.String() cmdFlags.Set("storage.connection.endpoint", testValue) if vString, err := cmdFlags.GetString("storage.connection.endpoint"); err == nil { @@ -499,7 +499,7 @@ func TestTestType_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vInt, err := cmdFlags.GetInt("i"); err == nil { - assert.Equal(t, int(cfg.elemValueOrNil(DefaultTestType.IntValue).(int)), vInt) + assert.Equal(t, int(DefaultTestType.elemValueOrNil(DefaultTestType.IntValue).(int)), vInt) } else { assert.FailNow(t, err.Error()) } diff --git a/flytestdlib/config/port.go b/flytestdlib/config/port.go index 87bbc854e2..fef6e4d11b 100644 --- a/flytestdlib/config/port.go +++ b/flytestdlib/config/port.go @@ -12,6 +12,10 @@ type Port struct { Port int `json:"port,omitempty"` } +func (p Port) String() string { + return strconv.Itoa(p.Port) +} + func (p Port) MarshalJSON() ([]byte, error) { return json.Marshal(p.Port) } diff --git a/flytestdlib/storage/config_flags_test.go b/flytestdlib/storage/config_flags_test.go index 429af71283..4809b6b78a 100755 --- a/flytestdlib/storage/config_flags_test.go +++ b/flytestdlib/storage/config_flags_test.go @@ -132,7 +132,7 @@ func TestConfig_SetFlags(t *testing.T) { }) t.Run("Override", func(t *testing.T) { - testValue := "1" + testValue := defaultConfig.Connection.Endpoint.String() cmdFlags.Set("connection.endpoint", testValue) if vString, err := cmdFlags.GetString("connection.endpoint"); err == nil { diff --git a/flytestdlib/tests/testdata/combined.yaml b/flytestdlib/tests/testdata/combined.yaml index f167b1ab33..24a463b339 100755 --- a/flytestdlib/tests/testdata/combined.yaml +++ b/flytestdlib/tests/testdata/combined.yaml @@ -1,4 +1,6 @@ logger: + formatter: + type: "" level: 5 mute: false show-source: false From 23a3d52e66c75aaec78736ffd3c6aaaaf5808063 Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Tue, 23 Apr 2019 19:09:42 -0700 Subject: [PATCH 0033/1918] Fix release tag image --- flytestdlib/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flytestdlib/README.md b/flytestdlib/README.md index 6411abcfb5..6045bf3f46 100644 --- a/flytestdlib/README.md +++ b/flytestdlib/README.md @@ -1,6 +1,6 @@ Common Go Tools ===================== -[![Current Release](https://img.shields.io/github/release/flytestdlib.svg)](https://github.com/lyft/flytestdlib/releases/latest) +[![Current Release](https://img.shields.io/github/release/lyft/flytestdlib.svg)](https://github.com/lyft/flytestdlib/releases/latest) [![Build Status](https://travis-ci.org/lyft/flytestdlib.svg?branch=master)](https://travis-ci.org/lyft/flytestdlib) [![GoDoc](https://godoc.org/github.com/lyft/flytestdlib?status.svg)](https://godoc.org/github.com/lyft/flytestdlib) [![License](https://img.shields.io/badge/LICENSE-Apache2.0-ff69b4.svg)](http://www.apache.org/licenses/LICENSE-2.0.html) From 23d5f1d4f632fb3f73fe3e4da8167eadc3bd82d6 Mon Sep 17 00:00:00 2001 From: goreleaserbot Date: Tue, 23 Apr 2019 23:32:33 -0700 Subject: [PATCH 0034/1918] Scoop update for flytestdlib version v0.2.2-alpha.1 --- flytestdlib/flytestdlib.json | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 flytestdlib/flytestdlib.json diff --git a/flytestdlib/flytestdlib.json b/flytestdlib/flytestdlib.json new file mode 100644 index 0000000000..4af8f55d08 --- /dev/null +++ b/flytestdlib/flytestdlib.json @@ -0,0 +1,26 @@ +{ + "version": "0.2.2-alpha.1", + "architecture": { + "32bit": { + "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.2-alpha.1/flytestdlib_0.2.2-alpha.1_Windows_i386.zip", + "bin": [ + "pflags.exe" + ], + "hash": "47d5e8405da7ab80e402e4bc02fc935d165b6bd0d7e42a13a728dfb756fc170f" + }, + "64bit": { + "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.2-alpha.1/flytestdlib_0.2.2-alpha.1_Windows_x86_64.zip", + "bin": [ + "pflags.exe" + ], + "hash": "59a9766bba67bb20fadcbe907137de8ad9d88996147bc3905098457e73103a00" + } + }, + "homepage": "https://godoc.org/github.com/lyft/flytestdlib", + "license": "Apache-2.0", + "description": "Common Go utilities (Typed-Config, PFlags, Prometheus Metrics,... more).", + "persist": [ + "data", + "config.toml" + ] +} \ No newline at end of file From 9de816c73ea55a2b8646e956997707fa6a8fbae4 Mon Sep 17 00:00:00 2001 From: goreleaserbot Date: Wed, 24 Apr 2019 00:23:14 -0700 Subject: [PATCH 0035/1918] Scoop update for flytestdlib version v0.2.2-alpha.4 --- flytestdlib/flytestdlib.json | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/flytestdlib/flytestdlib.json b/flytestdlib/flytestdlib.json index 4af8f55d08..264a4e9ed6 100644 --- a/flytestdlib/flytestdlib.json +++ b/flytestdlib/flytestdlib.json @@ -1,19 +1,19 @@ { - "version": "0.2.2-alpha.1", + "version": "0.2.2-alpha.4", "architecture": { "32bit": { - "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.2-alpha.1/flytestdlib_0.2.2-alpha.1_Windows_i386.zip", + "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.2-alpha.4/flytestdlib_0.2.2-alpha.4_Windows_i386.zip", "bin": [ "pflags.exe" ], - "hash": "47d5e8405da7ab80e402e4bc02fc935d165b6bd0d7e42a13a728dfb756fc170f" + "hash": "cb3cd184ba9d371f0959782d99bdcc6638c9e337b1a7ed763ba26d14e65ed83b" }, "64bit": { - "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.2-alpha.1/flytestdlib_0.2.2-alpha.1_Windows_x86_64.zip", + "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.2-alpha.4/flytestdlib_0.2.2-alpha.4_Windows_x86_64.zip", "bin": [ "pflags.exe" ], - "hash": "59a9766bba67bb20fadcbe907137de8ad9d88996147bc3905098457e73103a00" + "hash": "2c4fc1edc8b61ff6d489bc5de046ffb7fd3dbe3f7b9a7a53b3262d7e5d150ca9" } }, "homepage": "https://godoc.org/github.com/lyft/flytestdlib", From fa1cfb16c5955f4bbd06e06300cc1b1664953a48 Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Wed, 24 Apr 2019 09:40:22 -0700 Subject: [PATCH 0036/1918] Hookkup Goreleaser (#10) * Goreleaser * Add contributing section * Update README.md * Add encrypted github token * Add version info to binary * update archive to archives * Update unit tests & docs * Update README.md * lint * lint again --- flytestdlib/.goreleaser.yml | 65 +++++++++++++++++++++++++++ flytestdlib/.travis.yml | 36 +++++++++------ flytestdlib/README.md | 38 ++++++++++++++++ flytestdlib/cli/pflags/cmd/version.go | 2 +- flytestdlib/profutils/server.go | 11 ++--- flytestdlib/profutils/server_test.go | 4 +- flytestdlib/version/version.go | 4 +- flytestdlib/version/version_test.go | 2 +- 8 files changed, 136 insertions(+), 26 deletions(-) create mode 100644 flytestdlib/.goreleaser.yml diff --git a/flytestdlib/.goreleaser.yml b/flytestdlib/.goreleaser.yml new file mode 100644 index 0000000000..6e193fe0e9 --- /dev/null +++ b/flytestdlib/.goreleaser.yml @@ -0,0 +1,65 @@ +before: + hooks: + - dep ensure -vendor-only +builds: + - env: + - CGO_ENABLED=0 + main: ./cli/pflags/main.go + binary: pflags + goos: + - linux + - windows + - darwin + ldflags: + - -s -w -X github.com/lyft/flytestdlib/version.Version={{.Version}} -X github.com/lyft/flytestdlib/version.Build={{.ShortCommit}} -X github.com/lyft/flytestdlib/version.BuildTime={{.Date}} +archives: + - replacements: + darwin: macOS + linux: Linux + windows: Windows + 386: i386 + amd64: x86_64 + format_overrides: + - goos: windows + format: zip +checksum: + name_template: 'checksums.txt' +snapshot: + name_template: "{{ .Tag }}-next" +changelog: + sort: asc + filters: + exclude: + - '^docs:' + - '^test:' +scoop: + # Default is "https://github.com///releases/download/{{ .Tag }}/{{ .ArtifactName }}" + # url_template: "http://github.mycompany.com/foo/bar/releases/{{ .Tag }}/{{ .ArtifactName }}" + + # Repository to push the app manifest to. + bucket: + owner: lyft + name: flytestdlib + + # Git author used to commit to the repository. + # Defaults are shown. + commit_author: + name: goreleaserbot + email: goreleaser@carlosbecker.com + + # Your app's homepage. + # Default is empty. + homepage: "https://godoc.org/github.com/lyft/flytestdlib" + + # Your app's description. + # Default is empty. + description: "Common Go utilities (Typed-Config, PFlags, Prometheus Metrics,... more)." + + # Your app's license + # Default is empty. + license: Apache-2.0 + + # Persist data between application updates + persist: + - "data" + - "config.toml" diff --git a/flytestdlib/.travis.yml b/flytestdlib/.travis.yml index 05d3a1b406..99ebfa3e36 100644 --- a/flytestdlib/.travis.yml +++ b/flytestdlib/.travis.yml @@ -1,18 +1,28 @@ sudo: required language: go go: - - "1.11" +- '1.11' jobs: include: - - stage: test - name: unit tests - install: make install - script: make test_unit - - stage: test - name: benchmark tests - install: make install - script: make test_benchmark - - stage: test - install: make install - name: lint - script: make lint + - stage: test + name: unit tests + install: make install + script: make test_unit + - stage: test + name: benchmark tests + install: make install + script: make test_benchmark + - stage: test + install: make install + name: lint + script: make lint +deploy: +- provider: script + skip_cleanup: true + script: curl -sL https://git.io/goreleaser | bash + on: + tags: true + condition: "$TRAVIS_OS_NAME = linux" +env: + global: + secure: xqB2LwI1qbrPUIYXtaOnIoBX5b1h1ydxOvy5Cu1sS/R2t1BfYHFX5oH0/0Z23bfpQBpzEweXA3145xyjg1Q2vJiN2ebPLmMPduWdjp1be/4xWCnkftZuGW7LDEFg3zuREUMKdfDhkb0uQ5gzIte3TvGh/tJwfxwUHVHQEl1aqPYqbRHRqoLJZiuhgIH+17su5mBFfu/62xXMP8zImLUq4WLrmbmMszLWg3IOu+oawpMXuDsjoxkucdFjmo2rsVUNr3QNo7ock7hl1OYHJZvWuRV+HxCaNRNUrbr8GuWYUSNOB51Ml7kLAlSxnmKJMs1fZRxTPlXR/0+XA8zAWahcKvKxRqguoFNVqYEESS/yRoJhLctgwAjx/btSc1a4BXCwIDFXNFVBGyZiVcLnh9PG6WWXI2YRWSbXmoBG3QN8Dtdpz54qoCpCA7IVWijWBHVXiVbyIn9XmTMFCdMXIFZQ7mzzk6K+894taPSRsia305LCJ2/h1df8bLsw5zcXmjXjZpkxM7rK5nJqx6IiaZ94GmeRER3OQxKTxxBuoZvcWcn9+ni+FtA2EzJuMBxbWLh+jinfqqieLkoOPHeBzAN6YyaPuUQje/dT4tjdf95V+wuojfu/TqIk/o7WwMPgfYWP2tlj0R5GwoA3ZocZRXZYRP/gg1Cje6wCQCOiIPsFC5g= diff --git a/flytestdlib/README.md b/flytestdlib/README.md index 6045bf3f46..37986d95d9 100644 --- a/flytestdlib/README.md +++ b/flytestdlib/README.md @@ -46,3 +46,41 @@ This library consists of: - utils - version + +Contributing +------------ + +## Versioning + +This repo follows [semantic versioning](https://semver.org/). + +## Releases + +This repository is hooked up with [goreleaser](https://goreleaser.com/). Maintainers are expected to create tags and let goreleaser compose the release message and create a release. + +To create a new release, follow these steps: + +- Create a PR with your changes. + +- [Optional] Create an alpha tag on your branch and push that. + + - First get existing tags `git describe --abbrev=0 --tags` + + - Figure out the next alpha version (e.g. if tag is v1.2.3 then you should create a v1.2.4-alpha.0 tag) + + - Create a tag `git tag v1.2.4-alpha.0` + + - Push tag `git push --tags` + +- Merge your changes and checkout master branch `git checkout master && git pull` + +- Bump version tag and push to branch. + + - First get existing tags `git describe --abbrev=0 --tags` + + - Figure out the next release version (e.g. if tag is v1.2.3 then you should create a v1.2.4 tag or v1.3.0 or a v2.0.0 depending on what has changed. Refer to [Semantic Versioning](https://semver.org/) for information about when to bump each) + + - Create a tag `git tag v1.2.4` + + - Push tag `git push --tags` + diff --git a/flytestdlib/cli/pflags/cmd/version.go b/flytestdlib/cli/pflags/cmd/version.go index 8ee00af2e4..5d15a4d7c1 100644 --- a/flytestdlib/cli/pflags/cmd/version.go +++ b/flytestdlib/cli/pflags/cmd/version.go @@ -8,7 +8,7 @@ import ( var versionCmd = &cobra.Command{ Aliases: []string{"version", "ver"}, Run: func(cmd *cobra.Command, args []string) { - cmd.Printf("Version: %s\nBuildSHA: %s\nBuildTS: %s\n", version.Version, version.Build, version.BuildTime.String()) + version.LogBuildInformation("pflags") }, } diff --git a/flytestdlib/profutils/server.go b/flytestdlib/profutils/server.go index b4dcc00ac3..fe0157fca9 100644 --- a/flytestdlib/profutils/server.go +++ b/flytestdlib/profutils/server.go @@ -5,13 +5,10 @@ import ( "encoding/json" "fmt" "net/http" - "time" "github.com/lyft/flytestdlib/config" - - "github.com/lyft/flytestdlib/version" - "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/version" "github.com/prometheus/client_golang/prometheus/promhttp" _ "net/http/pprof" // Import for pprof server @@ -30,9 +27,9 @@ const ( ) type BuildVersion struct { - Build string `json:"build"` - Version string `json:"version"` - Timestamp time.Time `json:"timestamp,string"` + Build string `json:"build"` + Version string `json:"version"` + Timestamp string `json:"timestamp,string"` } // Writes a string to the Http output stream diff --git a/flytestdlib/profutils/server_test.go b/flytestdlib/profutils/server_test.go index deb3383128..c989bcbbf0 100644 --- a/flytestdlib/profutils/server_test.go +++ b/flytestdlib/profutils/server_test.go @@ -88,14 +88,14 @@ func TestVersionHandler(t *testing.T) { URL: &testURL, } - version.BuildTime = time.Now() + version.BuildTime = time.Now().String() http.DefaultServeMux.ServeHTTP(writer, request) assert.Equal(t, http.StatusOK, writer.Status) assert.NotNil(t, writer.Body) bv := BuildVersion{} assert.NoError(t, json.Unmarshal(writer.Body, &bv)) - assert.Equal(t, bv.Timestamp.Unix(), version.BuildTime.Unix()) + assert.Equal(t, bv.Timestamp, version.BuildTime) assert.Equal(t, bv.Build, version.Build) assert.Equal(t, bv.Version, version.Version) } diff --git a/flytestdlib/version/version.go b/flytestdlib/version/version.go index ab3e4cf112..08536c05e4 100644 --- a/flytestdlib/version/version.go +++ b/flytestdlib/version/version.go @@ -17,13 +17,13 @@ var ( // Version for the build, should follow a semver Version = "unknown" // Build timestamp - BuildTime = time.Now() + BuildTime = time.Now().String() ) // Use this method to log the build information for the current app. The app name should be provided. To inject the build // and version information refer to the top-level comment in this file func LogBuildInformation(appName string) { logrus.Info("------------------------------------------------------------------------") - logrus.Infof("App [%s], Version [%s], BuildSHA [%s], BuildTS [%s]", appName, Version, Build, BuildTime.String()) + logrus.Infof("App [%s], Version [%s], BuildSHA [%s], BuildTS [%s]", appName, Version, Build, BuildTime) logrus.Info("------------------------------------------------------------------------") } diff --git a/flytestdlib/version/version_test.go b/flytestdlib/version/version_test.go index d4ddb2def4..be7826ab76 100644 --- a/flytestdlib/version/version_test.go +++ b/flytestdlib/version/version_test.go @@ -20,7 +20,7 @@ func (dFormat) Format(e *logrus.Entry) ([]byte, error) { func TestLogBuildInformation(t *testing.T) { n := time.Now() - BuildTime = n + BuildTime = n.String() buf := bytes.NewBufferString("") logrus.SetFormatter(dFormat{}) logrus.SetOutput(buf) From 9f9155072bcad0964d0fa4b3c4a120ba0fc25591 Mon Sep 17 00:00:00 2001 From: goreleaserbot Date: Wed, 24 Apr 2019 09:46:17 -0700 Subject: [PATCH 0037/1918] Scoop update for flytestdlib version v0.2.2 --- flytestdlib/flytestdlib.json | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/flytestdlib/flytestdlib.json b/flytestdlib/flytestdlib.json index 264a4e9ed6..b708048a40 100644 --- a/flytestdlib/flytestdlib.json +++ b/flytestdlib/flytestdlib.json @@ -1,19 +1,19 @@ { - "version": "0.2.2-alpha.4", + "version": "0.2.2", "architecture": { "32bit": { - "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.2-alpha.4/flytestdlib_0.2.2-alpha.4_Windows_i386.zip", + "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.2/flytestdlib_0.2.2_Windows_i386.zip", "bin": [ "pflags.exe" ], - "hash": "cb3cd184ba9d371f0959782d99bdcc6638c9e337b1a7ed763ba26d14e65ed83b" + "hash": "3693ea92f967aa57568d1b646be44a8871261578e489113b2ff8d1926d47bf17" }, "64bit": { - "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.2-alpha.4/flytestdlib_0.2.2-alpha.4_Windows_x86_64.zip", + "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.2/flytestdlib_0.2.2_Windows_x86_64.zip", "bin": [ "pflags.exe" ], - "hash": "2c4fc1edc8b61ff6d489bc5de046ffb7fd3dbe3f7b9a7a53b3262d7e5d150ca9" + "hash": "c073e0dce9b45630bdd862eb7c1a45bc000c664e7d25f74be2ca0768a516cf3d" } }, "homepage": "https://godoc.org/github.com/lyft/flytestdlib", From 555f18ff2f71d1bbf023842863d3f0475a1f7437 Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Wed, 24 Apr 2019 13:39:25 -0700 Subject: [PATCH 0038/1918] Add godownloader script (#11) * Add godownloader script * Update README.md * typo --- flytestdlib/README.md | 26 ++- flytestdlib/godownloader.sh | 381 ++++++++++++++++++++++++++++++++++++ 2 files changed, 400 insertions(+), 7 deletions(-) create mode 100644 flytestdlib/godownloader.sh diff --git a/flytestdlib/README.md b/flytestdlib/README.md index 37986d95d9..2e979e460d 100644 --- a/flytestdlib/README.md +++ b/flytestdlib/README.md @@ -16,6 +16,18 @@ This library consists of: Tool to generate a pflags for all fields in a given struct. + #### Install + + On POSIX systems, run: `curl -sfL https://raw.githubusercontent.com/lyft/flytestdlib/godownloader/godownloader.sh | sh` + + On Windows: + + Install scoop: `iex (new-object net.webclient).downloadstring('https://get.scoop.sh')` + + Run: `scoop bucket add flytestdlib https://github.com/lyft/flytestdlib.git` + + Run: `scoop install pflags` + - storage Abstract storage library that uses stow behind the scenes to connect to s3/azure/gcs but also offers configurable factory, in-memory storage (for testing) as well as native protobuf support. @@ -63,13 +75,13 @@ To create a new release, follow these steps: - Create a PR with your changes. - [Optional] Create an alpha tag on your branch and push that. - + - First get existing tags `git describe --abbrev=0 --tags` - + - Figure out the next alpha version (e.g. if tag is v1.2.3 then you should create a v1.2.4-alpha.0 tag) - + - Create a tag `git tag v1.2.4-alpha.0` - + - Push tag `git push --tags` - Merge your changes and checkout master branch `git checkout master && git pull` @@ -77,10 +89,10 @@ To create a new release, follow these steps: - Bump version tag and push to branch. - First get existing tags `git describe --abbrev=0 --tags` - + - Figure out the next release version (e.g. if tag is v1.2.3 then you should create a v1.2.4 tag or v1.3.0 or a v2.0.0 depending on what has changed. Refer to [Semantic Versioning](https://semver.org/) for information about when to bump each) - + - Create a tag `git tag v1.2.4` - + - Push tag `git push --tags` diff --git a/flytestdlib/godownloader.sh b/flytestdlib/godownloader.sh new file mode 100644 index 0000000000..679511f30b --- /dev/null +++ b/flytestdlib/godownloader.sh @@ -0,0 +1,381 @@ +#!/bin/sh +set -e +# Code generated by godownloader on 2019-04-24T17:57:45Z. DO NOT EDIT. +# + +usage() { + this=$1 + cat </dev/null +} +echoerr() { + echo "$@" 1>&2 +} +log_prefix() { + echo "$0" +} +_logp=6 +log_set_priority() { + _logp="$1" +} +log_priority() { + if test -z "$1"; then + echo "$_logp" + return + fi + [ "$1" -le "$_logp" ] +} +log_tag() { + case $1 in + 0) echo "emerg" ;; + 1) echo "alert" ;; + 2) echo "crit" ;; + 3) echo "err" ;; + 4) echo "warning" ;; + 5) echo "notice" ;; + 6) echo "info" ;; + 7) echo "debug" ;; + *) echo "$1" ;; + esac +} +log_debug() { + log_priority 7 || return 0 + echoerr "$(log_prefix)" "$(log_tag 7)" "$@" +} +log_info() { + log_priority 6 || return 0 + echoerr "$(log_prefix)" "$(log_tag 6)" "$@" +} +log_err() { + log_priority 3 || return 0 + echoerr "$(log_prefix)" "$(log_tag 3)" "$@" +} +log_crit() { + log_priority 2 || return 0 + echoerr "$(log_prefix)" "$(log_tag 2)" "$@" +} +uname_os() { + os=$(uname -s | tr '[:upper:]' '[:lower:]') + case "$os" in + msys_nt) os="windows" ;; + esac + echo "$os" +} +uname_arch() { + arch=$(uname -m) + case $arch in + x86_64) arch="amd64" ;; + x86) arch="386" ;; + i686) arch="386" ;; + i386) arch="386" ;; + aarch64) arch="arm64" ;; + armv5*) arch="armv5" ;; + armv6*) arch="armv6" ;; + armv7*) arch="armv7" ;; + esac + echo ${arch} +} +uname_os_check() { + os=$(uname_os) + case "$os" in + darwin) return 0 ;; + dragonfly) return 0 ;; + freebsd) return 0 ;; + linux) return 0 ;; + android) return 0 ;; + nacl) return 0 ;; + netbsd) return 0 ;; + openbsd) return 0 ;; + plan9) return 0 ;; + solaris) return 0 ;; + windows) return 0 ;; + esac + log_crit "uname_os_check '$(uname -s)' got converted to '$os' which is not a GOOS value. Please file bug at https://github.com/client9/shlib" + return 1 +} +uname_arch_check() { + arch=$(uname_arch) + case "$arch" in + 386) return 0 ;; + amd64) return 0 ;; + arm64) return 0 ;; + armv5) return 0 ;; + armv6) return 0 ;; + armv7) return 0 ;; + ppc64) return 0 ;; + ppc64le) return 0 ;; + mips) return 0 ;; + mipsle) return 0 ;; + mips64) return 0 ;; + mips64le) return 0 ;; + s390x) return 0 ;; + amd64p32) return 0 ;; + esac + log_crit "uname_arch_check '$(uname -m)' got converted to '$arch' which is not a GOARCH value. Please file bug report at https://github.com/client9/shlib" + return 1 +} +untar() { + tarball=$1 + case "${tarball}" in + *.tar.gz | *.tgz) tar -xzf "${tarball}" ;; + *.tar) tar -xf "${tarball}" ;; + *.zip) unzip "${tarball}" ;; + *) + log_err "untar unknown archive format for ${tarball}" + return 1 + ;; + esac +} +http_download_curl() { + local_file=$1 + source_url=$2 + header=$3 + if [ -z "$header" ]; then + code=$(curl -w '%{http_code}' -sL -o "$local_file" "$source_url") + else + code=$(curl -w '%{http_code}' -sL -H "$header" -o "$local_file" "$source_url") + fi + if [ "$code" != "200" ]; then + log_debug "http_download_curl received HTTP status $code" + return 1 + fi + return 0 +} +http_download_wget() { + local_file=$1 + source_url=$2 + header=$3 + if [ -z "$header" ]; then + wget -q -O "$local_file" "$source_url" + else + wget -q --header "$header" -O "$local_file" "$source_url" + fi +} +http_download() { + log_debug "http_download $2" + if is_command curl; then + http_download_curl "$@" + return + elif is_command wget; then + http_download_wget "$@" + return + fi + log_crit "http_download unable to find wget or curl" + return 1 +} +http_copy() { + tmp=$(mktemp) + http_download "${tmp}" "$1" "$2" || return 1 + body=$(cat "$tmp") + rm -f "${tmp}" + echo "$body" +} +github_release() { + owner_repo=$1 + version=$2 + test -z "$version" && version="latest" + giturl="https://github.com/${owner_repo}/releases/${version}" + json=$(http_copy "$giturl" "Accept:application/json") + test -z "$json" && return 1 + version=$(echo "$json" | tr -s '\n' ' ' | sed 's/.*"tag_name":"//' | sed 's/".*//') + test -z "$version" && return 1 + echo "$version" +} +hash_sha256() { + TARGET=${1:-/dev/stdin} + if is_command gsha256sum; then + hash=$(gsha256sum "$TARGET") || return 1 + echo "$hash" | cut -d ' ' -f 1 + elif is_command sha256sum; then + hash=$(sha256sum "$TARGET") || return 1 + echo "$hash" | cut -d ' ' -f 1 + elif is_command shasum; then + hash=$(shasum -a 256 "$TARGET" 2>/dev/null) || return 1 + echo "$hash" | cut -d ' ' -f 1 + elif is_command openssl; then + hash=$(openssl -dst openssl dgst -sha256 "$TARGET") || return 1 + echo "$hash" | cut -d ' ' -f a + else + log_crit "hash_sha256 unable to find command to compute sha-256 hash" + return 1 + fi +} +hash_sha256_verify() { + TARGET=$1 + checksums=$2 + if [ -z "$checksums" ]; then + log_err "hash_sha256_verify checksum file not specified in arg2" + return 1 + fi + BASENAME=${TARGET##*/} + want=$(grep "${BASENAME}" "${checksums}" 2>/dev/null | tr '\t' ' ' | cut -d ' ' -f 1) + if [ -z "$want" ]; then + log_err "hash_sha256_verify unable to find checksum for '${TARGET}' in '${checksums}'" + return 1 + fi + got=$(hash_sha256 "$TARGET") + if [ "$want" != "$got" ]; then + log_err "hash_sha256_verify checksum for '$TARGET' did not verify ${want} vs $got" + return 1 + fi +} +cat /dev/null < Date: Wed, 24 Apr 2019 15:22:06 -0700 Subject: [PATCH 0039/1918] Adding codecov to report unit test coverage (#13) Setup new makefile target to get codecoverage and report it using codecov --- flytestdlib/.travis.yml | 6 +++--- flytestdlib/Makefile | 5 +++++ 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/flytestdlib/.travis.yml b/flytestdlib/.travis.yml index 99ebfa3e36..2f28738cd9 100644 --- a/flytestdlib/.travis.yml +++ b/flytestdlib/.travis.yml @@ -5,9 +5,9 @@ go: jobs: include: - stage: test - name: unit tests + name: unit tests and coverage install: make install - script: make test_unit + script: make test_unit_codecov - stage: test name: benchmark tests install: make install @@ -25,4 +25,4 @@ deploy: condition: "$TRAVIS_OS_NAME = linux" env: global: - secure: xqB2LwI1qbrPUIYXtaOnIoBX5b1h1ydxOvy5Cu1sS/R2t1BfYHFX5oH0/0Z23bfpQBpzEweXA3145xyjg1Q2vJiN2ebPLmMPduWdjp1be/4xWCnkftZuGW7LDEFg3zuREUMKdfDhkb0uQ5gzIte3TvGh/tJwfxwUHVHQEl1aqPYqbRHRqoLJZiuhgIH+17su5mBFfu/62xXMP8zImLUq4WLrmbmMszLWg3IOu+oawpMXuDsjoxkucdFjmo2rsVUNr3QNo7ock7hl1OYHJZvWuRV+HxCaNRNUrbr8GuWYUSNOB51Ml7kLAlSxnmKJMs1fZRxTPlXR/0+XA8zAWahcKvKxRqguoFNVqYEESS/yRoJhLctgwAjx/btSc1a4BXCwIDFXNFVBGyZiVcLnh9PG6WWXI2YRWSbXmoBG3QN8Dtdpz54qoCpCA7IVWijWBHVXiVbyIn9XmTMFCdMXIFZQ7mzzk6K+894taPSRsia305LCJ2/h1df8bLsw5zcXmjXjZpkxM7rK5nJqx6IiaZ94GmeRER3OQxKTxxBuoZvcWcn9+ni+FtA2EzJuMBxbWLh+jinfqqieLkoOPHeBzAN6YyaPuUQje/dT4tjdf95V+wuojfu/TqIk/o7WwMPgfYWP2tlj0R5GwoA3ZocZRXZYRP/gg1Cje6wCQCOiIPsFC5g= + - secure: xqB2LwI1qbrPUIYXtaOnIoBX5b1h1ydxOvy5Cu1sS/R2t1BfYHFX5oH0/0Z23bfpQBpzEweXA3145xyjg1Q2vJiN2ebPLmMPduWdjp1be/4xWCnkftZuGW7LDEFg3zuREUMKdfDhkb0uQ5gzIte3TvGh/tJwfxwUHVHQEl1aqPYqbRHRqoLJZiuhgIH+17su5mBFfu/62xXMP8zImLUq4WLrmbmMszLWg3IOu+oawpMXuDsjoxkucdFjmo2rsVUNr3QNo7ock7hl1OYHJZvWuRV+HxCaNRNUrbr8GuWYUSNOB51Ml7kLAlSxnmKJMs1fZRxTPlXR/0+XA8zAWahcKvKxRqguoFNVqYEESS/yRoJhLctgwAjx/btSc1a4BXCwIDFXNFVBGyZiVcLnh9PG6WWXI2YRWSbXmoBG3QN8Dtdpz54qoCpCA7IVWijWBHVXiVbyIn9XmTMFCdMXIFZQ7mzzk6K+894taPSRsia305LCJ2/h1df8bLsw5zcXmjXjZpkxM7rK5nJqx6IiaZ94GmeRER3OQxKTxxBuoZvcWcn9+ni+FtA2EzJuMBxbWLh+jinfqqieLkoOPHeBzAN6YyaPuUQje/dT4tjdf95V+wuojfu/TqIk/o7WwMPgfYWP2tlj0R5GwoA3ZocZRXZYRP/gg1Cje6wCQCOiIPsFC5g= diff --git a/flytestdlib/Makefile b/flytestdlib/Makefile index 623663f5f1..211f908f0a 100644 --- a/flytestdlib/Makefile +++ b/flytestdlib/Makefile @@ -22,3 +22,8 @@ compile: gen-config: which pflags || (go get github.com/lyft/flytestdlib/cli/pflags) @go generate ./... + +.PHONY: test_unit_codecov +test_unit_codecov: + go test ./... -race -coverprofile=coverage.txt -covermode=atomic; curl -s https://codecov.io/bash > codecov_bash.sh; bash codecov_bash.sh + From 1b96fb4bea24b45c21de4eb21e7801a878ec6da8 Mon Sep 17 00:00:00 2001 From: Ketan Umare Date: Wed, 24 Apr 2019 16:32:32 -0700 Subject: [PATCH 0040/1918] More Badges: Activity, Code coverage, report card (#15) * More badges for activity reporting * no links for activity --- flytestdlib/README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/flytestdlib/README.md b/flytestdlib/README.md index 2e979e460d..085e2c914b 100644 --- a/flytestdlib/README.md +++ b/flytestdlib/README.md @@ -4,6 +4,10 @@ Common Go Tools [![Build Status](https://travis-ci.org/lyft/flytestdlib.svg?branch=master)](https://travis-ci.org/lyft/flytestdlib) [![GoDoc](https://godoc.org/github.com/lyft/flytestdlib?status.svg)](https://godoc.org/github.com/lyft/flytestdlib) [![License](https://img.shields.io/badge/LICENSE-Apache2.0-ff69b4.svg)](http://www.apache.org/licenses/LICENSE-2.0.html) +[![CodeCoverage](https://img.shields.io/codecov/c/github/lyft/flytestdlib.svg)](https://codecov.io/gh/lyft/flytestdlib) +[![Go Report Card](https://goreportcard.com/badge/github.com/lyft/flytestdlib)](https://goreportcard.com/report/github.com/lyft/flytestdlib) +![Commit activity](https://img.shields.io/github/commit-activity/w/lyft/flytestdlib.svg?style=plastic) +![Commit since last releast](https://img.shields.io/github/commits-since/lyft/flytestdlib/latest.svg?style=plastic) Shared components we found ourselves building time and time again, so we collected them in one place! From f99988b07b232172656ec0e745b52618a2e857bf Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Fri, 26 Apr 2019 11:29:22 -0700 Subject: [PATCH 0041/1918] Fix Godownloader generated script (#14) * Add godownloader script * Update README.md * typo * Fix generated godownloader script * Default download to /Users/hamabuelfutuh/src/go/bin --- flytestdlib/godownloader.sh | 32 +++++++++++++++++++++++++++----- 1 file changed, 27 insertions(+), 5 deletions(-) mode change 100644 => 100755 flytestdlib/godownloader.sh diff --git a/flytestdlib/godownloader.sh b/flytestdlib/godownloader.sh old mode 100644 new mode 100755 index 679511f30b..44093b6d07 --- a/flytestdlib/godownloader.sh +++ b/flytestdlib/godownloader.sh @@ -26,7 +26,7 @@ parse_args() { #BINDIR is ./bin unless set be ENV # over-ridden by flag below - BINDIR=${BINDIR:-./bin} + BINDIR=${BINDIR:-${GOPATH}/bin} while getopts "b:dh?x" arg; do case "$arg" in b) BINDIR="$OPTARG" ;; @@ -66,10 +66,13 @@ is_supported_platform() { case "$platform" in linux/amd64) found=0 ;; linux/386) found=0 ;; + linux/x86_64) found=0 ;; windows/amd64) found=0 ;; windows/386) found=0 ;; + windows/x86_64) found=0 ;; darwin/amd64) found=0 ;; darwin/386) found=0 ;; + darwin/x86_64) found=0 ;; esac return $found } @@ -99,11 +102,28 @@ tag_to_version() { } adjust_format() { # change format (tar.gz or zip) based on OS - true + FORMAT="tar.gz" + case "$os" in + darwin) return 0 ;; + dragonfly) return 0 ;; + freebsd) return 0 ;; + linux) return 0 ;; + android) return 0 ;; + nacl) return 0 ;; + netbsd) return 0 ;; + openbsd) return 0 ;; + plan9) return 0 ;; + solaris) return 0 ;; + esac + FORMAT="zip" + return 0 } adjust_os() { # adjust archive name based on OS - true + case "$os" in + darwin) OS="macOS"; + esac + return 0; } adjust_arch() { # adjust archive name based on ARCH @@ -177,7 +197,7 @@ uname_os() { uname_arch() { arch=$(uname -m) case $arch in - x86_64) arch="amd64" ;; + x86_64) arch="x86_64" ;; x86) arch="386" ;; i686) arch="386" ;; i386) arch="386" ;; @@ -223,6 +243,7 @@ uname_arch_check() { mips64le) return 0 ;; s390x) return 0 ;; amd64p32) return 0 ;; + x86_64) return 0 ;; esac log_crit "uname_arch_check '$(uname -m)' got converted to '$arch' which is not a GOARCH value. Please file bug report at https://github.com/client9/shlib" return 1 @@ -371,11 +392,12 @@ adjust_arch log_info "found version: ${VERSION} for ${TAG}/${OS}/${ARCH}" -NAME= +NAME=${PROJECT_NAME}_${VERSION}_${OS}_${ARCH} TARBALL=${NAME}.${FORMAT} TARBALL_URL=${GITHUB_DOWNLOAD}/${TAG}/${TARBALL} CHECKSUM=checksums.txt CHECKSUM_URL=${GITHUB_DOWNLOAD}/${TAG}/${CHECKSUM} +log_info "Tarball url: ${TARBALL_URL}" execute From a12bdcf8746c91398c1e127aa9dfe12590d44d85 Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Tue, 30 Apr 2019 11:12:36 -0700 Subject: [PATCH 0042/1918] Race condition in logger.SetConfig (#16) * remove globalconfig variable to avoid data race * Add WithRoutineLabel * rename * simplify logger a bit * format * doc --- flytestdlib/config/section.go | 4 + flytestdlib/contextutils/context.go | 37 ++- flytestdlib/contextutils/context_test.go | 10 + flytestdlib/logger/config.go | 39 +-- flytestdlib/logger/config_test.go | 10 +- flytestdlib/logger/logger.go | 298 +++++++++++------------ flytestdlib/logger/logger_test.go | 63 +---- flytestdlib/profutils/server.go | 9 +- 8 files changed, 217 insertions(+), 253 deletions(-) diff --git a/flytestdlib/config/section.go b/flytestdlib/config/section.go index 41fddd0c1b..34f5f7a5dd 100644 --- a/flytestdlib/config/section.go +++ b/flytestdlib/config/section.go @@ -199,6 +199,10 @@ func (r *section) SetConfig(c Config) error { r.lockObj.Lock() defer r.lockObj.Unlock() + if reflect.TypeOf(c).Kind() != reflect.Ptr { + return fmt.Errorf("config must be a Pointer") + } + if !DeepEqual(r.config, c) { r.config = c r.isDirty.Store(true) diff --git a/flytestdlib/contextutils/context.go b/flytestdlib/contextutils/context.go index b5a9c00fa2..e5be11e219 100644 --- a/flytestdlib/contextutils/context.go +++ b/flytestdlib/contextutils/context.go @@ -4,22 +4,24 @@ package contextutils import ( "context" "fmt" + "runtime/pprof" ) type Key string const ( - AppNameKey Key = "app_name" - NamespaceKey Key = "ns" - TaskTypeKey Key = "tasktype" - ProjectKey Key = "project" - DomainKey Key = "domain" - WorkflowIDKey Key = "wf" - NodeIDKey Key = "node" - TaskIDKey Key = "task" - ExecIDKey Key = "exec_id" - JobIDKey Key = "job_id" - PhaseKey Key = "phase" + AppNameKey Key = "app_name" + NamespaceKey Key = "ns" + TaskTypeKey Key = "tasktype" + ProjectKey Key = "project" + DomainKey Key = "domain" + WorkflowIDKey Key = "wf" + NodeIDKey Key = "node" + TaskIDKey Key = "task" + ExecIDKey Key = "exec_id" + JobIDKey Key = "job_id" + PhaseKey Key = "phase" + RoutineLabelKey Key = "routine" ) func (k Key) String() string { @@ -35,6 +37,7 @@ var logKeys = []Key{ WorkflowIDKey, TaskTypeKey, PhaseKey, + RoutineLabelKey, } // Gets a new context with namespace set. @@ -98,6 +101,14 @@ func WithTaskType(ctx context.Context, taskType string) context.Context { return context.WithValue(ctx, TaskTypeKey, taskType) } +// Gets a new context with Go Routine label key set and a label assigned to the context using pprof.Labels. +// You can then call pprof.SetGoroutineLabels(ctx) to annotate the current go-routine and have that show up in +// pprof analysis. +func WithGoroutineLabel(ctx context.Context, routineLabel string) context.Context { + ctx = pprof.WithLabels(ctx, pprof.Labels(RoutineLabelKey.String(), routineLabel)) + return context.WithValue(ctx, RoutineLabelKey, routineLabel) +} + func addFieldIfNotNil(ctx context.Context, m map[string]interface{}, fieldKey Key) { val := ctx.Value(fieldKey) if val != nil { @@ -110,6 +121,7 @@ func addStringFieldWithDefaults(ctx context.Context, m map[string]string, fieldK if val == nil { val = "" } + m[fieldKey.String()] = val.(string) } @@ -120,6 +132,7 @@ func GetLogFields(ctx context.Context) map[string]interface{} { for _, k := range logKeys { addFieldIfNotNil(ctx, res, k) } + return res } @@ -128,6 +141,7 @@ func Value(ctx context.Context, key Key) string { if val != nil { return val.(string) } + return "" } @@ -136,5 +150,6 @@ func Values(ctx context.Context, keys ...Key) map[string]string { for _, k := range keys { addStringFieldWithDefaults(ctx, res, k) } + return res } diff --git a/flytestdlib/contextutils/context_test.go b/flytestdlib/contextutils/context_test.go index 99d29d3e3d..d65c2a2be3 100644 --- a/flytestdlib/contextutils/context_test.go +++ b/flytestdlib/contextutils/context_test.go @@ -2,6 +2,7 @@ package contextutils import ( "context" + "runtime/pprof" "testing" "github.com/stretchr/testify/assert" @@ -111,3 +112,12 @@ func TestValues(t *testing.T) { assert.Equal(t, "flyte", m[WorkflowIDKey.String()]) assert.Equal(t, "", m[ProjectKey.String()]) } + +func TestWithGoroutineLabel(t *testing.T) { + ctx := context.Background() + ctx = WithGoroutineLabel(ctx, "my_routine_123") + pprof.SetGoroutineLabels(ctx) + m := Values(ctx, RoutineLabelKey) + assert.Equal(t, 1, len(m)) + assert.Equal(t, "my_routine_123", m[RoutineLabelKey.String()]) +} diff --git a/flytestdlib/logger/config.go b/flytestdlib/logger/config.go index 1c665ac523..993c624579 100644 --- a/flytestdlib/logger/config.go +++ b/flytestdlib/logger/config.go @@ -21,12 +21,18 @@ const ( jsonDataKey string = "json" ) -var defaultConfig = &Config{ - Formatter: FormatterConfig{ - Type: FormatterJSON, - }, - Level: InfoLevel, -} +var ( + defaultConfig = &Config{ + Formatter: FormatterConfig{ + Type: FormatterJSON, + }, + Level: InfoLevel, + } + + configSection = config.MustRegisterSectionWithUpdates(configSectionKey, defaultConfig, func(ctx context.Context, newValue config.Config) { + onConfigUpdated(*newValue.(*Config)) + }) +) // Global logger config. type Config struct { @@ -47,13 +53,18 @@ type FormatterConfig struct { Type FormatterType `json:"type" pflag:",Sets logging format type."` } -var globalConfig = Config{} - // Sets global logger config -func SetConfig(cfg Config) { - globalConfig = cfg +func SetConfig(cfg *Config) error { + if err := configSection.SetConfig(cfg); err != nil { + return err + } - onConfigUpdated(cfg) + onConfigUpdated(*cfg) + return nil +} + +func GetConfig() *Config { + return configSection.GetConfig().(*Config) } // Level type. @@ -78,9 +89,3 @@ const ( // DebugLevel level. Usually only enabled when debugging. Very verbose logging. DebugLevel ) - -func init() { - config.MustRegisterSectionWithUpdates(configSectionKey, defaultConfig, func(ctx context.Context, newValue config.Config) { - SetConfig(*newValue.(*Config)) - }) -} diff --git a/flytestdlib/logger/config_test.go b/flytestdlib/logger/config_test.go index 7d2d3782b1..08be8dec6a 100644 --- a/flytestdlib/logger/config_test.go +++ b/flytestdlib/logger/config_test.go @@ -1,10 +1,14 @@ package logger -import "testing" +import ( + "testing" + + "github.com/stretchr/testify/assert" +) func TestSetConfig(t *testing.T) { type args struct { - cfg Config + cfg *Config } tests := []struct { name string @@ -14,7 +18,7 @@ func TestSetConfig(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - SetConfig(tt.args.cfg) + assert.NoError(t, SetConfig(tt.args.cfg)) }) } } diff --git a/flytestdlib/logger/logger.go b/flytestdlib/logger/logger.go index 29aadc8bf8..3d8eccc2db 100644 --- a/flytestdlib/logger/logger.go +++ b/flytestdlib/logger/logger.go @@ -17,7 +17,12 @@ import ( //go:generate gotests -w -all $FILE -const indentLevelKey contextutils.Key = "LoggerIndentLevel" +const ( + indentLevelKey contextutils.Key = "LoggerIndentLevel" + sourceCodeKey string = "src" +) + +var noopLogger = NoopLogger{} func onConfigUpdated(cfg Config) { logrus.SetLevel(logrus.Level(cfg.Level)) @@ -44,51 +49,35 @@ func onConfigUpdated(cfg Config) { } func getSourceLocation() string { - if globalConfig.IncludeSourceCode { - _, file, line, ok := runtime.Caller(3) - if !ok { - file = "???" - line = 1 - } else { - slash := strings.LastIndex(file, "/") - if slash >= 0 { - file = file[slash+1:] - } + // The reason we pass 3 here: 0 means this function (getSourceLocation), 1 means the getLogger function (only caller + // to getSourceLocation, 2 means the logging function (e.g. Debugln), and 3 means the caller for the logging function. + _, file, line, ok := runtime.Caller(3) + if !ok { + file = "???" + line = 1 + } else { + slash := strings.LastIndex(file, "/") + if slash >= 0 { + file = file[slash+1:] } - - return fmt.Sprintf("[%v:%v] ", file, line) } - return "" + return fmt.Sprintf("%v:%v", file, line) } -func wrapHeader(ctx context.Context, args ...interface{}) []interface{} { - args = append([]interface{}{getIndent(ctx)}, args...) - - if globalConfig.IncludeSourceCode { - return append( - []interface{}{ - fmt.Sprintf("%v", getSourceLocation()), - }, - args...) +func getLogger(ctx context.Context) logrus.FieldLogger { + cfg := GetConfig() + if cfg.Mute { + return noopLogger } - return args -} - -func wrapHeaderForMessage(ctx context.Context, message string) string { - message = fmt.Sprintf("%v%v", getIndent(ctx), message) - - if globalConfig.IncludeSourceCode { - return fmt.Sprintf("%v%v", getSourceLocation(), message) + entry := logrus.WithFields(logrus.Fields(contextutils.GetLogFields(ctx))) + if cfg.IncludeSourceCode { + entry = entry.WithField(sourceCodeKey, getSourceLocation()) } - return message -} + entry.Level = logrus.Level(cfg.Level) -func getLogger(ctx context.Context) *logrus.Entry { - entry := logrus.WithFields(logrus.Fields(contextutils.GetLogFields(ctx))) - entry.Level = logrus.Level(globalConfig.Level) return entry } @@ -107,231 +96,218 @@ func getIndent(ctx context.Context) string { // Gets a value indicating whether logs at this level will be written to the logger. This is particularly useful to avoid // computing log messages unnecessarily. -func IsLoggable(ctx context.Context, level Level) bool { - return getLogger(ctx).Level >= logrus.Level(level) +func IsLoggable(_ context.Context, level Level) bool { + return GetConfig().Level >= level } // Debug logs a message at level Debug on the standard logger. func Debug(ctx context.Context, args ...interface{}) { - if globalConfig.Mute { - return - } - - getLogger(ctx).Debug(wrapHeader(ctx, args)...) + getLogger(ctx).Debug(args...) } // Print logs a message at level Info on the standard logger. func Print(ctx context.Context, args ...interface{}) { - if globalConfig.Mute { - return - } - - getLogger(ctx).Print(wrapHeader(ctx, args)...) + getLogger(ctx).Print(args...) } // Info logs a message at level Info on the standard logger. func Info(ctx context.Context, args ...interface{}) { - if globalConfig.Mute { - return - } - - getLogger(ctx).Info(wrapHeader(ctx, args)...) + getLogger(ctx).Info(args...) } // Warn logs a message at level Warn on the standard logger. func Warn(ctx context.Context, args ...interface{}) { - if globalConfig.Mute { - return - } - - getLogger(ctx).Warn(wrapHeader(ctx, args)...) + getLogger(ctx).Warn(args...) } // Warning logs a message at level Warn on the standard logger. func Warning(ctx context.Context, args ...interface{}) { - if globalConfig.Mute { - return - } - - getLogger(ctx).Warning(wrapHeader(ctx, args)...) + getLogger(ctx).Warning(args...) } // Error logs a message at level Error on the standard logger. func Error(ctx context.Context, args ...interface{}) { - if globalConfig.Mute { - return - } - - getLogger(ctx).Error(wrapHeader(ctx, args)...) + getLogger(ctx).Error(args...) } // Panic logs a message at level Panic on the standard logger. func Panic(ctx context.Context, args ...interface{}) { - if globalConfig.Mute { - return - } - - getLogger(ctx).Panic(wrapHeader(ctx, args)...) + getLogger(ctx).Panic(args...) } // Fatal logs a message at level Fatal on the standard logger. func Fatal(ctx context.Context, args ...interface{}) { - if globalConfig.Mute { - return - } - - getLogger(ctx).Fatal(wrapHeader(ctx, args)...) + getLogger(ctx).Fatal(args...) } // Debugf logs a message at level Debug on the standard logger. func Debugf(ctx context.Context, format string, args ...interface{}) { - if globalConfig.Mute { - return - } - - getLogger(ctx).Debugf(wrapHeaderForMessage(ctx, format), args...) + getLogger(ctx).Debugf(format, args...) } // Printf logs a message at level Info on the standard logger. func Printf(ctx context.Context, format string, args ...interface{}) { - if globalConfig.Mute { - return - } - - getLogger(ctx).Printf(wrapHeaderForMessage(ctx, format), args...) + getLogger(ctx).Printf(format, args...) } // Infof logs a message at level Info on the standard logger. func Infof(ctx context.Context, format string, args ...interface{}) { - if globalConfig.Mute { - return - } - - getLogger(ctx).Infof(wrapHeaderForMessage(ctx, format), args...) + getLogger(ctx).Infof(format, args...) } // InfofNoCtx logs a formatted message without context. func InfofNoCtx(format string, args ...interface{}) { - if globalConfig.Mute { - return - } - getLogger(context.TODO()).Infof(format, args...) } // Warnf logs a message at level Warn on the standard logger. func Warnf(ctx context.Context, format string, args ...interface{}) { - if globalConfig.Mute { - return - } - - getLogger(ctx).Warnf(wrapHeaderForMessage(ctx, format), args...) + getLogger(ctx).Warnf(format, args...) } // Warningf logs a message at level Warn on the standard logger. func Warningf(ctx context.Context, format string, args ...interface{}) { - if globalConfig.Mute { - return - } - - getLogger(ctx).Warningf(wrapHeaderForMessage(ctx, format), args...) + getLogger(ctx).Warningf(format, args...) } // Errorf logs a message at level Error on the standard logger. func Errorf(ctx context.Context, format string, args ...interface{}) { - if globalConfig.Mute { - return - } - - getLogger(ctx).Errorf(wrapHeaderForMessage(ctx, format), args...) + getLogger(ctx).Errorf(format, args...) } // Panicf logs a message at level Panic on the standard logger. func Panicf(ctx context.Context, format string, args ...interface{}) { - if globalConfig.Mute { - return - } - - getLogger(ctx).Panicf(wrapHeaderForMessage(ctx, format), args...) + getLogger(ctx).Panicf(format, args...) } // Fatalf logs a message at level Fatal on the standard logger. func Fatalf(ctx context.Context, format string, args ...interface{}) { - if globalConfig.Mute { - return - } - - getLogger(ctx).Fatalf(wrapHeaderForMessage(ctx, format), args...) + getLogger(ctx).Fatalf(format, args...) } // Debugln logs a message at level Debug on the standard logger. func Debugln(ctx context.Context, args ...interface{}) { - if globalConfig.Mute { - return - } - - getLogger(ctx).Debugln(wrapHeader(ctx, args)...) + getLogger(ctx).Debugln(args...) } // Println logs a message at level Info on the standard logger. func Println(ctx context.Context, args ...interface{}) { - if globalConfig.Mute { - return - } - - getLogger(ctx).Println(wrapHeader(ctx, args)...) + getLogger(ctx).Println(args...) } // Infoln logs a message at level Info on the standard logger. func Infoln(ctx context.Context, args ...interface{}) { - if globalConfig.Mute { - return - } - - getLogger(ctx).Infoln(wrapHeader(ctx, args)...) + getLogger(ctx).Infoln(args...) } // Warnln logs a message at level Warn on the standard logger. func Warnln(ctx context.Context, args ...interface{}) { - if globalConfig.Mute { - return - } - - getLogger(ctx).Warnln(wrapHeader(ctx, args)...) + getLogger(ctx).Warnln(args...) } // Warningln logs a message at level Warn on the standard logger. func Warningln(ctx context.Context, args ...interface{}) { - if globalConfig.Mute { - return - } - - getLogger(ctx).Warningln(wrapHeader(ctx, args)...) + getLogger(ctx).Warningln(args...) } // Errorln logs a message at level Error on the standard logger. func Errorln(ctx context.Context, args ...interface{}) { - if globalConfig.Mute { - return - } - - getLogger(ctx).Errorln(wrapHeader(ctx, args)...) + getLogger(ctx).Errorln(args...) } // Panicln logs a message at level Panic on the standard logger. func Panicln(ctx context.Context, args ...interface{}) { - if globalConfig.Mute { - return - } - - getLogger(ctx).Panicln(wrapHeader(ctx, args)...) + getLogger(ctx).Panicln(args...) } // Fatalln logs a message at level Fatal on the standard logger. func Fatalln(ctx context.Context, args ...interface{}) { - if globalConfig.Mute { - return - } + getLogger(ctx).Fatalln(args...) +} + +type NoopLogger struct { +} + +func (NoopLogger) WithField(key string, value interface{}) *logrus.Entry { + return nil +} + +func (NoopLogger) WithFields(fields logrus.Fields) *logrus.Entry { + return nil +} + +func (NoopLogger) WithError(err error) *logrus.Entry { + return nil +} + +func (NoopLogger) Debugf(format string, args ...interface{}) { +} + +func (NoopLogger) Infof(format string, args ...interface{}) { +} + +func (NoopLogger) Warnf(format string, args ...interface{}) { +} + +func (NoopLogger) Warningf(format string, args ...interface{}) { +} + +func (NoopLogger) Errorf(format string, args ...interface{}) { +} + +func (NoopLogger) Debug(args ...interface{}) { +} + +func (NoopLogger) Info(args ...interface{}) { +} + +func (NoopLogger) Warn(args ...interface{}) { +} + +func (NoopLogger) Warning(args ...interface{}) { +} + +func (NoopLogger) Error(args ...interface{}) { +} + +func (NoopLogger) Debugln(args ...interface{}) { +} + +func (NoopLogger) Infoln(args ...interface{}) { +} + +func (NoopLogger) Warnln(args ...interface{}) { +} + +func (NoopLogger) Warningln(args ...interface{}) { +} + +func (NoopLogger) Errorln(args ...interface{}) { +} + +func (NoopLogger) Print(...interface{}) { +} + +func (NoopLogger) Printf(string, ...interface{}) { +} + +func (NoopLogger) Println(...interface{}) { +} + +func (NoopLogger) Fatal(...interface{}) { +} + +func (NoopLogger) Fatalf(string, ...interface{}) { +} + +func (NoopLogger) Fatalln(...interface{}) { +} + +func (NoopLogger) Panic(...interface{}) { +} + +func (NoopLogger) Panicf(string, ...interface{}) { +} - getLogger(ctx).Fatalln(wrapHeader(ctx, args)...) +func (NoopLogger) Panicln(...interface{}) { } diff --git a/flytestdlib/logger/logger_test.go b/flytestdlib/logger/logger_test.go index 75c73a9432..ce81002046 100644 --- a/flytestdlib/logger/logger_test.go +++ b/flytestdlib/logger/logger_test.go @@ -6,7 +6,6 @@ package logger import ( "context" "reflect" - "strings" "testing" "github.com/sirupsen/logrus" @@ -14,46 +13,11 @@ import ( ) func init() { - SetConfig(Config{ + if err := SetConfig(&Config{ Level: InfoLevel, IncludeSourceCode: true, - }) -} - -func Test_getSourceLocation(t *testing.T) { - tests := []struct { - name string - want string - }{ - {"current", " "}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := getSourceLocation(); !strings.HasSuffix(got, tt.want) { - t.Errorf("getSourceLocation() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_wrapHeaderForMessage(t *testing.T) { - type args struct { - message string - } - tests := []struct { - name string - args args - want string - }{ - {"no args", args{message: ""}, " "}, - {"1 arg", args{message: "hello"}, " hello"}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := wrapHeaderForMessage(context.TODO(), tt.args.message); !strings.HasSuffix(got, tt.want) { - t.Errorf("wrapHeaderForMessage() = %v, want %v", got, tt.want) - } - }) + }); err != nil { + panic(err) } } @@ -488,27 +452,6 @@ func TestPanicln(t *testing.T) { } } -func Test_wrapHeader(t *testing.T) { - type args struct { - ctx context.Context - args []interface{} - } - tests := []struct { - name string - args args - want []interface{} - }{ - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := wrapHeader(tt.args.ctx, tt.args.args...); !reflect.DeepEqual(got, tt.want) { - t.Errorf("wrapHeader() = %v, want %v", got, tt.want) - } - }) - } -} - func Test_getLogger(t *testing.T) { type args struct { ctx context.Context diff --git a/flytestdlib/profutils/server.go b/flytestdlib/profutils/server.go index fe0157fca9..a1681c191d 100644 --- a/flytestdlib/profutils/server.go +++ b/flytestdlib/profutils/server.go @@ -61,7 +61,14 @@ func healtcheckHandler(w http.ResponseWriter, req *http.Request) { // Handler that returns a JSON response indicating the Build Version information (refer to #version module) func versionHandler(w http.ResponseWriter, req *http.Request) { - err := WriteJSONResponse(w, http.StatusOK, BuildVersion{Build: version.Build, Version: version.Version, Timestamp: version.BuildTime}) + err := WriteJSONResponse( + w, + http.StatusOK, + BuildVersion{ + Build: version.Build, + Version: version.Version, + Timestamp: version.BuildTime, + }) if err != nil { panic(err) } From 6ed22f3d5b87e88d9e4ab2fb7122bb8f2e5a2c52 Mon Sep 17 00:00:00 2001 From: goreleaserbot Date: Tue, 30 Apr 2019 11:18:41 -0700 Subject: [PATCH 0043/1918] Scoop update for flytestdlib version v0.2.3 --- flytestdlib/flytestdlib.json | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/flytestdlib/flytestdlib.json b/flytestdlib/flytestdlib.json index b708048a40..4e010cb572 100644 --- a/flytestdlib/flytestdlib.json +++ b/flytestdlib/flytestdlib.json @@ -1,19 +1,19 @@ { - "version": "0.2.2", + "version": "0.2.3", "architecture": { "32bit": { - "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.2/flytestdlib_0.2.2_Windows_i386.zip", + "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.3/flytestdlib_0.2.3_Windows_i386.zip", "bin": [ "pflags.exe" ], - "hash": "3693ea92f967aa57568d1b646be44a8871261578e489113b2ff8d1926d47bf17" + "hash": "ac46479a6b1391b673b830bf0530beb2430a16f468c1d59a74301409c408d5f4" }, "64bit": { - "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.2/flytestdlib_0.2.2_Windows_x86_64.zip", + "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.3/flytestdlib_0.2.3_Windows_x86_64.zip", "bin": [ "pflags.exe" ], - "hash": "c073e0dce9b45630bdd862eb7c1a45bc000c664e7d25f74be2ca0768a516cf3d" + "hash": "80ccbd8fc038c66ee1d86724d401a8c58e496c859b8269d0ecf82b2ecf2db54b" } }, "homepage": "https://godoc.org/github.com/lyft/flytestdlib", From 35afe44d8a1d282717a1e3d7717290d5f7e8ece8 Mon Sep 17 00:00:00 2001 From: goreleaserbot Date: Mon, 6 May 2019 14:59:15 -0700 Subject: [PATCH 0044/1918] Scoop update for flytestdlib version v0.2.4 --- flytestdlib/flytestdlib.json | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/flytestdlib/flytestdlib.json b/flytestdlib/flytestdlib.json index 4e010cb572..90a16b2775 100644 --- a/flytestdlib/flytestdlib.json +++ b/flytestdlib/flytestdlib.json @@ -1,19 +1,19 @@ { - "version": "0.2.3", + "version": "0.2.4", "architecture": { "32bit": { - "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.3/flytestdlib_0.2.3_Windows_i386.zip", + "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.4/flytestdlib_0.2.4_Windows_i386.zip", "bin": [ "pflags.exe" ], - "hash": "ac46479a6b1391b673b830bf0530beb2430a16f468c1d59a74301409c408d5f4" + "hash": "4677e67045c65d71d026dcf7910bd7af6d59b062eef0754a368b19528995796d" }, "64bit": { - "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.3/flytestdlib_0.2.3_Windows_x86_64.zip", + "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.4/flytestdlib_0.2.4_Windows_x86_64.zip", "bin": [ "pflags.exe" ], - "hash": "80ccbd8fc038c66ee1d86724d401a8c58e496c859b8269d0ecf82b2ecf2db54b" + "hash": "6cfd92c39d98390995034bf2bbdba41920d5388787f18146dd62feb56d4a1b7b" } }, "homepage": "https://godoc.org/github.com/lyft/flytestdlib", From be32824b2d27bd5bc7fc0df416f9b844de9a1930 Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Wed, 8 May 2019 13:39:26 -0700 Subject: [PATCH 0045/1918] InitialzePFlags doesn't add flags for sub-sections (#17) * Allow InitializePFlags to recursively discover sub-section flags * Add unit test --- flytestdlib/config/tests/accessor_test.go | 30 +++++++++++++++++++++++ flytestdlib/config/viper/viper.go | 13 ++++++++-- 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/flytestdlib/config/tests/accessor_test.go b/flytestdlib/config/tests/accessor_test.go index e6f93095a1..18a81e5f11 100644 --- a/flytestdlib/config/tests/accessor_test.go +++ b/flytestdlib/config/tests/accessor_test.go @@ -159,6 +159,36 @@ func TestAccessor_InitializePflags(t *testing.T) { assert.Equal(t, 4, otherC.IntValue) assert.Equal(t, []string{"default value"}, otherC.StringArrayWithDefaults) }) + + t.Run(fmt.Sprintf("[%v] Sub-sections", provider(config.Options{}).ID()), func(t *testing.T) { + reg := config.NewRootSection() + sec, err := reg.RegisterSection(MyComponentSectionKey, &MyComponentConfig{}) + assert.NoError(t, err) + + _, err = sec.RegisterSection("nested", &OtherComponentConfig{}) + assert.NoError(t, err) + + v := provider(config.Options{ + SearchPaths: []string{filepath.Join("testdata", "nested_config.yaml")}, + RootSection: reg, + }) + + set := pflag.NewFlagSet("test", pflag.ExitOnError) + v.InitializePflags(set) + assert.NoError(t, set.Parse([]string{"--my-component.nested.int-val=3"})) + assert.True(t, set.Parsed()) + + flagValue, err := set.GetInt("my-component.nested.int-val") + assert.NoError(t, err) + assert.Equal(t, 3, flagValue) + + assert.NoError(t, v.UpdateConfig(context.TODO())) + r := reg.GetSection(MyComponentSectionKey).GetConfig().(*MyComponentConfig) + assert.Equal(t, "Hello World", r.StringValue) + + nested := sec.GetSection("nested").GetConfig().(*OtherComponentConfig) + assert.Equal(t, 3, nested.IntValue) + }) } } diff --git a/flytestdlib/config/viper/viper.go b/flytestdlib/config/viper/viper.go index 05a5378ec7..c9df21fabc 100644 --- a/flytestdlib/config/viper/viper.go +++ b/flytestdlib/config/viper/viper.go @@ -64,9 +64,18 @@ func (v viperAccessor) InitializePflags(cmdFlags *pflag.FlagSet) { } func (v viperAccessor) addSectionsPFlags(flags *pflag.FlagSet) (err error) { - for key, section := range v.rootConfig.GetSections() { + return v.addSubsectionsPFlags(flags, "", v.rootConfig) +} + +func (v viperAccessor) addSubsectionsPFlags(flags *pflag.FlagSet, rootKey string, root config.Section) error { + for key, section := range root.GetSections() { + prefix := rootKey + key + keyDelim if asPFlagProvider, ok := section.GetConfig().(config.PFlagProvider); ok { - flags.AddFlagSet(asPFlagProvider.GetPFlagSet(key + keyDelim)) + flags.AddFlagSet(asPFlagProvider.GetPFlagSet(prefix)) + } + + if err := v.addSubsectionsPFlags(flags, prefix, section); err != nil { + return err } } From 0da638dc8a8126aed8338fa7367cdc8382248430 Mon Sep 17 00:00:00 2001 From: goreleaserbot Date: Wed, 8 May 2019 13:45:05 -0700 Subject: [PATCH 0046/1918] Scoop update for flytestdlib version v0.2.5 --- flytestdlib/flytestdlib.json | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/flytestdlib/flytestdlib.json b/flytestdlib/flytestdlib.json index 90a16b2775..92058464fd 100644 --- a/flytestdlib/flytestdlib.json +++ b/flytestdlib/flytestdlib.json @@ -1,19 +1,19 @@ { - "version": "0.2.4", + "version": "0.2.5", "architecture": { "32bit": { - "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.4/flytestdlib_0.2.4_Windows_i386.zip", + "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.5/flytestdlib_0.2.5_Windows_i386.zip", "bin": [ "pflags.exe" ], - "hash": "4677e67045c65d71d026dcf7910bd7af6d59b062eef0754a368b19528995796d" + "hash": "94a8fdd1306fcb515f0fa473eaf769d16f3fcbcacba24c0e65b965b316b171d6" }, "64bit": { - "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.4/flytestdlib_0.2.4_Windows_x86_64.zip", + "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.5/flytestdlib_0.2.5_Windows_x86_64.zip", "bin": [ "pflags.exe" ], - "hash": "6cfd92c39d98390995034bf2bbdba41920d5388787f18146dd62feb56d4a1b7b" + "hash": "37d491a54ef214961e77dc39a7b332214748cffd6d73f02c220fe02c721413dd" } }, "homepage": "https://godoc.org/github.com/lyft/flytestdlib", From 3fae7a3c0bc386eb32a586e5b70117333ab8c3c8 Mon Sep 17 00:00:00 2001 From: Katrina Rogan Date: Thu, 30 May 2019 13:57:38 -0700 Subject: [PATCH 0047/1918] Fix s3 container initialization: don't error when container creation succeeds (#19) --- flytestdlib/storage/s3store.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/flytestdlib/storage/s3store.go b/flytestdlib/storage/s3store.go index 3c96730e0a..c2a6d9e290 100644 --- a/flytestdlib/storage/s3store.go +++ b/flytestdlib/storage/s3store.go @@ -68,8 +68,9 @@ func newS3RawStore(cfg *Config, metricsScope promutils.Scope) (RawStore, error) logger.Infof(context.TODO(), "Storage init-container already exists [%v].", cfg.InitContainer) return NewStowRawStore(s3FQN(c.Name()), c, metricsScope) } + return emptyStore, fmt.Errorf("unable to initialize container [%v]. Error: %v", cfg.InitContainer, err) } - return emptyStore, fmt.Errorf("unable to initialize container [%v]. Error: %v", cfg.InitContainer, err) + return NewStowRawStore(s3FQN(c.Name()), c, metricsScope) } return emptyStore, err } From e03f2b7ac1e13bab82dce31d07c02b30db5d4513 Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Thu, 6 Jun 2019 13:23:51 -0700 Subject: [PATCH 0048/1918] Call onChange on first config load (#20) * Call onChange on first config load * Ensure we have a repro in unit test & lint --- flytestdlib/config/tests/accessor_test.go | 25 +++++++++++++++++++++++ flytestdlib/config/viper/viper.go | 14 ++++++------- 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/flytestdlib/config/tests/accessor_test.go b/flytestdlib/config/tests/accessor_test.go index 18a81e5f11..11511f7a61 100644 --- a/flytestdlib/config/tests/accessor_test.go +++ b/flytestdlib/config/tests/accessor_test.go @@ -331,6 +331,31 @@ func TestAccessor_UpdateConfig(t *testing.T) { }) t.Run(fmt.Sprintf("[%v] Change handler", provider(config.Options{}).ID()), func(t *testing.T) { + configFile := tempFileName("config-*.yaml") + defer func() { assert.NoError(t, os.Remove(configFile)) }() + cfg, err := populateConfigData(configFile) + assert.NoError(t, err) + + reg := config.NewRootSection() + called := false + _, err = reg.RegisterSectionWithUpdates(MyComponentSectionKey, &cfg.MyComponentConfig, + func(ctx context.Context, newValue config.Config) { + called = true + }) + assert.NoError(t, err) + + opts := config.Options{ + SearchPaths: []string{configFile}, + RootSection: reg, + } + v := provider(opts) + err = v.UpdateConfig(context.TODO()) + assert.NoError(t, err) + + assert.True(t, called) + }) + + t.Run(fmt.Sprintf("[%v] Change handler on change", provider(config.Options{}).ID()), func(t *testing.T) { configFile := tempFileName("config-*.yaml") defer func() { assert.NoError(t, os.Remove(configFile)) }() _, err := populateConfigData(configFile) diff --git a/flytestdlib/config/viper/viper.go b/flytestdlib/config/viper/viper.go index c9df21fabc..3b44211f31 100644 --- a/flytestdlib/config/viper/viper.go +++ b/flytestdlib/config/viper/viper.go @@ -149,7 +149,7 @@ func (v viperAccessor) updateConfig(ctx context.Context, r config.Section) error }) } - return v.RefreshFromConfig(ctx, r) + return v.RefreshFromConfig(ctx, r, true) } func (v viperAccessor) UpdateConfig(ctx context.Context) error { @@ -287,7 +287,7 @@ func decode(input interface{}, config *mapstructure.DecoderConfig) error { func (v viperAccessor) configChangeHandler() { ctx := context.Background() - err := v.RefreshFromConfig(ctx, v.rootConfig) + err := v.RefreshFromConfig(ctx, v.rootConfig, false) if err != nil { // TODO: Retry? panic? logger.Printf(ctx, "Failed to update config. Error: %v", err) @@ -296,20 +296,20 @@ func (v viperAccessor) configChangeHandler() { } } -func (v viperAccessor) RefreshFromConfig(ctx context.Context, r config.Section) error { +func (v viperAccessor) RefreshFromConfig(ctx context.Context, r config.Section, forceSendUpdates bool) error { err := v.parseViperConfig(r) if err != nil { return err } - v.sendUpdatedEvents(ctx, r, "") + v.sendUpdatedEvents(ctx, r, forceSendUpdates, "") return nil } -func (v viperAccessor) sendUpdatedEvents(ctx context.Context, root config.Section, sectionKey config.SectionKey) { +func (v viperAccessor) sendUpdatedEvents(ctx context.Context, root config.Section, forceSend bool, sectionKey config.SectionKey) { for key, section := range root.GetSections() { - if !section.GetConfigChangedAndClear() { + if !section.GetConfigChangedAndClear() && !forceSend { logger.Infof(ctx, "Config section [%v] hasn't changed.", sectionKey+key) } else if section.GetConfigUpdatedHandler() == nil { logger.Infof(ctx, "Config section [%v] updated. No update handler registered.", sectionKey+key) @@ -318,7 +318,7 @@ func (v viperAccessor) sendUpdatedEvents(ctx context.Context, root config.Sectio section.GetConfigUpdatedHandler()(ctx, section.GetConfig()) } - v.sendUpdatedEvents(ctx, section, sectionKey+key+keyDelim) + v.sendUpdatedEvents(ctx, section, forceSend, sectionKey+key+keyDelim) } } From 2a8caf22a40a600de32fb4cf7b39eaed3256805c Mon Sep 17 00:00:00 2001 From: goreleaserbot Date: Thu, 6 Jun 2019 13:32:46 -0700 Subject: [PATCH 0049/1918] Scoop update for flytestdlib version v0.2.6 --- flytestdlib/flytestdlib.json | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/flytestdlib/flytestdlib.json b/flytestdlib/flytestdlib.json index 92058464fd..b88d443d77 100644 --- a/flytestdlib/flytestdlib.json +++ b/flytestdlib/flytestdlib.json @@ -1,19 +1,19 @@ { - "version": "0.2.5", + "version": "0.2.6", "architecture": { "32bit": { - "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.5/flytestdlib_0.2.5_Windows_i386.zip", + "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.6/flytestdlib_0.2.6_Windows_i386.zip", "bin": [ "pflags.exe" ], - "hash": "94a8fdd1306fcb515f0fa473eaf769d16f3fcbcacba24c0e65b965b316b171d6" + "hash": "2cb866af98ded87b014b76c7077cd69004b0d02c9eec7301d812f8979ea8b26b" }, "64bit": { - "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.5/flytestdlib_0.2.5_Windows_x86_64.zip", + "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.6/flytestdlib_0.2.6_Windows_x86_64.zip", "bin": [ "pflags.exe" ], - "hash": "37d491a54ef214961e77dc39a7b332214748cffd6d73f02c220fe02c721413dd" + "hash": "0e201ad563922638e9c86ce75672c2acc8e3c2eba4a2868b2d34fc183a767a12" } }, "homepage": "https://godoc.org/github.com/lyft/flytestdlib", From 3964f0125ef3a14c0af09b4d84193894ae845dae Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Tue, 11 Jun 2019 10:32:04 -0700 Subject: [PATCH 0050/1918] Fix copy implementation (#23) --- flytestdlib/storage/copy_impl.go | 1 + 1 file changed, 1 insertion(+) diff --git a/flytestdlib/storage/copy_impl.go b/flytestdlib/storage/copy_impl.go index 43f97f026f..4ac132a229 100644 --- a/flytestdlib/storage/copy_impl.go +++ b/flytestdlib/storage/copy_impl.go @@ -39,6 +39,7 @@ func (c copyImpl) CopyRaw(ctx context.Context, source, destination DataReference return err } + rc = ioutils.NewBytesReadCloser(raw) length = int64(len(raw)) } From 662d42107a3ce54e64ade3935b24ff36577e0c7f Mon Sep 17 00:00:00 2001 From: goreleaserbot Date: Tue, 11 Jun 2019 10:37:49 -0700 Subject: [PATCH 0051/1918] Scoop update for flytestdlib version v0.2.7 --- flytestdlib/flytestdlib.json | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/flytestdlib/flytestdlib.json b/flytestdlib/flytestdlib.json index b88d443d77..8fd5487df2 100644 --- a/flytestdlib/flytestdlib.json +++ b/flytestdlib/flytestdlib.json @@ -1,19 +1,19 @@ { - "version": "0.2.6", + "version": "0.2.7", "architecture": { "32bit": { - "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.6/flytestdlib_0.2.6_Windows_i386.zip", + "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.7/flytestdlib_0.2.7_Windows_i386.zip", "bin": [ "pflags.exe" ], - "hash": "2cb866af98ded87b014b76c7077cd69004b0d02c9eec7301d812f8979ea8b26b" + "hash": "03799847f32b1694d8c51d6bcf0cde70eed22a3fca6516d8921644ec9b6f48b2" }, "64bit": { - "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.6/flytestdlib_0.2.6_Windows_x86_64.zip", + "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.7/flytestdlib_0.2.7_Windows_x86_64.zip", "bin": [ "pflags.exe" ], - "hash": "0e201ad563922638e9c86ce75672c2acc8e3c2eba4a2868b2d34fc183a767a12" + "hash": "6a2678cd81bde9130717c8bff4f7c1b1d3f28cfe33f705609ba2b95542645cbe" } }, "homepage": "https://godoc.org/github.com/lyft/flytestdlib", From bb2445fae5cfc0b3806e69018761196bca2d0726 Mon Sep 17 00:00:00 2001 From: Anand Swaminathan Date: Tue, 11 Jun 2019 14:42:48 -0700 Subject: [PATCH 0052/1918] Expose Log writer that can be used with external libraries (#22) * Flinkk8soperator needs Log writer to be passed to --- flytestdlib/logger/logger.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/flytestdlib/logger/logger.go b/flytestdlib/logger/logger.go index 3d8eccc2db..f1c847542c 100644 --- a/flytestdlib/logger/logger.go +++ b/flytestdlib/logger/logger.go @@ -5,6 +5,7 @@ package logger import ( "context" + "io" "github.com/lyft/flytestdlib/contextutils" @@ -81,6 +82,12 @@ func getLogger(ctx context.Context) logrus.FieldLogger { return entry } +// Returns a standard io.PipeWriter that logs using the same logger configurations in this package. +func GetLogWriter(ctx context.Context) *io.PipeWriter { + logger := getLogger(ctx) + return logger.(*logrus.Entry).Writer() +} + func WithIndent(ctx context.Context, additionalIndent string) context.Context { indentLevel := getIndent(ctx) + additionalIndent return context.WithValue(ctx, indentLevelKey, indentLevel) From e190203e38ed21aeef1cc280b546189a15f7e883 Mon Sep 17 00:00:00 2001 From: goreleaserbot Date: Tue, 11 Jun 2019 15:03:55 -0700 Subject: [PATCH 0053/1918] Scoop update for flytestdlib version v0.2.8 --- flytestdlib/flytestdlib.json | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/flytestdlib/flytestdlib.json b/flytestdlib/flytestdlib.json index 8fd5487df2..4cf7fabf81 100644 --- a/flytestdlib/flytestdlib.json +++ b/flytestdlib/flytestdlib.json @@ -1,19 +1,19 @@ { - "version": "0.2.7", + "version": "0.2.8", "architecture": { "32bit": { - "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.7/flytestdlib_0.2.7_Windows_i386.zip", + "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.8/flytestdlib_0.2.8_Windows_i386.zip", "bin": [ "pflags.exe" ], - "hash": "03799847f32b1694d8c51d6bcf0cde70eed22a3fca6516d8921644ec9b6f48b2" + "hash": "5deaa24c632a8aab444d7c6b02b04cc20f864cba9f190f84de1295f359bbbe82" }, "64bit": { - "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.7/flytestdlib_0.2.7_Windows_x86_64.zip", + "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.8/flytestdlib_0.2.8_Windows_x86_64.zip", "bin": [ "pflags.exe" ], - "hash": "6a2678cd81bde9130717c8bff4f7c1b1d3f28cfe33f705609ba2b95542645cbe" + "hash": "4526711f5d3504576bedc9c8db5ad41029f30c3e336e1d42231c4a8ad62af3a3" } }, "homepage": "https://godoc.org/github.com/lyft/flytestdlib", From 975e9ee33d694ce11368332a48e0752907a4eb1d Mon Sep 17 00:00:00 2001 From: Anand Swaminathan Date: Mon, 1 Jul 2019 14:32:37 -0700 Subject: [PATCH 0054/1918] Upgrade Client go and implement new methods (#25) * Upgrade Client go and implement new methods --- flytestdlib/Gopkg.lock | 28 ++++++++++++++-------------- flytestdlib/Gopkg.toml | 4 ++-- flytestdlib/promutils/workqueue.go | 20 ++++++++++++++++++++ 3 files changed, 36 insertions(+), 16 deletions(-) diff --git a/flytestdlib/Gopkg.lock b/flytestdlib/Gopkg.lock index c2d2438f89..0fcc602f6f 100644 --- a/flytestdlib/Gopkg.lock +++ b/flytestdlib/Gopkg.lock @@ -127,14 +127,6 @@ revision = "0ca9ea5df5451ffdf184b4428c902747c2c11cd7" version = "v1.0.0" -[[projects]] - branch = "master" - digest = "1:1ba1d79f2810270045c328ae5d674321db34e3aae468eb4233883b473c5c0467" - name = "github.com/golang/glog" - packages = ["."] - pruneopts = "UT" - revision = "23def4e6c14b4da8ac2ed8007337bc5eb5007998" - [[projects]] digest = "1:9d6dc4d6de69b330d0de86494d6db90c09848c003d5db748f40c925f865c8534" name = "github.com/golang/protobuf" @@ -465,7 +457,7 @@ version = "v2.2.2" [[projects]] - digest = "1:5922c4db083d03579c576df514f096003f422b602aeb30028aedd892b69a4876" + digest = "1:074fb0a8da1e416b8a201e8e664c303ae610f316ffd615b678b636d27c225412" name = "k8s.io/apimachinery" packages = [ "pkg/util/clock", @@ -474,16 +466,24 @@ "pkg/util/wait", ] pruneopts = "UT" - revision = "103fd098999dc9c0c88536f5c9ad2e5da39373ae" - version = "kubernetes-1.11.2" + revision = "2b1284ed4c93a43499e781493253e2ac5959c4fd" + version = "kubernetes-1.13.1" [[projects]] - digest = "1:8d66fef1249b9b2105840377af3bab078604d3c298058f563685e88d2a9e6ad3" + digest = "1:b6412f8acd9a9fc6fb67302c24966618b16501b9d769a20bee42ce61e510c92c" name = "k8s.io/client-go" packages = ["util/workqueue"] pruneopts = "UT" - revision = "1f13a808da65775f22cbf47862c4e5898d8f4ca1" - version = "kubernetes-1.11.2" + revision = "8d9ed539ba3134352c586810e749e58df4e94e4f" + version = "kubernetes-1.13.1" + +[[projects]] + digest = "1:c283ca5951eb7d723d3300762f96ff94c2ea11eaceb788279e2b7327f92e4f2a" + name = "k8s.io/klog" + packages = ["."] + pruneopts = "UT" + revision = "d98d8acdac006fb39831f1b25640813fef9c314f" + version = "v0.3.3" [solve-meta] analyzer-name = "dep" diff --git a/flytestdlib/Gopkg.toml b/flytestdlib/Gopkg.toml index 3d6cf08f6f..fdf90bba03 100644 --- a/flytestdlib/Gopkg.toml +++ b/flytestdlib/Gopkg.toml @@ -57,11 +57,11 @@ [[constraint]] name = "k8s.io/apimachinery" - version = "kubernetes-1.11.2" + version = "kubernetes-1.13.1" [[constraint]] name = "k8s.io/client-go" - version = "kubernetes-1.11.2" + version = "kubernetes-1.13.1" [[constraint]] name = "github.com/graymeta/stow" diff --git a/flytestdlib/promutils/workqueue.go b/flytestdlib/promutils/workqueue.go index 24de6a0053..1911c0894c 100644 --- a/flytestdlib/promutils/workqueue.go +++ b/flytestdlib/promutils/workqueue.go @@ -31,6 +31,26 @@ func init() { type prometheusMetricsProvider struct{} +func (prometheusMetricsProvider) NewUnfinishedWorkSecondsMetric(name string) workqueue.SettableGaugeMetric { + unfinishedWork := prometheus.NewGauge(prometheus.GaugeOpts{ + Subsystem: name, + Name: "unfinished_work_s", + Help: "How many seconds of work in progress in workqueue: " + name, + }) + prometheus.MustRegister(unfinishedWork) + return unfinishedWork +} + +func (prometheusMetricsProvider) NewLongestRunningProcessorMicrosecondsMetric(name string) workqueue.SettableGaugeMetric { + unfinishedWork := prometheus.NewGauge(prometheus.GaugeOpts{ + Subsystem: name, + Name: "longest_running_processor_us", + Help: "How many microseconds longest running processor from workqueue" + name + " takes.", + }) + prometheus.MustRegister(unfinishedWork) + return unfinishedWork +} + func (prometheusMetricsProvider) NewDepthMetric(name string) workqueue.GaugeMetric { depth := prometheus.NewGauge(prometheus.GaugeOpts{ Subsystem: name, From b5e6bfd17909777eddf7189a4e2ee16fea162ff3 Mon Sep 17 00:00:00 2001 From: goreleaserbot Date: Mon, 1 Jul 2019 14:46:35 -0700 Subject: [PATCH 0055/1918] Scoop update for flytestdlib version v0.2.9 --- flytestdlib/flytestdlib.json | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/flytestdlib/flytestdlib.json b/flytestdlib/flytestdlib.json index 4cf7fabf81..25372948f8 100644 --- a/flytestdlib/flytestdlib.json +++ b/flytestdlib/flytestdlib.json @@ -1,19 +1,19 @@ { - "version": "0.2.8", + "version": "0.2.9", "architecture": { "32bit": { - "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.8/flytestdlib_0.2.8_Windows_i386.zip", + "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.9/flytestdlib_0.2.9_Windows_i386.zip", "bin": [ "pflags.exe" ], - "hash": "5deaa24c632a8aab444d7c6b02b04cc20f864cba9f190f84de1295f359bbbe82" + "hash": "d11a94af2328f0971b5010619d143c4f6f847e4b204fe671ba849a131ac68b7c" }, "64bit": { - "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.8/flytestdlib_0.2.8_Windows_x86_64.zip", + "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.9/flytestdlib_0.2.9_Windows_x86_64.zip", "bin": [ "pflags.exe" ], - "hash": "4526711f5d3504576bedc9c8db5ad41029f30c3e336e1d42231c4a8ad62af3a3" + "hash": "d78c7e1f7f3d35afaa49a770e0a4c1e55678923bc7bb5e7b7f698fe8efd2f349" } }, "homepage": "https://godoc.org/github.com/lyft/flytestdlib", From 00f5d1a38bcb37b8f28159b26aad6ed981ed7c76 Mon Sep 17 00:00:00 2001 From: Anand Swaminathan Date: Tue, 2 Jul 2019 16:55:52 -0700 Subject: [PATCH 0056/1918] Set provider only if all methods are implemented (#26) * Set provider only if all methods of interface are implemented --- flytestdlib/promutils/workqueue.go | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/flytestdlib/promutils/workqueue.go b/flytestdlib/promutils/workqueue.go index 1911c0894c..f7bdd055f0 100644 --- a/flytestdlib/promutils/workqueue.go +++ b/flytestdlib/promutils/workqueue.go @@ -17,6 +17,8 @@ limitations under the License. package promutils import ( + "fmt" + "k8s.io/client-go/util/workqueue" "github.com/prometheus/client_golang/prometheus" @@ -26,7 +28,14 @@ import ( // prometheus metrics. To use this package, you just have to import it. func init() { - workqueue.SetProvider(prometheusMetricsProvider{}) + var provider interface{} //nolint + provider = prometheusMetricsProvider{} + if p, casted := provider.(workqueue.MetricsProvider); casted { + workqueue.SetProvider(p) + } else { + // This case happens in future versions of client-go where the interface has added methods + fmt.Println("Warn: No metricsProvider set for the workqueue") + } } type prometheusMetricsProvider struct{} From ed2c0a30d3a98b4d9b95b273933e2ba3d02af5b7 Mon Sep 17 00:00:00 2001 From: goreleaserbot Date: Tue, 2 Jul 2019 17:16:20 -0700 Subject: [PATCH 0057/1918] Scoop update for flytestdlib version v0.2.10 --- flytestdlib/flytestdlib.json | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/flytestdlib/flytestdlib.json b/flytestdlib/flytestdlib.json index 25372948f8..20e5425e3a 100644 --- a/flytestdlib/flytestdlib.json +++ b/flytestdlib/flytestdlib.json @@ -1,19 +1,19 @@ { - "version": "0.2.9", + "version": "0.2.10", "architecture": { "32bit": { - "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.9/flytestdlib_0.2.9_Windows_i386.zip", + "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.10/flytestdlib_0.2.10_Windows_i386.zip", "bin": [ "pflags.exe" ], - "hash": "d11a94af2328f0971b5010619d143c4f6f847e4b204fe671ba849a131ac68b7c" + "hash": "0581a128328931280f2d60ff885dd1a2063776a622a710135ce846d66be442e1" }, "64bit": { - "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.9/flytestdlib_0.2.9_Windows_x86_64.zip", + "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.10/flytestdlib_0.2.10_Windows_x86_64.zip", "bin": [ "pflags.exe" ], - "hash": "d78c7e1f7f3d35afaa49a770e0a4c1e55678923bc7bb5e7b7f698fe8efd2f349" + "hash": "061dd31294bc2498d8cb939f164ca7518e23ea14978d7ea06244c3d43ea3b9e6" } }, "homepage": "https://godoc.org/github.com/lyft/flytestdlib", From 112e471ac7c464d48aad313450cc17a70f0cc15a Mon Sep 17 00:00:00 2001 From: matthewphsmith Date: Mon, 8 Jul 2019 11:41:24 -0700 Subject: [PATCH 0058/1918] Add lp as an optional label for stats (#24) * Add lp as an optional label for stats * goimports * Add lp as log key * Allow for testing of singleton metric keys * remove init() function from test package since it is unclear what the effects of that would be based on ordering of tests --- flytestdlib/contextutils/context.go | 7 +++++++ flytestdlib/contextutils/context_test.go | 7 +++++++ flytestdlib/promutils/labeled/counter_test.go | 7 ++++++- flytestdlib/promutils/labeled/keys.go | 17 ++++++++++++++--- flytestdlib/promutils/labeled/keys_test.go | 3 ++- flytestdlib/promutils/labeled/stopwatch_test.go | 1 + flytestdlib/storage/cached_rawstore_test.go | 11 +++++++---- flytestdlib/utils/marshal_utils_test.go | 2 +- 8 files changed, 45 insertions(+), 10 deletions(-) diff --git a/flytestdlib/contextutils/context.go b/flytestdlib/contextutils/context.go index e5be11e219..fd0b5fac0b 100644 --- a/flytestdlib/contextutils/context.go +++ b/flytestdlib/contextutils/context.go @@ -22,6 +22,7 @@ const ( JobIDKey Key = "job_id" PhaseKey Key = "phase" RoutineLabelKey Key = "routine" + LaunchPlanIDKey Key = "lp" ) func (k Key) String() string { @@ -38,6 +39,7 @@ var logKeys = []Key{ TaskTypeKey, PhaseKey, RoutineLabelKey, + LaunchPlanIDKey, } // Gets a new context with namespace set. @@ -85,6 +87,11 @@ func WithWorkflowID(ctx context.Context, workflow string) context.Context { return context.WithValue(ctx, WorkflowIDKey, workflow) } +// Gets a new context with a launch plan ID set. +func WithLaunchPlanID(ctx context.Context, launchPlan string) context.Context { + return context.WithValue(ctx, LaunchPlanIDKey, launchPlan) +} + // Get new context with Project and Domain values set func WithProjectDomain(ctx context.Context, project, domain string) context.Context { c := context.WithValue(ctx, ProjectKey, project) diff --git a/flytestdlib/contextutils/context_test.go b/flytestdlib/contextutils/context_test.go index d65c2a2be3..e2effe3a6b 100644 --- a/flytestdlib/contextutils/context_test.go +++ b/flytestdlib/contextutils/context_test.go @@ -69,6 +69,13 @@ func TestWithWorkflowID(t *testing.T) { assert.Equal(t, "flyte", ctx.Value(WorkflowIDKey)) } +func TestWithLaunchPlanID(t *testing.T) { + ctx := context.Background() + assert.Nil(t, ctx.Value(LaunchPlanIDKey)) + ctx = WithLaunchPlanID(ctx, "flytelp") + assert.Equal(t, "flytelp", ctx.Value(LaunchPlanIDKey)) +} + func TestWithNodeID(t *testing.T) { ctx := context.Background() assert.Nil(t, ctx.Value(NodeIDKey)) diff --git a/flytestdlib/promutils/labeled/counter_test.go b/flytestdlib/promutils/labeled/counter_test.go index 130b8217a2..e427026b43 100644 --- a/flytestdlib/promutils/labeled/counter_test.go +++ b/flytestdlib/promutils/labeled/counter_test.go @@ -10,8 +10,9 @@ import ( ) func TestLabeledCounter(t *testing.T) { + UnsetMetricKeys() assert.NotPanics(t, func() { - SetMetricKeys(contextutils.ProjectKey, contextutils.DomainKey, contextutils.WorkflowIDKey, contextutils.TaskIDKey) + SetMetricKeys(contextutils.ProjectKey, contextutils.DomainKey, contextutils.WorkflowIDKey, contextutils.TaskIDKey, contextutils.LaunchPlanIDKey) }) scope := promutils.NewTestScope() @@ -28,4 +29,8 @@ func TestLabeledCounter(t *testing.T) { ctx = contextutils.WithTaskID(ctx, "task") c.Inc(ctx) c.Add(ctx, 1.0) + + ctx = contextutils.WithLaunchPlanID(ctx, "lp") + c.Inc(ctx) + c.Add(ctx, 1.0) } diff --git a/flytestdlib/promutils/labeled/keys.go b/flytestdlib/promutils/labeled/keys.go index d8c8683750..7727a0dfa3 100644 --- a/flytestdlib/promutils/labeled/keys.go +++ b/flytestdlib/promutils/labeled/keys.go @@ -15,11 +15,11 @@ var ( // Metric Keys to label metrics with. These keys get pulled from context if they are present. Use contextutils to fill // them in. - metricKeys = make([]contextutils.Key, 0) + metricKeys []contextutils.Key // :(, we have to create a separate list to satisfy the MustNewCounterVec API as it accepts string only - metricStringKeys = make([]string, 0) - metricKeysAreSet = sync.Once{} + metricStringKeys []string + metricKeysAreSet sync.Once ) // Sets keys to use with labeled metrics. The values of these keys will be pulled from context at runtime. @@ -45,3 +45,14 @@ func SetMetricKeys(keys ...contextutils.Key) { func GetUnlabeledMetricName(metricName string) string { return metricName + "_unlabeled" } + +// Warning: This function is not thread safe and should be used for testing only outside of this package. +func UnsetMetricKeys() { + metricKeys = make([]contextutils.Key, 0) + metricStringKeys = make([]string, 0) + metricKeysAreSet = sync.Once{} +} + +func init() { + UnsetMetricKeys() +} diff --git a/flytestdlib/promutils/labeled/keys_test.go b/flytestdlib/promutils/labeled/keys_test.go index 4a8600aea3..6699ab2af6 100644 --- a/flytestdlib/promutils/labeled/keys_test.go +++ b/flytestdlib/promutils/labeled/keys_test.go @@ -8,8 +8,9 @@ import ( ) func TestMetricKeys(t *testing.T) { + UnsetMetricKeys() input := []contextutils.Key{ - contextutils.ProjectKey, contextutils.DomainKey, contextutils.WorkflowIDKey, contextutils.TaskIDKey, + contextutils.ProjectKey, contextutils.DomainKey, contextutils.WorkflowIDKey, contextutils.TaskIDKey, contextutils.LaunchPlanIDKey, } assert.NotPanics(t, func() { SetMetricKeys(input...) }) diff --git a/flytestdlib/promutils/labeled/stopwatch_test.go b/flytestdlib/promutils/labeled/stopwatch_test.go index d5adf0eade..1d8a69d561 100644 --- a/flytestdlib/promutils/labeled/stopwatch_test.go +++ b/flytestdlib/promutils/labeled/stopwatch_test.go @@ -11,6 +11,7 @@ import ( ) func TestLabeledStopWatch(t *testing.T) { + UnsetMetricKeys() assert.NotPanics(t, func() { SetMetricKeys(contextutils.ProjectKey, contextutils.DomainKey, contextutils.WorkflowIDKey, contextutils.TaskIDKey) }) diff --git a/flytestdlib/storage/cached_rawstore_test.go b/flytestdlib/storage/cached_rawstore_test.go index 316f999bf7..c5225aa790 100644 --- a/flytestdlib/storage/cached_rawstore_test.go +++ b/flytestdlib/storage/cached_rawstore_test.go @@ -18,11 +18,8 @@ import ( "github.com/stretchr/testify/assert" ) -func init() { - labeled.SetMetricKeys(contextutils.ProjectKey, contextutils.DomainKey, contextutils.WorkflowIDKey, contextutils.TaskIDKey) -} - func TestNewCachedStore(t *testing.T) { + resetMetricKeys() t.Run("CachingDisabled", func(t *testing.T) { testScope := promutils.NewTestScope() @@ -50,6 +47,11 @@ func TestNewCachedStore(t *testing.T) { }) } +func resetMetricKeys() { + labeled.UnsetMetricKeys() + labeled.SetMetricKeys(contextutils.ProjectKey, contextutils.DomainKey, contextutils.WorkflowIDKey, contextutils.TaskIDKey) +} + func dummyCacheStore(t *testing.T, store RawStore, scope promutils.Scope) *cachedRawStore { cfg := &Config{ Cache: CachingConfig{ @@ -86,6 +88,7 @@ func (d *dummyStore) WriteRaw(ctx context.Context, reference DataReference, size } func TestCachedRawStore(t *testing.T) { + resetMetricKeys() ctx := context.TODO() k1 := DataReference("k1") k2 := DataReference("k2") diff --git a/flytestdlib/utils/marshal_utils_test.go b/flytestdlib/utils/marshal_utils_test.go index 4ac0fc130b..4295f5a14d 100644 --- a/flytestdlib/utils/marshal_utils_test.go +++ b/flytestdlib/utils/marshal_utils_test.go @@ -7,7 +7,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/golang/protobuf/proto" - "github.com/golang/protobuf/ptypes/struct" + structpb "github.com/golang/protobuf/ptypes/struct" ) type SimpleType struct { From bfca9591809fd778448910c91f08a11854b78d11 Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Thu, 25 Jul 2019 21:00:00 -0700 Subject: [PATCH 0059/1918] Return errors from cached raw store (#28) * Return errors from cached raw store * gracefully ignore cache write errors in ReadProtobuf and WriteProtobuf * fix fixture * Wrapf the use of ErrExceedsLimit * nosec math/rand * add test with caching failures for protobuf store * log and add metrics to record failures not caused by ErrFailedToWriteCache --- flytestdlib/errors/error.go | 113 ++++++++++++++++++++ flytestdlib/errors/error_test.go | 47 ++++++++ flytestdlib/storage/cached_rawstore.go | 8 +- flytestdlib/storage/cached_rawstore_test.go | 26 +++++ flytestdlib/storage/protobuf_store.go | 36 ++++--- flytestdlib/storage/protobuf_store_test.go | 82 ++++++++++++++ flytestdlib/storage/stow_store.go | 4 +- flytestdlib/storage/utils.go | 20 +++- flytestdlib/storage/utils_test.go | 49 +++++++++ 9 files changed, 365 insertions(+), 20 deletions(-) create mode 100644 flytestdlib/errors/error.go create mode 100644 flytestdlib/errors/error_test.go create mode 100644 flytestdlib/storage/utils_test.go diff --git a/flytestdlib/errors/error.go b/flytestdlib/errors/error.go new file mode 100644 index 0000000000..6abb08d3cf --- /dev/null +++ b/flytestdlib/errors/error.go @@ -0,0 +1,113 @@ +// Contains utilities to use to create and consume simple errors. +package errors + +import ( + "fmt" + + "github.com/pkg/errors" +) + +// A generic error code type. +type ErrorCode = string + +type err struct { + code ErrorCode + message string +} + +func (e *err) Error() string { + return fmt.Sprintf("[%v] %v", e.code, e.message) +} + +func (e *err) Code() ErrorCode { + return e.code +} + +type errorWithCause struct { + *err + cause error +} + +func (e *errorWithCause) Error() string { + return fmt.Sprintf("%v, caused by: %v", e.err.Error(), errors.Cause(e)) +} + +func (e *errorWithCause) Cause() error { + return e.cause +} + +// Creates a new error using an error code and a message. +func Errorf(errorCode ErrorCode, msgFmt string, args ...interface{}) error { + return &err{ + code: errorCode, + message: fmt.Sprintf(msgFmt, args...), + } +} + +// Wraps a root cause error with another. This is useful to unify an error type in a package. +func Wrapf(code ErrorCode, cause error, msgFmt string, args ...interface{}) error { + return &errorWithCause{ + err: &err{ + code: code, + message: fmt.Sprintf(msgFmt, args...), + }, + cause: cause, + } +} + +// Gets the error code of the passed error if it has one. +func GetErrorCode(e error) (code ErrorCode, found bool) { + type coder interface { + Code() ErrorCode + } + + er, ok := e.(coder) + if ok { + return er.Code(), true + } + + return +} + +// Gets whether error is caused by another error with errCode. +func IsCausedBy(e error, errCode ErrorCode) bool { + type causer interface { + Cause() error + } + + for e != nil { + if code, found := GetErrorCode(e); found && code == errCode { + return true + } + + cause, ok := e.(causer) + if !ok { + break + } + + e = cause.Cause() + } + + return false +} + +func IsCausedByError(e, e2 error) bool { + type causer interface { + Cause() error + } + + for e != nil { + if e == e2 { + return true + } + + cause, ok := e.(causer) + if !ok { + break + } + + e = cause.Cause() + } + + return false +} diff --git a/flytestdlib/errors/error_test.go b/flytestdlib/errors/error_test.go new file mode 100644 index 0000000000..b3e3dd0a51 --- /dev/null +++ b/flytestdlib/errors/error_test.go @@ -0,0 +1,47 @@ +package errors + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestErrorf(t *testing.T) { + e := Errorf("Code1", "msg") + assert.NotNil(t, e) + assert.Equal(t, "[Code1] msg", e.Error()) +} + +func TestWrapf(t *testing.T) { + e := Wrapf("Code1", fmt.Errorf("test error"), "msg") + assert.NotNil(t, e) + assert.Equal(t, "[Code1] msg, caused by: test error", e.Error()) +} + +func TestGetErrorCode(t *testing.T) { + e := Errorf("Code1", "msg") + assert.NotNil(t, e) + code, found := GetErrorCode(e) + assert.True(t, found) + assert.Equal(t, "Code1", code) +} + +func TestIsCausedBy(t *testing.T) { + e := Errorf("Code1", "msg") + assert.NotNil(t, e) + + e = Wrapf("Code2", e, "msg") + assert.True(t, IsCausedBy(e, "Code1")) + assert.True(t, IsCausedBy(e, "Code2")) +} + +func TestIsCausedByError(t *testing.T) { + eRoot := Errorf("Code1", "msg") + assert.NotNil(t, eRoot) + e1 := Wrapf("Code2", eRoot, "msg") + assert.True(t, IsCausedByError(e1, eRoot)) + e2 := Wrapf("Code3", e1, "msg") + assert.True(t, IsCausedByError(e2, eRoot)) + assert.True(t, IsCausedByError(e2, e1)) +} diff --git a/flytestdlib/storage/cached_rawstore.go b/flytestdlib/storage/cached_rawstore.go index 2b539d7bb1..e9d83df6cd 100644 --- a/flytestdlib/storage/cached_rawstore.go +++ b/flytestdlib/storage/cached_rawstore.go @@ -7,6 +7,8 @@ import ( "runtime/debug" "time" + "github.com/lyft/flytestdlib/errors" + "github.com/prometheus/client_golang/prometheus" "github.com/lyft/flytestdlib/promutils" @@ -75,11 +77,11 @@ func (s *cachedRawStore) ReadRaw(ctx context.Context, reference DataReference) ( err = s.cache.Set(key, b, 0) if err != nil { - // TODO Ignore errors in writing to cache? logger.Debugf(ctx, "Failed to Cache the metadata") + err = errors.Wrapf(ErrFailedToWriteCache, err, "Failed to Cache the metadata") } - return ioutils.NewBytesReadCloser(b), nil + return ioutils.NewBytesReadCloser(b), err } // Stores a raw byte array. @@ -94,9 +96,9 @@ func (s *cachedRawStore) WriteRaw(ctx context.Context, reference DataReference, err = s.cache.Set([]byte(reference), buf.Bytes(), neverExpire) if err != nil { s.metrics.CacheWriteError.Inc() + err = errors.Wrapf(ErrFailedToWriteCache, err, "Failed to Cache the metadata") } - // TODO ignore errors? return err } diff --git a/flytestdlib/storage/cached_rawstore_test.go b/flytestdlib/storage/cached_rawstore_test.go index c5225aa790..6949119627 100644 --- a/flytestdlib/storage/cached_rawstore_test.go +++ b/flytestdlib/storage/cached_rawstore_test.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "io/ioutil" + "math/rand" "runtime/debug" "testing" @@ -92,8 +93,12 @@ func TestCachedRawStore(t *testing.T) { ctx := context.TODO() k1 := DataReference("k1") k2 := DataReference("k2") + bigK := DataReference("bigK") d1 := []byte("abc") d2 := []byte("xyz") + bigD := make([]byte, 1.5*1024*1024) + // #nosec G404 + rand.Read(bigD) writeCalled := false readCalled := false store := &dummyStore{ @@ -113,6 +118,11 @@ func TestCachedRawStore(t *testing.T) { assert.NoError(t, err) assert.Equal(t, d2, b) return nil + } else if reference == "bigK" { + b, err := ioutil.ReadAll(raw) + assert.NoError(t, err) + assert.Equal(t, bigD, b) + return nil } return fmt.Errorf("err") }, @@ -123,6 +133,8 @@ func TestCachedRawStore(t *testing.T) { readCalled = true if reference == "k1" { return ioutils.NewBytesReadCloser(d1), nil + } else if reference == "bigK" { + return ioutils.NewBytesReadCloser(bigD), nil } return nil, fmt.Errorf("err") }, @@ -182,4 +194,18 @@ func TestCachedRawStore(t *testing.T) { assert.False(t, readCalled) }) + t.Run("WriteAndReadBigData", func(t *testing.T) { + writeCalled = false + readCalled = false + err := cStore.WriteRaw(ctx, bigK, int64(len(bigD)), Options{}, bytes.NewReader(bigD)) + assert.True(t, writeCalled) + assert.True(t, IsFailedWriteToCache(err)) + + o, err := cStore.ReadRaw(ctx, bigK) + assert.True(t, IsFailedWriteToCache(err)) + b, err := ioutil.ReadAll(o) + assert.NoError(t, err) + assert.Equal(t, bigD, b) + assert.True(t, readCalled) + }) } diff --git a/flytestdlib/storage/protobuf_store.go b/flytestdlib/storage/protobuf_store.go index ba11d3311c..6b09a0e6eb 100644 --- a/flytestdlib/storage/protobuf_store.go +++ b/flytestdlib/storage/protobuf_store.go @@ -17,11 +17,13 @@ import ( ) type protoMetrics struct { - FetchLatency promutils.StopWatch - MarshalTime promutils.StopWatch - UnmarshalTime promutils.StopWatch - MarshalFailure prometheus.Counter - UnmarshalFailure prometheus.Counter + FetchLatency promutils.StopWatch + MarshalTime promutils.StopWatch + UnmarshalTime promutils.StopWatch + MarshalFailure prometheus.Counter + UnmarshalFailure prometheus.Counter + WriteFailureUnrelatedToCache prometheus.Counter + ReadFailureUnrelatedToCache prometheus.Counter } // Implements ProtobufStore to marshal and unmarshal protobufs to/from a RawStore @@ -32,7 +34,9 @@ type DefaultProtobufStore struct { func (s DefaultProtobufStore) ReadProtobuf(ctx context.Context, reference DataReference, msg proto.Message) error { rc, err := s.ReadRaw(ctx, reference) - if err != nil { + if err != nil && !IsFailedWriteToCache(err) { + logger.Errorf(ctx, "Failed to read from the raw store. Error: %v", err) + s.metrics.ReadFailureUnrelatedToCache.Inc() return errs.Wrap(err, fmt.Sprintf("path:%v", reference)) } @@ -68,18 +72,26 @@ func (s DefaultProtobufStore) WriteProtobuf(ctx context.Context, reference DataR return err } - return s.WriteRaw(ctx, reference, int64(len(raw)), opts, bytes.NewReader(raw)) + err = s.WriteRaw(ctx, reference, int64(len(raw)), opts, bytes.NewReader(raw)) + if err != nil && !IsFailedWriteToCache(err) { + logger.Errorf(ctx, "Failed to write to the raw store. Error: %v", err) + s.metrics.WriteFailureUnrelatedToCache.Inc() + return err + } + return nil } func NewDefaultProtobufStore(store RawStore, metricsScope promutils.Scope) DefaultProtobufStore { return DefaultProtobufStore{ RawStore: store, metrics: &protoMetrics{ - FetchLatency: metricsScope.MustNewStopWatch("proto_fetch", "Time to read data before unmarshalling", time.Millisecond), - MarshalTime: metricsScope.MustNewStopWatch("marshal", "Time incurred in marshalling data before writing", time.Millisecond), - UnmarshalTime: metricsScope.MustNewStopWatch("unmarshal", "Time incurred in unmarshalling received data", time.Millisecond), - MarshalFailure: metricsScope.MustNewCounter("marshal_failure", "Failures when marshalling"), - UnmarshalFailure: metricsScope.MustNewCounter("unmarshal_failure", "Failures when unmarshalling"), + FetchLatency: metricsScope.MustNewStopWatch("proto_fetch", "Time to read data before unmarshalling", time.Millisecond), + MarshalTime: metricsScope.MustNewStopWatch("marshal", "Time incurred in marshalling data before writing", time.Millisecond), + UnmarshalTime: metricsScope.MustNewStopWatch("unmarshal", "Time incurred in unmarshalling received data", time.Millisecond), + MarshalFailure: metricsScope.MustNewCounter("marshal_failure", "Failures when marshalling"), + UnmarshalFailure: metricsScope.MustNewCounter("unmarshal_failure", "Failures when unmarshalling"), + WriteFailureUnrelatedToCache: metricsScope.MustNewCounter("write_failure_unrelated_to_cache", "Raw store write failures that are not caused by ErrFailedToWriteCache"), + ReadFailureUnrelatedToCache: metricsScope.MustNewCounter("read_failure_unrelated_to_cache", "Raw store read failures that are not caused by ErrFailedToWriteCache"), }, } } diff --git a/flytestdlib/storage/protobuf_store_test.go b/flytestdlib/storage/protobuf_store_test.go index 160239bb73..44ac0c4084 100644 --- a/flytestdlib/storage/protobuf_store_test.go +++ b/flytestdlib/storage/protobuf_store_test.go @@ -2,11 +2,15 @@ package storage import ( "context" + "fmt" + "io" + "math/rand" "testing" "github.com/lyft/flytestdlib/promutils" "github.com/golang/protobuf/proto" + errs "github.com/pkg/errors" "github.com/stretchr/testify/assert" ) @@ -14,6 +18,10 @@ type mockProtoMessage struct { X int64 `protobuf:"varint,2,opt,name=x,json=x,proto3" json:"x,omitempty"` } +type mockBigDataProtoMessage struct { + X []byte `protobuf:"bytes,1,opt,name=X,proto3" json:"X,omitempty"` +} + func (mockProtoMessage) Reset() { } @@ -24,6 +32,16 @@ func (m mockProtoMessage) String() string { func (mockProtoMessage) ProtoMessage() { } +func (mockBigDataProtoMessage) Reset() { +} + +func (m mockBigDataProtoMessage) String() string { + return proto.CompactTextString(m) +} + +func (mockBigDataProtoMessage) ProtoMessage() { +} + func TestDefaultProtobufStore_ReadProtobuf(t *testing.T) { t.Run("Read after Write", func(t *testing.T) { testScope := promutils.NewTestScope() @@ -39,3 +57,67 @@ func TestDefaultProtobufStore_ReadProtobuf(t *testing.T) { assert.Equal(t, int64(5), m.X) }) } + +func TestDefaultProtobufStore_BigDataReadAfterWrite(t *testing.T) { + t.Run("Read after Write with Big Data", func(t *testing.T) { + testScope := promutils.NewTestScope() + + s, err := NewDataStore( + &Config{ + Type: TypeMemory, + Cache: CachingConfig{ + MaxSizeMegabytes: 1, + TargetGCPercent: 20, + }, + }, testScope) + assert.NoError(t, err) + + bigD := make([]byte, 1.5*1024*1024) + // #nosec G404 + rand.Read(bigD) + + mockMessage := &mockBigDataProtoMessage{X: bigD} + + err = s.WriteProtobuf(context.TODO(), DataReference("bigK"), Options{}, mockMessage) + assert.NoError(t, err) + + m := &mockBigDataProtoMessage{} + err = s.ReadProtobuf(context.TODO(), DataReference("bigK"), m) + assert.NoError(t, err) + assert.Equal(t, bigD, m.X) + + }) +} + +func TestDefaultProtobufStore_HardErrors(t *testing.T) { + ctx := context.TODO() + k1 := DataReference("k1") + dummyHeadErrorMsg := "Dummy head error" + dummyWriteErrorMsg := "Dummy write error" + dummyReadErrorMsg := "Dummy read error" + store := &dummyStore{ + HeadCb: func(ctx context.Context, reference DataReference) (Metadata, error) { + return MemoryMetadata{}, fmt.Errorf(dummyHeadErrorMsg) + }, + WriteRawCb: func(ctx context.Context, reference DataReference, size int64, opts Options, raw io.Reader) error { + return fmt.Errorf(dummyWriteErrorMsg) + }, + ReadRawCb: func(ctx context.Context, reference DataReference) (io.ReadCloser, error) { + return nil, fmt.Errorf(dummyReadErrorMsg) + }, + } + testScope := promutils.NewTestScope() + pbErroneousStore := NewDefaultProtobufStore(store, testScope) + t.Run("Test if hard write errors are handled correctly", func(t *testing.T) { + err := pbErroneousStore.WriteProtobuf(ctx, k1, Options{}, &mockProtoMessage{X: 5}) + assert.False(t, IsFailedWriteToCache(err)) + assert.Equal(t, dummyWriteErrorMsg, errs.Cause(err).Error()) + }) + + t.Run("Test if hard read errors are handled correctly", func(t *testing.T) { + m := &mockProtoMessage{} + err := pbErroneousStore.ReadProtobuf(ctx, k1, m) + assert.False(t, IsFailedWriteToCache(err)) + assert.Equal(t, dummyReadErrorMsg, errs.Cause(err).Error()) + }) +} diff --git a/flytestdlib/storage/stow_store.go b/flytestdlib/storage/stow_store.go index b13170bf09..e156f376da 100644 --- a/flytestdlib/storage/stow_store.go +++ b/flytestdlib/storage/stow_store.go @@ -5,6 +5,8 @@ import ( "io" "time" + "github.com/lyft/flytestdlib/errors" + "github.com/prometheus/client_golang/prometheus" "github.com/lyft/flytestdlib/promutils" @@ -116,7 +118,7 @@ func (s *StowStore) ReadRaw(ctx context.Context, reference DataReference) (io.Re } if sizeBytes/MiB > GetConfig().Limits.GetLimitMegabytes { - return nil, ErrExceedsLimit + return nil, errors.Wrapf(ErrExceedsLimit, err, "limit exceeded") } return item.Open() diff --git a/flytestdlib/storage/utils.go b/flytestdlib/storage/utils.go index 62fb2aa22c..9b790c6120 100644 --- a/flytestdlib/storage/utils.go +++ b/flytestdlib/storage/utils.go @@ -1,18 +1,26 @@ package storage import ( - "fmt" "os" + errors2 "github.com/lyft/flytestdlib/errors" + "github.com/graymeta/stow" "github.com/pkg/errors" ) -var ErrExceedsLimit = fmt.Errorf("limit exceeded") +var ( + ErrExceedsLimit errors2.ErrorCode = "LIMIT_EXCEEDED" + ErrFailedToWriteCache errors2.ErrorCode = "CACHE_WRITE_FAILED" +) // Gets a value indicating whether the underlying error is a Not Found error. func IsNotFound(err error) bool { - if root := errors.Cause(err); root == stow.ErrNotFound || os.IsNotExist(root) { + if root := errors.Cause(err); os.IsNotExist(root) { + return true + } + + if errors2.IsCausedByError(err, stow.ErrNotFound) { return true } @@ -30,5 +38,9 @@ func IsExists(err error) bool { // Gets a value indicating whether the root cause of error is a "limit exceeded" error. func IsExceedsLimit(err error) bool { - return errors.Cause(err) == ErrExceedsLimit + return errors2.IsCausedBy(err, ErrExceedsLimit) +} + +func IsFailedWriteToCache(err error) bool { + return errors2.IsCausedBy(err, ErrFailedToWriteCache) } diff --git a/flytestdlib/storage/utils_test.go b/flytestdlib/storage/utils_test.go new file mode 100644 index 0000000000..8618310f4a --- /dev/null +++ b/flytestdlib/storage/utils_test.go @@ -0,0 +1,49 @@ +package storage + +import ( + "os" + "syscall" + "testing" + + "github.com/graymeta/stow" + flyteerrors "github.com/lyft/flytestdlib/errors" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" +) + +func TestIsNotFound(t *testing.T) { + sysError := &os.PathError{Err: syscall.ENOENT} + assert.True(t, IsNotFound(sysError)) + flyteError := errors.Wrap(sysError, "Wrapping \"system not found\" error") + assert.True(t, IsNotFound(flyteError)) + secondLevelError := errors.Wrap(flyteError, "Higher level error") + assert.True(t, IsNotFound(secondLevelError)) + + // more for stow errors + stowNotFoundError := stow.ErrNotFound + assert.True(t, IsNotFound(stowNotFoundError)) + flyteError = errors.Wrap(stowNotFoundError, "Wrapping stow.ErrNotFound") + assert.True(t, IsNotFound(flyteError)) + secondLevelError = errors.Wrap(flyteError, "Higher level error wrapper of the stow.ErrNotFound error") + assert.True(t, IsNotFound(secondLevelError)) +} + +func TestIsExceedsLimit(t *testing.T) { + sysError := &os.PathError{Err: syscall.ENOENT} + exceedsLimitError := flyteerrors.Wrapf(ErrExceedsLimit, sysError, "An error wrapped in ErrExceedsLimits") + failedToWriteCacheError := flyteerrors.Wrapf(ErrFailedToWriteCache, sysError, "An error wrapped in ErrFailedToWriteCache") + + assert.True(t, IsExceedsLimit(exceedsLimitError)) + assert.False(t, IsExceedsLimit(failedToWriteCacheError)) + assert.False(t, IsExceedsLimit(sysError)) +} + +func TestIsFailedWriteToCache(t *testing.T) { + sysError := &os.PathError{Err: syscall.ENOENT} + exceedsLimitError := flyteerrors.Wrapf(ErrExceedsLimit, sysError, "An error wrapped in ErrExceedsLimits") + failedToWriteCacheError := flyteerrors.Wrapf(ErrFailedToWriteCache, sysError, "An error wrapped in ErrFailedToWriteCache") + + assert.False(t, IsFailedWriteToCache(exceedsLimitError)) + assert.True(t, IsFailedWriteToCache(failedToWriteCacheError)) + assert.False(t, IsFailedWriteToCache(sysError)) +} From e0a8732c22e0d55b042b55ef4df1d54e88202808 Mon Sep 17 00:00:00 2001 From: goreleaserbot Date: Thu, 25 Jul 2019 21:07:29 -0700 Subject: [PATCH 0060/1918] Scoop update for flytestdlib version v1.2.4 --- flytestdlib/flytestdlib.json | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/flytestdlib/flytestdlib.json b/flytestdlib/flytestdlib.json index 20e5425e3a..e43a4c2528 100644 --- a/flytestdlib/flytestdlib.json +++ b/flytestdlib/flytestdlib.json @@ -1,19 +1,19 @@ { - "version": "0.2.10", + "version": "1.2.4", "architecture": { "32bit": { - "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.10/flytestdlib_0.2.10_Windows_i386.zip", + "url": "https://github.com/lyft/flytestdlib/releases/download/v1.2.4/flytestdlib_1.2.4_Windows_i386.zip", "bin": [ "pflags.exe" ], - "hash": "0581a128328931280f2d60ff885dd1a2063776a622a710135ce846d66be442e1" + "hash": "6b52de6cf834a60a07d7755653492541ee4e9079aa594fe2e7157c11e4cfcdfb" }, "64bit": { - "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.10/flytestdlib_0.2.10_Windows_x86_64.zip", + "url": "https://github.com/lyft/flytestdlib/releases/download/v1.2.4/flytestdlib_1.2.4_Windows_x86_64.zip", "bin": [ "pflags.exe" ], - "hash": "061dd31294bc2498d8cb939f164ca7518e23ea14978d7ea06244c3d43ea3b9e6" + "hash": "cd32add4a68eb23020912d81e9c24ec00f742ab11584f0c15749a96c7937eff9" } }, "homepage": "https://godoc.org/github.com/lyft/flytestdlib", From a8b264f7a06b2369ac4ec62f9f6253a7adfaef2e Mon Sep 17 00:00:00 2001 From: goreleaserbot Date: Thu, 25 Jul 2019 21:08:45 -0700 Subject: [PATCH 0061/1918] Scoop update for flytestdlib version v0.2.11 --- flytestdlib/flytestdlib.json | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/flytestdlib/flytestdlib.json b/flytestdlib/flytestdlib.json index e43a4c2528..0246a2c4ef 100644 --- a/flytestdlib/flytestdlib.json +++ b/flytestdlib/flytestdlib.json @@ -1,19 +1,19 @@ { - "version": "1.2.4", + "version": "0.2.11", "architecture": { "32bit": { - "url": "https://github.com/lyft/flytestdlib/releases/download/v1.2.4/flytestdlib_1.2.4_Windows_i386.zip", + "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.11/flytestdlib_0.2.11_Windows_i386.zip", "bin": [ "pflags.exe" ], - "hash": "6b52de6cf834a60a07d7755653492541ee4e9079aa594fe2e7157c11e4cfcdfb" + "hash": "4ace26d248fea1f01e59cb613d8353c33800a7f8420dc70b48d50e71d8602ce3" }, "64bit": { - "url": "https://github.com/lyft/flytestdlib/releases/download/v1.2.4/flytestdlib_1.2.4_Windows_x86_64.zip", + "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.11/flytestdlib_0.2.11_Windows_x86_64.zip", "bin": [ "pflags.exe" ], - "hash": "cd32add4a68eb23020912d81e9c24ec00f742ab11584f0c15749a96c7937eff9" + "hash": "665370f8394034a2dbe9334b1cd698624da2cfaeb353c67d404f10a365042c00" } }, "homepage": "https://godoc.org/github.com/lyft/flytestdlib", From 0ff342be41ed431781eebfaa9acb99f2769c4296 Mon Sep 17 00:00:00 2001 From: Ketan Umare Date: Fri, 26 Jul 2019 22:57:14 -0700 Subject: [PATCH 0062/1918] Generic StowStore configuration (#29) * Work in progress: Generic Stowstore * Simplifying to use only stowstore * added missing functions * removing deleted reference * Updated pflags * Fixed --- flytestdlib/Gopkg.lock | 396 +++++++++++++++--- flytestdlib/Gopkg.toml | 6 +- flytestdlib/cli/pflags/api/generator.go | 5 + flytestdlib/cli/pflags/api/sample.go | 1 + .../cli/pflags/api/testdata/testtype.go | 2 +- flytestdlib/logger/config_flags.go | 13 +- flytestdlib/logger/config_flags_test.go | 33 ++ flytestdlib/logger/config_test.go | 17 + flytestdlib/storage/config.go | 9 +- flytestdlib/storage/config_flags.go | 15 +- flytestdlib/storage/localstore_test.go | 5 +- flytestdlib/storage/rawstores.go | 5 +- flytestdlib/storage/s3store.go | 103 ----- flytestdlib/storage/s3stsore_test.go | 26 -- flytestdlib/storage/stow_store_test.go | 6 +- flytestdlib/storage/stowstore.go | 130 ++++++ flytestdlib/storage/stowstore_test.go | 53 +++ 17 files changed, 618 insertions(+), 207 deletions(-) delete mode 100644 flytestdlib/storage/s3store.go delete mode 100644 flytestdlib/storage/s3stsore_test.go create mode 100644 flytestdlib/storage/stowstore.go create mode 100644 flytestdlib/storage/stowstore_test.go diff --git a/flytestdlib/Gopkg.lock b/flytestdlib/Gopkg.lock index 0fcc602f6f..38e3251581 100644 --- a/flytestdlib/Gopkg.lock +++ b/flytestdlib/Gopkg.lock @@ -2,7 +2,46 @@ [[projects]] - digest = "1:7fbdc0ca5fc0b0bb66b81ec2fdca82fbe64416742267f11aceb8ed56e6ca3121" + digest = "1:80004fcc5cf64e591486b3e11b406f1e0d17bf85d475d64203c8494f5da4fcd1" + name = "cloud.google.com/go" + packages = ["compute/metadata"] + pruneopts = "UT" + revision = "cdaaf98f9226c39dc162b8e55083b2fbc67b4674" + version = "v0.43.0" + +[[projects]] + digest = "1:6b1426cad7057b717351eacf5b6fe70f053f11aac1ce254bbf2fd72c031719eb" + name = "contrib.go.opencensus.io/exporter/ocagent" + packages = ["."] + pruneopts = "UT" + revision = "dcb33c7f3b7cfe67e8a2cea10207ede1b7c40764" + version = "v0.4.12" + +[[projects]] + digest = "1:94d4ae958b3d2ab476bef4bed53c1dcc3cb0fb2639bd45dd08b40e57139192e5" + name = "github.com/Azure/azure-sdk-for-go" + packages = ["storage"] + pruneopts = "UT" + revision = "2d49bb8f2cee530cc16f1f1a9f0aae763dee257d" + version = "v10.2.1-beta" + +[[projects]] + digest = "1:3818ae0b615fadcb7ae2291f0147c2db42775398347123c2fe4de1d54499b9da" + name = "github.com/Azure/go-autorest" + packages = [ + "autorest", + "autorest/adal", + "autorest/azure", + "autorest/date", + "logger", + "tracing", + ] + pruneopts = "UT" + revision = "2913f263500c4a5b23dada1b46ccd22ac972315f" + version = "v12.3.0" + +[[projects]] + digest = "1:5b029601017603a512847d8ed857cca57b0af863d5289b4c8493d3f25b5425d0" name = "github.com/aws/aws-sdk-go" packages = [ "aws", @@ -32,17 +71,21 @@ "private/protocol", "private/protocol/eventstream", "private/protocol/eventstream/eventstreamapi", + "private/protocol/json/jsonutil", "private/protocol/query", "private/protocol/query/queryutil", "private/protocol/rest", "private/protocol/restxml", "private/protocol/xml/xmlutil", "service/s3", + "service/s3/s3iface", + "service/s3/s3manager", "service/sts", + "service/sts/stsiface", ] pruneopts = "UT" - revision = "81f3829f5a9d041041bdf56e55926691309d7699" - version = "v1.16.26" + revision = "77b6968988559d355ab953b6436eafe6a76e4a92" + version = "v1.21.4" [[projects]] branch = "master" @@ -53,12 +96,27 @@ revision = "d1e3d6079fc16f8f542183fb5b2fdc11d9f00866" [[projects]] - branch = "master" digest = "1:d6afaeed1502aa28e80a4ed0981d570ad91b2579193404256ce672ed0a609e0d" name = "github.com/beorn7/perks" packages = ["quantile"] pruneopts = "UT" - revision = "3a771d992973f24aa725d07868b467d1ddfceafb" + revision = "4b2b341e8d7715fae06375aa633dbb6e91b3fb46" + version = "v1.0.0" + +[[projects]] + digest = "1:8f5acd4d4462b5136af644d25101f0968a7a94ee90fcb2059cec5b7cc42e0b20" + name = "github.com/census-instrumentation/opencensus-proto" + packages = [ + "gen-go/agent/common/v1", + "gen-go/agent/metrics/v1", + "gen-go/agent/trace/v1", + "gen-go/metrics/v1", + "gen-go/resource/v1", + "gen-go/trace/v1", + ] + pruneopts = "UT" + revision = "d89fa54de508111353cb0b06403c00569be780d8" + version = "v0.2.1" [[projects]] digest = "1:998cf998358a303ac2430c386ba3fd3398477d6013153d3c6e11432765cc9ae6" @@ -69,12 +127,12 @@ version = "v2.0.0" [[projects]] - digest = "1:04179a5bcbecdb18f06cca42e3808ae8560f86ad7fe470fde21206008f0c5e26" + digest = "1:00eb5d8bd96289512920ac43367d5bee76bbca2062da34862a98b26b92741896" name = "github.com/coocood/freecache" packages = ["."] pruneopts = "UT" - revision = "f3233c8095b26cd0dea0b136b931708c05defa08" - version = "v1.0.1" + revision = "3c79a0a23c1940ab4479332fb3e0127265650ce3" + version = "v1.1.0" [[projects]] digest = "1:ffe9824d294da03b391f44e1ae8281281b4afc1bdaa9588c9097785e3af10cec" @@ -84,6 +142,14 @@ revision = "8991bc29aa16c548c550c7ff78260e27b9ab7c73" version = "v1.1.1" +[[projects]] + digest = "1:76dc72490af7174349349838f2fe118996381b31ea83243812a97e5a0fd5ed55" + name = "github.com/dgrijalva/jwt-go" + packages = ["."] + pruneopts = "UT" + revision = "06ea1031745cb8b3dab3f6a236daf2b0aa468b7e" + version = "v3.2.0" + [[projects]] branch = "master" digest = "1:dc8bf44b7198605c83a4f2bb36a92c4d9f71eab2e8cf8094ce31b0297dd8ea89" @@ -113,11 +179,11 @@ [[projects]] branch = "master" - digest = "1:b9414457752702c53f6afd3838da3d89b9513ada40cdbe9603bdf54b1ceb5014" + digest = "1:78a5b63751bd99054bee07a498f6aa54da0a909922f9365d1aa3339091efa70a" name = "github.com/fsnotify/fsnotify" packages = ["."] pruneopts = "UT" - revision = "ccc981bf80385c528a65fbfdd49bf2d8da22aa23" + revision = "1485a34d5d5723fea214f5710708e19a831720e4" [[projects]] digest = "1:2cd7915ab26ede7d95b8749e6b1f933f1c6d5398030684e6505940a10f31cfda" @@ -128,31 +194,52 @@ version = "v1.0.0" [[projects]] - digest = "1:9d6dc4d6de69b330d0de86494d6db90c09848c003d5db748f40c925f865c8534" + digest = "1:b532ee3f683c057e797694b5bfeb3827d89e6adf41c53dbc80e549bca76364ea" name = "github.com/golang/protobuf" packages = [ "jsonpb", "proto", + "protoc-gen-go/descriptor", + "protoc-gen-go/generator", + "protoc-gen-go/generator/internal/remap", + "protoc-gen-go/plugin", "ptypes", "ptypes/any", "ptypes/duration", "ptypes/struct", "ptypes/timestamp", + "ptypes/wrappers", ] pruneopts = "UT" - revision = "aa810b61a9c79d51363740d207bb46cf8e620ed5" - version = "v1.2.0" + revision = "6c65a5562fc06764971b7c5d05c76c75e84bdbf7" + version = "v1.3.2" [[projects]] - digest = "1:3dab0e385faed192353d2150f6a041f4607f04a0e885f4a5a824eee6b676b4b9" + digest = "1:16e1cbd76f0d4152b5573f08f38b451748f74ec59b99a004a7481342b3fc05af" name = "github.com/graymeta/stow" packages = [ ".", + "azure", + "google", "local", + "oracle", "s3", + "swift", + ] + pruneopts = "UT" + revision = "903027f87de7054953efcdb8ba70d5dc02df38c7" + +[[projects]] + digest = "1:3b341cd71012c63aacddabfc70b9110be8e30c553349552ad3f77242843f2d03" + name = "github.com/grpc-ecosystem/grpc-gateway" + packages = [ + "internal", + "runtime", + "utilities", ] pruneopts = "UT" - revision = "77c84b1dd69c41b74fe0a94ca8ee257d85947327" + revision = "ad529a448ba494a88058f9e5be0988713174ac86" + version = "v1.9.5" [[projects]] digest = "1:d15ee511aa0f56baacc1eb4c6b922fa1c03b38413b6be18166b996d82a0156ea" @@ -200,23 +287,23 @@ revision = "c2b33e84" [[projects]] - digest = "1:0a69a1c0db3591fcefb47f115b224592c8dfa4368b7ba9fae509d5e16cdc95c8" + digest = "1:31e761d97c76151dde79e9d28964a812c46efc5baee4085b86f68f0c654450de" name = "github.com/konsorten/go-windows-terminal-sequences" packages = ["."] pruneopts = "UT" - revision = "5c8c8bd35d3832f5d134ae1e1e375b69a4d25242" - version = "v1.0.1" + revision = "f55edac94c9bbba5d6182a4be46d86a2c9b5b50e" + version = "v1.0.2" [[projects]] - digest = "1:53e8c5c79716437e601696140e8b1801aae4204f4ec54a504333702a49572c4f" + digest = "1:2a0da3440db3f2892609d99cd0389c2776a3fef24435f7b7b58bfc9030aa86ca" name = "github.com/magiconair/properties" packages = [ ".", "assert", ] pruneopts = "UT" - revision = "c2353362d570a7bfa228149c62842019201cfb71" - version = "v1.8.0" + revision = "de8848e004dd33dc07a2947b3d76f618a7fc7ef1" + version = "v1.8.1" [[projects]] digest = "1:c658e84ad3916da105a761660dcaeb01e63416c8ec7bc62256a9b411a05fcd67" @@ -227,12 +314,12 @@ version = "v0.0.9" [[projects]] - digest = "1:0981502f9816113c9c8c4ac301583841855c8cf4da8c72f696b3ebedf6d0e4e5" + digest = "1:9b90c7639a41697f3d4ad12d7d67dfacc9a7a4a6e0bbfae4fc72d0da57c28871" name = "github.com/mattn/go-isatty" packages = ["."] pruneopts = "UT" - revision = "6ca4dbf54d38eea1a992b3c722a76a5d1c4cb25c" - version = "v0.0.4" + revision = "1311e847b0cb909da63b5fecfb5370aa66236465" + version = "v0.0.8" [[projects]] digest = "1:ff5ebae34cfbf047d505ee150de27e60570e8c394b3b8fdbb720ff6ac71985fc" @@ -251,12 +338,20 @@ version = "v1.1.2" [[projects]] - digest = "1:95741de3af260a92cc5c7f3f3061e85273f5a81b5db20d4bd68da74bd521675e" + branch = "master" + digest = "1:b20de3cce4fa037405a51f29d69872915d88d5358820ec71e3537e1a29b2d8d5" + name = "github.com/ncw/swift" + packages = ["."] + pruneopts = "UT" + revision = "be076bb68c47b6c17b06258ea7778cfeaf185ea5" + +[[projects]] + digest = "1:93131d8002d7025da13582877c32d1fc302486775a1b06f62241741006428c5e" name = "github.com/pelletier/go-toml" packages = ["."] pruneopts = "UT" - revision = "c01d1270ff3e442a8a57cddc1c92dc1138598194" - version = "v1.2.0" + revision = "728039f679cbcd4f6a54e080d2219a4c4928c546" + version = "v1.4.0" [[projects]] digest = "1:cf31692c14422fa27c83a05292eb5cbe0fb2775972e8f1f8446a71549bd8980b" @@ -275,7 +370,7 @@ version = "v1.0.0" [[projects]] - digest = "1:93a746f1060a8acbcf69344862b2ceced80f854170e1caae089b2834c5fbf7f4" + digest = "1:e89f2cdede55684adbe44b5566f55838ad2aee1dff348d14b73ccf733607b671" name = "github.com/prometheus/client_golang" packages = [ "prometheus", @@ -283,8 +378,8 @@ "prometheus/promhttp", ] pruneopts = "UT" - revision = "505eaef017263e299324067d40ca2c48f6a2cf50" - version = "v0.9.2" + revision = "2641b987480bca71fb39738eb8c8b0d577cb1d76" + version = "v0.9.4" [[projects]] branch = "master" @@ -295,7 +390,7 @@ revision = "fd36f4220a901265f90734c3183c5f0c91daa0b8" [[projects]] - digest = "1:35cf6bdf68db765988baa9c4f10cc5d7dda1126a54bd62e252dbcd0b1fc8da90" + digest = "1:8dcedf2e8f06c7f94e48267dea0bc0be261fa97b377f3ae3e87843a92a549481" name = "github.com/prometheus/common" packages = [ "expfmt", @@ -303,40 +398,46 @@ "model", ] pruneopts = "UT" - revision = "cfeb6f9992ffa54aaa4f2170ade4067ee478b250" - version = "v0.2.0" + revision = "31bed53e4047fd6c510e43a941f90cb31be0972a" + version = "v0.6.0" [[projects]] - branch = "master" - digest = "1:5833c61ebbd625a6bad8e5a1ada2b3e13710cf3272046953a2c8915340fe60a3" + digest = "1:366f5aa02ff6c1e2eccce9ca03a22a6d983da89eecff8a89965401764534eb7c" name = "github.com/prometheus/procfs" packages = [ ".", - "internal/util", - "nfs", - "xfs", + "internal/fs", ] pruneopts = "UT" - revision = "316cf8ccfec56d206735d46333ca162eb374da8b" + revision = "3f98efb27840a48a7a2898ec80be07674d19f9c8" + version = "v0.0.3" + +[[projects]] + digest = "1:274f67cb6fed9588ea2521ecdac05a6d62a8c51c074c1fccc6a49a40ba80e925" + name = "github.com/satori/uuid" + packages = ["."] + pruneopts = "UT" + revision = "f58768cc1a7a7e77a3bd49e98cdd21419399b6a3" + version = "v1.2.0" [[projects]] - digest = "1:87c2e02fb01c27060ccc5ba7c5a407cc91147726f8f40b70cceeedbc52b1f3a8" + digest = "1:04457f9f6f3ffc5fea48e71d62f2ca256637dee0a04d710288e27e05c8b41976" name = "github.com/sirupsen/logrus" packages = ["."] pruneopts = "UT" - revision = "e1e72e9de974bd926e5c56f83753fba2df402ce5" - version = "v1.3.0" + revision = "839c75faf7f98a33d445d181f3018b5c3409a45e" + version = "v1.4.2" [[projects]] - digest = "1:3e39bafd6c2f4bf3c76c3bfd16a2e09e016510ad5db90dc02b88e2f565d6d595" + digest = "1:bb495ec276ab82d3dd08504bbc0594a65de8c3b22c6f2aaa92d05b73fbf3a82e" name = "github.com/spf13/afero" packages = [ ".", "mem", ] pruneopts = "UT" - revision = "f4711e4db9e9a1d3887343acb72b2bbfc2f686f5" - version = "v1.2.1" + revision = "588a75ec4f32903aa5e39a2619ba6a4631e28424" + version = "v1.2.2" [[projects]] digest = "1:08d65904057412fc0270fc4812a1c90c594186819243160dc779a402d4b6d0bc" @@ -347,20 +448,20 @@ version = "v1.3.0" [[projects]] - digest = "1:645cabccbb4fa8aab25a956cbcbdf6a6845ca736b2c64e197ca7cbb9d210b939" + digest = "1:e096613fb7cf34743d49af87d197663cfccd61876e2219853005a57baedfa562" name = "github.com/spf13/cobra" packages = ["."] pruneopts = "UT" - revision = "ef82de70bb3f60c65fb8eebacbb2d122ef517385" - version = "v0.0.3" + revision = "f2b07da1e2c38d5f12845a4f607e2e1018cbb1f5" + version = "v0.0.5" [[projects]] - digest = "1:68ea4e23713989dc20b1bded5d9da2c5f9be14ff9885beef481848edd18c26cb" + digest = "1:1b753ec16506f5864d26a28b43703c58831255059644351bbcb019b843950900" name = "github.com/spf13/jwalterweatherman" packages = ["."] pruneopts = "UT" - revision = "4a4406e478ca629068e7768fc33f3f044173c0a6" - version = "v1.0.0" + revision = "94f6ae3ed3bceceafa716478c5fbf8d29ca601a1" + version = "v1.1.0" [[projects]] digest = "1:c1b1102241e7f645bc8e0c22ae352e8f0dc6484b6cb4d132fa9f24174e0119e2" @@ -385,39 +486,104 @@ revision = "ffdc059bfe9ce6a4e144ba849dbedead332c6053" version = "v1.3.0" +[[projects]] + digest = "1:4c93890bbbb5016505e856cb06b5c5a2ff5b7217584d33f2a9071ebef4b5d473" + name = "go.opencensus.io" + packages = [ + ".", + "internal", + "internal/tagencoding", + "metric/metricdata", + "metric/metricproducer", + "plugin/ocgrpc", + "plugin/ochttp", + "plugin/ochttp/propagation/b3", + "plugin/ochttp/propagation/tracecontext", + "resource", + "stats", + "stats/internal", + "stats/view", + "tag", + "trace", + "trace/internal", + "trace/propagation", + "trace/tracestate", + ] + pruneopts = "UT" + revision = "43463a80402d8447b7fce0d2c58edf1687ff0b58" + version = "v0.19.3" + [[projects]] branch = "master" - digest = "1:fde12c4da6237363bf36b81b59aa36a43d28061167ec4acb0d41fc49464e28b9" - name = "golang.org/x/crypto" - packages = ["ssh/terminal"] + digest = "1:eae689808191546269bf951e1e66e0b6bc468be58a0498c0f037feeef2c67bab" + name = "golang.org/x/net" + packages = [ + "context", + "context/ctxhttp", + "http/httpguts", + "http2", + "http2/hpack", + "idna", + "internal/timeseries", + "trace", + ] pruneopts = "UT" - revision = "b01c7a72566457eb1420261cdafef86638fc3861" + revision = "ca1201d0de80cfde86cb01aea620983605dfe99b" [[projects]] branch = "master" - digest = "1:7941e2f16c0833b438cbef7fccfe4f8346f9f7876b42b29717a75d7e8c4800cb" - name = "golang.org/x/sys" + digest = "1:31e33f76456ccf54819ab4a646cf01271d1a99d7712ab84bf1a9e7b61cd2031b" + name = "golang.org/x/oauth2" packages = [ - "unix", - "windows", + ".", + "google", + "internal", + "jws", + "jwt", ] pruneopts = "UT" - revision = "aca44879d5644da7c5b8ec6a1115e9b6ea6c40d9" + revision = "0f29369cfe4552d0e4bcddc57cc75f4d7e672a33" + +[[projects]] + branch = "master" + digest = "1:382bb5a7fb4034db3b6a2d19e5a4a6bcf52f4750530603c01ca18a172fa3089b" + name = "golang.org/x/sync" + packages = ["semaphore"] + pruneopts = "UT" + revision = "112230192c580c3556b8cee6403af37a4fc5f28c" [[projects]] - digest = "1:8029e9743749d4be5bc9f7d42ea1659471767860f0cdc34d37c3111bd308a295" + branch = "master" + digest = "1:5632b0c4d972da51b5914f09fc5c1a8535e9d8d5d937e95ef83c423a0dd67f13" + name = "golang.org/x/sys" + packages = ["unix"] + pruneopts = "UT" + revision = "fae7ac547cb717d141c433a2a173315e216b64c4" + +[[projects]] + digest = "1:8d8faad6b12a3a4c819a3f9618cb6ee1fa1cfc33253abeeea8b55336721e3405" name = "golang.org/x/text" packages = [ + "collate", + "collate/build", + "internal/colltab", "internal/gen", + "internal/language", + "internal/language/compact", + "internal/tag", "internal/triegen", "internal/ucd", + "language", + "secure/bidirule", "transform", + "unicode/bidi", "unicode/cldr", "unicode/norm", + "unicode/rangetable", ] pruneopts = "UT" - revision = "f21a4dfb5e38f5895301dc265a8def02365cc3d0" - version = "v0.3.0" + revision = "342b2e1fbaa52c93f31447ad2c6abc048c63e475" + version = "v0.3.2" [[projects]] branch = "master" @@ -425,16 +591,15 @@ name = "golang.org/x/time" packages = ["rate"] pruneopts = "UT" - revision = "85acf8d2951cb2a3bde7632f9ff273ef0379bcbd" + revision = "9d24e82272b4f38b78bc8cff74fa936d31ccd8ef" [[projects]] branch = "master" - digest = "1:86d002f2c67e364e097c5047f517ab38cdef342c3c20be53974c1bfd5b191d30" + digest = "1:03cb7931c1d82eb61d91e3a29d672c7e1a5b97dbb64fc7ece8b4ec1016dba4ac" name = "golang.org/x/tools" packages = [ "go/ast/astutil", "go/gcexportdata", - "go/internal/cgo", "go/internal/gcimporter", "go/internal/packagesdriver", "go/packages", @@ -442,11 +607,104 @@ "imports", "internal/fastwalk", "internal/gopathwalk", + "internal/imports", "internal/module", "internal/semver", ] pruneopts = "UT" - revision = "58ecf64b2ccd4e014267d2ea143d23c617ee7e4c" + revision = "8aa4eac1a7c108f74782725e959bedd2b844f738" + +[[projects]] + branch = "master" + digest = "1:6fa791c8aeb0da453e2d3208a20e282bb8c15551a1b91e6672fec99f636a53ca" + name = "google.golang.org/api" + packages = [ + "gensupport", + "googleapi", + "googleapi/internal/uritemplates", + "googleapi/transport", + "internal", + "option", + "storage/v1", + "support/bundler", + "transport/http", + "transport/http/internal/propagation", + ] + pruneopts = "UT" + revision = "069bea57b1be6ad0671a49ea7a1128025a22b73f" + +[[projects]] + digest = "1:498b722d33dde4471e7d6e5d88a5e7132d2a8306fea5ff5ee82d1f418b4f41ed" + name = "google.golang.org/appengine" + packages = [ + ".", + "internal", + "internal/app_identity", + "internal/base", + "internal/datastore", + "internal/log", + "internal/modules", + "internal/remote_api", + "internal/urlfetch", + "urlfetch", + ] + pruneopts = "UT" + revision = "b2f4a3cf3c67576a2ee09e1fe62656a5086ce880" + version = "v1.6.1" + +[[projects]] + branch = "master" + digest = "1:3565a93b7692277a5dea355bc47bd6315754f3246ed07a224be6aec28972a805" + name = "google.golang.org/genproto" + packages = [ + "googleapis/api/httpbody", + "googleapis/rpc/status", + "protobuf/field_mask", + ] + pruneopts = "UT" + revision = "c506a9f9061087022822e8da603a52fc387115a8" + +[[projects]] + digest = "1:cf01ae0753310464677058b125fa31e74fd943781782ada503180ad784fc83d3" + name = "google.golang.org/grpc" + packages = [ + ".", + "balancer", + "balancer/base", + "balancer/roundrobin", + "binarylog/grpc_binarylog_v1", + "codes", + "connectivity", + "credentials", + "credentials/internal", + "encoding", + "encoding/proto", + "grpclog", + "internal", + "internal/backoff", + "internal/balancerload", + "internal/binarylog", + "internal/channelz", + "internal/envconfig", + "internal/grpcrand", + "internal/grpcsync", + "internal/syscall", + "internal/transport", + "keepalive", + "metadata", + "naming", + "peer", + "resolver", + "resolver/dns", + "resolver/passthrough", + "serviceconfig", + "stats", + "status", + "tap", + ] + pruneopts = "UT" + revision = "1d89a3c832915b2314551c1d2a506874d62e53f7" + version = "v1.22.0" [[projects]] digest = "1:4d2e5a73dc1500038e504a8d78b986630e3626dc027bc030ba5c75da257cdb96" @@ -506,8 +764,12 @@ "github.com/golang/protobuf/ptypes/struct", "github.com/golang/protobuf/ptypes/timestamp", "github.com/graymeta/stow", + "github.com/graymeta/stow/azure", + "github.com/graymeta/stow/google", "github.com/graymeta/stow/local", + "github.com/graymeta/stow/oracle", "github.com/graymeta/stow/s3", + "github.com/graymeta/stow/swift", "github.com/hashicorp/golang-lru", "github.com/magiconair/properties/assert", "github.com/mitchellh/mapstructure", diff --git a/flytestdlib/Gopkg.toml b/flytestdlib/Gopkg.toml index fdf90bba03..c5a397795f 100644 --- a/flytestdlib/Gopkg.toml +++ b/flytestdlib/Gopkg.toml @@ -55,6 +55,10 @@ branch = "master" name = "golang.org/x/time" +[[override]] + branch = "master" + name = "golang.org/x/net" + [[constraint]] name = "k8s.io/apimachinery" version = "kubernetes-1.13.1" @@ -65,7 +69,7 @@ [[constraint]] name = "github.com/graymeta/stow" - revision = "77c84b1dd69c41b74fe0a94ca8ee257d85947327" + revision = "903027f87de7054953efcdb8ba70d5dc02df38c7" [prune] go-tests = true diff --git a/flytestdlib/cli/pflags/api/generator.go b/flytestdlib/cli/pflags/api/generator.go index 93a3a040b2..cf043fadb8 100644 --- a/flytestdlib/cli/pflags/api/generator.go +++ b/flytestdlib/cli/pflags/api/generator.go @@ -145,6 +145,11 @@ func discoverFieldsRecursive(ctx context.Context, typ *types.Named, defaultValue tag.Name = v.Name() } + if tag.DefaultValue == "-" { + logger.Infof(ctx, "Skipping field [%s], as '-' value detected", tag.Name) + continue + } + typ := v.Type() ptr, isPtr := typ.(*types.Pointer) if isPtr { diff --git a/flytestdlib/cli/pflags/api/sample.go b/flytestdlib/cli/pflags/api/sample.go index f1b40e5b74..ca805e7040 100644 --- a/flytestdlib/cli/pflags/api/sample.go +++ b/flytestdlib/cli/pflags/api/sample.go @@ -19,6 +19,7 @@ type TestType struct { StringArray []string `json:"strs" pflag:"[]string{\"12\"%2C\"1\"}"` ComplexJSONArray []ComplexJSONType `json:"complexArr"` StringToJSON ComplexJSONType `json:"c" pflag:",I'm a complex type but can be converted from string."` + IgnoredMap map[string]string `json:"ignored-map" pflag:"-,"` StorageConfig storage.Config `json:"storage"` IntValue *int `json:"i"` } diff --git a/flytestdlib/cli/pflags/api/testdata/testtype.go b/flytestdlib/cli/pflags/api/testdata/testtype.go index ff6b7e63ba..c960f10044 100755 --- a/flytestdlib/cli/pflags/api/testdata/testtype.go +++ b/flytestdlib/cli/pflags/api/testdata/testtype.go @@ -48,7 +48,7 @@ func (cfg TestType) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "strs"), []string{"12", "1"}, "") cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "complexArr"), []string{}, "") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "c"), DefaultTestType.mustMarshalJSON(DefaultTestType.StringToJSON), "I'm a complex type but can be converted from string.") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.type"), DefaultTestType.StorageConfig.Type, "Sets the type of storage to configure [s3/minio/local/mem].") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.type"), DefaultTestType.StorageConfig.Type, "Sets the type of storage to configure [s3/minio/local/mem/stow].") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.connection.endpoint"), DefaultTestType.StorageConfig.Connection.Endpoint.String(), "URL for storage client to connect to.") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.connection.auth-type"), DefaultTestType.StorageConfig.Connection.AuthType, "Auth Type to use [iam, accesskey].") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.connection.access-key"), DefaultTestType.StorageConfig.Connection.AccessKey, "Access key to use. Only required when authtype is set to accesskey.") diff --git a/flytestdlib/logger/config_flags.go b/flytestdlib/logger/config_flags.go index 27be2e3440..4e2c7795e4 100755 --- a/flytestdlib/logger/config_flags.go +++ b/flytestdlib/logger/config_flags.go @@ -4,9 +4,11 @@ package logger import ( - "fmt" + "encoding/json" "reflect" + "fmt" + "github.com/spf13/pflag" ) @@ -26,6 +28,15 @@ func (Config) elemValueOrNil(v interface{}) interface{} { return v } +func (Config) mustMarshalJSON(v json.Marshaler) string { + raw, err := v.MarshalJSON() + if err != nil { + panic(err) + } + + return string(raw) +} + // GetPFlagSet will return strongly types pflags for all fields in Config and its nested types. The format of the // flags is json-name.json-sub-name... etc. func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { diff --git a/flytestdlib/logger/config_flags_test.go b/flytestdlib/logger/config_flags_test.go index 853aeac0b0..03b9878515 100755 --- a/flytestdlib/logger/config_flags_test.go +++ b/flytestdlib/logger/config_flags_test.go @@ -221,3 +221,36 @@ func TestConfig_elemValueOrNil(t *testing.T) { }) } } + +func TestConfig_mustMarshalJSON(t *testing.T) { + type fields struct { + IncludeSourceCode bool + Mute bool + Level Level + Formatter FormatterConfig + } + type args struct { + v json.Marshaler + } + tests := []struct { + name string + fields fields + args args + want string + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := Config{ + IncludeSourceCode: tt.fields.IncludeSourceCode, + Mute: tt.fields.Mute, + Level: tt.fields.Level, + Formatter: tt.fields.Formatter, + } + if got := c.mustMarshalJSON(tt.args.v); got != tt.want { + t.Errorf("Config.mustMarshalJSON() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/flytestdlib/logger/config_test.go b/flytestdlib/logger/config_test.go index 08be8dec6a..6da9415904 100644 --- a/flytestdlib/logger/config_test.go +++ b/flytestdlib/logger/config_test.go @@ -1,6 +1,7 @@ package logger import ( + "reflect" "testing" "github.com/stretchr/testify/assert" @@ -22,3 +23,19 @@ func TestSetConfig(t *testing.T) { }) } } + +func TestGetConfig(t *testing.T) { + tests := []struct { + name string + want *Config + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := GetConfig(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("GetConfig() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/flytestdlib/storage/config.go b/flytestdlib/storage/config.go index 780ef12ac9..863931ac17 100644 --- a/flytestdlib/storage/config.go +++ b/flytestdlib/storage/config.go @@ -20,6 +20,7 @@ const ( TypeS3 Type = "s3" TypeLocal Type = "local" TypeMinio Type = "minio" + TypeStow Type = "stow" ) const ( @@ -43,8 +44,9 @@ var ( // A common storage config. type Config struct { - Type Type `json:"type" pflag:",Sets the type of storage to configure [s3/minio/local/mem]."` + Type Type `json:"type" pflag:",Sets the type of storage to configure [s3/minio/local/mem/stow]."` Connection ConnectionConfig `json:"connection"` + Stow *StowConfig `json:"stow,omitempty"` InitContainer string `json:"container" pflag:",Initial container to create -if it doesn't exist-.'"` // Caching is recommended to improve the performance of underlying systems. It caches the metadata and resolving // inputs is accelerated. The size of the cache is large so understand how to configure the cache. @@ -64,6 +66,11 @@ type ConnectionConfig struct { DisableSSL bool `json:"disable-ssl" pflag:",Disables SSL connection. Should only be used for development."` } +type StowConfig struct { + Kind string `json:"kind,omitempty" pflag:"-,Kind of Stow backend to use. Refer to github/graymeta/stow"` + Config map[string]string `json:"config,omitempty" pflag:"-,Configuration for stow backend. Refer to github/graymeta/stow"` +} + type CachingConfig struct { // Maximum size of the cache where the Blob store data is cached in-memory // Refer to https://github.com/coocood/freecache to understand how to set the value diff --git a/flytestdlib/storage/config_flags.go b/flytestdlib/storage/config_flags.go index 4cde49bec6..50f634eab2 100755 --- a/flytestdlib/storage/config_flags.go +++ b/flytestdlib/storage/config_flags.go @@ -4,9 +4,11 @@ package storage import ( - "fmt" + "encoding/json" "reflect" + "fmt" + "github.com/spf13/pflag" ) @@ -26,11 +28,20 @@ func (Config) elemValueOrNil(v interface{}) interface{} { return v } +func (Config) mustMarshalJSON(v json.Marshaler) string { + raw, err := v.MarshalJSON() + if err != nil { + panic(err) + } + + return string(raw) +} + // GetPFlagSet will return strongly types pflags for all fields in Config and its nested types. The format of the // flags is json-name.json-sub-name... etc. func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags := pflag.NewFlagSet("Config", pflag.ExitOnError) - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "type"), defaultConfig.Type, "Sets the type of storage to configure [s3/minio/local/mem].") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "type"), defaultConfig.Type, "Sets the type of storage to configure [s3/minio/local/mem/stow].") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "connection.endpoint"), defaultConfig.Connection.Endpoint.String(), "URL for storage client to connect to.") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "connection.auth-type"), defaultConfig.Connection.AuthType, "Auth Type to use [iam, accesskey].") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "connection.access-key"), defaultConfig.Connection.AccessKey, "Access key to use. Only required when authtype is set to accesskey.") diff --git a/flytestdlib/storage/localstore_test.go b/flytestdlib/storage/localstore_test.go index 31f8aabeb2..7295847d5e 100644 --- a/flytestdlib/storage/localstore_test.go +++ b/flytestdlib/storage/localstore_test.go @@ -7,7 +7,9 @@ import ( "path/filepath" "testing" + "github.com/lyft/flytestdlib/contextutils" "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/promutils/labeled" "github.com/lyft/flytestdlib/config" "github.com/lyft/flytestdlib/internal/utils" @@ -15,6 +17,7 @@ import ( ) func TestNewLocalStore(t *testing.T) { + labeled.SetMetricKeys(contextutils.ProjectKey, contextutils.DomainKey, contextutils.WorkflowIDKey, contextutils.TaskIDKey) t.Run("Valid config", func(t *testing.T) { testScope := promutils.NewTestScope() store, err := newLocalRawStore(&Config{ @@ -28,7 +31,7 @@ func TestNewLocalStore(t *testing.T) { assert.NotNil(t, store) // Stow local store expects the full path after the container portion (looks like a bug to me) - rc, err := store.ReadRaw(context.TODO(), DataReference("file://testdata/testdata/config.yaml")) + rc, err := store.ReadRaw(context.TODO(), DataReference("file://testdata/config.yaml")) assert.NoError(t, err) assert.NotNil(t, rc) assert.NoError(t, rc.Close()) diff --git a/flytestdlib/storage/rawstores.go b/flytestdlib/storage/rawstores.go index 4a6d35573c..bb5af8aa24 100644 --- a/flytestdlib/storage/rawstores.go +++ b/flytestdlib/storage/rawstores.go @@ -11,8 +11,9 @@ type dataStoreCreateFn func(cfg *Config, metricsScope promutils.Scope) (RawStore var stores = map[string]dataStoreCreateFn{ TypeMemory: NewInMemoryRawStore, TypeLocal: newLocalRawStore, - TypeMinio: newS3RawStore, - TypeS3: newS3RawStore, + TypeMinio: newStowRawStore, + TypeS3: newStowRawStore, + TypeStow: newStowRawStore, } // Creates a new Data Store with the supplied config. diff --git a/flytestdlib/storage/s3store.go b/flytestdlib/storage/s3store.go deleted file mode 100644 index c2a6d9e290..0000000000 --- a/flytestdlib/storage/s3store.go +++ /dev/null @@ -1,103 +0,0 @@ -package storage - -import ( - "context" - "fmt" - - "github.com/lyft/flytestdlib/promutils" - - "github.com/aws/aws-sdk-go/aws/awserr" - awsS3 "github.com/aws/aws-sdk-go/service/s3" - "github.com/lyft/flytestdlib/logger" - "github.com/pkg/errors" - - "github.com/graymeta/stow" - "github.com/graymeta/stow/s3" -) - -func getStowConfigMap(cfg *Config) stow.ConfigMap { - // Non-nullable fields - stowConfig := stow.ConfigMap{ - s3.ConfigAuthType: cfg.Connection.AuthType, - s3.ConfigRegion: cfg.Connection.Region, - } - - // Fields that differ between minio and real S3 - if endpoint := cfg.Connection.Endpoint.String(); endpoint != "" { - stowConfig[s3.ConfigEndpoint] = endpoint - } - - if accessKey := cfg.Connection.AccessKey; accessKey != "" { - stowConfig[s3.ConfigAccessKeyID] = accessKey - } - - if secretKey := cfg.Connection.SecretKey; secretKey != "" { - stowConfig[s3.ConfigSecretKey] = secretKey - } - - if disableSsl := cfg.Connection.DisableSSL; disableSsl { - stowConfig[s3.ConfigDisableSSL] = "True" - } - - return stowConfig - -} - -func s3FQN(bucket string) DataReference { - return DataReference(fmt.Sprintf("s3://%s", bucket)) -} - -func newS3RawStore(cfg *Config, metricsScope promutils.Scope) (RawStore, error) { - if cfg.InitContainer == "" { - return nil, fmt.Errorf("initContainer is required") - } - - loc, err := stow.Dial(s3.Kind, getStowConfigMap(cfg)) - - if err != nil { - return emptyStore, fmt.Errorf("unable to configure the storage for s3. Error: %v", err) - } - - c, err := loc.Container(cfg.InitContainer) - if err != nil { - if IsNotFound(err) || awsBucketIsNotFound(err) { - c, err := loc.CreateContainer(cfg.InitContainer) - if err != nil { - // If the container's already created, move on. Otherwise, fail with error. - if awsBucketAlreadyExists(err) { - logger.Infof(context.TODO(), "Storage init-container already exists [%v].", cfg.InitContainer) - return NewStowRawStore(s3FQN(c.Name()), c, metricsScope) - } - return emptyStore, fmt.Errorf("unable to initialize container [%v]. Error: %v", cfg.InitContainer, err) - } - return NewStowRawStore(s3FQN(c.Name()), c, metricsScope) - } - return emptyStore, err - } - - return NewStowRawStore(s3FQN(c.Name()), c, metricsScope) -} - -func awsBucketIsNotFound(err error) bool { - if IsNotFound(err) { - return true - } - - if awsErr, errOk := errors.Cause(err).(awserr.Error); errOk { - return awsErr.Code() == awsS3.ErrCodeNoSuchBucket - } - - return false -} - -func awsBucketAlreadyExists(err error) bool { - if IsExists(err) { - return true - } - - if awsErr, errOk := errors.Cause(err).(awserr.Error); errOk { - return awsErr.Code() == awsS3.ErrCodeBucketAlreadyOwnedByYou - } - - return false -} diff --git a/flytestdlib/storage/s3stsore_test.go b/flytestdlib/storage/s3stsore_test.go deleted file mode 100644 index 2e8674a834..0000000000 --- a/flytestdlib/storage/s3stsore_test.go +++ /dev/null @@ -1,26 +0,0 @@ -package storage - -import ( - "testing" - - "github.com/lyft/flytestdlib/promutils" - - "github.com/lyft/flytestdlib/config" - "github.com/lyft/flytestdlib/internal/utils" - "github.com/stretchr/testify/assert" -) - -func TestNewS3RawStore(t *testing.T) { - t.Run("Missing required config", func(t *testing.T) { - testScope := promutils.NewTestScope() - _, err := NewDataStore(&Config{ - Type: TypeMinio, - InitContainer: "some-container", - Connection: ConnectionConfig{ - Endpoint: config.URL{URL: utils.MustParseURL("http://minio:9000")}, - }, - }, testScope) - - assert.Error(t, err) - }) -} diff --git a/flytestdlib/storage/stow_store_test.go b/flytestdlib/storage/stow_store_test.go index f0fd282340..bb36827acd 100644 --- a/flytestdlib/storage/stow_store_test.go +++ b/flytestdlib/storage/stow_store_test.go @@ -102,7 +102,8 @@ func (mockStowItem) Metadata() (map[string]interface{}, error) { func TestStowStore_ReadRaw(t *testing.T) { t.Run("Happy Path", func(t *testing.T) { testScope := promutils.NewTestScope() - s, err := NewStowRawStore(s3FQN("container"), newMockStowContainer("container"), testScope) + fn := fQNFn["s3"] + s, err := NewStowRawStore(fn("container"), newMockStowContainer("container"), testScope) assert.NoError(t, err) err = s.WriteRaw(context.TODO(), DataReference("s3://container/path"), 0, Options{}, bytes.NewReader([]byte{})) assert.NoError(t, err) @@ -119,7 +120,8 @@ func TestStowStore_ReadRaw(t *testing.T) { t.Run("Exceeds limit", func(t *testing.T) { testScope := promutils.NewTestScope() - s, err := NewStowRawStore(s3FQN("container"), newMockStowContainer("container"), testScope) + fn := fQNFn["s3"] + s, err := NewStowRawStore(fn("container"), newMockStowContainer("container"), testScope) assert.NoError(t, err) err = s.WriteRaw(context.TODO(), DataReference("s3://container/path"), 3*MiB, Options{}, bytes.NewReader([]byte{})) assert.NoError(t, err) diff --git a/flytestdlib/storage/stowstore.go b/flytestdlib/storage/stowstore.go new file mode 100644 index 0000000000..9249ddec21 --- /dev/null +++ b/flytestdlib/storage/stowstore.go @@ -0,0 +1,130 @@ +package storage + +import ( + "context" + "fmt" + + "github.com/aws/aws-sdk-go/aws/awserr" + awsS3 "github.com/aws/aws-sdk-go/service/s3" + + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/promutils" + + "github.com/pkg/errors" + + "github.com/graymeta/stow" + "github.com/graymeta/stow/azure" + "github.com/graymeta/stow/google" + "github.com/graymeta/stow/oracle" + "github.com/graymeta/stow/s3" + "github.com/graymeta/stow/swift" +) + +var fQNFn = map[string]func(string) DataReference{ + s3.Kind: func(bucket string) DataReference { + return DataReference(fmt.Sprintf("s3://%s", bucket)) + }, + google.Kind: func(bucket string) DataReference { + return DataReference(fmt.Sprintf("gs://%s", bucket)) + }, + oracle.Kind: func(bucket string) DataReference { + return DataReference(fmt.Sprintf("os://%s", bucket)) + }, + swift.Kind: func(bucket string) DataReference { + return DataReference(fmt.Sprintf("sw://%s", bucket)) + }, + azure.Kind: func(bucket string) DataReference { + return DataReference(fmt.Sprintf("afs://%s", bucket)) + }, +} + +func awsBucketIsNotFound(err error) bool { + if IsNotFound(err) { + return true + } + + if awsErr, errOk := errors.Cause(err).(awserr.Error); errOk { + return awsErr.Code() == awsS3.ErrCodeNoSuchBucket + } + + return false +} + +func awsBucketAlreadyExists(err error) bool { + if IsExists(err) { + return true + } + + if awsErr, errOk := errors.Cause(err).(awserr.Error); errOk { + return awsErr.Code() == awsS3.ErrCodeBucketAlreadyOwnedByYou + } + + return false +} + +func newStowRawStore(cfg *Config, metricsScope promutils.Scope) (RawStore, error) { + if cfg.InitContainer == "" { + return nil, fmt.Errorf("initContainer is required") + } + + var cfgMap stow.ConfigMap + var kind string + if cfg.Stow != nil { + kind = cfg.Stow.Kind + cfgMap = stow.ConfigMap(cfg.Stow.Config) + } else { + logger.Warnf(context.TODO(), "stow configuration section missing, defaulting to legacy s3/minio connection config") + // This is for supporting legacy configurations which configure S3 via connection config + kind = s3.Kind + cfgMap = legacyS3ConfigMap(cfg.Connection) + } + + fn, ok := fQNFn[kind] + if !ok { + return nil, errors.Errorf("unsupported stow.kind [%s], add support in flytestdlib?", kind) + } + loc, err := stow.Dial(kind, cfgMap) + if err != nil { + return emptyStore, fmt.Errorf("unable to configure the storage for %s. Error: %v", kind, err) + } + c, err := loc.Container(cfg.InitContainer) + if err != nil { + if IsNotFound(err) || awsBucketIsNotFound(err) { + c, err := loc.CreateContainer(cfg.InitContainer) + // If the container's already created, move on. Otherwise, fail with error. + if err != nil && !awsBucketAlreadyExists(err) { + return emptyStore, fmt.Errorf("unable to initialize container [%v]. Error: %v", cfg.InitContainer, err) + } + return NewStowRawStore(fn(c.Name()), c, metricsScope) + } + return emptyStore, err + } + return NewStowRawStore(fn(c.Name()), c, metricsScope) +} + +func legacyS3ConfigMap(cfg ConnectionConfig) stow.ConfigMap { + // Non-nullable fields + stowConfig := stow.ConfigMap{ + s3.ConfigAuthType: cfg.AuthType, + s3.ConfigRegion: cfg.Region, + } + + // Fields that differ between minio and real S3 + if endpoint := cfg.Endpoint.String(); endpoint != "" { + stowConfig[s3.ConfigEndpoint] = endpoint + } + + if accessKey := cfg.AccessKey; accessKey != "" { + stowConfig[s3.ConfigAccessKeyID] = accessKey + } + + if secretKey := cfg.SecretKey; secretKey != "" { + stowConfig[s3.ConfigSecretKey] = secretKey + } + + if disableSsl := cfg.DisableSSL; disableSsl { + stowConfig[s3.ConfigDisableSSL] = "True" + } + + return stowConfig +} diff --git a/flytestdlib/storage/stowstore_test.go b/flytestdlib/storage/stowstore_test.go new file mode 100644 index 0000000000..283ee71cd4 --- /dev/null +++ b/flytestdlib/storage/stowstore_test.go @@ -0,0 +1,53 @@ +package storage + +import ( + "testing" + + "github.com/graymeta/stow/google" + "github.com/stretchr/testify/assert" + + "github.com/lyft/flytestdlib/config" + "github.com/lyft/flytestdlib/internal/utils" + "github.com/lyft/flytestdlib/promutils" +) + +func Test_newStowRawStore(t *testing.T) { + type args struct { + cfg *Config + metricsScope promutils.Scope + } + tests := []struct { + name string + args args + wantErr bool + }{ + {"fail", args{&Config{}, promutils.NewTestScope()}, true}, + {"google", args{&Config{ + InitContainer: "flyte", + Stow: &StowConfig{ + Kind: google.Kind, + Config: map[string]string{ + google.ConfigProjectId: "x", + google.ConfigScopes: "y", + }, + }, + }, promutils.NewTestScope()}, true}, + {"minio", args{&Config{ + Type: TypeMinio, + InitContainer: "some-container", + Connection: ConnectionConfig{ + Endpoint: config.URL{URL: utils.MustParseURL("http://minio:9000")}, + }, + }, promutils.NewTestScope()}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := newStowRawStore(tt.args.cfg, tt.args.metricsScope) + if tt.wantErr { + assert.Error(t, err, "newStowRawStore() error = %v, wantErr %v", err, tt.wantErr) + return + } + assert.NotNil(t, got, "Expected rawstore, found nil!") + }) + } +} From c44f6af150004aba59dc9d750947c741c20f6614 Mon Sep 17 00:00:00 2001 From: goreleaserbot Date: Fri, 26 Jul 2019 23:05:36 -0700 Subject: [PATCH 0063/1918] Scoop update for flytestdlib version v0.2.12 --- flytestdlib/flytestdlib.json | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/flytestdlib/flytestdlib.json b/flytestdlib/flytestdlib.json index 0246a2c4ef..94df44543b 100644 --- a/flytestdlib/flytestdlib.json +++ b/flytestdlib/flytestdlib.json @@ -1,19 +1,19 @@ { - "version": "0.2.11", + "version": "0.2.12", "architecture": { "32bit": { - "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.11/flytestdlib_0.2.11_Windows_i386.zip", + "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.12/flytestdlib_0.2.12_Windows_i386.zip", "bin": [ "pflags.exe" ], - "hash": "4ace26d248fea1f01e59cb613d8353c33800a7f8420dc70b48d50e71d8602ce3" + "hash": "26e3f29078de38d766dee6407a21a7302d0d1b8020ed8d11232c497ebcf6edf6" }, "64bit": { - "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.11/flytestdlib_0.2.11_Windows_x86_64.zip", + "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.12/flytestdlib_0.2.12_Windows_x86_64.zip", "bin": [ "pflags.exe" ], - "hash": "665370f8394034a2dbe9334b1cd698624da2cfaeb353c67d404f10a365042c00" + "hash": "5158ee3d8c8e2d7f0b5a724230d5374cfd3dfcd71c6061cf1c52277c3421e0f0" } }, "homepage": "https://godoc.org/github.com/lyft/flytestdlib", From be62dfd236ff99b888e1e01bbc84c6fb769829fb Mon Sep 17 00:00:00 2001 From: Ketan Umare Date: Mon, 5 Aug 2019 12:04:34 -0700 Subject: [PATCH 0064/1918] Better Error logging in case containers misconfigured - Configured container and the data reference may not match. - This causes the metric BadContainer to be incremented, but the logs do not really indicate the problem. --- flytestdlib/storage/stow_store.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flytestdlib/storage/stow_store.go b/flytestdlib/storage/stow_store.go index e156f376da..998fc8a876 100644 --- a/flytestdlib/storage/stow_store.go +++ b/flytestdlib/storage/stow_store.go @@ -53,7 +53,7 @@ func (s StowMetadata) Exists() bool { func (s *StowStore) getContainer(container string) (c stow.Container, err error) { if s.Container.Name() != container { s.metrics.BadContainer.Inc() - return nil, stow.ErrNotFound + return nil, errs.Wrapf(stow.ErrNotFound, "Conf container:%v != Passed Container:%v", s.Container.Name(), container) } return s.Container, nil From 96c7a8af5c0b5a51276c762a4d6b9d2918a05a5a Mon Sep 17 00:00:00 2001 From: goreleaserbot Date: Mon, 5 Aug 2019 12:16:12 -0700 Subject: [PATCH 0065/1918] Scoop update for flytestdlib version v0.2.13 --- flytestdlib/flytestdlib.json | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/flytestdlib/flytestdlib.json b/flytestdlib/flytestdlib.json index 94df44543b..3c77a94c12 100644 --- a/flytestdlib/flytestdlib.json +++ b/flytestdlib/flytestdlib.json @@ -1,19 +1,19 @@ { - "version": "0.2.12", + "version": "0.2.13", "architecture": { "32bit": { - "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.12/flytestdlib_0.2.12_Windows_i386.zip", + "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.13/flytestdlib_0.2.13_Windows_i386.zip", "bin": [ "pflags.exe" ], - "hash": "26e3f29078de38d766dee6407a21a7302d0d1b8020ed8d11232c497ebcf6edf6" + "hash": "6786dfc4cc28a253bc2ec1398e38c53970504bc18c4c7f830f25f4a50736d051" }, "64bit": { - "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.12/flytestdlib_0.2.12_Windows_x86_64.zip", + "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.13/flytestdlib_0.2.13_Windows_x86_64.zip", "bin": [ "pflags.exe" ], - "hash": "5158ee3d8c8e2d7f0b5a724230d5374cfd3dfcd71c6061cf1c52277c3421e0f0" + "hash": "9c753bd02018725a1ec9c277da2556ef929307ec4620ed2ac60e9d700cb22e56" } }, "homepage": "https://godoc.org/github.com/lyft/flytestdlib", From faa4fd3437ea16df57368e24c025f76693f99235 Mon Sep 17 00:00:00 2001 From: Chang-Hong Hsu Date: Mon, 5 Aug 2019 15:45:34 -0700 Subject: [PATCH 0066/1918] Fix exception handling of caching errors for copyImpl (#30) * fix exception handling of caching errors for copyImpl * add test for the fix in copyImpl --- flytestdlib/storage/copy_impl.go | 34 +++++++++++---- flytestdlib/storage/copy_impl_test.go | 60 +++++++++++++++++++++++++++ 2 files changed, 87 insertions(+), 7 deletions(-) diff --git a/flytestdlib/storage/copy_impl.go b/flytestdlib/storage/copy_impl.go index 4ac132a229..178938f420 100644 --- a/flytestdlib/storage/copy_impl.go +++ b/flytestdlib/storage/copy_impl.go @@ -2,12 +2,17 @@ package storage import ( "context" + "fmt" "io" "time" + "github.com/lyft/flytestdlib/logger" + "github.com/prometheus/client_golang/prometheus" + "github.com/lyft/flytestdlib/ioutils" "github.com/lyft/flytestdlib/promutils" "github.com/lyft/flytestdlib/promutils/labeled" + errs "github.com/pkg/errors" ) type copyImpl struct { @@ -16,8 +21,10 @@ type copyImpl struct { } type copyMetrics struct { - CopyLatency labeled.StopWatch - ComputeLengthLatency labeled.StopWatch + CopyLatency labeled.StopWatch + ComputeLengthLatency labeled.StopWatch + WriteFailureUnrelatedToCache prometheus.Counter + ReadFailureUnrelatedToCache prometheus.Counter } // A naiive implementation for copy that reads all data locally then writes them to destination. @@ -25,8 +32,11 @@ type copyMetrics struct { // https://docs.aws.amazon.com/AmazonS3/latest/dev/CopyingObjectUsingREST.html func (c copyImpl) CopyRaw(ctx context.Context, source, destination DataReference, opts Options) error { rc, err := c.rawStore.ReadRaw(ctx, source) - if err != nil { - return err + + if err != nil && !IsFailedWriteToCache(err) { + logger.Errorf(ctx, "Failed to read from the raw store. Error: %v", err) + c.metrics.ReadFailureUnrelatedToCache.Inc() + return errs.Wrap(err, fmt.Sprintf("path:%v", destination)) } length := int64(0) @@ -43,13 +53,23 @@ func (c copyImpl) CopyRaw(ctx context.Context, source, destination DataReference length = int64(len(raw)) } - return c.rawStore.WriteRaw(ctx, destination, length, Options{}, rc) + err = c.rawStore.WriteRaw(ctx, destination, length, Options{}, rc) + + if err != nil && !IsFailedWriteToCache(err) { + logger.Errorf(ctx, "Failed to write to the raw store. Error: %v", err) + c.metrics.WriteFailureUnrelatedToCache.Inc() + return err + } + + return nil } func newCopyMetrics(scope promutils.Scope) copyMetrics { return copyMetrics{ - CopyLatency: labeled.NewStopWatch("overall", "Overall copy latency", time.Millisecond, scope, labeled.EmitUnlabeledMetric), - ComputeLengthLatency: labeled.NewStopWatch("length", "Latency involved in computing length of content before writing.", time.Millisecond, scope, labeled.EmitUnlabeledMetric), + CopyLatency: labeled.NewStopWatch("overall", "Overall copy latency", time.Millisecond, scope, labeled.EmitUnlabeledMetric), + ComputeLengthLatency: labeled.NewStopWatch("length", "Latency involved in computing length of content before writing.", time.Millisecond, scope, labeled.EmitUnlabeledMetric), + WriteFailureUnrelatedToCache: scope.MustNewCounter("write_failure_unrelated_to_cache", "Raw store write failures that are not caused by ErrFailedToWriteCache"), + ReadFailureUnrelatedToCache: scope.MustNewCounter("read_failure_unrelated_to_cache", "Raw store read failures that are not caused by ErrFailedToWriteCache"), } } diff --git a/flytestdlib/storage/copy_impl_test.go b/flytestdlib/storage/copy_impl_test.go index fc8d78cd1a..ccc44924d6 100644 --- a/flytestdlib/storage/copy_impl_test.go +++ b/flytestdlib/storage/copy_impl_test.go @@ -2,9 +2,13 @@ package storage import ( "context" + "fmt" "io" + "math/rand" "testing" + "github.com/lyft/flytestdlib/errors" + "github.com/lyft/flytestdlib/ioutils" "github.com/lyft/flytestdlib/promutils" "github.com/stretchr/testify/assert" @@ -40,6 +44,7 @@ func newNotSeekerReader(bytesCount int) *notSeekerReader { } func TestCopyRaw(t *testing.T) { + resetMetricKeys() t.Run("Called", func(t *testing.T) { readerCalled := false writerCalled := false @@ -80,3 +85,58 @@ func TestCopyRaw(t *testing.T) { assert.True(t, writerCalled) }) } + +func TestCopyRaw_CachingErrorHandling(t *testing.T) { + resetMetricKeys() + t.Run("CopyRaw with Caching Error", func(t *testing.T) { + readerCalled := false + writerCalled := false + bigD := make([]byte, 1.5*1024*1024) + // #nosec G404 + rand.Read(bigD) + dummyErrorMsg := "Dummy caching error" + + store := dummyStore{ + ReadRawCb: func(ctx context.Context, reference DataReference) (closer io.ReadCloser, e error) { + readerCalled = true + return ioutils.NewBytesReadCloser(bigD), errors.Wrapf(ErrFailedToWriteCache, fmt.Errorf(dummyErrorMsg), "Failed to Cache the metadata") + }, + WriteRawCb: func(ctx context.Context, reference DataReference, size int64, opts Options, raw io.Reader) error { + writerCalled = true + return errors.Wrapf(ErrFailedToWriteCache, fmt.Errorf(dummyErrorMsg), "Failed to Cache the metadata") + }, + } + + copier := newCopyImpl(&store, promutils.NewTestScope()) + assert.NoError(t, copier.CopyRaw(context.Background(), DataReference("source.pb"), DataReference("dest.pb"), Options{})) + assert.True(t, readerCalled) + assert.True(t, writerCalled) + }) + + t.Run("CopyRaw with Hard Error", func(t *testing.T) { + readerCalled := false + writerCalled := false + bigD := make([]byte, 1.5*1024*1024) + // #nosec G404 + rand.Read(bigD) + dummyErrorMsg := "Dummy non-caching error" + + store := dummyStore{ + ReadRawCb: func(ctx context.Context, reference DataReference) (closer io.ReadCloser, e error) { + readerCalled = true + return ioutils.NewBytesReadCloser(bigD), fmt.Errorf(dummyErrorMsg) + }, + WriteRawCb: func(ctx context.Context, reference DataReference, size int64, opts Options, raw io.Reader) error { + writerCalled = true + return fmt.Errorf(dummyErrorMsg) + }, + } + + copier := newCopyImpl(&store, promutils.NewTestScope()) + err := copier.CopyRaw(context.Background(), DataReference("source.pb"), DataReference("dest.pb"), Options{}) + assert.True(t, readerCalled) + // writerCalled should be false because CopyRaw should error out right after c.rawstore.ReadRaw() when the underlying error is a hard error + assert.False(t, writerCalled) + assert.False(t, IsFailedWriteToCache(err)) + }) +} From cbe9f5f69880443bc5b91953ef445a5c9e3d52f0 Mon Sep 17 00:00:00 2001 From: goreleaserbot Date: Mon, 5 Aug 2019 15:54:52 -0700 Subject: [PATCH 0067/1918] Scoop update for flytestdlib version v0.2.14 --- flytestdlib/flytestdlib.json | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/flytestdlib/flytestdlib.json b/flytestdlib/flytestdlib.json index 3c77a94c12..691f6d3697 100644 --- a/flytestdlib/flytestdlib.json +++ b/flytestdlib/flytestdlib.json @@ -1,19 +1,19 @@ { - "version": "0.2.13", + "version": "0.2.14", "architecture": { "32bit": { - "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.13/flytestdlib_0.2.13_Windows_i386.zip", + "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.14/flytestdlib_0.2.14_Windows_i386.zip", "bin": [ "pflags.exe" ], - "hash": "6786dfc4cc28a253bc2ec1398e38c53970504bc18c4c7f830f25f4a50736d051" + "hash": "2bc335843e23bd4644cac685d58433b46e54990fe6102c77172b3b811ba78ac8" }, "64bit": { - "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.13/flytestdlib_0.2.13_Windows_x86_64.zip", + "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.14/flytestdlib_0.2.14_Windows_x86_64.zip", "bin": [ "pflags.exe" ], - "hash": "9c753bd02018725a1ec9c277da2556ef929307ec4620ed2ac60e9d700cb22e56" + "hash": "44acfc95227ebd871d69298b93c7832cfbe48a545d693a818d888fe379d0c656" } }, "homepage": "https://godoc.org/github.com/lyft/flytestdlib", From 34229de335cd5bccc216ff5559694fd7884c27a6 Mon Sep 17 00:00:00 2001 From: Chang-Hong Hsu Date: Tue, 6 Aug 2019 10:06:54 -0700 Subject: [PATCH 0068/1918] improve msg in CopyRaw (#31) --- flytestdlib/storage/copy_impl.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flytestdlib/storage/copy_impl.go b/flytestdlib/storage/copy_impl.go index 178938f420..9bf39e1d66 100644 --- a/flytestdlib/storage/copy_impl.go +++ b/flytestdlib/storage/copy_impl.go @@ -34,7 +34,7 @@ func (c copyImpl) CopyRaw(ctx context.Context, source, destination DataReference rc, err := c.rawStore.ReadRaw(ctx, source) if err != nil && !IsFailedWriteToCache(err) { - logger.Errorf(ctx, "Failed to read from the raw store. Error: %v", err) + logger.Errorf(ctx, "Failed to read from the raw store when copying. Error: %v", err) c.metrics.ReadFailureUnrelatedToCache.Inc() return errs.Wrap(err, fmt.Sprintf("path:%v", destination)) } @@ -56,7 +56,7 @@ func (c copyImpl) CopyRaw(ctx context.Context, source, destination DataReference err = c.rawStore.WriteRaw(ctx, destination, length, Options{}, rc) if err != nil && !IsFailedWriteToCache(err) { - logger.Errorf(ctx, "Failed to write to the raw store. Error: %v", err) + logger.Errorf(ctx, "Failed to write to the raw store when copying. Error: %v", err) c.metrics.WriteFailureUnrelatedToCache.Inc() return err } From 91b48d8e3be5cf0c65281e725cb67cd044e18c55 Mon Sep 17 00:00:00 2001 From: goreleaserbot Date: Tue, 6 Aug 2019 10:18:09 -0700 Subject: [PATCH 0069/1918] Scoop update for flytestdlib version v0.2.15 --- flytestdlib/flytestdlib.json | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/flytestdlib/flytestdlib.json b/flytestdlib/flytestdlib.json index 691f6d3697..f07bee89c7 100644 --- a/flytestdlib/flytestdlib.json +++ b/flytestdlib/flytestdlib.json @@ -1,19 +1,19 @@ { - "version": "0.2.14", + "version": "0.2.15", "architecture": { "32bit": { - "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.14/flytestdlib_0.2.14_Windows_i386.zip", + "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.15/flytestdlib_0.2.15_Windows_i386.zip", "bin": [ "pflags.exe" ], - "hash": "2bc335843e23bd4644cac685d58433b46e54990fe6102c77172b3b811ba78ac8" + "hash": "8fa9e2f415e82a14a9c7283be8809ada61cd6da68878228139aeed1d471dc0d9" }, "64bit": { - "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.14/flytestdlib_0.2.14_Windows_x86_64.zip", + "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.15/flytestdlib_0.2.15_Windows_x86_64.zip", "bin": [ "pflags.exe" ], - "hash": "44acfc95227ebd871d69298b93c7832cfbe48a545d693a818d888fe379d0c656" + "hash": "faf6d8f7f20a3d2453b83eb3d739a3e3c4e6e0e08ddbd1182cb464bd4c2ae5da" } }, "homepage": "https://godoc.org/github.com/lyft/flytestdlib", From 502d96894ca62eb12d410550c9921e0da5f13f45 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Thu, 8 Aug 2019 19:22:13 -0700 Subject: [PATCH 0070/1918] Improve logging in protobuf_store --- flytestdlib/storage/protobuf_store.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flytestdlib/storage/protobuf_store.go b/flytestdlib/storage/protobuf_store.go index 6b09a0e6eb..9b2047bbbb 100644 --- a/flytestdlib/storage/protobuf_store.go +++ b/flytestdlib/storage/protobuf_store.go @@ -35,7 +35,7 @@ type DefaultProtobufStore struct { func (s DefaultProtobufStore) ReadProtobuf(ctx context.Context, reference DataReference, msg proto.Message) error { rc, err := s.ReadRaw(ctx, reference) if err != nil && !IsFailedWriteToCache(err) { - logger.Errorf(ctx, "Failed to read from the raw store. Error: %v", err) + logger.Errorf(ctx, "Failed to read from the raw store [%s] Error: %v", reference, err) s.metrics.ReadFailureUnrelatedToCache.Inc() return errs.Wrap(err, fmt.Sprintf("path:%v", reference)) } @@ -74,7 +74,7 @@ func (s DefaultProtobufStore) WriteProtobuf(ctx context.Context, reference DataR err = s.WriteRaw(ctx, reference, int64(len(raw)), opts, bytes.NewReader(raw)) if err != nil && !IsFailedWriteToCache(err) { - logger.Errorf(ctx, "Failed to write to the raw store. Error: %v", err) + logger.Errorf(ctx, "Failed to write to the raw store [%s] Error: %v", reference, err) s.metrics.WriteFailureUnrelatedToCache.Inc() return err } From 80f8ea736aacf854e5eafb1cf7df8390fd9cc765 Mon Sep 17 00:00:00 2001 From: Chang-Hong Hsu Date: Mon, 12 Aug 2019 14:32:24 -0700 Subject: [PATCH 0071/1918] add source and destination in CopyRaw's error msg (#33) --- flytestdlib/storage/copy_impl.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flytestdlib/storage/copy_impl.go b/flytestdlib/storage/copy_impl.go index 9bf39e1d66..56c7d54717 100644 --- a/flytestdlib/storage/copy_impl.go +++ b/flytestdlib/storage/copy_impl.go @@ -34,7 +34,7 @@ func (c copyImpl) CopyRaw(ctx context.Context, source, destination DataReference rc, err := c.rawStore.ReadRaw(ctx, source) if err != nil && !IsFailedWriteToCache(err) { - logger.Errorf(ctx, "Failed to read from the raw store when copying. Error: %v", err) + logger.Errorf(ctx, "Failed to read from the raw store when copying [%v] to [%v]. Error: %v", source, destination, err) c.metrics.ReadFailureUnrelatedToCache.Inc() return errs.Wrap(err, fmt.Sprintf("path:%v", destination)) } From 51c549ce971f021fe7d2177b8dedf19751b06c03 Mon Sep 17 00:00:00 2001 From: goreleaserbot Date: Mon, 12 Aug 2019 14:40:41 -0700 Subject: [PATCH 0072/1918] Scoop update for flytestdlib version v0.2.16 --- flytestdlib/flytestdlib.json | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/flytestdlib/flytestdlib.json b/flytestdlib/flytestdlib.json index f07bee89c7..6b17bde2f0 100644 --- a/flytestdlib/flytestdlib.json +++ b/flytestdlib/flytestdlib.json @@ -1,19 +1,19 @@ { - "version": "0.2.15", + "version": "0.2.16", "architecture": { "32bit": { - "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.15/flytestdlib_0.2.15_Windows_i386.zip", + "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.16/flytestdlib_0.2.16_Windows_i386.zip", "bin": [ "pflags.exe" ], - "hash": "8fa9e2f415e82a14a9c7283be8809ada61cd6da68878228139aeed1d471dc0d9" + "hash": "042389e9e0d08df89d771962b10ed193cd2834a8328d90f307a2acf3bbbb4094" }, "64bit": { - "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.15/flytestdlib_0.2.15_Windows_x86_64.zip", + "url": "https://github.com/lyft/flytestdlib/releases/download/v0.2.16/flytestdlib_0.2.16_Windows_x86_64.zip", "bin": [ "pflags.exe" ], - "hash": "faf6d8f7f20a3d2453b83eb3d739a3e3c4e6e0e08ddbd1182cb464bd4c2ae5da" + "hash": "030581539f3e55c525bbf0157058701cfd7281ef0b40156c4b4049269bc0093b" } }, "homepage": "https://godoc.org/github.com/lyft/flytestdlib", From 9f01f564bf30916b55805e04814108364d27870d Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Tue, 20 Aug 2019 22:36:57 +0200 Subject: [PATCH 0073/1918] Initial Commit Initial commit including basic flyte plugins and plugins interface. --- flyteplugins/CODE_OF_CONDUCT.md | 3 + flyteplugins/Gopkg.lock | 1057 +++++++++++++++++ flyteplugins/Gopkg.toml | 52 + flyteplugins/LICENSE | 202 ++++ flyteplugins/Makefile | 14 + flyteplugins/NOTICE | 4 + flyteplugins/README.md | 2 + .../lyft/golang_test_targets/Makefile | 38 + .../lyft/golang_test_targets/Readme.rst | 31 + .../lyft/golang_test_targets/goimports | 8 + .../lyft/golangci_file/.golangci.yml | 30 + .../boilerplate/lyft/golangci_file/Readme.rst | 8 + .../boilerplate/lyft/golangci_file/update.sh | 14 + flyteplugins/boilerplate/update.cfg | 2 + flyteplugins/boilerplate/update.sh | 53 + flyteplugins/go/tasks/loader.go | 28 + flyteplugins/go/tasks/v1/config/config.go | 29 + .../go/tasks/v1/config/config_flags.go | 46 + .../go/tasks/v1/config/config_flags_test.go | 124 ++ flyteplugins/go/tasks/v1/config_load_test.go | 94 ++ flyteplugins/go/tasks/v1/errors/errors.go | 76 ++ .../go/tasks/v1/events/event_utils.go | 100 ++ .../go/tasks/v1/events/event_utils_test.go | 193 +++ flyteplugins/go/tasks/v1/factory.go | 143 +++ .../go/tasks/v1/flytek8s/config/config.go | 59 + .../go/tasks/v1/flytek8s/constants.go | 11 + .../go/tasks/v1/flytek8s/container_helper.go | 123 ++ .../v1/flytek8s/container_helper_test.go | 98 ++ .../go/tasks/v1/flytek8s/k8s_resource_adds.go | 154 +++ .../v1/flytek8s/k8s_resource_adds_test.go | 283 +++++ .../go/tasks/v1/flytek8s/mocks/Cache.go | 25 + .../go/tasks/v1/flytek8s/mocks/Handler.go | 27 + .../go/tasks/v1/flytek8s/mocks/K8sResource.go | 383 ++++++ .../v1/flytek8s/mocks/K8sResourceHandler.go | 91 ++ .../tasks/v1/flytek8s/mocks/RuntimeClient.go | 142 +++ .../go/tasks/v1/flytek8s/mux_handler.go | 163 +++ .../go/tasks/v1/flytek8s/plugin_executor.go | 357 ++++++ .../tasks/v1/flytek8s/plugin_executor_test.go | 608 ++++++++++ .../go/tasks/v1/flytek8s/plugin_iface.go | 38 + .../go/tasks/v1/flytek8s/pod_helper.go | 159 +++ .../go/tasks/v1/flytek8s/pod_helper_test.go | 317 +++++ flyteplugins/go/tasks/v1/flytek8s/utils.go | 39 + .../go/tasks/v1/flytek8s/utils_test.go | 30 + .../go/tasks/v1/k8splugins/container.go | 113 ++ .../go/tasks/v1/k8splugins/container_test.go | 231 ++++ .../v1/k8splugins/mocks/AutoRefreshCache.go | 56 + .../tasks/v1/k8splugins/mocks/sidecar_custom | 59 + .../go/tasks/v1/k8splugins/sidecar.go | 196 +++ .../go/tasks/v1/k8splugins/sidecar_test.go | 231 ++++ flyteplugins/go/tasks/v1/k8splugins/spark.go | 304 +++++ .../go/tasks/v1/k8splugins/spark_test.go | 270 +++++ .../go/tasks/v1/k8splugins/waitable_task.go | 552 +++++++++ .../tasks/v1/k8splugins/waitable_task_test.go | 393 ++++++ flyteplugins/go/tasks/v1/logs/config.go | 33 + .../go/tasks/v1/logs/logconfig_flags.go | 53 + .../go/tasks/v1/logs/logconfig_flags_test.go | 278 +++++ .../go/tasks/v1/logs/logging_utils.go | 61 + .../go/tasks/v1/logs/logging_utils_test.go | 205 ++++ .../v1/qubole/client/mocks/QuboleClient.go | 70 ++ .../tasks/v1/qubole/client/qubole_client.go | 253 ++++ .../v1/qubole/client/qubole_client_test.go | 142 +++ .../tasks/v1/qubole/client/qubole_status.go | 38 + .../go/tasks/v1/qubole/config/config.go | 46 + .../go/tasks/v1/qubole/config/config_flags.go | 53 + .../v1/qubole/config/config_flags_test.go | 278 +++++ .../go/tasks/v1/qubole/hive_executor.go | 574 +++++++++ .../go/tasks/v1/qubole/hive_executor_test.go | 403 +++++++ .../tasks/v1/qubole/mocks/AutoRefreshCache.go | 56 + .../go/tasks/v1/qubole/qubole_work.go | 202 ++++ .../go/tasks/v1/qubole/qubole_work_test.go | 177 +++ .../go/tasks/v1/qubole/secrets_manager.go | 56 + .../go/tasks/v1/qubole/test_helper.go | 116 ++ flyteplugins/go/tasks/v1/registry.go | 24 + .../v1/resourcemanager/lookaside_buffer.go | 18 + .../mocks/execution_lookside_buffer.go | 46 + .../resourcemanager/mocks/resource_manager.go | 47 + .../mocks/resource_manager_ext.go | 44 + .../tasks/v1/resourcemanager/redis_client.go | 24 + .../resourcemanager/redis_lookaside_buffer.go | 46 + .../redis_lookaside_buffer_test.go | 10 + .../resourcemanager/redis_resource_manager.go | 122 ++ .../v1/resourcemanager/resource_manager.go | 68 ++ .../resourcemanager/resource_manager_test.go | 88 ++ flyteplugins/go/tasks/v1/testdata/config.yaml | 63 + .../go/tasks/v1/types/mocks/EventRecorder.go | 26 + .../go/tasks/v1/types/mocks/Executor.go | 141 +++ .../go/tasks/v1/types/mocks/TaskContext.go | 234 ++++ .../tasks/v1/types/mocks/TaskExecutionID.go | 39 + .../go/tasks/v1/types/mocks/TaskOverrides.go | 44 + .../go/tasks/v1/types/outputs_resolver.go | 46 + flyteplugins/go/tasks/v1/types/status.go | 131 ++ flyteplugins/go/tasks/v1/types/status_test.go | 34 + flyteplugins/go/tasks/v1/types/task.go | 92 ++ .../go/tasks/v1/types/task_context.go | 44 + flyteplugins/go/tasks/v1/types/task_test.go | 52 + .../go/tasks/v1/utils/marshal_utils.go | 66 + flyteplugins/go/tasks/v1/utils/template.go | 152 +++ .../go/tasks/v1/utils/template_test.go | 239 ++++ .../go/tasks/v1/utils/transformers.go | 26 + .../go/tasks/v1/utils/transformers_test.go | 18 + flyteplugins/tests/hive_integration_test.go | 201 ++++ .../tests/redis_lookaside_buffer_test.go | 35 + .../tests/redis_resource_manager_test.go | 36 + 103 files changed, 13213 insertions(+) create mode 100755 flyteplugins/CODE_OF_CONDUCT.md create mode 100755 flyteplugins/Gopkg.lock create mode 100755 flyteplugins/Gopkg.toml create mode 100755 flyteplugins/LICENSE create mode 100755 flyteplugins/Makefile create mode 100755 flyteplugins/NOTICE create mode 100755 flyteplugins/README.md create mode 100755 flyteplugins/boilerplate/lyft/golang_test_targets/Makefile create mode 100755 flyteplugins/boilerplate/lyft/golang_test_targets/Readme.rst create mode 100755 flyteplugins/boilerplate/lyft/golang_test_targets/goimports create mode 100755 flyteplugins/boilerplate/lyft/golangci_file/.golangci.yml create mode 100755 flyteplugins/boilerplate/lyft/golangci_file/Readme.rst create mode 100755 flyteplugins/boilerplate/lyft/golangci_file/update.sh create mode 100755 flyteplugins/boilerplate/update.cfg create mode 100755 flyteplugins/boilerplate/update.sh create mode 100755 flyteplugins/go/tasks/loader.go create mode 100755 flyteplugins/go/tasks/v1/config/config.go create mode 100755 flyteplugins/go/tasks/v1/config/config_flags.go create mode 100755 flyteplugins/go/tasks/v1/config/config_flags_test.go create mode 100755 flyteplugins/go/tasks/v1/config_load_test.go create mode 100755 flyteplugins/go/tasks/v1/errors/errors.go create mode 100755 flyteplugins/go/tasks/v1/events/event_utils.go create mode 100755 flyteplugins/go/tasks/v1/events/event_utils_test.go create mode 100755 flyteplugins/go/tasks/v1/factory.go create mode 100755 flyteplugins/go/tasks/v1/flytek8s/config/config.go create mode 100755 flyteplugins/go/tasks/v1/flytek8s/constants.go create mode 100755 flyteplugins/go/tasks/v1/flytek8s/container_helper.go create mode 100755 flyteplugins/go/tasks/v1/flytek8s/container_helper_test.go create mode 100755 flyteplugins/go/tasks/v1/flytek8s/k8s_resource_adds.go create mode 100755 flyteplugins/go/tasks/v1/flytek8s/k8s_resource_adds_test.go create mode 100755 flyteplugins/go/tasks/v1/flytek8s/mocks/Cache.go create mode 100755 flyteplugins/go/tasks/v1/flytek8s/mocks/Handler.go create mode 100755 flyteplugins/go/tasks/v1/flytek8s/mocks/K8sResource.go create mode 100755 flyteplugins/go/tasks/v1/flytek8s/mocks/K8sResourceHandler.go create mode 100755 flyteplugins/go/tasks/v1/flytek8s/mocks/RuntimeClient.go create mode 100755 flyteplugins/go/tasks/v1/flytek8s/mux_handler.go create mode 100755 flyteplugins/go/tasks/v1/flytek8s/plugin_executor.go create mode 100755 flyteplugins/go/tasks/v1/flytek8s/plugin_executor_test.go create mode 100755 flyteplugins/go/tasks/v1/flytek8s/plugin_iface.go create mode 100755 flyteplugins/go/tasks/v1/flytek8s/pod_helper.go create mode 100755 flyteplugins/go/tasks/v1/flytek8s/pod_helper_test.go create mode 100755 flyteplugins/go/tasks/v1/flytek8s/utils.go create mode 100755 flyteplugins/go/tasks/v1/flytek8s/utils_test.go create mode 100755 flyteplugins/go/tasks/v1/k8splugins/container.go create mode 100755 flyteplugins/go/tasks/v1/k8splugins/container_test.go create mode 100755 flyteplugins/go/tasks/v1/k8splugins/mocks/AutoRefreshCache.go create mode 100755 flyteplugins/go/tasks/v1/k8splugins/mocks/sidecar_custom create mode 100755 flyteplugins/go/tasks/v1/k8splugins/sidecar.go create mode 100755 flyteplugins/go/tasks/v1/k8splugins/sidecar_test.go create mode 100755 flyteplugins/go/tasks/v1/k8splugins/spark.go create mode 100755 flyteplugins/go/tasks/v1/k8splugins/spark_test.go create mode 100755 flyteplugins/go/tasks/v1/k8splugins/waitable_task.go create mode 100755 flyteplugins/go/tasks/v1/k8splugins/waitable_task_test.go create mode 100755 flyteplugins/go/tasks/v1/logs/config.go create mode 100755 flyteplugins/go/tasks/v1/logs/logconfig_flags.go create mode 100755 flyteplugins/go/tasks/v1/logs/logconfig_flags_test.go create mode 100755 flyteplugins/go/tasks/v1/logs/logging_utils.go create mode 100755 flyteplugins/go/tasks/v1/logs/logging_utils_test.go create mode 100755 flyteplugins/go/tasks/v1/qubole/client/mocks/QuboleClient.go create mode 100755 flyteplugins/go/tasks/v1/qubole/client/qubole_client.go create mode 100755 flyteplugins/go/tasks/v1/qubole/client/qubole_client_test.go create mode 100755 flyteplugins/go/tasks/v1/qubole/client/qubole_status.go create mode 100755 flyteplugins/go/tasks/v1/qubole/config/config.go create mode 100755 flyteplugins/go/tasks/v1/qubole/config/config_flags.go create mode 100755 flyteplugins/go/tasks/v1/qubole/config/config_flags_test.go create mode 100755 flyteplugins/go/tasks/v1/qubole/hive_executor.go create mode 100755 flyteplugins/go/tasks/v1/qubole/hive_executor_test.go create mode 100755 flyteplugins/go/tasks/v1/qubole/mocks/AutoRefreshCache.go create mode 100755 flyteplugins/go/tasks/v1/qubole/qubole_work.go create mode 100755 flyteplugins/go/tasks/v1/qubole/qubole_work_test.go create mode 100755 flyteplugins/go/tasks/v1/qubole/secrets_manager.go create mode 100755 flyteplugins/go/tasks/v1/qubole/test_helper.go create mode 100755 flyteplugins/go/tasks/v1/registry.go create mode 100755 flyteplugins/go/tasks/v1/resourcemanager/lookaside_buffer.go create mode 100755 flyteplugins/go/tasks/v1/resourcemanager/mocks/execution_lookside_buffer.go create mode 100755 flyteplugins/go/tasks/v1/resourcemanager/mocks/resource_manager.go create mode 100755 flyteplugins/go/tasks/v1/resourcemanager/mocks/resource_manager_ext.go create mode 100755 flyteplugins/go/tasks/v1/resourcemanager/redis_client.go create mode 100755 flyteplugins/go/tasks/v1/resourcemanager/redis_lookaside_buffer.go create mode 100755 flyteplugins/go/tasks/v1/resourcemanager/redis_lookaside_buffer_test.go create mode 100755 flyteplugins/go/tasks/v1/resourcemanager/redis_resource_manager.go create mode 100755 flyteplugins/go/tasks/v1/resourcemanager/resource_manager.go create mode 100755 flyteplugins/go/tasks/v1/resourcemanager/resource_manager_test.go create mode 100755 flyteplugins/go/tasks/v1/testdata/config.yaml create mode 100755 flyteplugins/go/tasks/v1/types/mocks/EventRecorder.go create mode 100755 flyteplugins/go/tasks/v1/types/mocks/Executor.go create mode 100755 flyteplugins/go/tasks/v1/types/mocks/TaskContext.go create mode 100755 flyteplugins/go/tasks/v1/types/mocks/TaskExecutionID.go create mode 100755 flyteplugins/go/tasks/v1/types/mocks/TaskOverrides.go create mode 100755 flyteplugins/go/tasks/v1/types/outputs_resolver.go create mode 100755 flyteplugins/go/tasks/v1/types/status.go create mode 100755 flyteplugins/go/tasks/v1/types/status_test.go create mode 100755 flyteplugins/go/tasks/v1/types/task.go create mode 100755 flyteplugins/go/tasks/v1/types/task_context.go create mode 100755 flyteplugins/go/tasks/v1/types/task_test.go create mode 100755 flyteplugins/go/tasks/v1/utils/marshal_utils.go create mode 100755 flyteplugins/go/tasks/v1/utils/template.go create mode 100755 flyteplugins/go/tasks/v1/utils/template_test.go create mode 100755 flyteplugins/go/tasks/v1/utils/transformers.go create mode 100755 flyteplugins/go/tasks/v1/utils/transformers_test.go create mode 100755 flyteplugins/tests/hive_integration_test.go create mode 100755 flyteplugins/tests/redis_lookaside_buffer_test.go create mode 100755 flyteplugins/tests/redis_resource_manager_test.go diff --git a/flyteplugins/CODE_OF_CONDUCT.md b/flyteplugins/CODE_OF_CONDUCT.md new file mode 100755 index 0000000000..803d8a77f3 --- /dev/null +++ b/flyteplugins/CODE_OF_CONDUCT.md @@ -0,0 +1,3 @@ +This project is governed by [Lyft's code of +conduct](https://github.com/lyft/code-of-conduct). All contributors +and participants agree to abide by its terms. diff --git a/flyteplugins/Gopkg.lock b/flyteplugins/Gopkg.lock new file mode 100755 index 0000000000..dd76c0bb7d --- /dev/null +++ b/flyteplugins/Gopkg.lock @@ -0,0 +1,1057 @@ +# This file is autogenerated, do not edit; changes may be undone by the next 'dep ensure'. + + +[[projects]] + digest = "1:f45299a845f297e482104076e5ae4b1b0885cafb227098d2d5675b2cc65084a5" + name = "github.com/GoogleCloudPlatform/spark-on-k8s-operator" + packages = [ + "pkg/apis/sparkoperator.k8s.io", + "pkg/apis/sparkoperator.k8s.io/v1beta1", + ] + pruneopts = "" + revision = "21894ac2fe2a4e64632ef620c8a4da776a7b6b87" + source = "https://github.com/lyft/spark-on-k8s-operator" + version = "v0.1.1" + +[[projects]] + digest = "1:60942d250d0e06d3722ddc8e22bc52f8cef7961ba6d8d3e95327a32b6b024a7b" + name = "github.com/appscode/jsonpatch" + packages = ["."] + pruneopts = "" + revision = "7c0e3b262f30165a8ec3d0b4c6059fd92703bfb2" + version = "1.0.0" + +[[projects]] + digest = "1:e54184af8a1457b632aae19f35b241b4fe48f18765f7c80d55d7ef2c0d19d774" + name = "github.com/aws/aws-sdk-go" + packages = [ + "aws", + "aws/awserr", + "aws/awsutil", + "aws/client", + "aws/client/metadata", + "aws/corehandlers", + "aws/credentials", + "aws/credentials/ec2rolecreds", + "aws/credentials/endpointcreds", + "aws/credentials/processcreds", + "aws/credentials/stscreds", + "aws/csm", + "aws/defaults", + "aws/ec2metadata", + "aws/endpoints", + "aws/request", + "aws/session", + "aws/signer/v4", + "internal/ini", + "internal/s3err", + "internal/sdkio", + "internal/sdkrand", + "internal/sdkuri", + "internal/shareddefaults", + "private/protocol", + "private/protocol/eventstream", + "private/protocol/eventstream/eventstreamapi", + "private/protocol/json/jsonutil", + "private/protocol/query", + "private/protocol/query/queryutil", + "private/protocol/rest", + "private/protocol/restxml", + "private/protocol/xml/xmlutil", + "service/s3", + "service/sts", + ] + pruneopts = "" + revision = "eb8216aeaa74d4010569c51ae6238919c172ed82" + version = "v1.19.44" + +[[projects]] + digest = "1:0d3deb8a6da8ffba5635d6fb1d2144662200def6c9d82a35a6d05d6c2d4a48f9" + name = "github.com/beorn7/perks" + packages = ["quantile"] + pruneopts = "" + revision = "4b2b341e8d7715fae06375aa633dbb6e91b3fb46" + version = "v1.0.0" + +[[projects]] + digest = "1:f6485831252319cd6ca29fc170adecf1eb81bf1e805f62f44eb48564ce2485fe" + name = "github.com/cespare/xxhash" + packages = ["."] + pruneopts = "" + revision = "3b82fb7d186719faeedd0c2864f868c74fbf79a1" + version = "v2.0.0" + +[[projects]] + digest = "1:193f6d32d751f26540aa8eeedc114ce0a51f9e77b6c22dda3a4db4e5f65aec66" + name = "github.com/coocood/freecache" + packages = ["."] + pruneopts = "" + revision = "3c79a0a23c1940ab4479332fb3e0127265650ce3" + version = "v1.1.0" + +[[projects]] + digest = "1:0deddd908b6b4b768cfc272c16ee61e7088a60f7fe2f06c547bd3d8e1f8b8e77" + name = "github.com/davecgh/go-spew" + packages = ["spew"] + pruneopts = "" + revision = "8991bc29aa16c548c550c7ff78260e27b9ab7c73" + version = "v1.1.1" + +[[projects]] + digest = "1:46ddeb9dd35d875ac7568c4dc1fc96ce424e034bdbb984239d8ffc151398ec01" + name = "github.com/evanphx/json-patch" + packages = ["."] + pruneopts = "" + revision = "026c730a0dcc5d11f93f1cf1cc65b01247ea7b6f" + version = "v4.5.0" + +[[projects]] + digest = "1:e988ed0ca0d81f4d28772760c02ee95084961311291bdfefc1b04617c178b722" + name = "github.com/fatih/color" + packages = ["."] + pruneopts = "" + revision = "5b77d2a35fb0ede96d138fc9a99f5c9b6aef11b4" + version = "v1.7.0" + +[[projects]] + branch = "master" + digest = "1:135223bf2c128b2158178ee48779ac9983b003634864d46b73e913c95f7a847e" + name = "github.com/fsnotify/fsnotify" + packages = ["."] + pruneopts = "" + revision = "1485a34d5d5723fea214f5710708e19a831720e4" + +[[projects]] + digest = "1:65587005c6fa4293c0b8a2e457e689df7fda48cc5e1f5449ea2c1e7784551558" + name = "github.com/go-logr/logr" + packages = ["."] + pruneopts = "" + revision = "9fb12b3b21c5415d16ac18dc5cd42c1cfdd40c4e" + version = "v0.1.0" + +[[projects]] + digest = "1:d81dfed1aa731d8e4a45d87154ec15ef18da2aa80fa9a2f95bec38577a244a99" + name = "github.com/go-logr/zapr" + packages = ["."] + pruneopts = "" + revision = "03f06a783fbb7dfaf3f629c7825480e43a7105e6" + version = "v0.1.1" + +[[projects]] + digest = "1:c2db84082861ca42d0b00580d28f4b31aceec477a00a38e1a057fb3da75c8adc" + name = "github.com/go-redis/redis" + packages = [ + ".", + "internal", + "internal/consistenthash", + "internal/hashtag", + "internal/pool", + "internal/proto", + "internal/util", + ] + pruneopts = "" + revision = "75795aa4236dc7341eefac3bbe945e68c99ef9df" + version = "v6.15.3" + +[[projects]] + digest = "1:fd53b471edb4c28c7d297f617f4da0d33402755f58d6301e7ca1197ef0a90937" + name = "github.com/gogo/protobuf" + packages = [ + "proto", + "sortkeys", + ] + pruneopts = "" + revision = "ba06b47c162d49f2af050fb4c75bcbc86a159d5c" + version = "v1.2.1" + +[[projects]] + branch = "master" + digest = "1:f9714c0c017f2b821bccceeec2c7a93d29638346bb546c36ca5f90e751f91b9e" + name = "github.com/golang/groupcache" + packages = ["lru"] + pruneopts = "" + revision = "5b532d6fd5efaf7fa130d4e859a2fde0fc3a9e1b" + +[[projects]] + digest = "1:529d738b7976c3848cae5cf3a8036440166835e389c1f617af701eeb12a0518d" + name = "github.com/golang/protobuf" + packages = [ + "jsonpb", + "proto", + "protoc-gen-go/descriptor", + "protoc-gen-go/generator", + "protoc-gen-go/generator/internal/remap", + "protoc-gen-go/plugin", + "ptypes", + "ptypes/any", + "ptypes/duration", + "ptypes/struct", + "ptypes/timestamp", + "ptypes/wrappers", + ] + pruneopts = "" + revision = "b5d812f8a3706043e23a9cd5babf2e5423744d30" + version = "v1.3.1" + +[[projects]] + digest = "1:1e5b1e14524ed08301977b7b8e10c719ed853cbf3f24ecb66fae783a46f207a6" + name = "github.com/google/btree" + packages = ["."] + pruneopts = "" + revision = "4030bb1f1f0c35b30ca7009e9ebd06849dd45306" + version = "v1.0.0" + +[[projects]] + digest = "1:8d4a577a9643f713c25a32151c0f26af7228b4b97a219b5ddb7fd38d16f6e673" + name = "github.com/google/gofuzz" + packages = ["."] + pruneopts = "" + revision = "f140a6486e521aad38f5917de355cbf147cc0496" + version = "v1.0.0" + +[[projects]] + digest = "1:16b2837c8b3cf045fa2cdc82af0cf78b19582701394484ae76b2c3bc3c99ad73" + name = "github.com/googleapis/gnostic" + packages = [ + "OpenAPIv2", + "compiler", + "extensions", + ] + pruneopts = "" + revision = "7c663266750e7d82587642f65e60bc4083f1f84e" + version = "v0.2.0" + +[[projects]] + digest = "1:94697ef521414e9814e038c512699e3ef984519a301b7a499b00cf851c928b29" + name = "github.com/graymeta/stow" + packages = [ + ".", + "local", + "s3", + ] + pruneopts = "" + revision = "77c84b1dd69c41b74fe0a94ca8ee257d85947327" + +[[projects]] + branch = "master" + digest = "1:326d7083af3723768cd8150db99b8ac730837b05ef290d5a042562905cc26210" + name = "github.com/gregjones/httpcache" + packages = [ + ".", + "diskcache", + ] + pruneopts = "" + revision = "3befbb6ad0cc97d4c25d851e9528915809e1a22f" + +[[projects]] + digest = "1:9a0b2dd1f882668a3d7fbcd424eed269c383a16f1faa3a03d14e0dd5fba571b1" + name = "github.com/grpc-ecosystem/go-grpc-middleware" + packages = [ + ".", + "retry", + "util/backoffutils", + "util/metautils", + ] + pruneopts = "" + revision = "c250d6563d4d4c20252cd865923440e829844f4e" + version = "v1.0.0" + +[[projects]] + digest = "1:e24dc5ef44694848785de507f439a24e9e6d96d7b43b8cf3d6cfa857aa1e2186" + name = "github.com/grpc-ecosystem/go-grpc-prometheus" + packages = ["."] + pruneopts = "" + revision = "c225b8c3b01faf2899099b768856a9e916e5087b" + version = "v1.2.0" + +[[projects]] + digest = "1:dee8ec16fa714522c6cad579dfeeba3caf9644d93b8b452cd7138584402c81f7" + name = "github.com/grpc-ecosystem/grpc-gateway" + packages = [ + "internal", + "protoc-gen-swagger/options", + "runtime", + "utilities", + ] + pruneopts = "" + revision = "8fd5fd9d19ce68183a6b0934519dfe7fe6269612" + version = "v1.9.0" + +[[projects]] + digest = "1:85f8f8d390a03287a563e215ea6bd0610c858042731a8b42062435a0dcbc485f" + name = "github.com/hashicorp/golang-lru" + packages = [ + ".", + "simplelru", + ] + pruneopts = "" + revision = "7087cb70de9f7a8bc0a10c375cb0d2280a8edf9c" + version = "v0.5.1" + +[[projects]] + digest = "1:d14365c51dd1d34d5c79833ec91413bfbb166be978724f15701e17080dc06dec" + name = "github.com/hashicorp/hcl" + packages = [ + ".", + "hcl/ast", + "hcl/parser", + "hcl/printer", + "hcl/scanner", + "hcl/strconv", + "hcl/token", + "json/parser", + "json/scanner", + "json/token", + ] + pruneopts = "" + revision = "8cb6e5b959231cc1119e43259c4a608f9c51a241" + version = "v1.0.0" + +[[projects]] + digest = "1:31bfd110d31505e9ffbc9478e31773bf05bf02adcaeb9b139af42684f9294c13" + name = "github.com/imdario/mergo" + packages = ["."] + pruneopts = "" + revision = "7c29201646fa3de8506f701213473dd407f19646" + version = "v0.3.7" + +[[projects]] + digest = "1:870d441fe217b8e689d7949fef6e43efbc787e50f200cb1e70dbca9204a1d6be" + name = "github.com/inconshreveable/mousetrap" + packages = ["."] + pruneopts = "" + revision = "76626ae9c91c4f2a10f34cad8ce83ea42c93bb75" + version = "v1.0" + +[[projects]] + digest = "1:13fe471d0ed891e8544eddfeeb0471fd3c9f2015609a1c000aefdedf52a19d40" + name = "github.com/jmespath/go-jmespath" + packages = ["."] + pruneopts = "" + revision = "c2b33e84" + +[[projects]] + digest = "1:12d3de2c11e54ea37d7f00daf85088ad5e61ec4e8a1f828d6c8b657976856be7" + name = "github.com/json-iterator/go" + packages = ["."] + pruneopts = "" + revision = "0ff49de124c6f76f8494e194af75bde0f1a49a29" + version = "v1.1.6" + +[[projects]] + digest = "1:0f51cee70b0d254dbc93c22666ea2abf211af81c1701a96d04e2284b408621db" + name = "github.com/konsorten/go-windows-terminal-sequences" + packages = ["."] + pruneopts = "" + revision = "f55edac94c9bbba5d6182a4be46d86a2c9b5b50e" + version = "v1.0.2" + +[[projects]] + digest = "1:2b5f0e6bc8fb862fed5bccf9fbb1ab819c8b3f8a21e813fe442c06aec3bb3e86" + name = "github.com/lyft/flyteidl" + packages = [ + "clients/go/admin", + "clients/go/admin/mocks", + "clients/go/coreutils", + "clients/go/coreutils/logs", + "clients/go/events/errors", + "gen/pb-go/flyteidl/admin", + "gen/pb-go/flyteidl/core", + "gen/pb-go/flyteidl/event", + "gen/pb-go/flyteidl/plugins", + "gen/pb-go/flyteidl/service", + ] + pruneopts = "" + revision = "793b09d190148236f41ad8160b5cec9a3325c16f" + source = "git@github.com:lyft/flyteidl" + version = "v0.1.0" + +[[projects]] + digest = "1:c368fe9a00a38c8702e24475dd3a8348d2a191892ef9030aceb821f8c035b737" + name = "github.com/lyft/flytestdlib" + packages = [ + "atomic", + "config", + "config/files", + "config/viper", + "contextutils", + "errors", + "ioutils", + "logger", + "promutils", + "promutils/labeled", + "sets", + "storage", + "utils", + ] + pruneopts = "" + revision = "c0e1a9369cb442d70093564fbbc21d8298f5aeb6" + source = "git@github.com:lyft/flytestdlib" + version = "v0.2.11" + +[[projects]] + digest = "1:ae39921edb7f801f7ce1b6b5484f9715a1dd2b52cb645daef095cd10fd6ee774" + name = "github.com/magiconair/properties" + packages = [ + ".", + "assert", + ] + pruneopts = "" + revision = "de8848e004dd33dc07a2947b3d76f618a7fc7ef1" + version = "v1.8.1" + +[[projects]] + digest = "1:9ea83adf8e96d6304f394d40436f2eb44c1dc3250d223b74088cc253a6cd0a1c" + name = "github.com/mattn/go-colorable" + packages = ["."] + pruneopts = "" + revision = "167de6bfdfba052fa6b2d3664c8f5272e23c9072" + version = "v0.0.9" + +[[projects]] + digest = "1:dbfae9da5a674236b914e486086671145b37b5e3880a38da906665aede3c9eab" + name = "github.com/mattn/go-isatty" + packages = ["."] + pruneopts = "" + revision = "1311e847b0cb909da63b5fecfb5370aa66236465" + version = "v0.0.8" + +[[projects]] + digest = "1:63722a4b1e1717be7b98fc686e0b30d5e7f734b9e93d7dee86293b6deab7ea28" + name = "github.com/matttproud/golang_protobuf_extensions" + packages = ["pbutil"] + pruneopts = "" + revision = "c12348ce28de40eed0136aa2b644d0ee0650e56c" + version = "v1.0.1" + +[[projects]] + digest = "1:bcc46a0fbd9e933087bef394871256b5c60269575bb661935874729c65bbbf60" + name = "github.com/mitchellh/mapstructure" + packages = ["."] + pruneopts = "" + revision = "3536a929edddb9a5b34bd6861dc4a9647cb459fe" + version = "v1.1.2" + +[[projects]] + digest = "1:0c0ff2a89c1bb0d01887e1dac043ad7efbf3ec77482ef058ac423d13497e16fd" + name = "github.com/modern-go/concurrent" + packages = ["."] + pruneopts = "" + revision = "bacd9c7ef1dd9b15be4a9909b8ac7a4e313eec94" + version = "1.0.3" + +[[projects]] + digest = "1:e32bdbdb7c377a07a9a46378290059822efdce5c8d96fe71940d87cb4f918855" + name = "github.com/modern-go/reflect2" + packages = ["."] + pruneopts = "" + revision = "4b7aa43c6742a2c18fdef89dd197aaae7dac7ccd" + version = "1.0.1" + +[[projects]] + digest = "1:3d2c33720d4255686b9f4a7e4d3b94938ee36063f14705c5eb0f73347ed4c496" + name = "github.com/pelletier/go-toml" + packages = ["."] + pruneopts = "" + revision = "728039f679cbcd4f6a54e080d2219a4c4928c546" + version = "v1.4.0" + +[[projects]] + branch = "master" + digest = "1:5f0faa008e8ff4221b55a1a5057c8b02cb2fd68da6a65c9e31c82b72cbc836d0" + name = "github.com/petar/GoLLRB" + packages = ["llrb"] + pruneopts = "" + revision = "33fb24c13b99c46c93183c291836c573ac382536" + +[[projects]] + digest = "1:4709c61d984ef9ba99b037b047546d8a576ae984fb49486e48d99658aa750cd5" + name = "github.com/peterbourgon/diskv" + packages = ["."] + pruneopts = "" + revision = "0be1b92a6df0e4f5cb0a5d15fb7f643d0ad93ce6" + version = "v3.0.0" + +[[projects]] + digest = "1:1d7e1867c49a6dd9856598ef7c3123604ea3daabf5b83f303ff457bcbc410b1d" + name = "github.com/pkg/errors" + packages = ["."] + pruneopts = "" + revision = "ba968bfe8b2f7e042a574c888954fccecfa385b4" + version = "v0.8.1" + +[[projects]] + digest = "1:256484dbbcd271f9ecebc6795b2df8cad4c458dd0f5fd82a8c2fa0c29f233411" + name = "github.com/pmezard/go-difflib" + packages = ["difflib"] + pruneopts = "" + revision = "792786c7400a136282c1664665ae0a8db921c6c2" + version = "v1.0.0" + +[[projects]] + digest = "1:6894aa393989b5e59d9936b8b1197dc261c2c200057b92dec34007b06e9856ae" + name = "github.com/prometheus/client_golang" + packages = [ + "prometheus", + "prometheus/internal", + ] + pruneopts = "" + revision = "50c4339db732beb2165735d2cde0bff78eb3c5a5" + version = "v0.9.3" + +[[projects]] + branch = "master" + digest = "1:cd67319ee7536399990c4b00fae07c3413035a53193c644549a676091507cadc" + name = "github.com/prometheus/client_model" + packages = ["go"] + pruneopts = "" + revision = "fd36f4220a901265f90734c3183c5f0c91daa0b8" + +[[projects]] + digest = "1:e6315869762add748defb9e0fcc537738f78cabeaf70b2788aba9db13097b6e9" + name = "github.com/prometheus/common" + packages = [ + "expfmt", + "internal/bitbucket.org/ww/goautoneg", + "model", + ] + pruneopts = "" + revision = "17f5ca1748182ddf24fc33a5a7caaaf790a52fcc" + version = "v0.4.1" + +[[projects]] + digest = "1:fea688256dfff79e9a0e24be47c4acf51347fcff52a5dfca7b251932a52c67e0" + name = "github.com/prometheus/procfs" + packages = [ + ".", + "internal/fs", + ] + pruneopts = "" + revision = "833678b5bb319f2d20a475cb165c6cc59c2cc77c" + version = "v0.0.2" + +[[projects]] + digest = "1:1a405cddcf3368445051fb70ab465ae99da56ad7be8d8ca7fc52159d1c2d873c" + name = "github.com/sirupsen/logrus" + packages = ["."] + pruneopts = "" + revision = "839c75faf7f98a33d445d181f3018b5c3409a45e" + version = "v1.4.2" + +[[projects]] + digest = "1:956f655c87b7255c6b1ae6c203ebb0af98cf2a13ef2507e34c9bf1c0332ac0f5" + name = "github.com/spf13/afero" + packages = [ + ".", + "mem", + ] + pruneopts = "" + revision = "588a75ec4f32903aa5e39a2619ba6a4631e28424" + version = "v1.2.2" + +[[projects]] + digest = "1:ae3493c780092be9d576a1f746ab967293ec165e8473425631f06658b6212afc" + name = "github.com/spf13/cast" + packages = ["."] + pruneopts = "" + revision = "8c9545af88b134710ab1cd196795e7f2388358d7" + version = "v1.3.0" + +[[projects]] + digest = "1:78715f4ed019d19795e67eed1dc63f525461d925616b1ed02b72582c01362440" + name = "github.com/spf13/cobra" + packages = ["."] + pruneopts = "" + revision = "67fc4837d267bc9bfd6e47f77783fcc3dffc68de" + version = "v0.0.4" + +[[projects]] + digest = "1:cc15ae4fbdb02ce31f3392361a70ac041f4f02e0485de8ffac92bd8033e3d26e" + name = "github.com/spf13/jwalterweatherman" + packages = ["."] + pruneopts = "" + revision = "94f6ae3ed3bceceafa716478c5fbf8d29ca601a1" + version = "v1.1.0" + +[[projects]] + digest = "1:cbaf13cdbfef0e4734ed8a7504f57fe893d471d62a35b982bf6fb3f036449a66" + name = "github.com/spf13/pflag" + packages = ["."] + pruneopts = "" + revision = "298182f68c66c05229eb03ac171abe6e309ee79a" + version = "v1.0.3" + +[[projects]] + digest = "1:c25a789c738f7cc8ec7f34026badd4e117853f329334a5aa45cf5d0727d7d442" + name = "github.com/spf13/viper" + packages = ["."] + pruneopts = "" + revision = "ae103d7e593e371c69e832d5eb3347e2b80cbbc9" + +[[projects]] + digest = "1:711eebe744c0151a9d09af2315f0bb729b2ec7637ef4c410fa90a18ef74b65b6" + name = "github.com/stretchr/objx" + packages = ["."] + pruneopts = "" + revision = "477a77ecc69700c7cdeb1fa9e129548e1c1c393c" + version = "v0.1.1" + +[[projects]] + digest = "1:381bcbeb112a51493d9d998bbba207a529c73dbb49b3fd789e48c63fac1f192c" + name = "github.com/stretchr/testify" + packages = [ + "assert", + "mock", + ] + pruneopts = "" + revision = "ffdc059bfe9ce6a4e144ba849dbedead332c6053" + version = "v1.3.0" + +[[projects]] + digest = "1:e6ff7840319b6fda979a918a8801005ec2049abca62af19211d96971d8ec3327" + name = "go.uber.org/atomic" + packages = ["."] + pruneopts = "" + revision = "df976f2515e274675050de7b3f42545de80594fd" + version = "v1.4.0" + +[[projects]] + digest = "1:22c7effcb4da0eacb2bb1940ee173fac010e9ef3c691f5de4b524d538bd980f5" + name = "go.uber.org/multierr" + packages = ["."] + pruneopts = "" + revision = "3c4937480c32f4c13a875a1829af76c98ca3d40a" + version = "v1.1.0" + +[[projects]] + digest = "1:984e93aca9088b440b894df41f2043b6a3db8f9cf30767032770bfc4796993b0" + name = "go.uber.org/zap" + packages = [ + ".", + "buffer", + "internal/bufferpool", + "internal/color", + "internal/exit", + "zapcore", + ] + pruneopts = "" + revision = "27376062155ad36be76b0f12cf1572a221d3a48c" + version = "v1.10.0" + +[[projects]] + branch = "master" + digest = "1:9d150270ca2c3356f2224a0878daa1652e4d0b25b345f18b4f6e156cc4b8ec5e" + name = "golang.org/x/crypto" + packages = ["ssh/terminal"] + pruneopts = "" + revision = "f99c8df09eb5bff426315721bfa5f16a99cad32c" + +[[projects]] + branch = "master" + digest = "1:d168befeef1eb51a25ab229b1bb411ae07c7bef22ebee588c290faf3bdf4ae27" + name = "golang.org/x/net" + packages = [ + "context", + "context/ctxhttp", + "http/httpguts", + "http2", + "http2/hpack", + "idna", + "internal/timeseries", + "trace", + ] + pruneopts = "" + revision = "1492cefac77f61bc789c00f41ead8f8d7307cd21" + +[[projects]] + branch = "master" + digest = "1:01bdbbc604dcd5afb6f66a717f69ad45e9643c72d5bc11678d44ffa5c50f9e42" + name = "golang.org/x/oauth2" + packages = [ + ".", + "internal", + ] + pruneopts = "" + revision = "0f29369cfe4552d0e4bcddc57cc75f4d7e672a33" + +[[projects]] + branch = "master" + digest = "1:4b923bc8024a3154f2c1e072d37133d17326e8a6a61bb03102e2f14b8af7a067" + name = "golang.org/x/sys" + packages = [ + "unix", + "windows", + ] + pruneopts = "" + revision = "5da285871e9c6a1c3acade75bea3282d33f55ebd" + +[[projects]] + digest = "1:740b51a55815493a8d0f2b1e0d0ae48fe48953bf7eaf3fcc4198823bf67768c0" + name = "golang.org/x/text" + packages = [ + "collate", + "collate/build", + "internal/colltab", + "internal/gen", + "internal/language", + "internal/language/compact", + "internal/tag", + "internal/triegen", + "internal/ucd", + "language", + "secure/bidirule", + "transform", + "unicode/bidi", + "unicode/cldr", + "unicode/norm", + "unicode/rangetable", + ] + pruneopts = "" + revision = "342b2e1fbaa52c93f31447ad2c6abc048c63e475" + version = "v0.3.2" + +[[projects]] + branch = "master" + digest = "1:9522af4be529c108010f95b05f1022cb872f2b9ff8b101080f554245673466e1" + name = "golang.org/x/time" + packages = ["rate"] + pruneopts = "" + revision = "9d24e82272b4f38b78bc8cff74fa936d31ccd8ef" + +[[projects]] + digest = "1:47f391ee443f578f01168347818cb234ed819521e49e4d2c8dd2fb80d48ee41a" + name = "google.golang.org/appengine" + packages = [ + "internal", + "internal/base", + "internal/datastore", + "internal/log", + "internal/remote_api", + "internal/urlfetch", + "urlfetch", + ] + pruneopts = "" + revision = "b2f4a3cf3c67576a2ee09e1fe62656a5086ce880" + version = "v1.6.1" + +[[projects]] + branch = "master" + digest = "1:52c6a188a1f16480607d7753e02cbd5ad43089d5550c725f32f46d376c946f37" + name = "google.golang.org/genproto" + packages = [ + "googleapis/api/annotations", + "googleapis/api/httpbody", + "googleapis/rpc/status", + "protobuf/field_mask", + ] + pruneopts = "" + revision = "eb0b1bdb6ae60fcfc41b8d907b50dfb346112301" + +[[projects]] + digest = "1:6881653b963cd12dc1a9824aed5e122d0ff38e53e3ee07862f969a56ad2f2e9c" + name = "google.golang.org/grpc" + packages = [ + ".", + "balancer", + "balancer/base", + "balancer/roundrobin", + "binarylog/grpc_binarylog_v1", + "codes", + "connectivity", + "credentials", + "credentials/internal", + "encoding", + "encoding/proto", + "grpclog", + "internal", + "internal/backoff", + "internal/balancerload", + "internal/binarylog", + "internal/channelz", + "internal/envconfig", + "internal/grpcrand", + "internal/grpcsync", + "internal/syscall", + "internal/transport", + "keepalive", + "metadata", + "naming", + "peer", + "resolver", + "resolver/dns", + "resolver/passthrough", + "stats", + "status", + "tap", + ] + pruneopts = "" + revision = "501c41df7f472c740d0674ff27122f3f48c80ce7" + version = "v1.21.1" + +[[projects]] + digest = "1:75fb3fcfc73a8c723efde7777b40e8e8ff9babf30d8c56160d01beffea8a95a6" + name = "gopkg.in/inf.v0" + packages = ["."] + pruneopts = "" + revision = "d2d2541c53f18d2a059457998ce2876cc8e67cbf" + version = "v0.9.1" + +[[projects]] + digest = "1:cedccf16b71e86db87a24f8d4c70b0a855872eb967cb906a66b95de56aefbd0d" + name = "gopkg.in/yaml.v2" + packages = ["."] + pruneopts = "" + revision = "51d6538a90f86fe93ac480b35f37b2be17fef232" + version = "v2.2.2" + +[[projects]] + digest = "1:73ee122857f257aa507ebae097783fe08ad8af49398e5b3876787325411f1a4b" + name = "k8s.io/api" + packages = [ + "admission/v1beta1", + "admissionregistration/v1alpha1", + "admissionregistration/v1beta1", + "apps/v1", + "apps/v1beta1", + "apps/v1beta2", + "auditregistration/v1alpha1", + "authentication/v1", + "authentication/v1beta1", + "authorization/v1", + "authorization/v1beta1", + "autoscaling/v1", + "autoscaling/v2beta1", + "autoscaling/v2beta2", + "batch/v1", + "batch/v1beta1", + "batch/v2alpha1", + "certificates/v1beta1", + "coordination/v1beta1", + "core/v1", + "events/v1beta1", + "extensions/v1beta1", + "networking/v1", + "policy/v1beta1", + "rbac/v1", + "rbac/v1alpha1", + "rbac/v1beta1", + "scheduling/v1alpha1", + "scheduling/v1beta1", + "settings/v1alpha1", + "storage/v1", + "storage/v1alpha1", + "storage/v1beta1", + ] + pruneopts = "" + revision = "27b77cf22008a0bf6e510b2a500b8885805a1b68" + source = "https://github.com/lyft/api" + +[[projects]] + digest = "1:a3bee4b1e4013573fc15631b51a7b7e0d580497e6fec63dc3724b370e624569f" + name = "k8s.io/apimachinery" + packages = [ + "pkg/api/errors", + "pkg/api/meta", + "pkg/api/resource", + "pkg/apis/meta/internalversion", + "pkg/apis/meta/v1", + "pkg/apis/meta/v1/unstructured", + "pkg/apis/meta/v1beta1", + "pkg/conversion", + "pkg/conversion/queryparams", + "pkg/fields", + "pkg/labels", + "pkg/runtime", + "pkg/runtime/schema", + "pkg/runtime/serializer", + "pkg/runtime/serializer/json", + "pkg/runtime/serializer/protobuf", + "pkg/runtime/serializer/recognizer", + "pkg/runtime/serializer/streaming", + "pkg/runtime/serializer/versioning", + "pkg/selection", + "pkg/types", + "pkg/util/cache", + "pkg/util/clock", + "pkg/util/diff", + "pkg/util/errors", + "pkg/util/framer", + "pkg/util/intstr", + "pkg/util/json", + "pkg/util/mergepatch", + "pkg/util/naming", + "pkg/util/net", + "pkg/util/rand", + "pkg/util/runtime", + "pkg/util/sets", + "pkg/util/strategicpatch", + "pkg/util/validation", + "pkg/util/validation/field", + "pkg/util/wait", + "pkg/util/yaml", + "pkg/version", + "pkg/watch", + "third_party/forked/golang/json", + "third_party/forked/golang/reflect", + ] + pruneopts = "" + revision = "695912cabb3a9c08353fbe628115867d47f56b1f" + source = "https://github.com/lyft/apimachinery" + +[[projects]] + digest = "1:d603c9957fa66c90792d45fe0205d484da2ea364a01069c20890f2640b4a0fd5" + name = "k8s.io/client-go" + packages = [ + "discovery", + "dynamic", + "kubernetes/scheme", + "pkg/apis/clientauthentication", + "pkg/apis/clientauthentication/v1alpha1", + "pkg/apis/clientauthentication/v1beta1", + "pkg/version", + "plugin/pkg/client/auth/exec", + "rest", + "rest/watch", + "restmapper", + "testing", + "tools/auth", + "tools/cache", + "tools/clientcmd", + "tools/clientcmd/api", + "tools/clientcmd/api/latest", + "tools/clientcmd/api/v1", + "tools/metrics", + "tools/pager", + "tools/record", + "tools/reference", + "transport", + "util/buffer", + "util/cert", + "util/connrotation", + "util/flowcontrol", + "util/homedir", + "util/integer", + "util/retry", + "util/workqueue", + ] + pruneopts = "" + revision = "8d9ed539ba3134352c586810e749e58df4e94e4f" + version = "kubernetes-1.13.1" + +[[projects]] + digest = "1:9eaf86f4f6fb4a8f177220d488ef1e3255d06a691cca95f14ef085d4cd1cef3c" + name = "k8s.io/klog" + packages = ["."] + pruneopts = "" + revision = "d98d8acdac006fb39831f1b25640813fef9c314f" + version = "v0.3.3" + +[[projects]] + branch = "master" + digest = "1:d2aae07aa745223592ae668f6eb6c2ca0242d66a6dcf16b1e8e2711a79aad0f1" + name = "k8s.io/kube-openapi" + packages = ["pkg/util/proto"] + pruneopts = "" + revision = "db7b694dc208eead64d38030265f702db593fcf2" + +[[projects]] + digest = "1:5c1664b5783da5772e29bc7c2fbe369dc0b1d2f11b7935c6adc283d9aa839355" + name = "sigs.k8s.io/controller-runtime" + packages = [ + "pkg/cache", + "pkg/cache/informertest", + "pkg/cache/internal", + "pkg/client", + "pkg/client/apiutil", + "pkg/client/config", + "pkg/client/fake", + "pkg/controller/controllertest", + "pkg/event", + "pkg/handler", + "pkg/internal/objectutil", + "pkg/predicate", + "pkg/reconcile", + "pkg/runtime/inject", + "pkg/runtime/log", + "pkg/source", + "pkg/source/internal", + "pkg/webhook/admission/types", + ] + pruneopts = "" + revision = "477bf4f046c31c351b46fa00262bc814ac0bbca1" + version = "v0.1.11" + +[[projects]] + digest = "1:321081b4a44256715f2b68411d8eda9a17f17ebfe6f0cc61d2cc52d11c08acfa" + name = "sigs.k8s.io/yaml" + packages = ["."] + pruneopts = "" + revision = "fd68e9863619f6ec2fdd8625fe1f02e7c877e480" + version = "v1.1.0" + +[solve-meta] + analyzer-name = "dep" + analyzer-version = 1 + input-imports = [ + "github.com/GoogleCloudPlatform/spark-on-k8s-operator/pkg/apis/sparkoperator.k8s.io/v1beta1", + "github.com/go-redis/redis", + "github.com/golang/protobuf/jsonpb", + "github.com/golang/protobuf/proto", + "github.com/golang/protobuf/ptypes", + "github.com/golang/protobuf/ptypes/struct", + "github.com/lyft/flyteidl/clients/go/admin", + "github.com/lyft/flyteidl/clients/go/admin/mocks", + "github.com/lyft/flyteidl/clients/go/coreutils", + "github.com/lyft/flyteidl/clients/go/coreutils/logs", + "github.com/lyft/flyteidl/clients/go/events/errors", + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin", + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core", + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/event", + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins", + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/service", + "github.com/lyft/flytestdlib/config", + "github.com/lyft/flytestdlib/config/viper", + "github.com/lyft/flytestdlib/contextutils", + "github.com/lyft/flytestdlib/logger", + "github.com/lyft/flytestdlib/promutils", + "github.com/lyft/flytestdlib/promutils/labeled", + "github.com/lyft/flytestdlib/sets", + "github.com/lyft/flytestdlib/storage", + "github.com/lyft/flytestdlib/utils", + "github.com/magiconair/properties/assert", + "github.com/mitchellh/mapstructure", + "github.com/pkg/errors", + "github.com/prometheus/client_golang/prometheus", + "github.com/spf13/pflag", + "github.com/stretchr/testify/assert", + "github.com/stretchr/testify/mock", + "google.golang.org/grpc", + "k8s.io/api/apps/v1", + "k8s.io/api/batch/v1", + "k8s.io/api/core/v1", + "k8s.io/apimachinery/pkg/api/errors", + "k8s.io/apimachinery/pkg/api/meta", + "k8s.io/apimachinery/pkg/api/resource", + "k8s.io/apimachinery/pkg/apis/meta/v1", + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured", + "k8s.io/apimachinery/pkg/runtime", + "k8s.io/apimachinery/pkg/runtime/schema", + "k8s.io/apimachinery/pkg/types", + "k8s.io/apimachinery/pkg/util/json", + "k8s.io/apimachinery/pkg/util/rand", + "k8s.io/apimachinery/pkg/util/sets", + "k8s.io/client-go/kubernetes/scheme", + "k8s.io/client-go/tools/record", + "k8s.io/client-go/util/workqueue", + "sigs.k8s.io/controller-runtime/pkg/cache", + "sigs.k8s.io/controller-runtime/pkg/cache/informertest", + "sigs.k8s.io/controller-runtime/pkg/client", + "sigs.k8s.io/controller-runtime/pkg/client/config", + "sigs.k8s.io/controller-runtime/pkg/client/fake", + "sigs.k8s.io/controller-runtime/pkg/event", + "sigs.k8s.io/controller-runtime/pkg/handler", + "sigs.k8s.io/controller-runtime/pkg/runtime/inject", + "sigs.k8s.io/controller-runtime/pkg/source", + ] + solver-name = "gps-cdcl" + solver-version = 1 diff --git a/flyteplugins/Gopkg.toml b/flyteplugins/Gopkg.toml new file mode 100755 index 0000000000..7969a7306a --- /dev/null +++ b/flyteplugins/Gopkg.toml @@ -0,0 +1,52 @@ +required = ["github.com/lyft/flytestdlib/storage", + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/event", + "sigs.k8s.io/controller-runtime/pkg/client/config", + "k8s.io/apimachinery/pkg/apis/meta/v1", + "k8s.io/client-go/tools/record", + ] +ignored = ["k8s.io/spark-on-k8s-operator", + ] + +[[constraint]] + name = "github.com/lyft/flyteidl" + source = "git@github.com:lyft/flyteidl" + version = "^0.1.x" + +[[constraint]] + name = "github.com/lyft/flytestdlib" + source = "git@github.com:lyft/flytestdlib" + version = "^0.2.x" + +[[constraint]] + name = "sigs.k8s.io/controller-runtime" + version = "^0.1.0" + +# Type resource.Quantity unmarshals with json, but not jsonpb (https://github.com/kubernetes/apimachinery/issues/59). +# Because we embed k8s protos in our own proto definitions we need to be able to call jsonpb marshal/unmarshal in order +# to use the k8s types in our our struct messages. +[[override]] + name = "k8s.io/apimachinery" + source = "https://github.com/lyft/apimachinery" + revision = "695912cabb3a9c08353fbe628115867d47f56b1f" + +[[constraint]] + name = "k8s.io/client-go" + version = "kubernetes-1.13.1" + +[[constraint]] + name = "github.com/GoogleCloudPlatform/spark-on-k8s-operator" + version = "^0.1.x" + source = "https://github.com/lyft/spark-on-k8s-operator" + +[[override]] + name = "k8s.io/api" + revision = "27b77cf22008a0bf6e510b2a500b8885805a1b68" + source = "https://github.com/lyft/api" + +[[override]] + name = "github.com/prometheus/client_golang" + version = "^0.9.0" + +[[override]] + name = "github.com/stretchr/objx" + version = "0.1.1" diff --git a/flyteplugins/LICENSE b/flyteplugins/LICENSE new file mode 100755 index 0000000000..bed437514f --- /dev/null +++ b/flyteplugins/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2019 Lyft, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/flyteplugins/Makefile b/flyteplugins/Makefile new file mode 100755 index 0000000000..b3238b52fe --- /dev/null +++ b/flyteplugins/Makefile @@ -0,0 +1,14 @@ +export REPOSITORY=flyteplugins +include boilerplate/lyft/golang_test_targets/Makefile + +.PHONY: update_boilerplate +update_boilerplate: + @boilerplate/update.sh + +generate: + which pflags || (go get github.com/lyft/flytestdlib/cli/pflags) + which mockery || (go get github.com/vektra/mockery/cmd/mockery) + @go generate ./... + +clean: + rm -rf bin diff --git a/flyteplugins/NOTICE b/flyteplugins/NOTICE new file mode 100755 index 0000000000..c3aef1fa32 --- /dev/null +++ b/flyteplugins/NOTICE @@ -0,0 +1,4 @@ +flyteplugins +Copyright 2019 Lyft Inc. + +This product includes software developed at Lyft Inc. diff --git a/flyteplugins/README.md b/flyteplugins/README.md new file mode 100755 index 0000000000..d2002cc54d --- /dev/null +++ b/flyteplugins/README.md @@ -0,0 +1,2 @@ +# flyteplugins +Plugins contributed by flyte community. diff --git a/flyteplugins/boilerplate/lyft/golang_test_targets/Makefile b/flyteplugins/boilerplate/lyft/golang_test_targets/Makefile new file mode 100755 index 0000000000..6c1e527fd6 --- /dev/null +++ b/flyteplugins/boilerplate/lyft/golang_test_targets/Makefile @@ -0,0 +1,38 @@ +# WARNING: THIS FILE IS MANAGED IN THE 'BOILERPLATE' REPO AND COPIED TO OTHER REPOSITORIES. +# ONLY EDIT THIS FILE FROM WITHIN THE 'LYFT/BOILERPLATE' REPOSITORY: +# +# TO OPT OUT OF UPDATES, SEE https://github.com/lyft/boilerplate/blob/master/Readme.rst + +DEP_SHA=1f7c19e5f52f49ffb9f956f64c010be14683468b + +.PHONY: lint +lint: #lints the package for common code smells + which golangci-lint || curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | bash -s -- -b $$GOPATH/bin v1.16.0 + golangci-lint run --exclude deprecated + +# If code is failing goimports linter, this will fix. +# skips 'vendor' +.PHONY: goimports +goimports: + @boilerplate/lyft/golang_test_targets/goimports + +.PHONY: install +install: #download dependencies (including test deps) for the package + which dep || (curl "https://raw.githubusercontent.com/golang/dep/${DEP_SHA}/install.sh" | sh) + dep ensure + +.PHONY: test_unit +test_unit: + go test -cover ./... -race + +.PHONY: test_benchmark +test_benchmark: + go test -bench . ./... + +.PHONY: test_unit_cover +test_unit_cover: + go test ./... -coverprofile /tmp/cover.out -covermode=count; go tool cover -func /tmp/cover.out + +.PHONY: test_unit_visual +test_unit_visual: + go test ./... -coverprofile /tmp/cover.out -covermode=count; go tool cover -html=/tmp/cover.out diff --git a/flyteplugins/boilerplate/lyft/golang_test_targets/Readme.rst b/flyteplugins/boilerplate/lyft/golang_test_targets/Readme.rst new file mode 100755 index 0000000000..acc5744f59 --- /dev/null +++ b/flyteplugins/boilerplate/lyft/golang_test_targets/Readme.rst @@ -0,0 +1,31 @@ +Golang Test Targets +~~~~~~~~~~~~~~~~~~~ + +Provides an ``install`` make target that uses ``dep`` install golang dependencies. + +Provides a ``lint`` make target that uses golangci to lint your code. + +Provides a ``test_unit`` target for unit tests. + +Provides a ``test_unit_cover`` target for analysing coverage of unit tests, which will output the coverage of each function and total statement coverage. + +Provides a ``test_unit_visual`` target for visualizing coverage of unit tests through an interactive html code heat map. + +Provides a ``test_benchmark`` target for benchmark tests. + +**To Enable:** + +Add ``lyft/golang_test_targets`` to your ``boilerplate/update.cfg`` file. + +Make sure you're using ``dep`` for dependency management. + +Provide a ``.golangci`` configuration (the lint target requires it). + +Add ``include boilerplate/lyft/golang_test_targets/Makefile`` in your main ``Makefile`` _after_ your REPOSITORY environment variable + +:: + + REPOSITORY= + include boilerplate/lyft/golang_test_targets/Makefile + +(this ensures the extra make targets get included in your main Makefile) diff --git a/flyteplugins/boilerplate/lyft/golang_test_targets/goimports b/flyteplugins/boilerplate/lyft/golang_test_targets/goimports new file mode 100755 index 0000000000..160525a8cc --- /dev/null +++ b/flyteplugins/boilerplate/lyft/golang_test_targets/goimports @@ -0,0 +1,8 @@ +#!/usr/bin/env bash + +# WARNING: THIS FILE IS MANAGED IN THE 'BOILERPLATE' REPO AND COPIED TO OTHER REPOSITORIES. +# ONLY EDIT THIS FILE FROM WITHIN THE 'LYFT/BOILERPLATE' REPOSITORY: +# +# TO OPT OUT OF UPDATES, SEE https://github.com/lyft/boilerplate/blob/master/Readme.rst + +goimports -w $(find . -type f -name '*.go' -not -path "./vendor/*" -not -path "./pkg/client/*") diff --git a/flyteplugins/boilerplate/lyft/golangci_file/.golangci.yml b/flyteplugins/boilerplate/lyft/golangci_file/.golangci.yml new file mode 100755 index 0000000000..a414f33f79 --- /dev/null +++ b/flyteplugins/boilerplate/lyft/golangci_file/.golangci.yml @@ -0,0 +1,30 @@ +# WARNING: THIS FILE IS MANAGED IN THE 'BOILERPLATE' REPO AND COPIED TO OTHER REPOSITORIES. +# ONLY EDIT THIS FILE FROM WITHIN THE 'LYFT/BOILERPLATE' REPOSITORY: +# +# TO OPT OUT OF UPDATES, SEE https://github.com/lyft/boilerplate/blob/master/Readme.rst + +run: + skip-dirs: + - pkg/client + +linters: + disable-all: true + enable: + - deadcode + - errcheck + - gas + - goconst + - goimports + - golint + - gosimple + - govet + - ineffassign + - misspell + - nakedret + - staticcheck + - structcheck + - typecheck + - unconvert + - unparam + - unused + - varcheck diff --git a/flyteplugins/boilerplate/lyft/golangci_file/Readme.rst b/flyteplugins/boilerplate/lyft/golangci_file/Readme.rst new file mode 100755 index 0000000000..ba5d2b61ce --- /dev/null +++ b/flyteplugins/boilerplate/lyft/golangci_file/Readme.rst @@ -0,0 +1,8 @@ +GolangCI File +~~~~~~~~~~~~~ + +Provides a ``.golangci`` file with the linters we've agreed upon. + +**To Enable:** + +Add ``lyft/golangci_file`` to your ``boilerplate/update.cfg`` file. diff --git a/flyteplugins/boilerplate/lyft/golangci_file/update.sh b/flyteplugins/boilerplate/lyft/golangci_file/update.sh new file mode 100755 index 0000000000..9e9e6c1f46 --- /dev/null +++ b/flyteplugins/boilerplate/lyft/golangci_file/update.sh @@ -0,0 +1,14 @@ +#!/usr/bin/env bash + +# WARNING: THIS FILE IS MANAGED IN THE 'BOILERPLATE' REPO AND COPIED TO OTHER REPOSITORIES. +# ONLY EDIT THIS FILE FROM WITHIN THE 'LYFT/BOILERPLATE' REPOSITORY: +# +# TO OPT OUT OF UPDATES, SEE https://github.com/lyft/boilerplate/blob/master/Readme.rst + +set -e + +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null && pwd )" + +# Clone the .golangci file +echo " - copying ${DIR}/.golangci to the root directory." +cp ${DIR}/.golangci.yml ${DIR}/../../../.golangci.yml diff --git a/flyteplugins/boilerplate/update.cfg b/flyteplugins/boilerplate/update.cfg new file mode 100755 index 0000000000..f861a23ccd --- /dev/null +++ b/flyteplugins/boilerplate/update.cfg @@ -0,0 +1,2 @@ +lyft/golang_test_targets +lyft/golangci_file diff --git a/flyteplugins/boilerplate/update.sh b/flyteplugins/boilerplate/update.sh new file mode 100755 index 0000000000..bea661d9a0 --- /dev/null +++ b/flyteplugins/boilerplate/update.sh @@ -0,0 +1,53 @@ +#!/usr/bin/env bash + +# WARNING: THIS FILE IS MANAGED IN THE 'BOILERPLATE' REPO AND COPIED TO OTHER REPOSITORIES. +# ONLY EDIT THIS FILE FROM WITHIN THE 'LYFT/BOILERPLATE' REPOSITORY: +# +# TO OPT OUT OF UPDATES, SEE https://github.com/lyft/boilerplate/blob/master/Readme.rst + +set -e + +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null && pwd )" + +OUT="$(mktemp -d)" +git clone git@github.com:lyft/boilerplate.git "${OUT}" + +echo "Updating the update.sh script." +cp "${OUT}/boilerplate/update.sh" "${DIR}/update.sh" +echo "" + + +CONFIG_FILE="${DIR}/update.cfg" +README="https://github.com/lyft/boilerplate/blob/master/Readme.rst" + +if [ ! -f "$CONFIG_FILE" ]; then + echo "$CONFIG_FILE not found." + echo "This file is required in order to select which features to include." + echo "See $README for more details." + exit 1 +fi + +if [ -z "$REPOSITORY" ]; then + echo '$REPOSITORY is required to run this script' + echo "See $README for more details." + exit 1 +fi + +while read directory; do + echo "***********************************************************************************" + echo "$directory is configured in update.cfg." + echo "-----------------------------------------------------------------------------------" + echo "syncing files from source." + dir_path="${OUT}/boilerplate/${directory}" + rm -rf "${DIR}/${directory}" + mkdir -p $(dirname "${DIR}/${directory}") + cp -r "$dir_path" "${DIR}/${directory}" + if [ -f "${DIR}/${directory}/update.sh" ]; then + echo "executing ${DIR}/${directory}/update.sh" + "${DIR}/${directory}/update.sh" + fi + echo "***********************************************************************************" + echo "" +done < "$CONFIG_FILE" + +rm -rf "${OUT}" diff --git a/flyteplugins/go/tasks/loader.go b/flyteplugins/go/tasks/loader.go new file mode 100755 index 0000000000..526ce28dae --- /dev/null +++ b/flyteplugins/go/tasks/loader.go @@ -0,0 +1,28 @@ +// This package contains various task plugins that offer an extensibility point into how propeller executes tasks. +// Check the documentation of either version for examples on how to develop and register plugins. +package tasks + +import ( + "context" + + "github.com/pkg/errors" + + v1 "github.com/lyft/flyteplugins/go/tasks/v1" + + // This is a temporary solution to invoke init() methods on all plugins. Ideally this step should happen dynamically + // based on a config. + _ "github.com/lyft/flyteplugins/go/tasks/v1/k8splugins" + _ "github.com/lyft/flyteplugins/go/tasks/v1/qubole" +) + +func Load(ctx context.Context) error { + if err := v1.RunAllLoaders(ctx); err != nil { + return err + } + + if len(v1.ListAllTaskExecutors()) == 0 { + return errors.Errorf("No Task Executor defined.") + } + + return nil +} diff --git a/flyteplugins/go/tasks/v1/config/config.go b/flyteplugins/go/tasks/v1/config/config.go new file mode 100755 index 0000000000..01576f1dd0 --- /dev/null +++ b/flyteplugins/go/tasks/v1/config/config.go @@ -0,0 +1,29 @@ +package config + +import ( + "github.com/lyft/flytestdlib/config" +) + +//go:generate pflags Config + +const configSectionKey = "plugins" + +var ( + // Root config section. If you are a plugin developer and your plugin needs a config, you should register + // your config as a subsection for this root section. + rootSection = config.MustRegisterSection(configSectionKey, &Config{}) +) + +// Top level plugins config. +type Config struct { + EnabledPlugins []string `json:"enabled-plugins" pflag:"[]string{\"*\"},List of enabled plugins, default value is to enable all plugins."` +} + +// Retrieves the current config value or default. +func GetConfig() *Config { + return rootSection.GetConfig().(*Config) +} + +func MustRegisterSubSection(subSectionKey string, section config.Config) config.Section { + return rootSection.MustRegisterSection(subSectionKey, section) +} diff --git a/flyteplugins/go/tasks/v1/config/config_flags.go b/flyteplugins/go/tasks/v1/config/config_flags.go new file mode 100755 index 0000000000..db49c62906 --- /dev/null +++ b/flyteplugins/go/tasks/v1/config/config_flags.go @@ -0,0 +1,46 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package config + +import ( + "encoding/json" + "reflect" + + "fmt" + + "github.com/spf13/pflag" +) + +// If v is a pointer, it will get its element value or the zero value of the element type. +// If v is not a pointer, it will return it as is. +func (Config) elemValueOrNil(v interface{}) interface{} { + if t := reflect.TypeOf(v); t.Kind() == reflect.Ptr { + if reflect.ValueOf(v).IsNil() { + return reflect.Zero(t.Elem()).Interface() + } else { + return reflect.ValueOf(v).Interface() + } + } else if v == nil { + return reflect.Zero(t).Interface() + } + + return v +} + +func (Config) mustMarshalJSON(v json.Marshaler) string { + raw, err := v.MarshalJSON() + if err != nil { + panic(err) + } + + return string(raw) +} + +// GetPFlagSet will return strongly types pflags for all fields in Config and its nested types. The format of the +// flags is json-name.json-sub-name... etc. +func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { + cmdFlags := pflag.NewFlagSet("Config", pflag.ExitOnError) + cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "enabled-plugins"), []string{"*"}, "List of enabled plugins, default value is to enable all plugins.") + return cmdFlags +} diff --git a/flyteplugins/go/tasks/v1/config/config_flags_test.go b/flyteplugins/go/tasks/v1/config/config_flags_test.go new file mode 100755 index 0000000000..9d1d5c8d14 --- /dev/null +++ b/flyteplugins/go/tasks/v1/config/config_flags_test.go @@ -0,0 +1,124 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package config + +import ( + "encoding/json" + "fmt" + "reflect" + "strings" + "testing" + + "github.com/mitchellh/mapstructure" + "github.com/stretchr/testify/assert" +) + +var dereferencableKindsConfig = map[reflect.Kind]struct{}{ + reflect.Array: {}, reflect.Chan: {}, reflect.Map: {}, reflect.Ptr: {}, reflect.Slice: {}, +} + +// Checks if t is a kind that can be dereferenced to get its underlying type. +func canGetElementConfig(t reflect.Kind) bool { + _, exists := dereferencableKindsConfig[t] + return exists +} + +// This decoder hook tests types for json unmarshaling capability. If implemented, it uses json unmarshal to build the +// object. Otherwise, it'll just pass on the original data. +func jsonUnmarshalerHookConfig(_, to reflect.Type, data interface{}) (interface{}, error) { + unmarshalerType := reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() + if to.Implements(unmarshalerType) || reflect.PtrTo(to).Implements(unmarshalerType) || + (canGetElementConfig(to.Kind()) && to.Elem().Implements(unmarshalerType)) { + + raw, err := json.Marshal(data) + if err != nil { + fmt.Printf("Failed to marshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + res := reflect.New(to).Interface() + err = json.Unmarshal(raw, &res) + if err != nil { + fmt.Printf("Failed to umarshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + return res, nil + } + + return data, nil +} + +func decode_Config(input, result interface{}) error { + config := &mapstructure.DecoderConfig{ + TagName: "json", + WeaklyTypedInput: true, + Result: result, + DecodeHook: mapstructure.ComposeDecodeHookFunc( + mapstructure.StringToTimeDurationHookFunc(), + mapstructure.StringToSliceHookFunc(","), + jsonUnmarshalerHookConfig, + ), + } + + decoder, err := mapstructure.NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(input) +} + +func join_Config(arr interface{}, sep string) string { + listValue := reflect.ValueOf(arr) + strs := make([]string, 0, listValue.Len()) + for i := 0; i < listValue.Len(); i++ { + strs = append(strs, fmt.Sprintf("%v", listValue.Index(i))) + } + + return strings.Join(strs, sep) +} + +func testDecodeJson_Config(t *testing.T, val, result interface{}) { + assert.NoError(t, decode_Config(val, result)) +} + +func testDecodeSlice_Config(t *testing.T, vStringSlice, result interface{}) { + assert.NoError(t, decode_Config(vStringSlice, result)) +} + +func TestConfig_GetPFlagSet(t *testing.T) { + val := Config{} + cmdFlags := val.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) +} + +func TestConfig_SetFlags(t *testing.T) { + actual := Config{} + cmdFlags := actual.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) + + t.Run("Test_enabled-plugins", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vStringSlice, err := cmdFlags.GetStringSlice("enabled-plugins"); err == nil { + assert.Equal(t, []string([]string{"*"}), vStringSlice) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := join_Config([]string{"*"}, ",") + + cmdFlags.Set("enabled-plugins", testValue) + if vStringSlice, err := cmdFlags.GetStringSlice("enabled-plugins"); err == nil { + testDecodeSlice_Config(t, join_Config(vStringSlice, ","), &actual.EnabledPlugins) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) +} diff --git a/flyteplugins/go/tasks/v1/config_load_test.go b/flyteplugins/go/tasks/v1/config_load_test.go new file mode 100755 index 0000000000..5a51a7e71c --- /dev/null +++ b/flyteplugins/go/tasks/v1/config_load_test.go @@ -0,0 +1,94 @@ +package v1_test + +import ( + "context" + "testing" + + "github.com/lyft/flytestdlib/config" + "github.com/lyft/flytestdlib/config/viper" + "github.com/stretchr/testify/assert" + v1 "k8s.io/api/core/v1" + + pluginsConfig "github.com/lyft/flyteplugins/go/tasks/v1/config" + flyteK8sConfig "github.com/lyft/flyteplugins/go/tasks/v1/flytek8s/config" + "github.com/lyft/flyteplugins/go/tasks/v1/k8splugins" + "github.com/lyft/flyteplugins/go/tasks/v1/logs" + quboleConfig "github.com/lyft/flyteplugins/go/tasks/v1/qubole/config" +) + +func TestLoadConfig(t *testing.T) { + configAccessor := viper.NewAccessor(config.Options{ + StrictMode: true, + SearchPaths: []string{"testdata/config.yaml"}, + }) + + err := configAccessor.UpdateConfig(context.TODO()) + assert.NoError(t, err) + + t.Run("root-config-test", func(t *testing.T) { + assert.Equal(t, 1, len(pluginsConfig.GetConfig().EnabledPlugins)) + }) + + t.Run("k8s-config-test", func(t *testing.T) { + + k8sConfig := flyteK8sConfig.GetK8sPluginConfig() + assert.True(t, k8sConfig.InjectFinalizer) + assert.Equal(t, map[string]string{ + "annotationKey1": "annotationValue1", + "annotationKey2": "annotationValue2", + "cluster-autoscaler.kubernetes.io/safe-to-evict": "false", + }, k8sConfig.DefaultAnnotations) + assert.Equal(t, map[string]string{ + "label1": "labelValue1", + "label2": "labelValue2", + }, k8sConfig.DefaultLabels) + assert.Equal(t, map[string]string{ + "AWS_METADATA_SERVICE_NUM_ATTEMPTS": "20", + "AWS_METADATA_SERVICE_TIMEOUT": "5", + "FLYTE_AWS_ACCESS_KEY_ID": "minio", + "FLYTE_AWS_ENDPOINT": "http://minio.flyte:9000", + "FLYTE_AWS_SECRET_ACCESS_KEY": "miniostorage", + }, k8sConfig.DefaultEnvVars) + assert.NotNil(t, k8sConfig.ResourceTolerations) + assert.Contains(t, k8sConfig.ResourceTolerations, v1.ResourceName("nvidia.com/gpu")) + assert.Contains(t, k8sConfig.ResourceTolerations, v1.ResourceStorage) + tolGPU := v1.Toleration{ + Key: "flyte/gpu", + Value: "dedicated", + Operator: v1.TolerationOpEqual, + Effect: v1.TaintEffectNoSchedule, + } + + tolStorage := v1.Toleration{ + Key: "storage", + Value: "special", + Operator: v1.TolerationOpEqual, + Effect: v1.TaintEffectPreferNoSchedule, + } + + assert.Equal(t, []v1.Toleration{tolGPU}, k8sConfig.ResourceTolerations[v1.ResourceName("nvidia.com/gpu")]) + assert.Equal(t, []v1.Toleration{tolStorage}, k8sConfig.ResourceTolerations[v1.ResourceStorage]) + assert.Equal(t, "1000m", k8sConfig.DefaultCpuRequest) + assert.Equal(t, "1024Mi", k8sConfig.DefaultMemoryRequest) + }) + + t.Run("logs-config-test", func(t *testing.T) { + assert.NotNil(t, logs.GetLogConfig()) + assert.True(t, logs.GetLogConfig().IsKubernetesEnabled) + }) + + t.Run("spark-config-test", func(t *testing.T) { + assert.NotNil(t, k8splugins.GetSparkConfig()) + assert.NotNil(t, k8splugins.GetSparkConfig().DefaultSparkConfig) + }) + + t.Run("qubole-config-test", func(t *testing.T) { + assert.NotNil(t, quboleConfig.GetQuboleConfig()) + assert.Equal(t, "redis-resource-manager.flyte:6379", quboleConfig.GetQuboleConfig().RedisHostPath) + }) + + t.Run("waitable-config-test", func(t *testing.T) { + assert.NotNil(t, k8splugins.GetWaitableConfig()) + assert.Equal(t, "http://localhost:30081/console", k8splugins.GetWaitableConfig().ConsoleURI.String()) + }) +} diff --git a/flyteplugins/go/tasks/v1/errors/errors.go b/flyteplugins/go/tasks/v1/errors/errors.go new file mode 100755 index 0000000000..664f202cb5 --- /dev/null +++ b/flyteplugins/go/tasks/v1/errors/errors.go @@ -0,0 +1,76 @@ +package errors + +import ( + "fmt" + + "github.com/pkg/errors" +) + +type ErrorCode = string + +const ( + TaskFailedWithError ErrorCode = "TaskFailedWithError" + DownstreamSystemError ErrorCode = "DownstreamSystemError" + TaskFailedUnknownError ErrorCode = "TaskFailedUnknownError" + BadTaskSpecification ErrorCode = "BadTaskSpecification" + TaskEventRecordingFailed ErrorCode = "TaskEventRecordingFailed" + MetadataAccessFailed ErrorCode = "MetadataAccessFailed" + MetadataTooLarge ErrorCode = "MetadataTooLarge" + PluginInitializationFailed ErrorCode = "PluginInitializationFailed" + CacheFailed ErrorCode = "AutoRefreshCacheFailed" + RuntimeFailure ErrorCode = "RuntimeFailure" +) + +type TaskError struct { + Code string + Message string +} + +func (e *TaskError) Error() string { + return fmt.Sprintf("task failed, %v: %v", e.Code, e.Message) +} + +type TaskErrorWithCause struct { + *TaskError + cause error +} + +func (e *TaskErrorWithCause) Error() string { + return fmt.Sprintf("%v, caused by: %v", e.TaskError.Error(), errors.Cause(e)) +} + +func (e *TaskErrorWithCause) Cause() error { + return e.cause +} + +func Errorf(errorCode ErrorCode, msgFmt string, args ...interface{}) *TaskError { + return &TaskError{ + Code: errorCode, + Message: fmt.Sprintf(msgFmt, args...), + } +} + +func Wrapf(errorCode ErrorCode, err error, msgFmt string, args ...interface{}) *TaskErrorWithCause { + return &TaskErrorWithCause{ + TaskError: Errorf(errorCode, msgFmt, args...), + cause: err, + } +} + +func GetErrorCode(err error) (code ErrorCode, isTaskError bool) { + isTaskError = false + e, ok := err.(*TaskError) + if ok { + code = e.Code + isTaskError = true + return + } + + e2, ok := err.(*TaskError) + if ok { + code = e2.Code + isTaskError = true + return + } + return +} diff --git a/flyteplugins/go/tasks/v1/events/event_utils.go b/flyteplugins/go/tasks/v1/events/event_utils.go new file mode 100755 index 0000000000..44e222e4dd --- /dev/null +++ b/flyteplugins/go/tasks/v1/events/event_utils.go @@ -0,0 +1,100 @@ +package events + +import ( + "time" + + structpb "github.com/golang/protobuf/ptypes/struct" + "github.com/lyft/flyteplugins/go/tasks/v1/errors" + "github.com/lyft/flyteplugins/go/tasks/v1/types" + + "github.com/golang/protobuf/ptypes" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/event" +) + +// Additional info that should be sent to the front end. The Information is sent to the front-end if it meets certain +// criterion, for example currently, it is sent only if an event was not already sent for +type TaskEventInfo struct { + // log information for the task execution + Logs []*core.TaskLog + // Set this value to the intended time when the status occurred at. If not provided, will be defaulted to the current + // time at the time of publishing the event. + OccurredAt *time.Time + // Custom Event information that the plugin would like to expose to the front-end + CustomInfo *structpb.Struct +} + +// Convert all TaskStatus to an ExecutionPhase that is common to Flyte and Admin understands +// NOTE: if we add a TaskStatus entry, we should add it here too +func convertTaskPhaseToExecutionStatus(status types.TaskPhase) core.TaskExecution_Phase { + switch status { + case types.TaskPhaseRunning: + return core.TaskExecution_RUNNING + case types.TaskPhaseSucceeded: + return core.TaskExecution_SUCCEEDED + case types.TaskPhaseRetryableFailure, types.TaskPhasePermanentFailure: + return core.TaskExecution_FAILED + case types.TaskPhaseQueued: + return core.TaskExecution_QUEUED + default: + return core.TaskExecution_UNDEFINED + } +} + +func CreateEvent(taskCtx types.TaskContext, taskStatus types.TaskStatus, info *TaskEventInfo) *event.TaskExecutionEvent { + + newTaskExecutionPhase := convertTaskPhaseToExecutionStatus(taskStatus.Phase) + taskExecutionID := taskCtx.GetTaskExecutionID().GetID() + + occurredAt := ptypes.TimestampNow() + logs := make([]*core.TaskLog, 0) + var customInfo *structpb.Struct + + if info != nil { + customInfo = info.CustomInfo + if info.OccurredAt != nil { + t, err := ptypes.TimestampProto(*info.OccurredAt) + if err != nil { + occurredAt = t + } + } + + logs = append(logs, info.Logs...) + } + + taskEvent := &event.TaskExecutionEvent{ + Phase: newTaskExecutionPhase, + PhaseVersion: taskStatus.PhaseVersion, + RetryAttempt: taskCtx.GetTaskExecutionID().GetID().RetryAttempt, + InputUri: taskCtx.GetInputsFile().String(), + OccurredAt: occurredAt, + Logs: logs, + CustomInfo: customInfo, + TaskId: taskExecutionID.TaskId, + ParentNodeExecutionId: taskExecutionID.NodeExecutionId, + } + + if newTaskExecutionPhase == core.TaskExecution_FAILED { + errorCode := "UnknownTaskError" + message := "unknown reason" + if taskStatus.Err != nil { + ec, ok := errors.GetErrorCode(taskStatus.Err) + if ok { + errorCode = ec + } + message = taskStatus.Err.Error() + } + taskEvent.OutputResult = &event.TaskExecutionEvent_Error{ + Error: &core.ExecutionError{ + Code: errorCode, + Message: message, + ErrorUri: taskCtx.GetErrorFile().String(), + }, + } + } else if newTaskExecutionPhase == core.TaskExecution_SUCCEEDED { + taskEvent.OutputResult = &event.TaskExecutionEvent_OutputUri{ + OutputUri: taskCtx.GetOutputsFile().String(), + } + } + return taskEvent +} diff --git a/flyteplugins/go/tasks/v1/events/event_utils_test.go b/flyteplugins/go/tasks/v1/events/event_utils_test.go new file mode 100755 index 0000000000..61b091db9e --- /dev/null +++ b/flyteplugins/go/tasks/v1/events/event_utils_test.go @@ -0,0 +1,193 @@ +package events + +import ( + "testing" + "time" + + "github.com/golang/protobuf/ptypes" + structpb "github.com/golang/protobuf/ptypes/struct" + "github.com/lyft/flyteplugins/go/tasks/v1/errors" + + "github.com/lyft/flyteplugins/go/tasks/v1/types" + "github.com/lyft/flyteplugins/go/tasks/v1/types/mocks" + "github.com/lyft/flytestdlib/storage" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/stretchr/testify/assert" +) + +const inputs = storage.DataReference("inputs.pb") +const outputs = storage.DataReference("outputs.pb") +const errorsFile = storage.DataReference("errorsFile.pb") + +var testTaskExecutionIdentifier = core.TaskExecutionIdentifier{ + TaskId: &core.Identifier{ + ResourceType: core.ResourceType_TASK, + Project: "proj", + Domain: "domain", + Name: "name", + }, + NodeExecutionId: &core.NodeExecutionIdentifier{ + NodeId: "nodeId", + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "proj", + Domain: "domain", + Name: "name", + }, + }, +} + +type mockTaskExecutionIdentifier struct{} + +func (m mockTaskExecutionIdentifier) GetID() core.TaskExecutionIdentifier { + return testTaskExecutionIdentifier +} + +func (m mockTaskExecutionIdentifier) GetGeneratedName() string { + return "task-exec-name" +} + +func TestEventsPublisher_Queued(t *testing.T) { + startedAt := time.Now() + taskCtx := &mocks.TaskContext{} + taskCtx.On("GetTaskExecutionID").Return(mockTaskExecutionIdentifier{}) + taskCtx.On("GetInputsFile").Return(inputs) + + e := CreateEvent(taskCtx, types.TaskStatusQueued, nil) + assert.Equal(t, e.GetPhase(), core.TaskExecution_QUEUED) + assert.Equal(t, e.GetTaskId(), testTaskExecutionIdentifier.TaskId) + assert.Equal(t, e.GetParentNodeExecutionId(), testTaskExecutionIdentifier.NodeExecutionId) + assert.Equal(t, e.GetInputUri(), string(inputs)) + assert.WithinDuration(t, startedAt, time.Now(), time.Millisecond*5) +} + +func TestEventsPublisher_Running(t *testing.T) { + startedAt := time.Now() + + taskCtx := &mocks.TaskContext{} + taskCtx.On("GetTaskExecutionID").Return(mockTaskExecutionIdentifier{}) + taskCtx.On("GetInputsFile").Return(inputs) + + e := CreateEvent(taskCtx, types.TaskStatusRunning, nil) + assert.Equal(t, e.GetPhase(), core.TaskExecution_RUNNING) + assert.Equal(t, e.GetTaskId(), testTaskExecutionIdentifier.TaskId) + assert.Equal(t, e.GetParentNodeExecutionId(), testTaskExecutionIdentifier.NodeExecutionId) + assert.Equal(t, e.GetInputUri(), string(inputs)) + assert.WithinDuration(t, startedAt, time.Now(), time.Millisecond*5) +} + +func TestEventsPublisher_Success(t *testing.T) { + startedAt := time.Now() + + taskCtx := &mocks.TaskContext{} + taskCtx.On("GetTaskExecutionID").Return(mockTaskExecutionIdentifier{}) + taskCtx.On("GetInputsFile").Return(inputs) + taskCtx.On("GetOutputsFile").Return(outputs) + + e := CreateEvent(taskCtx, types.TaskStatusSucceeded, nil) + assert.Equal(t, e.GetPhase(), core.TaskExecution_SUCCEEDED) + assert.Equal(t, e.GetTaskId(), testTaskExecutionIdentifier.TaskId) + assert.Equal(t, e.GetParentNodeExecutionId(), testTaskExecutionIdentifier.NodeExecutionId) + assert.Equal(t, e.GetOutputUri(), string(outputs)) + assert.WithinDuration(t, startedAt, time.Now(), time.Millisecond*5) +} + +func TestEventsPublisher_PermanentFailed(t *testing.T) { + startedAt := time.Now() + + taskCtx := &mocks.TaskContext{} + taskCtx.On("GetTaskExecutionID").Return(mockTaskExecutionIdentifier{}) + taskCtx.On("GetInputsFile").Return(inputs) + taskCtx.On("GetOutputsFile").Return(outputs) + taskCtx.On("GetErrorFile").Return(errorsFile) + + err := errors.Errorf("test", "failed") + e := CreateEvent(taskCtx, types.TaskStatusPermanentFailure(err), nil) + assert.Equal(t, e.GetPhase(), core.TaskExecution_FAILED) + assert.Equal(t, e.GetTaskId(), testTaskExecutionIdentifier.TaskId) + assert.Equal(t, e.GetParentNodeExecutionId(), testTaskExecutionIdentifier.NodeExecutionId) + assert.NotNil(t, e.GetError()) + assert.Equal(t, e.GetError().Code, "test") + assert.Equal(t, e.GetError().Message, "task failed, test: failed") + assert.Equal(t, e.GetError().GetErrorUri(), string(errorsFile)) + assert.WithinDuration(t, startedAt, time.Now(), time.Millisecond*5) +} + +func TestEventsPublisher_RetryableFailed(t *testing.T) { + startedAt := time.Now() + + taskCtx := &mocks.TaskContext{} + taskCtx.On("GetTaskExecutionID").Return(mockTaskExecutionIdentifier{}) + taskCtx.On("GetInputsFile").Return(inputs) + taskCtx.On("GetOutputsFile").Return(outputs) + taskCtx.On("GetErrorFile").Return(errorsFile) + + err := errors.Errorf("test", "failed") + e := CreateEvent(taskCtx, types.TaskStatusRetryableFailure(err), nil) + assert.Equal(t, e.GetPhase(), core.TaskExecution_FAILED) + assert.Equal(t, e.GetTaskId(), testTaskExecutionIdentifier.TaskId) + assert.Equal(t, e.GetParentNodeExecutionId(), testTaskExecutionIdentifier.NodeExecutionId) + assert.NotNil(t, e.GetError()) + assert.Equal(t, e.GetError().Code, "test") + assert.Equal(t, e.GetError().Message, "task failed, test: failed") + assert.Equal(t, e.GetError().GetErrorUri(), string(errorsFile)) + assert.WithinDuration(t, startedAt, time.Now(), time.Millisecond*5) +} + +func TestEventsPublisher_FailedNilError(t *testing.T) { + startedAt := time.Now() + + taskCtx := &mocks.TaskContext{} + taskCtx.On("GetTaskExecutionID").Return(mockTaskExecutionIdentifier{}) + taskCtx.On("GetInputsFile").Return(inputs) + taskCtx.On("GetOutputsFile").Return(outputs) + taskCtx.On("GetErrorFile").Return(errorsFile) + + e := CreateEvent(taskCtx, types.TaskStatusRetryableFailure(nil), nil) + assert.Equal(t, e.GetPhase(), core.TaskExecution_FAILED) + assert.Equal(t, e.GetTaskId(), testTaskExecutionIdentifier.TaskId) + assert.Equal(t, e.GetParentNodeExecutionId(), testTaskExecutionIdentifier.NodeExecutionId) + assert.NotNil(t, e.GetError()) + assert.Equal(t, e.GetError().Code, "UnknownTaskError") + assert.Equal(t, e.GetError().Message, "unknown reason") + assert.Equal(t, e.GetError().GetErrorUri(), string(errorsFile)) + assert.WithinDuration(t, startedAt, time.Now(), time.Millisecond*5) +} + +func TestEventsPublisher_WithCustomInfo(t *testing.T) { + startedAt := time.Now() + taskCtx := &mocks.TaskContext{} + taskCtx.On("GetTaskExecutionID").Return(mockTaskExecutionIdentifier{}) + taskCtx.On("GetInputsFile").Return(inputs) + + t.Run("emptyInfo", func(t *testing.T) { + e := CreateEvent(taskCtx, types.TaskStatusQueued, &TaskEventInfo{}) + assert.Equal(t, e.GetPhase(), core.TaskExecution_QUEUED) + assert.Equal(t, e.GetTaskId(), testTaskExecutionIdentifier.TaskId) + assert.Equal(t, e.GetParentNodeExecutionId(), testTaskExecutionIdentifier.NodeExecutionId) + assert.Equal(t, e.GetInputUri(), string(inputs)) + assert.WithinDuration(t, startedAt, time.Now(), time.Millisecond*5) + }) + + t.Run("withInfo", func(t *testing.T) { + n := time.Now() + s := structpb.Struct{} + e := CreateEvent(taskCtx, types.TaskStatusQueued, &TaskEventInfo{ + Logs: []*core.TaskLog{{Uri: "l1"}, {Uri: "l2"}}, + OccurredAt: &n, + CustomInfo: &s, + }) + assert.Equal(t, e.GetPhase(), core.TaskExecution_QUEUED) + assert.Equal(t, e.GetTaskId(), testTaskExecutionIdentifier.TaskId) + assert.Equal(t, e.GetParentNodeExecutionId(), testTaskExecutionIdentifier.NodeExecutionId) + assert.Equal(t, e.GetInputUri(), string(inputs)) + o, err := ptypes.Timestamp(e.OccurredAt) + assert.NoError(t, err) + assert.Equal(t, n.Unix(), o.Unix()) + assert.Equal(t, len(e.Logs), 2) + assert.Equal(t, e.Logs[0].Uri, "l1") + assert.Equal(t, e.Logs[1].Uri, "l2") + assert.Equal(t, e.CustomInfo, &s) + + }) +} diff --git a/flyteplugins/go/tasks/v1/factory.go b/flyteplugins/go/tasks/v1/factory.go new file mode 100755 index 0000000000..e58de4b437 --- /dev/null +++ b/flyteplugins/go/tasks/v1/factory.go @@ -0,0 +1,143 @@ +package v1 + +import ( + "context" + "fmt" + "github.com/lyft/flyteplugins/go/tasks/v1/config" + "github.com/lyft/flytestdlib/logger" + "sync" + "time" + + "github.com/lyft/flyteplugins/go/tasks/v1/flytek8s" + "github.com/lyft/flyteplugins/go/tasks/v1/types" + "github.com/lyft/flyteplugins/go/tasks/v1/utils" + "github.com/lyft/flytestdlib/sets" + "github.com/pkg/errors" + "k8s.io/apimachinery/pkg/runtime" +) + +type taskFactory struct { + registeredTasksForTypes map[types.TaskType]types.Executor + registeredDefault types.Executor + registrationLock sync.RWMutex +} + +var taskFactorySingleton = &taskFactory{ + registeredTasksForTypes: make(map[types.TaskType]types.Executor), +} + +func GetTaskExecutor(taskType types.TaskType) (types.Executor, error) { + taskFactorySingleton.registrationLock.RLock() + defer taskFactorySingleton.registrationLock.RUnlock() + + e, ok := taskFactorySingleton.registeredTasksForTypes[taskType] + if !ok { + exec := taskFactorySingleton.registeredDefault + if exec != nil { + logger.Debugf(context.TODO(), "All registered plugins [%v]", taskFactorySingleton.registeredTasksForTypes) + logger.Debugf(context.TODO(), "Using default plugin for [%s]", taskType) + return exec, nil + } + + return nil, errors.Errorf("No Executor defined for TaskType [%v] and no default executor is configured.", taskType) + } + logger.Debugf(context.TODO(), "Using plugin [%s] for task type [%s]", e.GetID(), taskType) + + return e, nil +} + +func ListAllTaskExecutors() []types.Executor { + taskFactorySingleton.registrationLock.RLock() + defer taskFactorySingleton.registrationLock.RUnlock() + + taskExecutors := sets.NewGeneric() + if taskFactorySingleton.registeredDefault != nil { + taskExecutors.Insert(taskFactorySingleton.registeredDefault) + } + + for _, v := range taskFactorySingleton.registeredTasksForTypes { + taskExecutors.Insert(v) + } + setList := taskExecutors.UnsortedList() + tList := make([]types.Executor, 0, len(setList)) + for _, t := range setList { + tList = append(tList, t.(types.Executor)) + } + return tList +} + +func GetAllTaskTypeExecutors() map[types.TaskType]types.Executor { + taskFactorySingleton.registrationLock.RLock() + defer taskFactorySingleton.registrationLock.RUnlock() + m := make(map[types.TaskType]types.Executor, len(taskFactorySingleton.registeredTasksForTypes)) + for k, v := range taskFactorySingleton.registeredTasksForTypes { + m[k] = v + } + + return m +} + +func isEnabled(enabledPlugins []string, pluginToCheck string) bool { + return enabledPlugins != nil && len(enabledPlugins) >= 1 && + (enabledPlugins[0] == "*" || utils.Contains(enabledPlugins, pluginToCheck)) +} + +func RegisterAsDefault(executor types.Executor) error { + enabledPlugins := config.GetConfig().EnabledPlugins + if isEnabled(enabledPlugins, executor.GetID()) { + taskFactorySingleton.registrationLock.Lock() + defer taskFactorySingleton.registrationLock.Unlock() + if existingDefault := taskFactorySingleton.registeredDefault; existingDefault != nil { + return fmt.Errorf("a default Executor already exists. Existing Plugin Id [%v], Proposed Plugin [%v]", + existingDefault.GetID(), executor.GetID()) + } + + taskFactorySingleton.registeredDefault = executor + } + + return nil +} + +func RegisterForTaskTypes(executor types.Executor, taskTypes ...types.TaskType) error { + logger.InfofNoCtx("Request to register executor [%v] for types [%+v]", executor.GetID(), taskTypes) + enabledPlugins := config.GetConfig().EnabledPlugins + if isEnabled(enabledPlugins, executor.GetID()) { + logger.InfofNoCtx("Executor [%v] is enabled, attempting to register it.", executor.GetID()) + taskFactorySingleton.registrationLock.Lock() + defer taskFactorySingleton.registrationLock.Unlock() + for _, t := range taskTypes { + x, ok := taskFactorySingleton.registeredTasksForTypes[t] + if ok { + return fmt.Errorf("an Executor already exists for TaskType [%v]. Existing Plugin Id [%v], Proposed Plugin [%v]", t, x.GetID(), executor.GetID()) + } + + logger.InfofNoCtx("Registering type [%s] with executor [%s]", t, executor.GetID()) + taskFactorySingleton.registeredTasksForTypes[t] = executor + } + } else { + logger.InfofNoCtx("Executor [%v] is not enabled, not registering it. Enabled Plugins: [%+v]", executor.GetID(), enabledPlugins) + } + + return nil +} + +func K8sRegisterAsDefault(id string, resourceToWatch runtime.Object, resyncPeriod time.Duration, handler flytek8s.K8sResourceHandler) error { + exec := flytek8s.NewK8sTaskExecutorForResource(id, resourceToWatch, handler, resyncPeriod) + return RegisterAsDefault(exec) +} + +func K8sRegisterForTaskTypes(id string, resourceToWatch runtime.Object, resyncPeriod time.Duration, handler flytek8s.K8sResourceHandler, taskTypes ...types.TaskType) error { + exec := flytek8s.NewK8sTaskExecutorForResource(id, resourceToWatch, handler, resyncPeriod) + return RegisterForTaskTypes(exec, taskTypes...) +} + +// Clears all registered plugins. This should only be called after all packages/plugins have been safely loaded. Otherwise, +// init-based registrations will sneak through. +func ClearRegistry(_ context.Context) error { + taskFactorySingleton.registrationLock.Lock() + defer taskFactorySingleton.registrationLock.Unlock() + + taskFactorySingleton.registeredTasksForTypes = map[types.TaskType]types.Executor{} + taskFactorySingleton.registeredDefault = nil + return nil +} diff --git a/flyteplugins/go/tasks/v1/flytek8s/config/config.go b/flyteplugins/go/tasks/v1/flytek8s/config/config.go new file mode 100755 index 0000000000..ecd05a45bd --- /dev/null +++ b/flyteplugins/go/tasks/v1/flytek8s/config/config.go @@ -0,0 +1,59 @@ +package config + +import ( + "k8s.io/api/core/v1" + + "github.com/lyft/flyteplugins/go/tasks/v1/config" +) + +const k8sPluginConfigSectionKey = "k8s" +const defaultCpuRequest = "1000m" +const defaultMemoryRequest = "1024Mi" + +var ( + defaultK8sConfig = K8sPluginConfig{ + DefaultAnnotations: map[string]string{ + "cluster-autoscaler.kubernetes.io/safe-to-evict": "false", + }, + } + + // Top level k8s plugin config section. If you are a plugin developer writing a k8s plugin, + // register your config section as a subsection to this. + K8sPluginConfigSection = config.MustRegisterSubSection(k8sPluginConfigSectionKey, &defaultK8sConfig) +) + +// Top level k8s plugin config. +type K8sPluginConfig struct { + // Boolean flag that indicates if a finalizer should be injected into every K8s resource launched + InjectFinalizer bool `json:"inject-finalizer" pflag:",Instructs the plugin to inject a finalizer on startTask and remove it on task termination."` + // Provide default annotations that should be added to K8s resource + DefaultAnnotations map[string]string `json:"default-annotations" pflag:",Defines a set of default annotations to add to the produced pods."` + // Provide default labels that should be added to K8s resource + DefaultLabels map[string]string `json:"default-labels" pflag:",Defines a set of default labels to add to the produced pods."` + // Provide additional environment variable pairs that plugin authors will provide to containers + DefaultEnvVars map[string]string `json:"default-env-vars" pflag:",Additional environment variable that should be injected into every resource"` + // Tolerations in the cluster that should be applied for a specific resource + // Currently we support simple resource based tolerations only + ResourceTolerations map[v1.ResourceName][]v1.Toleration `json:"resource-tolerations"` + // default cpu requests for a container + DefaultCpuRequest string `json:"default-cpus" pflag:",Defines a default value for cpu for containers if not specified."` + // default memory requests for a container + DefaultMemoryRequest string `json:"default-memory" pflag:",Defines a default value for memory for containers if not specified."` +} + +// Retrieves the current k8s plugin config or default. +func GetK8sPluginConfig() *K8sPluginConfig { + pluginsConfig := K8sPluginConfigSection.GetConfig().(*K8sPluginConfig) + if pluginsConfig.DefaultMemoryRequest == "" { + pluginsConfig.DefaultMemoryRequest = defaultMemoryRequest + } + if pluginsConfig.DefaultCpuRequest == "" { + pluginsConfig.DefaultCpuRequest = defaultCpuRequest + } + return pluginsConfig +} + +// [FOR TESTING ONLY] Sets current value for the config. +func SetK8sPluginConfig(cfg *K8sPluginConfig) error { + return K8sPluginConfigSection.SetConfig(cfg) +} diff --git a/flyteplugins/go/tasks/v1/flytek8s/constants.go b/flyteplugins/go/tasks/v1/flytek8s/constants.go new file mode 100755 index 0000000000..1cfbe30819 --- /dev/null +++ b/flyteplugins/go/tasks/v1/flytek8s/constants.go @@ -0,0 +1,11 @@ +package flytek8s + +import "time" + +const DefaultInformerResyncDuration = 30 * time.Second +const Kilobytes = 1024 * 1 +const Megabytes = 1024 * Kilobytes +const Gigabytes = 1024 * Megabytes +const MaxMetadataPayloadSizeBytes = 10 * Megabytes + +const finalizer = "flyte/flytek8s" diff --git a/flyteplugins/go/tasks/v1/flytek8s/container_helper.go b/flyteplugins/go/tasks/v1/flytek8s/container_helper.go new file mode 100755 index 0000000000..5c0f6dbf7b --- /dev/null +++ b/flyteplugins/go/tasks/v1/flytek8s/container_helper.go @@ -0,0 +1,123 @@ +package flytek8s + +import ( + "context" + "regexp" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytestdlib/logger" + v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + "k8s.io/apimachinery/pkg/util/rand" + + "github.com/lyft/flyteplugins/go/tasks/v1/errors" + "github.com/lyft/flyteplugins/go/tasks/v1/flytek8s/config" + "github.com/lyft/flyteplugins/go/tasks/v1/types" + "github.com/lyft/flyteplugins/go/tasks/v1/utils" +) + +var isAcceptableK8sName, _ = regexp.Compile("[a-z0-9]([-a-z0-9]*[a-z0-9])?") + +const resourceGPU = "GPU" + +// ResourceNvidiaGPU is the name of the Nvidia GPU resource. +// Copied from: k8s.io/autoscaler/cluster-autoscaler/utils/gpu/gpu.go +const ResourceNvidiaGPU = "nvidia.com/gpu" + +func ApplyResourceOverrides(ctx context.Context, resources v1.ResourceRequirements) *v1.ResourceRequirements { + // set memory and cpu to default if not provided by user. + if len(resources.Requests) == 0 { + resources.Requests = make(v1.ResourceList) + } + if _, found := resources.Requests[v1.ResourceCPU]; !found { + resources.Requests[v1.ResourceCPU] = resource.MustParse(config.GetK8sPluginConfig().DefaultCpuRequest) + } + if _, found := resources.Requests[v1.ResourceMemory]; !found { + resources.Requests[v1.ResourceMemory] = resource.MustParse(config.GetK8sPluginConfig().DefaultMemoryRequest) + } + + if len(resources.Limits) == 0 { + resources.Limits = make(v1.ResourceList) + } + if len(resources.Requests) == 0 { + resources.Requests = make(v1.ResourceList) + } + if _, found := resources.Limits[v1.ResourceCPU]; !found { + logger.Infof(ctx, "found cpu limit missing, setting limit to the requested value %v", resources.Requests[v1.ResourceCPU]) + resources.Limits[v1.ResourceCPU] = resources.Requests[v1.ResourceCPU] + } + if _, found := resources.Limits[v1.ResourceMemory]; !found { + logger.Infof(ctx, "found memory limit missing, setting limit to the requested value %v", resources.Requests[v1.ResourceMemory]) + resources.Limits[v1.ResourceMemory] = resources.Requests[v1.ResourceMemory] + } + + // TODO: Make configurable. 1/15/2019 Flyte Cluster doesn't support setting storage requests/limits. + // https://github.com/kubernetes/enhancements/issues/362 + delete(resources.Requests, v1.ResourceStorage) + delete(resources.Requests, v1.ResourceEphemeralStorage) + + delete(resources.Limits, v1.ResourceStorage) + delete(resources.Limits, v1.ResourceEphemeralStorage) + + // Override GPU + if resource, found := resources.Requests[resourceGPU]; found { + resources.Requests[ResourceNvidiaGPU] = resource + delete(resources.Requests, resourceGPU) + } + if resource, found := resources.Limits[resourceGPU]; found { + resources.Limits[ResourceNvidiaGPU] = resource + delete(resources.Requests, resourceGPU) + } + + return &resources +} + +// Returns a K8s Container for the execution +func ToK8sContainer(ctx context.Context, taskCtx types.TaskContext, taskContainer *core.Container, inputs *core.LiteralMap) (*v1.Container, error) { + inputFile := taskCtx.GetInputsFile() + cmdLineArgs := utils.CommandLineTemplateArgs{ + Input: inputFile.String(), + OutputPrefix: taskCtx.GetDataDir().String(), + Inputs: utils.LiteralMapToTemplateArgs(ctx, inputs), + } + + modifiedCommand, err := utils.ReplaceTemplateCommandArgs(ctx, taskContainer.GetCommand(), cmdLineArgs) + if err != nil { + return nil, err + } + + modifiedArgs, err := utils.ReplaceTemplateCommandArgs(ctx, taskContainer.GetArgs(), cmdLineArgs) + if err != nil { + return nil, err + } + + envVars := DecorateEnvVars(ctx, ToK8sEnvVar(taskContainer.GetEnv()), taskCtx.GetTaskExecutionID()) + + if taskCtx.GetOverrides() == nil { + return nil, errors.Errorf(errors.BadTaskSpecification, "platform/compiler error, overrides not set for task") + } + if taskCtx.GetOverrides() == nil || taskCtx.GetOverrides().GetResources() == nil { + return nil, errors.Errorf(errors.BadTaskSpecification, "resource requirements not found for container task, required!") + } + + res := taskCtx.GetOverrides().GetResources() + if res != nil { + res = ApplyResourceOverrides(ctx, *res) + } + + // Make the container name the same as the pod name, unless it violates K8s naming conventions + // Container names are subject to the DNS-1123 standard + containerName := taskCtx.GetTaskExecutionID().GetGeneratedName() + if !isAcceptableK8sName.MatchString(containerName) || len(containerName) > 63 { + containerName = rand.String(4) + } + + return &v1.Container{ + Name: containerName, + Image: taskContainer.GetImage(), + Args: modifiedArgs, + Command: modifiedCommand, + Env: envVars, + Resources: *res, + }, nil +} diff --git a/flyteplugins/go/tasks/v1/flytek8s/container_helper_test.go b/flyteplugins/go/tasks/v1/flytek8s/container_helper_test.go new file mode 100755 index 0000000000..c1f530152e --- /dev/null +++ b/flyteplugins/go/tasks/v1/flytek8s/container_helper_test.go @@ -0,0 +1,98 @@ +package flytek8s + +import ( + "context" + "github.com/stretchr/testify/assert" + "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + "testing" +) + +func TestApplyResourceOverrides_OverrideCpu(t *testing.T) { + cpuRequest := resource.MustParse("1") + overrides := ApplyResourceOverrides(context.Background(), v1.ResourceRequirements{ + Requests: v1.ResourceList{ + v1.ResourceCPU: cpuRequest, + }, + }) + assert.EqualValues(t, cpuRequest, overrides.Requests[v1.ResourceCPU]) + assert.EqualValues(t, cpuRequest, overrides.Limits[v1.ResourceCPU]) + + cpuLimit := resource.MustParse("2") + overrides = ApplyResourceOverrides(context.Background(), v1.ResourceRequirements{ + Requests: v1.ResourceList{ + v1.ResourceCPU: cpuRequest, + }, + Limits: v1.ResourceList{ + v1.ResourceCPU: cpuLimit, + }, + }) + assert.EqualValues(t, cpuRequest, overrides.Requests[v1.ResourceCPU]) + assert.EqualValues(t, cpuLimit, overrides.Limits[v1.ResourceCPU]) +} + +func TestApplyResourceOverrides_OverrideMemory(t *testing.T) { + memoryRequest := resource.MustParse("1") + overrides := ApplyResourceOverrides(context.Background(), v1.ResourceRequirements{ + Requests: v1.ResourceList{ + v1.ResourceMemory: memoryRequest, + }, + }) + assert.EqualValues(t, memoryRequest, overrides.Requests[v1.ResourceMemory]) + assert.EqualValues(t, memoryRequest, overrides.Limits[v1.ResourceMemory]) + + memoryLimit := resource.MustParse("2") + overrides = ApplyResourceOverrides(context.Background(), v1.ResourceRequirements{ + Requests: v1.ResourceList{ + v1.ResourceMemory: memoryRequest, + }, + Limits: v1.ResourceList{ + v1.ResourceMemory: memoryLimit, + }, + }) + assert.EqualValues(t, memoryRequest, overrides.Requests[v1.ResourceMemory]) + assert.EqualValues(t, memoryLimit, overrides.Limits[v1.ResourceMemory]) +} + +func TestApplyResourceOverrides_RemoveStorage(t *testing.T) { + requestedResourceQuantity := resource.MustParse("1") + overrides := ApplyResourceOverrides(context.Background(), v1.ResourceRequirements{ + Requests: v1.ResourceList{ + v1.ResourceStorage: requestedResourceQuantity, + v1.ResourceMemory: requestedResourceQuantity, + v1.ResourceCPU: requestedResourceQuantity, + v1.ResourceEphemeralStorage: requestedResourceQuantity, + }, + Limits: v1.ResourceList{ + v1.ResourceStorage: requestedResourceQuantity, + v1.ResourceMemory: requestedResourceQuantity, + v1.ResourceEphemeralStorage: requestedResourceQuantity, + }, + }) + assert.EqualValues(t, v1.ResourceList{ + v1.ResourceMemory: requestedResourceQuantity, + v1.ResourceCPU: requestedResourceQuantity, + }, overrides.Requests) + + assert.EqualValues(t, v1.ResourceList{ + v1.ResourceMemory: requestedResourceQuantity, + v1.ResourceCPU: requestedResourceQuantity, + }, overrides.Limits) +} + +func TestApplyResourceOverrides_OverrideGpu(t *testing.T) { + gpuRequest := resource.MustParse("1") + overrides := ApplyResourceOverrides(context.Background(), v1.ResourceRequirements{ + Requests: v1.ResourceList{ + resourceGPU: gpuRequest, + }, + }) + assert.EqualValues(t, gpuRequest, overrides.Requests[ResourceNvidiaGPU]) + + overrides = ApplyResourceOverrides(context.Background(), v1.ResourceRequirements{ + Limits: v1.ResourceList{ + resourceGPU: gpuRequest, + }, + }) + assert.EqualValues(t, gpuRequest, overrides.Limits[ResourceNvidiaGPU]) +} diff --git a/flyteplugins/go/tasks/v1/flytek8s/k8s_resource_adds.go b/flyteplugins/go/tasks/v1/flytek8s/k8s_resource_adds.go new file mode 100755 index 0000000000..757ca0d2d1 --- /dev/null +++ b/flyteplugins/go/tasks/v1/flytek8s/k8s_resource_adds.go @@ -0,0 +1,154 @@ +package flytek8s + +import ( + "context" + + "github.com/lyft/flytestdlib/contextutils" + v1 "k8s.io/api/core/v1" + + v12 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/sets" + + "github.com/lyft/flyteplugins/go/tasks/v1/flytek8s/config" + "github.com/lyft/flyteplugins/go/tasks/v1/types" + "github.com/lyft/flyteplugins/go/tasks/v1/utils" +) + +func GetContextEnvVars(ownerCtx context.Context) []v1.EnvVar { + envVars := []v1.EnvVar{} + + if ownerCtx == nil { + return envVars + } + + // Injecting useful env vars from the context + if wfName := contextutils.Value(ownerCtx, contextutils.WorkflowIDKey); wfName != "" { + envVars = append(envVars, + v1.EnvVar{ + Name: "FLYTE_INTERNAL_EXECUTION_WORKFLOW", + Value: wfName, + }, + ) + } + return envVars +} + +func GetExecutionEnvVars(id types.TaskExecutionID) []v1.EnvVar { + + if id == nil || id.GetID().NodeExecutionId == nil || id.GetID().NodeExecutionId.ExecutionId == nil { + return []v1.EnvVar{} + } + + // Execution level env variables. + nodeExecutionId := id.GetID().NodeExecutionId.ExecutionId + envVars := []v1.EnvVar{ + { + Name: "FLYTE_INTERNAL_EXECUTION_ID", + Value: nodeExecutionId.Name, + }, + { + Name: "FLYTE_INTERNAL_EXECUTION_PROJECT", + Value: nodeExecutionId.Project, + }, + { + Name: "FLYTE_INTERNAL_EXECUTION_DOMAIN", + Value: nodeExecutionId.Domain, + }, + // TODO: Fill in these + // { + // Name: "FLYTE_INTERNAL_EXECUTION_WORKFLOW", + // Value: "", + // }, + // { + // Name: "FLYTE_INTERNAL_EXECUTION_LAUNCHPLAN", + // Value: "", + // }, + } + + // Task definition Level env variables. + if id.GetID().TaskId != nil { + taskId := id.GetID().TaskId + + envVars = append(envVars, + v1.EnvVar{ + Name: "FLYTE_INTERNAL_TASK_PROJECT", + Value: taskId.Project, + }, + v1.EnvVar{ + Name: "FLYTE_INTERNAL_TASK_DOMAIN", + Value: taskId.Domain, + }, + v1.EnvVar{ + Name: "FLYTE_INTERNAL_TASK_NAME", + Value: taskId.Name, + }, + v1.EnvVar{ + Name: "FLYTE_INTERNAL_TASK_VERSION", + Value: taskId.Version, + }, + // Historic Task Definition Level env variables. + // Remove these once SDK is migrated to use the new ones. + v1.EnvVar{ + Name: "FLYTE_INTERNAL_PROJECT", + Value: taskId.Project, + }, + v1.EnvVar{ + Name: "FLYTE_INTERNAL_DOMAIN", + Value: taskId.Domain, + }, + v1.EnvVar{ + Name: "FLYTE_INTERNAL_NAME", + Value: taskId.Name, + }, + v1.EnvVar{ + Name: "FLYTE_INTERNAL_VERSION", + Value: taskId.Version, + }) + + } + return envVars +} + +func DecorateEnvVars(ctx context.Context, envVars []v1.EnvVar, id types.TaskExecutionID) []v1.EnvVar { + // Injecting workflow name into the container's env vars + envVars = append(envVars, GetContextEnvVars(ctx)...) + + envVars = append(envVars, GetExecutionEnvVars(id)...) + + for k, v := range config.GetK8sPluginConfig().DefaultEnvVars { + envVars = append(envVars, v1.EnvVar{Name: k, Value: v}) + } + return envVars +} + +func GetTolerationsForResources(resourceRequirements ...v1.ResourceRequirements) []v1.Toleration { + var tolerations []v1.Toleration + resourceNames := sets.NewString() + for _, resources := range resourceRequirements { + for r := range resources.Limits { + resourceNames.Insert(r.String()) + } + for r := range resources.Requests { + resourceNames.Insert(r.String()) + } + } + resourceTols := config.GetK8sPluginConfig().ResourceTolerations + for _, r := range resourceNames.UnsortedList() { + if v, ok := resourceTols[v1.ResourceName(r)]; ok { + tolerations = append(tolerations, v...) + } + } + return tolerations +} + +func AddObjectMetadata(taskCtx types.TaskContext, o K8sResource) { + o.SetNamespace(taskCtx.GetNamespace()) + o.SetAnnotations(UnionMaps(config.GetK8sPluginConfig().DefaultAnnotations, o.GetAnnotations(), utils.CopyMap(taskCtx.GetAnnotations()))) + o.SetLabels(UnionMaps(o.GetLabels(), utils.CopyMap(taskCtx.GetLabels()), config.GetK8sPluginConfig().DefaultLabels)) + o.SetOwnerReferences([]v12.OwnerReference{taskCtx.GetOwnerReference()}) + o.SetName(taskCtx.GetTaskExecutionID().GetGeneratedName()) + if config.GetK8sPluginConfig().InjectFinalizer { + f := append(o.GetFinalizers(), finalizer) + o.SetFinalizers(f) + } +} diff --git a/flyteplugins/go/tasks/v1/flytek8s/k8s_resource_adds_test.go b/flyteplugins/go/tasks/v1/flytek8s/k8s_resource_adds_test.go new file mode 100755 index 0000000000..ef1a623d56 --- /dev/null +++ b/flyteplugins/go/tasks/v1/flytek8s/k8s_resource_adds_test.go @@ -0,0 +1,283 @@ +package flytek8s + +import ( + "context" + "reflect" + "testing" + + "github.com/lyft/flytestdlib/contextutils" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/stretchr/testify/assert" + v12 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/lyft/flyteplugins/go/tasks/v1/flytek8s/config" + "github.com/lyft/flyteplugins/go/tasks/v1/types" + "github.com/lyft/flyteplugins/go/tasks/v1/types/mocks" +) + +func getMockTaskContext() *mocks.TaskContext { + taskCtx := &mocks.TaskContext{} + taskCtx.On("GetNamespace").Return("ns") + taskCtx.On("GetAnnotations").Return(map[string]string{"aKey": "aVal"}) + taskCtx.On("GetLabels").Return(map[string]string{"lKey": "lVal"}) + taskCtx.On("GetOwnerReference").Return(v1.OwnerReference{Name: "x"}) + + id := &mocks.TaskExecutionID{} + id.On("GetGeneratedName").Return("test") + taskCtx.On("GetTaskExecutionID").Return(id) + return taskCtx +} + +func assertObjectAndTaskCtx(t *testing.T, taskCtx types.TaskContext, resource K8sResource) { + assert.Equal(t, taskCtx.GetTaskExecutionID().GetGeneratedName(), resource.GetName()) + assert.Equal(t, []v1.OwnerReference{taskCtx.GetOwnerReference()}, resource.GetOwnerReferences()) + assert.Equal(t, taskCtx.GetNamespace(), resource.GetNamespace()) + assert.Equal(t, map[string]string{ + "cluster-autoscaler.kubernetes.io/safe-to-evict": "false", + "aKey": "aVal", + }, resource.GetAnnotations()) + assert.Equal(t, taskCtx.GetLabels(), resource.GetLabels()) +} + +func TestAddObjectMetadata(t *testing.T) { + taskCtx := getMockTaskContext() + o := &v12.Pod{} + AddObjectMetadata(taskCtx, o) + assertObjectAndTaskCtx(t, taskCtx, o) +} + +func TestGetExecutionEnvVars(t *testing.T) { + mock := mockTaskExecutionIdentifier{} + envVars := GetExecutionEnvVars(mock) + assert.Len(t, envVars, 11) +} + +func TestGetTolerationsForResources(t *testing.T) { + var empty []v12.Toleration + var emptyConfig map[v12.ResourceName][]v12.Toleration + + tolGPU := v12.Toleration{ + Key: "flyte/gpu", + Value: "dedicated", + Operator: v12.TolerationOpEqual, + Effect: v12.TaintEffectNoSchedule, + } + + tolStorage := v12.Toleration{ + Key: "storage", + Value: "dedicated", + Operator: v12.TolerationOpExists, + Effect: v12.TaintEffectNoSchedule, + } + + type args struct { + resources v12.ResourceRequirements + } + tests := []struct { + name string + args args + setVal map[v12.ResourceName][]v12.Toleration + want []v12.Toleration + }{ + { + "no-tolerations-limits", + args{ + v12.ResourceRequirements{ + Limits: v12.ResourceList{ + v12.ResourceCPU: resource.MustParse("1024m"), + v12.ResourceStorage: resource.MustParse("100M"), + }, + }, + }, + emptyConfig, + empty, + }, + { + "no-tolerations-req", + args{ + v12.ResourceRequirements{ + Requests: v12.ResourceList{ + v12.ResourceCPU: resource.MustParse("1024m"), + v12.ResourceStorage: resource.MustParse("100M"), + }, + }, + }, + emptyConfig, + empty, + }, + { + "no-tolerations-both", + args{ + v12.ResourceRequirements{ + Limits: v12.ResourceList{ + v12.ResourceCPU: resource.MustParse("1024m"), + v12.ResourceStorage: resource.MustParse("100M"), + }, + Requests: v12.ResourceList{ + v12.ResourceCPU: resource.MustParse("1024m"), + v12.ResourceStorage: resource.MustParse("100M"), + }, + }, + }, + emptyConfig, + empty, + }, + { + "tolerations-limits", + args{ + v12.ResourceRequirements{ + Limits: v12.ResourceList{ + v12.ResourceCPU: resource.MustParse("1024m"), + v12.ResourceStorage: resource.MustParse("100M"), + }, + }, + }, + map[v12.ResourceName][]v12.Toleration{ + v12.ResourceStorage: {tolStorage}, + ResourceNvidiaGPU: {tolGPU}, + }, + []v12.Toleration{tolStorage}, + }, + { + "tolerations-req", + args{ + v12.ResourceRequirements{ + Requests: v12.ResourceList{ + v12.ResourceCPU: resource.MustParse("1024m"), + v12.ResourceStorage: resource.MustParse("100M"), + }, + }, + }, + map[v12.ResourceName][]v12.Toleration{ + v12.ResourceStorage: {tolStorage}, + ResourceNvidiaGPU: {tolGPU}, + }, + []v12.Toleration{tolStorage}, + }, + { + "tolerations-both", + args{ + v12.ResourceRequirements{ + Limits: v12.ResourceList{ + v12.ResourceCPU: resource.MustParse("1024m"), + v12.ResourceStorage: resource.MustParse("100M"), + }, + Requests: v12.ResourceList{ + v12.ResourceCPU: resource.MustParse("1024m"), + v12.ResourceStorage: resource.MustParse("100M"), + }, + }, + }, + map[v12.ResourceName][]v12.Toleration{ + v12.ResourceStorage: {tolStorage}, + ResourceNvidiaGPU: {tolGPU}, + }, + []v12.Toleration{tolStorage}, + }, + { + "no-tolerations-both", + args{ + v12.ResourceRequirements{ + Limits: v12.ResourceList{ + v12.ResourceCPU: resource.MustParse("1024m"), + v12.ResourceStorage: resource.MustParse("100M"), + ResourceNvidiaGPU: resource.MustParse("1"), + }, + Requests: v12.ResourceList{ + v12.ResourceCPU: resource.MustParse("1024m"), + v12.ResourceStorage: resource.MustParse("100M"), + }, + }, + }, + map[v12.ResourceName][]v12.Toleration{ + v12.ResourceStorage: {tolStorage}, + ResourceNvidiaGPU: {tolGPU}, + }, + []v12.Toleration{tolStorage, tolGPU}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{ResourceTolerations: tt.setVal})) + if got := GetTolerationsForResources(tt.args.resources); len(got) != len(tt.want) { + t.Errorf("GetTolerationsForResources() = %v, want %v", got, tt.want) + } else { + for _, tol := range tt.want { + assert.Contains(t, got, tol) + } + } + }) + } +} + +var testTaskExecutionIdentifier = core.TaskExecutionIdentifier{ + TaskId: &core.Identifier{ + ResourceType: core.ResourceType_TASK, + Project: "proj", + Domain: "domain", + Name: "name", + }, + NodeExecutionId: &core.NodeExecutionIdentifier{ + NodeId: "nodeId", + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "proj", + Domain: "domain", + Name: "name", + }, + }, +} + +type mockTaskExecutionIdentifier struct{} + +func (m mockTaskExecutionIdentifier) GetID() core.TaskExecutionIdentifier { + return testTaskExecutionIdentifier +} + +func (m mockTaskExecutionIdentifier) GetGeneratedName() string { + return "task-exec-name" +} + +func TestDecorateEnvVars(t *testing.T) { + ctx := context.Background() + ctx = contextutils.WithWorkflowID(ctx, "fake_workflow") + + defaultEnv := []v12.EnvVar{ + { + Name: "x", + Value: "y", + }, + } + additionalEnv := map[string]string{ + "k": "v", + } + var emptyEnvVar map[string]string + + expected := append(defaultEnv, GetContextEnvVars(ctx)...) + expected = append(expected, GetExecutionEnvVars(mockTaskExecutionIdentifier{})...) + + aggregated := append(expected, v12.EnvVar{Name: "k", Value: "v"}) + type args struct { + envVars []v12.EnvVar + id types.TaskExecutionID + } + tests := []struct { + name string + args args + additionEnvVar map[string]string + want []v12.EnvVar + }{ + {"no-additional", args{envVars: defaultEnv, id: mockTaskExecutionIdentifier{}}, emptyEnvVar, expected}, + {"with-additional", args{envVars: defaultEnv, id: mockTaskExecutionIdentifier{}}, additionalEnv, aggregated}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{DefaultEnvVars: tt.additionEnvVar})) + if got := DecorateEnvVars(ctx, tt.args.envVars, tt.args.id); !reflect.DeepEqual(got, tt.want) { + t.Errorf("DecorateEnvVars() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/flyteplugins/go/tasks/v1/flytek8s/mocks/Cache.go b/flyteplugins/go/tasks/v1/flytek8s/mocks/Cache.go new file mode 100755 index 0000000000..5fea98afe0 --- /dev/null +++ b/flyteplugins/go/tasks/v1/flytek8s/mocks/Cache.go @@ -0,0 +1,25 @@ +package mocks + +import ( + v1 "k8s.io/api/apps/v1" + "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/api/meta" + "sigs.k8s.io/controller-runtime/pkg/cache/informertest" +) +import "context" +import "k8s.io/apimachinery/pkg/runtime" +import "k8s.io/apimachinery/pkg/types" + +type Cache struct { + informertest.FakeInformers +} + +// Get on this mock will always return object not found. +func (_m *Cache) Get(ctx context.Context, key types.NamespacedName, obj runtime.Object) error { + accessor, err := meta.Accessor(obj) + if err != nil { + return err + } + + return errors.NewNotFound(v1.Resource("pod"), accessor.GetName()) +} diff --git a/flyteplugins/go/tasks/v1/flytek8s/mocks/Handler.go b/flyteplugins/go/tasks/v1/flytek8s/mocks/Handler.go new file mode 100755 index 0000000000..89445eb0ea --- /dev/null +++ b/flyteplugins/go/tasks/v1/flytek8s/mocks/Handler.go @@ -0,0 +1,27 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import context "context" + +import mock "github.com/stretchr/testify/mock" +import runtime "k8s.io/apimachinery/pkg/runtime" + +// Handler is an autogenerated mock type for the Handler type +type Handler struct { + mock.Mock +} + +// Handle provides a mock function with given fields: _a0, _a1 +func (_m *Handler) Handle(_a0 context.Context, _a1 runtime.Object) error { + ret := _m.Called(_a0, _a1) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, runtime.Object) error); ok { + r0 = rf(_a0, _a1) + } else { + r0 = ret.Error(0) + } + + return r0 +} diff --git a/flyteplugins/go/tasks/v1/flytek8s/mocks/K8sResource.go b/flyteplugins/go/tasks/v1/flytek8s/mocks/K8sResource.go new file mode 100755 index 0000000000..05b409b3ff --- /dev/null +++ b/flyteplugins/go/tasks/v1/flytek8s/mocks/K8sResource.go @@ -0,0 +1,383 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" +import runtime "k8s.io/apimachinery/pkg/runtime" +import schema "k8s.io/apimachinery/pkg/runtime/schema" +import types "k8s.io/apimachinery/pkg/types" +import v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + +// K8sResource is an autogenerated mock type for the K8sResource type +type K8sResource struct { + mock.Mock +} + +// DeepCopyObject provides a mock function with given fields: +func (_m *K8sResource) DeepCopyObject() runtime.Object { + ret := _m.Called() + + var r0 runtime.Object + if rf, ok := ret.Get(0).(func() runtime.Object); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(runtime.Object) + } + } + + return r0 +} + +// GetAnnotations provides a mock function with given fields: +func (_m *K8sResource) GetAnnotations() map[string]string { + ret := _m.Called() + + var r0 map[string]string + if rf, ok := ret.Get(0).(func() map[string]string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]string) + } + } + + return r0 +} + +// GetClusterName provides a mock function with given fields: +func (_m *K8sResource) GetClusterName() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// GetCreationTimestamp provides a mock function with given fields: +func (_m *K8sResource) GetCreationTimestamp() v1.Time { + ret := _m.Called() + + var r0 v1.Time + if rf, ok := ret.Get(0).(func() v1.Time); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1.Time) + } + + return r0 +} + +// GetDeletionGracePeriodSeconds provides a mock function with given fields: +func (_m *K8sResource) GetDeletionGracePeriodSeconds() *int64 { + ret := _m.Called() + + var r0 *int64 + if rf, ok := ret.Get(0).(func() *int64); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*int64) + } + } + + return r0 +} + +// GetDeletionTimestamp provides a mock function with given fields: +func (_m *K8sResource) GetDeletionTimestamp() *v1.Time { + ret := _m.Called() + + var r0 *v1.Time + if rf, ok := ret.Get(0).(func() *v1.Time); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1.Time) + } + } + + return r0 +} + +// GetFinalizers provides a mock function with given fields: +func (_m *K8sResource) GetFinalizers() []string { + ret := _m.Called() + + var r0 []string + if rf, ok := ret.Get(0).(func() []string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + return r0 +} + +// GetGenerateName provides a mock function with given fields: +func (_m *K8sResource) GetGenerateName() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// GetGeneration provides a mock function with given fields: +func (_m *K8sResource) GetGeneration() int64 { + ret := _m.Called() + + var r0 int64 + if rf, ok := ret.Get(0).(func() int64); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(int64) + } + + return r0 +} + +// GetInitializers provides a mock function with given fields: +func (_m *K8sResource) GetInitializers() *v1.Initializers { + ret := _m.Called() + + var r0 *v1.Initializers + if rf, ok := ret.Get(0).(func() *v1.Initializers); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1.Initializers) + } + } + + return r0 +} + +// GetLabels provides a mock function with given fields: +func (_m *K8sResource) GetLabels() map[string]string { + ret := _m.Called() + + var r0 map[string]string + if rf, ok := ret.Get(0).(func() map[string]string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]string) + } + } + + return r0 +} + +// GetName provides a mock function with given fields: +func (_m *K8sResource) GetName() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// GetNamespace provides a mock function with given fields: +func (_m *K8sResource) GetNamespace() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// GetObjectKind provides a mock function with given fields: +func (_m *K8sResource) GetObjectKind() schema.ObjectKind { + ret := _m.Called() + + var r0 schema.ObjectKind + if rf, ok := ret.Get(0).(func() schema.ObjectKind); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(schema.ObjectKind) + } + } + + return r0 +} + +// GetOwnerReferences provides a mock function with given fields: +func (_m *K8sResource) GetOwnerReferences() []v1.OwnerReference { + ret := _m.Called() + + var r0 []v1.OwnerReference + if rf, ok := ret.Get(0).(func() []v1.OwnerReference); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]v1.OwnerReference) + } + } + + return r0 +} + +// GetResourceVersion provides a mock function with given fields: +func (_m *K8sResource) GetResourceVersion() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// GetSelfLink provides a mock function with given fields: +func (_m *K8sResource) GetSelfLink() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// GetUID provides a mock function with given fields: +func (_m *K8sResource) GetUID() types.UID { + ret := _m.Called() + + var r0 types.UID + if rf, ok := ret.Get(0).(func() types.UID); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(types.UID) + } + + return r0 +} + +// GroupVersionKind provides a mock function with given fields: +func (_m *K8sResource) GroupVersionKind() schema.GroupVersionKind { + ret := _m.Called() + + var r0 schema.GroupVersionKind + if rf, ok := ret.Get(0).(func() schema.GroupVersionKind); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(schema.GroupVersionKind) + } + + return r0 +} + +// SetAnnotations provides a mock function with given fields: annotations +func (_m *K8sResource) SetAnnotations(annotations map[string]string) { + _m.Called(annotations) +} + +// SetClusterName provides a mock function with given fields: clusterName +func (_m *K8sResource) SetClusterName(clusterName string) { + _m.Called(clusterName) +} + +// SetCreationTimestamp provides a mock function with given fields: timestamp +func (_m *K8sResource) SetCreationTimestamp(timestamp v1.Time) { + _m.Called(timestamp) +} + +// SetDeletionGracePeriodSeconds provides a mock function with given fields: _a0 +func (_m *K8sResource) SetDeletionGracePeriodSeconds(_a0 *int64) { + _m.Called(_a0) +} + +// SetDeletionTimestamp provides a mock function with given fields: timestamp +func (_m *K8sResource) SetDeletionTimestamp(timestamp *v1.Time) { + _m.Called(timestamp) +} + +// SetFinalizers provides a mock function with given fields: finalizers +func (_m *K8sResource) SetFinalizers(finalizers []string) { + _m.Called(finalizers) +} + +// SetGenerateName provides a mock function with given fields: name +func (_m *K8sResource) SetGenerateName(name string) { + _m.Called(name) +} + +// SetGeneration provides a mock function with given fields: generation +func (_m *K8sResource) SetGeneration(generation int64) { + _m.Called(generation) +} + +// SetGroupVersionKind provides a mock function with given fields: kind +func (_m *K8sResource) SetGroupVersionKind(kind schema.GroupVersionKind) { + _m.Called(kind) +} + +// SetInitializers provides a mock function with given fields: initializers +func (_m *K8sResource) SetInitializers(initializers *v1.Initializers) { + _m.Called(initializers) +} + +// SetLabels provides a mock function with given fields: labels +func (_m *K8sResource) SetLabels(labels map[string]string) { + _m.Called(labels) +} + +// SetName provides a mock function with given fields: name +func (_m *K8sResource) SetName(name string) { + _m.Called(name) +} + +// SetNamespace provides a mock function with given fields: namespace +func (_m *K8sResource) SetNamespace(namespace string) { + _m.Called(namespace) +} + +// SetOwnerReferences provides a mock function with given fields: _a0 +func (_m *K8sResource) SetOwnerReferences(_a0 []v1.OwnerReference) { + _m.Called(_a0) +} + +// SetResourceVersion provides a mock function with given fields: version +func (_m *K8sResource) SetResourceVersion(version string) { + _m.Called(version) +} + +// SetSelfLink provides a mock function with given fields: selfLink +func (_m *K8sResource) SetSelfLink(selfLink string) { + _m.Called(selfLink) +} + +// SetUID provides a mock function with given fields: uid +func (_m *K8sResource) SetUID(uid types.UID) { + _m.Called(uid) +} diff --git a/flyteplugins/go/tasks/v1/flytek8s/mocks/K8sResourceHandler.go b/flyteplugins/go/tasks/v1/flytek8s/mocks/K8sResourceHandler.go new file mode 100755 index 0000000000..292d056245 --- /dev/null +++ b/flyteplugins/go/tasks/v1/flytek8s/mocks/K8sResourceHandler.go @@ -0,0 +1,91 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import context "context" +import core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" +import events "github.com/lyft/flyteplugins/go/tasks/v1/events" +import flytek8s "github.com/lyft/flyteplugins/go/tasks/v1/flytek8s" +import mock "github.com/stretchr/testify/mock" +import types "github.com/lyft/flyteplugins/go/tasks/v1/types" + +// K8sResourceHandler is an autogenerated mock type for the K8sResourceHandler type +type K8sResourceHandler struct { + mock.Mock +} + +// BuildIdentityResource provides a mock function with given fields: ctx, taskCtx +func (_m *K8sResourceHandler) BuildIdentityResource(ctx context.Context, taskCtx types.TaskContext) (flytek8s.K8sResource, error) { + ret := _m.Called(ctx, taskCtx) + + var r0 flytek8s.K8sResource + if rf, ok := ret.Get(0).(func(context.Context, types.TaskContext) flytek8s.K8sResource); ok { + r0 = rf(ctx, taskCtx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(flytek8s.K8sResource) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, types.TaskContext) error); ok { + r1 = rf(ctx, taskCtx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// BuildResource provides a mock function with given fields: ctx, taskCtx, task, inputs +func (_m *K8sResourceHandler) BuildResource(ctx context.Context, taskCtx types.TaskContext, task *core.TaskTemplate, inputs *core.LiteralMap) (flytek8s.K8sResource, error) { + ret := _m.Called(ctx, taskCtx, task, inputs) + + var r0 flytek8s.K8sResource + if rf, ok := ret.Get(0).(func(context.Context, types.TaskContext, *core.TaskTemplate, *core.LiteralMap) flytek8s.K8sResource); ok { + r0 = rf(ctx, taskCtx, task, inputs) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(flytek8s.K8sResource) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, types.TaskContext, *core.TaskTemplate, *core.LiteralMap) error); ok { + r1 = rf(ctx, taskCtx, task, inputs) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetTaskStatus provides a mock function with given fields: ctx, taskCtx, resource +func (_m *K8sResourceHandler) GetTaskStatus(ctx context.Context, taskCtx types.TaskContext, resource flytek8s.K8sResource) (types.TaskStatus, *events.TaskEventInfo, error) { + ret := _m.Called(ctx, taskCtx, resource) + + var r0 types.TaskStatus + if rf, ok := ret.Get(0).(func(context.Context, types.TaskContext, flytek8s.K8sResource) types.TaskStatus); ok { + r0 = rf(ctx, taskCtx, resource) + } else { + r0 = ret.Get(0).(types.TaskStatus) + } + + var r1 *events.TaskEventInfo + if rf, ok := ret.Get(1).(func(context.Context, types.TaskContext, flytek8s.K8sResource) *events.TaskEventInfo); ok { + r1 = rf(ctx, taskCtx, resource) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*events.TaskEventInfo) + } + } + + var r2 error + if rf, ok := ret.Get(2).(func(context.Context, types.TaskContext, flytek8s.K8sResource) error); ok { + r2 = rf(ctx, taskCtx, resource) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} diff --git a/flyteplugins/go/tasks/v1/flytek8s/mocks/RuntimeClient.go b/flyteplugins/go/tasks/v1/flytek8s/mocks/RuntimeClient.go new file mode 100755 index 0000000000..5fbf552902 --- /dev/null +++ b/flyteplugins/go/tasks/v1/flytek8s/mocks/RuntimeClient.go @@ -0,0 +1,142 @@ +package mocks + +import ( + "context" + "fmt" + "reflect" + "sync" + + "k8s.io/apimachinery/pkg/api/meta" + + "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +type MockRuntimeClient struct { + syncObj sync.RWMutex + Cache map[string]runtime.Object + CreateCb func(ctx context.Context, obj runtime.Object) (err error) + GetCb func(ctx context.Context, key client.ObjectKey, out runtime.Object) error + ListCb func(ctx context.Context, opts *client.ListOptions, list runtime.Object) error +} + +func formatKey(name types.NamespacedName, kind schema.GroupVersionKind) string { + key := fmt.Sprintf("%v:%v", name.String(), kind.String()) + return key +} + +func (m MockRuntimeClient) Get(ctx context.Context, key client.ObjectKey, out runtime.Object) error { + if m.GetCb != nil { + return m.GetCb(ctx, key, out) + } + return m.defaultGet(ctx, key, out) +} + +func (m MockRuntimeClient) Create(ctx context.Context, obj runtime.Object) (err error) { + if m.CreateCb != nil { + return m.CreateCb(ctx, obj) + } + return m.defaultCreate(ctx, obj) +} + +func (m MockRuntimeClient) List(ctx context.Context, opts *client.ListOptions, list runtime.Object) error { + if m.ListCb != nil { + return m.ListCb(ctx, opts, list) + } + return m.defaultList(ctx, opts, list) +} + +func (*MockRuntimeClient) Delete(ctx context.Context, obj runtime.Object, opts ...client.DeleteOptionFunc) error { + panic("implement me") +} + +func (m *MockRuntimeClient) Update(ctx context.Context, obj runtime.Object) error { + // TODO: split update/create, create should fail if already exists. + return m.Create(ctx, obj) +} + +func (*MockRuntimeClient) Status() client.StatusWriter { + panic("implement me") +} + +func (m MockRuntimeClient) defaultGet(ctx context.Context, key client.ObjectKey, out runtime.Object) error { + m.syncObj.RLock() + defer m.syncObj.RUnlock() + + item, found := m.Cache[formatKey(key, out.GetObjectKind().GroupVersionKind())] + if found { + // deep copy to avoid mutating cache + item = item.(runtime.Object).DeepCopyObject() + _, isUnstructured := out.(*unstructured.Unstructured) + if isUnstructured { + // Copy the value of the item in the cache to the returned value + outVal := reflect.ValueOf(out) + objVal := reflect.ValueOf(item) + if !objVal.Type().AssignableTo(outVal.Type()) { + return fmt.Errorf("cache had type %s, but %s was asked for", objVal.Type(), outVal.Type()) + } + reflect.Indirect(outVal).Set(reflect.Indirect(objVal)) + return nil + } + + p, err := runtime.DefaultUnstructuredConverter.ToUnstructured(item) + if err != nil { + return err + } + + return runtime.DefaultUnstructuredConverter.FromUnstructured(p, out) + } + + return errors.NewNotFound(schema.GroupResource{}, key.Name) +} + +func (m MockRuntimeClient) defaultList(ctx context.Context, opts *client.ListOptions, list runtime.Object) error { + m.syncObj.RLock() + defer m.syncObj.RUnlock() + + objs := make([]runtime.Object, 0, len(m.Cache)) + + for _, val := range m.Cache { + if opts.Raw != nil { + if val.GetObjectKind().GroupVersionKind().Kind != opts.Raw.Kind { + continue + } + + if val.GetObjectKind().GroupVersionKind().GroupVersion().String() != opts.Raw.APIVersion { + continue + } + } + + objs = append(objs, val.(runtime.Object).DeepCopyObject()) + } + + return meta.SetList(list, objs) +} + +func (m MockRuntimeClient) defaultCreate(ctx context.Context, obj runtime.Object) (err error) { + m.syncObj.Lock() + defer m.syncObj.Unlock() + + accessor, err := meta.Accessor(obj) + if err != nil { + return err + } + + m.Cache[formatKey(types.NamespacedName{ + Name: accessor.GetName(), + Namespace: accessor.GetNamespace(), + }, obj.GetObjectKind().GroupVersionKind())] = obj + + return nil +} + +func NewMockRuntimeClient() *MockRuntimeClient { + return &MockRuntimeClient{ + syncObj: sync.RWMutex{}, + Cache: map[string]runtime.Object{}, + } +} diff --git a/flyteplugins/go/tasks/v1/flytek8s/mux_handler.go b/flyteplugins/go/tasks/v1/flytek8s/mux_handler.go new file mode 100755 index 0000000000..0e7cd948d9 --- /dev/null +++ b/flyteplugins/go/tasks/v1/flytek8s/mux_handler.go @@ -0,0 +1,163 @@ +package flytek8s + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/lyft/flytestdlib/logger" + + "sigs.k8s.io/controller-runtime/pkg/cache/informertest" + "sigs.k8s.io/controller-runtime/pkg/client/config" + "sigs.k8s.io/controller-runtime/pkg/client/fake" + + "k8s.io/client-go/util/workqueue" + "sigs.k8s.io/controller-runtime/pkg/cache" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/event" + ctrlHandler "sigs.k8s.io/controller-runtime/pkg/handler" + "sigs.k8s.io/controller-runtime/pkg/runtime/inject" + "sigs.k8s.io/controller-runtime/pkg/source" + + "k8s.io/apimachinery/pkg/runtime" +) + +var instance *flytek8s +var once sync.Once + +type flytek8s struct { + watchNamespace string + kubeClient client.Client + informersCache cache.Cache +} + +func (f *flytek8s) InjectClient(c client.Client) error { + f.kubeClient = c + return nil +} + +func InjectClient(c client.Client) error { + if instance == nil { + return fmt.Errorf("instance not initialized") + } + + return instance.InjectClient(c) +} + +func (f *flytek8s) InjectCache(c cache.Cache) error { + f.informersCache = c + return nil +} + +func InjectCache(c cache.Cache) error { + if instance == nil { + return fmt.Errorf("instance not initialized") + } + + return instance.InjectCache(c) +} + +func InitializeFake() client.Client { + once.Do(func() { + instance = &flytek8s{ + watchNamespace: "", + } + + instance.kubeClient = fake.NewFakeClient() + instance.informersCache = &informertest.FakeInformers{} + }) + + return instance.kubeClient +} + +func Initialize(ctx context.Context, watchNamespace string, resyncPeriod time.Duration) (err error) { + once.Do(func() { + instance = &flytek8s{ + watchNamespace: watchNamespace, + } + + kubeConfig := config.GetConfigOrDie() + instance.kubeClient, err = client.New(kubeConfig, client.Options{}) + if err != nil { + return + } + + instance.informersCache, err = cache.New(kubeConfig, cache.Options{ + Namespace: watchNamespace, + Resync: &resyncPeriod, + }) + + if err == nil { + go func() { + logger.Infof(ctx, "Starting informers cache.") + err = instance.informersCache.Start(ctx.Done()) + if err != nil { + logger.Panicf(ctx, "Failed to start informers cache. Error: %v", err) + } + }() + } + }) + + if err != nil { + return err + } + + if watchNamespace != instance.watchNamespace { + return fmt.Errorf("flytek8s is supposed to be used under single namespace."+ + " configured-for: %v, requested-for: %v", instance.watchNamespace, watchNamespace) + } + + return nil +} + +func RegisterResource(ctx context.Context, resourceToWatch runtime.Object, handler Handler) error { + if instance == nil { + return fmt.Errorf("instance not initialized") + } + + if handler == nil { + return fmt.Errorf("nil Handler for resource %s", resourceToWatch.GetObjectKind()) + } + + src := source.Kind{ + Type: resourceToWatch, + } + + if _, err := inject.CacheInto(instance.informersCache, &src); err != nil { + return err + } + + // TODO: a more unique workqueue name + q := workqueue.NewNamedRateLimitingQueue(workqueue.DefaultControllerRateLimiter(), + resourceToWatch.GetObjectKind().GroupVersionKind().Kind) + + err := src.Start(ctrlHandler.Funcs{ + CreateFunc: func(evt event.CreateEvent, q2 workqueue.RateLimitingInterface) { + err := handler.Handle(ctx, evt.Object) + if err != nil { + logger.Warnf(ctx, "Failed to handle Create event for object [%v]", evt.Object) + } + }, + UpdateFunc: func(evt event.UpdateEvent, q2 workqueue.RateLimitingInterface) { + err := handler.Handle(ctx, evt.ObjectNew) + if err != nil { + logger.Warnf(ctx, "Failed to handle Update event for object [%v]", evt.ObjectNew) + } + }, + DeleteFunc: func(evt event.DeleteEvent, q2 workqueue.RateLimitingInterface) { + err := handler.Handle(ctx, evt.Object) + if err != nil { + logger.Warnf(ctx, "Failed to handle Delete event for object [%v]", evt.Object) + } + }, + GenericFunc: func(evt event.GenericEvent, q2 workqueue.RateLimitingInterface) { + err := handler.Handle(ctx, evt.Object) + if err != nil { + logger.Warnf(ctx, "Failed to handle Generic event for object [%v]", evt.Object) + } + }, + }, q) + + return err +} diff --git a/flyteplugins/go/tasks/v1/flytek8s/plugin_executor.go b/flyteplugins/go/tasks/v1/flytek8s/plugin_executor.go new file mode 100755 index 0000000000..652e9775f2 --- /dev/null +++ b/flyteplugins/go/tasks/v1/flytek8s/plugin_executor.go @@ -0,0 +1,357 @@ +package flytek8s + +import ( + "context" + "time" + + "github.com/lyft/flytestdlib/contextutils" + + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/promutils/labeled" + + "github.com/lyft/flyteplugins/go/tasks/v1/events" + + k8stypes "k8s.io/apimachinery/pkg/types" + + "strings" + + eventErrors "github.com/lyft/flyteidl/clients/go/events/errors" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flyteplugins/go/tasks/v1/errors" + "github.com/lyft/flyteplugins/go/tasks/v1/types" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/storage" + k8serrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" +) + +// A generic task executor for k8s-resource reliant tasks. +type K8sTaskExecutor struct { + types.OutputsResolver + recorder types.EventRecorder + id string + handler K8sResourceHandler + resyncPeriod time.Duration + // Supplied on Initialization + // TODO decide the right place to put these interfaces or late-bind them? + store storage.ComposedProtobufStore + resourceToWatch runtime.Object + metrics K8sTaskExecutorMetrics +} + +type K8sTaskExecutorMetrics struct { + Scope promutils.Scope + GetCacheMiss labeled.StopWatch + GetCacheHit labeled.StopWatch + GetAPILatency labeled.StopWatch + ResourceDeleted labeled.Counter +} + +type ownerRegisteringHandler struct { + ownerKind string + enqueueOwner types.EnqueueOwner +} + +// A common handle for all k8s-resource reliant task executors that push workflow id on the work queue. +func (h ownerRegisteringHandler) Handle(ctx context.Context, k8sObject runtime.Object) error { + object := k8sObject.(metav1.Object) + ownerReference := metav1.GetControllerOf(object) + namespace := object.GetNamespace() + if ownerReference == nil { + return nil + } + + if ownerReference.Kind == h.ownerKind { + return h.enqueueOwner(k8stypes.NamespacedName{Name: ownerReference.Name, Namespace: namespace}) + } + + // Had a log line here but it was way too verbose. Every pod owned by something other than a + // Flyte workflow would come up. + return nil +} + +func (e *K8sTaskExecutor) GetID() types.TaskExecutorName { + return e.id +} + +func (e *K8sTaskExecutor) GetProperties() types.ExecutorProperties { + return types.ExecutorProperties{} +} + +func (e *K8sTaskExecutor) Initialize(ctx context.Context, params types.ExecutorInitializationParameters) error { + + if params.DataStore == nil { + return errors.Errorf(errors.PluginInitializationFailed, "Failed to initialize plugin, data store cannot be nil or empty.") + } + + if params.EventRecorder == nil { + return errors.Errorf(errors.PluginInitializationFailed, "Failed to initialize plugin, event recorder cannot be nil or empty.") + } + + if params.EnqueueOwner == nil { + return errors.Errorf(errors.PluginInitializationFailed, "Failed to initialize plugin, enqueue Owner cannot be nil or empty.") + } + + e.store = params.DataStore + e.recorder = params.EventRecorder + e.OutputsResolver = types.NewOutputsResolver(params.DataStore) + + metricScope := params.MetricsScope.NewSubScope(e.GetID()) + e.metrics = K8sTaskExecutorMetrics{ + Scope: metricScope, + GetCacheMiss: labeled.NewStopWatch("get_cache_miss", "Cache miss on get resource calls.", + time.Millisecond, metricScope), + GetCacheHit: labeled.NewStopWatch("get_cache_hit", "Cache miss on get resource calls.", + time.Millisecond, metricScope), + GetAPILatency: labeled.NewStopWatch("get_api", "Latency for APIServer Get calls.", + time.Millisecond, metricScope), + ResourceDeleted: labeled.NewCounter("pods_deleted", "Counts how many times CheckTaskStatus is"+ + " called with a deleted resource.", metricScope), + } + + return RegisterResource(ctx, e.resourceToWatch, ownerRegisteringHandler{ + enqueueOwner: params.EnqueueOwner, + ownerKind: params.OwnerKind, + }) +} + +func (e K8sTaskExecutor) HandleTaskSuccess(ctx context.Context, taskCtx types.TaskContext) (types.TaskStatus, error) { + errorPath := taskCtx.GetErrorFile() + metadata, err := e.store.Head(ctx, errorPath) + if err != nil { + return types.TaskStatusRetryableFailure(errors.Wrapf(errors.MetadataAccessFailed, err, + "failed to read error file")), nil + } + + if metadata.Exists() { + if metadata.Size() > MaxMetadataPayloadSizeBytes { + return types.TaskStatusPermanentFailure(errors.Errorf(errors.MetadataTooLarge, + "metadata file is too large [%v] bytes, max allowed [%v] bytes", metadata.Size(), + MaxMetadataPayloadSizeBytes)), nil + } + errorDoc := &core.ErrorDocument{} + err = e.store.ReadProtobuf(ctx, errorPath, errorDoc) + if err != nil { + if storage.IsNotFound(err) { + return types.TaskStatusRetryableFailure(errors.Wrapf(errors.MetadataAccessFailed, err, + "metadata file not found but head returned exists")), nil + } + + return types.TaskStatusUndefined, errors.Wrapf(errors.DownstreamSystemError, err, + "failed to read error data from task @[%s]", errorPath) + } + + if errorDoc.Error == nil { + return types.TaskStatusRetryableFailure(errors.Errorf(errors.MetadataAccessFailed, + "error not formatted correctly, missing error @path [%v]", errorPath)), nil + } + + if errorDoc.Error.Kind == core.ContainerError_RECOVERABLE { + return types.TaskStatusRetryableFailure(errors.Errorf(errorDoc.Error.Code, + "user-error. Message: [%s]", errorDoc.Error.Message)), nil + } + + return types.TaskStatusPermanentFailure(errors.Errorf(errorDoc.Error.Code, + "user-error. Message: [%s]", errorDoc.Error.Message)), nil + } + + return types.TaskStatusSucceeded, nil +} + +func (e *K8sTaskExecutor) StartTask(ctx context.Context, taskCtx types.TaskContext, task *core.TaskTemplate, inputs *core.LiteralMap) ( + types.TaskStatus, error) { + + o, err := e.handler.BuildResource(ctx, taskCtx, task, inputs) + if err != nil { + return types.TaskStatusUndefined, err + } + + AddObjectMetadata(taskCtx, o) + logger.Infof(ctx, "Creating Object: Type:[%v], Object:[%v/%v]", o.GroupVersionKind(), o.GetNamespace(), o.GetName()) + + err = instance.kubeClient.Create(ctx, o) + if err != nil && !k8serrors.IsAlreadyExists(err) { + if k8serrors.IsBadRequest(err) { + logger.Errorf(ctx, "Bad Request. [%+v]", o) + return types.TaskStatusUndefined, err + } + if k8serrors.IsForbidden(err) { + if strings.Contains(err.Error(), "exceeded quota") { + // TODO: Quota errors are retried forever, it would be good to have support for backoff strategy. + logger.Warnf(ctx, "Failed to launch job, resource quota exceeded. Err: %v", err) + return types.TaskStatusNotReadyFailure(err), nil + } + return types.TaskStatusPermanentFailure(err), nil + } + logger.Errorf(ctx, "Failed to launch job, system error. Err: %v", err) + return types.TaskStatusUndefined, err + } + status := types.TaskStatusQueued + + ev := events.CreateEvent(taskCtx, status, nil) + err = e.recorder.RecordTaskEvent(ctx, ev) + if err != nil && eventErrors.IsEventAlreadyInTerminalStateError(err) { + return types.TaskStatusPermanentFailure(errors.Wrapf(errors.TaskEventRecordingFailed, err, + "failed to record task event. phase mis-match between Propeller %v and Control Plane.", &status.Phase)), nil + } else if err != nil { + return types.TaskStatusUndefined, errors.Wrapf(errors.TaskEventRecordingFailed, err, + "failed to record start task event") + } + + return status, nil +} + +func (e *K8sTaskExecutor) getResource(ctx context.Context, taskCtx types.TaskContext, o K8sResource) (types.TaskStatus, + *events.TaskEventInfo, error) { + + nsName := k8stypes.NamespacedName{Namespace: o.GetNamespace(), Name: o.GetName()} + start := time.Now() + // Attempt to get resource from informer cache, if not found, retrieve it from API server. + err := instance.informersCache.Get(ctx, nsName, o) + if err != nil && IsK8sObjectNotExists(err) { + e.metrics.GetCacheMiss.Observe(ctx, start, time.Now()) + e.metrics.GetAPILatency.Time(ctx, func() { + err = instance.kubeClient.Get(ctx, nsName, o) + }) + } else if err == nil { + e.metrics.GetCacheHit.Observe(ctx, start, time.Now()) + } + + if err != nil { + if IsK8sObjectNotExists(err) { + // This happens sometimes because a node gets removed and K8s deletes the pod. This will result in a + // Pod does not exist error. This should be retried using the retry policy + logger.Warningf(ctx, "Failed to find the Resource with name: %v. Error: %v", nsName, err) + return types.TaskStatusRetryableFailure(err), nil, nil + } + + logger.Warningf(ctx, "Failed to retrieve Resource Details with name: %v. Error: %v", nsName, err) + return types.TaskStatusUndefined, nil, err + } + + return e.handler.GetTaskStatus(ctx, taskCtx, o) +} + +func (e *K8sTaskExecutor) CheckTaskStatus(ctx context.Context, taskCtx types.TaskContext, task *core.TaskTemplate) ( + types.TaskStatus, error) { + + o, err := e.handler.BuildIdentityResource(ctx, taskCtx) + finalStatus := types.TaskStatus{ + Phase: taskCtx.GetPhase(), + PhaseVersion: taskCtx.GetPhaseVersion(), + } + + var info *events.TaskEventInfo + + if err != nil { + logger.Warningf(ctx, "Failed to build the Resource with name: %v. Error: %v", + taskCtx.GetTaskExecutionID().GetGeneratedName(), err) + finalStatus = types.TaskStatusPermanentFailure(err) + } else { + AddObjectMetadata(taskCtx, o) + finalStatus, info, err = e.getResource(ctx, taskCtx, o) + if err != nil { + return types.TaskStatusUndefined, err + } + + if o.GetDeletionTimestamp() != nil { + e.metrics.ResourceDeleted.Inc(ctx) + } + + if finalStatus.Phase == types.TaskPhaseSucceeded { + finalStatus, err = e.HandleTaskSuccess(ctx, taskCtx) + if err != nil { + return types.TaskStatusUndefined, err + } + } + } + + if finalStatus.Phase != taskCtx.GetPhase() { + ev := events.CreateEvent(taskCtx, finalStatus, info) + err := e.recorder.RecordTaskEvent(ctx, ev) + if err != nil && eventErrors.IsEventAlreadyInTerminalStateError(err) { + return types.TaskStatusPermanentFailure(errors.Wrapf(errors.TaskEventRecordingFailed, err, + "failed to record task event. phase mis-match between Propeller %v and Control Plane.", &ev.Phase)), nil + } else if err != nil { + return types.TaskStatusUndefined, errors.Wrapf(errors.TaskEventRecordingFailed, err, + "failed to record state transition [%v] -> [%v]", taskCtx.GetPhase(), finalStatus.Phase) + } + } + + // This must happen after sending admin event. It's safe against partial failures because if the event failed, we will + // simply retry in the next round. If the event succeeded but this failed, we will try again the next round to send + // the same event (idempotent) and then come here again... + if finalStatus.Phase.IsTerminal() && len(o.GetFinalizers()) > 0 { + err = e.ClearFinalizers(ctx, o) + if err != nil { + return types.TaskStatusUndefined, err + } + } + + // If the object has been deleted, that is, it has a deletion timestamp, but is not in a terminal state, we should + // mark the task as a retryable failure. We've seen this happen when a kubelet disappears - all pods running on + // the node are marked with a deletionTimestamp, but our finalizers prevent the pod from being deleted. + // This can also happen when a user deletes a Pod directly. + if !finalStatus.Phase.IsTerminal() && o.GetDeletionTimestamp() != nil && len(o.GetFinalizers()) > 0 { + err = e.ClearFinalizers(ctx, o) + if err != nil { + return types.TaskStatusUndefined, err + } + return types.TaskStatusRetryableFailure(finalStatus.Err), nil + } + + return finalStatus, nil +} + +func (e *K8sTaskExecutor) KillTask(ctx context.Context, taskCtx types.TaskContext, reason string) error { + logger.Infof(ctx, "KillTask invoked for %v, clearing finalizers.", taskCtx.GetTaskExecutionID().GetGeneratedName()) + + o, err := e.handler.BuildIdentityResource(ctx, taskCtx) + if err != nil { + logger.Warningf(ctx, "Failed to build the Resource with name: %v. Error: %v", taskCtx.GetTaskExecutionID().GetGeneratedName(), err) + return err + } + AddObjectMetadata(taskCtx, o) + // Retrieve the object from cache/etcd to get the last known version. + _, _, err = e.getResource(ctx, taskCtx, o) + if err != nil { + return err + } + + return e.ClearFinalizers(ctx, o) +} + +func (e *K8sTaskExecutor) ClearFinalizers(ctx context.Context, o K8sResource) error { + if len(o.GetFinalizers()) > 0 { + o.SetFinalizers([]string{}) + err := instance.kubeClient.Update(ctx, o) + + if err != nil && !IsK8sObjectNotExists(err) { + logger.Warningf(ctx, "Failed to clear finalizers for Resource with name: %v/%v. Error: %v", + o.GetNamespace(), o.GetName(), err) + return err + } + } else { + logger.Debugf(ctx, "Finalizers are already empty for Resource with name: %v/%v", + o.GetNamespace(), o.GetName()) + } + + return nil +} + +// Creates a K8s generic task executor. This provides an easier way to build task executors that create K8s resources. +func NewK8sTaskExecutorForResource(id string, resourceToWatch runtime.Object, handler K8sResourceHandler, + resyncPeriod time.Duration) *K8sTaskExecutor { + + return &K8sTaskExecutor{ + id: id, + handler: handler, + resourceToWatch: resourceToWatch, + resyncPeriod: resyncPeriod, + } +} + +func init() { + labeled.SetMetricKeys(contextutils.ProjectKey, contextutils.DomainKey, contextutils.WorkflowIDKey, contextutils.TaskIDKey) +} diff --git a/flyteplugins/go/tasks/v1/flytek8s/plugin_executor_test.go b/flyteplugins/go/tasks/v1/flytek8s/plugin_executor_test.go new file mode 100755 index 0000000000..95015beb50 --- /dev/null +++ b/flyteplugins/go/tasks/v1/flytek8s/plugin_executor_test.go @@ -0,0 +1,608 @@ +package flytek8s_test + +import ( + "bytes" + "errors" + "fmt" + eventErrors "github.com/lyft/flyteidl/clients/go/events/errors" + "testing" + "time" + + k8serrs "k8s.io/apimachinery/pkg/api/errors" + + taskerrs "github.com/lyft/flyteplugins/go/tasks/v1/errors" + + "github.com/lyft/flytestdlib/promutils" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/event" + + "github.com/lyft/flyteplugins/go/tasks/v1/events" + "github.com/lyft/flyteplugins/go/tasks/v1/flytek8s" + "github.com/lyft/flyteplugins/go/tasks/v1/flytek8s/mocks" + "github.com/lyft/flytestdlib/storage" + "github.com/stretchr/testify/mock" + "k8s.io/api/core/v1" + v12 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + + "context" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flyteplugins/go/tasks/v1/types" + mocks2 "github.com/lyft/flyteplugins/go/tasks/v1/types/mocks" + "github.com/stretchr/testify/assert" + k8sBatch "k8s.io/api/batch/v1" + "k8s.io/apimachinery/pkg/runtime/schema" + k8stypes "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client/fake" +) + +type k8sSampleHandler struct { +} + +func (k8sSampleHandler) BuildResource(ctx context.Context, taskCtx types.TaskContext, task *core.TaskTemplate, inputs *core.LiteralMap) (flytek8s.K8sResource, error) { + panic("implement me") +} + +func (k8sSampleHandler) BuildIdentityResource(ctx context.Context, taskCtx types.TaskContext) (flytek8s.K8sResource, error) { + panic("implement me") +} + +func (k8sSampleHandler) GetTaskStatus(ctx context.Context, taskCtx types.TaskContext, resource flytek8s.K8sResource) (types.TaskStatus, *events.TaskEventInfo, error) { + panic("implement me") +} + +func ExampleNewK8sTaskExecutorForResource() { + exec := flytek8s.NewK8sTaskExecutorForResource("SampleHandler", &k8sBatch.Job{}, k8sSampleHandler{}, time.Second*1) + fmt.Printf("Created executor: %v\n", exec.GetID()) + + // Output: + // Created executor: SampleHandler +} + +func getMockTaskContext() *mocks2.TaskContext { + taskCtx := &mocks2.TaskContext{} + taskCtx.On("GetNamespace").Return("ns") + taskCtx.On("GetAnnotations").Return(map[string]string{"aKey": "aVal"}) + taskCtx.On("GetLabels").Return(map[string]string{"lKey": "lVal"}) + taskCtx.On("GetOwnerReference").Return(v12.OwnerReference{Name: "x"}) + taskCtx.On("GetOutputsFile").Return(storage.DataReference("outputs")) + taskCtx.On("GetInputsFile").Return(storage.DataReference("inputs")) + taskCtx.On("GetErrorFile").Return(storage.DataReference("error")) + + id := &mocks2.TaskExecutionID{} + id.On("GetGeneratedName").Return("test") + id.On("GetID").Return(core.TaskExecutionIdentifier{}) + taskCtx.On("GetTaskExecutionID").Return(id) + return taskCtx +} + +func init() { + _ = flytek8s.InitializeFake() + cacheMock := mocks.Cache{} + if err := flytek8s.InjectCache(&cacheMock); err != nil { + panic(err) + } +} + +func createExecutorInitializationParams(t testing.TB, evtRecorder *mocks2.EventRecorder) types.ExecutorInitializationParameters { + store, err := storage.NewDataStore(&storage.Config{ + Type: storage.TypeMemory, + }, promutils.NewTestScope()) + assert.NoError(t, err) + + return types.ExecutorInitializationParameters{ + EventRecorder: evtRecorder, + DataStore: store, + EnqueueOwner: func(ownerId k8stypes.NamespacedName) error { return nil }, + OwnerKind: "x", + MetricsScope: promutils.NewTestScope(), + } +} + +func TestK8sTaskExecutor_StartTask(t *testing.T) { + + ctx := context.TODO() + tctx := getMockTaskContext() + var tmpl *core.TaskTemplate + var inputs *core.LiteralMap + c := flytek8s.InitializeFake() + + t.Run("jobQueued", func(t *testing.T) { + // common setup code + mockResourceHandler := &mocks.K8sResourceHandler{} + evRecorder := &mocks2.EventRecorder{} + k := flytek8s.NewK8sTaskExecutorForResource("x", &v1.Pod{}, mockResourceHandler, time.Second) + mockResourceHandler.On("BuildResource", mock.Anything, tctx, tmpl, inputs).Return(&v1.Pod{}, nil) + err := k.Initialize(ctx, createExecutorInitializationParams(t, evRecorder)) + assert.NoError(t, err) + + evRecorder.On("RecordTaskEvent", mock.MatchedBy(func(c context.Context) bool { return true }), + mock.MatchedBy(func(e *event.TaskExecutionEvent) bool { return e.Phase == core.TaskExecution_QUEUED })).Return(nil) + + status, err := k.StartTask(ctx, tctx, nil, nil) + assert.NoError(t, err) + assert.Nil(t, status.State) + assert.Equal(t, types.TaskPhaseQueued, status.Phase) + createdPod := &v1.Pod{} + flytek8s.AddObjectMetadata(tctx, createdPod) + assert.NoError(t, c.Get(ctx, k8stypes.NamespacedName{Namespace: tctx.GetNamespace(), Name: tctx.GetTaskExecutionID().GetGeneratedName()}, createdPod)) + assert.Equal(t, tctx.GetTaskExecutionID().GetGeneratedName(), createdPod.Name) + assert.NoError(t, c.Delete(ctx, createdPod)) + }) + + t.Run("jobAlreadyExists", func(t *testing.T) { + // common setup code + mockResourceHandler := &mocks.K8sResourceHandler{} + evRecorder := &mocks2.EventRecorder{} + k := flytek8s.NewK8sTaskExecutorForResource("x", &v1.Pod{}, mockResourceHandler, time.Second) + mockResourceHandler.On("BuildResource", mock.Anything, tctx, tmpl, inputs).Return(&v1.Pod{}, nil) + err := k.Initialize(ctx, createExecutorInitializationParams(t, evRecorder)) + assert.NoError(t, err) + + expectedNewStatus := types.TaskStatusQueued + mockResourceHandler.On("GetTaskStatus", mock.Anything, mock.Anything, mock.MatchedBy(func(o *v1.Pod) bool { return true })).Return(expectedNewStatus, nil, nil) + + evRecorder.On("RecordTaskEvent", mock.MatchedBy(func(c context.Context) bool { return true }), + mock.MatchedBy(func(e *event.TaskExecutionEvent) bool { return e.Phase == core.TaskExecution_QUEUED })).Return(nil) + + status, err := k.StartTask(ctx, tctx, nil, nil) + assert.NoError(t, err) + assert.Nil(t, status.State) + assert.Equal(t, types.TaskPhaseQueued, status.Phase) + createdPod := &v1.Pod{} + flytek8s.AddObjectMetadata(tctx, createdPod) + assert.NoError(t, c.Get(ctx, k8stypes.NamespacedName{Namespace: tctx.GetNamespace(), Name: tctx.GetTaskExecutionID().GetGeneratedName()}, createdPod)) + assert.Equal(t, tctx.GetTaskExecutionID().GetGeneratedName(), createdPod.Name) + }) + + t.Run("jobDifferentTerminalState", func(t *testing.T) { + // common setup code + mockResourceHandler := &mocks.K8sResourceHandler{} + evRecorder := &mocks2.EventRecorder{} + k := flytek8s.NewK8sTaskExecutorForResource("x", &v1.Pod{}, mockResourceHandler, time.Second) + mockResourceHandler.On("BuildResource", mock.Anything, tctx, tmpl, inputs).Return(&v1.Pod{}, nil) + err := k.Initialize(ctx, createExecutorInitializationParams(t, evRecorder)) + assert.NoError(t, err) + + expectedNewStatus := types.TaskStatusQueued + mockResourceHandler.On("GetTaskStatus", mock.Anything, mock.Anything, mock.MatchedBy(func(o *v1.Pod) bool { return true })).Return(expectedNewStatus, nil, nil) + + evRecorder.On("RecordTaskEvent", mock.MatchedBy(func(c context.Context) bool { return true }), + mock.MatchedBy(func(e *event.TaskExecutionEvent) bool { + return e.Phase == core.TaskExecution_QUEUED + })).Return(&eventErrors.EventError{Code: eventErrors.EventAlreadyInTerminalStateError, + Cause: errors.New("already exists"), + }) + + status, err := k.StartTask(ctx, tctx, nil, nil) + assert.NoError(t, err) + assert.Nil(t, status.State) + assert.Equal(t, types.TaskPhasePermanentFailure, status.Phase) + }) + + t.Run("jobQuotaExceeded", func(t *testing.T) { + // common setup code + mockResourceHandler := &mocks.K8sResourceHandler{} + evRecorder := &mocks2.EventRecorder{} + k := flytek8s.NewK8sTaskExecutorForResource("x", &v1.Pod{}, mockResourceHandler, time.Second) + mockResourceHandler.On("BuildResource", mock.Anything, tctx, tmpl, inputs).Return(&v1.Pod{}, nil) + err := k.Initialize(ctx, createExecutorInitializationParams(t, evRecorder)) + assert.NoError(t, err) + + evRecorder.On("RecordTaskEvent", mock.MatchedBy(func(c context.Context) bool { return true }), + mock.MatchedBy(func(e *event.TaskExecutionEvent) bool { return e.Phase == core.TaskExecution_QUEUED })).Return(nil) + + // override create to return quota exceeded + mockRuntimeClient := mocks.NewMockRuntimeClient() + mockRuntimeClient.CreateCb = func(ctx context.Context, obj runtime.Object) (err error) { + return k8serrs.NewForbidden(schema.GroupResource{}, "", errors.New("exceeded quota")) + } + if err := flytek8s.InjectClient(mockRuntimeClient); err != nil { + assert.NoError(t, err) + } + + status, err := k.StartTask(ctx, tctx, nil, nil) + assert.NoError(t, err) + assert.Nil(t, status.State) + assert.Equal(t, types.TaskPhaseNotReady, status.Phase) + + // reset the client back to fake client + if err := flytek8s.InjectClient(fake.NewFakeClient()); err != nil { + assert.NoError(t, err) + } + }) + + t.Run("jobForbidden", func(t *testing.T) { + // common setup code + mockResourceHandler := &mocks.K8sResourceHandler{} + evRecorder := &mocks2.EventRecorder{} + k := flytek8s.NewK8sTaskExecutorForResource("x", &v1.Pod{}, mockResourceHandler, time.Second) + mockResourceHandler.On("BuildResource", mock.Anything, tctx, tmpl, inputs).Return(&v1.Pod{}, nil) + err := k.Initialize(ctx, createExecutorInitializationParams(t, evRecorder)) + assert.NoError(t, err) + + evRecorder.On("RecordTaskEvent", mock.MatchedBy(func(c context.Context) bool { return true }), + mock.MatchedBy(func(e *event.TaskExecutionEvent) bool { return e.Phase == core.TaskExecution_FAILED })).Return(nil) + + // override create to return forbidden + mockRuntimeClient := mocks.NewMockRuntimeClient() + mockRuntimeClient.CreateCb = func(ctx context.Context, obj runtime.Object) (err error) { + return k8serrs.NewForbidden(schema.GroupResource{}, "", nil) + } + if err := flytek8s.InjectClient(mockRuntimeClient); err != nil { + assert.NoError(t, err) + } + + status, err := k.StartTask(ctx, tctx, nil, nil) + assert.NoError(t, err) + assert.Nil(t, status.State) + assert.Equal(t, types.TaskPhasePermanentFailure, status.Phase) + + // reset the client back to fake client + if err := flytek8s.InjectClient(fake.NewFakeClient()); err != nil { + assert.NoError(t, err) + } + }) +} + +func TestK8sTaskExecutor_CheckTaskStatus(t *testing.T) { + ctx := context.TODO() + c := flytek8s.InitializeFake() + + t.Run("phaseChange", func(t *testing.T) { + + evRecorder := &mocks2.EventRecorder{} + tctx := getMockTaskContext() + mockResourceHandler := &mocks.K8sResourceHandler{} + k := flytek8s.NewK8sTaskExecutorForResource("x", &v1.Pod{}, mockResourceHandler, time.Second) + mockResourceHandler.On("BuildIdentityResource", mock.Anything, tctx).Return(&v1.Pod{}, nil) + + assert.NoError(t, k.Initialize(ctx, createExecutorInitializationParams(t, evRecorder))) + testPod := &v1.Pod{} + testPod.SetName(tctx.GetTaskExecutionID().GetGeneratedName()) + testPod.SetNamespace(tctx.GetNamespace()) + testPod.SetOwnerReferences([]v12.OwnerReference{tctx.GetOwnerReference()}) + err := c.Delete(ctx, testPod) + if err != nil { + assert.True(t, k8serrs.IsNotFound(err)) + } + + assert.NoError(t, c.Create(ctx, testPod)) + defer func() { + assert.NoError(t, c.Delete(ctx, testPod)) + }() + + info := &events.TaskEventInfo{} + info.Logs = []*core.TaskLog{{Uri: "log1"}} + expectedOldPhase := types.TaskPhaseQueued + expectedNewStatus := types.TaskStatusRunning + expectedNewStatus.PhaseVersion = uint32(1) + tctx.On("GetPhase").Return(expectedOldPhase) + tctx.On("GetPhaseVersion").Return(uint32(1)) + mockResourceHandler.On("GetTaskStatus", mock.Anything, mock.Anything, mock.MatchedBy(func(o *v1.Pod) bool { return true })).Return(expectedNewStatus, nil, nil) + + evRecorder.On("RecordTaskEvent", mock.MatchedBy(func(c context.Context) bool { return true }), + mock.MatchedBy(func(e *event.TaskExecutionEvent) bool { return true })).Return(nil) + + s, err := k.CheckTaskStatus(ctx, tctx, nil) + assert.Nil(t, s.State) + assert.NoError(t, err) + assert.Equal(t, expectedNewStatus, s) + }) + + + t.Run("PhaseMismatch", func(t *testing.T) { + + evRecorder := &mocks2.EventRecorder{} + tctx := getMockTaskContext() + mockResourceHandler := &mocks.K8sResourceHandler{} + k := flytek8s.NewK8sTaskExecutorForResource("x", &v1.Pod{}, mockResourceHandler, time.Second) + mockResourceHandler.On("BuildIdentityResource", mock.Anything, tctx).Return(&v1.Pod{}, nil) + + assert.NoError(t, k.Initialize(ctx, createExecutorInitializationParams(t, evRecorder))) + testPod := &v1.Pod{} + testPod.SetName(tctx.GetTaskExecutionID().GetGeneratedName()) + testPod.SetNamespace(tctx.GetNamespace()) + testPod.SetOwnerReferences([]v12.OwnerReference{tctx.GetOwnerReference()}) + err := c.Delete(ctx, testPod) + if err != nil { + assert.True(t, k8serrs.IsNotFound(err)) + } + + assert.NoError(t, c.Create(ctx, testPod)) + defer func() { + assert.NoError(t, c.Delete(ctx, testPod)) + }() + + info := &events.TaskEventInfo{} + info.Logs = []*core.TaskLog{{Uri: "log1"}} + expectedOldPhase := types.TaskPhaseRunning + expectedNewStatus := types.TaskStatusSucceeded + expectedNewStatus.PhaseVersion = uint32(1) + tctx.On("GetPhase").Return(expectedOldPhase) + tctx.On("GetPhaseVersion").Return(uint32(1)) + mockResourceHandler.On("GetTaskStatus", mock.Anything, mock.Anything, mock.MatchedBy(func(o *v1.Pod) bool { return true })).Return(expectedNewStatus, nil, nil) + + evRecorder.On("RecordTaskEvent", mock.MatchedBy(func(c context.Context) bool { return true }), + mock.MatchedBy(func(e *event.TaskExecutionEvent) bool { + return e.Phase == core.TaskExecution_SUCCEEDED + })).Return(&eventErrors.EventError{Code: eventErrors.EventAlreadyInTerminalStateError, + Cause: errors.New("already exists"), + }) + + s, err := k.CheckTaskStatus(ctx, tctx, nil) + assert.NoError(t, err) + assert.Nil(t, s.State) + assert.Equal(t, types.TaskPhasePermanentFailure, s.Phase) + }) + + t.Run("noChange", func(t *testing.T) { + + evRecorder := &mocks2.EventRecorder{} + tctx := getMockTaskContext() + mockResourceHandler := &mocks.K8sResourceHandler{} + k := flytek8s.NewK8sTaskExecutorForResource("x", &v1.Pod{}, mockResourceHandler, time.Second) + mockResourceHandler.On("BuildIdentityResource", mock.Anything, tctx).Return(&v1.Pod{}, nil) + + assert.NoError(t, k.Initialize(ctx, createExecutorInitializationParams(t, evRecorder))) + testPod := &v1.Pod{} + testPod.SetName(tctx.GetTaskExecutionID().GetGeneratedName()) + testPod.SetNamespace(tctx.GetNamespace()) + testPod.SetOwnerReferences([]v12.OwnerReference{tctx.GetOwnerReference()}) + err := c.Delete(ctx, testPod) + if err != nil { + assert.True(t, k8serrs.IsNotFound(err)) + } + + assert.NoError(t, c.Create(ctx, testPod)) + defer func() { + assert.NoError(t, c.Delete(ctx, testPod)) + }() + + info := &events.TaskEventInfo{} + info.Logs = []*core.TaskLog{{Uri: "log1"}} + expectedOldPhase := types.TaskPhaseRunning + expectedNewStatus := types.TaskStatusRunning + expectedNewStatus.PhaseVersion = uint32(1) + tctx.On("GetPhase").Return(expectedOldPhase) + tctx.On("GetPhaseVersion").Return(uint32(1)) + mockResourceHandler.On("GetTaskStatus", mock.Anything, mock.Anything, mock.MatchedBy(func(o *v1.Pod) bool { return true })).Return(expectedNewStatus, nil, nil) + + s, err := k.CheckTaskStatus(ctx, tctx, nil) + assert.Nil(t, s.State) + assert.NoError(t, err) + assert.Equal(t, expectedNewStatus, s) + }) + + t.Run("resourceNotFound", func(t *testing.T) { + + evRecorder := &mocks2.EventRecorder{} + tctx := getMockTaskContext() + mockResourceHandler := &mocks.K8sResourceHandler{} + k := flytek8s.NewK8sTaskExecutorForResource("x", &v1.Pod{}, mockResourceHandler, time.Second) + mockResourceHandler.On("BuildIdentityResource", mock.Anything, tctx).Return(&v1.Pod{}, nil) + + assert.NoError(t, k.Initialize(ctx, createExecutorInitializationParams(t, evRecorder))) + _ = flytek8s.InitializeFake() + + info := &events.TaskEventInfo{} + info.Logs = []*core.TaskLog{{Uri: "log1"}} + expectedOldPhase := types.TaskPhaseRunning + tctx.On("GetPhase").Return(expectedOldPhase) + tctx.On("GetPhaseVersion").Return(uint32(0)) + + evRecorder.On("RecordTaskEvent", mock.MatchedBy(func(c context.Context) bool { return true }), + mock.MatchedBy(func(e *event.TaskExecutionEvent) bool { return true })).Return(nil) + + s, err := k.CheckTaskStatus(ctx, tctx, nil) + assert.Nil(t, s.State) + assert.NoError(t, err) + assert.Equal(t, types.TaskPhaseRetryableFailure, s.Phase, "Expected failure got %s", s.Phase.String()) + }) + + t.Run("errorFileExit", func(t *testing.T) { + + evRecorder := &mocks2.EventRecorder{} + tctx := getMockTaskContext() + mockResourceHandler := &mocks.K8sResourceHandler{} + k := flytek8s.NewK8sTaskExecutorForResource("x", &v1.Pod{}, mockResourceHandler, time.Second) + mockResourceHandler.On("BuildIdentityResource", mock.Anything, tctx).Return(&v1.Pod{}, nil) + + params := createExecutorInitializationParams(t, evRecorder) + store := params.DataStore + assert.NoError(t, k.Initialize(ctx, params)) + testPod := &v1.Pod{} + testPod.SetName(tctx.GetTaskExecutionID().GetGeneratedName()) + testPod.SetNamespace(tctx.GetNamespace()) + testPod.SetOwnerReferences([]v12.OwnerReference{tctx.GetOwnerReference()}) + + err := c.Delete(ctx, testPod) + if err != nil { + assert.True(t, k8serrs.IsNotFound(err)) + } + + assert.NoError(t, c.Create(ctx, testPod)) + defer func() { + assert.NoError(t, c.Delete(ctx, testPod)) + }() + + assert.NoError(t, store.WriteProtobuf(ctx, tctx.GetErrorFile(), storage.Options{}, &core.ErrorDocument{ + Error: &core.ContainerError{ + Kind: core.ContainerError_NON_RECOVERABLE, + Code: "code", + Message: "pleh", + }, + })) + + info := &events.TaskEventInfo{} + info.Logs = []*core.TaskLog{{Uri: "log1"}} + expectedOldPhase := types.TaskPhaseQueued + tctx.On("GetPhase").Return(expectedOldPhase) + tctx.On("GetPhaseVersion").Return(uint32(0)) + mockResourceHandler.On("GetTaskStatus", mock.Anything, mock.Anything, mock.MatchedBy(func(o *v1.Pod) bool { return true })).Return(types.TaskStatusSucceeded, nil, nil) + + evRecorder.On("RecordTaskEvent", mock.MatchedBy(func(c context.Context) bool { return true }), + mock.MatchedBy(func(e *event.TaskExecutionEvent) bool { return true })).Return(nil) + + s, err := k.CheckTaskStatus(ctx, tctx, nil) + assert.Nil(t, s.State) + assert.NoError(t, err) + assert.Equal(t, types.TaskPhasePermanentFailure, s.Phase) + }) + + t.Run("errorFileExitRecoverable", func(t *testing.T) { + evRecorder := &mocks2.EventRecorder{} + tctx := getMockTaskContext() + mockResourceHandler := &mocks.K8sResourceHandler{} + k := flytek8s.NewK8sTaskExecutorForResource("x", &v1.Pod{}, mockResourceHandler, time.Second) + mockResourceHandler.On("BuildIdentityResource", mock.Anything, tctx).Return(&v1.Pod{}, nil) + + params := createExecutorInitializationParams(t, evRecorder) + store := params.DataStore + + assert.NoError(t, k.Initialize(ctx, params)) + testPod := &v1.Pod{} + testPod.SetName(tctx.GetTaskExecutionID().GetGeneratedName()) + testPod.SetNamespace(tctx.GetNamespace()) + testPod.SetOwnerReferences([]v12.OwnerReference{tctx.GetOwnerReference()}) + + err := c.Delete(ctx, testPod) + if err != nil { + assert.True(t, k8serrs.IsNotFound(err)) + } + + assert.NoError(t, c.Create(ctx, testPod)) + defer func() { + assert.NoError(t, c.Delete(ctx, testPod)) + }() + + assert.NoError(t, store.WriteProtobuf(ctx, tctx.GetErrorFile(), storage.Options{}, &core.ErrorDocument{ + Error: &core.ContainerError{ + Kind: core.ContainerError_RECOVERABLE, + Code: "code", + Message: "pleh", + }, + })) + + info := &events.TaskEventInfo{} + info.Logs = []*core.TaskLog{{Uri: "log1"}} + expectedOldPhase := types.TaskPhaseQueued + tctx.On("GetPhase").Return(expectedOldPhase) + tctx.On("GetPhaseVersion").Return(uint32(0)) + mockResourceHandler.On("GetTaskStatus", mock.Anything, mock.Anything, mock.MatchedBy(func(o *v1.Pod) bool { return true })).Return(types.TaskStatusSucceeded, nil, nil) + + evRecorder.On("RecordTaskEvent", mock.MatchedBy(func(c context.Context) bool { return true }), + mock.MatchedBy(func(e *event.TaskExecutionEvent) bool { return e.Phase == core.TaskExecution_FAILED })).Return(nil) + + s, err := k.CheckTaskStatus(ctx, tctx, nil) + assert.Nil(t, s.State) + assert.NoError(t, err) + assert.Equal(t, types.TaskPhaseRetryableFailure, s.Phase) + }) + + t.Run("nodeGetsDeleted", func(t *testing.T) { + evRecorder := &mocks2.EventRecorder{} + tctx := getMockTaskContext() + mockResourceHandler := &mocks.K8sResourceHandler{} + k := flytek8s.NewK8sTaskExecutorForResource("x", &v1.Pod{}, mockResourceHandler, time.Second) + mockResourceHandler.On("BuildIdentityResource", mock.Anything, tctx).Return(&v1.Pod{}, nil) + params := createExecutorInitializationParams(t, evRecorder) + assert.NoError(t, k.Initialize(ctx, params)) + testPod := &v1.Pod{} + testPod.SetName(tctx.GetTaskExecutionID().GetGeneratedName()) + testPod.SetNamespace(tctx.GetNamespace()) + testPod.SetOwnerReferences([]v12.OwnerReference{tctx.GetOwnerReference()}) + + testPod.SetFinalizers([]string{"test_finalizer"}) + + // Ensure that the pod is not there + err := c.Delete(ctx, testPod) + assert.True(t, k8serrs.IsNotFound(err)) + + // Add a deletion timestamp to the pod definition and then create it + tt := time.Now() + k8sTime := v12.Time{ + Time: tt, + } + testPod.SetDeletionTimestamp(&k8sTime) + assert.NoError(t, c.Create(ctx, testPod)) + + // Make sure that the phase doesn't change so no events are recorded + expectedOldPhase := types.TaskPhaseQueued + tctx.On("GetPhase").Return(expectedOldPhase) + tctx.On("GetPhaseVersion").Return(uint32(0)) + mockResourceHandler.On("GetTaskStatus", mock.Anything, mock.Anything, mock.MatchedBy(func(o *v1.Pod) bool { return true })).Return(types.TaskStatusQueued, nil, nil) + + s, err := k.CheckTaskStatus(ctx, tctx, nil) + assert.Nil(t, s.State) + assert.NoError(t, err) + assert.Equal(t, types.TaskPhaseRetryableFailure, s.Phase) + }) +} + +func TestK8sTaskExecutor_HandleTaskSuccess(t *testing.T) { + ctx := context.TODO() + + tctx := getMockTaskContext() + mockResourceHandler := &mocks.K8sResourceHandler{} + k := flytek8s.NewK8sTaskExecutorForResource("x", &v1.Pod{}, mockResourceHandler, time.Second) + + t.Run("no-errorfile", func(t *testing.T) { + assert.NoError(t, k.Initialize(ctx, createExecutorInitializationParams(t, nil))) + s, err := k.HandleTaskSuccess(ctx, tctx) + assert.NoError(t, err) + assert.Equal(t, s.Phase, types.TaskPhaseSucceeded) + }) + + t.Run("retryable-error", func(t *testing.T) { + params := createExecutorInitializationParams(t, nil) + store := params.DataStore + msg := &core.ErrorDocument{ + Error: &core.ContainerError{ + Kind: core.ContainerError_RECOVERABLE, + Code: "x", + Message: "y", + }, + } + assert.NoError(t, store.WriteProtobuf(ctx, tctx.GetErrorFile(), storage.Options{}, msg)) + assert.NoError(t, k.Initialize(ctx, params)) + s, err := k.HandleTaskSuccess(ctx, tctx) + assert.NoError(t, err) + assert.Equal(t, s.Phase, types.TaskPhaseRetryableFailure) + c, ok := taskerrs.GetErrorCode(s.Err) + assert.True(t, ok) + assert.Equal(t, c, "x") + }) + + t.Run("nonretryable-error", func(t *testing.T) { + params := createExecutorInitializationParams(t, nil) + store := params.DataStore + msg := &core.ErrorDocument{ + Error: &core.ContainerError{ + Kind: core.ContainerError_NON_RECOVERABLE, + Code: "m", + Message: "n", + }, + } + assert.NoError(t, store.WriteProtobuf(ctx, tctx.GetErrorFile(), storage.Options{}, msg)) + assert.NoError(t, k.Initialize(ctx, params)) + s, err := k.HandleTaskSuccess(ctx, tctx) + assert.NoError(t, err) + assert.Equal(t, s.Phase, types.TaskPhasePermanentFailure) + c, ok := taskerrs.GetErrorCode(s.Err) + assert.True(t, ok) + assert.Equal(t, c, "m") + }) + + t.Run("corrupted", func(t *testing.T) { + params := createExecutorInitializationParams(t, nil) + store := params.DataStore + r := bytes.NewReader([]byte{'x'}) + assert.NoError(t, store.WriteRaw(ctx, tctx.GetErrorFile(), r.Size(), storage.Options{}, r)) + assert.NoError(t, k.Initialize(ctx, params)) + s, err := k.HandleTaskSuccess(ctx, tctx) + assert.Error(t, err) + assert.Equal(t, s.Phase, types.TaskPhaseUndefined) + }) +} diff --git a/flyteplugins/go/tasks/v1/flytek8s/plugin_iface.go b/flyteplugins/go/tasks/v1/flytek8s/plugin_iface.go new file mode 100755 index 0000000000..844970fd31 --- /dev/null +++ b/flyteplugins/go/tasks/v1/flytek8s/plugin_iface.go @@ -0,0 +1,38 @@ +package flytek8s + +import ( + "context" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flyteplugins/go/tasks/v1/events" + "github.com/lyft/flyteplugins/go/tasks/v1/types" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" +) + +//go:generate mockery -name K8sResourceHandler + +type Handler interface { + Handle(context.Context, runtime.Object) error +} + +// Defines an interface that deals with k8s resources. Combined with K8sTaskExecutor, this provides an easier and more +// consistent way to write TaskExecutors that create k8s resources. +type K8sResourceHandler interface { + // Defines a func to create the full resource object that will be posted to k8s. + BuildResource(ctx context.Context, taskCtx types.TaskContext, task *core.TaskTemplate, inputs *core.LiteralMap) (K8sResource, error) + + // Defines a func to create a query object (typically just object and type meta portions) that's used to query k8s + // resources. + BuildIdentityResource(ctx context.Context, taskCtx types.TaskContext) (K8sResource, error) + + // Analyses the k8s resource and reports the status as TaskPhase. + GetTaskStatus(ctx context.Context, taskCtx types.TaskContext, resource K8sResource) (types.TaskStatus, *events.TaskEventInfo, error) +} + +type K8sResource interface { + runtime.Object + metav1.Object + schema.ObjectKind +} diff --git a/flyteplugins/go/tasks/v1/flytek8s/pod_helper.go b/flyteplugins/go/tasks/v1/flytek8s/pod_helper.go new file mode 100755 index 0000000000..11fa85d86c --- /dev/null +++ b/flyteplugins/go/tasks/v1/flytek8s/pod_helper.go @@ -0,0 +1,159 @@ +package flytek8s + +import ( + "context" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytestdlib/logger" + "k8s.io/api/core/v1" + v12 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/lyft/flyteplugins/go/tasks/v1/errors" + "github.com/lyft/flyteplugins/go/tasks/v1/types" +) + +const PodKind = "pod" + +func ToK8sPod(ctx context.Context, taskCtx types.TaskContext, taskContainer *core.Container, inputs *core.LiteralMap) (*v1.PodSpec, error) { + c, err := ToK8sContainer(ctx, taskCtx, taskContainer, inputs) + if err != nil { + return nil, err + } + + containers := []v1.Container{ + *c, + } + return &v1.PodSpec{ + // We could specify Scheduler, Affinity, nodename etc + RestartPolicy: v1.RestartPolicyNever, + Containers: containers, + Tolerations: GetTolerationsForResources(c.Resources), + ServiceAccountName: taskCtx.GetK8sServiceAccount(), + }, nil +} + +func BuildPodWithSpec(podSpec *v1.PodSpec) *v1.Pod { + pod := v1.Pod{ + TypeMeta: v12.TypeMeta{ + Kind: PodKind, + APIVersion: v1.SchemeGroupVersion.String(), + }, + Spec: *podSpec, + } + + return &pod +} + +func BuildIdentityPod() *v1.Pod { + return &v1.Pod{ + TypeMeta: v12.TypeMeta{ + Kind: PodKind, + APIVersion: v1.SchemeGroupVersion.String(), + }, + } +} + +// Important considerations. +// Pending Status in Pod could be for various reasons and sometimes could signal a problem +// Case I: Pending because the Image pull is failing and it is backing off +// This could be transient. So we can actually rely on the failure reason. +// The failure transitions from ErrImagePull -> ImagePullBackoff +// Case II: Not enough resources are available. This is tricky. It could be that the total number of +// resources requested is beyond the capability of the system. for this we will rely on configuration +// and hence input gates. We should not allow bad requests that request for large number of resource through. +// In the case it makes through, we will fail after timeout +func DemystifyPending(status v1.PodStatus) (types.TaskStatus, error) { + // Search over the difference conditions in the status object. Note that the 'Pending' this function is + // demystifying is the 'phase' of the pod status. This is different than the PodReady condition type also used below + for _, c := range status.Conditions { + switch c.Type { + case v1.PodScheduled: + if c.Status == v1.ConditionFalse { + // Waiting to be scheduled. This usually refers to inability to acquire resources. + return types.TaskStatusQueued, nil + } + + case v1.PodReasonUnschedulable: + // We Ignore case in which we are unable to find resources on the cluster. This is because + // - The resources may be not available at the moment, but may become available eventually + // The pod scheduler will keep on looking at this pod and trying to satisfy it. + // + // Pod status looks like this: + // message: '0/1 nodes are available: 1 Insufficient memory.' + // reason: Unschedulable + // status: "False" + // type: PodScheduled + return types.TaskStatusQueued, nil + + case v1.PodReady: + if c.Status == v1.ConditionFalse { + // This happens in the case the image is having some problems. In the following example, K8s is having + // problems downloading an image. To ensure that, we will have to iterate over all the container statuses and + // find if some container has imagepull failure + // e.g. + // - lastProbeTime: null + // lastTransitionTime: 2018-12-18T00:57:30Z + // message: 'containers with unready status: [myapp-container]' + // reason: ContainersNotReady + // status: "False" + // type: Ready + // + // e.g. Container status + // - image: blah + // imageID: "" + // lastState: {} + // name: myapp-container + // ready: false + // restartCount: 0 + // state: + // waiting: + // message: Back-off pulling image "blah" + // reason: ImagePullBackOff + for _, containerStatus := range status.ContainerStatuses { + if !containerStatus.Ready { + if containerStatus.State.Waiting != nil { + // There are a variety of reasons that can cause a pod to be in this waiting state. + // Waiting state may be legitimate when the container is being downloaded, started or init containers are running + reason := containerStatus.State.Waiting.Reason + switch reason { + case "ErrImagePull", "ContainerCreating", "PodInitializing": + // But, there are only two "reasons" when a pod is successfully being created and hence it is in + // waiting state + // Refer to https://github.com/kubernetes/kubernetes/blob/master/pkg/kubelet/kubelet_pods.go + // and look for the default waiting states + // We also want to allow Image pulls to be retried, so ErrImagePull will be ignored + // as it eventually enters into ImagePullBackOff + // ErrImagePull -> Transitionary phase to ImagePullBackOff + // ContainerCreating -> Image is being downloaded + // PodInitializing -> Init containers are running + return types.TaskStatusQueued, nil + + case "ImagePullBackOff": + return types.TaskStatusRetryableFailure(errors.Errorf(reason, + containerStatus.State.Waiting.Message)), nil + + case "CreateContainerError": + // This happens if for instance the command to the container is incorrect, ie doesn't run + return types.TaskStatusPermanentFailure(errors.Errorf( + "CreateContainerError", containerStatus.State.Waiting.Reason)), nil + + default: + // Since we are not checking for all error states, we may end up perpetually + // in the queued state returned at the bottom of this function, until the Pod is reaped + // by K8s and we get elusive 'pod not found' errors + // So be default if the container is not waiting with the PodInitializing/ContainerCreating + // reasons, then we will assume a failure reason, and fail instantly + logger.Errorf(context.TODO(), "Pod pending with Waiting container with unhandled reason %s", reason) + return types.TaskStatusRetryableFailure(errors.Errorf(reason, containerStatus.State.Waiting.Message)), nil + } + + } + } + } + } + } + } + + return types.TaskStatusQueued, nil +} + diff --git a/flyteplugins/go/tasks/v1/flytek8s/pod_helper_test.go b/flyteplugins/go/tasks/v1/flytek8s/pod_helper_test.go new file mode 100755 index 0000000000..c089fe5b16 --- /dev/null +++ b/flyteplugins/go/tasks/v1/flytek8s/pod_helper_test.go @@ -0,0 +1,317 @@ +package flytek8s + +import ( + "context" + "testing" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytestdlib/storage" + "github.com/stretchr/testify/assert" + "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + metaV1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/lyft/flyteplugins/go/tasks/v1/flytek8s/config" + "github.com/lyft/flyteplugins/go/tasks/v1/types" + "github.com/lyft/flyteplugins/go/tasks/v1/types/mocks" +) + +func dummyContainerTaskContext(resources *v1.ResourceRequirements) types.TaskContext { + taskCtx := &mocks.TaskContext{} + taskCtx.On("GetNamespace").Return("test-namespace") + taskCtx.On("GetAnnotations").Return(map[string]string{"annotation-1": "val1"}) + taskCtx.On("GetLabels").Return(map[string]string{"label-1": "val1"}) + taskCtx.On("GetOwnerReference").Return(metaV1.OwnerReference{ + Kind: "node", + Name: "blah", + }) + taskCtx.On("GetDataDir").Return(storage.DataReference("/data/")) + taskCtx.On("GetInputsFile").Return(storage.DataReference("/input")) + taskCtx.On("GetK8sServiceAccount").Return("service-account") + + tID := &mocks.TaskExecutionID{} + tID.On("GetID").Return(core.TaskExecutionIdentifier{ + NodeExecutionId: &core.NodeExecutionIdentifier{ + ExecutionId: &core.WorkflowExecutionIdentifier{ + Name: "my_name", + Project: "my_project", + Domain: "my_domain", + }, + }, + }) + tID.On("GetGeneratedName").Return("some-acceptable-name") + taskCtx.On("GetTaskExecutionID").Return(tID) + + to := &mocks.TaskOverrides{} + to.On("GetResources").Return(resources) + taskCtx.On("GetOverrides").Return(to) + + return taskCtx +} + +func TestToK8sPod(t *testing.T) { + ctx := context.TODO() + command := []string{"command"} + args := []string{"{{.Input}}"} + container := &core.Container{ + Command: command, + Args: args, + } + + tolGPU := v1.Toleration{ + Key: "flyte/gpu", + Value: "dedicated", + Operator: v1.TolerationOpEqual, + Effect: v1.TaintEffectNoSchedule, + } + + tolStorage := v1.Toleration{ + Key: "storage", + Value: "dedicated", + Operator: v1.TolerationOpExists, + Effect: v1.TaintEffectNoSchedule, + } + + assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{ + ResourceTolerations: map[v1.ResourceName][]v1.Toleration{ + v1.ResourceStorage: {tolStorage}, + ResourceNvidiaGPU: {tolGPU}, + }}), + ) + + t.Run("WithGPU", func(t *testing.T) { + x := dummyContainerTaskContext(&v1.ResourceRequirements{ + Limits: v1.ResourceList{ + v1.ResourceCPU: resource.MustParse("1024m"), + v1.ResourceStorage: resource.MustParse("100M"), + ResourceNvidiaGPU: resource.MustParse("1"), + }, + Requests: v1.ResourceList{ + v1.ResourceCPU: resource.MustParse("1024m"), + v1.ResourceStorage: resource.MustParse("100M"), + }, + }) + + p, err := ToK8sPod(ctx, x, container, nil) + assert.NoError(t, err) + assert.Equal(t, len(p.Tolerations), 1) + }) + + t.Run("NoGPU", func(t *testing.T) { + x := dummyContainerTaskContext(&v1.ResourceRequirements{ + Limits: v1.ResourceList{ + v1.ResourceCPU: resource.MustParse("1024m"), + v1.ResourceStorage: resource.MustParse("100M"), + }, + Requests: v1.ResourceList{ + v1.ResourceCPU: resource.MustParse("1024m"), + v1.ResourceStorage: resource.MustParse("100M"), + }, + }) + + p, err := ToK8sPod(ctx, x, container, nil) + assert.NoError(t, err) + assert.Equal(t, len(p.Tolerations), 0) + assert.Equal(t, "some-acceptable-name", p.Containers[0].Name) + }) +} + +func TestDemystifyPending(t *testing.T) { + + t.Run("PodNotScheduled", func(t *testing.T) { + s := v1.PodStatus{ + Phase: v1.PodPending, + Conditions: []v1.PodCondition{ + { + Type: v1.PodScheduled, + Status: v1.ConditionFalse, + }, + }, + } + taskStatus, err := DemystifyPending(s) + assert.NoError(t, err) + assert.Equal(t, types.TaskPhaseQueued, taskStatus.Phase) + }) + + t.Run("PodUnschedulable", func(t *testing.T) { + s := v1.PodStatus{ + Phase: v1.PodPending, + Conditions: []v1.PodCondition{ + { + Type: v1.PodReasonUnschedulable, + Status: v1.ConditionFalse, + }, + }, + } + taskStatus, err := DemystifyPending(s) + assert.NoError(t, err) + assert.Equal(t, types.TaskPhaseQueued, taskStatus.Phase) + }) + + t.Run("PodNotScheduled", func(t *testing.T) { + s := v1.PodStatus{ + Phase: v1.PodPending, + Conditions: []v1.PodCondition{ + { + Type: v1.PodScheduled, + Status: v1.ConditionTrue, + }, + }, + } + taskStatus, err := DemystifyPending(s) + assert.NoError(t, err) + assert.Equal(t, types.TaskPhaseQueued, taskStatus.Phase) + }) + + t.Run("PodUnschedulable", func(t *testing.T) { + s := v1.PodStatus{ + Phase: v1.PodPending, + Conditions: []v1.PodCondition{ + { + Type: v1.PodReasonUnschedulable, + Status: v1.ConditionUnknown, + }, + }, + } + taskStatus, err := DemystifyPending(s) + assert.NoError(t, err) + assert.Equal(t, types.TaskPhaseQueued, taskStatus.Phase) + }) + + s := v1.PodStatus{ + Phase: v1.PodPending, + Conditions: []v1.PodCondition{ + { + Type: v1.PodReady, + Status: v1.ConditionFalse, + }, + { + Type: v1.PodReasonUnschedulable, + Status: v1.ConditionUnknown, + }, + { + Type: v1.PodScheduled, + Status: v1.ConditionTrue, + }, + }, + } + + t.Run("ContainerCreating", func(t *testing.T) { + s.ContainerStatuses = []v1.ContainerStatus{ + { + Ready: false, + State: v1.ContainerState{ + Waiting: &v1.ContainerStateWaiting{ + Reason: "ContainerCreating", + Message: "this is not an error", + }, + }, + }, + } + taskStatus, err := DemystifyPending(s) + assert.NoError(t, err) + assert.Equal(t, types.TaskPhaseQueued, taskStatus.Phase) + }) + + t.Run("ErrImagePull", func(t *testing.T) { + s.ContainerStatuses = []v1.ContainerStatus{ + { + Ready: false, + State: v1.ContainerState{ + Waiting: &v1.ContainerStateWaiting{ + Reason: "ErrImagePull", + Message: "this is not an error", + }, + }, + }, + } + taskStatus, err := DemystifyPending(s) + assert.NoError(t, err) + assert.Equal(t, types.TaskPhaseQueued, taskStatus.Phase) + }) + + t.Run("PodInitializing", func(t *testing.T) { + s.ContainerStatuses = []v1.ContainerStatus{ + { + Ready: false, + State: v1.ContainerState{ + Waiting: &v1.ContainerStateWaiting{ + Reason: "PodInitializing", + Message: "this is not an error", + }, + }, + }, + } + taskStatus, err := DemystifyPending(s) + assert.NoError(t, err) + assert.Equal(t, types.TaskPhaseQueued, taskStatus.Phase) + }) + + t.Run("ImagePullBackOff", func(t *testing.T) { + s.ContainerStatuses = []v1.ContainerStatus{ + { + Ready: false, + State: v1.ContainerState{ + Waiting: &v1.ContainerStateWaiting{ + Reason: "ImagePullBackOff", + Message: "this is an error", + }, + }, + }, + } + taskStatus, err := DemystifyPending(s) + assert.NoError(t, err) + assert.Equal(t, types.TaskPhaseRetryableFailure, taskStatus.Phase) + }) + + t.Run("InvalidImageName", func(t *testing.T) { + s.ContainerStatuses = []v1.ContainerStatus{ + { + Ready: false, + State: v1.ContainerState{ + Waiting: &v1.ContainerStateWaiting{ + Reason: "InvalidImageName", + Message: "this is an error", + }, + }, + }, + } + taskStatus, err := DemystifyPending(s) + assert.NoError(t, err) + assert.Equal(t, types.TaskPhaseRetryableFailure, taskStatus.Phase) + }) + + t.Run("RegistryUnavailable", func(t *testing.T) { + s.ContainerStatuses = []v1.ContainerStatus{ + { + Ready: false, + State: v1.ContainerState{ + Waiting: &v1.ContainerStateWaiting{ + Reason: "RegistryUnavailable", + Message: "this is an error", + }, + }, + }, + } + taskStatus, err := DemystifyPending(s) + assert.NoError(t, err) + assert.Equal(t, types.TaskPhaseRetryableFailure, taskStatus.Phase) + }) + + t.Run("RandomError", func(t *testing.T) { + s.ContainerStatuses = []v1.ContainerStatus{ + { + Ready: false, + State: v1.ContainerState{ + Waiting: &v1.ContainerStateWaiting{ + Reason: "RandomError", + Message: "this is an error", + }, + }, + }, + } + taskStatus, err := DemystifyPending(s) + assert.NoError(t, err) + assert.Equal(t, types.TaskPhaseRetryableFailure, taskStatus.Phase) + }) +} diff --git a/flyteplugins/go/tasks/v1/flytek8s/utils.go b/flyteplugins/go/tasks/v1/flytek8s/utils.go new file mode 100755 index 0000000000..1bb2afd67d --- /dev/null +++ b/flyteplugins/go/tasks/v1/flytek8s/utils.go @@ -0,0 +1,39 @@ +package flytek8s + +import ( + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + v1 "k8s.io/api/core/v1" + k8serrors "k8s.io/apimachinery/pkg/api/errors" +) + +func ToK8sEnvVar(env []*core.KeyValuePair) []v1.EnvVar { + envVars := make([]v1.EnvVar, 0, len(env)) + for _, kv := range env { + envVars = append(envVars, v1.EnvVar{Name: kv.Key, Value: kv.Value}) + } + return envVars +} + +// This function unions a list of maps (each can be nil or populated) by allocating a new map. +// Conflicting keys will always defer to the later input map's corresponding value. +func UnionMaps(maps ...map[string]string) map[string]string { + size := 0 + for _, m := range maps { + size += len(m) + } + + composite := make(map[string]string, size) + for _, m := range maps { + if m != nil { + for k, v := range m { + composite[k] = v + } + } + } + + return composite +} + +func IsK8sObjectNotExists(err error) bool { + return k8serrors.IsNotFound(err) || k8serrors.IsGone(err) || k8serrors.IsResourceExpired(err) +} diff --git a/flyteplugins/go/tasks/v1/flytek8s/utils_test.go b/flyteplugins/go/tasks/v1/flytek8s/utils_test.go new file mode 100755 index 0000000000..af56b60f3e --- /dev/null +++ b/flyteplugins/go/tasks/v1/flytek8s/utils_test.go @@ -0,0 +1,30 @@ +package flytek8s + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestUnionMaps(t *testing.T) { + assert.EqualValues(t, map[string]string{ + "left": "only", + }, UnionMaps(map[string]string{ + "left": "only", + }, nil)) + + assert.EqualValues(t, map[string]string{ + "right": "only", + }, UnionMaps(nil, map[string]string{ + "right": "only", + })) + + assert.EqualValues(t, map[string]string{ + "left": "val", + "right": "val", + }, UnionMaps(map[string]string{ + "left": "val", + }, map[string]string{ + "right": "val", + })) +} diff --git a/flyteplugins/go/tasks/v1/k8splugins/container.go b/flyteplugins/go/tasks/v1/k8splugins/container.go new file mode 100755 index 0000000000..3223b1f209 --- /dev/null +++ b/flyteplugins/go/tasks/v1/k8splugins/container.go @@ -0,0 +1,113 @@ +package k8splugins + +import ( + "context" + "time" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "k8s.io/api/core/v1" + metaV1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + tasksV1 "github.com/lyft/flyteplugins/go/tasks/v1" + "github.com/lyft/flyteplugins/go/tasks/v1/errors" + "github.com/lyft/flyteplugins/go/tasks/v1/events" + "github.com/lyft/flyteplugins/go/tasks/v1/flytek8s" + "github.com/lyft/flyteplugins/go/tasks/v1/logs" + "github.com/lyft/flyteplugins/go/tasks/v1/types" +) + +const ( + containerTaskType = "container" +) + +func ConvertPodFailureToError(status v1.PodStatus) error { + reason := errors.TaskFailedUnknownError + message := "Container/Pod failed. No message received from kubernetes. Could be permissions?" + if status.Reason != "" { + reason = status.Reason + } + if status.Message != "" { + message = status.Message + } + return errors.Errorf(reason, message) +} + +func GetLastTransitionOccurredAt(pod *v1.Pod) metaV1.Time { + var lastTransitionTime metaV1.Time + containerStatuses := append(pod.Status.ContainerStatuses, pod.Status.InitContainerStatuses...) + for _, containerStatus := range containerStatuses { + if r := containerStatus.LastTerminationState.Running; r != nil { + if r.StartedAt.Unix() > lastTransitionTime.Unix() { + lastTransitionTime = r.StartedAt + } + } else if r := containerStatus.LastTerminationState.Terminated; r != nil { + if r.FinishedAt.Unix() > lastTransitionTime.Unix() { + lastTransitionTime = r.StartedAt + } + } + } + + if lastTransitionTime.IsZero() { + lastTransitionTime = metaV1.NewTime(time.Now()) + } + + return lastTransitionTime +} + +type containerTaskExecutor struct { +} + +func (containerTaskExecutor) GetTaskStatus(ctx context.Context, _ types.TaskContext, r flytek8s.K8sResource) ( + types.TaskStatus, *events.TaskEventInfo, error) { + + pod := r.(*v1.Pod) + + var info *events.TaskEventInfo + if pod.Status.Phase != v1.PodPending && pod.Status.Phase != v1.PodUnknown { + taskLogs, err := logs.GetLogsForContainerInPod(ctx, pod, 0, " (User)") + if err != nil { + return types.TaskStatusUndefined, nil, err + } + + t := GetLastTransitionOccurredAt(pod).Time + info = &events.TaskEventInfo{ + Logs: taskLogs, + OccurredAt: &t, + } + } + switch pod.Status.Phase { + case v1.PodSucceeded: + return types.TaskStatusSucceeded.WithOccurredAt(GetLastTransitionOccurredAt(pod).Time), info, nil + case v1.PodFailed: + return types.TaskStatusRetryableFailure(ConvertPodFailureToError(pod.Status)).WithOccurredAt(GetLastTransitionOccurredAt(pod).Time), info, nil + case v1.PodPending: + status, err := flytek8s.DemystifyPending(pod.Status) + return status, info, err + case v1.PodUnknown: + return types.TaskStatusUnknown, info, nil + } + return types.TaskStatusRunning, info, nil +} + +// Creates a new Pod that will Exit on completion. The pods have no retries by design +func (c containerTaskExecutor) BuildResource(ctx context.Context, taskCtx types.TaskContext, task *core.TaskTemplate, inputs *core.LiteralMap) (flytek8s.K8sResource, error) { + + podSpec, err := flytek8s.ToK8sPod(ctx, taskCtx, task.GetContainer(), inputs) + if err != nil { + return nil, err + } + + pod := flytek8s.BuildPodWithSpec(podSpec) + return pod, nil +} + +func (containerTaskExecutor) BuildIdentityResource(_ context.Context, taskCtx types.TaskContext) (flytek8s.K8sResource, error) { + return flytek8s.BuildIdentityPod(), nil +} + +func init() { + tasksV1.RegisterLoader(func(ctx context.Context) error { + return tasksV1.K8sRegisterAsDefault(containerTaskType, &v1.Pod{}, flytek8s.DefaultInformerResyncDuration, + containerTaskExecutor{}) + }) +} diff --git a/flyteplugins/go/tasks/v1/k8splugins/container_test.go b/flyteplugins/go/tasks/v1/k8splugins/container_test.go new file mode 100755 index 0000000000..f078ff3d19 --- /dev/null +++ b/flyteplugins/go/tasks/v1/k8splugins/container_test.go @@ -0,0 +1,231 @@ +package k8splugins + +import ( + "context" + "testing" + + "github.com/lyft/flyteplugins/go/tasks/v1/errors" + "github.com/lyft/flyteplugins/go/tasks/v1/flytek8s" + + "sigs.k8s.io/controller-runtime/pkg/client" + + "github.com/lyft/flytestdlib/storage" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/lyft/flyteplugins/go/tasks/v1/types/mocks" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/stretchr/testify/assert" + v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + + "github.com/lyft/flyteplugins/go/tasks/v1/types" +) + +var resourceRequirements = &v1.ResourceRequirements{ + Limits: v1.ResourceList{ + v1.ResourceCPU: resource.MustParse("1024m"), + v1.ResourceStorage: resource.MustParse("100M"), + }, +} + +func dummyContainerTaskContext(resources *v1.ResourceRequirements) types.TaskContext { + taskCtx := &mocks.TaskContext{} + taskCtx.On("GetNamespace").Return("test-namespace") + taskCtx.On("GetAnnotations").Return(map[string]string{"annotation-1": "val1"}) + taskCtx.On("GetLabels").Return(map[string]string{"label-1": "val1"}) + taskCtx.On("GetOwnerReference").Return(metav1.OwnerReference{ + Kind: "node", + Name: "blah", + }) + taskCtx.On("GetDataDir").Return(storage.DataReference("/data/")) + taskCtx.On("GetInputsFile").Return(storage.DataReference("/input")) + taskCtx.On("GetK8sServiceAccount").Return("service-account") + + tID := &mocks.TaskExecutionID{} + tID.On("GetID").Return(core.TaskExecutionIdentifier{ + NodeExecutionId: &core.NodeExecutionIdentifier{ + ExecutionId: &core.WorkflowExecutionIdentifier{ + Name: "my_name", + Project: "my_project", + Domain: "my_domain", + }, + }, + }) + tID.On("GetGeneratedName").Return("some-acceptable-name") + taskCtx.On("GetTaskExecutionID").Return(tID) + + to := &mocks.TaskOverrides{} + to.On("GetResources").Return(resources) + taskCtx.On("GetOverrides").Return(to) + + return taskCtx +} + +func TestContainerTaskExecutor_BuildIdentityResource(t *testing.T) { + c := containerTaskExecutor{} + x := &mocks.TaskContext{} + r, err := c.BuildIdentityResource(context.TODO(), x) + assert.NoError(t, err) + assert.NotNil(t, r) + _, ok := r.(*v1.Pod) + assert.True(t, ok) + assert.Equal(t, flytek8s.PodKind, r.GetObjectKind().GroupVersionKind().Kind) +} + +func TestContainerTaskExecutor_BuildResource(t *testing.T) { + c := containerTaskExecutor{} + x := dummyContainerTaskContext(resourceRequirements) + command := []string{"command"} + args := []string{"{{.Input}}"} + + task := &core.TaskTemplate{ + Type: "test", + Target: &core.TaskTemplate_Container{ + Container: &core.Container{ + Command: command, + Args: args, + }, + }, + } + r, err := c.BuildResource(context.TODO(), x, task, nil) + assert.NoError(t, err) + assert.NotNil(t, r) + j, ok := r.(*v1.Pod) + assert.True(t, ok) + + assert.NotEmpty(t, j.Spec.Containers) + assert.Equal(t, resourceRequirements.Limits[v1.ResourceCPU], j.Spec.Containers[0].Resources.Limits[v1.ResourceCPU]) + + // TODO: Once configurable, test when setting storage is supported on the cluster vs not. + storageRes := j.Spec.Containers[0].Resources.Limits[v1.ResourceStorage] + assert.Equal(t, int64(0), (&storageRes).Value()) + + assert.Equal(t, command, j.Spec.Containers[0].Command) + assert.Equal(t, []string{"/input"}, j.Spec.Containers[0].Args) + + assert.Equal(t, "service-account", j.Spec.ServiceAccountName) +} + +func TestContainerTaskExecutor_GetTaskStatus(t *testing.T) { + c := containerTaskExecutor{} + j := &v1.Pod{ + Status: v1.PodStatus{}, + } + + ctx := context.TODO() + t.Run("running", func(t *testing.T) { + s, i, err := c.GetTaskStatus(ctx, nil, j) + assert.NoError(t, err) + assert.NotNil(t, i) + assert.Equal(t, types.TaskPhaseRunning, s.Phase) + }) + + t.Run("queued", func(t *testing.T) { + j.Status.Phase = v1.PodPending + s, i, err := c.GetTaskStatus(ctx, nil, j) + assert.NoError(t, err) + assert.Equal(t, types.TaskPhaseQueued, s.Phase) + assert.Nil(t, i) + }) + + t.Run("failNoCondition", func(t *testing.T) { + j.Status.Phase = v1.PodFailed + s, i, err := c.GetTaskStatus(ctx, nil, j) + assert.NoError(t, err) + assert.NotNil(t, i) + assert.Equal(t, types.TaskPhaseRetryableFailure, s.Phase) + ec, ok := errors.GetErrorCode(s.Err) + assert.True(t, ok) + assert.Equal(t, errors.TaskFailedUnknownError, ec) + }) + + t.Run("failConditionUnschedulable", func(t *testing.T) { + j.Status.Phase = v1.PodFailed + j.Status.Reason = "Unschedulable" + j.Status.Message = "some message" + j.Status.Conditions = []v1.PodCondition{ + { + Type: v1.PodReasonUnschedulable, + }, + } + s, i, err := c.GetTaskStatus(ctx, nil, j) + assert.NoError(t, err) + assert.NotNil(t, i) + assert.Equal(t, types.TaskPhaseRetryableFailure, s.Phase) + ec, ok := errors.GetErrorCode(s.Err) + assert.True(t, ok) + assert.Equal(t, "Unschedulable", ec) + }) + + t.Run("success", func(t *testing.T) { + j.Status.Phase = v1.PodSucceeded + s, i, err := c.GetTaskStatus(ctx, nil, j) + assert.NoError(t, err) + assert.Equal(t, types.TaskPhaseSucceeded, s.Phase) + assert.NotNil(t, i) + }) +} + +func TestConvertPodFailureToError(t *testing.T) { + t.Run("unknown-error", func(t *testing.T) { + err := ConvertPodFailureToError(v1.PodStatus{}) + assert.Error(t, err) + ec, ok := errors.GetErrorCode(err) + assert.True(t, ok) + assert.Equal(t, ec, errors.TaskFailedUnknownError) + }) + + t.Run("known-error", func(t *testing.T) { + err := ConvertPodFailureToError(v1.PodStatus{Reason: "hello"}) + assert.Error(t, err) + ec, ok := errors.GetErrorCode(err) + assert.True(t, ok) + assert.Equal(t, ec, "hello") + }) +} + +func advancePodPhases(ctx context.Context, runtimeClient client.Client) error { + podList := &v1.PodList{} + err := runtimeClient.List(ctx, &client.ListOptions{ + Raw: &metav1.ListOptions{ + TypeMeta: metav1.TypeMeta{ + Kind: flytek8s.PodKind, + APIVersion: v1.SchemeGroupVersion.String(), + }, + }, + }, podList) + if err != nil { + return err + } + + for _, pod := range podList.Items { + pod.Status.Phase = nextHappyPodPhase(pod.Status.Phase) + pod.Status.ContainerStatuses = []v1.ContainerStatus{ + {ContainerID: "cont_123"}, + } + err = runtimeClient.Update(ctx, pod.DeepCopy()) + if err != nil { + return err + } + } + + return nil +} + +func nextHappyPodPhase(phase v1.PodPhase) v1.PodPhase { + switch phase { + case v1.PodUnknown: + fallthrough + case v1.PodPending: + fallthrough + case "": + return v1.PodRunning + case v1.PodRunning: + return v1.PodSucceeded + case v1.PodSucceeded: + return v1.PodSucceeded + } + + return v1.PodUnknown +} diff --git a/flyteplugins/go/tasks/v1/k8splugins/mocks/AutoRefreshCache.go b/flyteplugins/go/tasks/v1/k8splugins/mocks/AutoRefreshCache.go new file mode 100755 index 0000000000..9c6f13efc2 --- /dev/null +++ b/flyteplugins/go/tasks/v1/k8splugins/mocks/AutoRefreshCache.go @@ -0,0 +1,56 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import context "context" +import mock "github.com/stretchr/testify/mock" +import utils "github.com/lyft/flytestdlib/utils" + +// AutoRefreshCache is an autogenerated mock type for the AutoRefreshCache type +type AutoRefreshCache struct { + mock.Mock +} + +// Get provides a mock function with given fields: id +func (_m *AutoRefreshCache) Get(id string) utils.CacheItem { + ret := _m.Called(id) + + var r0 utils.CacheItem + if rf, ok := ret.Get(0).(func(string) utils.CacheItem); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(utils.CacheItem) + } + } + + return r0 +} + +// GetOrCreate provides a mock function with given fields: item +func (_m *AutoRefreshCache) GetOrCreate(item utils.CacheItem) (utils.CacheItem, error) { + ret := _m.Called(item) + + var r0 utils.CacheItem + if rf, ok := ret.Get(0).(func(utils.CacheItem) utils.CacheItem); ok { + r0 = rf(item) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(utils.CacheItem) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(utils.CacheItem) error); ok { + r1 = rf(item) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Start provides a mock function with given fields: ctx +func (_m *AutoRefreshCache) Start(ctx context.Context) { + _m.Called(ctx) +} diff --git a/flyteplugins/go/tasks/v1/k8splugins/mocks/sidecar_custom b/flyteplugins/go/tasks/v1/k8splugins/mocks/sidecar_custom new file mode 100755 index 0000000000..c963af720c --- /dev/null +++ b/flyteplugins/go/tasks/v1/k8splugins/mocks/sidecar_custom @@ -0,0 +1,59 @@ +{ + "podSpec": { + "restartPolicy": "OnFailure", + "containers": [{ + "name": "a container", + "image": "foo", + "args": ["pyflyte-execute", "--task-module", "tests.flytekit.unit.sdk.tasks.test_sidecar_tasks", "--task-name", "simple_sidecar_task", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}"], + "volumeMounts": [{ + "mountPath": "some/where", + "name": "volume mount" + }], + "env": [{ + "name": "FLYTE_INTERNAL_CONFIGURATION_PATH", + "value": "flytekit.config" + }, { + "name": "FLYTE_INTERNAL_PROJECT", + "value": "" + }, { + "name": "foo", + "value": "bar" + }, { + "name": "FLYTE_INTERNAL_DOMAIN", + "value": "" + }, { + "name": "FLYTE_INTERNAL_VERSION", + "value": "" + }], + "resources": { + "requests": { + "cpu": { + "string": "10" + } + }, + "limits": { + "nvidia.com/gpu": { + "string": "2" + }, + "cpu": { + "string": "10" + } + } + } + }, { + "name": "another container" + }], + "volumes": [{ + "volumeSource": { + "emptyDir": { + "sizeLimit": { + "string": "10G" + }, + "medium": "Memory" + } + }, + "name": "dshm" + }] + }, + "primaryContainerName": "a container" +} diff --git a/flyteplugins/go/tasks/v1/k8splugins/sidecar.go b/flyteplugins/go/tasks/v1/k8splugins/sidecar.go new file mode 100755 index 0000000000..29c8adc627 --- /dev/null +++ b/flyteplugins/go/tasks/v1/k8splugins/sidecar.go @@ -0,0 +1,196 @@ +package k8splugins + +import ( + "context" + "fmt" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins" + + v1 "github.com/lyft/flyteplugins/go/tasks/v1" + "github.com/lyft/flyteplugins/go/tasks/v1/errors" + "github.com/lyft/flyteplugins/go/tasks/v1/events" + "github.com/lyft/flyteplugins/go/tasks/v1/logs" + "github.com/lyft/flyteplugins/go/tasks/v1/utils" + + "github.com/lyft/flyteplugins/go/tasks/v1/flytek8s" + + k8sv1 "k8s.io/api/core/v1" + + "github.com/lyft/flyteplugins/go/tasks/v1/types" +) + +const ( + sidecarTaskType = "sidecar" + primaryContainerKey = "primary" +) + +type sidecarResourceHandler struct{} + +// This method handles templatizing primary container input args, env variables and adds a GPU toleration to the pod +// spec if necessary. +func validateAndFinalizeContainers( + ctx context.Context, taskCtx types.TaskContext, primaryContainerName string, pod k8sv1.Pod, + inputs *core.LiteralMap) (*k8sv1.Pod, error) { + var hasPrimaryContainer bool + + finalizedContainers := make([]k8sv1.Container, len(pod.Spec.Containers)) + resReqs := make([]k8sv1.ResourceRequirements, 0, len(pod.Spec.Containers)) + for index, container := range pod.Spec.Containers { + if container.Name == primaryContainerName { + hasPrimaryContainer = true + } + modifiedCommand, err := utils.ReplaceTemplateCommandArgs(ctx, + container.Command, + utils.CommandLineTemplateArgs{ + Input: taskCtx.GetInputsFile().String(), + OutputPrefix: taskCtx.GetDataDir().String(), + Inputs: utils.LiteralMapToTemplateArgs(ctx, inputs), + }) + + if err != nil { + return nil, err + } + container.Command = modifiedCommand + + modifiedArgs, err := utils.ReplaceTemplateCommandArgs(ctx, + container.Args, + utils.CommandLineTemplateArgs{ + Input: taskCtx.GetInputsFile().String(), + OutputPrefix: taskCtx.GetDataDir().String(), + Inputs: utils.LiteralMapToTemplateArgs(ctx, inputs), + }) + + if err != nil { + return nil, err + } + container.Args = modifiedArgs + + container.Env = flytek8s.DecorateEnvVars(ctx, container.Env, taskCtx.GetTaskExecutionID()) + resources := flytek8s.ApplyResourceOverrides(ctx, container.Resources) + resReqs = append(resReqs, *resources) + finalizedContainers[index] = container + } + if !hasPrimaryContainer { + return nil, errors.Errorf(errors.BadTaskSpecification, + "invalid Sidecar task, primary container [%s] not defined", primaryContainerName) + + } + pod.Spec.Containers = finalizedContainers + pod.Spec.Tolerations = flytek8s.GetTolerationsForResources(resReqs...) + return &pod, nil +} + +func (sidecarResourceHandler) BuildResource( + ctx context.Context, taskCtx types.TaskContext, task *core.TaskTemplate, inputs *core.LiteralMap) ( + flytek8s.K8sResource, error) { + sidecarJob := plugins.SidecarJob{} + err := utils.UnmarshalStruct(task.GetCustom(), &sidecarJob) + if err != nil { + return nil, errors.Errorf(errors.BadTaskSpecification, + "invalid TaskSpecification [%v], Err: [%v]", task.GetCustom(), err.Error()) + } + + pod := flytek8s.BuildPodWithSpec(sidecarJob.PodSpec) + // Set the restart policy to *not* inherit from the default so that a completed pod doesn't get caught in a + // CrashLoopBackoff after the initial job completion. + pod.Spec.RestartPolicy = k8sv1.RestartPolicyNever + + // We want to Also update the serviceAccount to the serviceaccount of the workflow + pod.Spec.ServiceAccountName = taskCtx.GetK8sServiceAccount() + + pod, err = validateAndFinalizeContainers(ctx, taskCtx, sidecarJob.PrimaryContainerName, *pod, inputs) + if err != nil { + return nil, err + } + + if pod.Annotations == nil { + pod.Annotations = make(map[string]string, 1) + } + + pod.Annotations[primaryContainerKey] = sidecarJob.PrimaryContainerName + + return pod, nil +} + +func (sidecarResourceHandler) BuildIdentityResource(ctx context.Context, taskCtx types.TaskContext) ( + flytek8s.K8sResource, error) { + return flytek8s.BuildIdentityPod(), nil +} + +func determinePrimaryContainerStatus(primaryContainerName string, statuses []k8sv1.ContainerStatus) ( + types.TaskStatus, error) { + for _, s := range statuses { + if s.Name == primaryContainerName { + if s.State.Waiting != nil || s.State.Running != nil { + return types.TaskStatusRunning, nil + } + + if s.State.Terminated != nil { + if s.State.Terminated.ExitCode != 0 { + return types.TaskStatusRetryableFailure(errors.Errorf( + s.State.Terminated.Reason, s.State.Terminated.Message)), nil + } + return types.TaskStatusSucceeded, nil + } + } + } + + // If for some reason we can't find the primary container, always just return a permanent failure + return types.TaskStatusPermanentFailure(errors.Errorf("PrimaryContainerMissing", + "Primary container [%s] not found in pod's container statuses", primaryContainerName)), nil +} + +func (sidecarResourceHandler) GetTaskStatus( + ctx context.Context, taskCtx types.TaskContext, resource flytek8s.K8sResource) ( + types.TaskStatus, *events.TaskEventInfo, error) { + pod := resource.(*k8sv1.Pod) + + var info *events.TaskEventInfo + if pod.Status.Phase != k8sv1.PodPending && pod.Status.Phase != k8sv1.PodUnknown { + taskLogs := make([]*core.TaskLog, 0) + for idx, container := range pod.Spec.Containers { + containerLogs, err := logs.GetLogsForContainerInPod(ctx, pod, uint32(idx), fmt.Sprintf(" (%s)", container.Name)) + if err != nil { + return types.TaskStatusUndefined, nil, err + } + taskLogs = append(taskLogs, containerLogs...) + } + + t := GetLastTransitionOccurredAt(pod).Time + info = &events.TaskEventInfo{ + Logs: taskLogs, + OccurredAt: &t, + } + } + switch pod.Status.Phase { + case k8sv1.PodSucceeded: + return types.TaskStatusSucceeded, info, nil + case k8sv1.PodFailed: + return types.TaskStatusRetryableFailure(ConvertPodFailureToError(pod.Status)), info, nil + case k8sv1.PodPending: + status, err := flytek8s.DemystifyPending(pod.Status) + return status, info, err + case k8sv1.PodReasonUnschedulable: + return types.TaskStatusQueued, info, nil + case k8sv1.PodUnknown: + return types.TaskStatusUnknown, info, nil + } + + // Otherwise, assume the pod is running. + primaryContainerName, ok := resource.GetAnnotations()[primaryContainerKey] + if !ok { + return types.TaskStatusUndefined, nil, errors.Errorf(errors.BadTaskSpecification, + "missing primary container annotation for pod") + } + + status, err := determinePrimaryContainerStatus(primaryContainerName, pod.Status.ContainerStatuses) + return status, info, err +} + +func init() { + v1.RegisterLoader(func(ctx context.Context) error { + return v1.K8sRegisterForTaskTypes(sidecarTaskType, &k8sv1.Pod{}, flytek8s.DefaultInformerResyncDuration, + sidecarResourceHandler{}, sidecarTaskType) + }) +} diff --git a/flyteplugins/go/tasks/v1/k8splugins/sidecar_test.go b/flyteplugins/go/tasks/v1/k8splugins/sidecar_test.go new file mode 100755 index 0000000000..43d28fe96a --- /dev/null +++ b/flyteplugins/go/tasks/v1/k8splugins/sidecar_test.go @@ -0,0 +1,231 @@ +package k8splugins + +import ( + "context" + "io/ioutil" + "os" + "path" + "testing" + + "github.com/golang/protobuf/jsonpb" + "github.com/golang/protobuf/ptypes/struct" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins" + "github.com/stretchr/testify/assert" + "k8s.io/api/core/v1" + + "github.com/lyft/flyteplugins/go/tasks/v1/flytek8s" + "github.com/lyft/flyteplugins/go/tasks/v1/flytek8s/config" + "github.com/lyft/flyteplugins/go/tasks/v1/types" + "github.com/lyft/flyteplugins/go/tasks/v1/utils" +) + +const ResourceNvidiaGPU = "nvidia.com/gpu" + +func getSidecarTaskTemplateForTest(sideCarJob plugins.SidecarJob) *core.TaskTemplate { + sidecarJSON, err := utils.MarshalToString(&sideCarJob) + if err != nil { + panic(err) + } + structObj := structpb.Struct{} + err = jsonpb.UnmarshalString(sidecarJSON, &structObj) + if err != nil { + panic(err) + } + return &core.TaskTemplate{ + Custom: &structObj, + } +} + +func TestBuildSidecarResource(t *testing.T) { + dir, err := os.Getwd() + if err != nil { + t.Fatal(err) + } + sidecarCustomJSON, err := ioutil.ReadFile(path.Join(dir, "mocks", "sidecar_custom")) + if err != nil { + t.Fatal(sidecarCustomJSON) + } + sidecarCustom := structpb.Struct{} + if err := jsonpb.UnmarshalString(string(sidecarCustomJSON), &sidecarCustom); err != nil { + t.Fatal(err) + } + task := core.TaskTemplate{ + Custom: &sidecarCustom, + } + + tolGPU := v1.Toleration{ + Key: "flyte/gpu", + Value: "dedicated", + Operator: v1.TolerationOpEqual, + Effect: v1.TaintEffectNoSchedule, + } + + tolStorage := v1.Toleration{ + Key: "storage", + Value: "dedicated", + Operator: v1.TolerationOpExists, + Effect: v1.TaintEffectNoSchedule, + } + assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{ + ResourceTolerations: map[v1.ResourceName][]v1.Toleration{ + v1.ResourceStorage: {tolStorage}, + ResourceNvidiaGPU: {tolGPU}, + }, + })) + handler := &sidecarResourceHandler{} + taskCtx := dummyContainerTaskContext(resourceRequirements) + resource, err := handler.BuildResource(context.TODO(), taskCtx, &task, nil) + assert.Nil(t, err) + assert.EqualValues(t, map[string]string{ + primaryContainerKey: "a container", + }, resource.GetAnnotations()) + assert.Contains(t, resource.(*v1.Pod).Spec.Tolerations, tolGPU) +} + +func TestBuildSidecarResourceMissingPrimary(t *testing.T) { + sideCarJob := plugins.SidecarJob{ + PrimaryContainerName: "PrimaryContainer", + PodSpec: &v1.PodSpec{ + Containers: []v1.Container{ + { + Name: "SecondaryContainer", + }, + }, + }, + } + + task := getSidecarTaskTemplateForTest(sideCarJob) + + handler := &sidecarResourceHandler{} + taskCtx := dummyContainerTaskContext(resourceRequirements) + _, err := handler.BuildResource(context.TODO(), taskCtx, task, nil) + assert.EqualError(t, err, + "task failed, BadTaskSpecification: invalid Sidecar task, primary container [PrimaryContainer] not defined") +} + +func TestGetTaskSidecarStatus(t *testing.T) { + var testCases = map[v1.PodPhase]types.TaskPhase{ + v1.PodSucceeded: types.TaskPhaseSucceeded, + v1.PodFailed: types.TaskPhaseRetryableFailure, + v1.PodReasonUnschedulable: types.TaskPhaseQueued, + v1.PodUnknown: types.TaskPhaseUnknown, + } + + for podPhase, expectedTaskPhase := range testCases { + var resource flytek8s.K8sResource + resource = &v1.Pod{ + Status: v1.PodStatus{ + Phase: podPhase, + }, + } + handler := &sidecarResourceHandler{} + taskCtx := dummyContainerTaskContext(resourceRequirements) + status, _, err := handler.GetTaskStatus(context.TODO(), taskCtx, resource) + assert.Nil(t, err) + assert.Equal(t, expectedTaskPhase, status.Phase) + } +} + +func TestDemystifiedSidecarStatus_PrimaryFailed(t *testing.T) { + var resource flytek8s.K8sResource + resource = &v1.Pod{ + Status: v1.PodStatus{ + Phase: v1.PodRunning, + ContainerStatuses: []v1.ContainerStatus{ + { + Name: "Primary", + State: v1.ContainerState{ + Terminated: &v1.ContainerStateTerminated{ + ExitCode: 1, + }, + }, + }, + }, + }, + } + resource.SetAnnotations(map[string]string{ + primaryContainerKey: "Primary", + }) + handler := &sidecarResourceHandler{} + taskCtx := dummyContainerTaskContext(resourceRequirements) + status, _, err := handler.GetTaskStatus(context.TODO(), taskCtx, resource) + assert.Nil(t, err) + assert.Equal(t, types.TaskPhaseRetryableFailure, status.Phase) +} + +func TestDemystifiedSidecarStatus_PrimarySucceeded(t *testing.T) { + var resource flytek8s.K8sResource + resource = &v1.Pod{ + Status: v1.PodStatus{ + Phase: v1.PodRunning, + ContainerStatuses: []v1.ContainerStatus{ + { + Name: "Primary", + State: v1.ContainerState{ + Terminated: &v1.ContainerStateTerminated{ + ExitCode: 0, + }, + }, + }, + }, + }, + } + resource.SetAnnotations(map[string]string{ + primaryContainerKey: "Primary", + }) + handler := &sidecarResourceHandler{} + taskCtx := dummyContainerTaskContext(resourceRequirements) + status, _, err := handler.GetTaskStatus(context.TODO(), taskCtx, resource) + assert.Nil(t, err) + assert.Equal(t, types.TaskPhaseSucceeded, status.Phase) +} + +func TestDemystifiedSidecarStatus_PrimaryRunning(t *testing.T) { + var resource flytek8s.K8sResource + resource = &v1.Pod{ + Status: v1.PodStatus{ + Phase: v1.PodRunning, + ContainerStatuses: []v1.ContainerStatus{ + { + Name: "Primary", + State: v1.ContainerState{ + Waiting: &v1.ContainerStateWaiting{ + Reason: "stay patient", + }, + }, + }, + }, + }, + } + resource.SetAnnotations(map[string]string{ + primaryContainerKey: "Primary", + }) + handler := &sidecarResourceHandler{} + taskCtx := dummyContainerTaskContext(resourceRequirements) + status, _, err := handler.GetTaskStatus(context.TODO(), taskCtx, resource) + assert.Nil(t, err) + assert.Equal(t, types.TaskPhaseRunning, status.Phase) +} + +func TestDemystifiedSidecarStatus_PrimaryMissing(t *testing.T) { + var resource flytek8s.K8sResource + resource = &v1.Pod{ + Status: v1.PodStatus{ + Phase: v1.PodRunning, + ContainerStatuses: []v1.ContainerStatus{ + { + Name: "Secondary", + }, + }, + }, + } + resource.SetAnnotations(map[string]string{ + primaryContainerKey: "Primary", + }) + handler := &sidecarResourceHandler{} + taskCtx := dummyContainerTaskContext(resourceRequirements) + status, _, err := handler.GetTaskStatus(context.TODO(), taskCtx, resource) + assert.Nil(t, err) + assert.Equal(t, types.TaskPhasePermanentFailure, status.Phase) +} diff --git a/flyteplugins/go/tasks/v1/k8splugins/spark.go b/flyteplugins/go/tasks/v1/k8splugins/spark.go new file mode 100755 index 0000000000..de19a37828 --- /dev/null +++ b/flyteplugins/go/tasks/v1/k8splugins/spark.go @@ -0,0 +1,304 @@ +package k8splugins + +import ( + "context" + "fmt" + + "github.com/lyft/flyteplugins/go/tasks/v1/flytek8s/config" + "github.com/lyft/flyteplugins/go/tasks/v1/logs" + + v1 "github.com/lyft/flyteplugins/go/tasks/v1" + "github.com/lyft/flyteplugins/go/tasks/v1/events" + "github.com/lyft/flyteplugins/go/tasks/v1/utils" + + "k8s.io/client-go/kubernetes/scheme" + + sparkOp "github.com/GoogleCloudPlatform/spark-on-k8s-operator/pkg/apis/sparkoperator.k8s.io/v1beta1" + logUtils "github.com/lyft/flyteidl/clients/go/coreutils/logs" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/lyft/flyteplugins/go/tasks/v1/errors" + "github.com/lyft/flyteplugins/go/tasks/v1/flytek8s" + "github.com/lyft/flyteplugins/go/tasks/v1/types" + + pluginsConfig "github.com/lyft/flyteplugins/go/tasks/v1/config" +) + +const KindSparkApplication = "SparkApplication" +const sparkDriverUI = "sparkDriverUI" +const sparkHistoryUI = "sparkHistoryUI" + +var sparkTaskType = "spark" + +// Spark-specific configs +type SparkConfig struct { + DefaultSparkConfig map[string]string `json:"spark-config-default" pflag:",Key value pairs of default spark configuration that should be applied to every SparkJob"` + SparkHistoryServerURL string `json:"spark-history-server-url" pflag:",URL for SparkHistory Server that each job will publish the execution history to."` +} + +var ( + sparkConfigSection = pluginsConfig.MustRegisterSubSection("spark", &SparkConfig{}) +) + +func GetSparkConfig() *SparkConfig { + return sparkConfigSection.GetConfig().(*SparkConfig) +} + +// This method should be used for unit testing only +func setSparkConfig(cfg *SparkConfig) error { + return sparkConfigSection.SetConfig(cfg) +} + +type sparkResourceHandler struct { +} + +// Creates a new Job that will execute the main container as well as any generated types the result from the execution. +func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx types.TaskContext, task *core.TaskTemplate, inputs *core.LiteralMap) (flytek8s.K8sResource, error) { + + sparkJob := plugins.SparkJob{} + err := utils.UnmarshalStruct(task.GetCustom(), &sparkJob) + if err != nil { + return nil, errors.Errorf(errors.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", task.GetCustom(), err.Error()) + } + + annotations := flytek8s.UnionMaps(config.GetK8sPluginConfig().DefaultAnnotations, utils.CopyMap(taskCtx.GetAnnotations())) + labels := flytek8s.UnionMaps(config.GetK8sPluginConfig().DefaultLabels, utils.CopyMap(taskCtx.GetLabels())) + container := task.GetContainer() + + envVars := flytek8s.DecorateEnvVars(ctx, flytek8s.ToK8sEnvVar(task.GetContainer().GetEnv()), taskCtx.GetTaskExecutionID()) + + sparkEnvVars := make(map[string]string) + for _, envVar := range envVars { + sparkEnvVars[envVar.Name] = envVar.Value + } + + driverSpec := sparkOp.DriverSpec{ + SparkPodSpec: sparkOp.SparkPodSpec{ + Annotations: annotations, + Labels: labels, + EnvVars: sparkEnvVars, + Image: &container.Image, + }, + ServiceAccount: &sparkTaskType, + } + + executorSpec := sparkOp.ExecutorSpec{ + SparkPodSpec: sparkOp.SparkPodSpec{ + Annotations: annotations, + Labels: labels, + Image: &container.Image, + EnvVars: sparkEnvVars, + }, + } + + modifiedArgs, err := utils.ReplaceTemplateCommandArgs(context.TODO(), + task.GetContainer().GetArgs(), + utils.CommandLineTemplateArgs{ + Input: taskCtx.GetInputsFile().String(), + OutputPrefix: taskCtx.GetDataDir().String(), + Inputs: utils.LiteralMapToTemplateArgs(context.TODO(), inputs), + }) + + if err != nil { + return nil, err + } + + // Hack: Retry submit failures in-case of resource limits hit. + submissionFailureRetries := int32(14) + // Start with default config values. + sparkConfig := make(map[string]string) + for k, v := range GetSparkConfig().DefaultSparkConfig { + sparkConfig[k] = v + } + + if sparkJob.GetExecutorPath() != "" { + sparkConfig["spark.pyspark.python"] = sparkJob.GetExecutorPath() + sparkConfig["spark.pyspark.driver.python"] = sparkJob.GetExecutorPath() + } + + for k, v := range sparkJob.GetSparkConf() { + sparkConfig[k] = v + } + + // Set pod limits. + if sparkConfig["spark.kubernetes.driver.limit.cores"] == "" && sparkConfig["spark.driver.cores"] != "" { + sparkConfig["spark.kubernetes.driver.limit.cores"] = sparkConfig["spark.driver.cores"] + } + if sparkConfig["spark.kubernetes.executor.limit.cores"] == "" && sparkConfig["spark.executor.cores"] != "" { + sparkConfig["spark.kubernetes.executor.limit.cores"] = sparkConfig["spark.executor.cores"] + } + + j := &sparkOp.SparkApplication{ + TypeMeta: metav1.TypeMeta{ + Kind: KindSparkApplication, + APIVersion: sparkOp.SchemeGroupVersion.String(), + }, + Spec: sparkOp.SparkApplicationSpec{ + ServiceAccount: &sparkTaskType, + Type: getApplicationType(sparkJob.GetApplicationType()), + Mode: sparkOp.ClusterMode, + Image: &container.Image, + Arguments: modifiedArgs, + Driver: driverSpec, + Executor: executorSpec, + SparkConf: sparkConfig, + HadoopConf: sparkJob.GetHadoopConf(), + // SubmissionFailures handled here. Task Failures handled at Propeller/Job level. + RestartPolicy: sparkOp.RestartPolicy{ + Type: sparkOp.OnFailure, + OnSubmissionFailureRetries: &submissionFailureRetries, + }, + }, + } + + if sparkJob.MainApplicationFile != "" { + j.Spec.MainApplicationFile = &sparkJob.MainApplicationFile + } else if sparkJob.MainClass != "" { + j.Spec.MainClass = &sparkJob.MainClass + } + + return j, nil +} + +// Convert SparkJob ApplicationType to Operator CRD ApplicationType +func getApplicationType(applicationType plugins.SparkApplication_Type) sparkOp.SparkApplicationType { + switch applicationType { + case plugins.SparkApplication_PYTHON: + return sparkOp.PythonApplicationType + case plugins.SparkApplication_JAVA: + return sparkOp.JavaApplicationType + case plugins.SparkApplication_SCALA: + return sparkOp.ScalaApplicationType + case plugins.SparkApplication_R: + return sparkOp.RApplicationType + } + return sparkOp.PythonApplicationType +} + +func (sparkResourceHandler) BuildIdentityResource(ctx context.Context, taskCtx types.TaskContext) (flytek8s.K8sResource, error) { + return &sparkOp.SparkApplication{ + TypeMeta: metav1.TypeMeta{ + Kind: KindSparkApplication, + APIVersion: sparkOp.SchemeGroupVersion.String(), + }, + }, nil +} + +func getEventInfoForSpark(sj *sparkOp.SparkApplication) (*events.TaskEventInfo, error) { + var taskLogs []*core.TaskLog + customInfoMap := make(map[string]string) + + logConfig := logs.GetLogConfig() + if logConfig.IsKubernetesEnabled && sj.Status.DriverInfo.PodName != "" { + k8sLog, err := logUtils.NewKubernetesLogPlugin(logConfig.KubernetesURL).GetTaskLog( + sj.Status.DriverInfo.PodName, + sj.Namespace, + "", + "", + "Driver Logs (via Kubernetes)") + + if err != nil { + return nil, err + } + taskLogs = append(taskLogs, &k8sLog) + } + + if logConfig.IsCloudwatchEnabled { + cwUserLogs := core.TaskLog{ + Uri: fmt.Sprintf( + "https://console.aws.amazon.com/cloudwatch/home?region=%s#logStream:group=%s;prefix=var.log.containers.%s;streamFilter=typeLogStreamPrefix", + logConfig.CloudwatchRegion, + logConfig.CloudwatchLogGroup, + sj.Status.DriverInfo.PodName), + Name: "User Driver Logs (via Cloudwatch)", + MessageFormat: core.TaskLog_JSON, + } + cwSystemLogs := core.TaskLog{ + Uri: fmt.Sprintf( + "https://console.aws.amazon.com/cloudwatch/home?region=%s#logStream:group=%s;prefix=system_log.var.log.containers.%s;streamFilter=typeLogStreamPrefix", + logConfig.CloudwatchRegion, + logConfig.CloudwatchLogGroup, + sj.Name), + Name: "System Logs (via Cloudwatch)", + MessageFormat: core.TaskLog_JSON, + } + taskLogs = append(taskLogs, &cwUserLogs) + taskLogs = append(taskLogs, &cwSystemLogs) + + } + + // Spark UI. + if sj.Status.AppState.State == sparkOp.FailedState || sj.Status.AppState.State == sparkOp.CompletedState { + if sj.Status.SparkApplicationID != "" && GetSparkConfig().SparkHistoryServerURL != "" { + customInfoMap[sparkHistoryUI] = fmt.Sprintf("%s/history/%s", GetSparkConfig().SparkHistoryServerURL, sj.Status.SparkApplicationID) + // Custom doesn't work unless the UI has a custom plugin to parse this, hence add to Logs as well. + taskLogs = append(taskLogs, &core.TaskLog{ + Uri: customInfoMap[sparkHistoryUI], + Name: "Spark History UI", + MessageFormat: core.TaskLog_JSON, + }) + } + } else if sj.Status.AppState.State == sparkOp.RunningState && sj.Status.DriverInfo.WebUIIngressAddress != "" { + // Append https as the operator doesn't currently. + customInfoMap[sparkDriverUI] = fmt.Sprintf("https://%s", sj.Status.DriverInfo.WebUIIngressAddress) + // Custom doesn't work unless the UI has a custom plugin to parse this, hence add to Logs as well. + taskLogs = append(taskLogs, &core.TaskLog{ + Uri: customInfoMap[sparkDriverUI], + Name: "Spark Driver UI", + MessageFormat: core.TaskLog_JSON, + }) + } + + customInfo, err := utils.MarshalObjToStruct(customInfoMap) + if err != nil { + return nil, err + } + + return &events.TaskEventInfo{ + Logs: taskLogs, + CustomInfo: customInfo, + }, nil +} + +func (sparkResourceHandler) GetTaskStatus(_ context.Context, _ types.TaskContext, r flytek8s.K8sResource) ( + types.TaskStatus, *events.TaskEventInfo, error) { + + app := r.(*sparkOp.SparkApplication) + var status types.TaskStatus + switch app.Status.AppState.State { + case sparkOp.NewState, sparkOp.SubmittedState, sparkOp.PendingSubmissionState: + status = types.TaskStatusQueued + case sparkOp.FailedSubmissionState: + status = types.TaskStatusRetryableFailure(errors.Errorf(errors.DownstreamSystemError, "Spark Job Submission Failed with Error: %s", app.Status.AppState.ErrorMessage)) + case sparkOp.FailedState: + status = types.TaskStatusRetryableFailure(errors.Errorf(errors.DownstreamSystemError, "Spark Job Failed with Error: %s", app.Status.AppState.ErrorMessage)) + case sparkOp.CompletedState: + status = types.TaskStatusSucceeded + default: + status = types.TaskStatusRunning + } + + info, err := getEventInfoForSpark(app) + if err != nil { + return types.TaskStatusUndefined, nil, err + } + + return status, info, nil +} + +func (sparkResourceHandler) PopulateTaskEventInfo(taskCtx types.TaskContext, resource flytek8s.K8sResource) error { + return nil +} + +func init() { + if err := sparkOp.AddToScheme(scheme.Scheme); err != nil { + panic(err) + } + + v1.RegisterLoader(func(ctx context.Context) error { + return v1.K8sRegisterForTaskTypes(sparkTaskType, &sparkOp.SparkApplication{}, + flytek8s.DefaultInformerResyncDuration, sparkResourceHandler{}, sparkTaskType) + }) +} diff --git a/flyteplugins/go/tasks/v1/k8splugins/spark_test.go b/flyteplugins/go/tasks/v1/k8splugins/spark_test.go new file mode 100755 index 0000000000..b17c75e9f6 --- /dev/null +++ b/flyteplugins/go/tasks/v1/k8splugins/spark_test.go @@ -0,0 +1,270 @@ +package k8splugins + +import ( + "context" + "fmt" + "testing" + + "github.com/lyft/flytestdlib/storage" + + "github.com/lyft/flyteplugins/go/tasks/v1/logs" + "github.com/lyft/flyteplugins/go/tasks/v1/types/mocks" + "github.com/lyft/flyteplugins/go/tasks/v1/utils" + + sj "github.com/GoogleCloudPlatform/spark-on-k8s-operator/pkg/apis/sparkoperator.k8s.io/v1beta1" + "github.com/golang/protobuf/jsonpb" + structpb "github.com/golang/protobuf/ptypes/struct" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins" + "github.com/stretchr/testify/assert" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/lyft/flyteplugins/go/tasks/v1/types" +) + +const sparkApplicationFile = "local:///spark_app.py" +const testImage = "image://" +const sparkUIAddress = "spark-ui.flyte" + +var ( + dummySparkConf = map[string]string{ + "spark.driver.memory": "500M", + "spark.driver.cores": "1", + "spark.executor.cores": "1", + "spark.executor.instances": "3", + "spark.executor.memory": "500M", + } + + dummyEnvVars = []*core.KeyValuePair{ + {Key: "Env_Var", Value: "Env_Val"}, + } + + testArgs = []string{ + "execute-spark-task", + } +) + +func TestGetApplicationType(t *testing.T) { + assert.Equal(t, getApplicationType(plugins.SparkApplication_PYTHON), sj.PythonApplicationType) + assert.Equal(t, getApplicationType(plugins.SparkApplication_R), sj.RApplicationType) + assert.Equal(t, getApplicationType(plugins.SparkApplication_JAVA), sj.JavaApplicationType) + assert.Equal(t, getApplicationType(plugins.SparkApplication_SCALA), sj.ScalaApplicationType) +} + +func TestGetEventInfo(t *testing.T) { + assert.NoError(t, logs.SetLogConfig(&logs.LogConfig{ + IsCloudwatchEnabled: true, + CloudwatchRegion: "us-east-1", + CloudwatchLogGroup: "/kubernetes/flyte", + IsKubernetesEnabled: true, + KubernetesURL: "k8s.com", + })) + info, err := getEventInfoForSpark(dummySparkApplication(sj.RunningState)) + assert.NoError(t, err) + assert.Equal(t, fmt.Sprintf("https://%s", sparkUIAddress), info.CustomInfo.Fields[sparkDriverUI].GetStringValue()) + assert.Equal(t, "k8s.com/#!/log/spark-namespace/spark-pod/pod?namespace=spark-namespace", info.Logs[0].Uri) + assert.Equal(t, "https://console.aws.amazon.com/cloudwatch/home?region=us-east-1#logStream:group=/kubernetes/flyte;prefix=var.log.containers.spark-pod;streamFilter=typeLogStreamPrefix", info.Logs[1].Uri) + assert.Equal(t, "https://console.aws.amazon.com/cloudwatch/home?region=us-east-1#logStream:group=/kubernetes/flyte;prefix=system_log.var.log.containers.spark-app-name;streamFilter=typeLogStreamPrefix", info.Logs[2].Uri) + assert.Equal(t, "https://spark-ui.flyte", info.Logs[3].Uri) + + assert.NoError(t, setSparkConfig(&SparkConfig{ + SparkHistoryServerURL: "spark-history.flyte", + })) + + info, err = getEventInfoForSpark(dummySparkApplication(sj.FailedState)) + assert.NoError(t, err) + assert.Equal(t, "spark-history.flyte/history/app-id", info.CustomInfo.Fields[sparkHistoryUI].GetStringValue()) + assert.Equal(t, "k8s.com/#!/log/spark-namespace/spark-pod/pod?namespace=spark-namespace", info.Logs[0].Uri) + assert.Equal(t, "https://console.aws.amazon.com/cloudwatch/home?region=us-east-1#logStream:group=/kubernetes/flyte;prefix=var.log.containers.spark-pod;streamFilter=typeLogStreamPrefix", info.Logs[1].Uri) + assert.Equal(t, "https://console.aws.amazon.com/cloudwatch/home?region=us-east-1#logStream:group=/kubernetes/flyte;prefix=system_log.var.log.containers.spark-app-name;streamFilter=typeLogStreamPrefix", info.Logs[2].Uri) + assert.Equal(t, "spark-history.flyte/history/app-id", info.Logs[3].Uri) +} + +func TestGetTaskStatus(t *testing.T) { + sparkResourceHandler := sparkResourceHandler{} + + ctx := context.TODO() + taskStatus, i, err := sparkResourceHandler.GetTaskStatus(ctx, nil, dummySparkApplication(sj.NewState)) + assert.NoError(t, err) + assert.Equal(t, taskStatus.Phase, types.TaskPhaseQueued) + assert.NotNil(t, i) + assert.Nil(t, err) + + taskStatus, i, err = sparkResourceHandler.GetTaskStatus(ctx, nil, dummySparkApplication(sj.SubmittedState)) + assert.NoError(t, err) + assert.Equal(t, taskStatus.Phase, types.TaskPhaseQueued) + assert.NotNil(t, i) + assert.Nil(t, err) + + taskStatus, i, err = sparkResourceHandler.GetTaskStatus(ctx, nil, dummySparkApplication(sj.RunningState)) + assert.NoError(t, err) + assert.Equal(t, taskStatus.Phase, types.TaskPhaseRunning) + assert.NotNil(t, i) + assert.Nil(t, err) + + taskStatus, i, err = sparkResourceHandler.GetTaskStatus(ctx, nil, dummySparkApplication(sj.CompletedState)) + assert.NoError(t, err) + assert.Equal(t, taskStatus.Phase, types.TaskPhaseSucceeded) + assert.NotNil(t, i) + assert.Nil(t, err) + + taskStatus, i, err = sparkResourceHandler.GetTaskStatus(ctx, nil, dummySparkApplication(sj.InvalidatingState)) + assert.NoError(t, err) + assert.Equal(t, taskStatus.Phase, types.TaskPhaseRunning) + assert.NotNil(t, i) + assert.Nil(t, err) + + taskStatus, i, err = sparkResourceHandler.GetTaskStatus(ctx, nil, dummySparkApplication(sj.FailingState)) + assert.NoError(t, err) + assert.Equal(t, taskStatus.Phase, types.TaskPhaseRunning) + assert.NotNil(t, i) + assert.Nil(t, err) + + taskStatus, i, err = sparkResourceHandler.GetTaskStatus(ctx, nil, dummySparkApplication(sj.PendingRerunState)) + assert.NoError(t, err) + assert.Equal(t, taskStatus.Phase, types.TaskPhaseRunning) + assert.NotNil(t, i) + assert.Nil(t, err) + + taskStatus, i, err = sparkResourceHandler.GetTaskStatus(ctx, nil, dummySparkApplication(sj.SucceedingState)) + assert.NoError(t, err) + assert.Equal(t, taskStatus.Phase, types.TaskPhaseRunning) + assert.NotNil(t, i) + assert.Nil(t, err) + + taskStatus, i, err = sparkResourceHandler.GetTaskStatus(ctx, nil, dummySparkApplication(sj.FailedSubmissionState)) + assert.NoError(t, err) + assert.Equal(t, taskStatus.Phase, types.TaskPhaseRetryableFailure) + assert.NotNil(t, i) + assert.Nil(t, err) + + taskStatus, i, err = sparkResourceHandler.GetTaskStatus(ctx, nil, dummySparkApplication(sj.FailedState)) + assert.NoError(t, err) + assert.Equal(t, taskStatus.Phase, types.TaskPhaseRetryableFailure) + assert.NotNil(t, i) + assert.Nil(t, err) +} + +func dummySparkApplication(state sj.ApplicationStateType) *sj.SparkApplication { + + return &sj.SparkApplication{ + ObjectMeta: v1.ObjectMeta{ + Name: "spark-app-name", + Namespace: "spark-namespace", + }, + Status: sj.SparkApplicationStatus{ + SparkApplicationID: "app-id", + AppState: sj.ApplicationState{ + State: state, + }, + DriverInfo: sj.DriverInfo{ + PodName: "spark-pod", + WebUIIngressAddress: sparkUIAddress, + }, + ExecutionAttempts: 1, + }, + } +} + +func dummySparkCustomObj() *plugins.SparkJob { + sparkJob := plugins.SparkJob{} + + sparkJob.MainApplicationFile = sparkApplicationFile + sparkJob.SparkConf = dummySparkConf + sparkJob.ApplicationType = plugins.SparkApplication_PYTHON + return &sparkJob +} + +func dummySparkTaskTemplate(id string) *core.TaskTemplate { + + sparkJob := dummySparkCustomObj() + sparkJobJSON, err := utils.MarshalToString(sparkJob) + if err != nil { + panic(err) + } + + structObj := structpb.Struct{} + + err = jsonpb.UnmarshalString(sparkJobJSON, &structObj) + if err != nil { + panic(err) + } + + return &core.TaskTemplate{ + Id: &core.Identifier{Name: id}, + Type: "container", + Target: &core.TaskTemplate_Container{ + Container: &core.Container{ + Image: testImage, + Args: testArgs, + Env: dummyEnvVars, + }, + }, + Custom: &structObj, + } +} + +func dummySparkTaskContext() types.TaskContext { + taskCtx := &mocks.TaskContext{} + taskCtx.On("GetNamespace").Return("test-namespace") + taskCtx.On("GetAnnotations").Return(map[string]string{"annotation-1": "val1"}) + taskCtx.On("GetLabels").Return(map[string]string{"label-1": "val1"}) + taskCtx.On("GetOwnerReference").Return(v1.OwnerReference{ + Kind: "node", + Name: "blah", + }) + taskCtx.On("GetDataDir").Return(storage.DataReference("/data/")) + taskCtx.On("GetInputsFile").Return(storage.DataReference("/input")) + + tID := &mocks.TaskExecutionID{} + tID.On("GetID").Return(core.TaskExecutionIdentifier{ + NodeExecutionId: &core.NodeExecutionIdentifier{ + ExecutionId: &core.WorkflowExecutionIdentifier{ + Name: "my_name", + Project: "my_project", + Domain: "my_domain", + }, + }, + }) + tID.On("GetGeneratedName").Return("some-acceptable-name") + taskCtx.On("GetTaskExecutionID").Return(tID) + return taskCtx +} + +func TestBuildResourceSpark(t *testing.T) { + sparkResourceHandler := sparkResourceHandler{} + + // Case1: Valid Spark Task-Template + taskTemplate := dummySparkTaskTemplate("blah-1") + + resource, err := sparkResourceHandler.BuildResource(context.TODO(), dummySparkTaskContext(), taskTemplate, nil) + assert.Nil(t, err) + + assert.NotNil(t, resource) + sparkApp, ok := resource.(*sj.SparkApplication) + assert.True(t, ok) + assert.Equal(t, sparkApplicationFile, *sparkApp.Spec.MainApplicationFile) + assert.Equal(t, sj.PythonApplicationType, sparkApp.Spec.Type) + assert.Equal(t, testArgs, sparkApp.Spec.Arguments) + assert.Equal(t, testImage, *sparkApp.Spec.Image) + + for confKey, confVal := range dummySparkConf { + exists := false + for k, v := range sparkApp.Spec.SparkConf { + if k == confKey { + assert.Equal(t, v, confVal) + exists = true + } + } + assert.True(t, exists) + } + + assert.Equal(t, dummySparkConf["spark.driver.cores"], sparkApp.Spec.SparkConf["spark.kubernetes.driver.limit.cores"]) + assert.Equal(t, dummySparkConf["spark.executor.cores"], sparkApp.Spec.SparkConf["spark.kubernetes.executor.limit.cores"]) + + // Case2: Invalid Spark Task-Template + taskTemplate.Custom = nil + resource, err = sparkResourceHandler.BuildResource(context.TODO(), dummySparkTaskContext(), taskTemplate, nil) + assert.NotNil(t, err) + assert.Nil(t, resource) +} diff --git a/flyteplugins/go/tasks/v1/k8splugins/waitable_task.go b/flyteplugins/go/tasks/v1/k8splugins/waitable_task.go new file mode 100755 index 0000000000..2a5f82eeef --- /dev/null +++ b/flyteplugins/go/tasks/v1/k8splugins/waitable_task.go @@ -0,0 +1,552 @@ +package k8splugins + +import ( + "context" + "encoding/json" + "fmt" + "reflect" + + eventErrors "github.com/lyft/flyteidl/clients/go/events/errors" + pluginsConfig "github.com/lyft/flyteplugins/go/tasks/v1/config" + + "github.com/lyft/flytestdlib/config" + + "github.com/golang/protobuf/jsonpb" + "github.com/lyft/flyteidl/clients/go/admin" + admin2 "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/service" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/storage" + utils2 "github.com/lyft/flytestdlib/utils" + v1 "k8s.io/api/core/v1" + + tasksV1 "github.com/lyft/flyteplugins/go/tasks/v1" + "github.com/lyft/flyteplugins/go/tasks/v1/errors" + "github.com/lyft/flyteplugins/go/tasks/v1/events" + "github.com/lyft/flyteplugins/go/tasks/v1/flytek8s" + "github.com/lyft/flyteplugins/go/tasks/v1/types" + "github.com/lyft/flyteplugins/go/tasks/v1/utils" +) + +const ( + waitableTaskType = "waitable" + waitablesKey = "waitables" + handoverKey = "handover" +) + +var marshaler = &jsonpb.Marshaler{} + +var ( + defaultWaitableConfig = &WaitableConfig{ + LruCacheSize: 1000, + } + + waitableConfigSection = pluginsConfig.MustRegisterSubSection("waitable", defaultWaitableConfig) +) + +type WaitableConfig struct { + // TODO: Reconsider this once we plug Console as a Log Provider plugin. + ConsoleURI config.URL `json:"console-uri" pflag:",URI for console. Used to expose links in the emitted events."` + LruCacheSize int `json:"lru-cache-size" pflag:",Size of the AutoRefreshCache"` +} + +func GetWaitableConfig() *WaitableConfig { + return waitableConfigSection.GetConfig().(*WaitableConfig) +} + +// WaitableTaskExecutor plugin handle tasks of type "waitable". These tasks take as input plugins.Waitable protos. Each +// proto includes an execution id. The plugin asynchronously waits for all executions to reach a terminal state before +// passing the execution over to a ContainerTaskExecutor plugin. +type waitableTaskExecutor struct { + *flytek8s.K8sTaskExecutor + containerTaskExecutor + adminClient service.AdminServiceClient + executionsCache utils2.AutoRefreshCache + dataStore *storage.DataStore + recorder types.EventRecorder +} + +// Wrapper object for plugins.waitable to implement CacheItem interface required to use AutoRefreshCache. +type waitableWrapper struct { + *plugins.Waitable +} + +// A unique identifier for the Waitable instance used to dedupe AutoRefreshCache. +func (d *waitableWrapper) ID() string { + if d.WfExecId != nil { + return d.WfExecId.String() + } + + return "" +} + +// An override to the default json marshaler to use jsonPb. +func (d *waitableWrapper) MarshalJSON() ([]byte, error) { + s, err := marshaler.MarshalToString(d.Waitable) + if err != nil { + return nil, err + } + + return []byte(s), nil +} + +// An override to the default json unmarshaler to use jsonPb. +func (d *waitableWrapper) UnmarshalJSON(b []byte) error { + w := plugins.Waitable{} + err := jsonpb.UnmarshalString(string(b), &w) + if err != nil { + return err + } + + d.Waitable = &w + return nil +} + +// Traverses a literal to find all Waitable objects. +func discoverWaitableInputs(l *core.Literal) (literals []*core.Literal, waitables []*waitableWrapper) { + switch o := l.Value.(type) { + case *core.Literal_Collection: + literals := make([]*core.Literal, 0, len(o.Collection.Literals)) + waitables := make([]*waitableWrapper, 0, len(o.Collection.Literals)) + for _, i := range o.Collection.Literals { + ls, ws := discoverWaitableInputs(i) + literals = append(literals, ls...) + waitables = append(waitables, ws...) + } + + return literals, waitables + case *core.Literal_Map: + literals := make([]*core.Literal, 0, len(o.Map.Literals)) + waitables := make([]*waitableWrapper, 0, len(o.Map.Literals)) + for _, i := range o.Map.Literals { + ls, ws := discoverWaitableInputs(i) + literals = append(literals, ls...) + waitables = append(waitables, ws...) + } + + return literals, waitables + case *core.Literal_Scalar: + switch v := o.Scalar.Value.(type) { + case *core.Scalar_Generic: + waitable := &plugins.Waitable{} + err := utils.UnmarshalStruct(v.Generic, waitable) + if err != nil { + // skip, it's just a different type? + return []*core.Literal{}, []*waitableWrapper{} + } + + return []*core.Literal{l}, []*waitableWrapper{{Waitable: waitable}} + } + } + + return []*core.Literal{}, []*waitableWrapper{} +} + +// Generates workflow execution links as log links. +func generateWorkflowExecutionLinks(ctx context.Context, waitables []*waitableWrapper) []*core.TaskLog { + cfg := GetWaitableConfig() + if cfg == nil { + logger.Info(ctx, "No console endpoint config, skipping the generation of execution links.") + return []*core.TaskLog{} + } + + logs := make([]*core.TaskLog, 0, len(waitables)) + for _, w := range waitables { + logs = append(logs, &core.TaskLog{ + Name: fmt.Sprintf("Exec: %v (%v)", w.WfExecId.Name, w.Phase), + Uri: fmt.Sprintf("%v/projects/%v/domains/%v/executions/%v", cfg.ConsoleURI.String(), + w.WfExecId.Project, w.WfExecId.Domain, w.WfExecId.Name), + }) + } + + return logs +} + +// Shadows k8sExecutor initialize to capture some of the initialization params. +func (w *waitableTaskExecutor) Initialize(ctx context.Context, params types.ExecutorInitializationParameters) error { + w.dataStore = params.DataStore + w.recorder = params.EventRecorder + err := w.K8sTaskExecutor.Initialize(ctx, params) + if err != nil { + return err + } + + // We assign a mock one in tests, so let's not override if it's already there + w.executionsCache, err = utils2.NewAutoRefreshCache(w.syncItem, utils2.NewRateLimiter( + "admin-get-executions", 50, 50), + flytek8s.DefaultInformerResyncDuration, GetWaitableConfig().LruCacheSize, + params.MetricsScope.NewSubScope(waitableTaskType)) + if err != nil { + return err + } + + w.executionsCache.Start(ctx) + return nil +} + +// Shadows k8sExecutor +func (w waitableTaskExecutor) StartTask(ctx context.Context, taskCtx types.TaskContext, task *core.TaskTemplate, inputs *core.LiteralMap) ( + types.TaskStatus, error) { + + _, allWaitables := discoverWaitableInputs(&core.Literal{Value: &core.Literal_Map{Map: inputs}}) + for _, waitable := range allWaitables { + _, err := w.executionsCache.GetOrCreate(waitable) + if err != nil { + // System failure + return types.TaskStatusUndefined, err + } + } + + state := map[string]interface{}{ + waitablesKey: allWaitables, + } + + status := types.TaskStatus{Phase: types.TaskPhaseQueued, PhaseVersion: taskCtx.GetPhaseVersion(), State: state} + ev := events.CreateEvent(taskCtx, status, &events.TaskEventInfo{ + Logs: generateWorkflowExecutionLinks(ctx, allWaitables), + }) + + err := w.recorder.RecordTaskEvent(ctx, ev) + if err != nil && eventErrors.IsEventAlreadyInTerminalStateError(err) { + return types.TaskStatusPermanentFailure(errors.Wrapf(errors.TaskEventRecordingFailed, err, + "failed to record task event. phase mis-match between Propeller %v and Control Plane.", &status.Phase)), nil + } else if err != nil { + logger.Errorf(ctx, "Failed to record task event [%v]. Error: %v", ev, err) + return types.TaskStatusUndefined, errors.Wrapf(errors.TaskEventRecordingFailed, err, + "failed to record start task event") + } + + return status, nil +} + +func isTerminalWorkflowPhase(phase core.WorkflowExecution_Phase) bool { + return phase == core.WorkflowExecution_SUCCEEDED || + phase == core.WorkflowExecution_FAILED || + phase == core.WorkflowExecution_ABORTED +} + +func toWaitableWrapperSlice(ctx context.Context, sliceInterface interface{}) ([]*waitableWrapper, error) { + waitables, casted := sliceInterface.([]*waitableWrapper) + if casted { + return waitables, nil + } + + waitablesIfaceSlice, casted := sliceInterface.([]interface{}) + if !casted { + err := fmt.Errorf("failed to cast interface to watiableWrapper. Actual type: %v, Allowed types: [%v, %v]", + reflect.TypeOf(sliceInterface), + reflect.TypeOf([]interface{}{}), + reflect.TypeOf([]*waitableWrapper{})) + logger.Error(ctx, err) + + return []*waitableWrapper{}, err + } + + waitables = make([]*waitableWrapper, 0, len(waitablesIfaceSlice)) + for _, item := range waitablesIfaceSlice { + raw, err := json.Marshal(item) + if err != nil { + return nil, errors.Wrapf(errors.RuntimeFailure, err, "failed to marshal state item into json") + } + + waitable := &waitableWrapper{} + err = json.Unmarshal(raw, waitable) + if err != nil { + return nil, errors.Wrapf(errors.RuntimeFailure, err, "failed to unmarshal state into waitable wrapper") + } + + waitables = append(waitables, waitable) + } + + return waitables, nil +} + +func (w waitableTaskExecutor) getUpdatedWaitables(ctx context.Context, taskCtx types.TaskContext) ( + updatedWaitables []*waitableWrapper, terminatedCount int, hasChanged bool, err error) { + + state := taskCtx.GetCustomState() + if state == nil { + return []*waitableWrapper{}, 0, false, nil + } + + sliceInterface, found := state[waitablesKey] + if !found { + return []*waitableWrapper{}, 0, false, nil + } + + allWaitables, err := toWaitableWrapperSlice(ctx, sliceInterface) + if err != nil { + return []*waitableWrapper{}, 0, false, err + } + + updatedWaitables = make([]*waitableWrapper, 0, len(allWaitables)) + allDone := 0 + hasChanged = false + for _, waitable := range allWaitables { + if !isTerminalWorkflowPhase(waitable.GetPhase()) { + w, err := w.executionsCache.GetOrCreate(waitable) + if err != nil { + return nil, 0, false, err + } + + newWaitable := w.(*waitableWrapper) + if newWaitable.Phase != waitable.Phase { + hasChanged = true + } + + waitable = newWaitable + } + + if isTerminalWorkflowPhase(waitable.GetPhase()) { + allDone++ + } + + updatedWaitables = append(updatedWaitables, waitable) + } + + return updatedWaitables, allDone, hasChanged, nil +} + +func updateWaitableLiterals(literals []*core.Literal, waitables []*waitableWrapper) error { + index := make(map[string]*plugins.Waitable, len(waitables)) + for _, w := range waitables { + index[w.WfExecId.String()] = w.Waitable + } + + for _, l := range literals { + orig := &plugins.Waitable{} + if err := utils.UnmarshalStruct(l.GetScalar().GetGeneric(), orig); err != nil { + return err + } + + newW, found := index[orig.WfExecId.String()] + if !found { + return fmt.Errorf("couldn't find a waitable corresponding to literal WfID: %v", orig.WfExecId.String()) + } + + if err := utils.MarshalStruct(newW, l.GetScalar().GetGeneric()); err != nil { + return err + } + } + + return nil +} + +// Shadows K8sExecutor CheckTaskStatus to check for in-progress workflow executions and only schedule the container when +// they have reached a terminal state. +func (w waitableTaskExecutor) CheckTaskStatus(ctx context.Context, taskCtx types.TaskContext, task *core.TaskTemplate) ( + types.TaskStatus, error) { + + // Have we handed over execution to the container plugin already? if so, then go straight to K8sExecutor to check + // status. + _, handOver := taskCtx.GetCustomState()[handoverKey] + logger.Debugf(ctx, "Hand over state: %v", handOver) + + if taskCtx.GetPhase() == types.TaskPhaseQueued && !handOver { + logger.Infof(ctx, "Monitoring Launch Plan Execution.") + // Unmarshal custom state and get latest known state for the watched executions. + allWaitables, allDone, hasChanged, err := w.getUpdatedWaitables(ctx, taskCtx) + if err != nil { + return types.TaskStatusRetryableFailure(err), nil + } + + // If all executions reached terminal states, hand over execution to K8sExecutor to launch the user container. + if allDone == len(allWaitables) { + logger.Infof(ctx, "All [%v] waitables have finished executing.", allDone) + inputs := &core.LiteralMap{} + // Rewrite inputs using the latest phases. + // 1. Read inputs + if err = w.dataStore.ReadProtobuf(ctx, taskCtx.GetInputsFile(), inputs); err != nil { + logger.Errorf(ctx, "Failed to read inputs file [%v]. Error: %v", taskCtx.GetInputsFile(), err) + return types.TaskStatusUndefined, err + } + + // 2. Get pointers to literals that contain waitables. + literals, _ := discoverWaitableInputs(&core.Literal{Value: &core.Literal_Map{Map: inputs}}) + logger.Debugf(ctx, "Discovered literals with waitables [%v].", len(literals)) + + // 3. Find the corresponding waitables and update original literals. + if err = updateWaitableLiterals(literals, allWaitables); err != nil { + logger.Errorf(ctx, "Failed to update Waitable literals. Error: %v", err) + return types.TaskStatusUndefined, err + } + + // 4. Write back results into inputs path. + // TODO: Read after update consistency? + if err = w.dataStore.WriteProtobuf(ctx, taskCtx.GetInputsFile(), storage.Options{}, inputs); err != nil { + logger.Errorf(ctx, "Failed to write inputs file back [%v]. Error: %v", taskCtx.GetInputsFile(), err) + return types.TaskStatusUndefined, err + } + + status, err := w.K8sTaskExecutor.StartTask(ctx, taskCtx, task, inputs) + if err != nil { + logger.Errorf(ctx, "Failed to start k8s task. Error: %v", err) + return status, err + } + + logger.Info(ctx, "Launched k8s task with status [%v]", status) + + // If launching the container resulted in queued state, then we need to advance phase version. + if status.Phase == types.TaskPhaseQueued { + logger.Debugf(ctx, "StartTask returned queued state.") + status.PhaseVersion = taskCtx.GetPhaseVersion() + 1 + } + + // Indicate we are done waiting and have launched the task so that next time we go straight to the + // container plugin. + state := taskCtx.GetCustomState() + state[handoverKey] = true + + return status.WithState(state), nil + } else if hasChanged { + logger.Infof(ctx, "Some waitable state has changed. Sending event.") + status := types.TaskStatus{ + Phase: types.TaskPhaseQueued, + PhaseVersion: taskCtx.GetPhaseVersion() + 1, + } + + ev := events.CreateEvent(taskCtx, status, &events.TaskEventInfo{ + Logs: generateWorkflowExecutionLinks(ctx, allWaitables), + }) + + err := w.recorder.RecordTaskEvent(ctx, ev) + if err != nil && eventErrors.IsEventAlreadyInTerminalStateError(err) { + return types.TaskStatusPermanentFailure(errors.Wrapf(errors.TaskEventRecordingFailed, err, + "failed to record task event. phase mis-match between Propeller %v and Control Plane.", &status.Phase)), nil + } else if err != nil { + logger.Errorf(ctx, "Failed to record task event [%v]. Error: %v", ev, err) + return types.TaskStatusUndefined, errors.Wrapf(errors.TaskEventRecordingFailed, err, + "failed to record task event") + } + return status.WithState(map[string]interface{}{ + waitablesKey: allWaitables, + }), nil + } + + logger.Debugf(ctx, "No update to waitables in this round.") + return types.TaskStatus{ + Phase: types.TaskPhaseQueued, + PhaseVersion: taskCtx.GetPhaseVersion(), + State: map[string]interface{}{ + waitablesKey: allWaitables, + }}, nil + } + + logger.Infof(ctx, "Handing over task status check to K8s Task Executor.") + status, err := w.K8sTaskExecutor.CheckTaskStatus(ctx, taskCtx, task) + if err != nil { + return status, err + } + + return status.WithState(taskCtx.GetCustomState()), nil +} + +func (w waitableTaskExecutor) GetTaskStatus(ctx context.Context, taskCtx types.TaskContext, r flytek8s.K8sResource) ( + types.TaskStatus, *events.TaskEventInfo, error) { + + status, eventInfo, err := w.containerTaskExecutor.GetTaskStatus(ctx, taskCtx, r) + if err != nil { + return status, eventInfo, err + } + + // TODO: no need to check with the cache anymore since all executions are terminated already. Optimize! + allWaitables, _, _, err := w.getUpdatedWaitables(ctx, taskCtx) + if err != nil { + return types.TaskStatusRetryableFailure(err), nil, nil + } + + if eventInfo == nil { + eventInfo = &events.TaskEventInfo{} + } + + logs := generateWorkflowExecutionLinks(ctx, allWaitables) + + if eventInfo.Logs == nil { + eventInfo.Logs = make([]*core.TaskLog, 0, len(logs)) + } + + eventInfo.Logs = append(eventInfo.Logs, logs...) + + return status, eventInfo, err +} + +func (w waitableTaskExecutor) KillTask(ctx context.Context, taskCtx types.TaskContext, reason string) error { + switch taskCtx.GetPhase() { + case types.TaskPhaseQueued: + // TODO: no need to check with the cache anymore since all executions are terminated already. Optimize! + allWaitables, _, _, err := w.getUpdatedWaitables(ctx, taskCtx) + if err != nil { + return err + } + + for _, waitable := range allWaitables { + if !isTerminalWorkflowPhase(waitable.GetPhase()) { + _, err := w.adminClient.TerminateExecution(ctx, &admin2.ExecutionTerminateRequest{ + Id: waitable.WfExecId, + Cause: reason, + }) + + if err != nil { + return err + } + } + } + + return nil + default: + return w.K8sTaskExecutor.KillTask(ctx, taskCtx, reason) + } +} + +func (w waitableTaskExecutor) syncItem(ctx context.Context, obj utils2.CacheItem) ( + utils2.CacheItem, utils2.CacheSyncAction, error) { + + waitable, casted := obj.(*waitableWrapper) + if !casted { + return nil, utils2.Unchanged, fmt.Errorf("wrong type. expected %v. got %v", reflect.TypeOf(&waitableWrapper{}), reflect.TypeOf(obj)) + } + + exec, err := w.adminClient.GetExecution(ctx, &admin2.WorkflowExecutionGetRequest{ + Id: waitable.WfExecId, + }) + + if err != nil { + return nil, utils2.Unchanged, err + } + + if waitable.Phase != exec.GetClosure().Phase { + waitable.Phase = exec.GetClosure().Phase + return waitable, utils2.Update, nil + } + + return waitable, utils2.Unchanged, nil +} + +func newWaitableTaskExecutor(ctx context.Context) (executor *waitableTaskExecutor, err error) { + waitableExec := &waitableTaskExecutor{ + containerTaskExecutor: containerTaskExecutor{}, + } + + waitableExec.K8sTaskExecutor = flytek8s.NewK8sTaskExecutorForResource(waitableTaskType, &v1.Pod{}, + waitableExec, flytek8s.DefaultInformerResyncDuration) + + waitableExec.adminClient, err = admin.InitializeAdminClientFromConfig(ctx) + if err != nil { + return waitableExec, err + } + + return waitableExec, nil +} + +func init() { + tasksV1.RegisterLoader(func(ctx context.Context) error { + waitableExec, err := newWaitableTaskExecutor(ctx) + if err != nil { + return err + } + + return tasksV1.RegisterForTaskTypes(waitableExec, waitableTaskType) + }) +} diff --git a/flyteplugins/go/tasks/v1/k8splugins/waitable_task_test.go b/flyteplugins/go/tasks/v1/k8splugins/waitable_task_test.go new file mode 100755 index 0000000000..32a041cb50 --- /dev/null +++ b/flyteplugins/go/tasks/v1/k8splugins/waitable_task_test.go @@ -0,0 +1,393 @@ +package k8splugins + +import ( + "context" + "testing" + + "errors" + structpb "github.com/golang/protobuf/ptypes/struct" + eventErrors "github.com/lyft/flyteidl/clients/go/events/errors" + utils2 "github.com/lyft/flyteplugins/go/tasks/v1/utils" + + adminMocks "github.com/lyft/flyteidl/clients/go/admin/mocks" + "github.com/lyft/flyteidl/clients/go/coreutils" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins" + "github.com/lyft/flyteplugins/go/tasks/v1/flytek8s" + flytek8sMocks "github.com/lyft/flyteplugins/go/tasks/v1/flytek8s/mocks" + "github.com/lyft/flyteplugins/go/tasks/v1/k8splugins/mocks" + "github.com/lyft/flyteplugins/go/tasks/v1/types" + tasksMocks "github.com/lyft/flyteplugins/go/tasks/v1/types/mocks" + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/storage" + "github.com/lyft/flytestdlib/utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "google.golang.org/grpc" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + k8sTypes "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/json" +) + +//go:generate mockery -dir ../../../../vendor/github.com/lyft/flytestdlib/utils -name AutoRefreshCache + +type mockAdminService struct { + *adminMocks.AdminServiceClient + executionPhaseCache map[string]core.WorkflowExecution_Phase +} + +type mockAutorefreshCache struct { + *mocks.AutoRefreshCache + waitableExec *waitableTaskExecutor +} + +func getMockTaskContext() *tasksMocks.TaskContext { + taskCtx := &tasksMocks.TaskContext{} + taskCtx.On("GetNamespace").Return("ns") + taskCtx.On("GetAnnotations").Return(map[string]string{"aKey": "aVal"}) + taskCtx.On("GetLabels").Return(map[string]string{"lKey": "lVal"}) + taskCtx.On("GetOwnerReference").Return(metav1.OwnerReference{Name: "x"}) + taskCtx.On("GetInputsFile").Return(storage.DataReference("/fake/inputs.pb")) + taskCtx.On("GetDataDir").Return(storage.DataReference("/fake/")) + taskCtx.On("GetErrorFile").Return(storage.DataReference("/fake/error.pb")) + taskCtx.On("GetOutputsFile").Return(storage.DataReference("/fake/inputs.pb")) + taskCtx.On("GetPhaseVersion").Return(uint32(1)) + + id := &tasksMocks.TaskExecutionID{} + id.On("GetGeneratedName").Return("test") + id.On("GetID").Return(core.TaskExecutionIdentifier{}) + taskCtx.On("GetTaskExecutionID").Return(id) + + to := &tasksMocks.TaskOverrides{} + to.On("GetResources").Return(resourceRequirements) + taskCtx.On("GetOverrides").Return(to) + return taskCtx +} + +func newDummpyTaskEventRecorder() types.EventRecorder { + s := &tasksMocks.EventRecorder{} + s.On("RecordTaskEvent", mock.Anything, mock.Anything).Return(nil) + return s +} + +func (m mockAdminService) GetExecution(ctx context.Context, in *admin.WorkflowExecutionGetRequest, opts ...grpc.CallOption) ( + *admin.Execution, error) { + wfExec := &admin.Execution{} + wfExec.Id = in.Id + wfExec.Closure = &admin.ExecutionClosure{} + if existingPhase, found := m.executionPhaseCache[in.Id.String()]; found { + if existingPhase < core.WorkflowExecution_SUCCEEDED { + wfExec.Closure.Phase = existingPhase + 1 + } else { + wfExec.Closure.Phase = existingPhase + } + } else { + wfExec.Closure.Phase = core.WorkflowExecution_QUEUED + } + + m.executionPhaseCache[in.Id.String()] = wfExec.Closure.Phase + + return wfExec, nil +} + +func (m mockAutorefreshCache) Start(ctx context.Context) { +} + +func (m mockAutorefreshCache) GetOrCreate(item utils.CacheItem) (utils.CacheItem, error) { + w := item.(*waitableWrapper) + item, _, err := m.waitableExec.syncItem(context.TODO(), w) + return item, err +} + +func setupMockExecutor(t testing.TB) *waitableTaskExecutor { + ctx := context.Background() + waitableExec := &waitableTaskExecutor{ + containerTaskExecutor: containerTaskExecutor{}, + } + + waitableExec.K8sTaskExecutor = flytek8s.NewK8sTaskExecutorForResource(waitableTaskType, &v1.Pod{}, waitableExec, + flytek8s.DefaultInformerResyncDuration) + + mockAdmin := &mockAdminService{ + AdminServiceClient: &adminMocks.AdminServiceClient{}, + executionPhaseCache: map[string]core.WorkflowExecution_Phase{}, + } + mockAdmin.On("TerminateExecution", mock.Anything, mock.Anything).Return(nil, nil).Times(2) + waitableExec.adminClient = mockAdmin + + mockCache := &mockAutorefreshCache{ + AutoRefreshCache: &mocks.AutoRefreshCache{}, + waitableExec: waitableExec, + } + + mem, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) + assert.NoError(t, err) + + ref, err := mem.ConstructReference(ctx, mem.GetBaseContainerFQN(ctx), "fake", "inputs.pb") + assert.NoError(t, err) + assert.NoError(t, mem.WriteProtobuf(ctx, ref, + storage.Options{}, coreutils.MustMakeLiteral(map[string]interface{}{ + "w": 1, + }))) + + assert.NoError(t, waitableExec.Initialize(context.TODO(), types.ExecutorInitializationParameters{ + EventRecorder: newDummpyTaskEventRecorder(), + DataStore: mem, + MetricsScope: promutils.NewTestScope(), + OwnerKind: "Pod", + EnqueueOwner: func(name k8sTypes.NamespacedName) error { + return nil + }, + })) + + // Wait until after Initialize is called since it sets the executionsCache there too, but we want + // to use the mock one + waitableExec.executionsCache = mockCache + mockCache.Start(ctx) + + return waitableExec +} + +func simulateMarshalUnMarshal(t testing.TB, in map[string]interface{}) (out map[string]interface{}) { + raw, err := json.Marshal(in) + assert.NoError(t, err) + + out = map[string]interface{}{} + assert.NoError(t, json.Unmarshal(raw, &out)) + return out +} + +func createWaitableLiteral(t testing.TB, execName string, phase core.WorkflowExecution_Phase) (*plugins.Waitable, *core.Literal) { + stObj := &structpb.Struct{} + expected := &plugins.Waitable{ + WfExecId: &core.WorkflowExecutionIdentifier{ + Name: execName, + Project: "exec_proj", + Domain: "exec_domain", + }, + Phase: phase, + } + assert.NoError(t, utils2.MarshalStruct(expected, stObj)) + + waitableLiteral := &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Generic{ + Generic: stObj, + }, + }, + }, + } + + return expected, waitableLiteral +} + +func TestWaitableTaskExecutor(t *testing.T) { + ctx := context.Background() + mockRuntimeClient := flytek8sMocks.NewMockRuntimeClient() + flytek8s.InitializeFake() + assert.NoError(t, flytek8s.InjectCache(&flytek8sMocks.Cache{})) + assert.NoError(t, flytek8s.InjectClient(mockRuntimeClient)) + + + t.Run("Initialize", func(t *testing.T) { + _, err := newWaitableTaskExecutor(context.TODO()) + assert.NoError(t, err) + }) + + t.Run("StartTask", func(t *testing.T) { + taskCtx := getMockTaskContext() + expected, waitableLiteral := createWaitableLiteral(t, "exec_name", 0) + exec := setupMockExecutor(t) + + status, err := exec.StartTask(ctx, taskCtx, &core.TaskTemplate{}, coreutils.MustMakeLiteral(map[string]interface{}{ + "w": waitableLiteral, + }).GetMap()) + + assert.NoError(t, err) + assert.Contains(t, status.State, waitablesKey) + wr := status.State[waitablesKey].([]*waitableWrapper) + assert.Equal(t, (&waitableWrapper{Waitable: &plugins.Waitable{ + WfExecId: expected.WfExecId, + Phase: core.WorkflowExecution_QUEUED, + }}).String(), wr[0].String()) + assert.Equal(t, types.TaskPhaseQueued, status.Phase) + }) + + t.Run("StartTaskStateMismatch", func(t *testing.T) { + taskCtx := getMockTaskContext() + _, waitableLiteral := createWaitableLiteral(t, "exec_name", 0) + mockRecorder := &tasksMocks.EventRecorder{} + mockRecorder.On("RecordTaskEvent", mock.Anything, mock.Anything).Return(&eventErrors.EventError{Code: eventErrors.EventAlreadyInTerminalStateError, + Cause: errors.New("already exists"), + }) + exec := setupMockExecutor(t) + exec.recorder = mockRecorder + status, err := exec.StartTask(ctx, taskCtx, &core.TaskTemplate{}, coreutils.MustMakeLiteral(map[string]interface{}{ + "w": waitableLiteral, + }).GetMap()) + + assert.NoError(t, err) + assert.Nil(t, status.State) + assert.Equal(t, types.TaskPhasePermanentFailure, status.Phase) + }) + + t.Run("CheckTaskUntilSuccess", func(t *testing.T) { + taskCtx := getMockTaskContext() + _, waitableLiteral := createWaitableLiteral(t, "exec_should_succeed", 0) + stObj := &structpb.Struct{} + wrongTypeProto := &plugins.ArrayJob{ + Size: 2, + } + assert.NoError(t, utils2.MarshalStruct(wrongTypeProto, stObj)) + taskTemplate := &core.TaskTemplate{} + exec := setupMockExecutor(t) + status, err := exec.StartTask(ctx, taskCtx, taskTemplate, coreutils.MustMakeLiteral(map[string]interface{}{ + "w": waitableLiteral, + "wArr": coreutils.MustMakeLiteral([]interface{}{waitableLiteral, waitableLiteral}), + "otherGeneric": &core.Literal{Value: &core.Literal_Scalar{Scalar: &core.Scalar{Value: &core.Scalar_Generic{Generic: stObj}}}}, + }).GetMap()) + + assert.NoError(t, err) + + taskCtx = getMockTaskContext() + taskCtx.On("GetCustomState").Return(simulateMarshalUnMarshal(t, status.State)) + taskCtx.On("GetPhase").Return(status.Phase) + taskCtx.On("GetK8sServiceAccount").Return("") + + for status.Phase != types.TaskPhaseSucceeded { + status, err = exec.CheckTaskStatus(ctx, taskCtx, taskTemplate) + assert.NoError(t, err) + if err != nil { + assert.FailNow(t, "failed to check status. Err: %v", err) + } + + taskCtx = getMockTaskContext() + taskCtx.On("GetCustomState").Return(simulateMarshalUnMarshal(t, status.State)) + taskCtx.On("GetPhase").Return(status.Phase) + + assert.NoError(t, advancePodPhases(ctx, mockRuntimeClient)) + } + + assert.Equal(t, types.TaskPhaseSucceeded, status.Phase) + }) + + t.Run("KillTask", func(t *testing.T) { + taskCtx := getMockTaskContext() + expected, waitableLiteral := createWaitableLiteral(t, "exec_should_terminate", 0) + exec := setupMockExecutor(t) + + status, err := exec.StartTask(ctx, taskCtx, &core.TaskTemplate{}, coreutils.MustMakeLiteral(map[string]interface{}{ + "w": waitableLiteral, + }).GetMap()) + + assert.NoError(t, err) + taskCtx.On("GetCustomState").Return(simulateMarshalUnMarshal(t, status.State)) + taskCtx.On("GetPhase").Return(status.Phase) + + assert.Contains(t, status.State, waitablesKey) + wr := status.State[waitablesKey].([]*waitableWrapper) + assert.Equal(t, (&waitableWrapper{Waitable: &plugins.Waitable{ + WfExecId: expected.WfExecId, + Phase: core.WorkflowExecution_QUEUED, + }}).String(), wr[0].String()) + assert.Equal(t, types.TaskPhaseQueued, status.Phase) + + assert.NoError(t, exec.KillTask(ctx, taskCtx, "cause I like to")) + }) +} + +func TestUpdateWaitableLiterals(t *testing.T) { + _, l1 := createWaitableLiteral(t, "a.a.exec", 0) + _, l2 := createWaitableLiteral(t, "b.exec", 0) + originalLiterals := coreutils.MustMakeLiteral(map[string]interface{}{ + "a": map[string]interface{}{ + "a.a": l1, + }, + "b": l2, + }) + + subLiterals, waitables := discoverWaitableInputs(originalLiterals) + assert.Len(t, subLiterals, 2) + assert.Len(t, waitables, 2) + + for _, w := range waitables { + if w.WfExecId.Name == "a.a.exec" { + w.Phase = core.WorkflowExecution_FAILED + } else { + w.Phase = core.WorkflowExecution_SUCCEEDED + } + } + + assert.NoError(t, updateWaitableLiterals(subLiterals, waitables)) + + _, l1 = createWaitableLiteral(t, "a.a.exec", core.WorkflowExecution_FAILED) + _, l2 = createWaitableLiteral(t, "b.exec", core.WorkflowExecution_SUCCEEDED) + expectedLiterals := coreutils.MustMakeLiteral(map[string]interface{}{ + "a": map[string]interface{}{ + "a.a": l1, + }, + "b": l2, + }) + + assert.Equal(t, expectedLiterals, originalLiterals) +} + +func TestToWaitableWrapperSlice(t *testing.T) { + input := []*waitableWrapper{ + { + Waitable: &plugins.Waitable{ + Phase: core.WorkflowExecution_SUCCEEDED, + }, + }, + } + + ctx := context.Background() + t.Run("WaitableWrapper Slice", func(t *testing.T) { + res, err := toWaitableWrapperSlice(ctx, input) + assert.NoError(t, err) + assert.Equal(t, input, res) + }) + + t.Run("Wrong type", func(t *testing.T) { + _, err := toWaitableWrapperSlice(ctx, "wrong type") + assert.Error(t, err) + }) + + t.Run("Interface", func(t *testing.T) { + a := make([]interface{}, 0, len(input)) + for _, w := range input { + a = append(a, w) + } + + res, err := toWaitableWrapperSlice(context.TODO(), a) + assert.NoError(t, err) + assert.Equal(t, input, res) + }) + + t.Run("Json", func(t *testing.T) { + raw, err := json.Marshal(input) + assert.NoError(t, err) + + var a []interface{} + assert.NoError(t, json.Unmarshal(raw, &a)) + + res, err := toWaitableWrapperSlice(context.TODO(), a) + assert.NoError(t, err) + assert.Equal(t, input, res) + }) + + t.Run("Wrong type in slice", func(t *testing.T) { + a := make([]interface{}, 0, len(input)) + for _, w := range input { + a = append(a, w) + } + + a = append(a, "wrong type") + + _, err := toWaitableWrapperSlice(context.TODO(), a) + assert.Error(t, err) + }) +} diff --git a/flyteplugins/go/tasks/v1/logs/config.go b/flyteplugins/go/tasks/v1/logs/config.go new file mode 100755 index 0000000000..f8a190aa95 --- /dev/null +++ b/flyteplugins/go/tasks/v1/logs/config.go @@ -0,0 +1,33 @@ +package logs + +import "github.com/lyft/flyteplugins/go/tasks/v1/config" + +//go:generate pflags LogConfig + +// Log plugins configs +type LogConfig struct { + IsCloudwatchEnabled bool `json:"cloudwatch-enabled" pflag:",Enable Cloudwatch Logging"` + CloudwatchRegion string `json:"cloudwatch-region" pflag:",AWS region in which Cloudwatch logs are stored."` + CloudwatchLogGroup string `json:"cloudwatch-log-group" pflag:",Log group to which streams are associated."` + + IsKubernetesEnabled bool `json:"kubernetes-enabled" pflag:",Enable Kubernetes Logging"` + KubernetesURL string `json:"kubernetes-url" pflag:",Console URL for Kubernetes logs"` + + IsStackDriverEnabled bool `json:"stackdriver-enabled" pflag:",Enable Log-links to stackdriver"` + GCPProjectName string `json:"gcp-project" pflag:",Name of the project in GCP"` + StackdriverLogResourceName string `json:"stackdriver-logresourcename" pflag:",Name of the logresource in stackdriver"` +} + +var ( + logConfigSection = config.MustRegisterSubSection("logs", &LogConfig{}) +) + +func GetLogConfig() *LogConfig { + return logConfigSection.GetConfig().(*LogConfig) +} + +// This method should be used for unit testing only +func SetLogConfig(logConfig *LogConfig) error { + return logConfigSection.SetConfig(logConfig) +} + diff --git a/flyteplugins/go/tasks/v1/logs/logconfig_flags.go b/flyteplugins/go/tasks/v1/logs/logconfig_flags.go new file mode 100755 index 0000000000..dae059f7be --- /dev/null +++ b/flyteplugins/go/tasks/v1/logs/logconfig_flags.go @@ -0,0 +1,53 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package logs + +import ( + "encoding/json" + "reflect" + + "fmt" + + "github.com/spf13/pflag" +) + +// If v is a pointer, it will get its element value or the zero value of the element type. +// If v is not a pointer, it will return it as is. +func (LogConfig) elemValueOrNil(v interface{}) interface{} { + if t := reflect.TypeOf(v); t.Kind() == reflect.Ptr { + if reflect.ValueOf(v).IsNil() { + return reflect.Zero(t.Elem()).Interface() + } else { + return reflect.ValueOf(v).Interface() + } + } else if v == nil { + return reflect.Zero(t).Interface() + } + + return v +} + +func (LogConfig) mustMarshalJSON(v json.Marshaler) string { + raw, err := v.MarshalJSON() + if err != nil { + panic(err) + } + + return string(raw) +} + +// GetPFlagSet will return strongly types pflags for all fields in LogConfig and its nested types. The format of the +// flags is json-name.json-sub-name... etc. +func (cfg LogConfig) GetPFlagSet(prefix string) *pflag.FlagSet { + cmdFlags := pflag.NewFlagSet("LogConfig", pflag.ExitOnError) + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "cloudwatch-enabled"), *new(bool), "Enable Cloudwatch Logging") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "cloudwatch-region"), *new(string), "AWS region in which Cloudwatch logs are stored.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "cloudwatch-log-group"), *new(string), "Log group to which streams are associated.") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "kubernetes-enabled"), *new(bool), "Enable Kubernetes Logging") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "kubernetes-url"), *new(string), "Console URL for Kubernetes logs") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "stackdriver-enabled"), *new(bool), "Enable Log-links to stackdriver") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "gcp-project"), *new(string), "Name of the project in GCP") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "stackdriver-logresourcename"), *new(string), "Name of the logresource in stackdriver") + return cmdFlags +} diff --git a/flyteplugins/go/tasks/v1/logs/logconfig_flags_test.go b/flyteplugins/go/tasks/v1/logs/logconfig_flags_test.go new file mode 100755 index 0000000000..966ab32e64 --- /dev/null +++ b/flyteplugins/go/tasks/v1/logs/logconfig_flags_test.go @@ -0,0 +1,278 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package logs + +import ( + "encoding/json" + "fmt" + "reflect" + "strings" + "testing" + + "github.com/mitchellh/mapstructure" + "github.com/stretchr/testify/assert" +) + +var dereferencableKindsLogConfig = map[reflect.Kind]struct{}{ + reflect.Array: {}, reflect.Chan: {}, reflect.Map: {}, reflect.Ptr: {}, reflect.Slice: {}, +} + +// Checks if t is a kind that can be dereferenced to get its underlying type. +func canGetElementLogConfig(t reflect.Kind) bool { + _, exists := dereferencableKindsLogConfig[t] + return exists +} + +// This decoder hook tests types for json unmarshaling capability. If implemented, it uses json unmarshal to build the +// object. Otherwise, it'll just pass on the original data. +func jsonUnmarshalerHookLogConfig(_, to reflect.Type, data interface{}) (interface{}, error) { + unmarshalerType := reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() + if to.Implements(unmarshalerType) || reflect.PtrTo(to).Implements(unmarshalerType) || + (canGetElementLogConfig(to.Kind()) && to.Elem().Implements(unmarshalerType)) { + + raw, err := json.Marshal(data) + if err != nil { + fmt.Printf("Failed to marshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + res := reflect.New(to).Interface() + err = json.Unmarshal(raw, &res) + if err != nil { + fmt.Printf("Failed to umarshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + return res, nil + } + + return data, nil +} + +func decode_LogConfig(input, result interface{}) error { + config := &mapstructure.DecoderConfig{ + TagName: "json", + WeaklyTypedInput: true, + Result: result, + DecodeHook: mapstructure.ComposeDecodeHookFunc( + mapstructure.StringToTimeDurationHookFunc(), + mapstructure.StringToSliceHookFunc(","), + jsonUnmarshalerHookLogConfig, + ), + } + + decoder, err := mapstructure.NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(input) +} + +func join_LogConfig(arr interface{}, sep string) string { + listValue := reflect.ValueOf(arr) + strs := make([]string, 0, listValue.Len()) + for i := 0; i < listValue.Len(); i++ { + strs = append(strs, fmt.Sprintf("%v", listValue.Index(i))) + } + + return strings.Join(strs, sep) +} + +func testDecodeJson_LogConfig(t *testing.T, val, result interface{}) { + assert.NoError(t, decode_LogConfig(val, result)) +} + +func testDecodeSlice_LogConfig(t *testing.T, vStringSlice, result interface{}) { + assert.NoError(t, decode_LogConfig(vStringSlice, result)) +} + +func TestLogConfig_GetPFlagSet(t *testing.T) { + val := LogConfig{} + cmdFlags := val.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) +} + +func TestLogConfig_SetFlags(t *testing.T) { + actual := LogConfig{} + cmdFlags := actual.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) + + t.Run("Test_cloudwatch-enabled", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vBool, err := cmdFlags.GetBool("cloudwatch-enabled"); err == nil { + assert.Equal(t, bool(*new(bool)), vBool) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("cloudwatch-enabled", testValue) + if vBool, err := cmdFlags.GetBool("cloudwatch-enabled"); err == nil { + testDecodeJson_LogConfig(t, fmt.Sprintf("%v", vBool), &actual.IsCloudwatchEnabled) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_cloudwatch-region", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("cloudwatch-region"); err == nil { + assert.Equal(t, string(*new(string)), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("cloudwatch-region", testValue) + if vString, err := cmdFlags.GetString("cloudwatch-region"); err == nil { + testDecodeJson_LogConfig(t, fmt.Sprintf("%v", vString), &actual.CloudwatchRegion) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_cloudwatch-log-group", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("cloudwatch-log-group"); err == nil { + assert.Equal(t, string(*new(string)), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("cloudwatch-log-group", testValue) + if vString, err := cmdFlags.GetString("cloudwatch-log-group"); err == nil { + testDecodeJson_LogConfig(t, fmt.Sprintf("%v", vString), &actual.CloudwatchLogGroup) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_kubernetes-enabled", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vBool, err := cmdFlags.GetBool("kubernetes-enabled"); err == nil { + assert.Equal(t, bool(*new(bool)), vBool) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("kubernetes-enabled", testValue) + if vBool, err := cmdFlags.GetBool("kubernetes-enabled"); err == nil { + testDecodeJson_LogConfig(t, fmt.Sprintf("%v", vBool), &actual.IsKubernetesEnabled) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_kubernetes-url", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("kubernetes-url"); err == nil { + assert.Equal(t, string(*new(string)), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("kubernetes-url", testValue) + if vString, err := cmdFlags.GetString("kubernetes-url"); err == nil { + testDecodeJson_LogConfig(t, fmt.Sprintf("%v", vString), &actual.KubernetesURL) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_stackdriver-enabled", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vBool, err := cmdFlags.GetBool("stackdriver-enabled"); err == nil { + assert.Equal(t, bool(*new(bool)), vBool) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("stackdriver-enabled", testValue) + if vBool, err := cmdFlags.GetBool("stackdriver-enabled"); err == nil { + testDecodeJson_LogConfig(t, fmt.Sprintf("%v", vBool), &actual.IsStackDriverEnabled) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_gcp-project", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("gcp-project"); err == nil { + assert.Equal(t, string(*new(string)), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("gcp-project", testValue) + if vString, err := cmdFlags.GetString("gcp-project"); err == nil { + testDecodeJson_LogConfig(t, fmt.Sprintf("%v", vString), &actual.GCPProjectName) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_stackdriver-logresourcename", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("stackdriver-logresourcename"); err == nil { + assert.Equal(t, string(*new(string)), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("stackdriver-logresourcename", testValue) + if vString, err := cmdFlags.GetString("stackdriver-logresourcename"); err == nil { + testDecodeJson_LogConfig(t, fmt.Sprintf("%v", vString), &actual.StackdriverLogResourceName) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) +} diff --git a/flyteplugins/go/tasks/v1/logs/logging_utils.go b/flyteplugins/go/tasks/v1/logs/logging_utils.go new file mode 100755 index 0000000000..f69cd49c01 --- /dev/null +++ b/flyteplugins/go/tasks/v1/logs/logging_utils.go @@ -0,0 +1,61 @@ +package logs + +import ( + "context" + + logUtils "github.com/lyft/flyteidl/clients/go/coreutils/logs" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytestdlib/logger" + "k8s.io/api/core/v1" +) + +func GetLogsForContainerInPod(ctx context.Context, pod *v1.Pod, index uint32, nameSuffix string) ([]*core.TaskLog, error) { + var logs []*core.TaskLog + logConfig := GetLogConfig() + + logPlugins := map[string]logUtils.LogPlugin{} + + if logConfig.IsKubernetesEnabled { + logPlugins["Kubernetes Logs"] = logUtils.NewKubernetesLogPlugin(logConfig.KubernetesURL) + } + if logConfig.IsCloudwatchEnabled { + logPlugins["Cloudwatch Logs"] = logUtils.NewCloudwatchLogPlugin(logConfig.CloudwatchRegion, logConfig.CloudwatchLogGroup) + } + if logConfig.IsStackDriverEnabled { + logPlugins["Stackdriver Logs"] = logUtils.NewStackdriverLogPlugin(logConfig.GCPProjectName, logConfig.StackdriverLogResourceName) + } + + if len(logPlugins) == 0 { + return nil, nil + } + + if pod == nil { + logger.Error(ctx, "cannot extract logs for a nil container") + return nil, nil + } + + if uint32(len(pod.Spec.Containers)) <= index { + logger.Errorf(ctx, "container IndexOutOfBound, requested [%d], but total containers [%d] in pod phase [%v]", index, len(pod.Spec.Containers), pod.Status.Phase) + return nil, nil + } + + if uint32(len(pod.Status.ContainerStatuses)) <= index { + logger.Errorf(ctx, "containerStatus IndexOutOfBound, requested [%d], but total containerStatuses [%d] in pod phase [%v]", index, len(pod.Status.ContainerStatuses), pod.Status.Phase) + return nil, nil + } + + for name, plugin := range logPlugins { + log, err := plugin.GetTaskLog( + pod.Name, + pod.Namespace, + pod.Spec.Containers[index].Name, + pod.Status.ContainerStatuses[index].ContainerID, + name+nameSuffix, + ) + if err != nil { + return nil, err + } + logs = append(logs, &log) + } + return logs, nil +} diff --git a/flyteplugins/go/tasks/v1/logs/logging_utils_test.go b/flyteplugins/go/tasks/v1/logs/logging_utils_test.go new file mode 100755 index 0000000000..d6bdb48565 --- /dev/null +++ b/flyteplugins/go/tasks/v1/logs/logging_utils_test.go @@ -0,0 +1,205 @@ +package logs + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + v1 "k8s.io/api/core/v1" +) + +const podName = "PodName" + +func TestGetLogsForContainerInPod_NoPlugins(t *testing.T) { + assert.NoError(t, SetLogConfig(&LogConfig{})) + l, err := GetLogsForContainerInPod(context.TODO(), nil, 0, " Suffix") + assert.NoError(t, err) + assert.Nil(t, l) +} + +func TestGetLogsForContainerInPod_NoLogs(t *testing.T) { + assert.NoError(t, SetLogConfig(&LogConfig{ + IsCloudwatchEnabled: true, + CloudwatchRegion: "us-east-1", + CloudwatchLogGroup: "/kubernetes/flyte-production", + })) + p, err := GetLogsForContainerInPod(context.TODO(), nil, 0, " Suffix") + assert.NoError(t, err) + assert.Nil(t, p) +} + +func TestGetLogsForContainerInPod_BadIndex(t *testing.T) { + assert.NoError(t, SetLogConfig(&LogConfig{ + IsCloudwatchEnabled: true, + CloudwatchRegion: "us-east-1", + CloudwatchLogGroup: "/kubernetes/flyte-production", + })) + + pod := &v1.Pod{ + Spec: v1.PodSpec{ + Containers: []v1.Container{ + { + Name: "ContainerName", + }, + }, + }, + Status: v1.PodStatus{ + ContainerStatuses: []v1.ContainerStatus{ + { + ContainerID: "ContainerID", + }, + }, + }, + } + pod.Name = podName + + p, err := GetLogsForContainerInPod(context.TODO(), pod, 1, " Suffix") + assert.NoError(t, err) + assert.Nil(t, p) +} + +func TestGetLogsForContainerInPod_MissingStatus(t *testing.T) { + assert.NoError(t, SetLogConfig(&LogConfig{ + IsCloudwatchEnabled: true, + CloudwatchRegion: "us-east-1", + CloudwatchLogGroup: "/kubernetes/flyte-production", + })) + + pod := &v1.Pod{ + Spec: v1.PodSpec{ + Containers: []v1.Container{ + { + Name: "ContainerName", + }, + }, + }, + Status: v1.PodStatus{ + }, + } + pod.Name = podName + + p, err := GetLogsForContainerInPod(context.TODO(), pod, 1, " Suffix") + assert.NoError(t, err) + assert.Nil(t, p) +} + +func TestGetLogsForContainerInPod_Cloudwatch(t *testing.T) { + assert.NoError(t, SetLogConfig(&LogConfig{IsCloudwatchEnabled: true, + CloudwatchRegion: "us-east-1", + CloudwatchLogGroup: "/kubernetes/flyte-production", + })) + + pod := &v1.Pod{ + Spec: v1.PodSpec{ + Containers: []v1.Container{ + { + Name: "ContainerName", + }, + }, + }, + Status: v1.PodStatus{ + ContainerStatuses: []v1.ContainerStatus{ + { + ContainerID: "ContainerID", + }, + }, + }, + } + pod.Name = podName + + logs, err := GetLogsForContainerInPod(context.TODO(), pod, 0, " Suffix") + assert.Nil(t, err) + assert.Len(t, logs, 1) +} + +func TestGetLogsForContainerInPod_K8s(t *testing.T) { + assert.NoError(t, SetLogConfig(&LogConfig{ + IsKubernetesEnabled: true, + KubernetesURL: "k8s.com", + })) + + pod := &v1.Pod{ + Spec: v1.PodSpec{ + Containers: []v1.Container{ + { + Name: "ContainerName", + }, + }, + }, + Status: v1.PodStatus{ + ContainerStatuses: []v1.ContainerStatus{ + { + ContainerID: "ContainerID", + }, + }, + }, + } + pod.Name = podName + + logs, err := GetLogsForContainerInPod(context.TODO(), pod, 0, " Suffix") + assert.Nil(t, err) + assert.Len(t, logs, 1) +} + +func TestGetLogsForContainerInPod_All(t *testing.T) { + assert.NoError(t, SetLogConfig(&LogConfig{ + IsKubernetesEnabled: true, + KubernetesURL: "k8s.com", + IsCloudwatchEnabled: true, + CloudwatchRegion: "us-east-1", + CloudwatchLogGroup: "/kubernetes/flyte-production", + })) + + pod := &v1.Pod{ + Spec: v1.PodSpec{ + Containers: []v1.Container{ + { + Name: "ContainerName", + }, + }, + }, + Status: v1.PodStatus{ + ContainerStatuses: []v1.ContainerStatus{ + { + ContainerID: "ContainerID", + }, + }, + }, + } + pod.Name = podName + + logs, err := GetLogsForContainerInPod(context.TODO(), pod, 0, " Suffix") + assert.Nil(t, err) + assert.Len(t, logs, 2) +} + +func TestGetLogsForContainerInPod_Stackdriver(t *testing.T) { + + assert.NoError(t, SetLogConfig(&LogConfig{ + IsStackDriverEnabled: true, + GCPProjectName: "myGCPProject", + StackdriverLogResourceName: "aws_ec2_instance", + })) + + pod := &v1.Pod{ + Spec: v1.PodSpec{ + Containers: []v1.Container{ + { + Name: "ContainerName", + }, + }, + }, + Status: v1.PodStatus{ + ContainerStatuses: []v1.ContainerStatus{ + { + ContainerID: "ContainerID", + }, + }, + }, + } + pod.Name = podName + + logs, err := GetLogsForContainerInPod(context.TODO(), pod, 0, " Suffix") + assert.Nil(t, err) + assert.Len(t, logs, 1) +} diff --git a/flyteplugins/go/tasks/v1/qubole/client/mocks/QuboleClient.go b/flyteplugins/go/tasks/v1/qubole/client/mocks/QuboleClient.go new file mode 100755 index 0000000000..8c3bd3e353 --- /dev/null +++ b/flyteplugins/go/tasks/v1/qubole/client/mocks/QuboleClient.go @@ -0,0 +1,70 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import client "github.com/lyft/flyteplugins/go/tasks/v1/qubole/client" +import context "context" +import mock "github.com/stretchr/testify/mock" + +// QuboleClient is an autogenerated mock type for the QuboleClient type +type QuboleClient struct { + mock.Mock +} + +// ExecuteHiveCommand provides a mock function with given fields: ctx, commandStr, timeoutVal, clusterLabel, accountKey, tags +func (_m *QuboleClient) ExecuteHiveCommand(ctx context.Context, commandStr string, timeoutVal uint32, clusterLabel string, accountKey string, tags []string) (*client.QuboleCommandDetails, error) { + ret := _m.Called(ctx, commandStr, timeoutVal, clusterLabel, accountKey, tags) + + var r0 *client.QuboleCommandDetails + if rf, ok := ret.Get(0).(func(context.Context, string, uint32, string, string, []string) *client.QuboleCommandDetails); ok { + r0 = rf(ctx, commandStr, timeoutVal, clusterLabel, accountKey, tags) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*client.QuboleCommandDetails) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, string, uint32, string, string, []string) error); ok { + r1 = rf(ctx, commandStr, timeoutVal, clusterLabel, accountKey, tags) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetCommandStatus provides a mock function with given fields: ctx, commandID, accountKey +func (_m *QuboleClient) GetCommandStatus(ctx context.Context, commandID string, accountKey string) (client.QuboleStatus, error) { + ret := _m.Called(ctx, commandID, accountKey) + + var r0 client.QuboleStatus + if rf, ok := ret.Get(0).(func(context.Context, string, string) client.QuboleStatus); ok { + r0 = rf(ctx, commandID, accountKey) + } else { + r0 = ret.Get(0).(client.QuboleStatus) + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = rf(ctx, commandID, accountKey) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// KillCommand provides a mock function with given fields: ctx, commandID, accountKey +func (_m *QuboleClient) KillCommand(ctx context.Context, commandID string, accountKey string) error { + ret := _m.Called(ctx, commandID, accountKey) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = rf(ctx, commandID, accountKey) + } else { + r0 = ret.Error(0) + } + + return r0 +} diff --git a/flyteplugins/go/tasks/v1/qubole/client/qubole_client.go b/flyteplugins/go/tasks/v1/qubole/client/qubole_client.go new file mode 100755 index 0000000000..3bb8470739 --- /dev/null +++ b/flyteplugins/go/tasks/v1/qubole/client/qubole_client.go @@ -0,0 +1,253 @@ +package client + +import ( + "context" + "errors" + "fmt" + "net/http" + "time" + + "bytes" + "encoding/json" + "io" + "io/ioutil" + + "github.com/lyft/flytestdlib/logger" +) + +const url = "https://api.qubole.com/api" +const apiPath = "/v1.2/commands" +const QuboleLogLinkFormat = "https://api.qubole.com/v2/analyze?command_id=%s" + +const tokenKeyForAth = "X-AUTH-TOKEN" +const acceptHeaderKey = "Accept" +const hiveCommandType = "HiveCommand" +const killStatus = "kill" +const httpRequestTimeoutSecs = 30 +const host = "api.qubole.com" +const hostHeaderKey = "Host" +const HeaderContentType = "Content-Type" +const ContentTypeJSON = "application/json" + +// QuboleClient CommandStatus Response Format, only used to unmarshall the response +type quboleCmdDetailsInternal struct { + ID int64 + Status string +} + +type QuboleCommandDetails struct { + ID int64 + Status QuboleStatus +} + +// QuboleClient API Request Body, meant to be passed into JSON.marshal +// Any nil, 0 or "" fields will not be marshaled +type RequestBody struct { + Query string `json:"query,omitempty"` + ClusterLabel string `json:"label,omitempty"` + CommandType string `json:"command_type,omitempty"` + Retry uint32 `json:"retry,omitempty"` + Status string `json:"status,omitempty"` + Tags []string `json:"tags,omitempty"` + Timeout uint32 `json:"timeout,omitempty"` + InlineScript string `json:"inline,omitempty"` + Files string `json:"files,omitempty"` +} + +//go:generate mockery -name QuboleClient + +// Interface to interact with QuboleClient for hive tasks +type QuboleClient interface { + ExecuteHiveCommand(ctx context.Context, commandStr string, timeoutVal uint32, clusterLabel string, + accountKey string, tags []string) (*QuboleCommandDetails, error) + KillCommand(ctx context.Context, commandID string, accountKey string) error + GetCommandStatus(ctx context.Context, commandID string, accountKey string) (QuboleStatus, error) +} + +// TODO: The Qubole client needs a rate limiter +type quboleClient struct { + client *http.Client +} + +func (q *quboleClient) getHeaders(accountKey string) http.Header { + headers := make(http.Header) + headers.Set(tokenKeyForAth, accountKey) + headers.Set(HeaderContentType, ContentTypeJSON) + headers.Set(acceptHeaderKey, ContentTypeJSON) + headers.Set(hostHeaderKey, host) + + return headers +} + +// no-op closer for in-memory buffers used as io.Reader +type nopCloser struct { + io.Reader +} + +func (nopCloser) Close() error { return nil } + +func addJSONBody(req *http.Request, body interface{}) error { + // marshals body into JSON and set the request body + js, err := json.Marshal(body) + if err != nil { + return err + } + + req.Header.Add(HeaderContentType, ContentTypeJSON) + req.Body = &nopCloser{bytes.NewReader(js)} + return nil +} + +func unmarshalBody(res *http.Response, t interface{}) error { + bts, err := ioutil.ReadAll(res.Body) + if err != nil { + return err + } + + return json.Unmarshal(bts, t) +} + +func closeBody(ctx context.Context, response *http.Response) { + _, err := io.Copy(ioutil.Discard, response.Body) + if err != nil { + logger.Errorf(ctx, "unexpected failure writing to devNull: %v", err) + } + err = response.Body.Close() + if err != nil { + logger.Warnf(ctx, "failure closing response body: %v", err) + } +} + +// Helper method to execute the requests +func (q *quboleClient) executeRequest(ctx context.Context, method string, path string, body *RequestBody, accountKey string) (*http.Response, error) { + var req *http.Request + var err error + path = url + "/" + path + + switch method { + case http.MethodGet: + req, err = http.NewRequest("GET", path, nil) + case http.MethodPost: + req, err = http.NewRequest("POST", path, nil) + req.Header = q.getHeaders(accountKey) + case http.MethodPut: + req, err = http.NewRequest("PUT", path, nil) + } + + if err != nil { + return nil, err + } + + if body != nil { + err := addJSONBody(req, body) + if err != nil { + return nil, err + } + } + + logger.Debugf(ctx, "qubole endpoint: %v", path) + req.Header = q.getHeaders(accountKey) + return q.client.Do(req) +} + +/* + Execute Hive Command on the QuboleClient Hive Cluster and return the CommandId + param: context.Context ctx: The default go context. + param: string commandStr: the query to execute + param: uint32 timeoutVal: timeout for the query to execute in seconds + param: string ClusterLabel: label for cluster on which to execute the Hive Command. + return: *int64: CommandId for the command executed + return: error: error in-case of a failure +*/ +func (q *quboleClient) ExecuteHiveCommand( + ctx context.Context, + commandStr string, + timeoutVal uint32, + clusterLabel string, + accountKey string, + tags []string) (*QuboleCommandDetails, error) { + + requestBody := RequestBody{ + CommandType: hiveCommandType, + Query: commandStr, + Timeout: timeoutVal, + ClusterLabel: clusterLabel, + Tags: tags, + } + response, err := q.executeRequest(ctx, http.MethodPost, apiPath, &requestBody, accountKey) + if err != nil { + return nil, err + } + defer closeBody(ctx, response) + + if response.StatusCode != 200 { + bts, err := ioutil.ReadAll(response.Body) + if err != nil { + return nil, err + } + return nil, errors.New(fmt.Sprintf("Bad response from Qubole creating query: %d %s", + response.StatusCode, string(bts))) + } + + var cmd quboleCmdDetailsInternal + if err = unmarshalBody(response, &cmd); err != nil { + return nil, err + } + + status := newQuboleStatus(ctx, cmd.Status) + return &QuboleCommandDetails{ID: cmd.ID, Status: status}, nil +} + +/* + Terminate a QuboleClient command + param: context.Context ctx: The default go context. + param: string CommandId: the CommandId to terminate. + return: error: error in-case of a failure +*/ +func (q *quboleClient) KillCommand(ctx context.Context, commandID string, accountKey string) error { + killPath := apiPath + "/" + commandID + requestBody := RequestBody{Status: killStatus} + + response, err := q.executeRequest(ctx, http.MethodPut, killPath, &requestBody, accountKey) + defer closeBody(ctx, response) + return err +} + +/* + Get the status of a QuboleClient command + param: context.Context ctx: The default go context. + param: string CommandId: the CommandId to fetch the status for + return: *string: commandStatus for the CommandId passed + return: error: error in-case of a failure +*/ +func (q *quboleClient) GetCommandStatus(ctx context.Context, commandID string, accountKey string) (QuboleStatus, error) { + statusPath := apiPath + "/" + commandID + response, err := q.executeRequest(ctx, http.MethodGet, statusPath, nil, accountKey) + if err != nil { + return QuboleStatusUnknown, err + } + defer closeBody(ctx, response) + + if response.StatusCode != 200 { + bts, err := ioutil.ReadAll(response.Body) + if err != nil { + return QuboleStatusUnknown, err + } + return QuboleStatusUnknown, errors.New(fmt.Sprintf("Bad response from Qubole getting command status: %d %s", + response.StatusCode, string(bts))) + } + + var cmd quboleCmdDetailsInternal + if err = unmarshalBody(response, &cmd); err != nil { + return QuboleStatusUnknown, err + } + + cmdStatus := newQuboleStatus(ctx, cmd.Status) + return cmdStatus, nil +} + +func NewQuboleClient() QuboleClient { + return &quboleClient{ + client: &http.Client{Timeout: httpRequestTimeoutSecs * time.Second}, + } +} diff --git a/flyteplugins/go/tasks/v1/qubole/client/qubole_client_test.go b/flyteplugins/go/tasks/v1/qubole/client/qubole_client_test.go new file mode 100755 index 0000000000..d9bac006f4 --- /dev/null +++ b/flyteplugins/go/tasks/v1/qubole/client/qubole_client_test.go @@ -0,0 +1,142 @@ +package client + +import ( + "bytes" + "context" + "github.com/stretchr/testify/assert" + "io/ioutil" + "net/http" + "strings" + "testing" +) + +var getCommandResponse = `{ + "command": { + "approx_mode": false, + "approx_aggregations": false, + "query": "select count(*) as num_rows from miniwikistats;", + "sample": false + }, + "qbol_session_id": 0, + "created_at": "2012-10-11T16:54:57Z", + "user_id": 0, + "status": "{STATUS}", + "command_type": "HiveCommand", + "id": 3852, + "progress": 100, + "throttled": true + }` + +var createCommandResponse = `{ + "command": { + "approx_mode": false, + "approx_aggregations": false, + "query": "show tables", + "sample": false + }, + "qbol_session_id": 0, + "created_at": "2012-10-11T16:01:09Z", + "user_id": 0, + "status": "waiting", + "command_type": "HiveCommand", + "id": 3850, + "progress": 0 + }` + +func TestQuboleClient_GetCommandStatus(t *testing.T) { + tests := []struct { + Name string + quboleInternalStatus string + status QuboleStatus + }{ + { + Name: "done status", + quboleInternalStatus: "done", + status: QuboleStatusDone, + }, + { + Name: "unknown status", + quboleInternalStatus: "bogus", + status: QuboleStatusUnknown, + }, + { + Name: "running status", + quboleInternalStatus: "running", + status: QuboleStatusRunning, + }, + } + + for _, test := range tests { + tc := test + t.Run(tc.Name, func(t *testing.T) { + client := createQuboleClient(strings.Replace(getCommandResponse, "{STATUS}", tc.quboleInternalStatus, 1)) + status, err := client.GetCommandStatus(context.Background(), "", "") + assert.NoError(t, err) + assert.Equal(t, tc.status, status) + }) + } +} + +func TestQuboleClient_ExecuteHiveCommand(t *testing.T) { + client := createQuboleClient(createCommandResponse) + details, err := client.ExecuteHiveCommand(context.Background(), + "", 0, "clusterLabel", "", nil) + assert.NoError(t, err) + assert.Equal(t, int64(3850), details.ID) + assert.Equal(t, QuboleStatusWaiting, details.Status) +} + +func TestQuboleClient_KillCommand(t *testing.T) { + client := createQuboleClient("OK") + err := client.KillCommand(context.Background(), "", "") + assert.NoError(t, err) +} + +func TestQuboleClient_ExecuteHiveCommandError(t *testing.T) { + client := createQuboleErrorClient("bad token") + details, err := client.ExecuteHiveCommand(context.Background(), + "", 0, "clusterLabel", "", nil) + assert.Error(t, err) + assert.Nil(t, details) +} + +func TestQuboleClient_GetCommandStatusError(t *testing.T) { + client := createQuboleErrorClient("bad token") + details, err := client.GetCommandStatus(context.Background(), "1234", "fake account key") + assert.Error(t, err) + assert.Equal(t, QuboleStatusUnknown, details) +} + +func createQuboleClient(response string) quboleClient { + hc := &http.Client{Transport: RoundTripFunc(func(req *http.Request) (*http.Response, error) { + response := &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(bytes.NewBufferString(response)), + Header: make(http.Header), + } + response.Header.Set("Content-Type", "application/json") + return response, nil + + })} + + return quboleClient{client: hc} +} + +func createQuboleErrorClient(errorMsg string) quboleClient { + hc := &http.Client{Transport: RoundTripFunc(func(req *http.Request) (*http.Response, error) { + response := &http.Response{ + StatusCode: 400, + Body: ioutil.NopCloser(bytes.NewBufferString(errorMsg)), + Header: make(http.Header), + } + response.Header.Set("Content-Type", "application/json") + return response, nil + + })} + + return quboleClient{client: hc} +} + +type RoundTripFunc func(*http.Request) (*http.Response, error) + +func (rt RoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { return rt(req) } diff --git a/flyteplugins/go/tasks/v1/qubole/client/qubole_status.go b/flyteplugins/go/tasks/v1/qubole/client/qubole_status.go new file mode 100755 index 0000000000..bcafa6b007 --- /dev/null +++ b/flyteplugins/go/tasks/v1/qubole/client/qubole_status.go @@ -0,0 +1,38 @@ +package client + +import ( + "context" + "github.com/lyft/flytestdlib/logger" + "strings" +) + +// This type is meant only to encapsulate the response coming from Qubole as a type, it is +// not meant to be stored locally. +type QuboleStatus string + +const ( + QuboleStatusUnknown QuboleStatus = "UNKNOWN" + QuboleStatusWaiting QuboleStatus = "WAITING" + QuboleStatusRunning QuboleStatus = "RUNNING" + QuboleStatusDone QuboleStatus = "DONE" + QuboleStatusError QuboleStatus = "ERROR" + QuboleStatusCancelled QuboleStatus = "CANCELLED" +) + +var QuboleStatuses = map[QuboleStatus]struct{}{ + QuboleStatusUnknown: {}, + QuboleStatusWaiting: {}, + QuboleStatusRunning: {}, + QuboleStatusDone: {}, + QuboleStatusError: {}, + QuboleStatusCancelled: {}, +} + +func newQuboleStatus(ctx context.Context, status string) QuboleStatus { + upperCased := strings.ToUpper(status) + if _, ok := QuboleStatuses[QuboleStatus(upperCased)]; ok { + return QuboleStatus(upperCased) + } + logger.Warnf(ctx, "Invalid Qubole Status found: %v", status) + return QuboleStatusUnknown +} diff --git a/flyteplugins/go/tasks/v1/qubole/config/config.go b/flyteplugins/go/tasks/v1/qubole/config/config.go new file mode 100755 index 0000000000..aabc8608ad --- /dev/null +++ b/flyteplugins/go/tasks/v1/qubole/config/config.go @@ -0,0 +1,46 @@ +package config + +//go:generate pflags Config + +import ( + "time" + + "github.com/lyft/flytestdlib/config" + + pluginsConfig "github.com/lyft/flyteplugins/go/tasks/v1/config" +) + +const quboleConfigSectionKey = "qubole" + +var ( + defaultConfig = Config{ + QuboleLimit: 200, + LruCacheSize: 1000, + LookasideBufferPrefix: "ql", + LookasideExpirySeconds: config.Duration{Duration: time.Hour * 24}, + } + + quboleConfigSection = pluginsConfig.MustRegisterSubSection(quboleConfigSectionKey, &defaultConfig) +) + +// Qubole plugin configs +type Config struct { + QuboleTokenPath string `json:"quboleTokenPath" pflag:",Where to find the Qubole secret"` + ResourceManagerType string `json:"resourceManagerType" pflag:"noop,Which resource manager to use"` + RedisHostPath string `json:"redisHostPath" pflag:",Redis host location"` + RedisHostKey string `json:"redisHostKey" pflag:",Key for local Redis access"` + RedisMaxRetries int `json:"redisMaxRetries" pflag:",See Redis client options for more info"` + QuboleLimit int `json:"quboleLimit" pflag:",Global limit for concurrent Qubole queries"` + LruCacheSize int `json:"lruCacheSize" pflag:",Size of the AutoRefreshCache"` + LookasideBufferPrefix string `json:"lookasideBufferPrefix" pflag:",Prefix used for lookaside buffer"` + LookasideExpirySeconds config.Duration `json:"lookasideExpirySeconds" pflag:",TTL for lookaside buffer if supported"` +} + +// Retrieves the current config value or default. +func GetQuboleConfig() *Config { + return quboleConfigSection.GetConfig().(*Config) +} + +func SetQuboleConfig(cfg *Config) error { + return quboleConfigSection.SetConfig(cfg) +} diff --git a/flyteplugins/go/tasks/v1/qubole/config/config_flags.go b/flyteplugins/go/tasks/v1/qubole/config/config_flags.go new file mode 100755 index 0000000000..8fb6b6ee11 --- /dev/null +++ b/flyteplugins/go/tasks/v1/qubole/config/config_flags.go @@ -0,0 +1,53 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package config + +import ( + "encoding/json" + "reflect" + + "fmt" + + "github.com/spf13/pflag" +) + +// If v is a pointer, it will get its element value or the zero value of the element type. +// If v is not a pointer, it will return it as is. +func (Config) elemValueOrNil(v interface{}) interface{} { + if t := reflect.TypeOf(v); t.Kind() == reflect.Ptr { + if reflect.ValueOf(v).IsNil() { + return reflect.Zero(t.Elem()).Interface() + } else { + return reflect.ValueOf(v).Interface() + } + } else if v == nil { + return reflect.Zero(t).Interface() + } + + return v +} + +func (Config) mustMarshalJSON(v json.Marshaler) string { + raw, err := v.MarshalJSON() + if err != nil { + panic(err) + } + + return string(raw) +} + +// GetPFlagSet will return strongly types pflags for all fields in Config and its nested types. The format of the +// flags is json-name.json-sub-name... etc. +func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { + cmdFlags := pflag.NewFlagSet("Config", pflag.ExitOnError) + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "quboleTokenPath"), defaultConfig.QuboleTokenPath, "Where to find the Qubole secret") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "resourceManagerType"), defaultConfig.ResourceManagerType, "Which resource manager to use") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "redisHostPath"), defaultConfig.RedisHostPath, "Redis host location") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "redisHostKey"), defaultConfig.RedisHostKey, "Key for local Redis access") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "quboleLimit"), defaultConfig.QuboleLimit, "Global limit for concurrent Qubole queries") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "lruCacheSize"), defaultConfig.LruCacheSize, "Size of the AutoRefreshCache") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "lookasideBufferPrefix"), defaultConfig.LookasideBufferPrefix, "Prefix used for lookaside buffer") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "lookasideExpirySeconds"), defaultConfig.LookasideExpirySeconds.String(), "TTL for lookaside buffer if supported") + return cmdFlags +} diff --git a/flyteplugins/go/tasks/v1/qubole/config/config_flags_test.go b/flyteplugins/go/tasks/v1/qubole/config/config_flags_test.go new file mode 100755 index 0000000000..336b875795 --- /dev/null +++ b/flyteplugins/go/tasks/v1/qubole/config/config_flags_test.go @@ -0,0 +1,278 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package config + +import ( + "encoding/json" + "fmt" + "reflect" + "strings" + "testing" + + "github.com/mitchellh/mapstructure" + "github.com/stretchr/testify/assert" +) + +var dereferencableKindsConfig = map[reflect.Kind]struct{}{ + reflect.Array: {}, reflect.Chan: {}, reflect.Map: {}, reflect.Ptr: {}, reflect.Slice: {}, +} + +// Checks if t is a kind that can be dereferenced to get its underlying type. +func canGetElementConfig(t reflect.Kind) bool { + _, exists := dereferencableKindsConfig[t] + return exists +} + +// This decoder hook tests types for json unmarshaling capability. If implemented, it uses json unmarshal to build the +// object. Otherwise, it'll just pass on the original data. +func jsonUnmarshalerHookConfig(_, to reflect.Type, data interface{}) (interface{}, error) { + unmarshalerType := reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() + if to.Implements(unmarshalerType) || reflect.PtrTo(to).Implements(unmarshalerType) || + (canGetElementConfig(to.Kind()) && to.Elem().Implements(unmarshalerType)) { + + raw, err := json.Marshal(data) + if err != nil { + fmt.Printf("Failed to marshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + res := reflect.New(to).Interface() + err = json.Unmarshal(raw, &res) + if err != nil { + fmt.Printf("Failed to umarshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + return res, nil + } + + return data, nil +} + +func decode_Config(input, result interface{}) error { + config := &mapstructure.DecoderConfig{ + TagName: "json", + WeaklyTypedInput: true, + Result: result, + DecodeHook: mapstructure.ComposeDecodeHookFunc( + mapstructure.StringToTimeDurationHookFunc(), + mapstructure.StringToSliceHookFunc(","), + jsonUnmarshalerHookConfig, + ), + } + + decoder, err := mapstructure.NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(input) +} + +func join_Config(arr interface{}, sep string) string { + listValue := reflect.ValueOf(arr) + strs := make([]string, 0, listValue.Len()) + for i := 0; i < listValue.Len(); i++ { + strs = append(strs, fmt.Sprintf("%v", listValue.Index(i))) + } + + return strings.Join(strs, sep) +} + +func testDecodeJson_Config(t *testing.T, val, result interface{}) { + assert.NoError(t, decode_Config(val, result)) +} + +func testDecodeSlice_Config(t *testing.T, vStringSlice, result interface{}) { + assert.NoError(t, decode_Config(vStringSlice, result)) +} + +func TestConfig_GetPFlagSet(t *testing.T) { + val := Config{} + cmdFlags := val.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) +} + +func TestConfig_SetFlags(t *testing.T) { + actual := Config{} + cmdFlags := actual.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) + + t.Run("Test_quboleTokenPath", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("quboleTokenPath"); err == nil { + assert.Equal(t, string(defaultConfig.QuboleTokenPath), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("quboleTokenPath", testValue) + if vString, err := cmdFlags.GetString("quboleTokenPath"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.QuboleTokenPath) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_resourceManagerType", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("resourceManagerType"); err == nil { + assert.Equal(t, string(defaultConfig.ResourceManagerType), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("resourceManagerType", testValue) + if vString, err := cmdFlags.GetString("resourceManagerType"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.ResourceManagerType) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_redisHostPath", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("redisHostPath"); err == nil { + assert.Equal(t, string(defaultConfig.RedisHostPath), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("redisHostPath", testValue) + if vString, err := cmdFlags.GetString("redisHostPath"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.RedisHostPath) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_redisHostKey", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("redisHostKey"); err == nil { + assert.Equal(t, string(defaultConfig.RedisHostKey), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("redisHostKey", testValue) + if vString, err := cmdFlags.GetString("redisHostKey"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.RedisHostKey) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_quboleLimit", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("quboleLimit"); err == nil { + assert.Equal(t, int(defaultConfig.QuboleLimit), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("quboleLimit", testValue) + if vInt, err := cmdFlags.GetInt("quboleLimit"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.QuboleLimit) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_lruCacheSize", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("lruCacheSize"); err == nil { + assert.Equal(t, int(defaultConfig.LruCacheSize), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("lruCacheSize", testValue) + if vInt, err := cmdFlags.GetInt("lruCacheSize"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.LruCacheSize) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_lookasideBufferPrefix", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("lookasideBufferPrefix"); err == nil { + assert.Equal(t, string(defaultConfig.LookasideBufferPrefix), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("lookasideBufferPrefix", testValue) + if vString, err := cmdFlags.GetString("lookasideBufferPrefix"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.LookasideBufferPrefix) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_lookasideExpirySeconds", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("lookasideExpirySeconds"); err == nil { + assert.Equal(t, string(defaultConfig.LookasideExpirySeconds.String()), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := defaultConfig.LookasideExpirySeconds.String() + + cmdFlags.Set("lookasideExpirySeconds", testValue) + if vString, err := cmdFlags.GetString("lookasideExpirySeconds"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.LookasideExpirySeconds) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) +} diff --git a/flyteplugins/go/tasks/v1/qubole/hive_executor.go b/flyteplugins/go/tasks/v1/qubole/hive_executor.go new file mode 100755 index 0000000000..814ae6769b --- /dev/null +++ b/flyteplugins/go/tasks/v1/qubole/hive_executor.go @@ -0,0 +1,574 @@ +package qubole + +import ( + "context" + "fmt" + "github.com/go-redis/redis" + "strconv" + "time" + + eventErrors "github.com/lyft/flyteidl/clients/go/events/errors" + "github.com/lyft/flyteplugins/go/tasks/v1/events" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/promutils/labeled" + utils2 "github.com/lyft/flytestdlib/utils" + + tasksV1 "github.com/lyft/flyteplugins/go/tasks/v1" + "github.com/lyft/flyteplugins/go/tasks/v1/errors" + "github.com/lyft/flyteplugins/go/tasks/v1/qubole/client" + "github.com/lyft/flyteplugins/go/tasks/v1/qubole/config" + "github.com/lyft/flyteplugins/go/tasks/v1/resourcemanager" + "github.com/lyft/flyteplugins/go/tasks/v1/types" + "github.com/lyft/flyteplugins/go/tasks/v1/utils" +) + +const ResyncDuration = 30 * time.Second +const hiveExecutorId = "hiveExecutor" +const hiveTaskType = "hive" + +type HiveExecutor struct { + types.OutputsResolver + recorder types.EventRecorder + id string + secretsManager SecretsManager + executionsCache utils2.AutoRefreshCache + metrics HiveExecutorMetrics + quboleClient client.QuboleClient + redisClient *redis.Client + resourceManager resourcemanager.ResourceManager + executionBuffer resourcemanager.ExecutionLooksideBuffer +} + +type HiveExecutorMetrics struct { + Scope promutils.Scope + ReleaseResourceFailed labeled.Counter + AllocationGranted labeled.Counter + AllocationNotGranted labeled.Counter +} + +func (h HiveExecutor) GetID() types.TaskExecutorName { + return h.id +} + +func (h HiveExecutor) GetProperties() types.ExecutorProperties { + return types.ExecutorProperties{ + RequiresFinalizer: true, + } +} + +func getHiveExecutorMetrics(scope promutils.Scope) HiveExecutorMetrics { + return HiveExecutorMetrics{ + Scope: scope, + ReleaseResourceFailed: labeled.NewCounter("released_resource_failed", + "Error releasing allocation token", scope), + AllocationGranted: labeled.NewCounter("allocation_granted", + "Allocation request granted", scope), + AllocationNotGranted: labeled.NewCounter("allocation_not_granted", + "Allocation request did not fail but not granted", scope), + } +} + +// This runs once, after the constructor (since the constructor is called in the package init) +func (h *HiveExecutor) Initialize(ctx context.Context, param types.ExecutorInitializationParameters) error { + // Make sure we can get the Qubole token, and set up metrics so we have the scope + h.metrics = getHiveExecutorMetrics(param.MetricsScope) + _, err := h.secretsManager.GetToken() + if err != nil { + return err + } + + h.executionsCache, err = utils2.NewAutoRefreshCache(h.SyncQuboleQuery, + utils2.NewRateLimiter("qubole-api-updater", 5, 15), + ResyncDuration, config.GetQuboleConfig().LruCacheSize, param.MetricsScope.NewSubScope(hiveTaskType)) + if err != nil { + return err + } + + // Create Redis client + redisHost := config.GetQuboleConfig().RedisHostPath + redisPassword := config.GetQuboleConfig().RedisHostKey + redisMaxRetries := config.GetQuboleConfig().RedisMaxRetries + redisClient, err := resourcemanager.NewRedisClient(ctx, redisHost, redisPassword, redisMaxRetries) + if err != nil { + return err + } + h.redisClient = redisClient + + // Assign the resource manager here. We do it here instead of the constructor because we need to pass in metrics + resourceManager, err := resourcemanager.GetResourceManagerByType(ctx, config.GetQuboleConfig().ResourceManagerType, + param.MetricsScope, h.redisClient) + if err != nil { + return err + } + h.resourceManager = resourceManager + + // Create a lookaside buffer in Redis to hold the command IDs created by Qubole + expiryDuration := config.GetQuboleConfig().LookasideExpirySeconds.Duration + h.executionBuffer = resourcemanager.NewRedisLookasideBuffer(ctx, h.redisClient, + config.GetQuboleConfig().LookasideBufferPrefix, expiryDuration) + + h.recorder = param.EventRecorder + + h.executionsCache.Start(ctx) + + return nil +} + +func (h HiveExecutor) getUniqueCacheKey(taskCtx types.TaskContext, idx int) string { + // The cache will be holding all hive jobs across the engine, so it's imperative that the id of the cache + // items be unique. It should be unique across tasks, nodes, retries, etc. It also needs to be deterministic + // so that we know what to look for in independent calls of CheckTaskStatus + // Appending the index of the query should be sufficient. + return fmt.Sprintf("%s_%d", taskCtx.GetTaskExecutionID().GetGeneratedName(), idx) +} + +// This function is only ever called once, assuming it doesn't return in error. +// Essentially, what this function does is translate the task's custom field into the TaskContext's CustomState +// that's stored back into etcd +func (h HiveExecutor) StartTask(ctx context.Context, taskCtx types.TaskContext, task *core.TaskTemplate, + inputs *core.LiteralMap) (types.TaskStatus, error) { + // Fill in authorization stuff here in the future + hiveJob := plugins.QuboleHiveJob{} + err := utils.UnmarshalStruct(task.GetCustom(), &hiveJob) + if err != nil { + return types.TaskStatusPermanentFailure(errors.Errorf(errors.BadTaskSpecification, + "Invalid Job Specification in task: [%v]. Err: [%v]", task.GetCustom(), err)), nil + } + + // TODO: Asserts around queries, like len > 0 or something. + + // This custom state object will be passed back to us when the CheckTaskStatus call is received. + customState := make(map[string]interface{}) + + // Iterate through the queries that we'll need to run, and create custom objects for them. We don't even + // need to look at the query right now. We won't attempt to run the query yet. + for idx, q := range hiveJob.QueryCollection.Queries { + fullFlyteKey := h.getUniqueCacheKey(taskCtx, idx) + wrappedHiveJob := constructQuboleWorkItem(fullFlyteKey, "", QuboleWorkNotStarted) + + // Merge custom object Tags with labels to form Tags + tags := hiveJob.Tags + for k, v := range taskCtx.GetLabels() { + tags = append(tags, fmt.Sprintf("%s:%s", k, v)) + } + tags = append(tags, fmt.Sprintf("ns:%s", taskCtx.GetNamespace())) + wrappedHiveJob.Tags = tags + wrappedHiveJob.ClusterLabel = hiveJob.ClusterLabel + wrappedHiveJob.Query = q.Query + wrappedHiveJob.TimeoutSec = q.TimeoutSec + + customState[fullFlyteKey] = wrappedHiveJob + } + + // This Phase represents the phase of the entire Job, of all the queries. + // The Queued state is only ever used at the very beginning here. The first CheckTaskStatus + // call made on the object will return the running state. + status := types.TaskStatus{ + Phase: types.TaskPhaseQueued, + PhaseVersion: 0, + State: customState, + } + + return status, nil +} + +func (h HiveExecutor) convertCustomStateToQuboleWorkItems(customState map[string]interface{}) ( + map[string]QuboleWorkItem, error) { + + m := make(map[string]QuboleWorkItem, len(customState)) + for k, v := range customState { + // Cast the corresponding custom object + item, err := InterfaceConverter(v) + if err != nil { + return map[string]QuboleWorkItem{}, err + } + m[k] = item + } + return m, nil +} + +func (h HiveExecutor) CheckTaskStatus(ctx context.Context, taskCtx types.TaskContext, _ *core.TaskTemplate) ( + types.TaskStatus, error) { + // Get the custom task information, and the custom state information + customState := taskCtx.GetCustomState() + logger.Infof(ctx, "Checking status for task execution [%s] Phase [%v] length of custom [%d]", + taskCtx.GetTaskExecutionID().GetGeneratedName(), taskCtx.GetPhase(), len(customState)) + quboleApiKey, _ := h.secretsManager.GetToken() + + // Loop through all the queries and do whatever needs to be done + // Also accumulate the new CustomState while iterating + var newItems = make(map[string]interface{}) + quboleAttempts := 0 + quboleFailures := 0 + for workCacheKey, v := range customState { + // Cast the corresponding custom object + item, err := InterfaceConverter(v) + if err != nil { + logger.Errorf(ctx, "Error converting old state into an object for key %s", workCacheKey) + return types.TaskStatusUndefined, err + } + logger.Debugf(ctx, "CheckTaskStatus, customState iteration - key [%s] id [%s] status [%s]", + item.UniqueWorkCacheKey, item.CommandId, item.Status) + + // This copies the items in the cache into new objects. It's important to leave the initial custom state + // untouched because we compare new to old later for eventing. + // This if block handles transitions from NotStarted to Running - i.e. attempt to create the query on Qubole + // What happens if a job has ten queries and 8 of them launch successfully, but two of fail because + // of a Qubole error that has nothing to do with the user's code, temporary Qubole flakiness for instance + // To resolve this, we keep track of the errors in launching Qubole commands + // - if all calls fail, then we return a system level error + // - if only some calls fail, then that means we've updated the custom state with new Qubole command IDs + // so we shouldn't waste those, return as normal so that they get recorded + if item.Status == QuboleWorkNotStarted { + foundCommandId, err := h.executionBuffer.RetrieveExecution(ctx, workCacheKey) + if err != nil { + if err != resourcemanager.ExecutionNotFoundError { + logger.Errorf(ctx, "Unable to retrieve from cache for %s", workCacheKey) + return types.TaskStatusUndefined, errors.Wrapf(errors.DownstreamSystemError, err, + "unable to retrieve from cache for %s", workCacheKey) + } + // Get an allocation token + logger.Infof(ctx, "Attempting to get allocation token for %s", workCacheKey) + allocationStatus, err := h.resourceManager.AllocateResource(ctx, taskCtx.GetNamespace(), workCacheKey) + if err != nil { + logger.Errorf(ctx, "Resource manager broke for [%s] key [%s], owner [%s]", + taskCtx.GetTaskExecutionID().GetID(), workCacheKey, taskCtx.GetOwnerReference()) + return types.TaskStatusUndefined, err + } + logger.Infof(ctx, "Allocation result for [%s] is [%s]", workCacheKey, allocationStatus) + + // If successfully got an allocation token then kick off the query, and try to progress the job state + // if no token was granted, we stay in the NotStarted state. + if allocationStatus == resourcemanager.AllocationStatusGranted { + h.metrics.AllocationGranted.Inc(ctx) + + quboleAttempts++ + // Note that the query itself doesn't live in the work item object that's cached. That would take + // up too much room in the workflow CRD - instead we iterate through the task's custom field + // each time. + cmdDetails, err := h.quboleClient.ExecuteHiveCommand(ctx, item.Query, item.TimeoutSec, + item.ClusterLabel, quboleApiKey, item.Tags) + if err != nil { + // If we failed, we'll keep the NotStarted state + logger.Warnf(ctx, "Error creating Qubole query for %s", item.UniqueWorkCacheKey) + quboleFailures++ + // Deallocate token if Qubole API returns in error. + err := h.resourceManager.ReleaseResource(ctx, taskCtx.GetNamespace(), workCacheKey) + if err != nil { + h.metrics.ReleaseResourceFailed.Inc(ctx) + } + } else { + commandId := strconv.FormatInt(cmdDetails.ID, 10) + logger.Infof(ctx, "Created Qubole ID %s for %s", commandId, workCacheKey) + item.CommandId = commandId + item.Status = QuboleWorkRunning + item.Query = "" // Clear the query to save space in etcd once we've successfully launched + err := h.executionBuffer.ConfirmExecution(ctx, workCacheKey, commandId) + if err != nil { + logger.Errorf(ctx, "Unable to record execution for %s", workCacheKey) + return types.TaskStatusUndefined, errors.Wrapf(errors.DownstreamSystemError, err, + "unable to record execution for %s", workCacheKey) + } + } + } else { + h.metrics.AllocationNotGranted.Inc(ctx) + logger.Infof(ctx, "Unable to get allocation token for %s skipping...", workCacheKey) + } + } else { + // If found, this means that we've previously kicked off this execution, but CheckTaskStatus + // has been called with stale context. + logger.Infof(ctx, "Unstarted Qubole work found in buffer for %s, setting to running with ID %s", + workCacheKey, foundCommandId) + item.CommandId = foundCommandId + item.Status = QuboleWorkRunning + item.Query = "" // Clear the query to save space in etcd once we've successfully launched + } + } + + // Add to the cache iff the item from the taskContext has a command ID (ie, has already been launched on Qubole) + // Then check and update the item if necessary. + if item.CommandId != "" { + logger.Debugf(ctx, "Calling GetOrCreate for [%s] id [%s] status [%s]", + item.UniqueWorkCacheKey, item.CommandId, item.Status) + cc, err := h.executionsCache.GetOrCreate(item) + if err != nil { + // This means that our cache has fundamentally broken... return a system error + logger.Errorf(ctx, "Cache is broken on execution [%s] cache key [%s], owner [%s]", + taskCtx.GetTaskExecutionID().GetID(), workCacheKey, taskCtx.GetOwnerReference()) + return types.TaskStatusUndefined, err + } + + cachedItem := cc.(QuboleWorkItem) + logger.Debugf(ctx, "Finished GetOrCreate - cache key [%s]->[%s] status [%s]->[%s]", + item.UniqueWorkCacheKey, cachedItem.UniqueWorkCacheKey, item.Status, cachedItem.Status) + + // TODO: Remove this sanity check if still here by late July 2019 + // This is a sanity check - get the item back immediately + sanityCheck := h.executionsCache.Get(item.UniqueWorkCacheKey) + if sanityCheck == nil { + // This means that our cache has fundamentally broken... return a system error + logger.Errorf(ctx, "Cache is b0rked!!! Unless there are a lot of evictions happening, a GetOrCreate"+ + " has failed to actually create!!! Cache key [%s], owner [%s]", + workCacheKey, taskCtx.GetOwnerReference()) + } else { + sanityCheckCast := sanityCheck.(QuboleWorkItem) + logger.Debugf(ctx, "Immediate cache write check worked [%s] status [%s]", + sanityCheckCast.UniqueWorkCacheKey, sanityCheckCast.Status) + } + // Handle all transitions after the initial one - If the one from the cache has a higher value, + // that means our loop has done something, and we should update the new custom state to reflect that. + if cachedItem.Status > item.Status { + item.Status = cachedItem.Status + } + + // Always copy the number of update retries + item.Retries = cachedItem.Retries + } + + // Always add the potentially modified item back to the new list so that it again can be persisted + // into etcd + newItems[workCacheKey] = item + } + + // If all creation attempts fail, then report a system error + if quboleFailures > 0 && quboleAttempts == quboleFailures { + err := errors.Errorf(errors.DownstreamSystemError, "All %d Hive creation attempts failed for %s", + quboleFailures, taskCtx.GetTaskExecutionID().GetGeneratedName()) + logger.Error(ctx, err) + return types.TaskStatusUndefined, err + } + + // Otherwise, look through the current state of things and decide what's up. + newStatus := h.TranslateCurrentState(newItems) + newStatus.PhaseVersion = taskCtx.GetPhaseVersion() + + // Determine whether or not to send an event. If the phase has changed, then we definitely want to, if not, + // we need to compare all the individual items to see if any were updated. + var sendEvent = false + if taskCtx.GetPhase() != newStatus.Phase { + newStatus.PhaseVersion = 0 + sendEvent = true + } else { + oldItems, err := h.convertCustomStateToQuboleWorkItems(customState) + if err != nil { + // This error condition should not trigger because the exact same thing should've been done earlier + logger.Errorf(ctx, "Error converting custom state %v", err) + return types.TaskStatusUndefined, err + } + if !workItemMapsAreEqual(oldItems, newItems) { + // If any of the items we updated, we also need to increment the version in order for admin to record it + newStatus.PhaseVersion++ + sendEvent = true + } + } + if sendEvent { + info, err := constructEventInfoFromQuboleWorkItems(taskCtx, newStatus.State) + if err != nil { + logger.Errorf(ctx, "Error constructing event info for %s", + taskCtx.GetTaskExecutionID().GetGeneratedName()) + return types.TaskStatusUndefined, err + } + + ev := events.CreateEvent(taskCtx, newStatus, info) + + err = h.recorder.RecordTaskEvent(ctx, ev) + if err != nil && eventErrors.IsEventAlreadyInTerminalStateError(err) { + return types.TaskStatusPermanentFailure(errors.Wrapf(errors.TaskEventRecordingFailed, err, + "failed to record task event. state mis-match between Propeller %v and Control Plane.", &ev.Phase)), nil + } else if err != nil { + return types.TaskStatusUndefined, errors.Wrapf(errors.TaskEventRecordingFailed, err, + "failed to record task event") + } + } + logger.Debugf(ctx, "Task [%s] phase [%s]->[%s] phase version [%d]->[%d] sending event: %s", + taskCtx.GetTaskExecutionID().GetGeneratedName(), taskCtx.GetPhase(), newStatus.Phase, taskCtx.GetPhaseVersion(), + newStatus.PhaseVersion, sendEvent) + + return newStatus, nil +} + +// This translates a series of QuboleWorkItem statuses into what it means for the task as a whole +func (h HiveExecutor) TranslateCurrentState(state map[string]interface{}) types.TaskStatus { + succeeded := 0 + failed := 0 + total := len(state) + status := types.TaskStatus{ + State: state, + } + + for _, k := range state { + workItem := k.(QuboleWorkItem) + if workItem.Status == QuboleWorkSucceeded { + succeeded++ + } else if workItem.Status == QuboleWorkFailed { + failed++ + } + } + + if succeeded == total { + status.Phase = types.TaskPhaseSucceeded + } else if failed > 0 { + status.Phase = types.TaskPhaseRetryableFailure + status.Err = errors.Errorf(errors.DownstreamSystemError, "Qubole job failed") + } else { + status.Phase = types.TaskPhaseRunning + } + + return status +} + +// Loop through all the queries in the task, if there are any in a non-terminal state, then +// submit the request to terminate the Qubole query. If there are any problems with anything, then return +// an error +func (h HiveExecutor) KillTask(ctx context.Context, taskCtx types.TaskContext, reason string) error { + // Is it ever possible to get a CheckTaskStatus call for a task while this function is running? + // Or immediately after this function runs? + customState := taskCtx.GetCustomState() + logger.Infof(ctx, "Kill task called on [%s] with [%d] customs", taskCtx.GetTaskExecutionID().GetGeneratedName(), + len(customState)) + quboleApiKey, _ := h.secretsManager.GetToken() + + var callsWithErrors = make([]string, 0, len(customState)) + + for key, value := range customState { + work, err := InterfaceConverter(value) + if err != nil { + logger.Errorf(ctx, "Error converting old state into an object for key %s", work.UniqueWorkCacheKey) + return err + } + logger.Debugf(ctx, "KillTask processing custom item key [%s] id [%s] on cluster [%s]", + work.UniqueWorkCacheKey, work.CommandId, work.ClusterLabel) + + status, err := h.quboleClient.GetCommandStatus(ctx, work.CommandId, quboleApiKey) + if err != nil { + logger.Errorf(ctx, "Problem getting command status while terminating %s %s %v", + work.CommandId, taskCtx.GetTaskExecutionID().GetGeneratedName(), err) + callsWithErrors = append(callsWithErrors, work.CommandId) + continue + } + + if !QuboleWorkIsTerminalState(QuboleStatusToWorkItemStatus(status)) { + logger.Debugf(ctx, "Terminating cache item [%s] id [%s] status [%s]", + work.UniqueWorkCacheKey, work.CommandId, work.Status) + + err := h.quboleClient.KillCommand(ctx, work.CommandId, quboleApiKey) + if err != nil { + logger.Errorf(ctx, "Error stopping Qubole command in termination sequence %s from %s with %v", + work.CommandId, key, err) + callsWithErrors = append(callsWithErrors, work.CommandId) + continue + } + err = h.resourceManager.ReleaseResource(ctx, "", work.UniqueWorkCacheKey) + if err != nil { + logger.Errorf(ctx, "Failed to release resource [%s]", work.UniqueWorkCacheKey) + h.metrics.ReleaseResourceFailed.Inc(ctx) + } + logger.Debugf(ctx, "Finished terminating cache item [%s] id [%s]", + work.UniqueWorkCacheKey, work.CommandId) + } else { + logger.Debugf(ctx, "Custom work in terminal state [%s] id [%s] status [%s]", + work.UniqueWorkCacheKey, work.CommandId, work.Status) + + // This is idempotent anyways, just be tripley safe we're not leaking resources + err := h.resourceManager.ReleaseResource(ctx, "", work.UniqueWorkCacheKey) + if err != nil { + logger.Errorf(ctx, "Failed to release resource [%s]", work.UniqueWorkCacheKey) + h.metrics.ReleaseResourceFailed.Inc(ctx) + } + } + } + + if len(callsWithErrors) > 0 { + return errors.Errorf(errors.DownstreamSystemError, "%d errors found for Qubole commands %v", + len(callsWithErrors), callsWithErrors) + } + + return nil +} + +// This should do minimal work - basically grab an updated status from the Qubole API and store it in the cache +// All other handling should be in the synchronous loop. +func (h *HiveExecutor) SyncQuboleQuery(ctx context.Context, obj utils2.CacheItem) ( + utils2.CacheItem, utils2.CacheSyncAction, error) { + + workItem := obj.(QuboleWorkItem) + + // TODO: Remove this if block if still here by late July 2019. This should not happen any more ever. + if workItem.CommandId == "" { + logger.Debugf(ctx, "Sync loop - CommandID is blank for [%s] skipping", workItem.UniqueWorkCacheKey) + // No need to do anything if the work hasn't been kicked off yet + return workItem, utils2.Unchanged, nil + } + + logger.Debugf(ctx, "Sync loop - processing Hive job [%s] - cache key [%s]", + workItem.CommandId, workItem.UniqueWorkCacheKey) + + quboleApiKey, _ := h.secretsManager.GetToken() + + if QuboleWorkIsTerminalState(workItem.Status) { + // Release again - this is idempotent anyways, shouldn't be a huge deal to be on the safe side and release + // many times. + logger.Debugf(ctx, "Sync loop - Qubole id [%s] in terminal state, re-releasing cache key [%s]", + workItem.CommandId, workItem.UniqueWorkCacheKey) + + err := h.resourceManager.ReleaseResource(ctx, "", workItem.UniqueWorkCacheKey) + if err != nil { + h.metrics.ReleaseResourceFailed.Inc(ctx) + } + return workItem, utils2.Unchanged, nil + } + + // Get an updated status from Qubole + logger.Debugf(ctx, "Querying Qubole for %s - %s", workItem.CommandId, workItem.UniqueWorkCacheKey) + commandStatus, err := h.quboleClient.GetCommandStatus(ctx, workItem.CommandId, quboleApiKey) + if err != nil { + logger.Errorf(ctx, "Error from Qubole command %s", workItem.CommandId) + workItem.Retries++ + // Make sure we don't return nil for the first argument, because that deletes it from the cache. + return workItem, utils2.Update, err + } + workItemStatus := QuboleStatusToWorkItemStatus(commandStatus) + + // Careful how we call this, don't want to ever go backwards, unless it's unknown + if workItemStatus > workItem.Status || workItemStatus == QuboleWorkUnknown { + workItem.Status = workItemStatus + logger.Infof(ctx, "Moving Qubole work %s %s from %s to %s", workItem.CommandId, workItem.UniqueWorkCacheKey, + workItem.Status, workItemStatus) + + if QuboleWorkIsTerminalState(workItem.Status) { + err := h.resourceManager.ReleaseResource(ctx, "", workItem.UniqueWorkCacheKey) + if err != nil { + h.metrics.ReleaseResourceFailed.Inc(ctx) + } + } + + return workItem, utils2.Update, nil + } + + return workItem, utils2.Unchanged, nil +} + +func NewHiveTaskExecutorWithCache(ctx context.Context) (*HiveExecutor, error) { + hiveExecutor := HiveExecutor{ + id: hiveExecutorId, + secretsManager: NewSecretsManager(), + quboleClient: client.NewQuboleClient(), + } + + return &hiveExecutor, nil +} + +func init() { + tasksV1.RegisterLoader(func(ctx context.Context) error { + hiveExecutor, err := NewHiveTaskExecutorWithCache(ctx) + if err != nil { + return err + } + + return tasksV1.RegisterForTaskTypes(hiveExecutor, hiveTaskType) + }) +} diff --git a/flyteplugins/go/tasks/v1/qubole/hive_executor_test.go b/flyteplugins/go/tasks/v1/qubole/hive_executor_test.go new file mode 100755 index 0000000000..915b8f188c --- /dev/null +++ b/flyteplugins/go/tasks/v1/qubole/hive_executor_test.go @@ -0,0 +1,403 @@ +package qubole + +import ( + "context" + "errors" + "testing" + + eventErrors "github.com/lyft/flyteidl/clients/go/events/errors" + "github.com/lyft/flyteplugins/go/tasks/v1/qubole/client" + clientMocks "github.com/lyft/flyteplugins/go/tasks/v1/qubole/client/mocks" + "github.com/lyft/flyteplugins/go/tasks/v1/resourcemanager" + resourceManagerMocks "github.com/lyft/flyteplugins/go/tasks/v1/resourcemanager/mocks" + "github.com/lyft/flyteplugins/go/tasks/v1/types" + eventMocks "github.com/lyft/flyteplugins/go/tasks/v1/types/mocks" + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +// This line is here because we need to generate a mock for the cache as the flytestdlib does not have one. +// Remove this if we ever do generate one inside flytestdlib itself. +//go:generate mockery -dir ../../../../vendor/github.com/lyft/flytestdlib/utils -name AutoRefreshCache + +func getDummyHiveExecutor() *HiveExecutor { + return &HiveExecutor{ + OutputsResolver: types.OutputsResolver{}, + recorder: &eventMocks.EventRecorder{}, + id: "test-hive-executor", + secretsManager: MockSecretsManager{}, + executionsCache: &MockAutoRefreshCache{}, + metrics: getHiveExecutorMetrics(promutils.NewTestScope()), + quboleClient: &clientMocks.QuboleClient{}, + redisClient: nil, + resourceManager: resourcemanager.NoopResourceManager{}, + executionBuffer: &resourceManagerMocks.ExecutionLooksideBuffer{}, + } +} + +func TestUniqueCacheKey(t *testing.T) { + mockTaskContext := CreateMockTaskContextWithRealTaskExecId() + executor := getDummyHiveExecutor() + + out := executor.getUniqueCacheKey(mockTaskContext, 42) + assert.Equal(t, "test-hive-job_42", out) +} + +func TestTranslateCurrentState(t *testing.T) { + t.Run("just one item", func(t *testing.T) { + workItems := map[string]interface{}{ + "key_1": NewQuboleWorkItem( + "key_1", + "12345", + QuboleWorkSucceeded, + "default", + []string{}, + 0, + ), + } + + executor := getDummyHiveExecutor() + taskStatus := executor.TranslateCurrentState(workItems) + assert.Equal(t, types.TaskPhaseSucceeded, taskStatus.Phase) + }) + + t.Run("partial completion", func(t *testing.T) { + workItems := map[string]interface{}{ + "key_1": NewQuboleWorkItem( + "key_1", + "12345", + QuboleWorkSucceeded, + "default", + []string{}, + 0, + ), + "key_2": NewQuboleWorkItem( + "key_2", + "45645", + QuboleWorkRunning, + "default", + []string{}, + 0, + ), + } + + executor := getDummyHiveExecutor() + taskStatus := executor.TranslateCurrentState(workItems) + assert.Equal(t, types.TaskPhaseRunning, taskStatus.Phase) + }) + + t.Run("any failure is a failure", func(t *testing.T) { + workItems := map[string]interface{}{ + "key_1": NewQuboleWorkItem( + "key_1", + "12345", + QuboleWorkSucceeded, + "default", + []string{}, + 0, + ), + "key_2": NewQuboleWorkItem( + "key_2", + "45645", + QuboleWorkFailed, + "default", + []string{}, + 0, + ), + } + + executor := getDummyHiveExecutor() + taskStatus := executor.TranslateCurrentState(workItems) + assert.Equal(t, types.TaskPhaseRetryableFailure, taskStatus.Phase) + }) +} + +func TestHiveExecutor_StartTask(t *testing.T) { + ctx := context.Background() + mockContext := CreateMockTaskContextWithRealTaskExecId() + taskTemplate := createDummyHiveTaskTemplate("hive-task-id") + + executor := getDummyHiveExecutor() + taskStatus, err := executor.StartTask(ctx, mockContext, taskTemplate, nil) + assert.NoError(t, err) + assert.Equal(t, types.TaskPhaseQueued, taskStatus.Phase) + customState := taskStatus.State + assert.Equal(t, 1, len(customState)) + + workItem := customState["test-hive-job_0"].(QuboleWorkItem) + + assert.Equal(t, []string{"tag1", "tag2", "label-1:val1", "ns:test-namespace"}, workItem.Tags) + assert.Equal(t, "cluster-label", workItem.ClusterLabel) + assert.Equal(t, QuboleWorkNotStarted, workItem.Status) +} + +func NewMockedHiveTaskExecutor() *HiveExecutor { + mockCache := NewMockAutoRefreshCache() + mockQubole := &clientMocks.QuboleClient{} + mockResourceManager := &resourceManagerMocks.ResourceManager{} + mockEventRecorder := &eventMocks.EventRecorder{} + mockExecutionBuffer := &resourceManagerMocks.ExecutionLooksideBuffer{} + + return &HiveExecutor{ + OutputsResolver: types.OutputsResolver{}, + recorder: mockEventRecorder, + id: "test-hive-executor", + secretsManager: MockSecretsManager{}, + executionsCache: mockCache, + metrics: getHiveExecutorMetrics(promutils.NewTestScope()), + quboleClient: mockQubole, + redisClient: nil, + resourceManager: mockResourceManager, + executionBuffer: mockExecutionBuffer, + } +} + +func TestHiveExecutor_CheckTaskStatus(t *testing.T) { + ctx := context.Background() + + t.Run("basic lifecycle", func(t *testing.T) { + taskTemplate := createDummyHiveTaskTemplate("hive-task-id") + + // Get executor and add hooks for mocks + executor := NewMockedHiveTaskExecutor() + executor.resourceManager.(*resourceManagerMocks.ResourceManager).On("AllocateResource", + mock.Anything, mock.Anything, mock.Anything).Return(resourcemanager.AllocationStatusGranted, nil) + + var eventRecorded = false + executor.recorder.(*eventMocks.EventRecorder).On("RecordTaskEvent", mock.Anything, mock.Anything).Return(nil).Run(func(args mock.Arguments) { + eventRecorded = true + }) + executor.executionBuffer.(*resourceManagerMocks.ExecutionLooksideBuffer).On("RetrieveExecution", + mock.Anything, mock.Anything).Return("", resourcemanager.ExecutionNotFoundError) + executor.executionBuffer.(*resourceManagerMocks.ExecutionLooksideBuffer).On("ConfirmExecution", + mock.Anything, mock.Anything, mock.Anything).Return(nil) + + // StartTask + taskStatus_0, err := executor.StartTask(ctx, CreateMockTaskContextWithRealTaskExecId(), taskTemplate, nil) + assert.NoError(t, err) + + // Create a new mock task context to return the first custom state, the one constructed by the StartTask call + mockContext := CreateMockTaskContextWithRealTaskExecId() + mockContext.On("GetCustomState").Return(taskStatus_0.State) + mockContext.On("GetPhase").Return(types.TaskPhaseQueued) + mockContext.On("GetPhaseVersion").Return(uint32(0)) + + // Call CheckTaskStatus twice + // The first time the mock qubole client will create the query + executor.quboleClient.(*clientMocks.QuboleClient).On("ExecuteHiveCommand", mock.Anything, + mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return( + &client.QuboleCommandDetails{ + Status: client.QuboleStatusRunning, + ID: 55482218961153, + }, nil) + taskStatus_1, err := executor.CheckTaskStatus(ctx, mockContext, taskTemplate) + assert.NoError(t, err) + assert.Equal(t, types.TaskPhaseRunning, taskStatus_1.Phase) + assert.True(t, eventRecorded) + customState := taskStatus_1.State + assert.Equal(t, 1, len(customState)) + workItem := customState["test-hive-job_0"].(QuboleWorkItem) + assert.Equal(t, "55482218961153", workItem.CommandId) + assert.Equal(t, QuboleWorkRunning, workItem.Status) + + // This bit mimics what the AutoRefreshCache would've been doing in the background. Pull out the workitem, + // update the status to succeeded, and put back into the cache. + cachedWorkItem := executor.executionsCache.(MockAutoRefreshCache).values["test-hive-job_0"].(QuboleWorkItem) + cachedWorkItem.Status = QuboleWorkSucceeded + executor.executionsCache.(MockAutoRefreshCache).values["test-hive-job_0"] = cachedWorkItem + + // Second call to CheckTaskStatus, cache will say that it's finished + // Reset eventRecorded + eventRecorded = false + mockContext = CreateMockTaskContextWithRealTaskExecId() + mockContext.On("GetCustomState").Return(taskStatus_1.State) + mockContext.On("GetPhase").Return(types.TaskPhaseRunning) + taskStatus_2, err := executor.CheckTaskStatus(ctx, mockContext, taskTemplate) + assert.NoError(t, err) + assert.Equal(t, types.TaskPhaseSucceeded, taskStatus_2.Phase) + customState = taskStatus_2.State + assert.True(t, eventRecorded) + assert.Equal(t, 1, len(customState)) + workItem = customState["test-hive-job_0"].(QuboleWorkItem) + assert.Equal(t, "55482218961153", workItem.CommandId) + assert.Equal(t, QuboleWorkSucceeded, workItem.Status) + }) + + t.Run("resource manager gates properly", func(t *testing.T) { + // Get executor and add hooks for mocks + executor := NewMockedHiveTaskExecutor() + executor.resourceManager.(*resourceManagerMocks.ResourceManager).On("AllocateResource", mock.Anything, mock.Anything, mock.Anything). + Return(resourcemanager.AllocationStatusExhausted, nil) + executor.recorder.(*eventMocks.EventRecorder).On("RecordTaskEvent", mock.Anything, mock.Anything).Return(nil) + executor.executionBuffer.(*resourceManagerMocks.ExecutionLooksideBuffer).On("RetrieveExecution", + mock.Anything, mock.Anything).Return("", resourcemanager.ExecutionNotFoundError) + + mockContext := CreateMockTaskContextWithRealTaskExecId() + taskTemplate := createDummyHiveTaskTemplate("hive-task-id") + + taskStatus_0, err := executor.StartTask(ctx, mockContext, taskTemplate, nil) + assert.NoError(t, err) + mockContext.On("GetCustomState").Return(taskStatus_0.State) + mockContext.On("GetPhase").Return(types.TaskPhaseQueued) + + // If the AllocateResource call doesn't return successfully, the query is not kicked off + executor.quboleClient.(*clientMocks.QuboleClient).On("ExecuteHiveCommand", + mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return( + &client.QuboleCommandDetails{}, nil) + taskStatus_1, err := executor.CheckTaskStatus(ctx, mockContext, taskTemplate) + + assert.NoError(t, err) + assert.Equal(t, types.TaskPhaseRunning, taskStatus_1.Phase) + customState := taskStatus_1.State + assert.Equal(t, 1, len(customState)) + workItem := customState["test-hive-job_0"].(QuboleWorkItem) + assert.Equal(t, QuboleWorkNotStarted, workItem.Status) + assert.Equal(t, 0, len(executor.quboleClient.(*clientMocks.QuboleClient).Calls)) + }) + + t.Run("executor doesn't launch when already in buffer", func(t *testing.T) { + var allocationCalled = false + + // Get executor and add hooks for mocks + executor := NewMockedHiveTaskExecutor() + executor.resourceManager.(*resourceManagerMocks.ResourceManager).On("AllocateResource", mock.Anything, + mock.Anything, mock.Anything).Return(resourcemanager.AllocationStatusGranted, nil).Run(func(args mock.Arguments) { + allocationCalled = true + }) + executor.recorder.(*eventMocks.EventRecorder).On("RecordTaskEvent", mock.Anything, mock.Anything).Return(nil) + executor.executionBuffer.(*resourceManagerMocks.ExecutionLooksideBuffer).On("RetrieveExecution", + mock.Anything, mock.Anything).Return("123456", nil) + + taskTemplate := createDummyHiveTaskTemplate("hive-task-id") + taskStatus_0, err := executor.StartTask(ctx, CreateMockTaskContextWithRealTaskExecId(), taskTemplate, nil) + assert.NoError(t, err) + workItem_0 := taskStatus_0.State["test-hive-job_0"].(QuboleWorkItem) + // object should be initialized in the not started state + assert.Equal(t, QuboleWorkNotStarted, workItem_0.Status) + + // Create a new mock task context to return the first custom state, the one constructed by the StartTask call + mockContext := CreateMockTaskContextWithRealTaskExecId() + mockContext.On("GetCustomState").Return(taskStatus_0.State) + mockContext.On("GetPhase").Return(types.TaskPhaseQueued) + mockContext.On("GetPhaseVersion").Return(uint32(0)) + + taskStatus_1, err := executor.CheckTaskStatus(ctx, mockContext, taskTemplate) + assert.NoError(t, err) + assert.False(t, allocationCalled) + assert.Equal(t, types.TaskPhaseRunning, taskStatus_1.Phase) + customState := taskStatus_1.State + assert.Equal(t, 1, len(customState)) + workItem_1 := customState["test-hive-job_0"].(QuboleWorkItem) + // If found in the lookaside buffer, CheckTaskStatus should set the status to running, and set the command ID + assert.Equal(t, "123456", workItem_1.CommandId) + assert.Equal(t, QuboleWorkRunning, workItem_1.Status) + }) +} + +func TestHiveExecutor_SyncQuboleQuery(t *testing.T) { + ctx := context.Background() + executor := getDummyHiveExecutor() + + t.Run("command failed", func(t *testing.T) { + mockQubole := &clientMocks.QuboleClient{} + mockQubole.On("GetCommandStatus", mock.Anything, mock.Anything, mock.Anything). + Return(client.QuboleStatusError, nil) + executor.quboleClient = mockQubole + + workItem := QuboleWorkItem{ + CommandId: "123456789", + Status: QuboleWorkRunning, + } + x, action, err := executor.SyncQuboleQuery(ctx, workItem) + newWorkItem := x.(QuboleWorkItem) + assert.NoError(t, err) + assert.Equal(t, QuboleWorkFailed, newWorkItem.Status) + assert.Equal(t, utils.Update, action) + }) + + t.Run("command still running", func(t *testing.T) { + mockQubole := &clientMocks.QuboleClient{} + mockQubole.On("GetCommandStatus", mock.Anything, mock.Anything, mock.Anything). + Return(client.QuboleStatusRunning, nil) + executor.quboleClient = mockQubole + + workItem := QuboleWorkItem{ + CommandId: "123456789", + Status: QuboleWorkRunning, + } + x, action, err := executor.SyncQuboleQuery(ctx, workItem) + newWorkItem := x.(QuboleWorkItem) + assert.NoError(t, err) + assert.Equal(t, QuboleWorkRunning, newWorkItem.Status) + assert.Equal(t, utils.Unchanged, action) + }) + + t.Run("command succeeded", func(t *testing.T) { + mockQubole := &clientMocks.QuboleClient{} + mockQubole.On("GetCommandStatus", mock.Anything, mock.Anything, mock.Anything). + Return(client.QuboleStatusDone, nil) + executor.quboleClient = mockQubole + + mockResourceManager := &resourceManagerMocks.ResourceManager{} + mockResourceManager.On("ReleaseResource", mock.Anything, mock.Anything, mock.Anything). + Return(nil) + executor.resourceManager = mockResourceManager + + workItem := QuboleWorkItem{ + CommandId: "123456789", + Status: QuboleWorkRunning, + } + x, action, err := executor.SyncQuboleQuery(ctx, workItem) + newWorkItem := x.(QuboleWorkItem) + assert.NoError(t, err) + assert.Equal(t, QuboleWorkSucceeded, newWorkItem.Status) + assert.Equal(t, 1, len(mockResourceManager.Calls)) + assert.Equal(t, utils.Update, action) + }) +} + +func TestHiveExecutor_CheckTaskStatusStateMismatch(t *testing.T) { + ctx := context.Background() + taskTemplate := createDummyHiveTaskTemplate("hive-task-id") + mockQubole := &clientMocks.QuboleClient{} + + mockResourceManager := &resourceManagerMocks.ResourceManager{} + mockResourceManager.On("AllocateResource", mock.Anything, mock.Anything, mock.Anything). + Return(resourcemanager.AllocationStatusGranted, nil) + mockEventRecorder := &eventMocks.EventRecorder{} + mockEventRecorder.On("RecordTaskEvent", mock.Anything, mock.Anything).Return(&eventErrors.EventError{Code: eventErrors.EventAlreadyInTerminalStateError, + Cause: errors.New("already exists"), + }) + + executor := NewMockedHiveTaskExecutor() + executor.recorder = mockEventRecorder + executor.resourceManager = mockResourceManager + executor.quboleClient = mockQubole + executor.executionBuffer.(*resourceManagerMocks.ExecutionLooksideBuffer).On("RetrieveExecution", + mock.Anything, mock.Anything).Return("", resourcemanager.ExecutionNotFoundError) + executor.executionBuffer.(*resourceManagerMocks.ExecutionLooksideBuffer).On("ConfirmExecution", + mock.Anything, mock.Anything, mock.Anything).Return(nil) + taskstatus0, err := executor.StartTask(ctx, CreateMockTaskContextWithRealTaskExecId(), taskTemplate, nil) + assert.NoError(t, err) + + // Create a new mock task context to return the first custom state, the one constructed by the StartTask call + mockContext := CreateMockTaskContextWithRealTaskExecId() + mockContext.On("GetCustomState").Return(taskstatus0.State) + mockContext.On("GetPhase").Return(types.TaskPhaseSucceeded) + mockContext.On("GetPhaseVersion").Return(uint32(0)) + + // Unit test will call CheckTaskStatus twice + // - the first time the mock qubole client will create the query + mockQubole.On("ExecuteHiveCommand", mock.Anything, mock.Anything, mock.Anything, + mock.Anything, mock.Anything, mock.Anything).Return( + &client.QuboleCommandDetails{ + Status: client.QuboleStatusRunning, + ID: 55482218961153, + }, nil) + taskstatus1, err := executor.CheckTaskStatus(ctx, mockContext, taskTemplate) + assert.NoError(t, err) + assert.Equal(t, types.TaskPhasePermanentFailure, taskstatus1.Phase) + assert.Nil(t, taskstatus1.State) +} diff --git a/flyteplugins/go/tasks/v1/qubole/mocks/AutoRefreshCache.go b/flyteplugins/go/tasks/v1/qubole/mocks/AutoRefreshCache.go new file mode 100755 index 0000000000..9c6f13efc2 --- /dev/null +++ b/flyteplugins/go/tasks/v1/qubole/mocks/AutoRefreshCache.go @@ -0,0 +1,56 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import context "context" +import mock "github.com/stretchr/testify/mock" +import utils "github.com/lyft/flytestdlib/utils" + +// AutoRefreshCache is an autogenerated mock type for the AutoRefreshCache type +type AutoRefreshCache struct { + mock.Mock +} + +// Get provides a mock function with given fields: id +func (_m *AutoRefreshCache) Get(id string) utils.CacheItem { + ret := _m.Called(id) + + var r0 utils.CacheItem + if rf, ok := ret.Get(0).(func(string) utils.CacheItem); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(utils.CacheItem) + } + } + + return r0 +} + +// GetOrCreate provides a mock function with given fields: item +func (_m *AutoRefreshCache) GetOrCreate(item utils.CacheItem) (utils.CacheItem, error) { + ret := _m.Called(item) + + var r0 utils.CacheItem + if rf, ok := ret.Get(0).(func(utils.CacheItem) utils.CacheItem); ok { + r0 = rf(item) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(utils.CacheItem) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(utils.CacheItem) error); ok { + r1 = rf(item) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Start provides a mock function with given fields: ctx +func (_m *AutoRefreshCache) Start(ctx context.Context) { + _m.Called(ctx) +} diff --git a/flyteplugins/go/tasks/v1/qubole/qubole_work.go b/flyteplugins/go/tasks/v1/qubole/qubole_work.go new file mode 100755 index 0000000000..bcfebf6995 --- /dev/null +++ b/flyteplugins/go/tasks/v1/qubole/qubole_work.go @@ -0,0 +1,202 @@ +package qubole + +import ( + "fmt" + "encoding/json" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flyteplugins/go/tasks/v1/events" + "github.com/lyft/flyteplugins/go/tasks/v1/qubole/client" + "github.com/lyft/flyteplugins/go/tasks/v1/types" + "github.com/lyft/flyteplugins/go/tasks/v1/utils" +) + +// This struct is supposed to represent all the details of one query/unit of work on Qubole. For instance, a user's +// @qubole_hive_task will get unpacked to one of these for each query contained in the task. +// It is intentionally vaguely named, in an effort to potentially support extensibility to other things Qubole +// is capable of executing in the future. +// Retries and Status are the only two fields that should get changed. +type QuboleWorkItem struct { + // This ID is the cache key and so will need to be unique across all objects in the cache (it will probably be + // unique across all of Flyte) and needs to be deterministic. + // This will also be used as the allocation token for now. + UniqueWorkCacheKey string `json:"unique_work_cache_key"` + + // This will store the command ID from Qubole + CommandId string `json:"command_id,omitempty"` + + // Our representation of the status of this work item + Status QuboleWorkItemStatus `json:"status,omitempty"` + + // The Qubole cluster to do this work + ClusterLabel string `json:"cluster_label,omitempty"` + + // These are Qubole Tags that show up on their UI + Tags []string `json:"tags,omitempty"` + + // This number keeps track of the number of retries within the sync function. Without this, what happens in + // the sync function is entirely opaque. Note that this field is not meant to represent the number of retries + // of the work itself, just errors with the Qubole API when attempting to sync + Retries int `json:"retries,omitempty"` + + // For Hive jobs, this is the query that will be run + // Not necessary for other Qubole task types necessarily + Query string `json:"query,omitempty"` + + TimeoutSec uint32 `json:"timeout,omitempty"` +} + +// This ID will be used in a process-wide cache, so it needs to be unique across all concurrent work being done by +// that process, but does not necessarily need to be universally unique +func (q QuboleWorkItem) ID() string { + return q.UniqueWorkCacheKey +} + +func constructQuboleWorkItem(uniqueWorkCacheKey string, quboleCommandId string, status QuboleWorkItemStatus) QuboleWorkItem { + return QuboleWorkItem{ + UniqueWorkCacheKey: uniqueWorkCacheKey, + CommandId: quboleCommandId, + Status: status, + } +} + +func NewQuboleWorkItem(uniqueWorkCacheKey string, quboleCommandId string, status QuboleWorkItemStatus, clusterLabel string, + tags []string, retries int) QuboleWorkItem { + return QuboleWorkItem{ + UniqueWorkCacheKey: uniqueWorkCacheKey, + CommandId: quboleCommandId, + Status: status, + ClusterLabel: clusterLabel, + Tags: tags, + Retries: retries, + } +} + +// This status encapsulates all possible states for our custom object. It is different from the QuboleStatus type +// in that this is our Flyte type. It represents the same thing as QuboleStatus, but will actually persist in etcd. +// It is also different from the TaskStatuses in that this is on the qubole job level, not the task level. A task, +// can contain many queries/spark jobs, etc. +type QuboleWorkItemStatus int + +const ( + QuboleWorkNotStarted QuboleWorkItemStatus = iota + QuboleWorkUnknown + QuboleWorkRunnable + QuboleWorkRunning + QuboleWorkExecutionFailed + QuboleWorkExecutionSucceeded + QuboleWorkFailed + QuboleWorkSucceeded +) + +func QuboleWorkIsTerminalState(status QuboleWorkItemStatus) bool { + return status == QuboleWorkFailed || status == QuboleWorkSucceeded +} + +func (q QuboleWorkItemStatus) String() string { + switch q { + case QuboleWorkNotStarted: + return "NotStarted" + case QuboleWorkUnknown: + return "Unknown" + case QuboleWorkRunnable: + return "Runnable" + case QuboleWorkRunning: + return "Running" + case QuboleWorkExecutionFailed: + return "ExecutionFailed" + case QuboleWorkExecutionSucceeded: + return "ExecutionSucceeded" + case QuboleWorkFailed: + return "Failed" + case QuboleWorkSucceeded: + return "Succeeded" + } + return "IllegalQuboleWorkStatus" +} + +func (q QuboleWorkItem) EqualTo(other QuboleWorkItem) bool { + if q.UniqueWorkCacheKey != other.UniqueWorkCacheKey || + q.Status != other.Status || q.CommandId != other.CommandId || + q.Retries != other.Retries || len(q.Tags) != len(other.Tags) { + return false + } + + return true +} + +func workItemMapsAreEqual(old map[string]QuboleWorkItem, new map[string]interface{}) bool { + if len(old) != len(new) { + return false + } + + for k, oldItem := range old { + if x, ok := new[k]; ok { + newItem := x.(QuboleWorkItem) + if !oldItem.EqualTo(newItem) { + return false + } + } else { + return false + } + } + return true +} + +func constructEventInfoFromQuboleWorkItems(taskCtx types.TaskContext, quboleWorkItems map[string]interface{}) (*events.TaskEventInfo, error) { + logs := make([]*core.TaskLog, 0, len(quboleWorkItems)) + for _, v := range quboleWorkItems { + workItem := v.(QuboleWorkItem) + if workItem.CommandId != "" { + logs = append(logs, &core.TaskLog{ + Name: fmt.Sprintf("Retry: %d Status: %s [%s]", + taskCtx.GetTaskExecutionID().GetID().RetryAttempt, workItem.Status, workItem.CommandId), + MessageFormat: core.TaskLog_UNKNOWN, + Uri: fmt.Sprintf(client.QuboleLogLinkFormat, workItem.CommandId), + }) + } + } + + customInfo, err := utils.MarshalObjToStruct(quboleWorkItems) + if err != nil { + return nil, err + } + + return &events.TaskEventInfo{ + CustomInfo: customInfo, + Logs: logs, + }, nil +} + +func QuboleStatusToWorkItemStatus(s client.QuboleStatus) QuboleWorkItemStatus { + switch s { + case client.QuboleStatusDone: + return QuboleWorkSucceeded + case client.QuboleStatusCancelled: + return QuboleWorkFailed + case client.QuboleStatusError: + return QuboleWorkFailed + case client.QuboleStatusUnknown: + return QuboleWorkUnknown + case client.QuboleStatusWaiting: + return QuboleWorkRunning + case client.QuboleStatusRunning: + return QuboleWorkRunning + default: + return QuboleWorkRunning + } +} + +func InterfaceConverter(cachedInterface interface{}) (QuboleWorkItem, error) { + raw, err := json.Marshal(cachedInterface) + if err != nil { + return QuboleWorkItem{}, err + } + + item := &QuboleWorkItem{} + err = json.Unmarshal(raw, item) + if err != nil { + return QuboleWorkItem{}, fmt.Errorf("Failed to unmarshal state into Qubole work item") + } + + return *item, nil +} \ No newline at end of file diff --git a/flyteplugins/go/tasks/v1/qubole/qubole_work_test.go b/flyteplugins/go/tasks/v1/qubole/qubole_work_test.go new file mode 100755 index 0000000000..af4daa313f --- /dev/null +++ b/flyteplugins/go/tasks/v1/qubole/qubole_work_test.go @@ -0,0 +1,177 @@ +package qubole + +import ( + "encoding/json" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + tasksMocks "github.com/lyft/flyteplugins/go/tasks/v1/types/mocks" + "github.com/stretchr/testify/assert" + "strings" + "testing" +) + +func getMockTaskContext() *tasksMocks.TaskContext { + taskCtx := &tasksMocks.TaskContext{} + + id := &tasksMocks.TaskExecutionID{} + id.On("GetGeneratedName").Return("flyteplugins_integration") + id.On("GetID").Return(core.TaskExecutionIdentifier{ + RetryAttempt: 0, + }) + taskCtx.On("GetTaskExecutionID").Return(id) + + return taskCtx +} + +func TestConstructEventInfoFromQuboleWorkItems(t *testing.T) { + workItems := map[string]interface{}{ + "key_1": QuboleWorkItem{ + CommandId: "12345", + UniqueWorkCacheKey: "key_1", + Retries: 0, + Status: QuboleWorkSucceeded, + ClusterLabel: "default", + Tags: []string{}, + }, + } + + out, err := constructEventInfoFromQuboleWorkItems(getMockTaskContext(), workItems) + assert.NoError(t, err) + assert.Equal(t, "Retry: 0 Status: Succeeded [12345]", out.Logs[0].Name) + assert.Equal(t, "12345", out.CustomInfo.Fields["key_1"].GetStructValue().Fields["command_id"].GetStringValue()) + status := out.CustomInfo.Fields["key_1"].GetStructValue().Fields["status"] + assert.Equal(t, float64(7), status.GetNumberValue()) + assert.True(t, strings.Contains(out.Logs[0].Uri, "api.qubole")) +} + +func TestPrinting(t *testing.T) { + assert.Equal(t, "NotStarted", QuboleWorkNotStarted.String()) + assert.Equal(t, "Unknown", QuboleWorkUnknown.String()) + assert.Equal(t, "Runnable", QuboleWorkRunnable.String()) + assert.Equal(t, "Running", QuboleWorkRunning.String()) + assert.Equal(t, "ExecutionFailed", QuboleWorkExecutionFailed.String()) + assert.Equal(t, "ExecutionSucceeded", QuboleWorkExecutionSucceeded.String()) + assert.Equal(t, "Failed", QuboleWorkFailed.String()) + assert.Equal(t, "Succeeded", QuboleWorkSucceeded.String()) +} + +func TestEquality(t *testing.T) { + first := QuboleWorkItem{ + Tags: []string{"hello"}, + Retries: 0, + CommandId: "123456789", + Status: QuboleWorkRunnable, + UniqueWorkCacheKey: "fdsfa", + ClusterLabel: "default", + } + second := QuboleWorkItem{ + Tags: []string{"hello"}, + Retries: 0, + CommandId: "123456789", + Status: QuboleWorkRunnable, + UniqueWorkCacheKey: "fdsfa", + ClusterLabel: "default", + } + third := QuboleWorkItem{ + Tags: []string{"hello"}, + Retries: 0, + CommandId: "123456789", + Status: QuboleWorkFailed, + UniqueWorkCacheKey: "fdsfa", + ClusterLabel: "default", + } + fourth := QuboleWorkItem{ + Tags: []string{"hello"}, + Retries: 0, + CommandId: "1234789", + Status: QuboleWorkRunnable, + UniqueWorkCacheKey: "fdsfad", + ClusterLabel: "default", + } + + assert.True(t, first.EqualTo(second)) + assert.False(t, first.EqualTo(third)) + assert.False(t, first.EqualTo(fourth)) +} + +func TestMapEquality(t *testing.T) { + first := QuboleWorkItem{ + Tags: []string{"hello"}, + Retries: 0, + CommandId: "123456789", + Status: QuboleWorkRunnable, + UniqueWorkCacheKey: "fdsfa", + ClusterLabel: "default", + } + second := QuboleWorkItem{ + Tags: []string{"hello"}, + Retries: 0, + CommandId: "123456789", + Status: QuboleWorkRunnable, + UniqueWorkCacheKey: "fdsfa", + ClusterLabel: "default", + } + third := QuboleWorkItem{ + Tags: []string{"hello"}, + Retries: 0, + CommandId: "123456789", + Status: QuboleWorkFailed, + UniqueWorkCacheKey: "fdsfa", + ClusterLabel: "default", + } + + old := map[string]QuboleWorkItem{ + "first": first, + "second": second, + } + + old2 := map[string]QuboleWorkItem{ + "first": first, + "second": second, + "third": third, + } + + new1 := map[string]interface{}{ + "first": first, + "second": second, + "third": third, + } + + new2 := map[string]interface{}{ + "first": first, + "second": second, + } + + new3 := map[string]interface{}{ + "first": first, + "second": third, + } + + assert.False(t, workItemMapsAreEqual(old, new1)) + assert.False(t, workItemMapsAreEqual(old2, new2)) + assert.True(t, workItemMapsAreEqual(old, new2)) + assert.False(t, workItemMapsAreEqual(old, new3)) +} + +func TestInterfaceConverter(t *testing.T) { + // This is a complicated step to reproduce what will ultimately be given to the function at runtime, the values + // inside the CustomState + item := QuboleWorkItem{ + Status: QuboleWorkRunning, + CommandId: "123456", + Query: "", + UniqueWorkCacheKey: "fjdsakfjd", + } + raw, err := json.Marshal(map[string]interface{}{"":item}) + assert.NoError(t, err) + + // We can't unmarshal into a interface{} but we can unmarhsal into a interface{} if it's the value of a map. + interfaceItem := map[string]interface{}{} + err = json.Unmarshal(raw, &interfaceItem) + assert.NoError(t, err) + + convertedItem, err := InterfaceConverter(interfaceItem[""]) + assert.NoError(t, err) + assert.Equal(t, "123456", convertedItem.CommandId) + assert.Equal(t, QuboleWorkRunning, convertedItem.Status) + assert.Equal(t, "fjdsakfjd", convertedItem.UniqueWorkCacheKey) +} diff --git a/flyteplugins/go/tasks/v1/qubole/secrets_manager.go b/flyteplugins/go/tasks/v1/qubole/secrets_manager.go new file mode 100755 index 0000000000..02f3ad4063 --- /dev/null +++ b/flyteplugins/go/tasks/v1/qubole/secrets_manager.go @@ -0,0 +1,56 @@ +package qubole + +import ( + "context" + "errors" + "fmt" + "io/ioutil" + "os" + "strings" + + "github.com/lyft/flytestdlib/logger" + + "github.com/lyft/flyteplugins/go/tasks/v1/qubole/config" +) + +type secretsManager struct { + // Memoize the key + quboleKey string +} + +type SecretsManager interface { + GetToken() (string, error) +} + +func NewSecretsManager() SecretsManager { + return &secretsManager{} +} + +func (s *secretsManager) GetToken() (string, error) { + if s.quboleKey != "" { + return s.quboleKey, nil + } + + // If the environment variable has been defined, then just use the value of the environment + // variable. This is primarily for local use/testing purposes. Otherwise we expect the token + // to exist in a file. + if key := os.Getenv("QUBOLE_API_KEY"); key != "" { + return key, nil + } + + // Assume that secrets have been mounted somehow + fileLocation := config.GetQuboleConfig().QuboleTokenPath + + b, err := ioutil.ReadFile(fileLocation) + if err != nil { + logger.Errorf(context.Background(), "Could not read entry at %s", fileLocation) + return "", errors.New(fmt.Sprintf("Bad Qubole token file, could not read file at [%s]", fileLocation)) + } + s.quboleKey = strings.TrimSpace(string(b)) + if s.quboleKey == "" { + logger.Errorf(context.Background(), "Qubole token was empty") + return "", errors.New("bad Qubole token - file was read but was empty") + } + + return s.quboleKey, nil +} diff --git a/flyteplugins/go/tasks/v1/qubole/test_helper.go b/flyteplugins/go/tasks/v1/qubole/test_helper.go new file mode 100755 index 0000000000..dbab03e050 --- /dev/null +++ b/flyteplugins/go/tasks/v1/qubole/test_helper.go @@ -0,0 +1,116 @@ +package qubole + +import ( + "github.com/golang/protobuf/jsonpb" + "github.com/golang/protobuf/ptypes/struct" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins" + quboleMocks "github.com/lyft/flyteplugins/go/tasks/v1/qubole/mocks" + "github.com/lyft/flyteplugins/go/tasks/v1/types/mocks" + "github.com/lyft/flyteplugins/go/tasks/v1/utils" + "github.com/lyft/flytestdlib/storage" + libUtils "github.com/lyft/flytestdlib/utils" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +type MockSecretsManager struct { +} + +func (m MockSecretsManager) GetToken() (string, error) { + return "sample-token", nil +} + +func CreateMockTaskContextWithRealTaskExecId() *mocks.TaskContext { + taskExId := core.TaskExecutionIdentifier{ + TaskId: &core.Identifier{ + ResourceType: core.ResourceType_TASK, + Project: "flyteplugins", + Domain: "testing", + Name: "send_event_test", + }, + NodeExecutionId: &core.NodeExecutionIdentifier{ + NodeId: "nodeId", + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "flyteplugins", + Domain: "testing", + Name: "node_1", + }, + }, + } + mockTaskCtx := &mocks.TaskContext{} + id := &mocks.TaskExecutionID{} + id.On("GetGeneratedName").Return("test-hive-job") + id.On("GetID").Return(taskExId) + + mockTaskCtx.On("GetLabels").Return(map[string]string{"label-1": "val1"}) + mockTaskCtx.On("GetAnnotations").Return(map[string]string{"annotation-1": "val1"}) + mockTaskCtx.On("GetTaskExecutionID").Return(id) + mockTaskCtx.On("GetNamespace").Return("test-namespace") + mockTaskCtx.On("GetOwnerReference").Return(metav1.OwnerReference{}) + mockTaskCtx.On("GetPhaseVersion").Return(uint32(0)) + dummyStorageReference := storage.DataReference("s3://fake/file/in") + mockTaskCtx.On("GetInputsFile").Return(dummyStorageReference) + dummyStorageReferenceOut := storage.DataReference("s3://fake/file/out") + mockTaskCtx.On("GetOutputsFile").Return(dummyStorageReferenceOut) + + return mockTaskCtx +} + +func createDummyHiveCustomObj() *plugins.QuboleHiveJob { + hiveJob := plugins.QuboleHiveJob{} + + hiveJob.ClusterLabel = "cluster-label" + hiveJob.Tags = []string{"tag1", "tag2"} + hiveJob.QueryCollection = &plugins.HiveQueryCollection{Queries: []*plugins.HiveQuery{{Query: "Select 5"}}} + return &hiveJob +} + +func createDummyHiveTaskTemplate(id string) *core.TaskTemplate { + hiveJob := createDummyHiveCustomObj() + hiveJobJSON, err := utils.MarshalToString(hiveJob) + if err != nil { + panic(err) + } + + structObj := structpb.Struct{} + + err = jsonpb.UnmarshalString(hiveJobJSON, &structObj) + if err != nil { + panic(err) + } + + return &core.TaskTemplate{ + Id: &core.Identifier{Name: id}, + Type: "hive", + Custom: &structObj, + } +} + +type MockAutoRefreshCache struct { + *quboleMocks.AutoRefreshCache + values map[string]libUtils.CacheItem +} + +func (m MockAutoRefreshCache) GetOrCreate(item libUtils.CacheItem) (libUtils.CacheItem, error) { + if cachedItem, ok := m.values[item.ID()]; ok { + return cachedItem, nil + } else { + m.values[item.ID()] = item + } + return item, nil +} + +func NewMockAutoRefreshCache() MockAutoRefreshCache { + return MockAutoRefreshCache{ + AutoRefreshCache: &quboleMocks.AutoRefreshCache{}, + values: make(map[string]libUtils.CacheItem), + } +} + +func (m MockAutoRefreshCache) Get(key string) libUtils.CacheItem { + if cachedItem, ok := m.values[key]; ok { + return cachedItem + } else { + return nil + } +} diff --git a/flyteplugins/go/tasks/v1/registry.go b/flyteplugins/go/tasks/v1/registry.go new file mode 100755 index 0000000000..7afa772bf2 --- /dev/null +++ b/flyteplugins/go/tasks/v1/registry.go @@ -0,0 +1,24 @@ +package v1 + +import "context" + +type PluginLoaderFn func(ctx context.Context) error + +var loaders []PluginLoaderFn + +// Registers a plugin loader to be called when it's safe to perform plugin initialization logic. This function is NOT +// thread-safe and is expected to be called in an init() function (which runs in a single thread). +func RegisterLoader(fn PluginLoaderFn) { + loaders = append(loaders, fn) +} + +// Runs all plugin loader functions and errors out if any of the loaders fails to finish successfully. +func RunAllLoaders(ctx context.Context) error { + for _, fn := range loaders { + if err := fn(ctx); err != nil { + return err + } + } + + return nil +} diff --git a/flyteplugins/go/tasks/v1/resourcemanager/lookaside_buffer.go b/flyteplugins/go/tasks/v1/resourcemanager/lookaside_buffer.go new file mode 100755 index 0000000000..657ed2fb4b --- /dev/null +++ b/flyteplugins/go/tasks/v1/resourcemanager/lookaside_buffer.go @@ -0,0 +1,18 @@ +package resourcemanager + +import ( + "context" + "fmt" +) + +// This error will be returned when a key is not found in the buffer +var ExecutionNotFoundError = fmt.Errorf("Execution not found") + +//go:generate mockery -name ExecutionLooksideBuffer -case=underscore + +// Remembers an execution key to a value. Specifically for example in the Qubole case, the key will be the same +// key that's used in the AutoRefreshCache +type ExecutionLooksideBuffer interface { + ConfirmExecution(ctx context.Context, executionKey string, executionValue string) error + RetrieveExecution(ctx context.Context, executionKey string) (string, error) +} diff --git a/flyteplugins/go/tasks/v1/resourcemanager/mocks/execution_lookside_buffer.go b/flyteplugins/go/tasks/v1/resourcemanager/mocks/execution_lookside_buffer.go new file mode 100755 index 0000000000..fc94b9bd96 --- /dev/null +++ b/flyteplugins/go/tasks/v1/resourcemanager/mocks/execution_lookside_buffer.go @@ -0,0 +1,46 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import context "context" +import mock "github.com/stretchr/testify/mock" + +// ExecutionLooksideBuffer is an autogenerated mock type for the ExecutionLooksideBuffer type +type ExecutionLooksideBuffer struct { + mock.Mock +} + +// ConfirmExecution provides a mock function with given fields: ctx, executionKey, executionValue +func (_m *ExecutionLooksideBuffer) ConfirmExecution(ctx context.Context, executionKey string, executionValue string) error { + ret := _m.Called(ctx, executionKey, executionValue) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = rf(ctx, executionKey, executionValue) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// RetrieveExecution provides a mock function with given fields: ctx, executionKey +func (_m *ExecutionLooksideBuffer) RetrieveExecution(ctx context.Context, executionKey string) (string, error) { + ret := _m.Called(ctx, executionKey) + + var r0 string + if rf, ok := ret.Get(0).(func(context.Context, string) string); ok { + r0 = rf(ctx, executionKey) + } else { + r0 = ret.Get(0).(string) + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, executionKey) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/flyteplugins/go/tasks/v1/resourcemanager/mocks/resource_manager.go b/flyteplugins/go/tasks/v1/resourcemanager/mocks/resource_manager.go new file mode 100755 index 0000000000..fbe51fcdfa --- /dev/null +++ b/flyteplugins/go/tasks/v1/resourcemanager/mocks/resource_manager.go @@ -0,0 +1,47 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import context "context" +import mock "github.com/stretchr/testify/mock" +import resourcemanager "github.com/lyft/flyteplugins/go/tasks/v1/resourcemanager" + +// ResourceManager is an autogenerated mock type for the ResourceManager type +type ResourceManager struct { + mock.Mock +} + +// AllocateResource provides a mock function with given fields: ctx, namespace, allocationToken +func (_m *ResourceManager) AllocateResource(ctx context.Context, namespace string, allocationToken string) (resourcemanager.AllocationStatus, error) { + ret := _m.Called(ctx, namespace, allocationToken) + + var r0 resourcemanager.AllocationStatus + if rf, ok := ret.Get(0).(func(context.Context, string, string) resourcemanager.AllocationStatus); ok { + r0 = rf(ctx, namespace, allocationToken) + } else { + r0 = ret.Get(0).(resourcemanager.AllocationStatus) + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = rf(ctx, namespace, allocationToken) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ReleaseResource provides a mock function with given fields: ctx, namespace, allocationToken +func (_m *ResourceManager) ReleaseResource(ctx context.Context, namespace string, allocationToken string) error { + ret := _m.Called(ctx, namespace, allocationToken) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = rf(ctx, namespace, allocationToken) + } else { + r0 = ret.Error(0) + } + + return r0 +} diff --git a/flyteplugins/go/tasks/v1/resourcemanager/mocks/resource_manager_ext.go b/flyteplugins/go/tasks/v1/resourcemanager/mocks/resource_manager_ext.go new file mode 100755 index 0000000000..7592a33e23 --- /dev/null +++ b/flyteplugins/go/tasks/v1/resourcemanager/mocks/resource_manager_ext.go @@ -0,0 +1,44 @@ +package mocks + +import ( + "context" + + "github.com/lyft/flyteplugins/go/tasks/v1/resourcemanager" + "github.com/stretchr/testify/mock" +) + +type AllocateResourceCall struct { + *mock.Call +} + +func (_a *AllocateResourceCall) Return(status resourcemanager.AllocationStatus, err error) *AllocateResourceCall { + return &AllocateResourceCall{Call: _a.Call.Return(status, err)} +} + +func (_m *ResourceManager) OnAllocateResource(ctx context.Context, namespace string, allocationToken string) *AllocateResourceCall { + call := _m.On("AllocateResource", ctx, namespace, allocationToken) + return &AllocateResourceCall{Call: call} +} + +func (_m *ResourceManager) OnAllocateResourceWithMatchers(matchers []interface{}) *AllocateResourceCall { + call := _m.On("AllocateResource", matchers...) + return &AllocateResourceCall{Call: call} +} + +type ReleaseResourceCall struct { + *mock.Call +} + +func (_a *ReleaseResourceCall) Return(err error) *ReleaseResourceCall { + return &ReleaseResourceCall{Call: _a.Call.Return(err)} +} + +func (_m *ResourceManager) OnReleaseResource(ctx context.Context, namespace string, allocationToken string) *ReleaseResourceCall { + call := _m.On("ReleaseResource", ctx, namespace, allocationToken) + return &ReleaseResourceCall{Call: call} +} + +func (_m *ResourceManager) OnReleaseResourceWithMatchers(matchers []interface{}) *ReleaseResourceCall { + call := _m.On("ReleaseResource", matchers...) + return &ReleaseResourceCall{Call: call} +} diff --git a/flyteplugins/go/tasks/v1/resourcemanager/redis_client.go b/flyteplugins/go/tasks/v1/resourcemanager/redis_client.go new file mode 100755 index 0000000000..ee7435d531 --- /dev/null +++ b/flyteplugins/go/tasks/v1/resourcemanager/redis_client.go @@ -0,0 +1,24 @@ +package resourcemanager + +import ( + "context" + "github.com/go-redis/redis" + "github.com/lyft/flytestdlib/logger" +) + +func NewRedisClient(ctx context.Context, host string, key string, maxRetries int) (*redis.Client, error) { + client := redis.NewClient(&redis.Options{ + Addr: host, + Password: key, + DB: 0, // use default DB + MaxRetries: maxRetries, + }) + + _, err := client.Ping().Result() + if err != nil { + logger.Errorf(ctx, "Error creating Redis client at %s %v", host, err) + return nil, err + } + logger.Infof(ctx, "Created Redis client with host %s key %s ...", host, key) + return client, nil +} diff --git a/flyteplugins/go/tasks/v1/resourcemanager/redis_lookaside_buffer.go b/flyteplugins/go/tasks/v1/resourcemanager/redis_lookaside_buffer.go new file mode 100755 index 0000000000..6506497f24 --- /dev/null +++ b/flyteplugins/go/tasks/v1/resourcemanager/redis_lookaside_buffer.go @@ -0,0 +1,46 @@ +package resourcemanager + +import ( + "context" + "fmt" + "github.com/go-redis/redis" + "time" +) + +type RedisLookasideBuffer struct { + client *redis.Client + redisPrefix string + expiry time.Duration +} + +func createKey(prefix, key string) string { + return fmt.Sprintf("%s:%s", prefix, key) +} + +func (r RedisLookasideBuffer) ConfirmExecution(ctx context.Context, executionKey string, executionValue string) error { + err := r.client.Set(createKey(r.redisPrefix, executionKey), executionValue, r.expiry).Err() + if err != nil { + return err + } + + return nil +} + +func (r RedisLookasideBuffer) RetrieveExecution(ctx context.Context, executionKey string) (string, error) { + value, err := r.client.Get(createKey(r.redisPrefix, executionKey)).Result() + if err == redis.Nil { + return "", ExecutionNotFoundError + } else if err != nil { + return "", err + } + + return value, nil +} + +func NewRedisLookasideBuffer(ctx context.Context, client *redis.Client, redisPrefix string, expiry time.Duration) RedisLookasideBuffer { + return RedisLookasideBuffer{ + client: client, + redisPrefix: redisPrefix, + expiry: expiry, + } +} diff --git a/flyteplugins/go/tasks/v1/resourcemanager/redis_lookaside_buffer_test.go b/flyteplugins/go/tasks/v1/resourcemanager/redis_lookaside_buffer_test.go new file mode 100755 index 0000000000..1b0172d74a --- /dev/null +++ b/flyteplugins/go/tasks/v1/resourcemanager/redis_lookaside_buffer_test.go @@ -0,0 +1,10 @@ +package resourcemanager + +import ( + "github.com/magiconair/properties/assert" + "testing" +) + +func TestCreateKey(t *testing.T) { + assert.Equal(t, "asdf:fdsa", createKey("asdf", "fdsa")) +} diff --git a/flyteplugins/go/tasks/v1/resourcemanager/redis_resource_manager.go b/flyteplugins/go/tasks/v1/resourcemanager/redis_resource_manager.go new file mode 100755 index 0000000000..8317122b39 --- /dev/null +++ b/flyteplugins/go/tasks/v1/resourcemanager/redis_resource_manager.go @@ -0,0 +1,122 @@ +package resourcemanager + +import ( + "context" + "time" + + "github.com/go-redis/redis" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/promutils" + "github.com/prometheus/client_golang/prometheus" + + "github.com/lyft/flyteplugins/go/tasks/v1/qubole/config" +) + +// This is the key that will point to the Redis Set. +// https://redis.io/commands#set +const RedisSetKey = "qubole" + +type RedisResourceManager struct { + client *redis.Client + redisSetKey string + Metrics RedisResourceManagerMetrics +} + +type RedisResourceManagerMetrics struct { + Scope promutils.Scope + RedisSizeCheckTime promutils.StopWatch + AllocatedTokensGauge prometheus.Gauge +} + +func NewRedisResourceManagerMetrics(scope promutils.Scope) RedisResourceManagerMetrics { + return RedisResourceManagerMetrics{ + Scope: scope, + RedisSizeCheckTime: scope.MustNewStopWatch("redis:size_check_time_ms", + "The time it takes to measure the size of the Redis Set where all the queries are stored", time.Millisecond), + + AllocatedTokensGauge: scope.MustNewGauge("size", + "The number of allocation tokens currently in the Redis set"), + } +} + +func (r RedisResourceManager) AllocateResource(ctx context.Context, namespace string, allocationToken string) ( + AllocationStatus, error) { + + // Check to see if the allocation token is already in the set + found, err := r.client.SIsMember(r.redisSetKey, allocationToken).Result() + if err != nil { + logger.Errorf(ctx, "Error getting size of Redis set %v", err) + return AllocationUndefined, err + } + if found { + logger.Infof(ctx, "Already allocated [%s:%s]", namespace, allocationToken) + return AllocationStatusGranted, nil + } + + size, err := r.client.SCard(r.redisSetKey).Result() + if err != nil { + logger.Errorf(ctx, "Error getting size of Redis set %v", err) + return AllocationUndefined, err + } + + if size > int64(config.GetQuboleConfig().QuboleLimit) { + logger.Infof(ctx, "Too many allocations (total [%d]), rejecting [%s:%s]", size, namespace, allocationToken) + return AllocationStatusExhausted, nil + } + + countAdded, err := r.client.SAdd(r.redisSetKey, allocationToken).Result() + if err != nil { + logger.Errorf(ctx, "Error adding token [%s:%s] %v", namespace, allocationToken, err) + return AllocationUndefined, err + } + logger.Infof(ctx, "Added %d to the Redis Qubole set", countAdded) + + return AllocationStatusGranted, err +} + +func (r RedisResourceManager) ReleaseResource(ctx context.Context, namespace string, allocationToken string) error { + countRemoved, err := r.client.SRem(r.redisSetKey, allocationToken).Result() + if err != nil { + logger.Errorf(ctx, "Error removing token [%s:%s] %v", namespace, allocationToken, err) + return err + } + logger.Infof(ctx, "Removed %d token: %s", countRemoved, allocationToken) + + return nil +} + +func (r *RedisResourceManager) pollRedis(ctx context.Context) { + stopWatch := r.Metrics.RedisSizeCheckTime.Start() + defer stopWatch.Stop() + size, err := r.client.SCard(r.redisSetKey).Result() + if err != nil { + logger.Errorf(ctx, "Error getting size of Redis set in metrics poller %v", err) + return + } + r.Metrics.AllocatedTokensGauge.Set(float64(size)) +} + +func (r *RedisResourceManager) startMetricsGathering(ctx context.Context) { + ticker := time.NewTicker(10 * time.Second) + go func() { + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + r.pollRedis(ctx) + } + } + }() +} + +func NewRedisResourceManager(ctx context.Context, client *redis.Client, scope promutils.Scope) (*RedisResourceManager, error) { + rm := &RedisResourceManager{ + client: client, + Metrics: NewRedisResourceManagerMetrics(scope), + redisSetKey: RedisSetKey, + } + rm.startMetricsGathering(ctx) + + return rm, nil +} diff --git a/flyteplugins/go/tasks/v1/resourcemanager/resource_manager.go b/flyteplugins/go/tasks/v1/resourcemanager/resource_manager.go new file mode 100755 index 0000000000..14aed73881 --- /dev/null +++ b/flyteplugins/go/tasks/v1/resourcemanager/resource_manager.go @@ -0,0 +1,68 @@ +package resourcemanager + +import ( + "context" + "github.com/go-redis/redis" + + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/promutils" +) + +//go:generate mockery -name ResourceManager -case=underscore + +type AllocationStatus string + +const ( + // This is the enum returned when there's an error + AllocationUndefined AllocationStatus = "ResourceGranted" + + // Go for it + AllocationStatusGranted AllocationStatus = "ResourceGranted" + + // This means that no resources are available globally. This is the only rejection message we use right now. + AllocationStatusExhausted AllocationStatus = "ResourceExhausted" + + // We're not currently using this - but this would indicate that things globally are okay, but that your + // own namespace is too busy + AllocationStatusNamespaceQuotaExceeded AllocationStatus = "NamespaceQuotaExceeded" +) + +// Resource Manager manages a single resource type, and each allocation is of size one +type ResourceManager interface { + AllocateResource(ctx context.Context, namespace string, allocationToken string) (AllocationStatus, error) + ReleaseResource(ctx context.Context, namespace string, allocationToken string) error +} + +type NoopResourceManager struct { +} + +func (NoopResourceManager) AllocateResource(ctx context.Context, namespace string, allocationToken string) ( + AllocationStatus, error) { + + return AllocationStatusGranted, nil +} + +func (NoopResourceManager) ReleaseResource(ctx context.Context, namespace string, allocationToken string) error { + return nil +} + +// Gets or creates a resource manager to the given resource name. This function is thread-safe and calling it with the +// same resource name will return the same instance of resource manager every time. +func GetOrCreateResourceManagerFor(ctx context.Context, resourceName string) (ResourceManager, error) { + return NoopResourceManager{}, nil +} + +func GetResourceManagerByType(ctx context.Context, managerType string, scope promutils.Scope, redisClient *redis.Client) ( + ResourceManager, error) { + + switch managerType { + case "noop": + logger.Infof(ctx, "Using the NOOP resource manager") + return NoopResourceManager{}, nil + case "redis": + logger.Infof(ctx, "Using Redis based resource manager") + return NewRedisResourceManager(ctx, redisClient, scope.NewSubScope("resourcemanager:redis")) + } + logger.Infof(ctx, "Using the NOOP resource manager by default") + return NoopResourceManager{}, nil +} diff --git a/flyteplugins/go/tasks/v1/resourcemanager/resource_manager_test.go b/flyteplugins/go/tasks/v1/resourcemanager/resource_manager_test.go new file mode 100755 index 0000000000..d92483b3cc --- /dev/null +++ b/flyteplugins/go/tasks/v1/resourcemanager/resource_manager_test.go @@ -0,0 +1,88 @@ +package resourcemanager + +import ( + "context" + "reflect" + "testing" +) + +func TestGetOrCreateResourceManagerFor(t *testing.T) { + type args struct { + ctx context.Context + resourceName string + } + tests := []struct { + name string + args args + want ResourceManager + wantErr bool + }{ + {name: "Simple", args: args{ctx: context.TODO(), resourceName: "simple"}, want: NoopResourceManager{}, wantErr: false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := GetOrCreateResourceManagerFor(tt.args.ctx, tt.args.resourceName) + if (err != nil) != tt.wantErr { + t.Errorf("GetOrCreateResourceManagerFor() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("GetOrCreateResourceManagerFor() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestNoopResourceManager_AllocateResource(t *testing.T) { + type args struct { + ctx context.Context + namespace string + allocationToken string + } + tests := []struct { + name string + n NoopResourceManager + args args + want AllocationStatus + wantErr bool + }{ + {name: "Simple", n: NoopResourceManager{}, args: args{ctx: context.TODO(), namespace: "namespace", allocationToken: "token"}, want: AllocationStatusGranted, wantErr: false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n := NoopResourceManager{} + got, err := n.AllocateResource(tt.args.ctx, tt.args.namespace, tt.args.allocationToken) + if (err != nil) != tt.wantErr { + t.Errorf("NoopResourceManager.AllocateResource() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("NoopResourceManager.AllocateResource() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestNoopResourceManager_ReleaseResource(t *testing.T) { + type args struct { + ctx context.Context + namespace string + allocationToken string + } + tests := []struct { + name string + n NoopResourceManager + args args + wantErr bool + }{ + {name: "Simple", n: NoopResourceManager{}, args: args{ctx: context.TODO(), namespace: "namespace", allocationToken: "token"}, wantErr: false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n := NoopResourceManager{} + if err := n.ReleaseResource(tt.args.ctx, tt.args.namespace, tt.args.allocationToken); (err != nil) != tt.wantErr { + t.Errorf("NoopResourceManager.ReleaseResource() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/flyteplugins/go/tasks/v1/testdata/config.yaml b/flyteplugins/go/tasks/v1/testdata/config.yaml new file mode 100755 index 0000000000..45daa4b20b --- /dev/null +++ b/flyteplugins/go/tasks/v1/testdata/config.yaml @@ -0,0 +1,63 @@ +# Sample plugins config +plugins: + # Set of enabled plugins at root level + enabled-plugins: + - container + # All k8s plugins default configuration + k8s: + inject-finalizer: true + default-annotations: + - annotationKey1: annotationValue1 + - annotationKey2: annotationValue2 + default-labels: + - label1: labelValue1 + - label2: labelValue2 + resource-tolerations: + nvidia.com/gpu: + key: flyte/gpu + value: dedicated + operator: Equal + effect: NoSchedule + storage: + - key: storage + value: special + operator: Equal + effect: PreferNoSchedule + default-env-vars: + - AWS_METADATA_SERVICE_TIMEOUT: 5 + - AWS_METADATA_SERVICE_NUM_ATTEMPTS: 20 + - FLYTE_AWS_ENDPOINT: "http://minio.flyte:9000" + - FLYTE_AWS_ACCESS_KEY_ID: minio + - FLYTE_AWS_SECRET_ACCESS_KEY: miniostorage + # Spark Plugin configuration + spark: + spark-config-default: + - spark.hadoop.mapreduce.fileoutputcommitter.algorithm.version: "2" + - spark.kubernetes.allocation.batch.size: "50" + - spark.hadoop.fs.s3a.acl.default: "BucketOwnerFullControl" + - spark.hadoop.fs.s3n.impl: "org.apache.hadoop.fs.s3a.S3AFileSystem" + - spark.hadoop.fs.AbstractFileSystem.s3n.impl: "org.apache.hadoop.fs.s3a.S3A" + - spark.hadoop.fs.s3.impl: "org.apache.hadoop.fs.s3a.S3AFileSystem" + - spark.hadoop.fs.AbstractFileSystem.s3.impl: "org.apache.hadoop.fs.s3a.S3A" + - spark.hadoop.fs.s3a.impl: "org.apache.hadoop.fs.s3a.S3AFileSystem" + - spark.hadoop.fs.AbstractFileSystem.s3a.impl: "org.apache.hadoop.fs.s3a.S3A" + - spark.hadoop.fs.s3a.multipart.threshold: "536870912" + - spark.blacklist.enabled: "true" + - spark.blacklist.timeout: "5m" + # Qubole plugin configuration + qubole: + # Either create this file with your username with the real token, or set the QUBOLE_API_KEY environment variable + # See the secrets_manager.go file in the plugins repo for usage. Since the dev/test deployment of + # this has a dummy QUBOLE_API_KEY env var built in, this fake path won't break anything. + quboleTokenPath: "/Users/yourusername/CREDENTIALS_FLYTE_QUBOLE_CLIENT_TOKEN" + resourceManagerType: redis + redisHostPath: redis-resource-manager.flyte:6379 + redisHostKey: mypassword + quboleLimit: 10 + # Waitable plugin configuration + waitable: + console-uri: http://localhost:30081/console + # Logging configuration + logs: + kubernetes-enabled: true + kubernetes-url: "http://localhost:30082" diff --git a/flyteplugins/go/tasks/v1/types/mocks/EventRecorder.go b/flyteplugins/go/tasks/v1/types/mocks/EventRecorder.go new file mode 100755 index 0000000000..62d0d1d373 --- /dev/null +++ b/flyteplugins/go/tasks/v1/types/mocks/EventRecorder.go @@ -0,0 +1,26 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import context "context" +import event "github.com/lyft/flyteidl/gen/pb-go/flyteidl/event" +import mock "github.com/stretchr/testify/mock" + +// EventRecorder is an autogenerated mock type for the EventRecorder type +type EventRecorder struct { + mock.Mock +} + +// RecordTaskEvent provides a mock function with given fields: ctx, _a1 +func (_m *EventRecorder) RecordTaskEvent(ctx context.Context, _a1 *event.TaskExecutionEvent) error { + ret := _m.Called(ctx, _a1) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *event.TaskExecutionEvent) error); ok { + r0 = rf(ctx, _a1) + } else { + r0 = ret.Error(0) + } + + return r0 +} diff --git a/flyteplugins/go/tasks/v1/types/mocks/Executor.go b/flyteplugins/go/tasks/v1/types/mocks/Executor.go new file mode 100755 index 0000000000..a2221c4270 --- /dev/null +++ b/flyteplugins/go/tasks/v1/types/mocks/Executor.go @@ -0,0 +1,141 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import context "context" +import core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" +import mock "github.com/stretchr/testify/mock" +import types "github.com/lyft/flyteplugins/go/tasks/v1/types" + +// Executor is an autogenerated mock type for the Executor type +type Executor struct { + mock.Mock +} + +// CheckTaskStatus provides a mock function with given fields: ctx, taskCtx, task +func (_m *Executor) CheckTaskStatus(ctx context.Context, taskCtx types.TaskContext, task *core.TaskTemplate) (types.TaskStatus, error) { + ret := _m.Called(ctx, taskCtx, task) + + var r0 types.TaskStatus + if rf, ok := ret.Get(0).(func(context.Context, types.TaskContext, *core.TaskTemplate) types.TaskStatus); ok { + r0 = rf(ctx, taskCtx, task) + } else { + r0 = ret.Get(0).(types.TaskStatus) + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, types.TaskContext, *core.TaskTemplate) error); ok { + r1 = rf(ctx, taskCtx, task) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetID provides a mock function with given fields: +func (_m *Executor) GetID() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// GetProperties provides a mock function with given fields: +func (_m *Executor) GetProperties() types.ExecutorProperties { + ret := _m.Called() + + var r0 types.ExecutorProperties + if rf, ok := ret.Get(0).(func() types.ExecutorProperties); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(types.ExecutorProperties) + } + + return r0 +} + +// Initialize provides a mock function with given fields: ctx, param +func (_m *Executor) Initialize(ctx context.Context, param types.ExecutorInitializationParameters) error { + ret := _m.Called(ctx, param) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, types.ExecutorInitializationParameters) error); ok { + r0 = rf(ctx, param) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// KillTask provides a mock function with given fields: ctx, taskCtx, reason +func (_m *Executor) KillTask(ctx context.Context, taskCtx types.TaskContext, reason string) error { + ret := _m.Called(ctx, taskCtx, reason) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, types.TaskContext, string) error); ok { + r0 = rf(ctx, taskCtx, reason) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// ResolveOutputs provides a mock function with given fields: ctx, taskCtx, outputVariables +func (_m *Executor) ResolveOutputs(ctx context.Context, taskCtx types.TaskContext, outputVariables ...string) (map[string]*core.Literal, error) { + _va := make([]interface{}, len(outputVariables)) + for _i := range outputVariables { + _va[_i] = outputVariables[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, taskCtx) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 map[string]*core.Literal + if rf, ok := ret.Get(0).(func(context.Context, types.TaskContext, ...string) map[string]*core.Literal); ok { + r0 = rf(ctx, taskCtx, outputVariables...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]*core.Literal) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, types.TaskContext, ...string) error); ok { + r1 = rf(ctx, taskCtx, outputVariables...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// StartTask provides a mock function with given fields: ctx, taskCtx, task, inputs +func (_m *Executor) StartTask(ctx context.Context, taskCtx types.TaskContext, task *core.TaskTemplate, inputs *core.LiteralMap) (types.TaskStatus, error) { + ret := _m.Called(ctx, taskCtx, task, inputs) + + var r0 types.TaskStatus + if rf, ok := ret.Get(0).(func(context.Context, types.TaskContext, *core.TaskTemplate, *core.LiteralMap) types.TaskStatus); ok { + r0 = rf(ctx, taskCtx, task, inputs) + } else { + r0 = ret.Get(0).(types.TaskStatus) + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, types.TaskContext, *core.TaskTemplate, *core.LiteralMap) error); ok { + r1 = rf(ctx, taskCtx, task, inputs) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/flyteplugins/go/tasks/v1/types/mocks/TaskContext.go b/flyteplugins/go/tasks/v1/types/mocks/TaskContext.go new file mode 100755 index 0000000000..ce6b400eb4 --- /dev/null +++ b/flyteplugins/go/tasks/v1/types/mocks/TaskContext.go @@ -0,0 +1,234 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" +import pkgtypes "k8s.io/apimachinery/pkg/types" +import storage "github.com/lyft/flytestdlib/storage" +import types "github.com/lyft/flyteplugins/go/tasks/v1/types" +import v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + +// TaskContext is an autogenerated mock type for the TaskContext type +type TaskContext struct { + mock.Mock +} + +// GetAnnotations provides a mock function with given fields: +func (_m *TaskContext) GetAnnotations() map[string]string { + ret := _m.Called() + + var r0 map[string]string + if rf, ok := ret.Get(0).(func() map[string]string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]string) + } + } + + return r0 +} + +// GetCustomState provides a mock function with given fields: +func (_m *TaskContext) GetCustomState() map[string]interface{} { + ret := _m.Called() + + var r0 map[string]interface{} + if rf, ok := ret.Get(0).(func() map[string]interface{}); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]interface{}) + } + } + + return r0 +} + +// GetDataDir provides a mock function with given fields: +func (_m *TaskContext) GetDataDir() storage.DataReference { + ret := _m.Called() + + var r0 storage.DataReference + if rf, ok := ret.Get(0).(func() storage.DataReference); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(storage.DataReference) + } + + return r0 +} + +// GetErrorFile provides a mock function with given fields: +func (_m *TaskContext) GetErrorFile() storage.DataReference { + ret := _m.Called() + + var r0 storage.DataReference + if rf, ok := ret.Get(0).(func() storage.DataReference); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(storage.DataReference) + } + + return r0 +} + +// GetInputsFile provides a mock function with given fields: +func (_m *TaskContext) GetInputsFile() storage.DataReference { + ret := _m.Called() + + var r0 storage.DataReference + if rf, ok := ret.Get(0).(func() storage.DataReference); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(storage.DataReference) + } + + return r0 +} + +// GetK8sServiceAccount provides a mock function with given fields: +func (_m *TaskContext) GetK8sServiceAccount() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// GetLabels provides a mock function with given fields: +func (_m *TaskContext) GetLabels() map[string]string { + ret := _m.Called() + + var r0 map[string]string + if rf, ok := ret.Get(0).(func() map[string]string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]string) + } + } + + return r0 +} + +// GetNamespace provides a mock function with given fields: +func (_m *TaskContext) GetNamespace() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// GetOutputsFile provides a mock function with given fields: +func (_m *TaskContext) GetOutputsFile() storage.DataReference { + ret := _m.Called() + + var r0 storage.DataReference + if rf, ok := ret.Get(0).(func() storage.DataReference); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(storage.DataReference) + } + + return r0 +} + +// GetOverrides provides a mock function with given fields: +func (_m *TaskContext) GetOverrides() types.TaskOverrides { + ret := _m.Called() + + var r0 types.TaskOverrides + if rf, ok := ret.Get(0).(func() types.TaskOverrides); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(types.TaskOverrides) + } + } + + return r0 +} + +// GetOwnerID provides a mock function with given fields: +func (_m *TaskContext) GetOwnerID() pkgtypes.NamespacedName { + ret := _m.Called() + + var r0 pkgtypes.NamespacedName + if rf, ok := ret.Get(0).(func() pkgtypes.NamespacedName); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(pkgtypes.NamespacedName) + } + + return r0 +} + +// GetOwnerReference provides a mock function with given fields: +func (_m *TaskContext) GetOwnerReference() v1.OwnerReference { + ret := _m.Called() + + var r0 v1.OwnerReference + if rf, ok := ret.Get(0).(func() v1.OwnerReference); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1.OwnerReference) + } + + return r0 +} + +// GetPhase provides a mock function with given fields: +func (_m *TaskContext) GetPhase() types.TaskPhase { + ret := _m.Called() + + var r0 types.TaskPhase + if rf, ok := ret.Get(0).(func() types.TaskPhase); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(types.TaskPhase) + } + + return r0 +} + +// GetPhaseVersion provides a mock function with given fields: +func (_m *TaskContext) GetPhaseVersion() uint32 { + ret := _m.Called() + + var r0 uint32 + if rf, ok := ret.Get(0).(func() uint32); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(uint32) + } + + return r0 +} + +// GetTaskExecutionID provides a mock function with given fields: +func (_m *TaskContext) GetTaskExecutionID() types.TaskExecutionID { + ret := _m.Called() + + var r0 types.TaskExecutionID + if rf, ok := ret.Get(0).(func() types.TaskExecutionID); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(types.TaskExecutionID) + } + } + + return r0 +} diff --git a/flyteplugins/go/tasks/v1/types/mocks/TaskExecutionID.go b/flyteplugins/go/tasks/v1/types/mocks/TaskExecutionID.go new file mode 100755 index 0000000000..bd0f4efdd2 --- /dev/null +++ b/flyteplugins/go/tasks/v1/types/mocks/TaskExecutionID.go @@ -0,0 +1,39 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" +import mock "github.com/stretchr/testify/mock" + +// TaskExecutionID is an autogenerated mock type for the TaskExecutionID type +type TaskExecutionID struct { + mock.Mock +} + +// GetGeneratedName provides a mock function with given fields: +func (_m *TaskExecutionID) GetGeneratedName() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// GetID provides a mock function with given fields: +func (_m *TaskExecutionID) GetID() core.TaskExecutionIdentifier { + ret := _m.Called() + + var r0 core.TaskExecutionIdentifier + if rf, ok := ret.Get(0).(func() core.TaskExecutionIdentifier); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(core.TaskExecutionIdentifier) + } + + return r0 +} diff --git a/flyteplugins/go/tasks/v1/types/mocks/TaskOverrides.go b/flyteplugins/go/tasks/v1/types/mocks/TaskOverrides.go new file mode 100755 index 0000000000..e5e9d0fb5f --- /dev/null +++ b/flyteplugins/go/tasks/v1/types/mocks/TaskOverrides.go @@ -0,0 +1,44 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" + +import v1 "k8s.io/api/core/v1" + +// TaskOverrides is an autogenerated mock type for the TaskOverrides type +type TaskOverrides struct { + mock.Mock +} + +// GetConfig provides a mock function with given fields: +func (_m *TaskOverrides) GetConfig() *v1.ConfigMap { + ret := _m.Called() + + var r0 *v1.ConfigMap + if rf, ok := ret.Get(0).(func() *v1.ConfigMap); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1.ConfigMap) + } + } + + return r0 +} + +// GetResources provides a mock function with given fields: +func (_m *TaskOverrides) GetResources() *v1.ResourceRequirements { + ret := _m.Called() + + var r0 *v1.ResourceRequirements + if rf, ok := ret.Get(0).(func() *v1.ResourceRequirements); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1.ResourceRequirements) + } + } + + return r0 +} diff --git a/flyteplugins/go/tasks/v1/types/outputs_resolver.go b/flyteplugins/go/tasks/v1/types/outputs_resolver.go new file mode 100755 index 0000000000..34abfcf03e --- /dev/null +++ b/flyteplugins/go/tasks/v1/types/outputs_resolver.go @@ -0,0 +1,46 @@ +package types + +import ( + "context" + "fmt" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytestdlib/storage" +) + +// Provides a default implementation for ResolveOutputs method by reading 'outputs.pb' from task directory into a LiteralMap. +type OutputsResolver struct { + store storage.ComposedProtobufStore +} + +func (r OutputsResolver) ResolveOutputs(ctx context.Context, taskCtx TaskContext, outputVariables ...VarName) ( + values map[VarName]*core.Literal, err error) { + + d := &core.LiteralMap{} + outputsFileRef := taskCtx.GetOutputsFile() + if err := r.store.ReadProtobuf(ctx, outputsFileRef, d); err != nil { + return nil, fmt.Errorf("failed to read data from dataDir [%v]. Error: %v", taskCtx.GetOutputsFile(), err) + } + + if d == nil || d.Literals == nil { + return nil, fmt.Errorf("outputs from Task [%v] not found at [%v]", taskCtx.GetTaskExecutionID().GetGeneratedName(), + outputsFileRef) + } + + values = make(map[VarName]*core.Literal, len(outputVariables)) + for _, varName := range outputVariables { + l, ok := d.Literals[varName] + if !ok { + return nil, fmt.Errorf("failed to find [%v].[%v]", taskCtx.GetTaskExecutionID().GetGeneratedName(), varName) + } + + values[varName] = l + } + + return values, nil +} + +// Creates a default outputs resolver that expects a LiteralMap to exist in the task's outputFile location. +func NewOutputsResolver(store storage.ComposedProtobufStore) OutputsResolver { + return OutputsResolver{store: store} +} diff --git a/flyteplugins/go/tasks/v1/types/status.go b/flyteplugins/go/tasks/v1/types/status.go new file mode 100755 index 0000000000..f028316c09 --- /dev/null +++ b/flyteplugins/go/tasks/v1/types/status.go @@ -0,0 +1,131 @@ +package types + +import ( + "strconv" + "strings" + "time" +) + +type TaskPhase int + +// NOTE: if we add a status here, we should make sure it converts correctly when reporting Task event +// See events_publisher.go +const ( + TaskPhaseQueued TaskPhase = iota + TaskPhaseRunning + TaskPhaseRetryableFailure + TaskPhasePermanentFailure + TaskPhaseSucceeded + TaskPhaseUndefined + TaskPhaseNotReady + TaskPhaseUnknown +) + +func (t TaskPhase) String() string { + switch t { + case TaskPhaseQueued: + return "Queued" + case TaskPhaseRunning: + return "Running" + case TaskPhaseRetryableFailure: + return "RetryableFailure" + case TaskPhasePermanentFailure: + return "PermanentFailure" + case TaskPhaseSucceeded: + return "Succeeded" + case TaskPhaseNotReady: + return "NotReady" + case TaskPhaseUndefined: + return "Undefined" + } + return "Unknown" +} + +func (t TaskPhase) IsTerminal() bool { + return t.IsSuccess() || t.IsPermanentFailure() +} + +func (t TaskPhase) IsSuccess() bool { + return t == TaskPhaseSucceeded +} + +func (t TaskPhase) IsPermanentFailure() bool { + return t == TaskPhasePermanentFailure +} + +func (t TaskPhase) IsRetryableFailure() bool { + return t == TaskPhaseRetryableFailure +} + +type TaskStatus struct { + Phase TaskPhase + PhaseVersion uint32 + Err error + State CustomState + OccurredAt time.Time +} + +func (t TaskStatus) String() string { + sb := strings.Builder{} + sb.WriteString("{Phase: ") + sb.WriteString(t.Phase.String()) + sb.WriteString(", Version: ") + sb.WriteString(strconv.Itoa(int(t.PhaseVersion))) + sb.WriteString(", At: ") + sb.WriteString(t.OccurredAt.String()) + if t.Err != nil { + sb.WriteString(", Err: ") + sb.WriteString(t.Err.Error()) + } + + sb.WriteString(", CustomStateLen: ") + sb.WriteString(strconv.Itoa(len(t.State))) + + sb.WriteString("}") + + return sb.String() +} + +var TaskStatusNotReady = TaskStatus{Phase: TaskPhaseNotReady} +var TaskStatusQueued = TaskStatus{Phase: TaskPhaseQueued} +var TaskStatusRunning = TaskStatus{Phase: TaskPhaseRunning} +var TaskStatusSucceeded = TaskStatus{Phase: TaskPhaseSucceeded} +var TaskStatusUndefined = TaskStatus{Phase: TaskPhaseUndefined} +var TaskStatusUnknown = TaskStatus{Phase: TaskPhaseUnknown} + +func (t TaskStatus) WithPhaseVersion(version uint32) TaskStatus { + t.PhaseVersion = version + return t +} + +func (t TaskStatus) WithState(state CustomState) TaskStatus { + t.State = state + return t +} + +func (t TaskStatus) WithOccurredAt(time time.Time) TaskStatus { + t.OccurredAt = time + return t +} + +// This failure can be used to indicate that the task wasn't accepted due to resource quota or similar constraints. +func TaskStatusNotReadyFailure(err error) TaskStatus { + return TaskStatus{Phase: TaskPhaseNotReady, Err: err} +} + +// This failure can be used to indicate that the task failed with an error that is most probably transient +// and if the task retries (retry strategy) permits, it is safe to retry this task again. +// The same task execution will not be retried, but a new task execution will be created. +func TaskStatusRetryableFailure(err error) TaskStatus { + return TaskStatus{Phase: TaskPhaseRetryableFailure, Err: err} +} + +// PermanentFailure should be used to signal that either +// 1. The user wants to signal that the task has failed with something NON-RECOVERABLE +// 2. The plugin writer wants to signal that the task has failed with NON-RECOVERABLE +// Essentially a permanent failure will force the statemachine to shutdown and stop the task from being retried +// further, even if retries exist. +// If it is desirable to retry the task (a separate execution) then, use RetryableFailure +func TaskStatusPermanentFailure(err error) TaskStatus { + return TaskStatus{Phase: TaskPhasePermanentFailure, Err: err} +} diff --git a/flyteplugins/go/tasks/v1/types/status_test.go b/flyteplugins/go/tasks/v1/types/status_test.go new file mode 100755 index 0000000000..7bb0a1686f --- /dev/null +++ b/flyteplugins/go/tasks/v1/types/status_test.go @@ -0,0 +1,34 @@ +package types + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestTaskStatus_CopyConstructor(t *testing.T) { + t.Run("WithOccurredAt", func(t *testing.T) { + input := TaskStatusSucceeded + assert.True(t, input.OccurredAt.IsZero()) + actual := input.WithOccurredAt(time.Now()) + assert.True(t, input.OccurredAt.IsZero()) + assert.False(t, actual.OccurredAt.IsZero()) + }) + + t.Run("WithPhaseVersion", func(t *testing.T) { + input := TaskStatusSucceeded + assert.Zero(t, input.PhaseVersion) + actual := input.WithPhaseVersion(4) + assert.Zero(t, input.PhaseVersion) + assert.NotZero(t, actual.PhaseVersion) + }) + + t.Run("WithState", func(t *testing.T) { + input := TaskStatusSucceeded + assert.Nil(t, input.State) + actual := input.WithState(map[string]interface{}{"hello": "world"}) + assert.Nil(t, input.State) + assert.NotNil(t, actual.State) + }) +} diff --git a/flyteplugins/go/tasks/v1/types/task.go b/flyteplugins/go/tasks/v1/types/task.go new file mode 100755 index 0000000000..346630f68f --- /dev/null +++ b/flyteplugins/go/tasks/v1/types/task.go @@ -0,0 +1,92 @@ +package types + +import ( + "context" + + "github.com/lyft/flytestdlib/promutils" + "k8s.io/apimachinery/pkg/types" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/event" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytestdlib/storage" +) + +//go:generate mockery -name Executor + +type TaskType = string +type WorkflowID = string +type VarName = string +type TaskExecutorName = string + +type EnqueueOwner func(name types.NamespacedName) error + +// Defines optional properties for the executor. +type ExecutorProperties struct { + // If the executor needs to clean-up external resources that won't automatically be garbage-collected by the fact that + // the containing-k8s object is being deleted, it should set this value to true. This ensures that the containing-k8s + // object is not deleted until all executors of non-terminal phase tasks report success for KillTask calls. + RequiresFinalizer bool + + // If set, the execution engine will not perform node-level task caching and retrieval. This can be useful for more + // fine-grained executors that implement their own logic for caching. + DisableNodeLevelCaching bool +} + +// Defines the exposed interface for plugins to record task events. +// TODO: Add a link to explain how events are structred and linked together. +type EventRecorder interface { + RecordTaskEvent(ctx context.Context, event *event.TaskExecutionEvent) error +} + +// Defines the Catalog client interface exposed for plugins +type CatalogClient interface { + Get(ctx context.Context, task *core.TaskTemplate, inputPath storage.DataReference) (*core.LiteralMap, error) + Put(ctx context.Context, task *core.TaskTemplate, execID *core.TaskExecutionIdentifier, inputPath storage.DataReference, outputPath storage.DataReference) error +} + +// Defines the all-optional initialization parameters passed to plugins. +type ExecutorInitializationParameters struct { + CatalogClient CatalogClient + EventRecorder EventRecorder + DataStore *storage.DataStore + EnqueueOwner EnqueueOwner + OwnerKind string + MetricsScope promutils.Scope +} + +// Defines a task executor interface. +type Executor interface { + // Gets a unique identifier for the executor. No two executors can have the same ID. + GetID() TaskExecutorName + + // Gets optional properties about this executor. These properties are not task-specific. + GetProperties() ExecutorProperties + + // Initializes the executor. The executor should not have any heavy initialization logic in its constructor and should + // delay all initialization logic till this method is called. + Initialize(ctx context.Context, param ExecutorInitializationParameters) error + + // Start the task with an initial state that could be empty and return the new state of the task once it started + StartTask(ctx context.Context, taskCtx TaskContext, task *core.TaskTemplate, inputs *core.LiteralMap) ( + status TaskStatus, err error) + + // ChecksTaskStatus is called every time client needs to know the latest status of a given task this specific + // executor launched. It passes the same task context that was used when StartTask was called as well as the last + // known state of the task. Note that there is no strict guarantee that the previous state is literally the last + // status returned due to the nature of eventual consistency in the system. The system guarantees idempotency as long + // as it's within kubernetes boundaries or if external services support idempotency. + CheckTaskStatus(ctx context.Context, taskCtx TaskContext, task *core.TaskTemplate) (status TaskStatus, err error) + + // The engine will ensure kill task is called in abort scenarios. KillTask will not be called in case CheckTaskStatus + // ever returned a terminal phase. + KillTask(ctx context.Context, taskCtx TaskContext, reason string) error + + // ResolveOutputs is responsible for retrieving outputs variables from a task. For simple tasks, adding OutputsResolver + // in the executor is enough to get a default implementation. + ResolveOutputs(ctx context.Context, taskCtx TaskContext, outputVariables ...VarName) ( + values map[VarName]*core.Literal, err error) +} + +// Represents a free-form state that allows plugins to store custom information between invocations. +type CustomState = map[string]interface{} diff --git a/flyteplugins/go/tasks/v1/types/task_context.go b/flyteplugins/go/tasks/v1/types/task_context.go new file mode 100755 index 0000000000..ad78075fa7 --- /dev/null +++ b/flyteplugins/go/tasks/v1/types/task_context.go @@ -0,0 +1,44 @@ +package types + +import ( + "github.com/lyft/flytestdlib/storage" + "k8s.io/apimachinery/pkg/types" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + typesv1 "k8s.io/api/core/v1" + metaV1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +//go:generate mockery -name TaskContext + +// Interface to expose any overrides that have been set for this task (like resource overrides etc) +type TaskOverrides interface { + GetResources() *typesv1.ResourceRequirements + GetConfig() *typesv1.ConfigMap +} + +// Simple Interface to expose the ExecutionID of the running Task +type TaskExecutionID interface { + GetGeneratedName() string + GetID() core.TaskExecutionIdentifier +} + +// TaskContext represents any execution information for a Task. It is used to communicate meta information about the +// execution or any previously stored information +type TaskContext interface { + GetOwnerID() types.NamespacedName + GetTaskExecutionID() TaskExecutionID + GetDataDir() storage.DataReference + GetInputsFile() storage.DataReference + GetOutputsFile() storage.DataReference + GetErrorFile() storage.DataReference + GetNamespace() string + GetOwnerReference() metaV1.OwnerReference + GetOverrides() TaskOverrides + GetLabels() map[string]string + GetAnnotations() map[string]string + GetCustomState() CustomState + GetK8sServiceAccount() string + GetPhase() TaskPhase + GetPhaseVersion() uint32 +} diff --git a/flyteplugins/go/tasks/v1/types/task_test.go b/flyteplugins/go/tasks/v1/types/task_test.go new file mode 100755 index 0000000000..1a981f8512 --- /dev/null +++ b/flyteplugins/go/tasks/v1/types/task_test.go @@ -0,0 +1,52 @@ +package types + +import ( + "encoding/json" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +type myCustomState struct { + String string `json:"str"` + Recursive *myCustomState `json:"recursive"` + Map map[string]*myCustomState `json:"map"` +} + +func marshalIntoMap(x interface{}) (map[string]interface{}, error) { + raw, err := json.Marshal(x) + if err != nil { + return nil, err + } + + res := map[string]interface{}{} + return res, json.Unmarshal(raw, &res) +} + +func TestState_marshaling(t *testing.T) { + expectedMessage := "TestMessage" + expectedCustom := myCustomState{ + String: expectedMessage, + Recursive: &myCustomState{ + String: expectedMessage + expectedMessage, + }, + Map: map[string]*myCustomState{ + "key1": { + String: expectedMessage, + }, + }, + } + + rawCustom, err := marshalIntoMap(expectedCustom) + assert.NoError(t, err) + + raw, err := json.Marshal(&rawCustom) + assert.NoError(t, err) + + newState := &myCustomState{} + err = json.Unmarshal(raw, newState) + assert.NoError(t, err) + assert.True(t, reflect.DeepEqual(expectedCustom, *newState), "%v != %v", rawCustom, *newState) + assert.Equal(t, expectedMessage, newState.Map["key1"].String) +} diff --git a/flyteplugins/go/tasks/v1/utils/marshal_utils.go b/flyteplugins/go/tasks/v1/utils/marshal_utils.go new file mode 100755 index 0000000000..e5d5d9c8b6 --- /dev/null +++ b/flyteplugins/go/tasks/v1/utils/marshal_utils.go @@ -0,0 +1,66 @@ +package utils + +import ( + "encoding/json" + "fmt" + + "github.com/golang/protobuf/jsonpb" + "github.com/golang/protobuf/proto" + structpb "github.com/golang/protobuf/ptypes/struct" +) + +var jsonPbMarshaler = jsonpb.Marshaler{} + +func UnmarshalStruct(structObj *structpb.Struct, msg proto.Message) error { + if structObj == nil { + return fmt.Errorf("nil Struct Object passed") + } + + jsonObj, err := jsonPbMarshaler.MarshalToString(structObj) + if err != nil { + return err + } + + if err = jsonpb.UnmarshalString(jsonObj, msg); err != nil { + return err + } + + return nil +} + +func MarshalStruct(in proto.Message, out *structpb.Struct) error { + if out == nil { + return fmt.Errorf("nil Struct Object passed") + } + + jsonObj, err := jsonPbMarshaler.MarshalToString(in) + if err != nil { + return err + } + + if err = jsonpb.UnmarshalString(jsonObj, out); err != nil { + return err + } + + return nil +} + +func MarshalToString(msg proto.Message) (string, error) { + return jsonPbMarshaler.MarshalToString(msg) +} + +// TODO: Use the stdlib version in the future, or move there if not there. +// Don't use this if input is a proto Message. +func MarshalObjToStruct(input interface{}) (*structpb.Struct, error) { + b, err := json.Marshal(input) + if err != nil { + return nil, err + } + + // Turn JSON into a protobuf struct + structObj := &structpb.Struct{} + if err := jsonpb.UnmarshalString(string(b), structObj); err != nil { + return nil, err + } + return structObj, nil +} diff --git a/flyteplugins/go/tasks/v1/utils/template.go b/flyteplugins/go/tasks/v1/utils/template.go new file mode 100755 index 0000000000..1d88ec6f0e --- /dev/null +++ b/flyteplugins/go/tasks/v1/utils/template.go @@ -0,0 +1,152 @@ +package utils + +import ( + "context" + "fmt" + "reflect" + "regexp" + "strings" + + "github.com/golang/protobuf/ptypes" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytestdlib/logger" +) + +var inputFileRegex = regexp.MustCompile(`(?i){{\s*[\.$]Input\s*}}`) +var outputRegex = regexp.MustCompile(`(?i){{\s*[\.$]OutputPrefix\s*}}`) +var inputVarRegex = regexp.MustCompile(`(?i){{\s*[\.$]Inputs\.(?P[^}\s]+)\s*}}`) + +// Contains arguments passed down to command line templates. +type CommandLineTemplateArgs struct { + Input string `json:"input"` + OutputPrefix string `json:"output"` + Inputs map[string]string `json:"inputs"` +} + +// Evaluates templates in each command with the equivalent value from passed args. Templates are case-insensitive +// Supported templates are: +// - {{ .InputFile }} to receive the input file path. The protocol used will depend on the underlying system +// configuration. E.g. s3://bucket/key/to/file.pb or /var/run/local.pb are both valid. +// - {{ .OutputPrefix }} to receive the path prefix for where to store the outputs. +// - {{ .Inputs.myInput }} to receive the actual value of the input passed. See docs on LiteralMapToTemplateArgs for how +// what to expect each literal type to be serialized as. +// If a command isn't a valid template or failed to evaluate, it'll be returned as is. +// NOTE: I wanted to do in-place replacement, until I realized that in-place replacement will alter the definition of the +// graph. This is not desirable, as we may have to retry and in that case the replacement will not work and we want +// to create a new location for outputs +func ReplaceTemplateCommandArgs(ctx context.Context, command []string, args CommandLineTemplateArgs) ([]string, error) { + res := make([]string, 0, len(command)) + for _, commandTemplate := range command { + updated, err := replaceTemplateCommandArgs(ctx, commandTemplate, &args) + if err != nil { + return res, err + } + + res = append(res, updated) + } + + return res, nil +} + +func replaceTemplateCommandArgs(_ context.Context, commandTemplate string, args *CommandLineTemplateArgs) (string, error) { + val := inputFileRegex.ReplaceAllString(commandTemplate, args.Input) + val = outputRegex.ReplaceAllString(val, args.OutputPrefix) + groupMatches := inputVarRegex.FindAllStringSubmatchIndex(val, -1) + if len(groupMatches) == 0 { + return val, nil + } else if len(groupMatches) > 1 { + return val, fmt.Errorf("only one level of inputs nesting is supported. Syntax in [%v] is invalid", commandTemplate) + } else if len(groupMatches[0]) > 4 { + return val, fmt.Errorf("longer submatches not supported. Syntax in [%v] is invalid", commandTemplate) + } else { + startIdx := groupMatches[0][0] + endIdx := groupMatches[0][1] + inputStartIdx := groupMatches[0][2] + inputEndIdx := groupMatches[0][3] + inputName := val[inputStartIdx:inputEndIdx] + inputVal, exists := args.Inputs[inputName] + if !exists { + return val, fmt.Errorf("requested input is not found [%v] while processing template [%v]", + inputName, commandTemplate) + } + + if endIdx >= len(val) { + return val[:startIdx] + inputVal, nil + } + + return val[:startIdx] + inputVal + val[endIdx:], nil + } +} + +// Converts a literal map to a go map that can be used in templates. It drops literals that don't have a defined way to +// be safely serialized into a string. +func LiteralMapToTemplateArgs(ctx context.Context, m *core.LiteralMap) map[string]string { + if m == nil { + return map[string]string{} + } + + res := make(map[string]string, len(m.Literals)) + + for key, val := range m.Literals { + serialized, ok := serializeLiteral(ctx, val) + if ok { + res[key] = serialized + } + } + + return res +} + +func serializePrimitive(ctx context.Context, p *core.Primitive) (string, bool) { + switch o := p.Value.(type) { + case *core.Primitive_Integer: + return fmt.Sprintf("%v", o.Integer), true + case *core.Primitive_Boolean: + return fmt.Sprintf("%v", o.Boolean), true + case *core.Primitive_Datetime: + return ptypes.TimestampString(o.Datetime), true + case *core.Primitive_Duration: + return o.Duration.String(), true + case *core.Primitive_FloatValue: + return fmt.Sprintf("%v", o.FloatValue), true + case *core.Primitive_StringValue: + return o.StringValue, true + default: + logger.Warnf(ctx, "Received an unexpected primitive type [%v]", reflect.TypeOf(p.Value)) + return "", false + } +} + +func serializeLiteralScalar(ctx context.Context, l *core.Scalar) (string, bool) { + switch o := l.Value.(type) { + case *core.Scalar_Primitive: + return serializePrimitive(ctx, o.Primitive) + case *core.Scalar_Blob: + return o.Blob.Uri, true + default: + logger.Warnf(ctx, "Received an unexpected scalar type [%v]", reflect.TypeOf(l.Value)) + return "", false + } +} + +func serializeLiteral(ctx context.Context, l *core.Literal) (string, bool) { + switch o := l.Value.(type) { + case *core.Literal_Collection: + res := make([]string, 0, len(o.Collection.Literals)) + for _, sub := range o.Collection.Literals { + s, ok := serializeLiteral(ctx, sub) + if !ok { + return "", false + } + + res = append(res, s) + } + + return fmt.Sprintf("[%v]", strings.Join(res, ",")), true + case *core.Literal_Scalar: + return serializeLiteralScalar(ctx, o.Scalar) + default: + logger.Warnf(ctx, "Received an unexpected primitive type [%v]", reflect.TypeOf(l.Value)) + return "", false + } +} diff --git a/flyteplugins/go/tasks/v1/utils/template_test.go b/flyteplugins/go/tasks/v1/utils/template_test.go new file mode 100755 index 0000000000..5060927db6 --- /dev/null +++ b/flyteplugins/go/tasks/v1/utils/template_test.go @@ -0,0 +1,239 @@ +package utils + +import ( + "bytes" + "context" + "testing" + "text/template" + "time" + + "github.com/lyft/flyteidl/clients/go/coreutils" + + "github.com/stretchr/testify/assert" +) + +func BenchmarkRegexCommandArgs(b *testing.B) { + for i := 0; i < b.N; i++ { + inputFileRegex.MatchString("{{ .InputFile }}") + } +} + +// Benchmark results: +// Regex_replacement-8 3000000 583 ns/op +// NotCompiled-8 100000 14684 ns/op +// Precompile/Execute-8 500000 2706 ns/op +func BenchmarkReplacements(b *testing.B) { + cmd := `abc {{ .Inputs.x }} ` + cmdTemplate := `abc {{ index .Inputs "x" }}` + cmdArgs := CommandLineTemplateArgs{ + Input: "inputfile.pb", + Inputs: map[string]string{ + "x": "1", + }, + } + + b.Run("NotCompiled", func(b *testing.B) { + for i := 0; i < b.N; i++ { + t, err := template.New("NotCompiled").Parse(cmdTemplate) + assert.NoError(b, err) + var buf bytes.Buffer + err = t.Execute(&buf, cmdArgs) + assert.NoError(b, err) + } + }) + + b.Run("Precompile", func(b *testing.B) { + t, err := template.New("NotCompiled").Parse(cmdTemplate) + assert.NoError(b, err) + + b.Run("Execute", func(b *testing.B) { + for i := 0; i < b.N; i++ { + var buf bytes.Buffer + err = t.Execute(&buf, cmdArgs) + assert.NoError(b, err) + } + }) + }) + + b.Run("Regex replacement", func(b *testing.B) { + for i := 0; i < b.N; i++ { + inputVarRegex.FindAllStringSubmatchIndex(cmd, -1) + } + }) +} + +func TestInputRegexMatch(t *testing.T) { + assert.True(t, inputFileRegex.MatchString("{{$input}}")) + assert.True(t, inputFileRegex.MatchString("{{ $Input }}")) + assert.True(t, inputFileRegex.MatchString("{{.input}}")) + assert.True(t, inputFileRegex.MatchString("{{ .Input }}")) + assert.True(t, inputFileRegex.MatchString("{{ .Input }}")) + assert.True(t, inputFileRegex.MatchString("{{ .Input }}")) + assert.True(t, inputFileRegex.MatchString("{{ .Input}}")) + assert.True(t, inputFileRegex.MatchString("{{.Input }}")) + assert.True(t, inputFileRegex.MatchString("--something={{.Input}}")) + assert.False(t, inputFileRegex.MatchString("{{input}}"), "Missing $") + assert.False(t, inputFileRegex.MatchString("{$input}}"), "Missing Brace") +} + +func TestOutputRegexMatch(t *testing.T) { + assert.True(t, outputRegex.MatchString("{{.OutputPrefix}}")) + assert.True(t, outputRegex.MatchString("{{ .OutputPrefix }}")) + assert.True(t, outputRegex.MatchString("{{ .OutputPrefix }}")) + assert.True(t, outputRegex.MatchString("{{ .OutputPrefix }}")) + assert.True(t, outputRegex.MatchString("{{ .OutputPrefix}}")) + assert.True(t, outputRegex.MatchString("{{.OutputPrefix }}")) + assert.True(t, outputRegex.MatchString("--something={{.OutputPrefix}}")) + assert.False(t, outputRegex.MatchString("{{output}}"), "Missing $") + assert.False(t, outputRegex.MatchString("{.OutputPrefix}}"), "Missing Brace") +} + +func TestReplaceTemplateCommandArgs(t *testing.T) { + t.Run("empty cmd", func(t *testing.T) { + actual, err := ReplaceTemplateCommandArgs(context.TODO(), + []string{}, + CommandLineTemplateArgs{Input: "input/blah", OutputPrefix: "output/blah"}) + assert.NoError(t, err) + assert.Equal(t, []string{}, actual) + }) + + t.Run("nothing to substitute", func(t *testing.T) { + actual, err := ReplaceTemplateCommandArgs(context.TODO(), []string{ + "hello", + "world", + }, CommandLineTemplateArgs{Input: "input/blah", OutputPrefix: "output/blah"}) + assert.NoError(t, err) + + assert.Equal(t, []string{ + "hello", + "world", + }, actual) + }) + + t.Run("Sub InputFile", func(t *testing.T) { + actual, err := ReplaceTemplateCommandArgs(context.TODO(), []string{ + "hello", + "world", + "{{ .Input }}", + }, CommandLineTemplateArgs{Input: "input/blah", OutputPrefix: "output/blah"}) + assert.NoError(t, err) + + assert.Equal(t, []string{ + "hello", + "world", + "input/blah", + }, actual) + }) + + t.Run("Sub Output Prefix", func(t *testing.T) { + actual, err := ReplaceTemplateCommandArgs(context.TODO(), []string{ + "hello", + "world", + "{{ .OutputPrefix }}", + }, CommandLineTemplateArgs{Input: "input/blah", OutputPrefix: "output/blah"}) + assert.NoError(t, err) + assert.Equal(t, []string{ + "hello", + "world", + "output/blah", + }, actual) + }) + + t.Run("Sub Input Output prefix", func(t *testing.T) { + actual, err := ReplaceTemplateCommandArgs(context.TODO(), []string{ + "hello", + "world", + "{{ .Input }}", + "{{ .OutputPrefix }}", + }, CommandLineTemplateArgs{Input: "input/blah", OutputPrefix: "output/blah"}) + assert.NoError(t, err) + assert.Equal(t, []string{ + "hello", + "world", + "input/blah", + "output/blah", + }, actual) + }) + + t.Run("Bad input template", func(t *testing.T) { + actual, err := ReplaceTemplateCommandArgs(context.TODO(), []string{ + "hello", + "world", + "${{input}}", + "{{ .OutputPrefix }}", + }, CommandLineTemplateArgs{Input: "input/blah", OutputPrefix: "output/blah"}) + assert.NoError(t, err) + assert.Equal(t, []string{ + "hello", + "world", + "${{input}}", + "output/blah", + }, actual) + }) + + t.Run("Input arg", func(t *testing.T) { + actual, err := ReplaceTemplateCommandArgs(context.TODO(), []string{ + "hello", + "world", + `--someArg {{ .Inputs.arr }}`, + "{{ .OutputPrefix }}", + }, CommandLineTemplateArgs{ + Input: "input/blah", + OutputPrefix: "output/blah", + Inputs: map[string]string{ + "arr": "[a,b]", + }}) + assert.NoError(t, err) + assert.Equal(t, []string{ + "hello", + "world", + "--someArg [a,b]", + "output/blah", + }, actual) + }) +} + +func TestLiteralMapToTemplateArgs(t *testing.T) { + t.Run("Scalars", func(t *testing.T) { + expected := map[string]string{ + "str": "blah", + "int": "5", + "date": "1900-01-01T01:01:01.000000001Z", + } + + dd := time.Date(1900, 1, 1, 1, 1, 1, 1, time.UTC) + lit := coreutils.MustMakeLiteral(map[string]interface{}{ + "str": "blah", + "int": 5, + "date": dd, + }) + + actual := LiteralMapToTemplateArgs(context.TODO(), lit.GetMap()) + + assert.Equal(t, expected, actual) + }) + + t.Run("1d array", func(t *testing.T) { + expected := map[string]string{ + "arr": "[a,b]", + } + + actual := LiteralMapToTemplateArgs(context.TODO(), coreutils.MustMakeLiteral(map[string]interface{}{ + "arr": []interface{}{"a", "b"}, + }).GetMap()) + + assert.Equal(t, expected, actual) + }) + + t.Run("2d array", func(t *testing.T) { + expected := map[string]string{ + "arr": "[[a,b],[1,2]]", + } + + actual := LiteralMapToTemplateArgs(context.TODO(), coreutils.MustMakeLiteral(map[string]interface{}{ + "arr": []interface{}{[]interface{}{"a", "b"}, []interface{}{1, 2}}, + }).GetMap()) + + assert.Equal(t, expected, actual) + }) +} diff --git a/flyteplugins/go/tasks/v1/utils/transformers.go b/flyteplugins/go/tasks/v1/utils/transformers.go new file mode 100755 index 0000000000..3faf6f12b3 --- /dev/null +++ b/flyteplugins/go/tasks/v1/utils/transformers.go @@ -0,0 +1,26 @@ +package utils + +func CopyMap(o map[string]string) (r map[string]string) { + if o == nil { + return nil + } + r = make(map[string]string, len(o)) + for k, v := range o { + r[k] = v + } + return +} + +func Contains(s []string, e string) bool { + if s == nil { + return false + } + + for _, a := range s { + if a == e { + return true + } + } + + return false +} diff --git a/flyteplugins/go/tasks/v1/utils/transformers_test.go b/flyteplugins/go/tasks/v1/utils/transformers_test.go new file mode 100755 index 0000000000..f16732901d --- /dev/null +++ b/flyteplugins/go/tasks/v1/utils/transformers_test.go @@ -0,0 +1,18 @@ +package utils + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestContains(t *testing.T) { + + assert.True(t, Contains([]string{"a", "b", "c"}, "b")) + + assert.False(t, Contains([]string{"a", "b", "c"}, "spark")) + + assert.False(t, Contains([]string{}, "spark")) + + assert.False(t, Contains(nil, "b")) +} diff --git a/flyteplugins/tests/hive_integration_test.go b/flyteplugins/tests/hive_integration_test.go new file mode 100755 index 0000000000..d09c4c077c --- /dev/null +++ b/flyteplugins/tests/hive_integration_test.go @@ -0,0 +1,201 @@ +// +build manualintegration +// Be sure to add this to your goland build settings Tags in order to get linting/testing +// Set the QUBOLE_API_KEY environment variable in your test to be a real token + +package tests + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/golang/protobuf/ptypes/struct" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins" + config2 "github.com/lyft/flytestdlib/config" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/storage" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/lyft/flyteplugins/go/tasks/v1/qubole" + "github.com/lyft/flyteplugins/go/tasks/v1/qubole/config" + "github.com/lyft/flyteplugins/go/tasks/v1/types" + tasksMocks "github.com/lyft/flyteplugins/go/tasks/v1/types/mocks" + "github.com/lyft/flyteplugins/go/tasks/v1/utils" +) + +func getMockTaskContext() *tasksMocks.TaskContext { + taskCtx := &tasksMocks.TaskContext{} + taskCtx.On("GetNamespace").Return("ns") + taskCtx.On("GetAnnotations").Return(map[string]string{"aKey": "aVal"}) + taskCtx.On("GetLabels").Return(map[string]string{"lKey": "lVal"}) + taskCtx.On("GetOwnerReference").Return(metav1.OwnerReference{Name: "x"}) + taskCtx.On("GetInputsFile").Return(storage.DataReference("/fake/inputs.pb")) + taskCtx.On("GetDataDir").Return(storage.DataReference("/fake/")) + taskCtx.On("GetErrorFile").Return(storage.DataReference("/fake/error.pb")) + taskCtx.On("GetOutputsFile").Return(storage.DataReference("/fake/inputs.pb")) + taskCtx.On("GetPhaseVersion").Return(uint32(1)) + + id := &tasksMocks.TaskExecutionID{} + id.On("GetGeneratedName").Return("flyteplugins_integration") + id.On("GetID").Return(core.TaskExecutionIdentifier{}) + taskCtx.On("GetTaskExecutionID").Return(id) + + return taskCtx +} + +func getTaskTemplate() core.TaskTemplate { + hiveJob := plugins.QuboleHiveJob{ + ClusterLabel: "default", + Tags: []string{"flyte_plugin_test"}, + QueryCollection: &plugins.HiveQueryCollection{ + Queries: []*plugins.HiveQuery{ + {TimeoutSec: 500, + Query: "select 'one'", + RetryCount: 0}, + }, + }, + } + stObj := &structpb.Struct{} + utils.MarshalStruct(&hiveJob, stObj) + tt := core.TaskTemplate{ + Type: "hive", + Custom: stObj, + Id: &core.Identifier{ + Name: "integrationtest1", + Project: "flyteplugins", + Version: "1", + ResourceType: core.ResourceType_TASK, + }, + } + + return tt +} + +func TestHappy(t *testing.T) { + _ = logger.SetConfig(&logger.Config{ + IncludeSourceCode: true, + Level: 6, + }) + ctx := context.Background() + testScope := promutils.NewTestScope() + mockEventRecorder := tasksMocks.EventRecorder{} + mockEventRecorder.On("RecordTaskEvent", mock.Anything, mock.Anything).Return(nil).Run(func(args mock.Arguments) { + fmt.Printf("Event: %v\n", args) + }) + + assert.NoError(t, config.SetQuboleConfig(&config.Config{ + RedisHostPath: "localhost:6379", + RedisHostKey: "mypassword", + LookasideBufferPrefix: "test", + LookasideExpirySeconds: config2.Duration{Duration: time.Second * 1000}, + LruCacheSize: 100, + })) + + executor, err := qubole.NewHiveTaskExecutorWithCache(ctx) + assert.NoError(t, err) + err = executor.Initialize(ctx, types.ExecutorInitializationParameters{ + MetricsScope: testScope, + EventRecorder: &mockEventRecorder}) + assert.NoError(t, err) + + var taskCtx *tasksMocks.TaskContext + taskCtx = getMockTaskContext() + taskTemplate := getTaskTemplate() + + statuses, err := executor.StartTask(ctx, taskCtx, &taskTemplate, nil) + assert.NoError(t, err) + assert.Equal(t, types.TaskPhaseQueued, statuses.Phase) + + // The initial custom state returned by the StartTask function + customState0 := statuses.State + work := customState0["flyteplugins_integration_0"].(qubole.QuboleWorkItem) + assert.Equal(t, qubole.QuboleWorkNotStarted, work.Status) + assert.Equal(t, 0, work.Retries) + assert.Equal(t, "flyteplugins_integration_0", work.ID()) + + taskCtx.On("GetCustomState").Return(customState0) + taskCtx.On("GetPhase").Return(types.TaskPhaseQueued) + + for true { + taskStatus, err := executor.CheckTaskStatus(ctx, taskCtx, &taskTemplate) + assert.NoError(t, err) + fmt.Printf("New status phase %s custom state %v\n", taskStatus.Phase, taskStatus.State) + if taskStatus.Phase == types.TaskPhaseSucceeded { + fmt.Println("success") + break + } + taskCtx = getMockTaskContext() + taskCtx.On("GetCustomState").Return(taskStatus.State) + taskCtx.On("GetPhase").Return(types.TaskPhaseRunning) + time.Sleep(15 * time.Second) + } +} + +func TestMultipleCallbacks(t *testing.T) { + _ = logger.SetConfig(&logger.Config{ + IncludeSourceCode: true, + Level: 6, + }) + ctx := context.Background() + testScope := promutils.NewTestScope() + mockEventRecorder := tasksMocks.EventRecorder{} + mockEventRecorder.On("RecordTaskEvent", mock.Anything, mock.Anything).Return(nil).Run(func(args mock.Arguments) { + fmt.Printf("Event: %v\n", args) + }) + + assert.NoError(t, config.SetQuboleConfig(&config.Config{ + RedisHostPath: "localhost:6379", + RedisHostKey: "mypassword", + LookasideBufferPrefix: "test", + LookasideExpirySeconds: config2.Duration{Duration: time.Second * 1000}, + LruCacheSize: 100, + })) + + executor, err := qubole.NewHiveTaskExecutorWithCache(ctx) + assert.NoError(t, err) + err = executor.Initialize(ctx, types.ExecutorInitializationParameters{ + MetricsScope: testScope, + EventRecorder: &mockEventRecorder}) + assert.NoError(t, err) + + var taskCtx *tasksMocks.TaskContext + taskCtx = getMockTaskContext() + taskTemplate := getTaskTemplate() + + statuses, err := executor.StartTask(ctx, taskCtx, &taskTemplate, nil) + assert.NoError(t, err) + assert.Equal(t, types.TaskPhaseQueued, statuses.Phase) + + // The initial custom state returned by the StartTask function + customState0 := statuses.State + work := customState0["flyteplugins_integration_0"].(qubole.QuboleWorkItem) + assert.Equal(t, qubole.QuboleWorkNotStarted, work.Status) + assert.Equal(t, 0, work.Retries) + assert.Equal(t, "flyteplugins_integration_0", work.ID()) + + taskCtx.On("GetCustomState").Return(customState0) + taskCtx.On("GetPhase").Return(types.TaskPhaseQueued) + + for true { + taskStatus, err := executor.CheckTaskStatus(ctx, taskCtx, &taskTemplate) + assert.NoError(t, err) + for i := 0; i < 100; i++ { + taskStatus, err = executor.CheckTaskStatus(ctx, taskCtx, &taskTemplate) + assert.NoError(t, err) + } + fmt.Printf("New status phase %s custom state %v\n", taskStatus.Phase, taskStatus.State) + if taskStatus.Phase == types.TaskPhaseSucceeded { + fmt.Println("success") + break + } + taskCtx = getMockTaskContext() + taskCtx.On("GetCustomState").Return(taskStatus.State) + taskCtx.On("GetPhase").Return(types.TaskPhaseRunning) + time.Sleep(15 * time.Second) + } +} \ No newline at end of file diff --git a/flyteplugins/tests/redis_lookaside_buffer_test.go b/flyteplugins/tests/redis_lookaside_buffer_test.go new file mode 100755 index 0000000000..423dd02db7 --- /dev/null +++ b/flyteplugins/tests/redis_lookaside_buffer_test.go @@ -0,0 +1,35 @@ +// +build manualintegration +// Be sure to add this to your goland build settings Tags in order to get linting/testing +// In order to run these integration tests you will need to run +// +// kubectl -n flyte port-forward service/redis-resource-manager 6379:6379 + +package tests + +import ( + "context" + "github.com/lyft/flyteplugins/go/tasks/v1/resourcemanager" + "github.com/lyft/flytestdlib/logger" + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +func TestRedisLookasideBuffer(t *testing.T) { + _ = logger.SetConfig(&logger.Config{ + IncludeSourceCode: true, + Level: 6, + }) + + ctx := context.Background() + redisClient, err := resourcemanager.NewRedisClient(ctx, "localhost:6379", "mypassword") + assert.NoError(t, err) + expiry := time.Duration(1) * time.Second // To ensure your local Redis cache stays clean + buffer := resourcemanager.NewRedisLookasideBuffer(ctx, redisClient, "testPrefix", expiry) + assert.NoError(t, err) + + err = buffer.ConfirmExecution(ctx, "mykey", "123456") + assert.NoError(t, err) + commandId, err := buffer.RetrieveExecution(ctx, "mykey") + assert.Equal(t, "123456", commandId) +} diff --git a/flyteplugins/tests/redis_resource_manager_test.go b/flyteplugins/tests/redis_resource_manager_test.go new file mode 100755 index 0000000000..175404af1a --- /dev/null +++ b/flyteplugins/tests/redis_resource_manager_test.go @@ -0,0 +1,36 @@ +// +build manualintegration +// Be sure to add this to your goland build settings Tags in order to get linting/testing +// In order to run these integration tests you will need to run +// +// kubectl -n flyte port-forward service/redis-resource-manager 6379:6379 + +package tests + +import ( + "context" + "fmt" + "github.com/lyft/flyteplugins/go/tasks/v1/resourcemanager" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/promutils" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestRedisResourceManager(t *testing.T) { + _ = logger.SetConfig(&logger.Config{ + IncludeSourceCode: true, + Level: 6, + }) + + ctx := context.Background() + scope := promutils.NewScope("test") + redisClient, err := resourcemanager.NewRedisClient(ctx, "localhost:6379", "mypassword") + assert.NoError(t, err) + manager, err := resourcemanager.NewRedisResourceManager(ctx, redisClient, scope) + assert.NoError(t, err) + + status, err := manager.AllocateResource(ctx, "default", "my-token-1") + assert.Equal(t, resourcemanager.AllocationStatusGranted, status) + + fmt.Println(status) +} From e7a247a43f921b02c48d3074cddcfb059097e86d Mon Sep 17 00:00:00 2001 From: Andrew Chan Date: Wed, 21 Aug 2019 10:47:47 -0700 Subject: [PATCH 0074/1918] Initial Commit --- datacatalog/.gitignore | 5 + datacatalog/.golangci.yml | 30 + datacatalog/.travis.yml | 26 + datacatalog/CODE_OF_CONDUCT.md | 2 + datacatalog/Dockerfile | 33 + datacatalog/Gopkg.lock | 601 ++++++++++ datacatalog/Gopkg.toml | 34 + datacatalog/LICENSE | 202 ++++ datacatalog/Makefile | 20 + datacatalog/NOTICE | 5 + datacatalog/README.md | 2 + .../boilerplate/lyft/docker_build/Makefile | 12 + .../boilerplate/lyft/docker_build/Readme.rst | 23 + .../lyft/docker_build/docker_build.sh | 67 ++ .../golang_dockerfile/Dockerfile.GoTemplate | 33 + .../lyft/golang_dockerfile/Readme.rst | 16 + .../lyft/golang_dockerfile/update.sh | 15 + .../lyft/golang_test_targets/Makefile | 38 + .../lyft/golang_test_targets/Readme.rst | 31 + .../lyft/golang_test_targets/goimports | 8 + .../lyft/golangci_file/.golangci.yml | 30 + .../boilerplate/lyft/golangci_file/Readme.rst | 8 + .../boilerplate/lyft/golangci_file/update.sh | 14 + datacatalog/boilerplate/update.cfg | 4 + datacatalog/boilerplate/update.sh | 53 + datacatalog/cmd/entrypoints/migrate.go | 85 ++ datacatalog/cmd/entrypoints/root.go | 76 ++ datacatalog/cmd/entrypoints/serve.go | 72 ++ datacatalog/cmd/entrypoints/serve_dummy.go | 59 + datacatalog/cmd/main.go | 14 + datacatalog/datacatalog_config.yaml | 27 + datacatalog/pkg/config/config.go | 41 + datacatalog/pkg/config/config_flags.go | 21 + datacatalog/pkg/config/config_flags_test.go | 152 +++ datacatalog/pkg/errors/errors.go | 45 + .../pkg/manager/impl/artifact_data_store.go | 62 + .../pkg/manager/impl/artifact_manager.go | 126 ++ .../pkg/manager/impl/artifact_manager_test.go | 343 ++++++ .../pkg/manager/impl/dataset_manager.go | 70 ++ .../pkg/manager/impl/dataset_manager_test.go | 154 +++ datacatalog/pkg/manager/impl/tag_manager.go | 54 + .../pkg/manager/impl/tag_manager_test.go | 127 ++ .../impl/validators/artifact_validator.go | 66 ++ .../pkg/manager/impl/validators/common.go | 8 + .../impl/validators/dataset_validator.go | 31 + .../pkg/manager/impl/validators/errors.go | 20 + .../manager/impl/validators/tag_validator.go | 28 + .../pkg/manager/interfaces/artifact.go | 12 + datacatalog/pkg/manager/interfaces/dataset.go | 12 + datacatalog/pkg/manager/interfaces/tag.go | 11 + datacatalog/pkg/manager/mocks/artifact.go | 59 + datacatalog/pkg/manager/mocks/dataset.go | 59 + datacatalog/pkg/manager/mocks/tag.go | 36 + .../pkg/repositories/config/database.go | 27 + .../config/dbconfigsection_flags.go | 25 + .../config/dbconfigsection_flags_test.go | 232 ++++ .../pkg/repositories/config/postgres.go | 81 ++ .../repositories/errors/error_transformer.go | 6 + datacatalog/pkg/repositories/errors/errors.go | 16 + .../pkg/repositories/errors/postgres.go | 54 + datacatalog/pkg/repositories/factory.go | 45 + .../pkg/repositories/gormimpl/artifact.go | 67 ++ .../repositories/gormimpl/artifact_test.go | 168 +++ .../pkg/repositories/gormimpl/dataset.go | 57 + .../pkg/repositories/gormimpl/dataset_test.go | 122 ++ datacatalog/pkg/repositories/gormimpl/tag.go | 51 + .../pkg/repositories/gormimpl/tag_test.go | 136 +++ datacatalog/pkg/repositories/handle.go | 77 ++ datacatalog/pkg/repositories/handle_test.go | 78 ++ .../repositories/interfaces/artifact_repo.go | 12 + .../pkg/repositories/interfaces/base.go | 7 + .../repositories/interfaces/dataset_repo.go | 12 + .../pkg/repositories/interfaces/tag_repo.go | 12 + .../pkg/repositories/mocks/artifact.go | 48 + datacatalog/pkg/repositories/mocks/base.go | 21 + datacatalog/pkg/repositories/mocks/dataset.go | 48 + datacatalog/pkg/repositories/mocks/tag.go | 48 + .../pkg/repositories/models/artifact.go | 24 + datacatalog/pkg/repositories/models/base.go | 9 + .../pkg/repositories/models/dataset.go | 14 + datacatalog/pkg/repositories/models/tag.go | 16 + datacatalog/pkg/repositories/postgres_repo.go | 35 + .../pkg/repositories/transformers/artifact.go | 55 + .../transformers/artifact_test.go | 102 ++ .../pkg/repositories/transformers/dataset.go | 52 + .../repositories/transformers/dataset_test.go | 66 ++ .../pkg/repositories/transformers/tag.go | 16 + .../pkg/repositories/transformers/tag_test.go | 26 + .../pkg/repositories/transformers/util.go | 25 + .../repositories/transformers/util_test.go | 25 + .../pkg/repositories/utils/test_utils.go | 19 + .../pkg/rpc/datacatalogservice/service.go | 94 ++ .../runtime/application_config_provider.go | 59 + .../runtime/configs/data_catalog_config.go | 8 + .../configs/datacatalogconfig_flags.go | 19 + .../configs/datacatalogconfig_flags_test.go | 113 ++ .../pkg/runtime/configuration_provider.go | 21 + datacatalog/protos/gen/service.pb.go | 1046 +++++++++++++++++ datacatalog/protos/idl/service.proto | 92 ++ 99 files changed, 6568 insertions(+) create mode 100644 datacatalog/.gitignore create mode 100644 datacatalog/.golangci.yml create mode 100644 datacatalog/.travis.yml create mode 100644 datacatalog/CODE_OF_CONDUCT.md create mode 100644 datacatalog/Dockerfile create mode 100644 datacatalog/Gopkg.lock create mode 100644 datacatalog/Gopkg.toml create mode 100644 datacatalog/LICENSE create mode 100644 datacatalog/Makefile create mode 100644 datacatalog/NOTICE create mode 100644 datacatalog/README.md create mode 100644 datacatalog/boilerplate/lyft/docker_build/Makefile create mode 100644 datacatalog/boilerplate/lyft/docker_build/Readme.rst create mode 100755 datacatalog/boilerplate/lyft/docker_build/docker_build.sh create mode 100644 datacatalog/boilerplate/lyft/golang_dockerfile/Dockerfile.GoTemplate create mode 100644 datacatalog/boilerplate/lyft/golang_dockerfile/Readme.rst create mode 100755 datacatalog/boilerplate/lyft/golang_dockerfile/update.sh create mode 100644 datacatalog/boilerplate/lyft/golang_test_targets/Makefile create mode 100644 datacatalog/boilerplate/lyft/golang_test_targets/Readme.rst create mode 100755 datacatalog/boilerplate/lyft/golang_test_targets/goimports create mode 100644 datacatalog/boilerplate/lyft/golangci_file/.golangci.yml create mode 100644 datacatalog/boilerplate/lyft/golangci_file/Readme.rst create mode 100755 datacatalog/boilerplate/lyft/golangci_file/update.sh create mode 100644 datacatalog/boilerplate/update.cfg create mode 100755 datacatalog/boilerplate/update.sh create mode 100644 datacatalog/cmd/entrypoints/migrate.go create mode 100644 datacatalog/cmd/entrypoints/root.go create mode 100644 datacatalog/cmd/entrypoints/serve.go create mode 100644 datacatalog/cmd/entrypoints/serve_dummy.go create mode 100644 datacatalog/cmd/main.go create mode 100644 datacatalog/datacatalog_config.yaml create mode 100644 datacatalog/pkg/config/config.go create mode 100755 datacatalog/pkg/config/config_flags.go create mode 100755 datacatalog/pkg/config/config_flags_test.go create mode 100644 datacatalog/pkg/errors/errors.go create mode 100644 datacatalog/pkg/manager/impl/artifact_data_store.go create mode 100644 datacatalog/pkg/manager/impl/artifact_manager.go create mode 100644 datacatalog/pkg/manager/impl/artifact_manager_test.go create mode 100644 datacatalog/pkg/manager/impl/dataset_manager.go create mode 100644 datacatalog/pkg/manager/impl/dataset_manager_test.go create mode 100644 datacatalog/pkg/manager/impl/tag_manager.go create mode 100644 datacatalog/pkg/manager/impl/tag_manager_test.go create mode 100644 datacatalog/pkg/manager/impl/validators/artifact_validator.go create mode 100644 datacatalog/pkg/manager/impl/validators/common.go create mode 100644 datacatalog/pkg/manager/impl/validators/dataset_validator.go create mode 100644 datacatalog/pkg/manager/impl/validators/errors.go create mode 100644 datacatalog/pkg/manager/impl/validators/tag_validator.go create mode 100644 datacatalog/pkg/manager/interfaces/artifact.go create mode 100644 datacatalog/pkg/manager/interfaces/dataset.go create mode 100644 datacatalog/pkg/manager/interfaces/tag.go create mode 100644 datacatalog/pkg/manager/mocks/artifact.go create mode 100644 datacatalog/pkg/manager/mocks/dataset.go create mode 100644 datacatalog/pkg/manager/mocks/tag.go create mode 100644 datacatalog/pkg/repositories/config/database.go create mode 100755 datacatalog/pkg/repositories/config/dbconfigsection_flags.go create mode 100755 datacatalog/pkg/repositories/config/dbconfigsection_flags_test.go create mode 100644 datacatalog/pkg/repositories/config/postgres.go create mode 100644 datacatalog/pkg/repositories/errors/error_transformer.go create mode 100644 datacatalog/pkg/repositories/errors/errors.go create mode 100644 datacatalog/pkg/repositories/errors/postgres.go create mode 100644 datacatalog/pkg/repositories/factory.go create mode 100644 datacatalog/pkg/repositories/gormimpl/artifact.go create mode 100644 datacatalog/pkg/repositories/gormimpl/artifact_test.go create mode 100644 datacatalog/pkg/repositories/gormimpl/dataset.go create mode 100644 datacatalog/pkg/repositories/gormimpl/dataset_test.go create mode 100644 datacatalog/pkg/repositories/gormimpl/tag.go create mode 100644 datacatalog/pkg/repositories/gormimpl/tag_test.go create mode 100644 datacatalog/pkg/repositories/handle.go create mode 100644 datacatalog/pkg/repositories/handle_test.go create mode 100644 datacatalog/pkg/repositories/interfaces/artifact_repo.go create mode 100644 datacatalog/pkg/repositories/interfaces/base.go create mode 100644 datacatalog/pkg/repositories/interfaces/dataset_repo.go create mode 100644 datacatalog/pkg/repositories/interfaces/tag_repo.go create mode 100644 datacatalog/pkg/repositories/mocks/artifact.go create mode 100644 datacatalog/pkg/repositories/mocks/base.go create mode 100644 datacatalog/pkg/repositories/mocks/dataset.go create mode 100644 datacatalog/pkg/repositories/mocks/tag.go create mode 100644 datacatalog/pkg/repositories/models/artifact.go create mode 100644 datacatalog/pkg/repositories/models/base.go create mode 100644 datacatalog/pkg/repositories/models/dataset.go create mode 100644 datacatalog/pkg/repositories/models/tag.go create mode 100644 datacatalog/pkg/repositories/postgres_repo.go create mode 100644 datacatalog/pkg/repositories/transformers/artifact.go create mode 100644 datacatalog/pkg/repositories/transformers/artifact_test.go create mode 100644 datacatalog/pkg/repositories/transformers/dataset.go create mode 100644 datacatalog/pkg/repositories/transformers/dataset_test.go create mode 100644 datacatalog/pkg/repositories/transformers/tag.go create mode 100644 datacatalog/pkg/repositories/transformers/tag_test.go create mode 100644 datacatalog/pkg/repositories/transformers/util.go create mode 100644 datacatalog/pkg/repositories/transformers/util_test.go create mode 100644 datacatalog/pkg/repositories/utils/test_utils.go create mode 100644 datacatalog/pkg/rpc/datacatalogservice/service.go create mode 100644 datacatalog/pkg/runtime/application_config_provider.go create mode 100644 datacatalog/pkg/runtime/configs/data_catalog_config.go create mode 100755 datacatalog/pkg/runtime/configs/datacatalogconfig_flags.go create mode 100755 datacatalog/pkg/runtime/configs/datacatalogconfig_flags_test.go create mode 100644 datacatalog/pkg/runtime/configuration_provider.go create mode 100644 datacatalog/protos/gen/service.pb.go create mode 100644 datacatalog/protos/idl/service.proto diff --git a/datacatalog/.gitignore b/datacatalog/.gitignore new file mode 100644 index 0000000000..72ff26ecff --- /dev/null +++ b/datacatalog/.gitignore @@ -0,0 +1,5 @@ +.idea/ +vendor/ +vendor-new/ +.DS_Store +bin/ diff --git a/datacatalog/.golangci.yml b/datacatalog/.golangci.yml new file mode 100644 index 0000000000..a414f33f79 --- /dev/null +++ b/datacatalog/.golangci.yml @@ -0,0 +1,30 @@ +# WARNING: THIS FILE IS MANAGED IN THE 'BOILERPLATE' REPO AND COPIED TO OTHER REPOSITORIES. +# ONLY EDIT THIS FILE FROM WITHIN THE 'LYFT/BOILERPLATE' REPOSITORY: +# +# TO OPT OUT OF UPDATES, SEE https://github.com/lyft/boilerplate/blob/master/Readme.rst + +run: + skip-dirs: + - pkg/client + +linters: + disable-all: true + enable: + - deadcode + - errcheck + - gas + - goconst + - goimports + - golint + - gosimple + - govet + - ineffassign + - misspell + - nakedret + - staticcheck + - structcheck + - typecheck + - unconvert + - unparam + - unused + - varcheck diff --git a/datacatalog/.travis.yml b/datacatalog/.travis.yml new file mode 100644 index 0000000000..671f277711 --- /dev/null +++ b/datacatalog/.travis.yml @@ -0,0 +1,26 @@ +sudo: required +language: go +go: + - "1.10" +services: + - docker +jobs: + include: + - if: fork = true + stage: test + name: docker build + install: true + script: make docker_build + - if: fork = false + stage: test + name: docker build and push + install: true + script: make dockerhub_push + - stage: test + install: make install + name: lint + script: make lint + - stage: test + name: unit tests + install: make install + script: make test_unit diff --git a/datacatalog/CODE_OF_CONDUCT.md b/datacatalog/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000..4c3a38cc48 --- /dev/null +++ b/datacatalog/CODE_OF_CONDUCT.md @@ -0,0 +1,2 @@ +This project is governed by [Lyft's code of conduct](https://github.com/lyft/code-of-conduct). +All contributors and participants agree to abide by its terms. diff --git a/datacatalog/Dockerfile b/datacatalog/Dockerfile new file mode 100644 index 0000000000..8129304d34 --- /dev/null +++ b/datacatalog/Dockerfile @@ -0,0 +1,33 @@ +# WARNING: THIS FILE IS MANAGED IN THE 'BOILERPLATE' REPO AND COPIED TO OTHER REPOSITORIES. +# ONLY EDIT THIS FILE FROM WITHIN THE 'LYFT/BOILERPLATE' REPOSITORY: +# +# TO OPT OUT OF UPDATES, SEE https://github.com/lyft/boilerplate/blob/master/Readme.rst + +# Using go1.10.4 +FROM golang:1.10.4-alpine3.8 as builder +RUN apk add git openssh-client make curl dep + +# COPY only the dep files for efficient caching +COPY Gopkg.* /go/src/github.com/lyft/datacatalog/ +WORKDIR /go/src/github.com/lyft/datacatalog + +# Pull dependencies +RUN dep ensure -vendor-only + +# COPY the rest of the source code +COPY . /go/src/github.com/lyft/datacatalog/ + +# This 'linux_compile' target should compile binaries to the /artifacts directory +# The main entrypoint should be compiled to /artifacts/datacatalog +RUN make linux_compile + +# update the PATH to include the /artifacts directory +ENV PATH="/artifacts:${PATH}" + +# This will eventually move to centurylink/ca-certs:latest for minimum possible image size +FROM alpine:3.8 +COPY --from=builder /artifacts /bin + +RUN apk --update add ca-certificates + +CMD ["datacatalog"] diff --git a/datacatalog/Gopkg.lock b/datacatalog/Gopkg.lock new file mode 100644 index 0000000000..130738217c --- /dev/null +++ b/datacatalog/Gopkg.lock @@ -0,0 +1,601 @@ +# This file is autogenerated, do not edit; changes may be undone by the next 'dep ensure'. + + +[[projects]] + digest = "1:920e6fa9e64d9a7524a4bb535f0a95fff8492d73ed26c81177ba381bd5d32aa2" + name = "github.com/Selvatico/go-mocket" + packages = ["."] + pruneopts = "NUT" + revision = "c368d4162be502eea110ae12fb85e98567b0f1e6" + version = "v1.0.7" + +[[projects]] + digest = "1:baf8d1f197a20ab6b3b92c63cb194ee8e4570d381d7a3299e0f5163252ba6df4" + name = "github.com/aws/aws-sdk-go" + packages = [ + "aws", + "aws/awserr", + "aws/awsutil", + "aws/client", + "aws/client/metadata", + "aws/corehandlers", + "aws/credentials", + "aws/credentials/ec2rolecreds", + "aws/credentials/endpointcreds", + "aws/credentials/processcreds", + "aws/credentials/stscreds", + "aws/csm", + "aws/defaults", + "aws/ec2metadata", + "aws/endpoints", + "aws/request", + "aws/session", + "aws/signer/v4", + "internal/ini", + "internal/s3err", + "internal/sdkio", + "internal/sdkrand", + "internal/sdkuri", + "internal/shareddefaults", + "private/protocol", + "private/protocol/eventstream", + "private/protocol/eventstream/eventstreamapi", + "private/protocol/json/jsonutil", + "private/protocol/query", + "private/protocol/query/queryutil", + "private/protocol/rest", + "private/protocol/restxml", + "private/protocol/xml/xmlutil", + "service/s3", + "service/sts", + "service/sts/stsiface", + ] + pruneopts = "NUT" + revision = "3d64de9191c7cf300083c8ff3049bf079c5b67ba" + version = "v1.21.3" + +[[projects]] + digest = "1:707ebe952a8b3d00b343c01536c79c73771d100f63ec6babeaed5c79e2b8a8dd" + name = "github.com/beorn7/perks" + packages = ["quantile"] + pruneopts = "NUT" + revision = "4b2b341e8d7715fae06375aa633dbb6e91b3fb46" + version = "v1.0.0" + +[[projects]] + digest = "1:0c31fa2fb2c809d61d640e28cc400087fe205df6ec9623dd1eb91a7de8d4f5d6" + name = "github.com/cespare/xxhash" + packages = ["."] + pruneopts = "NUT" + revision = "3b82fb7d186719faeedd0c2864f868c74fbf79a1" + version = "v2.0.0" + +[[projects]] + digest = "1:6f502d622502fc4a896b69acdf0837c550d7b3657d124c9870f5fcc681f15b53" + name = "github.com/coocood/freecache" + packages = ["."] + pruneopts = "NUT" + revision = "3c79a0a23c1940ab4479332fb3e0127265650ce3" + version = "v1.1.0" + +[[projects]] + digest = "1:ffe9824d294da03b391f44e1ae8281281b4afc1bdaa9588c9097785e3af10cec" + name = "github.com/davecgh/go-spew" + packages = ["spew"] + pruneopts = "NUT" + revision = "8991bc29aa16c548c550c7ff78260e27b9ab7c73" + version = "v1.1.1" + +[[projects]] + digest = "1:ade392a843b2035effb4b4a2efa2c3bab3eb29b992e98bacf9c898b0ecb54e45" + name = "github.com/fatih/color" + packages = ["."] + pruneopts = "NUT" + revision = "5b77d2a35fb0ede96d138fc9a99f5c9b6aef11b4" + version = "v1.7.0" + +[[projects]] + branch = "master" + digest = "1:91bbc8ba11c3bbee849fc48e6816e03b827d9847249c00a976c2fef2f58be59e" + name = "github.com/fsnotify/fsnotify" + packages = ["."] + pruneopts = "NUT" + revision = "1485a34d5d5723fea214f5710708e19a831720e4" + +[[projects]] + branch = "master" + digest = "1:e2b86e41f3d669fc36b50d31d32d22c8ac656c75aa5ea89717ce7177e134ff2a" + name = "github.com/golang/glog" + packages = ["."] + pruneopts = "NUT" + revision = "23def4e6c14b4da8ac2ed8007337bc5eb5007998" + +[[projects]] + digest = "1:3ea3429061d04eff320611732c74432f1e8fb0e61f69a86bce7d7c73de99e4f2" + name = "github.com/golang/protobuf" + packages = [ + "proto", + "ptypes", + "ptypes/any", + "ptypes/duration", + "ptypes/struct", + "ptypes/timestamp", + ] + pruneopts = "NUT" + revision = "b5d812f8a3706043e23a9cd5babf2e5423744d30" + version = "v1.3.1" + +[[projects]] + digest = "1:3c7231e9eb47d3b907d205991a5659fa30b4abed804eee7612d19c2da131f398" + name = "github.com/graymeta/stow" + packages = [ + ".", + "local", + "s3", + ] + pruneopts = "NUT" + revision = "77c84b1dd69c41b74fe0a94ca8ee257d85947327" + +[[projects]] + digest = "1:11c6c696067d3127ecf332b10f89394d386d9083f82baf71f40f2da31841a009" + name = "github.com/hashicorp/hcl" + packages = [ + ".", + "hcl/ast", + "hcl/parser", + "hcl/printer", + "hcl/scanner", + "hcl/strconv", + "hcl/token", + "json/parser", + "json/scanner", + "json/token", + ] + pruneopts = "NUT" + revision = "8cb6e5b959231cc1119e43259c4a608f9c51a241" + version = "v1.0.0" + +[[projects]] + digest = "1:406338ad39ab2e37b7f4452906442a3dbf0eb3379dd1f06aafb5c07e769a5fbb" + name = "github.com/inconshreveable/mousetrap" + packages = ["."] + pruneopts = "NUT" + revision = "76626ae9c91c4f2a10f34cad8ce83ea42c93bb75" + version = "v1.0" + +[[projects]] + digest = "1:8481b0bd2fa016025fe683e5b84d1051ebfbf356ae2d6aad5374f23edb21c10a" + name = "github.com/jinzhu/gorm" + packages = [ + ".", + "dialects/postgres", + "dialects/sqlite", + ] + pruneopts = "NUT" + revision = "836fb2c19d84dac7b0272958dfb9af7cf0d0ade4" + version = "v1.9.10" + +[[projects]] + digest = "1:802f75230c29108e787d40679f9bf5da1a5673eaf5c10eb89afd993e18972909" + name = "github.com/jinzhu/inflection" + packages = ["."] + pruneopts = "NUT" + revision = "f5c5f50e6090ae76a29240b61ae2a90dd810112e" + version = "v1.0.0" + +[[projects]] + digest = "1:1f2aebae7e7c856562355ec0198d8ca2fa222fb05e5b1b66632a1fce39631885" + name = "github.com/jmespath/go-jmespath" + packages = ["."] + pruneopts = "NUT" + revision = "c2b33e84" + +[[projects]] + digest = "1:58999a98719fddbac6303cb17e8d85b945f60b72f48e3a2df6b950b97fa926f1" + name = "github.com/konsorten/go-windows-terminal-sequences" + packages = ["."] + pruneopts = "NUT" + revision = "f55edac94c9bbba5d6182a4be46d86a2c9b5b50e" + version = "v1.0.2" + +[[projects]] + digest = "1:d8d85b5aace516f6897fcc81d3ad2da280e1d4d2ccd22c7270649bf44715d7a9" + name = "github.com/lib/pq" + packages = [ + ".", + "hstore", + "oid", + "scram", + ] + pruneopts = "NUT" + revision = "3427c32cb71afc948325f299f040e53c1dd78979" + version = "v1.2.0" + +[[projects]] + digest = "1:72841617053e049a34e8d98232b863fe5c173cdf03a4d8bc8dc9039303ad418e" + name = "github.com/lyft/flyteidl" + packages = ["gen/pb-go/flyteidl/core"] + pruneopts = "T" + revision = "793b09d190148236f41ad8160b5cec9a3325c16f" + version = "v0.1.0" + +[[projects]] + digest = "1:6cc3cfda698262608d464cd89bfab217bb8fa8d507bf13b82cd585a520aed37d" + name = "github.com/lyft/flytestdlib" + packages = [ + "atomic", + "config", + "config/files", + "config/viper", + "contextutils", + "ioutils", + "logger", + "promutils", + "promutils/labeled", + "storage", + ] + pruneopts = "NUT" + revision = "2577ff228d559b8fdf687f6cfad196bfbf1bd50a" + version = "v0.2.10" + +[[projects]] + digest = "1:802689c84994b7f2e41ffe7e39d29a8d8227f2121938dc025db44dfaa9633b15" + name = "github.com/magiconair/properties" + packages = ["."] + pruneopts = "NUT" + revision = "de8848e004dd33dc07a2947b3d76f618a7fc7ef1" + version = "v1.8.1" + +[[projects]] + digest = "1:08c231ec84231a7e23d67e4b58f975e1423695a32467a362ee55a803f9de8061" + name = "github.com/mattn/go-colorable" + packages = ["."] + pruneopts = "NUT" + revision = "167de6bfdfba052fa6b2d3664c8f5272e23c9072" + version = "v0.0.9" + +[[projects]] + digest = "1:666428435471609937bd13f56e2e31501b5660ffc44a786cf2217dc53291a604" + name = "github.com/mattn/go-isatty" + packages = ["."] + pruneopts = "NUT" + revision = "1311e847b0cb909da63b5fecfb5370aa66236465" + version = "v0.0.8" + +[[projects]] + digest = "1:c9724c929d27a14475a45b17a267dbc60671c0bc2c5c05ed21f011f7b5bc9fb5" + name = "github.com/mattn/go-sqlite3" + packages = ["."] + pruneopts = "NUT" + revision = "c7c4067b79cc51e6dfdcef5c702e74b1e0fa7c75" + version = "v1.10.0" + +[[projects]] + digest = "1:5985ef4caf91ece5d54817c11ea25f182697534f8ae6521eadcd628c142ac4b6" + name = "github.com/matttproud/golang_protobuf_extensions" + packages = ["pbutil"] + pruneopts = "NUT" + revision = "c12348ce28de40eed0136aa2b644d0ee0650e56c" + version = "v1.0.1" + +[[projects]] + digest = "1:a45ae66dea4c899d79fceb116accfa1892105c251f0dcd9a217ddc276b42ec68" + name = "github.com/mitchellh/mapstructure" + packages = ["."] + pruneopts = "NUT" + revision = "3536a929edddb9a5b34bd6861dc4a9647cb459fe" + version = "v1.1.2" + +[[projects]] + digest = "1:4e9827f31d4fc1ddd732a0e3af4e863d281dd405adb2bfb96a25cc5346a77caf" + name = "github.com/pelletier/go-toml" + packages = ["."] + pruneopts = "NUT" + revision = "728039f679cbcd4f6a54e080d2219a4c4928c546" + version = "v1.4.0" + +[[projects]] + digest = "1:14715f705ff5dfe0ffd6571d7d201dd8e921030f8070321a79380d8ca4ec1a24" + name = "github.com/pkg/errors" + packages = ["."] + pruneopts = "NUT" + revision = "ba968bfe8b2f7e042a574c888954fccecfa385b4" + version = "v0.8.1" + +[[projects]] + digest = "1:0028cb19b2e4c3112225cd871870f2d9cf49b9b4276531f03438a88e94be86fe" + name = "github.com/pmezard/go-difflib" + packages = ["difflib"] + pruneopts = "NUT" + revision = "792786c7400a136282c1664665ae0a8db921c6c2" + version = "v1.0.0" + +[[projects]] + digest = "1:0f362379987ecc2cf4df1b8e4c1653a782f6f9f77f749547b734499b3c543080" + name = "github.com/prometheus/client_golang" + packages = [ + "prometheus", + "prometheus/internal", + ] + pruneopts = "NUT" + revision = "2641b987480bca71fb39738eb8c8b0d577cb1d76" + version = "v0.9.4" + +[[projects]] + branch = "master" + digest = "1:2d5cd61daa5565187e1d96bae64dbbc6080dacf741448e9629c64fd93203b0d4" + name = "github.com/prometheus/client_model" + packages = ["go"] + pruneopts = "NUT" + revision = "fd36f4220a901265f90734c3183c5f0c91daa0b8" + +[[projects]] + digest = "1:d03ca24670416dc8fccc78b05d6736ec655416ca7db0a028e8fb92cfdfe3b55e" + name = "github.com/prometheus/common" + packages = [ + "expfmt", + "internal/bitbucket.org/ww/goautoneg", + "model", + ] + pruneopts = "NUT" + revision = "31bed53e4047fd6c510e43a941f90cb31be0972a" + version = "v0.6.0" + +[[projects]] + digest = "1:19305fc369377c111c865a7a01e11c675c57c52a932353bbd4ea360bd5b72d99" + name = "github.com/prometheus/procfs" + packages = [ + ".", + "internal/fs", + ] + pruneopts = "NUT" + revision = "3f98efb27840a48a7a2898ec80be07674d19f9c8" + version = "v0.0.3" + +[[projects]] + digest = "1:f4aaa07a6c33f2b354726d0571acbc8ca118837c75709f6353203ae1a3f8eeab" + name = "github.com/sirupsen/logrus" + packages = ["."] + pruneopts = "NUT" + revision = "839c75faf7f98a33d445d181f3018b5c3409a45e" + version = "v1.4.2" + +[[projects]] + digest = "1:6792bb72ea0e7112157d02e4e175cd421b43d004a853f56316a19beca6e0c074" + name = "github.com/spf13/afero" + packages = [ + ".", + "mem", + ] + pruneopts = "NUT" + revision = "588a75ec4f32903aa5e39a2619ba6a4631e28424" + version = "v1.2.2" + +[[projects]] + digest = "1:c5e6b121ef3d2043505edaf4c80e5a008cec2513dc8804795eb0479d1555bcf7" + name = "github.com/spf13/cast" + packages = ["."] + pruneopts = "NUT" + revision = "8c9545af88b134710ab1cd196795e7f2388358d7" + version = "v1.3.0" + +[[projects]] + digest = "1:343d44e06621142ab09ae0c76c1799104cdfddd3ffb445d78b1adf8dc3ffaf3d" + name = "github.com/spf13/cobra" + packages = ["."] + pruneopts = "NUT" + revision = "ef82de70bb3f60c65fb8eebacbb2d122ef517385" + version = "v0.0.3" + +[[projects]] + digest = "1:3d72352adb74e79d6d5a43d6f51bfd2d0bd0c9b5f3c00cf5a4b1636d8d3b9d92" + name = "github.com/spf13/jwalterweatherman" + packages = ["."] + pruneopts = "NUT" + revision = "94f6ae3ed3bceceafa716478c5fbf8d29ca601a1" + version = "v1.1.0" + +[[projects]] + digest = "1:9d8420bbf131d1618bde6530af37c3799340d3762cc47210c1d9532a4c3a2779" + name = "github.com/spf13/pflag" + packages = ["."] + pruneopts = "NUT" + revision = "298182f68c66c05229eb03ac171abe6e309ee79a" + version = "v1.0.3" + +[[projects]] + digest = "1:7e1a17cfd0ad758abc27d10beeda550f5bb1e1b8746f0d67040a45775931c226" + name = "github.com/spf13/viper" + packages = ["."] + pruneopts = "NUT" + revision = "ae103d7e593e371c69e832d5eb3347e2b80cbbc9" + +[[projects]] + digest = "1:60a46e2410edbf02b419f833372dd1d24d7aa1b916a990a7370e792fada1eadd" + name = "github.com/stretchr/objx" + packages = ["."] + pruneopts = "NUT" + revision = "477a77ecc69700c7cdeb1fa9e129548e1c1c393c" + version = "v0.1.1" + +[[projects]] + digest = "1:4fec2fe9ce617252d634de83ca8dd5f9c7c391637540896542c80ab63f071f51" + name = "github.com/stretchr/testify" + packages = [ + "assert", + "mock", + ] + pruneopts = "NUT" + revision = "ffdc059bfe9ce6a4e144ba849dbedead332c6053" + version = "v1.3.0" + +[[projects]] + branch = "master" + digest = "1:bafd5fe279fd563e58b661da3c7c258beb33b55d48294446a8870caeecefd3a1" + name = "golang.org/x/net" + packages = [ + "http/httpguts", + "http2", + "http2/hpack", + "idna", + "internal/timeseries", + "trace", + ] + pruneopts = "NUT" + revision = "1f3472d942ba824034fb77cab6a6cfc1bc8a2c3c" + +[[projects]] + branch = "master" + digest = "1:f632c225ef300cb47bb7363bc1006ca57601af35f03b8fe41d03ed67b581c4fc" + name = "golang.org/x/sys" + packages = ["unix"] + pruneopts = "NUT" + revision = "e8e3143a4f4a00f1fafef0dd82ba78223281b01b" + +[[projects]] + digest = "1:e7071ed636b5422cc51c0e3a6cebc229d6c9fffc528814b519a980641422d619" + name = "golang.org/x/text" + packages = [ + "collate", + "collate/build", + "internal/colltab", + "internal/gen", + "internal/tag", + "internal/triegen", + "internal/ucd", + "language", + "secure/bidirule", + "transform", + "unicode/bidi", + "unicode/cldr", + "unicode/norm", + "unicode/rangetable", + ] + pruneopts = "NUT" + revision = "f21a4dfb5e38f5895301dc265a8def02365cc3d0" + version = "v0.3.0" + +[[projects]] + branch = "master" + digest = "1:9fdc2b55e8e0fafe4b41884091e51e77344f7dc511c5acedcfd98200003bff90" + name = "golang.org/x/time" + packages = ["rate"] + pruneopts = "NUT" + revision = "9d24e82272b4f38b78bc8cff74fa936d31ccd8ef" + +[[projects]] + branch = "master" + digest = "1:c3076e7defee87de1236f1814beb588f40a75544c60121e6eb38b3b3721783e2" + name = "google.golang.org/genproto" + packages = ["googleapis/rpc/status"] + pruneopts = "NUT" + revision = "e7d98fc518a78c9f8b5ee77be7b0b317475d89e1" + +[[projects]] + digest = "1:ffb498178a6bbe5a877e715cc85a40d5a712883d85f5bf05acf26dbd6c8f71e2" + name = "google.golang.org/grpc" + packages = [ + ".", + "balancer", + "balancer/base", + "balancer/roundrobin", + "binarylog/grpc_binarylog_v1", + "codes", + "connectivity", + "credentials", + "credentials/internal", + "encoding", + "encoding/proto", + "grpclog", + "internal", + "internal/backoff", + "internal/balancerload", + "internal/binarylog", + "internal/channelz", + "internal/envconfig", + "internal/grpcrand", + "internal/grpcsync", + "internal/syscall", + "internal/transport", + "keepalive", + "metadata", + "naming", + "peer", + "resolver", + "resolver/dns", + "resolver/passthrough", + "stats", + "status", + "tap", + ] + pruneopts = "NUT" + revision = "25c4f928eaa6d96443009bd842389fb4fa48664e" + version = "v1.20.1" + +[[projects]] + digest = "1:18108594151654e9e696b27b181b953f9a90b16bf14d253dd1b397b025a1487f" + name = "gopkg.in/yaml.v2" + packages = ["."] + pruneopts = "NUT" + revision = "51d6538a90f86fe93ac480b35f37b2be17fef232" + version = "v2.2.2" + +[[projects]] + digest = "1:47d558c9776c006fd56b90df0ffddf57d11bedcd4f62a75153f230fe55873a37" + name = "k8s.io/apimachinery" + packages = [ + "pkg/util/clock", + "pkg/util/rand", + "pkg/util/runtime", + ] + pruneopts = "NUT" + revision = "2b1284ed4c93a43499e781493253e2ac5959c4fd" + version = "kubernetes-1.13.1" + +[[projects]] + digest = "1:b6412f8acd9a9fc6fb67302c24966618b16501b9d769a20bee42ce61e510c92c" + name = "k8s.io/client-go" + packages = ["util/workqueue"] + pruneopts = "NUT" + revision = "8d9ed539ba3134352c586810e749e58df4e94e4f" + version = "kubernetes-1.13.1" + +[[projects]] + digest = "1:43099cc4ed575c40f80277c7ba7168df37d0c663bdc4f541325430bd175cce8a" + name = "k8s.io/klog" + packages = ["."] + pruneopts = "NUT" + revision = "d98d8acdac006fb39831f1b25640813fef9c314f" + version = "v0.3.3" + +[solve-meta] + analyzer-name = "dep" + analyzer-version = 1 + input-imports = [ + "github.com/Selvatico/go-mocket", + "github.com/golang/glog", + "github.com/golang/protobuf/proto", + "github.com/jinzhu/gorm", + "github.com/jinzhu/gorm/dialects/postgres", + "github.com/jinzhu/gorm/dialects/sqlite", + "github.com/lib/pq", + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core", + "github.com/lyft/flytestdlib/config", + "github.com/lyft/flytestdlib/config/viper", + "github.com/lyft/flytestdlib/contextutils", + "github.com/lyft/flytestdlib/logger", + "github.com/lyft/flytestdlib/promutils", + "github.com/lyft/flytestdlib/promutils/labeled", + "github.com/lyft/flytestdlib/storage", + "github.com/mitchellh/mapstructure", + "github.com/pkg/errors", + "github.com/spf13/cobra", + "github.com/spf13/pflag", + "github.com/stretchr/testify/assert", + "github.com/stretchr/testify/mock", + "google.golang.org/grpc", + "google.golang.org/grpc/codes", + "google.golang.org/grpc/status", + ] + solver-name = "gps-cdcl" + solver-version = 1 diff --git a/datacatalog/Gopkg.toml b/datacatalog/Gopkg.toml new file mode 100644 index 0000000000..f1d8843b74 --- /dev/null +++ b/datacatalog/Gopkg.toml @@ -0,0 +1,34 @@ +# Gopkg.toml example +# +# Refer to https://golang.github.io/dep/docs/Gopkg.toml.html +# for detailed Gopkg.toml documentation. +# +# required = ["github.com/user/thing/cmd/thing"] +# ignored = ["github.com/user/project/pkgX", "bitbucket.org/user/project/pkgA/pkgY"] +# +# [[constraint]] +# name = "github.com/user/project" +# version = "1.0.0" +# +# [[constraint]] +# name = "github.com/user/project2" +# branch = "dev" +# source = "github.com/myfork/project2" +# +# [[override]] +# name = "github.com/x/y" +# version = "2.4.0" +# +# [prune] +# non-go = false +# go-tests = true +# unused-packages = true +[prune] + go-tests = true + unused-packages = true + non-go = true + + [[prune.project]] + name = "github.com/lyft/flyteidl" + non-go = false + unused-packages = false diff --git a/datacatalog/LICENSE b/datacatalog/LICENSE new file mode 100644 index 0000000000..bed437514f --- /dev/null +++ b/datacatalog/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2019 Lyft, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/datacatalog/Makefile b/datacatalog/Makefile new file mode 100644 index 0000000000..4ef36d1fb7 --- /dev/null +++ b/datacatalog/Makefile @@ -0,0 +1,20 @@ +export REPOSITORY=datacatalog +include boilerplate/lyft/docker_build/Makefile +include boilerplate/lyft/golang_test_targets/Makefile + +.PHONY: update_boilerplate +update_boilerplate: + @boilerplate/update.sh + +.PHONY: compile +compile: + mkdir -p ./bin + go build -o datacatalog ./cmd/main.go && mv ./datacatalog ./bin + +.PHONY: linux_compile +linux_compile: + GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -o /artifacts/datacatalog ./cmd/ + +.PHONY: generate_idl +generate_idl: + protoc -I ./vendor/github.com/lyft/flyteidl/protos/ -I ./protos/idl/. --go_out=plugins=grpc:protos/gen ./protos/idl/service.proto diff --git a/datacatalog/NOTICE b/datacatalog/NOTICE new file mode 100644 index 0000000000..0f113fb25a --- /dev/null +++ b/datacatalog/NOTICE @@ -0,0 +1,5 @@ +datacatalog +Copyright 2019-2020 Lyft Inc. + +This product includes software developed at Lyft Inc. + diff --git a/datacatalog/README.md b/datacatalog/README.md new file mode 100644 index 0000000000..9c80ac52f3 --- /dev/null +++ b/datacatalog/README.md @@ -0,0 +1,2 @@ +# datacatalog +Service that catalogs data to allow for data discovery, lineage and tagging diff --git a/datacatalog/boilerplate/lyft/docker_build/Makefile b/datacatalog/boilerplate/lyft/docker_build/Makefile new file mode 100644 index 0000000000..4019dab839 --- /dev/null +++ b/datacatalog/boilerplate/lyft/docker_build/Makefile @@ -0,0 +1,12 @@ +# WARNING: THIS FILE IS MANAGED IN THE 'BOILERPLATE' REPO AND COPIED TO OTHER REPOSITORIES. +# ONLY EDIT THIS FILE FROM WITHIN THE 'LYFT/BOILERPLATE' REPOSITORY: +# +# TO OPT OUT OF UPDATES, SEE https://github.com/lyft/boilerplate/blob/master/Readme.rst + +.PHONY: docker_build +docker_build: + IMAGE_NAME=$$REPOSITORY ./boilerplate/lyft/docker_build/docker_build.sh + +.PHONY: dockerhub_push +dockerhub_push: + IMAGE_NAME=lyft/$$REPOSITORY REGISTRY=docker.io ./boilerplate/lyft/docker_build/docker_build.sh diff --git a/datacatalog/boilerplate/lyft/docker_build/Readme.rst b/datacatalog/boilerplate/lyft/docker_build/Readme.rst new file mode 100644 index 0000000000..bb6af9b49e --- /dev/null +++ b/datacatalog/boilerplate/lyft/docker_build/Readme.rst @@ -0,0 +1,23 @@ +Docker Build and Push +~~~~~~~~~~~~~~~~~~~~~ + +Provides a ``make docker_build`` target that builds your image locally. + +Provides a ``make dockerhub_push`` target that pushes your final image to Dockerhub. + +The Dockerhub image will tagged ``:`` + +If git head has a git tag, the Dockerhub image will also be tagged ``:``. + +**To Enable:** + +Add ``lyft/docker_build`` to your ``boilerplate/update.cfg`` file. + +Add ``include boilerplate/lyft/docker_build/Makefile`` in your main ``Makefile`` _after_ your REPOSITORY environment variable + +:: + + REPOSITORY= + include boilerplate/lyft/docker_build/Makefile + +(this ensures the extra Make targets get included in your main Makefile) diff --git a/datacatalog/boilerplate/lyft/docker_build/docker_build.sh b/datacatalog/boilerplate/lyft/docker_build/docker_build.sh new file mode 100755 index 0000000000..f504c100c7 --- /dev/null +++ b/datacatalog/boilerplate/lyft/docker_build/docker_build.sh @@ -0,0 +1,67 @@ +#!/usr/bin/env bash + +# WARNING: THIS FILE IS MANAGED IN THE 'BOILERPLATE' REPO AND COPIED TO OTHER REPOSITORIES. +# ONLY EDIT THIS FILE FROM WITHIN THE 'LYFT/BOILERPLATE' REPOSITORY: +# +# TO OPT OUT OF UPDATES, SEE https://github.com/lyft/boilerplate/blob/master/Readme.rst + +set -e + +echo "" +echo "------------------------------------" +echo " DOCKER BUILD" +echo "------------------------------------" +echo "" + +if [ -n "$REGISTRY" ]; then + # Do not push if there are unstaged git changes + CHANGED=$(git status --porcelain) + if [ -n "$CHANGED" ]; then + echo "Please commit git changes before pushing to a registry" + exit 1 + fi +fi + + +GIT_SHA=$(git rev-parse HEAD) + +IMAGE_TAG_SUFFIX="" +# for intermediate build phases, append -$BUILD_PHASE to all image tags +if [ -n "$BUILD_PHASE" ]; then + IMAGE_TAG_SUFFIX="-${BUILD_PHASE}" +fi + +IMAGE_TAG_WITH_SHA="${IMAGE_NAME}:${GIT_SHA}${IMAGE_TAG_SUFFIX}" + +RELEASE_SEMVER=$(git describe --tags --exact-match "$GIT_SHA" 2>/dev/null) || true +if [ -n "$RELEASE_SEMVER" ]; then + IMAGE_TAG_WITH_SEMVER="${IMAGE_NAME}:${RELEASE_SEMVER}${IMAGE_TAG_SUFFIX}" +fi + +# build the image +# passing no build phase will build the final image +docker build -t "$IMAGE_TAG_WITH_SHA" --target=${BUILD_PHASE} . +echo "${IMAGE_TAG_WITH_SHA} built locally." + +# if REGISTRY specified, push the images to the remote registy +if [ -n "$REGISTRY" ]; then + + if [ -n "${DOCKER_REGISTRY_PASSWORD}" ]; then + docker login --username="$DOCKER_REGISTRY_USERNAME" --password="$DOCKER_REGISTRY_PASSWORD" + fi + + docker tag "$IMAGE_TAG_WITH_SHA" "${REGISTRY}/${IMAGE_TAG_WITH_SHA}" + + docker push "${REGISTRY}/${IMAGE_TAG_WITH_SHA}" + echo "${REGISTRY}/${IMAGE_TAG_WITH_SHA} pushed to remote." + + # If the current commit has a semver tag, also push the images with the semver tag + if [ -n "$RELEASE_SEMVER" ]; then + + docker tag "$IMAGE_TAG_WITH_SHA" "${REGISTRY}/${IMAGE_TAG_WITH_SEMVER}" + + docker push "${REGISTRY}/${IMAGE_TAG_WITH_SEMVER}" + echo "${REGISTRY}/${IMAGE_TAG_WITH_SEMVER} pushed to remote." + + fi +fi diff --git a/datacatalog/boilerplate/lyft/golang_dockerfile/Dockerfile.GoTemplate b/datacatalog/boilerplate/lyft/golang_dockerfile/Dockerfile.GoTemplate new file mode 100644 index 0000000000..5e7b984a11 --- /dev/null +++ b/datacatalog/boilerplate/lyft/golang_dockerfile/Dockerfile.GoTemplate @@ -0,0 +1,33 @@ +# WARNING: THIS FILE IS MANAGED IN THE 'BOILERPLATE' REPO AND COPIED TO OTHER REPOSITORIES. +# ONLY EDIT THIS FILE FROM WITHIN THE 'LYFT/BOILERPLATE' REPOSITORY: +# +# TO OPT OUT OF UPDATES, SEE https://github.com/lyft/boilerplate/blob/master/Readme.rst + +# Using go1.10.4 +FROM golang:1.10.4-alpine3.8 as builder +RUN apk add git openssh-client make curl dep + +# COPY only the dep files for efficient caching +COPY Gopkg.* /go/src/github.com/lyft/{{REPOSITORY}}/ +WORKDIR /go/src/github.com/lyft/{{REPOSITORY}} + +# Pull dependencies +RUN dep ensure -vendor-only + +# COPY the rest of the source code +COPY . /go/src/github.com/lyft/{{REPOSITORY}}/ + +# This 'linux_compile' target should compile binaries to the /artifacts directory +# The main entrypoint should be compiled to /artifacts/{{REPOSITORY}} +RUN make linux_compile + +# update the PATH to include the /artifacts directory +ENV PATH="/artifacts:${PATH}" + +# This will eventually move to centurylink/ca-certs:latest for minimum possible image size +FROM alpine:3.8 +COPY --from=builder /artifacts /bin + +RUN apk --update add ca-certificates + +CMD ["{{REPOSITORY}}"] diff --git a/datacatalog/boilerplate/lyft/golang_dockerfile/Readme.rst b/datacatalog/boilerplate/lyft/golang_dockerfile/Readme.rst new file mode 100644 index 0000000000..f801ef98d6 --- /dev/null +++ b/datacatalog/boilerplate/lyft/golang_dockerfile/Readme.rst @@ -0,0 +1,16 @@ +Golang Dockerfile +~~~~~~~~~~~~~~~~~ + +Provides a Dockerfile that produces a small image. + +**To Enable:** + +Add ``lyft/golang_dockerfile`` to your ``boilerplate/update.cfg`` file. + +Create and configure a ``make linux_compile`` target that compiles your go binaries to the ``/artifacts`` directory :: + + .PHONY: linux_compile + linux_compile: + RUN GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -o /artifacts {{ packages }} + +All binaries compiled to ``/artifacts`` will be available at ``/bin`` in your final image. diff --git a/datacatalog/boilerplate/lyft/golang_dockerfile/update.sh b/datacatalog/boilerplate/lyft/golang_dockerfile/update.sh new file mode 100755 index 0000000000..293eedcb8b --- /dev/null +++ b/datacatalog/boilerplate/lyft/golang_dockerfile/update.sh @@ -0,0 +1,15 @@ +#!/usr/bin/env bash + +# WARNING: THIS FILE IS MANAGED IN THE 'BOILERPLATE' REPO AND COPIED TO OTHER REPOSITORIES. +# ONLY EDIT THIS FILE FROM WITHIN THE 'LYFT/BOILERPLATE' REPOSITORY: +# +# TO OPT OUT OF UPDATES, SEE https://github.com/lyft/boilerplate/blob/master/Readme.rst + +set -e + +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null && pwd )" + +# Dockerfile for Golang Services +echo " - generating Dockerfile in root directory." +# replace FLYTEADMIN_SHA with the test SHA +sed -e "s/{{REPOSITORY}}/${REPOSITORY}/g" ${DIR}/Dockerfile.GoTemplate > ${DIR}/../../../Dockerfile diff --git a/datacatalog/boilerplate/lyft/golang_test_targets/Makefile b/datacatalog/boilerplate/lyft/golang_test_targets/Makefile new file mode 100644 index 0000000000..6c1e527fd6 --- /dev/null +++ b/datacatalog/boilerplate/lyft/golang_test_targets/Makefile @@ -0,0 +1,38 @@ +# WARNING: THIS FILE IS MANAGED IN THE 'BOILERPLATE' REPO AND COPIED TO OTHER REPOSITORIES. +# ONLY EDIT THIS FILE FROM WITHIN THE 'LYFT/BOILERPLATE' REPOSITORY: +# +# TO OPT OUT OF UPDATES, SEE https://github.com/lyft/boilerplate/blob/master/Readme.rst + +DEP_SHA=1f7c19e5f52f49ffb9f956f64c010be14683468b + +.PHONY: lint +lint: #lints the package for common code smells + which golangci-lint || curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | bash -s -- -b $$GOPATH/bin v1.16.0 + golangci-lint run --exclude deprecated + +# If code is failing goimports linter, this will fix. +# skips 'vendor' +.PHONY: goimports +goimports: + @boilerplate/lyft/golang_test_targets/goimports + +.PHONY: install +install: #download dependencies (including test deps) for the package + which dep || (curl "https://raw.githubusercontent.com/golang/dep/${DEP_SHA}/install.sh" | sh) + dep ensure + +.PHONY: test_unit +test_unit: + go test -cover ./... -race + +.PHONY: test_benchmark +test_benchmark: + go test -bench . ./... + +.PHONY: test_unit_cover +test_unit_cover: + go test ./... -coverprofile /tmp/cover.out -covermode=count; go tool cover -func /tmp/cover.out + +.PHONY: test_unit_visual +test_unit_visual: + go test ./... -coverprofile /tmp/cover.out -covermode=count; go tool cover -html=/tmp/cover.out diff --git a/datacatalog/boilerplate/lyft/golang_test_targets/Readme.rst b/datacatalog/boilerplate/lyft/golang_test_targets/Readme.rst new file mode 100644 index 0000000000..acc5744f59 --- /dev/null +++ b/datacatalog/boilerplate/lyft/golang_test_targets/Readme.rst @@ -0,0 +1,31 @@ +Golang Test Targets +~~~~~~~~~~~~~~~~~~~ + +Provides an ``install`` make target that uses ``dep`` install golang dependencies. + +Provides a ``lint`` make target that uses golangci to lint your code. + +Provides a ``test_unit`` target for unit tests. + +Provides a ``test_unit_cover`` target for analysing coverage of unit tests, which will output the coverage of each function and total statement coverage. + +Provides a ``test_unit_visual`` target for visualizing coverage of unit tests through an interactive html code heat map. + +Provides a ``test_benchmark`` target for benchmark tests. + +**To Enable:** + +Add ``lyft/golang_test_targets`` to your ``boilerplate/update.cfg`` file. + +Make sure you're using ``dep`` for dependency management. + +Provide a ``.golangci`` configuration (the lint target requires it). + +Add ``include boilerplate/lyft/golang_test_targets/Makefile`` in your main ``Makefile`` _after_ your REPOSITORY environment variable + +:: + + REPOSITORY= + include boilerplate/lyft/golang_test_targets/Makefile + +(this ensures the extra make targets get included in your main Makefile) diff --git a/datacatalog/boilerplate/lyft/golang_test_targets/goimports b/datacatalog/boilerplate/lyft/golang_test_targets/goimports new file mode 100755 index 0000000000..160525a8cc --- /dev/null +++ b/datacatalog/boilerplate/lyft/golang_test_targets/goimports @@ -0,0 +1,8 @@ +#!/usr/bin/env bash + +# WARNING: THIS FILE IS MANAGED IN THE 'BOILERPLATE' REPO AND COPIED TO OTHER REPOSITORIES. +# ONLY EDIT THIS FILE FROM WITHIN THE 'LYFT/BOILERPLATE' REPOSITORY: +# +# TO OPT OUT OF UPDATES, SEE https://github.com/lyft/boilerplate/blob/master/Readme.rst + +goimports -w $(find . -type f -name '*.go' -not -path "./vendor/*" -not -path "./pkg/client/*") diff --git a/datacatalog/boilerplate/lyft/golangci_file/.golangci.yml b/datacatalog/boilerplate/lyft/golangci_file/.golangci.yml new file mode 100644 index 0000000000..a414f33f79 --- /dev/null +++ b/datacatalog/boilerplate/lyft/golangci_file/.golangci.yml @@ -0,0 +1,30 @@ +# WARNING: THIS FILE IS MANAGED IN THE 'BOILERPLATE' REPO AND COPIED TO OTHER REPOSITORIES. +# ONLY EDIT THIS FILE FROM WITHIN THE 'LYFT/BOILERPLATE' REPOSITORY: +# +# TO OPT OUT OF UPDATES, SEE https://github.com/lyft/boilerplate/blob/master/Readme.rst + +run: + skip-dirs: + - pkg/client + +linters: + disable-all: true + enable: + - deadcode + - errcheck + - gas + - goconst + - goimports + - golint + - gosimple + - govet + - ineffassign + - misspell + - nakedret + - staticcheck + - structcheck + - typecheck + - unconvert + - unparam + - unused + - varcheck diff --git a/datacatalog/boilerplate/lyft/golangci_file/Readme.rst b/datacatalog/boilerplate/lyft/golangci_file/Readme.rst new file mode 100644 index 0000000000..ba5d2b61ce --- /dev/null +++ b/datacatalog/boilerplate/lyft/golangci_file/Readme.rst @@ -0,0 +1,8 @@ +GolangCI File +~~~~~~~~~~~~~ + +Provides a ``.golangci`` file with the linters we've agreed upon. + +**To Enable:** + +Add ``lyft/golangci_file`` to your ``boilerplate/update.cfg`` file. diff --git a/datacatalog/boilerplate/lyft/golangci_file/update.sh b/datacatalog/boilerplate/lyft/golangci_file/update.sh new file mode 100755 index 0000000000..9e9e6c1f46 --- /dev/null +++ b/datacatalog/boilerplate/lyft/golangci_file/update.sh @@ -0,0 +1,14 @@ +#!/usr/bin/env bash + +# WARNING: THIS FILE IS MANAGED IN THE 'BOILERPLATE' REPO AND COPIED TO OTHER REPOSITORIES. +# ONLY EDIT THIS FILE FROM WITHIN THE 'LYFT/BOILERPLATE' REPOSITORY: +# +# TO OPT OUT OF UPDATES, SEE https://github.com/lyft/boilerplate/blob/master/Readme.rst + +set -e + +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null && pwd )" + +# Clone the .golangci file +echo " - copying ${DIR}/.golangci to the root directory." +cp ${DIR}/.golangci.yml ${DIR}/../../../.golangci.yml diff --git a/datacatalog/boilerplate/update.cfg b/datacatalog/boilerplate/update.cfg new file mode 100644 index 0000000000..5417c80464 --- /dev/null +++ b/datacatalog/boilerplate/update.cfg @@ -0,0 +1,4 @@ +lyft/docker_build +lyft/golang_test_targets +lyft/golangci_file +lyft/golang_dockerfile diff --git a/datacatalog/boilerplate/update.sh b/datacatalog/boilerplate/update.sh new file mode 100755 index 0000000000..bea661d9a0 --- /dev/null +++ b/datacatalog/boilerplate/update.sh @@ -0,0 +1,53 @@ +#!/usr/bin/env bash + +# WARNING: THIS FILE IS MANAGED IN THE 'BOILERPLATE' REPO AND COPIED TO OTHER REPOSITORIES. +# ONLY EDIT THIS FILE FROM WITHIN THE 'LYFT/BOILERPLATE' REPOSITORY: +# +# TO OPT OUT OF UPDATES, SEE https://github.com/lyft/boilerplate/blob/master/Readme.rst + +set -e + +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null && pwd )" + +OUT="$(mktemp -d)" +git clone git@github.com:lyft/boilerplate.git "${OUT}" + +echo "Updating the update.sh script." +cp "${OUT}/boilerplate/update.sh" "${DIR}/update.sh" +echo "" + + +CONFIG_FILE="${DIR}/update.cfg" +README="https://github.com/lyft/boilerplate/blob/master/Readme.rst" + +if [ ! -f "$CONFIG_FILE" ]; then + echo "$CONFIG_FILE not found." + echo "This file is required in order to select which features to include." + echo "See $README for more details." + exit 1 +fi + +if [ -z "$REPOSITORY" ]; then + echo '$REPOSITORY is required to run this script' + echo "See $README for more details." + exit 1 +fi + +while read directory; do + echo "***********************************************************************************" + echo "$directory is configured in update.cfg." + echo "-----------------------------------------------------------------------------------" + echo "syncing files from source." + dir_path="${OUT}/boilerplate/${directory}" + rm -rf "${DIR}/${directory}" + mkdir -p $(dirname "${DIR}/${directory}") + cp -r "$dir_path" "${DIR}/${directory}" + if [ -f "${DIR}/${directory}/update.sh" ]; then + echo "executing ${DIR}/${directory}/update.sh" + "${DIR}/${directory}/update.sh" + fi + echo "***********************************************************************************" + echo "" +done < "$CONFIG_FILE" + +rm -rf "${OUT}" diff --git a/datacatalog/cmd/entrypoints/migrate.go b/datacatalog/cmd/entrypoints/migrate.go new file mode 100644 index 0000000000..b7b641f479 --- /dev/null +++ b/datacatalog/cmd/entrypoints/migrate.go @@ -0,0 +1,85 @@ +package entrypoints + +import ( + "github.com/lib/pq" + "github.com/lyft/datacatalog/pkg/repositories" + "github.com/lyft/datacatalog/pkg/runtime" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/promutils" + + "context" + + _ "github.com/jinzhu/gorm/dialects/postgres" // Required to import database driver. + _ "github.com/jinzhu/gorm/dialects/sqlite" + "github.com/spf13/cobra" +) + +var parentMigrateCmd = &cobra.Command{ + Use: "migrate", + Short: "This command controls migration behavior for the Flyte Catalog database. Please choose a subcommand.", +} + +var migrationsScope = promutils.NewScope("migrations") +var migrateScope = migrationsScope.NewSubScope("migrate") + +// all postgres servers come by default with a db name named postgres +const defaultDB = "postgres" +const pqInvalidDBCode = "3D000" + +// This runs all the migrations +var migrateCmd = &cobra.Command{ + Use: "run", + Short: "This command will run all the migrations for the database", + Run: func(cmd *cobra.Command, args []string) { + ctx := context.Background() + configProvider := runtime.NewConfigurationProvider() + dbConfigValues := configProvider.ApplicationConfiguration().GetDbConfig() + + dbName := dbConfigValues.DbName + dbHandle, err := repositories.NewDBHandle(dbConfigValues, migrateScope) + + if err != nil { + // if db does not exist, try creating it + pqError, ok := err.(*pq.Error) + if ok && pqError.Code == pqInvalidDBCode { + logger.Warningf(ctx, "Database [%v] does not exist, trying to create it now", dbName) + + dbConfigValues.DbName = defaultDB + setupDBHandler, err := repositories.NewDBHandle(dbConfigValues, migrateScope) + if err != nil { + logger.Errorf(ctx, "Failed to connect to default DB %v, err %v", defaultDB, err) + panic(err) + } + + // Create the database if it doesn't exist + // NOTE: this is non-destructive - if for some reason one does exist an err will be thrown + err = setupDBHandler.CreateDB(dbName) + if err != nil { + logger.Errorf(ctx, "Failed to create DB %v err %v", dbName, err) + panic(err) + } + + dbConfigValues.DbName = dbName + dbHandle, err = repositories.NewDBHandle(dbConfigValues, migrateScope) + if err != nil { + logger.Errorf(ctx, "Failed to connect DB err %v", err) + panic(err) + } + } else { + logger.Errorf(ctx, "Failed to connect DB err %v", err) + panic(err) + } + } + + logger.Infof(ctx, "Created DB connection.") + + // TODO: checkpoints for migrations + dbHandle.Migrate() + logger.Infof(ctx, "Ran DB migration successfully.") + }, +} + +func init() { + RootCmd.AddCommand(parentMigrateCmd) + parentMigrateCmd.AddCommand(migrateCmd) +} diff --git a/datacatalog/cmd/entrypoints/root.go b/datacatalog/cmd/entrypoints/root.go new file mode 100644 index 0000000000..4ffce611af --- /dev/null +++ b/datacatalog/cmd/entrypoints/root.go @@ -0,0 +1,76 @@ +package entrypoints + +import ( + "context" + "flag" + "fmt" + "os" + + "github.com/lyft/flytestdlib/config" + "github.com/lyft/flytestdlib/config/viper" + "github.com/spf13/cobra" + "github.com/spf13/pflag" +) + +var ( + cfgFile string + + configAccessor = viper.NewAccessor(config.Options{StrictMode: true}) +) + +func init() { + // See https://gist.github.com/nak3/78a32817a8a3950ae48f239a44cd3663 + // allows `$ datacatalog --logtostderr` to work + pflag.CommandLine.AddGoFlagSet(flag.CommandLine) + + // Add persistent flags - persistent flags persist through all sub-commands + RootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is ./datacatalog_config.yaml)") + + RootCmd.AddCommand(viper.GetConfigCommand()) + + // Allow viper to read the value of the flags + configAccessor.InitializePflags(RootCmd.PersistentFlags()) + + err := flag.CommandLine.Parse([]string{}) + if err != nil { + fmt.Println(err) + os.Exit(-1) + } +} + +// Execute adds all child commands to the root command sets flags appropriately. +// This is called by main.main(). It only needs to happen once to the rootCmd. +func Execute() error { + if err := RootCmd.Execute(); err != nil { + fmt.Println(err) + os.Exit(1) + } + return nil +} + +// RootCmd represents the base command when called without any subcommands +var RootCmd = &cobra.Command{ + Use: "datacatalog", + Short: "Launches datacatalog", + Long: ` +To get started run the serve subcommand which will start a server on localhost:8089: + + datacatalog serve +`, + PersistentPreRunE: func(cmd *cobra.Command, args []string) error { + return initConfig(cmd.Flags()) + }, +} + +func initConfig(flags *pflag.FlagSet) error { + configAccessor = viper.NewAccessor(config.Options{ + SearchPaths: []string{cfgFile, ".", "/etc/flyte/config", "$GOPATH/src/github.com/lyft/datacatalog"}, + StrictMode: false, + }) + + fmt.Println("Using config file: ", configAccessor.ConfigFilesUsed()) + + configAccessor.InitializePflags(flags) + + return configAccessor.UpdateConfig(context.TODO()) +} diff --git a/datacatalog/cmd/entrypoints/serve.go b/datacatalog/cmd/entrypoints/serve.go new file mode 100644 index 0000000000..ee84033969 --- /dev/null +++ b/datacatalog/cmd/entrypoints/serve.go @@ -0,0 +1,72 @@ +package entrypoints + +import ( + "context" + "fmt" + "html" + "net" + "net/http" + + "github.com/lyft/datacatalog/pkg/config" + "github.com/lyft/datacatalog/pkg/rpc/datacatalogservice" + datacatalog "github.com/lyft/datacatalog/protos/gen" + "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/promutils/labeled" + "github.com/spf13/cobra" + "google.golang.org/grpc" +) + +var serveCmd = &cobra.Command{ + Use: "serve", + Short: "Launches the Data Catalog server", + RunE: func(cmd *cobra.Command, args []string) error { + ctx := context.Background() + cfg := config.GetConfig() + + // serve a http healthcheck endpoint + go func() { + err := serveHealthcheck(ctx, cfg) + if err != nil { + logger.Errorf(ctx, "Unable to serve http", config.GetConfig().GetGrpcHostAddress(), err) + } + }() + + return serveInsecure(ctx, cfg) + }, +} + +func init() { + RootCmd.AddCommand(serveCmd) + + labeled.SetMetricKeys(contextutils.AppNameKey) +} + +// Create and start the gRPC server +func serveInsecure(ctx context.Context, cfg *config.Config) error { + grpcServer := newGRPCServer(ctx) + + grpcListener, err := net.Listen("tcp", cfg.GetGrpcHostAddress()) + if err != nil { + return err + } + + logger.Infof(ctx, "Serving DataCatalog Insecure on port %v", config.GetConfig().GetGrpcHostAddress()) + return grpcServer.Serve(grpcListener) +} + +// Creates a new GRPC Server with all the configuration +func newGRPCServer(_ context.Context) *grpc.Server { + grpcServer := grpc.NewServer() + datacatalog.RegisterDataCatalogServer(grpcServer, datacatalogservice.NewDataCatalogService()) + return grpcServer +} + +func serveHealthcheck(ctx context.Context, cfg *config.Config) error { + http.HandleFunc("/healthcheck", func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "Healthcheck success on %v", html.EscapeString(r.URL.Path)) + }) + + logger.Infof(ctx, "Serving DataCatalog http on port %v", cfg.GetHTTPHostAddress()) + return http.ListenAndServe(cfg.GetHTTPHostAddress(), nil) +} diff --git a/datacatalog/cmd/entrypoints/serve_dummy.go b/datacatalog/cmd/entrypoints/serve_dummy.go new file mode 100644 index 0000000000..8eacd00ea6 --- /dev/null +++ b/datacatalog/cmd/entrypoints/serve_dummy.go @@ -0,0 +1,59 @@ +package entrypoints + +import ( + "context" + "net" + + "github.com/lyft/datacatalog/pkg/config" + "github.com/lyft/datacatalog/pkg/rpc/datacatalogservice" + datacatalog "github.com/lyft/datacatalog/protos/gen" + "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/promutils/labeled" + "github.com/spf13/cobra" + "google.golang.org/grpc" +) + +var serveDummyCmd = &cobra.Command{ + Use: "serve-dummy", + Short: "Launches the Data Catalog server without any connections", + RunE: func(cmd *cobra.Command, args []string) error { + ctx := context.Background() + cfg := config.GetConfig() + return serveDummy(ctx, cfg) + }, +} + +func init() { + RootCmd.AddCommand(serveDummyCmd) + + labeled.SetMetricKeys(contextutils.AppNameKey) +} + +// Create and start the gRPC server and http healthcheck endpoint +func serveDummy(ctx context.Context, cfg *config.Config) error { + // serve a http healthcheck endpoint + go func() { + err := serveHealthcheck(ctx, cfg) + if err != nil { + logger.Errorf(ctx, "Unable to serve http", cfg.GetGrpcHostAddress(), err) + } + }() + + grpcServer := newGRPCDummyServer(ctx) + + grpcListener, err := net.Listen("tcp", cfg.GetGrpcHostAddress()) + if err != nil { + return err + } + + logger.Infof(ctx, "Serving DataCatalog Insecure on port %v", cfg.GetGrpcHostAddress()) + return grpcServer.Serve(grpcListener) +} + +// Creates a new GRPC Server with all the configuration +func newGRPCDummyServer(_ context.Context) *grpc.Server { + grpcServer := grpc.NewServer() + datacatalog.RegisterDataCatalogServer(grpcServer, &datacatalogservice.DataCatalogService{}) + return grpcServer +} diff --git a/datacatalog/cmd/main.go b/datacatalog/cmd/main.go new file mode 100644 index 0000000000..391cbb7d93 --- /dev/null +++ b/datacatalog/cmd/main.go @@ -0,0 +1,14 @@ +package main + +import ( + "github.com/golang/glog" + "github.com/lyft/datacatalog/cmd/entrypoints" +) + +func main() { + glog.V(2).Info("Beginning Data Catalog") + err := entrypoints.Execute() + if err != nil { + panic(err) + } +} diff --git a/datacatalog/datacatalog_config.yaml b/datacatalog/datacatalog_config.yaml new file mode 100644 index 0000000000..b33bb9acc5 --- /dev/null +++ b/datacatalog/datacatalog_config.yaml @@ -0,0 +1,27 @@ +# This is a sample configuration file. +# Real configuration when running inside K8s (local or otherwise) lives in a ConfigMap +# Look in the artifacts directory in the flyte repo for what's actually run +application: + grpcPort: 8089 + httpPort: 8080 +datacatalog: + storage-prefix: "metadata" +storage: + connection: + access-key: minio + auth-type: accesskey + disable-ssl: true + endpoint: http://localhost:9000 + region: my-region-here + secret-key: miniostorage + cache: + max_size_mbs: 10 + target_gc_percent: 100 + container: my-container + type: minio +database: + port: 5432 + username: postgres + host: localhost + dbname: datacatalog + options: "sslmode=disable" diff --git a/datacatalog/pkg/config/config.go b/datacatalog/pkg/config/config.go new file mode 100644 index 0000000000..a1dbadadec --- /dev/null +++ b/datacatalog/pkg/config/config.go @@ -0,0 +1,41 @@ +package config + +import ( + "fmt" + + "github.com/lyft/flytestdlib/config" +) + +const SectionKey = "application" + +//go:generate pflags Config + +type Config struct { + GrpcPort int `json:"grpcPort" pflag:",On which grpc port to serve Catalog"` + HTTPPort int `json:"httpPort" pflag:",On which http port to serve Catalog"` + Secure bool `json:"secure" pflag:",Whether to run Catalog in secure mode or not"` +} + +var applicationConfig = config.MustRegisterSection(SectionKey, &Config{}) + +func GetConfig() *Config { + return applicationConfig.GetConfig().(*Config) +} + +func SetConfig(c *Config) { + if err := applicationConfig.SetConfig(c); err != nil { + panic(err) + } +} + +func (c Config) GetGrpcHostAddress() string { + return fmt.Sprintf(":%d", c.GrpcPort) +} + +func (c Config) GetHTTPHostAddress() string { + return fmt.Sprintf(":%d", c.HTTPPort) +} + +func init() { + SetConfig(&Config{}) +} diff --git a/datacatalog/pkg/config/config_flags.go b/datacatalog/pkg/config/config_flags.go new file mode 100755 index 0000000000..f52950d17c --- /dev/null +++ b/datacatalog/pkg/config/config_flags.go @@ -0,0 +1,21 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots at +// 2019-08-15 17:34:15.109663842 -0700 PDT m=+1.854915201 + +package config + +import ( + "fmt" + + "github.com/spf13/pflag" +) + +// GetPFlagSet will return strongly types pflags for all fields in Config and its nested types. The format of the +// flags is json-name.json-sub-name... etc. +func (Config) GetPFlagSet(prefix string) *pflag.FlagSet { + cmdFlags := pflag.NewFlagSet("Config", pflag.ExitOnError) + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "grpcPort"), *new(int), "On which grpc port to serve Catalog") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "httpPort"), *new(int), "On which http port to serve Catalog") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "secure"), *new(bool), "Whether to run Catalog in secure mode or not") + return cmdFlags +} diff --git a/datacatalog/pkg/config/config_flags_test.go b/datacatalog/pkg/config/config_flags_test.go new file mode 100755 index 0000000000..ef51aabcb4 --- /dev/null +++ b/datacatalog/pkg/config/config_flags_test.go @@ -0,0 +1,152 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots at +// 2019-08-15 17:34:15.109663842 -0700 PDT m=+1.854915201 + +package config + +import ( + "encoding/json" + "fmt" + "reflect" + "testing" + + "github.com/mitchellh/mapstructure" + "github.com/stretchr/testify/assert" +) + +var dereferencableKindsConfig = map[reflect.Kind]struct{}{ + reflect.Array: {}, reflect.Chan: {}, reflect.Map: {}, reflect.Ptr: {}, reflect.Slice: {}, +} + +// Checks if t is a kind that can be dereferenced to get its underlying type. +func canGetElementConfig(t reflect.Kind) bool { + _, exists := dereferencableKindsConfig[t] + return exists +} + +// This decoder hook tests types for json unmarshaling capability. If implemented, it uses json unmarshal to build the +// object. Otherwise, it'll just pass on the original data. +func jsonUnmarshalerHookConfig(_, to reflect.Type, data interface{}) (interface{}, error) { + unmarshalerType := reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() + if to.Implements(unmarshalerType) || reflect.PtrTo(to).Implements(unmarshalerType) || + (canGetElementConfig(to.Kind()) && to.Elem().Implements(unmarshalerType)) { + + raw, err := json.Marshal(data) + if err != nil { + fmt.Printf("Failed to marshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + res := reflect.New(to).Interface() + err = json.Unmarshal(raw, &res) + if err != nil { + fmt.Printf("Failed to umarshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + return res, nil + } + + return data, nil +} + +func decode_Config(input, result interface{}) error { + config := &mapstructure.DecoderConfig{ + TagName: "json", + WeaklyTypedInput: true, + Result: result, + DecodeHook: mapstructure.ComposeDecodeHookFunc( + mapstructure.StringToTimeDurationHookFunc(), + mapstructure.StringToSliceHookFunc(","), + jsonUnmarshalerHookConfig, + ), + } + + decoder, err := mapstructure.NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(input) +} + +func testDecodeJson_Config(t *testing.T, val, result interface{}) { + assert.NoError(t, decode_Config(val, result)) +} + +func testDecodeSlice_Config(t *testing.T, vStringSlice, result interface{}) { + assert.NoError(t, decode_Config(vStringSlice, result)) +} + +func TestConfig_GetPFlagSet(t *testing.T) { + val := Config{} + cmdFlags := val.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) +} + +func TestConfig_SetFlags(t *testing.T) { + actual := Config{} + cmdFlags := actual.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) + + t.Run("Test_grpcPort", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("grpcPort"); err == nil { + assert.Equal(t, *new(int), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + cmdFlags.Set("grpcPort", "1") + if vInt, err := cmdFlags.GetInt("grpcPort"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.GrpcPort) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_httpPort", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("httpPort"); err == nil { + assert.Equal(t, *new(int), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + cmdFlags.Set("httpPort", "1") + if vInt, err := cmdFlags.GetInt("httpPort"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.HTTPPort) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_secure", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vBool, err := cmdFlags.GetBool("secure"); err == nil { + assert.Equal(t, *new(bool), vBool) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + cmdFlags.Set("secure", "1") + if vBool, err := cmdFlags.GetBool("secure"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vBool), &actual.Secure) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) +} diff --git a/datacatalog/pkg/errors/errors.go b/datacatalog/pkg/errors/errors.go new file mode 100644 index 0000000000..a6d3ef0910 --- /dev/null +++ b/datacatalog/pkg/errors/errors.go @@ -0,0 +1,45 @@ +package errors + +import ( + "fmt" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +type DataCatalogError interface { + Error() string + Code() codes.Code + GRPCStatus() *status.Status + String() string +} + +type dataCatalogErrorImpl struct { + status *status.Status +} + +func (e *dataCatalogErrorImpl) Error() string { + return e.status.Message() +} + +func (e *dataCatalogErrorImpl) Code() codes.Code { + return e.status.Code() +} + +func (e *dataCatalogErrorImpl) GRPCStatus() *status.Status { + return e.status +} + +func (e *dataCatalogErrorImpl) String() string { + return fmt.Sprintf("status: %v", e.status) +} + +func NewDataCatalogError(code codes.Code, message string) error { + return &dataCatalogErrorImpl{ + status: status.New(code, message), + } +} + +func NewDataCatalogErrorf(code codes.Code, format string, a ...interface{}) error { + return NewDataCatalogError(code, fmt.Sprintf(format, a...)) +} diff --git a/datacatalog/pkg/manager/impl/artifact_data_store.go b/datacatalog/pkg/manager/impl/artifact_data_store.go new file mode 100644 index 0000000000..2e4aee21f2 --- /dev/null +++ b/datacatalog/pkg/manager/impl/artifact_data_store.go @@ -0,0 +1,62 @@ +package impl + +import ( + "context" + + "github.com/lyft/datacatalog/pkg/errors" + "github.com/lyft/datacatalog/pkg/repositories/models" + datacatalog "github.com/lyft/datacatalog/protos/gen" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytestdlib/storage" + "google.golang.org/grpc/codes" +) + +const artifactDataFile = "data.pb" + +// ArtifactDataStore stores and retrieves ArtifactData values in a data.pb +type ArtifactDataStore interface { + PutData(ctx context.Context, artifact datacatalog.Artifact, data datacatalog.ArtifactData) (storage.DataReference, error) + GetData(ctx context.Context, dataModel models.ArtifactData) (*core.Literal, error) +} + +type artifactDataStore struct { + store *storage.DataStore + storagePrefix storage.DataReference +} + +func (m *artifactDataStore) getDataLocation(ctx context.Context, artifact datacatalog.Artifact, data datacatalog.ArtifactData) (storage.DataReference, error) { + dataset := artifact.Dataset + return m.store.ConstructReference(ctx, m.storagePrefix, dataset.Project, dataset.Domain, dataset.Name, dataset.Version, artifact.Id, data.Name, artifactDataFile) +} + +// Store marshalled data in data.pb under the storage prefix +func (m *artifactDataStore) PutData(ctx context.Context, artifact datacatalog.Artifact, data datacatalog.ArtifactData) (storage.DataReference, error) { + dataLocation, err := m.getDataLocation(ctx, artifact, data) + if err != nil { + return "", errors.NewDataCatalogErrorf(codes.Internal, "Unable to generate data location %s, err %v", dataLocation.String(), err) + } + err = m.store.WriteProtobuf(ctx, dataLocation, storage.Options{}, data.Value) + if err != nil { + return "", errors.NewDataCatalogErrorf(codes.Internal, "Unable to store artifact data in location %s, value %v, err %v", dataLocation.String(), data.Value, err) + } + + return dataLocation, nil +} + +// Retrieve the literal value of the ArtifactData from its specified location +func (m *artifactDataStore) GetData(ctx context.Context, dataModel models.ArtifactData) (*core.Literal, error) { + var value core.Literal + err := m.store.ReadProtobuf(ctx, storage.DataReference(dataModel.Location), &value) + if err != nil { + return nil, errors.NewDataCatalogErrorf(codes.Internal, "Unable to read artifact data from location %s, err %v", dataModel.Location, err) + } + + return &value, nil +} + +func NewArtifactDataStore(store *storage.DataStore, storagePrefix storage.DataReference) ArtifactDataStore { + return &artifactDataStore{ + store: store, + storagePrefix: storagePrefix, + } +} diff --git a/datacatalog/pkg/manager/impl/artifact_manager.go b/datacatalog/pkg/manager/impl/artifact_manager.go new file mode 100644 index 0000000000..6b65f46ca5 --- /dev/null +++ b/datacatalog/pkg/manager/impl/artifact_manager.go @@ -0,0 +1,126 @@ +package impl + +import ( + "context" + + "github.com/lyft/datacatalog/pkg/errors" + "github.com/lyft/datacatalog/pkg/manager/impl/validators" + "github.com/lyft/datacatalog/pkg/manager/interfaces" + "github.com/lyft/datacatalog/pkg/repositories" + datacatalog "github.com/lyft/datacatalog/protos/gen" + + "github.com/lyft/datacatalog/pkg/repositories/models" + "github.com/lyft/datacatalog/pkg/repositories/transformers" + + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/storage" + "google.golang.org/grpc/codes" +) + +type artifactManager struct { + repo repositories.RepositoryInterface + artifactStore ArtifactDataStore +} + +// Create an Artifact along with the associated ArtifactData. The ArtifactData will be stored in an offloaded location. +func (m *artifactManager) CreateArtifact(ctx context.Context, request datacatalog.CreateArtifactRequest) (*datacatalog.CreateArtifactResponse, error) { + artifact := request.Artifact + err := validators.ValidateArtifact(artifact) + if err != nil { + return nil, err + } + + datasetKey := transformers.FromDatasetID(*artifact.Dataset) + + // The dataset must exist for the artifact, let's verify that first + _, err = m.repo.DatasetRepo().Get(ctx, datasetKey) + if err != nil { + return nil, err + } + + // create Artifact Data offloaded storage files + artifactDataModels := make([]models.ArtifactData, len(request.Artifact.Data)) + for i, artifactData := range request.Artifact.Data { + dataLocation, err := m.artifactStore.PutData(ctx, *artifact, *artifactData) + if err != nil { + return nil, err + } + + artifactDataModels[i].Name = artifactData.Name + artifactDataModels[i].Location = dataLocation.String() + } + + artifactModel, err := transformers.CreateArtifactModel(request, artifactDataModels) + if err != nil { + return nil, err + } + + err = m.repo.ArtifactRepo().Create(ctx, artifactModel) + if err != nil { + return nil, err + } + return &datacatalog.CreateArtifactResponse{}, nil +} + +// Get the Artifact and its associated ArtifactData. The request can query by ArtifactID or TagName. +func (m *artifactManager) GetArtifact(ctx context.Context, request datacatalog.GetArtifactRequest) (*datacatalog.GetArtifactResponse, error) { + datasetID := request.Dataset + err := validators.ValidateGetArtifactRequest(request) + if err != nil { + return nil, err + } + + var artifactModel models.Artifact + switch request.QueryHandle.(type) { + case *datacatalog.GetArtifactRequest_ArtifactId: + artifactKey := transformers.ToArtifactKey(*datasetID, request.GetArtifactId()) + artifactModel, err = m.repo.ArtifactRepo().Get(ctx, artifactKey) + + if err != nil { + return nil, err + } + case *datacatalog.GetArtifactRequest_TagName: + tagKey := transformers.ToTagKey(*datasetID, request.GetTagName()) + tag, err := m.repo.TagRepo().Get(ctx, tagKey) + + if err != nil { + return nil, err + } + + artifactModel = tag.Artifact + } + + if len(artifactModel.ArtifactData) == 0 { + return nil, errors.NewDataCatalogErrorf(codes.Internal, "artifact [%+v] does not have artifact data associated", request) + } + + artifact, err := transformers.FromArtifactModel(artifactModel) + if err != nil { + return nil, err + } + + artifactDataList := make([]*datacatalog.ArtifactData, len(artifactModel.ArtifactData)) + for i, artifactData := range artifactModel.ArtifactData { + value, err := m.artifactStore.GetData(ctx, artifactData) + if err != nil { + return nil, err + } + + artifactDataList[i] = &datacatalog.ArtifactData{ + Name: artifactData.Name, + Value: value, + } + } + artifact.Data = artifactDataList + + return &datacatalog.GetArtifactResponse{ + Artifact: &artifact, + }, nil +} + +func NewArtifactManager(repo repositories.RepositoryInterface, store *storage.DataStore, storagePrefix storage.DataReference, artifactScope promutils.Scope) interfaces.ArtifactManager { + return &artifactManager{ + repo: repo, + artifactStore: NewArtifactDataStore(store, storagePrefix), + } +} diff --git a/datacatalog/pkg/manager/impl/artifact_manager_test.go b/datacatalog/pkg/manager/impl/artifact_manager_test.go new file mode 100644 index 0000000000..5800768116 --- /dev/null +++ b/datacatalog/pkg/manager/impl/artifact_manager_test.go @@ -0,0 +1,343 @@ +package impl + +import ( + "context" + "testing" + + "github.com/golang/protobuf/proto" + "github.com/lyft/datacatalog/pkg/errors" + "github.com/lyft/datacatalog/pkg/repositories/mocks" + "github.com/lyft/datacatalog/pkg/repositories/models" + datacatalog "github.com/lyft/datacatalog/protos/gen" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytestdlib/contextutils" + mockScope "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/promutils/labeled" + "github.com/lyft/flytestdlib/storage" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func createInmemoryDataStore(t testing.TB, scope mockScope.Scope) *storage.DataStore { + labeled.SetMetricKeys(contextutils.AppNameKey) + cfg := storage.Config{ + Type: storage.TypeMemory, + } + d, err := storage.NewDataStore(&cfg, scope) + assert.NoError(t, err) + return d +} + +func getTestStringLiteral() *core.Literal { + return &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Primitive{ + Primitive: &core.Primitive{Value: &core.Primitive_StringValue{StringValue: "value1"}}, + }, + }, + }, + } +} + +func getTestArtifact() *datacatalog.Artifact { + + return &datacatalog.Artifact{ + Id: "test-id", + Dataset: &datacatalog.DatasetID{ + Project: "test-project", + Domain: "test-domain", + Name: "test-name", + Version: "test-version", + }, + Metadata: &datacatalog.Metadata{ + KeyMap: map[string]string{"key1": "value1"}, + }, + Data: []*datacatalog.ArtifactData{ + { + Name: "data1", + Value: getTestStringLiteral(), + }, + }, + } +} + +func newMockDataCatalogRepo() *mocks.DataCatalogRepo { + return &mocks.DataCatalogRepo{ + MockDatasetRepo: &mocks.DatasetRepo{}, + MockArtifactRepo: &mocks.ArtifactRepo{}, + } +} + +func getExpectedDatastoreLocation(ctx context.Context, store *storage.DataStore, prefix storage.DataReference, artifact *datacatalog.Artifact, idx int) (storage.DataReference, error) { + dataset := artifact.Dataset + return store.ConstructReference(ctx, prefix, dataset.Project, dataset.Domain, dataset.Name, dataset.Version, artifact.Id, artifact.Data[idx].Name, artifactDataFile) +} + +func TestCreateArtifact(t *testing.T) { + ctx := context.Background() + datastore := createInmemoryDataStore(t, mockScope.NewTestScope()) + testStoragePrefix, err := datastore.ConstructReference(ctx, datastore.GetBaseContainerFQN(ctx), "test") + assert.NoError(t, err) + + t.Run("HappyPath", func(t *testing.T) { + datastore := createInmemoryDataStore(t, mockScope.NewTestScope()) + ctx := context.Background() + dcRepo := newMockDataCatalogRepo() + dcRepo.MockDatasetRepo.On("Get", mock.Anything, + mock.MatchedBy(func(dataset models.DatasetKey) bool { + expectedDataset := getTestDataset() + return dataset.Project == expectedDataset.Id.Project && + dataset.Domain == expectedDataset.Id.Domain && + dataset.Name == expectedDataset.Id.Name && + dataset.Version == expectedDataset.Id.Version + })).Return(models.Dataset{ + DatasetKey: models.DatasetKey{ + Project: getTestDataset().Id.Project, + Domain: getTestDataset().Id.Domain, + Name: getTestDataset().Id.Name, + Version: getTestDataset().Id.Version, + }, + }, nil) + + dcRepo.MockArtifactRepo.On("Create", + mock.MatchedBy(func(ctx context.Context) bool { return true }), + mock.MatchedBy(func(artifact models.Artifact) bool { + expectedArtifact := getTestArtifact() + return artifact.ArtifactID == expectedArtifact.Id && + artifact.SerializedMetadata != nil && + len(artifact.ArtifactData) == len(expectedArtifact.Data) && + artifact.ArtifactKey.DatasetProject == expectedArtifact.Dataset.Project && + artifact.ArtifactKey.DatasetDomain == expectedArtifact.Dataset.Domain && + artifact.ArtifactKey.DatasetName == expectedArtifact.Dataset.Name && + artifact.ArtifactKey.DatasetVersion == expectedArtifact.Dataset.Version + })).Return(nil) + + request := datacatalog.CreateArtifactRequest{Artifact: getTestArtifact()} + artifactManager := NewArtifactManager(dcRepo, datastore, testStoragePrefix, mockScope.NewTestScope()) + artifactResponse, err := artifactManager.CreateArtifact(ctx, request) + assert.NoError(t, err) + assert.NotNil(t, artifactResponse) + + // check that the datastore has the artifactData + dataRef, err := getExpectedDatastoreLocation(ctx, datastore, testStoragePrefix, getTestArtifact(), 0) + assert.NoError(t, err) + var value core.Literal + err = datastore.ReadProtobuf(ctx, dataRef, &value) + assert.NoError(t, err) + assert.Equal(t, value, *getTestArtifact().Data[0].Value) + }) + + t.Run("Dataset does not exist", func(t *testing.T) { + dcRepo := newMockDataCatalogRepo() + dcRepo.MockDatasetRepo.On("Get", mock.Anything, mock.Anything).Return(models.Dataset{}, status.Error(codes.NotFound, "not found")) + + request := datacatalog.CreateArtifactRequest{Artifact: getTestArtifact()} + artifactManager := NewArtifactManager(dcRepo, datastore, testStoragePrefix, mockScope.NewTestScope()) + artifactResponse, err := artifactManager.CreateArtifact(ctx, request) + assert.Error(t, err) + assert.Nil(t, artifactResponse) + responseCode := status.Code(err) + assert.Equal(t, codes.NotFound, responseCode) + }) + + t.Run("Artifact missing ID", func(t *testing.T) { + request := datacatalog.CreateArtifactRequest{ + Artifact: &datacatalog.Artifact{ + // missing artifact id + Dataset: getTestDataset().Id, + }, + } + + artifactManager := NewArtifactManager(&mocks.DataCatalogRepo{}, createInmemoryDataStore(t, mockScope.NewTestScope()), testStoragePrefix, mockScope.NewTestScope()) + _, err := artifactManager.CreateArtifact(ctx, request) + assert.Error(t, err) + responseCode := status.Code(err) + assert.Equal(t, codes.InvalidArgument, responseCode) + }) + + t.Run("Artifact missing artifact data", func(t *testing.T) { + request := datacatalog.CreateArtifactRequest{ + Artifact: &datacatalog.Artifact{ + Id: "test", + Dataset: getTestDataset().Id, + // missing artifactData + }, + } + + artifactManager := NewArtifactManager(&mocks.DataCatalogRepo{}, datastore, testStoragePrefix, mockScope.NewTestScope()) + _, err := artifactManager.CreateArtifact(ctx, request) + assert.Error(t, err) + responseCode := status.Code(err) + assert.Equal(t, codes.InvalidArgument, responseCode) + }) + + t.Run("Already exists", func(t *testing.T) { + dcRepo := &mocks.DataCatalogRepo{ + MockDatasetRepo: &mocks.DatasetRepo{}, + MockArtifactRepo: &mocks.ArtifactRepo{}, + } + dcRepo.MockDatasetRepo.On("Get", mock.Anything, + mock.MatchedBy(func(dataset models.DatasetKey) bool { + expectedDataset := getTestDataset() + return dataset.Project == expectedDataset.Id.Project && + dataset.Domain == expectedDataset.Id.Domain && + dataset.Name == expectedDataset.Id.Name && + dataset.Version == expectedDataset.Id.Version + })).Return(models.Dataset{ + DatasetKey: models.DatasetKey{ + Project: getTestDataset().Id.Project, + Domain: getTestDataset().Id.Domain, + Name: getTestDataset().Id.Name, + Version: getTestDataset().Id.Version, + }, + }, nil) + + dcRepo.MockArtifactRepo.On("Create", + mock.MatchedBy(func(ctx context.Context) bool { return true }), + mock.MatchedBy(func(artifact models.Artifact) bool { + expectedArtifact := getTestArtifact() + return artifact.ArtifactID == expectedArtifact.Id && + artifact.SerializedMetadata != nil && + len(artifact.ArtifactData) == len(expectedArtifact.Data) && + artifact.ArtifactKey.DatasetProject == expectedArtifact.Dataset.Project && + artifact.ArtifactKey.DatasetDomain == expectedArtifact.Dataset.Domain && + artifact.ArtifactKey.DatasetName == expectedArtifact.Dataset.Name && + artifact.ArtifactKey.DatasetVersion == expectedArtifact.Dataset.Version + })).Return(status.Error(codes.AlreadyExists, "test already exists")) + + request := datacatalog.CreateArtifactRequest{Artifact: getTestArtifact()} + artifactManager := NewArtifactManager(dcRepo, datastore, testStoragePrefix, mockScope.NewTestScope()) + artifactResponse, err := artifactManager.CreateArtifact(ctx, request) + assert.Error(t, err) + assert.Nil(t, artifactResponse) + + responseCode := status.Code(err) + assert.Equal(t, codes.AlreadyExists, responseCode) + }) +} + +func TestGetArtifact(t *testing.T) { + ctx := context.Background() + datastore := createInmemoryDataStore(t, mockScope.NewTestScope()) + testStoragePrefix, err := datastore.ConstructReference(ctx, datastore.GetBaseContainerFQN(ctx), "test") + assert.NoError(t, err) + + dcRepo := &mocks.DataCatalogRepo{ + MockDatasetRepo: &mocks.DatasetRepo{}, + MockArtifactRepo: &mocks.ArtifactRepo{}, + MockTagRepo: &mocks.TagRepo{}, + } + + expectedArtifact := getTestArtifact() + expectedDataset := expectedArtifact.Dataset + + // Write the artifact data to the expected location and see if the retrieved data matches + dataLocation, err := getExpectedDatastoreLocation(ctx, datastore, testStoragePrefix, expectedArtifact, 0) + assert.NoError(t, err) + err = datastore.WriteProtobuf(ctx, dataLocation, storage.Options{}, getTestStringLiteral()) + assert.NoError(t, err) + + // construct the artifact model we will return on the queries + serializedMetadata, err := proto.Marshal(expectedArtifact.Metadata) + assert.NoError(t, err) + datasetKey := models.DatasetKey{ + Project: expectedDataset.Project, + Domain: expectedDataset.Domain, + Version: expectedDataset.Version, + Name: expectedDataset.Name, + } + testArtifactModel := models.Artifact{ + ArtifactKey: models.ArtifactKey{ + DatasetProject: expectedDataset.Project, + DatasetDomain: expectedDataset.Domain, + DatasetVersion: expectedDataset.Version, + DatasetName: expectedDataset.Name, + ArtifactID: expectedArtifact.Id, + }, + ArtifactData: []models.ArtifactData{ + {Name: "data1", Location: dataLocation.String()}, + }, + Dataset: models.Dataset{ + DatasetKey: datasetKey, + SerializedMetadata: serializedMetadata, + }, + SerializedMetadata: serializedMetadata, + } + + t.Run("Get by Id", func(t *testing.T) { + + dcRepo.MockArtifactRepo.On("Get", mock.Anything, + mock.MatchedBy(func(artifactKey models.ArtifactKey) bool { + return artifactKey.ArtifactID == expectedArtifact.Id && + artifactKey.DatasetProject == expectedArtifact.Dataset.Project && + artifactKey.DatasetDomain == expectedArtifact.Dataset.Domain && + artifactKey.DatasetVersion == expectedArtifact.Dataset.Version && + artifactKey.DatasetName == expectedArtifact.Dataset.Name + })).Return(testArtifactModel, nil) + + artifactManager := NewArtifactManager(dcRepo, datastore, testStoragePrefix, mockScope.NewTestScope()) + artifactResponse, err := artifactManager.GetArtifact(ctx, datacatalog.GetArtifactRequest{ + Dataset: getTestDataset().Id, + QueryHandle: &datacatalog.GetArtifactRequest_ArtifactId{ArtifactId: expectedArtifact.Id}, + }) + assert.NoError(t, err) + + assert.True(t, proto.Equal(expectedArtifact, artifactResponse.Artifact)) + }) + + t.Run("Get by Artifact Tag", func(t *testing.T) { + expectedTag := getTestTag() + + dcRepo.MockTagRepo.On("Get", mock.Anything, + mock.MatchedBy(func(tag models.TagKey) bool { + return tag.TagName == expectedTag.TagName && + tag.DatasetProject == expectedTag.DatasetProject && + tag.DatasetDomain == expectedTag.DatasetDomain && + tag.DatasetVersion == expectedTag.DatasetVersion && + tag.DatasetName == expectedTag.DatasetName + })).Return(models.Tag{ + TagKey: models.TagKey{ + DatasetProject: expectedTag.DatasetProject, + DatasetDomain: expectedTag.DatasetDomain, + DatasetName: expectedTag.DatasetName, + DatasetVersion: expectedTag.DatasetVersion, + TagName: expectedTag.TagName, + }, + Artifact: testArtifactModel, + ArtifactID: testArtifactModel.ArtifactID, + }, nil) + + artifactManager := NewArtifactManager(dcRepo, datastore, testStoragePrefix, mockScope.NewTestScope()) + artifactResponse, err := artifactManager.GetArtifact(ctx, datacatalog.GetArtifactRequest{ + Dataset: getTestDataset().Id, + QueryHandle: &datacatalog.GetArtifactRequest_TagName{TagName: expectedTag.TagName}, + }) + assert.NoError(t, err) + + assert.True(t, proto.Equal(expectedArtifact, artifactResponse.Artifact)) + }) + + t.Run("Get missing input", func(t *testing.T) { + artifactManager := NewArtifactManager(dcRepo, datastore, testStoragePrefix, mockScope.NewTestScope()) + artifactResponse, err := artifactManager.GetArtifact(ctx, datacatalog.GetArtifactRequest{Dataset: getTestDataset().Id}) + assert.Error(t, err) + assert.Nil(t, artifactResponse) + responseCode := status.Code(err) + assert.Equal(t, codes.InvalidArgument, responseCode) + }) + + t.Run("Get does not exist", func(t *testing.T) { + dcRepo.MockTagRepo.On("Get", mock.Anything, mock.Anything).Return( + models.Tag{}, errors.NewDataCatalogError(codes.NotFound, "tag with artifact does not exist")) + artifactManager := NewArtifactManager(dcRepo, datastore, testStoragePrefix, mockScope.NewTestScope()) + artifactResponse, err := artifactManager.GetArtifact(ctx, datacatalog.GetArtifactRequest{Dataset: getTestDataset().Id, QueryHandle: &datacatalog.GetArtifactRequest_TagName{TagName: "test"}}) + assert.Error(t, err) + assert.Nil(t, artifactResponse) + responseCode := status.Code(err) + assert.Equal(t, codes.NotFound, responseCode) + }) + +} diff --git a/datacatalog/pkg/manager/impl/dataset_manager.go b/datacatalog/pkg/manager/impl/dataset_manager.go new file mode 100644 index 0000000000..a4a318aa56 --- /dev/null +++ b/datacatalog/pkg/manager/impl/dataset_manager.go @@ -0,0 +1,70 @@ +package impl + +import ( + "context" + + "github.com/lyft/datacatalog/pkg/manager/impl/validators" + "github.com/lyft/datacatalog/pkg/manager/interfaces" + "github.com/lyft/datacatalog/pkg/repositories" + "github.com/lyft/datacatalog/pkg/repositories/transformers" + datacatalog "github.com/lyft/datacatalog/protos/gen" + + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/storage" +) + +type datasetManager struct { + repo repositories.RepositoryInterface + store *storage.DataStore +} + +// Create a Dataset with optional metadata. If one already exists a grpc AlreadyExists err will be returned +func (dm *datasetManager) CreateDataset(ctx context.Context, request datacatalog.CreateDatasetRequest) (*datacatalog.CreateDatasetResponse, error) { + err := validators.ValidateDatasetID(request.Dataset.Id) + if err != nil { + return nil, err + } + + datasetModel, err := transformers.CreateDatasetModel(request.Dataset) + if err != nil { + return nil, err + } + + err = dm.repo.DatasetRepo().Create(ctx, *datasetModel) + if err != nil { + return nil, err + } + + return &datacatalog.CreateDatasetResponse{}, nil +} + +// Get a Dataset with the given DatasetID if it exists. If none exist a grpc NotFound err will be returned +func (dm *datasetManager) GetDataset(ctx context.Context, request datacatalog.GetDatasetRequest) (*datacatalog.GetDatasetResponse, error) { + err := validators.ValidateDatasetID(request.Dataset) + if err != nil { + return nil, err + } + + datasetKey := transformers.FromDatasetID(*request.Dataset) + datasetModel, err := dm.repo.DatasetRepo().Get(ctx, datasetKey) + + if err != nil { + return nil, err + } + + datasetResponse, err := transformers.FromDatasetModel(datasetModel) + if err != nil { + return nil, err + } + + return &datacatalog.GetDatasetResponse{ + Dataset: datasetResponse, + }, nil +} + +func NewDatasetManager(repo repositories.RepositoryInterface, store *storage.DataStore, datasetScope promutils.Scope) interfaces.DatasetManager { + return &datasetManager{ + repo: repo, + store: store, + } +} diff --git a/datacatalog/pkg/manager/impl/dataset_manager_test.go b/datacatalog/pkg/manager/impl/dataset_manager_test.go new file mode 100644 index 0000000000..9f9c39e611 --- /dev/null +++ b/datacatalog/pkg/manager/impl/dataset_manager_test.go @@ -0,0 +1,154 @@ +package impl + +import ( + "testing" + + "context" + + "github.com/golang/protobuf/proto" + "github.com/lyft/datacatalog/pkg/errors" + "github.com/lyft/datacatalog/pkg/repositories/mocks" + "github.com/lyft/datacatalog/pkg/repositories/models" + datacatalog "github.com/lyft/datacatalog/protos/gen" + mockScope "github.com/lyft/flytestdlib/promutils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func getTestDataset() *datacatalog.Dataset { + return &datacatalog.Dataset{ + Id: &datacatalog.DatasetID{ + Project: "test-project", + Domain: "test-domain", + Name: "test-name", + Version: "test-version", + }, + Metadata: &datacatalog.Metadata{ + KeyMap: map[string]string{"key1": "value1"}, + }, + } +} + +func getDataCatalogRepo() *mocks.DataCatalogRepo { + return &mocks.DataCatalogRepo{ + MockDatasetRepo: &mocks.DatasetRepo{}, + } +} + +func TestCreateDataset(t *testing.T) { + + expectedDataset := getTestDataset() + + t.Run("HappyPath", func(t *testing.T) { + dcRepo := getDataCatalogRepo() + datasetManager := NewDatasetManager(dcRepo, nil, mockScope.NewTestScope()) + dcRepo.MockDatasetRepo.On("Create", + mock.MatchedBy(func(ctx context.Context) bool { return true }), + mock.MatchedBy(func(dataset models.Dataset) bool { + + return dataset.Name == expectedDataset.Id.Name && + dataset.Project == expectedDataset.Id.Project && + dataset.Domain == expectedDataset.Id.Domain && + dataset.Version == expectedDataset.Id.Version + })).Return(nil) + request := datacatalog.CreateDatasetRequest{Dataset: expectedDataset} + datasetResponse, err := datasetManager.CreateDataset(context.Background(), request) + assert.NoError(t, err) + assert.NotNil(t, datasetResponse) + }) + + t.Run("MissingInput", func(t *testing.T) { + dcRepo := getDataCatalogRepo() + datasetManager := NewDatasetManager(dcRepo, nil, mockScope.NewTestScope()) + request := datacatalog.CreateDatasetRequest{ + Dataset: &datacatalog.Dataset{ + Id: &datacatalog.DatasetID{ + Domain: "missing-domain", + Name: "missing-name", + Version: "missing-version", + }, + }, + } + + _, err := datasetManager.CreateDataset(context.Background(), request) + assert.Error(t, err) + responseCode := status.Code(err) + assert.Equal(t, codes.InvalidArgument, responseCode) + }) + + t.Run("AlreadyExists", func(t *testing.T) { + dcRepo := getDataCatalogRepo() + datasetManager := NewDatasetManager(dcRepo, nil, mockScope.NewTestScope()) + + dcRepo.MockDatasetRepo.On("Create", + mock.Anything, + mock.Anything).Return(status.Error(codes.AlreadyExists, "test already exists")) + request := datacatalog.CreateDatasetRequest{ + Dataset: getTestDataset(), + } + + _, err := datasetManager.CreateDataset(context.Background(), request) + assert.Error(t, err) + responseCode := status.Code(err) + assert.Equal(t, codes.AlreadyExists, responseCode) + }) +} + +func TestGetDataset(t *testing.T) { + expectedDataset := getTestDataset() + + t.Run("HappyPath", func(t *testing.T) { + dcRepo := getDataCatalogRepo() + datasetManager := NewDatasetManager(dcRepo, nil, mockScope.NewTestScope()) + + serializedMetadata, _ := proto.Marshal(expectedDataset.Metadata) + datasetModelResponse := models.Dataset{ + DatasetKey: models.DatasetKey{ + Project: expectedDataset.Id.Project, + Domain: expectedDataset.Id.Domain, + Version: expectedDataset.Id.Version, + Name: expectedDataset.Id.Name, + }, + SerializedMetadata: serializedMetadata, + } + + dcRepo.MockDatasetRepo.On("Get", + mock.MatchedBy(func(ctx context.Context) bool { return true }), + mock.MatchedBy(func(datasetKey models.DatasetKey) bool { + + return datasetKey.Name == expectedDataset.Id.Name && + datasetKey.Project == expectedDataset.Id.Project && + datasetKey.Domain == expectedDataset.Id.Domain && + datasetKey.Version == expectedDataset.Id.Version + })).Return(datasetModelResponse, nil) + request := datacatalog.GetDatasetRequest{Dataset: getTestDataset().Id} + datasetResponse, err := datasetManager.GetDataset(context.Background(), request) + assert.NoError(t, err) + assert.NotNil(t, datasetResponse) + assert.True(t, proto.Equal(datasetResponse.Dataset, expectedDataset)) + assert.EqualValues(t, datasetResponse.Dataset.Metadata.KeyMap, expectedDataset.Metadata.KeyMap) + }) + + t.Run("Does not exist", func(t *testing.T) { + dcRepo := getDataCatalogRepo() + datasetManager := NewDatasetManager(dcRepo, nil, mockScope.NewTestScope()) + + dcRepo.MockDatasetRepo.On("Get", + mock.MatchedBy(func(ctx context.Context) bool { return true }), + mock.MatchedBy(func(datasetKey models.DatasetKey) bool { + + return datasetKey.Name == expectedDataset.Id.Name && + datasetKey.Project == expectedDataset.Id.Project && + datasetKey.Domain == expectedDataset.Id.Domain && + datasetKey.Version == expectedDataset.Id.Version + })).Return(models.Dataset{}, errors.NewDataCatalogError(codes.NotFound, "dataset does not exist")) + request := datacatalog.GetDatasetRequest{Dataset: getTestDataset().Id} + _, err := datasetManager.GetDataset(context.Background(), request) + assert.Error(t, err) + responseCode := status.Code(err) + assert.Equal(t, codes.NotFound, responseCode) + }) + +} diff --git a/datacatalog/pkg/manager/impl/tag_manager.go b/datacatalog/pkg/manager/impl/tag_manager.go new file mode 100644 index 0000000000..2c2de65cb3 --- /dev/null +++ b/datacatalog/pkg/manager/impl/tag_manager.go @@ -0,0 +1,54 @@ +package impl + +import ( + "context" + + "github.com/lyft/datacatalog/pkg/manager/impl/validators" + "github.com/lyft/datacatalog/pkg/manager/interfaces" + "github.com/lyft/datacatalog/pkg/repositories" + + "github.com/lyft/datacatalog/pkg/repositories/models" + "github.com/lyft/datacatalog/pkg/repositories/transformers" + datacatalog "github.com/lyft/datacatalog/protos/gen" + + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/storage" +) + +type tagManager struct { + repo repositories.RepositoryInterface + store *storage.DataStore +} + +func (m *tagManager) AddTag(ctx context.Context, request datacatalog.AddTagRequest) (*datacatalog.AddTagResponse, error) { + + if err := validators.ValidateTag(request.Tag); err != nil { + return nil, err + } + + // verify the artifact exists before adding a tag to it + datasetID := *request.Tag.Dataset + artifactKey := transformers.ToArtifactKey(datasetID, request.Tag.ArtifactId) + _, err := m.repo.ArtifactRepo().Get(ctx, artifactKey) + if err != nil { + return nil, err + } + + tagKey := transformers.ToTagKey(datasetID, request.Tag.Name) + err = m.repo.TagRepo().Create(ctx, models.Tag{ + TagKey: tagKey, + ArtifactID: request.Tag.ArtifactId, + }) + if err != nil { + return nil, err + } + + return &datacatalog.AddTagResponse{}, nil +} + +func NewTagManager(repo repositories.RepositoryInterface, store *storage.DataStore, tagScope promutils.Scope) interfaces.TagManager { + return &tagManager{ + repo: repo, + store: store, + } +} diff --git a/datacatalog/pkg/manager/impl/tag_manager_test.go b/datacatalog/pkg/manager/impl/tag_manager_test.go new file mode 100644 index 0000000000..7f8a75db85 --- /dev/null +++ b/datacatalog/pkg/manager/impl/tag_manager_test.go @@ -0,0 +1,127 @@ +package impl + +import ( + "context" + "testing" + + "github.com/lyft/datacatalog/pkg/repositories/mocks" + "github.com/lyft/datacatalog/pkg/repositories/models" + datacatalog "github.com/lyft/datacatalog/protos/gen" + + mockScope "github.com/lyft/flytestdlib/promutils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func getTestTag() models.Tag { + return models.Tag{ + TagKey: models.TagKey{ + DatasetProject: "test-project", + DatasetDomain: "test-domain", + DatasetVersion: "test-version", + DatasetName: "test-name", + TagName: "test-tag", + }, + ArtifactID: "test-artifactID", + } +} + +func TestAddTag(t *testing.T) { + dcRepo := &mocks.DataCatalogRepo{ + MockDatasetRepo: &mocks.DatasetRepo{}, + MockArtifactRepo: &mocks.ArtifactRepo{}, + MockTagRepo: &mocks.TagRepo{}, + } + + expectedTag := getTestTag() + + t.Run("HappyPath", func(t *testing.T) { + dcRepo.MockTagRepo.On("Create", mock.MatchedBy(func(ctx context.Context) bool { return true }), + mock.MatchedBy(func(tag models.Tag) bool { + + return tag.DatasetProject == expectedTag.DatasetProject && + tag.DatasetDomain == expectedTag.DatasetDomain && + tag.DatasetName == expectedTag.DatasetName && + tag.DatasetVersion == expectedTag.DatasetVersion && + tag.ArtifactID == expectedTag.ArtifactID && + tag.TagName == expectedTag.TagName + })).Return(nil) + + artifactKey := models.ArtifactKey{ + DatasetProject: expectedTag.DatasetProject, + DatasetDomain: expectedTag.DatasetDomain, + DatasetName: expectedTag.DatasetVersion, + DatasetVersion: expectedTag.DatasetName, + ArtifactID: expectedTag.ArtifactID, + } + + dcRepo.MockArtifactRepo.On("Get", mock.MatchedBy(func(ctx context.Context) bool { return true }), + mock.MatchedBy(func(artifactKey models.ArtifactKey) bool { + return artifactKey.DatasetProject == expectedTag.DatasetProject && + artifactKey.DatasetDomain == expectedTag.DatasetDomain && + artifactKey.DatasetName == expectedTag.DatasetName && + artifactKey.DatasetVersion == expectedTag.DatasetVersion && + artifactKey.ArtifactID == expectedTag.ArtifactID + })).Return(models.Artifact{ArtifactKey: artifactKey}, nil) + + tagManager := NewTagManager(dcRepo, nil, mockScope.NewTestScope()) + _, err := tagManager.AddTag(context.Background(), datacatalog.AddTagRequest{ + Tag: &datacatalog.Tag{ + Name: expectedTag.TagName, + ArtifactId: expectedTag.ArtifactID, + Dataset: &datacatalog.DatasetID{ + Project: expectedTag.DatasetProject, + Domain: expectedTag.DatasetDomain, + Version: expectedTag.DatasetVersion, + Name: expectedTag.DatasetName, + }, + }, + }) + + assert.NoError(t, err) + }) + + t.Run("NoDataset", func(t *testing.T) { + tagManager := NewTagManager(dcRepo, nil, mockScope.NewTestScope()) + _, err := tagManager.AddTag(context.Background(), datacatalog.AddTagRequest{ + Tag: &datacatalog.Tag{ + Name: "noDataset", + ArtifactId: "noDataset", + }, + }) + + assert.Error(t, err) + responseCode := status.Code(err) + assert.Equal(t, codes.InvalidArgument, responseCode) + }) + + t.Run("NoTagName", func(t *testing.T) { + tagManager := NewTagManager(dcRepo, nil, mockScope.NewTestScope()) + _, err := tagManager.AddTag(context.Background(), datacatalog.AddTagRequest{ + Tag: &datacatalog.Tag{ + ArtifactId: "noArtifact", + Dataset: getTestDataset().Id, + }, + }) + + assert.Error(t, err) + responseCode := status.Code(err) + assert.Equal(t, codes.InvalidArgument, responseCode) + }) + + t.Run("NoArtifactID", func(t *testing.T) { + tagManager := NewTagManager(dcRepo, nil, mockScope.NewTestScope()) + _, err := tagManager.AddTag(context.Background(), datacatalog.AddTagRequest{ + Tag: &datacatalog.Tag{ + Name: "noArtifact", + Dataset: getTestDataset().Id, + }, + }) + + assert.Error(t, err) + responseCode := status.Code(err) + assert.Equal(t, codes.InvalidArgument, responseCode) + }) +} diff --git a/datacatalog/pkg/manager/impl/validators/artifact_validator.go b/datacatalog/pkg/manager/impl/validators/artifact_validator.go new file mode 100644 index 0000000000..ceba3c0441 --- /dev/null +++ b/datacatalog/pkg/manager/impl/validators/artifact_validator.go @@ -0,0 +1,66 @@ +package validators + +import ( + "fmt" + + datacatalog "github.com/lyft/datacatalog/protos/gen" +) + +const ( + artifactID = "artifactID" + artifactDataEntity = "artifactData" + artifactEntity = "artifact" +) + +func ValidateGetArtifactRequest(request datacatalog.GetArtifactRequest) error { + if err := ValidateDatasetID(request.Dataset); err != nil { + return err + } + + if request.QueryHandle == nil { + return NewMissingArgumentError(fmt.Sprintf("one of %s/%s", artifactID, tagName)) + } + + switch request.QueryHandle.(type) { + case *datacatalog.GetArtifactRequest_ArtifactId: + if err := ValidateEmptyStringField(request.GetArtifactId(), artifactID); err != nil { + return err + } + case *datacatalog.GetArtifactRequest_TagName: + if err := ValidateEmptyStringField(request.GetTagName(), tagName); err != nil { + return err + } + default: + return NewInvalidArgumentError("QueryHandle", "invalid type") + } + + return nil +} + +func ValidateEmptyArtifactData(artifactData []*datacatalog.ArtifactData) error { + if len(artifactData) == 0 { + return NewMissingArgumentError(artifactDataEntity) + } + + return nil +} + +func ValidateArtifact(artifact *datacatalog.Artifact) error { + if artifact == nil { + return NewMissingArgumentError(artifactEntity) + } + + if err := ValidateDatasetID(artifact.Dataset); err != nil { + return err + } + + if err := ValidateEmptyStringField(artifact.Id, artifactID); err != nil { + return err + } + + if err := ValidateEmptyArtifactData(artifact.Data); err != nil { + return err + } + + return nil +} diff --git a/datacatalog/pkg/manager/impl/validators/common.go b/datacatalog/pkg/manager/impl/validators/common.go new file mode 100644 index 0000000000..906b7aea27 --- /dev/null +++ b/datacatalog/pkg/manager/impl/validators/common.go @@ -0,0 +1,8 @@ +package validators + +func ValidateEmptyStringField(field, fieldName string) error { + if field == "" { + return NewMissingArgumentError(fieldName) + } + return nil +} diff --git a/datacatalog/pkg/manager/impl/validators/dataset_validator.go b/datacatalog/pkg/manager/impl/validators/dataset_validator.go new file mode 100644 index 0000000000..5ba8d17da3 --- /dev/null +++ b/datacatalog/pkg/manager/impl/validators/dataset_validator.go @@ -0,0 +1,31 @@ +package validators + +import datacatalog "github.com/lyft/datacatalog/protos/gen" + +const ( + datasetEntity = "dataset" + datasetProject = "project" + datasetDomain = "domain" + datasetName = "name" + datasetVersion = "version" +) + +// Validate that the DatasetID has all the fields filled +func ValidateDatasetID(ds *datacatalog.DatasetID) error { + if ds == nil { + return NewMissingArgumentError(datasetEntity) + } + if err := ValidateEmptyStringField(ds.Project, datasetProject); err != nil { + return err + } + if err := ValidateEmptyStringField(ds.Domain, datasetDomain); err != nil { + return err + } + if err := ValidateEmptyStringField(ds.Name, datasetName); err != nil { + return err + } + if err := ValidateEmptyStringField(ds.Version, datasetVersion); err != nil { + return err + } + return nil +} diff --git a/datacatalog/pkg/manager/impl/validators/errors.go b/datacatalog/pkg/manager/impl/validators/errors.go new file mode 100644 index 0000000000..8c9d592347 --- /dev/null +++ b/datacatalog/pkg/manager/impl/validators/errors.go @@ -0,0 +1,20 @@ +package validators + +import ( + "fmt" + + "github.com/lyft/datacatalog/pkg/errors" + + "google.golang.org/grpc/codes" +) + +const missingFieldFormat = "missing %s" +const invalidArgFormat = "invalid value for %s, value:[%s]" + +func NewMissingArgumentError(field string) error { + return errors.NewDataCatalogErrorf(codes.InvalidArgument, fmt.Sprintf(missingFieldFormat, field)) +} + +func NewInvalidArgumentError(field string, value string) error { + return errors.NewDataCatalogErrorf(codes.InvalidArgument, fmt.Sprintf(invalidArgFormat, field, value)) +} diff --git a/datacatalog/pkg/manager/impl/validators/tag_validator.go b/datacatalog/pkg/manager/impl/validators/tag_validator.go new file mode 100644 index 0000000000..5bf778ff96 --- /dev/null +++ b/datacatalog/pkg/manager/impl/validators/tag_validator.go @@ -0,0 +1,28 @@ +package validators + +import ( + datacatalog "github.com/lyft/datacatalog/protos/gen" +) + +const ( + tagName = "tagName" + tagEntity = "tag" +) + +func ValidateTag(tag *datacatalog.Tag) error { + if tag == nil { + return NewMissingArgumentError(tagEntity) + } + if err := ValidateDatasetID(tag.Dataset); err != nil { + return err + } + + if err := ValidateEmptyStringField(tag.Name, tagName); err != nil { + return err + } + + if err := ValidateEmptyStringField(tag.ArtifactId, artifactID); err != nil { + return err + } + return nil +} diff --git a/datacatalog/pkg/manager/interfaces/artifact.go b/datacatalog/pkg/manager/interfaces/artifact.go new file mode 100644 index 0000000000..71896f7454 --- /dev/null +++ b/datacatalog/pkg/manager/interfaces/artifact.go @@ -0,0 +1,12 @@ +package interfaces + +import ( + "context" + + idl_datacatalog "github.com/lyft/datacatalog/protos/gen" +) + +type ArtifactManager interface { + CreateArtifact(ctx context.Context, request idl_datacatalog.CreateArtifactRequest) (*idl_datacatalog.CreateArtifactResponse, error) + GetArtifact(ctx context.Context, request idl_datacatalog.GetArtifactRequest) (*idl_datacatalog.GetArtifactResponse, error) +} diff --git a/datacatalog/pkg/manager/interfaces/dataset.go b/datacatalog/pkg/manager/interfaces/dataset.go new file mode 100644 index 0000000000..eaa9c0eef4 --- /dev/null +++ b/datacatalog/pkg/manager/interfaces/dataset.go @@ -0,0 +1,12 @@ +package interfaces + +import ( + "context" + + idl_datacatalog "github.com/lyft/datacatalog/protos/gen" +) + +type DatasetManager interface { + CreateDataset(ctx context.Context, request idl_datacatalog.CreateDatasetRequest) (*idl_datacatalog.CreateDatasetResponse, error) + GetDataset(ctx context.Context, request idl_datacatalog.GetDatasetRequest) (*idl_datacatalog.GetDatasetResponse, error) +} diff --git a/datacatalog/pkg/manager/interfaces/tag.go b/datacatalog/pkg/manager/interfaces/tag.go new file mode 100644 index 0000000000..274fd17c59 --- /dev/null +++ b/datacatalog/pkg/manager/interfaces/tag.go @@ -0,0 +1,11 @@ +package interfaces + +import ( + "context" + + datacatalog "github.com/lyft/datacatalog/protos/gen" +) + +type TagManager interface { + AddTag(ctx context.Context, request datacatalog.AddTagRequest) (*datacatalog.AddTagResponse, error) +} diff --git a/datacatalog/pkg/manager/mocks/artifact.go b/datacatalog/pkg/manager/mocks/artifact.go new file mode 100644 index 0000000000..70fab799ef --- /dev/null +++ b/datacatalog/pkg/manager/mocks/artifact.go @@ -0,0 +1,59 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import context "context" +import idl_datacatalog "github.com/lyft/datacatalog/protos/gen" + +import mock "github.com/stretchr/testify/mock" + +// ArtifactManager is an autogenerated mock type for the ArtifactManager type +type ArtifactManager struct { + mock.Mock +} + +// CreateArtifact provides a mock function with given fields: ctx, request +func (_m *ArtifactManager) CreateArtifact(ctx context.Context, request idl_datacatalog.CreateArtifactRequest) (*idl_datacatalog.CreateArtifactResponse, error) { + ret := _m.Called(ctx, request) + + var r0 *idl_datacatalog.CreateArtifactResponse + if rf, ok := ret.Get(0).(func(context.Context, idl_datacatalog.CreateArtifactRequest) *idl_datacatalog.CreateArtifactResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*idl_datacatalog.CreateArtifactResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, idl_datacatalog.CreateArtifactRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetArtifact provides a mock function with given fields: ctx, request +func (_m *ArtifactManager) GetArtifact(ctx context.Context, request idl_datacatalog.GetArtifactRequest) (*idl_datacatalog.GetArtifactResponse, error) { + ret := _m.Called(ctx, request) + + var r0 *idl_datacatalog.GetArtifactResponse + if rf, ok := ret.Get(0).(func(context.Context, idl_datacatalog.GetArtifactRequest) *idl_datacatalog.GetArtifactResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*idl_datacatalog.GetArtifactResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, idl_datacatalog.GetArtifactRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/datacatalog/pkg/manager/mocks/dataset.go b/datacatalog/pkg/manager/mocks/dataset.go new file mode 100644 index 0000000000..ccfa1babfb --- /dev/null +++ b/datacatalog/pkg/manager/mocks/dataset.go @@ -0,0 +1,59 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import context "context" +import idl_datacatalog "github.com/lyft/datacatalog/protos/gen" + +import mock "github.com/stretchr/testify/mock" + +// DatasetManager is an autogenerated mock type for the DatasetManager type +type DatasetManager struct { + mock.Mock +} + +// CreateDataset provides a mock function with given fields: ctx, request +func (_m *DatasetManager) CreateDataset(ctx context.Context, request idl_datacatalog.CreateDatasetRequest) (*idl_datacatalog.CreateDatasetResponse, error) { + ret := _m.Called(ctx, request) + + var r0 *idl_datacatalog.CreateDatasetResponse + if rf, ok := ret.Get(0).(func(context.Context, idl_datacatalog.CreateDatasetRequest) *idl_datacatalog.CreateDatasetResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*idl_datacatalog.CreateDatasetResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, idl_datacatalog.CreateDatasetRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetDataset provides a mock function with given fields: ctx, request +func (_m *DatasetManager) GetDataset(ctx context.Context, request idl_datacatalog.GetDatasetRequest) (*idl_datacatalog.GetDatasetResponse, error) { + ret := _m.Called(ctx, request) + + var r0 *idl_datacatalog.GetDatasetResponse + if rf, ok := ret.Get(0).(func(context.Context, idl_datacatalog.GetDatasetRequest) *idl_datacatalog.GetDatasetResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*idl_datacatalog.GetDatasetResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, idl_datacatalog.GetDatasetRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/datacatalog/pkg/manager/mocks/tag.go b/datacatalog/pkg/manager/mocks/tag.go new file mode 100644 index 0000000000..ccad49b131 --- /dev/null +++ b/datacatalog/pkg/manager/mocks/tag.go @@ -0,0 +1,36 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import context "context" +import idl_datacatalog "github.com/lyft/datacatalog/protos/gen" + +import mock "github.com/stretchr/testify/mock" + +// TagManager is an autogenerated mock type for the TagManager type +type TagManager struct { + mock.Mock +} + +// AddTag provides a mock function with given fields: ctx, request +func (_m *TagManager) AddTag(ctx context.Context, request idl_datacatalog.AddTagRequest) (*idl_datacatalog.AddTagResponse, error) { + ret := _m.Called(ctx, request) + + var r0 *idl_datacatalog.AddTagResponse + if rf, ok := ret.Get(0).(func(context.Context, idl_datacatalog.AddTagRequest) *idl_datacatalog.AddTagResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*idl_datacatalog.AddTagResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, idl_datacatalog.AddTagRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/datacatalog/pkg/repositories/config/database.go b/datacatalog/pkg/repositories/config/database.go new file mode 100644 index 0000000000..7ac5930bc4 --- /dev/null +++ b/datacatalog/pkg/repositories/config/database.go @@ -0,0 +1,27 @@ +package config + +//go:generate pflags DbConfigSection + +// This struct corresponds to the database section of in the config +type DbConfigSection struct { + Host string `json:"host"` + Port int `json:"port"` + DbName string `json:"dbname"` + User string `json:"username"` + // Either Password or PasswordPath must be set. + Password string `json:"password"` + PasswordPath string `json:"passwordPath"` + // See http://gorm.io/docs/connecting_to_the_database.html for available options passed, in addition to the above. + ExtraOptions string `json:"options"` +} + +// Database config. Contains values necessary to open a database connection. +type DbConfig struct { + BaseConfig + Host string `json:"host"` + Port int `json:"port"` + DbName string `json:"dbname"` + User string `json:"user"` + Password string `json:"password"` + ExtraOptions string `json:"options"` +} diff --git a/datacatalog/pkg/repositories/config/dbconfigsection_flags.go b/datacatalog/pkg/repositories/config/dbconfigsection_flags.go new file mode 100755 index 0000000000..42abddf106 --- /dev/null +++ b/datacatalog/pkg/repositories/config/dbconfigsection_flags.go @@ -0,0 +1,25 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots at +// 2019-08-01 11:59:32.35127988 -0700 PDT m=+3.149045889 + +package config + +import ( + "fmt" + + "github.com/spf13/pflag" +) + +// GetPFlagSet will return strongly types pflags for all fields in DbConfigSection and its nested types. The format of the +// flags is json-name.json-sub-name... etc. +func (DbConfigSection) GetPFlagSet(prefix string) *pflag.FlagSet { + cmdFlags := pflag.NewFlagSet("DbConfigSection", pflag.ExitOnError) + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "host"), *new(string), "") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "port"), *new(int), "") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "dbname"), *new(string), "") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "username"), *new(string), "") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "password"), *new(string), "") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "passwordPath"), *new(string), "") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "options"), *new(string), "") + return cmdFlags +} diff --git a/datacatalog/pkg/repositories/config/dbconfigsection_flags_test.go b/datacatalog/pkg/repositories/config/dbconfigsection_flags_test.go new file mode 100755 index 0000000000..f92cc4623e --- /dev/null +++ b/datacatalog/pkg/repositories/config/dbconfigsection_flags_test.go @@ -0,0 +1,232 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots at +// 2019-08-01 11:59:32.35127988 -0700 PDT m=+3.149045889 + +package config + +import ( + "encoding/json" + "fmt" + "reflect" + "testing" + + "github.com/mitchellh/mapstructure" + "github.com/stretchr/testify/assert" +) + +var dereferencableKindsDbConfigSection = map[reflect.Kind]struct{}{ + reflect.Array: {}, reflect.Chan: {}, reflect.Map: {}, reflect.Ptr: {}, reflect.Slice: {}, +} + +// Checks if t is a kind that can be dereferenced to get its underlying type. +func canGetElementDbConfigSection(t reflect.Kind) bool { + _, exists := dereferencableKindsDbConfigSection[t] + return exists +} + +// This decoder hook tests types for json unmarshaling capability. If implemented, it uses json unmarshal to build the +// object. Otherwise, it'll just pass on the original data. +func jsonUnmarshalerHookDbConfigSection(_, to reflect.Type, data interface{}) (interface{}, error) { + unmarshalerType := reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() + if to.Implements(unmarshalerType) || reflect.PtrTo(to).Implements(unmarshalerType) || + (canGetElementDbConfigSection(to.Kind()) && to.Elem().Implements(unmarshalerType)) { + + raw, err := json.Marshal(data) + if err != nil { + fmt.Printf("Failed to marshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + res := reflect.New(to).Interface() + err = json.Unmarshal(raw, &res) + if err != nil { + fmt.Printf("Failed to umarshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + return res, nil + } + + return data, nil +} + +func decode_DbConfigSection(input, result interface{}) error { + config := &mapstructure.DecoderConfig{ + TagName: "json", + WeaklyTypedInput: true, + Result: result, + DecodeHook: mapstructure.ComposeDecodeHookFunc( + mapstructure.StringToTimeDurationHookFunc(), + mapstructure.StringToSliceHookFunc(","), + jsonUnmarshalerHookDbConfigSection, + ), + } + + decoder, err := mapstructure.NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(input) +} + +func testDecodeJson_DbConfigSection(t *testing.T, val, result interface{}) { + assert.NoError(t, decode_DbConfigSection(val, result)) +} + +func testDecodeSlice_DbConfigSection(t *testing.T, vStringSlice, result interface{}) { + assert.NoError(t, decode_DbConfigSection(vStringSlice, result)) +} + +func TestDbConfigSection_GetPFlagSet(t *testing.T) { + val := DbConfigSection{} + cmdFlags := val.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) +} + +func TestDbConfigSection_SetFlags(t *testing.T) { + actual := DbConfigSection{} + cmdFlags := actual.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) + + t.Run("Test_host", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("host"); err == nil { + assert.Equal(t, *new(string), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + cmdFlags.Set("host", "1") + if vString, err := cmdFlags.GetString("host"); err == nil { + testDecodeJson_DbConfigSection(t, fmt.Sprintf("%v", vString), &actual.Host) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_port", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("port"); err == nil { + assert.Equal(t, *new(int), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + cmdFlags.Set("port", "1") + if vInt, err := cmdFlags.GetInt("port"); err == nil { + testDecodeJson_DbConfigSection(t, fmt.Sprintf("%v", vInt), &actual.Port) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_dbname", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("dbname"); err == nil { + assert.Equal(t, *new(string), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + cmdFlags.Set("dbname", "1") + if vString, err := cmdFlags.GetString("dbname"); err == nil { + testDecodeJson_DbConfigSection(t, fmt.Sprintf("%v", vString), &actual.DbName) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_username", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("username"); err == nil { + assert.Equal(t, *new(string), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + cmdFlags.Set("username", "1") + if vString, err := cmdFlags.GetString("username"); err == nil { + testDecodeJson_DbConfigSection(t, fmt.Sprintf("%v", vString), &actual.User) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_password", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("password"); err == nil { + assert.Equal(t, *new(string), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + cmdFlags.Set("password", "1") + if vString, err := cmdFlags.GetString("password"); err == nil { + testDecodeJson_DbConfigSection(t, fmt.Sprintf("%v", vString), &actual.Password) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_passwordPath", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("passwordPath"); err == nil { + assert.Equal(t, *new(string), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + cmdFlags.Set("passwordPath", "1") + if vString, err := cmdFlags.GetString("passwordPath"); err == nil { + testDecodeJson_DbConfigSection(t, fmt.Sprintf("%v", vString), &actual.PasswordPath) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_options", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("options"); err == nil { + assert.Equal(t, *new(string), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + cmdFlags.Set("options", "1") + if vString, err := cmdFlags.GetString("options"); err == nil { + testDecodeJson_DbConfigSection(t, fmt.Sprintf("%v", vString), &actual.ExtraOptions) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) +} diff --git a/datacatalog/pkg/repositories/config/postgres.go b/datacatalog/pkg/repositories/config/postgres.go new file mode 100644 index 0000000000..9ee7ee52c1 --- /dev/null +++ b/datacatalog/pkg/repositories/config/postgres.go @@ -0,0 +1,81 @@ +package config + +import ( + "fmt" + + "github.com/lyft/flytestdlib/promutils" + + "github.com/jinzhu/gorm" + _ "github.com/jinzhu/gorm/dialects/postgres" // Required to import database driver. +) + +const Postgres = "postgres" + +// Generic interface for providing a config necessary to open a database connection. +type DbConnectionConfigProvider interface { + // Returns the database type. For instance PostgreSQL or MySQL. + GetType() string + // Returns arguments specific for the database type necessary to open a database connection. + GetArgs() string + // Enables verbose logging. + WithDebugModeEnabled() + // Disables verbose logging. + WithDebugModeDisabled() + // Returns whether verbose logging is enabled or not. + IsDebug() bool +} + +type BaseConfig struct { + IsDebug bool +} + +// PostgreSQL implementation for DbConnectionConfigProvider. +type PostgresConfigProvider struct { + config DbConfig + scope promutils.Scope +} + +// TODO : Make the Config provider itself env based +func NewPostgresConfigProvider(config DbConfig, scope promutils.Scope) DbConnectionConfigProvider { + return &PostgresConfigProvider{ + config: config, + scope: scope, + } +} + +func (p *PostgresConfigProvider) GetType() string { + return Postgres +} + +func (p *PostgresConfigProvider) GetArgs() string { + if p.config.Password == "" { + // Switch for development + return fmt.Sprintf("host=%s port=%d dbname=%s user=%s sslmode=disable", + p.config.Host, p.config.Port, p.config.DbName, p.config.User) + } + return fmt.Sprintf("host=%s port=%d dbname=%s user=%s password=%s", + p.config.Host, p.config.Port, p.config.DbName, p.config.User, p.config.Password) +} + +func (p *PostgresConfigProvider) WithDebugModeEnabled() { + p.config.IsDebug = true +} + +func (p *PostgresConfigProvider) WithDebugModeDisabled() { + p.config.IsDebug = false +} + +func (p *PostgresConfigProvider) IsDebug() bool { + return p.config.IsDebug +} + +// Opens a connection to the database specified in the config. +// You must call CloseDbConnection at the end of your session! +func OpenDbConnection(config DbConnectionConfigProvider) (*gorm.DB, error) { + db, err := gorm.Open(config.GetType(), config.GetArgs()) + if err != nil { + return nil, err + } + db.LogMode(config.IsDebug()) + return db, nil +} diff --git a/datacatalog/pkg/repositories/errors/error_transformer.go b/datacatalog/pkg/repositories/errors/error_transformer.go new file mode 100644 index 0000000000..ad35dfaaaf --- /dev/null +++ b/datacatalog/pkg/repositories/errors/error_transformer.go @@ -0,0 +1,6 @@ +package errors + +// Defines the basic error transformer interface that all database types must implement. +type ErrorTransformer interface { + ToDataCatalogError(err error) error +} diff --git a/datacatalog/pkg/repositories/errors/errors.go b/datacatalog/pkg/repositories/errors/errors.go new file mode 100644 index 0000000000..289ca5a720 --- /dev/null +++ b/datacatalog/pkg/repositories/errors/errors.go @@ -0,0 +1,16 @@ +// Generic errors used in the repos layer +package errors + +import ( + "github.com/golang/protobuf/proto" + "github.com/lyft/datacatalog/pkg/errors" + "google.golang.org/grpc/codes" +) + +const ( + notFound = "missing entity of type %s with identifier %v" +) + +func GetMissingEntityError(entityType string, identifier proto.Message) error { + return errors.NewDataCatalogErrorf(codes.NotFound, notFound, entityType, identifier) +} diff --git a/datacatalog/pkg/repositories/errors/postgres.go b/datacatalog/pkg/repositories/errors/postgres.go new file mode 100644 index 0000000000..69c7de79e0 --- /dev/null +++ b/datacatalog/pkg/repositories/errors/postgres.go @@ -0,0 +1,54 @@ +package errors + +import ( + "fmt" + + "github.com/jinzhu/gorm" + "github.com/lib/pq" + "github.com/lyft/datacatalog/pkg/errors" + "google.golang.org/grpc/codes" +) + +// Postgres error codes +const ( + uniqueConstraintViolationCode = "23505" + undefinedTable = "42P01" +) + +type postgresErrorTransformer struct { +} + +const ( + unexpectedType = "unexpected error type for: %v" + uniqueConstraintViolation = "value with matching %s already exists (%s)" + defaultPgError = "failed database operation with %s" + unsupportedTableOperation = "cannot query with specified table attributes: %s" +) + +func (p *postgresErrorTransformer) fromGormError(err error) error { + switch err.Error() { + case gorm.ErrRecordNotFound.Error(): + return errors.NewDataCatalogErrorf(codes.NotFound, "entry not found") + default: + return errors.NewDataCatalogErrorf(codes.Internal, unexpectedType, err) + } +} + +func (p *postgresErrorTransformer) ToDataCatalogError(err error) error { + pqError, ok := err.(*pq.Error) + if !ok { + return p.fromGormError(err) + } + switch pqError.Code { + case uniqueConstraintViolationCode: + return errors.NewDataCatalogErrorf(codes.AlreadyExists, uniqueConstraintViolation, pqError.Constraint, pqError.Message) + case undefinedTable: + return errors.NewDataCatalogErrorf(codes.InvalidArgument, unsupportedTableOperation, pqError.Message) + default: + return errors.NewDataCatalogErrorf(codes.Unknown, fmt.Sprintf(defaultPgError, pqError.Message)) + } +} + +func NewPostgresErrorTransformer() ErrorTransformer { + return &postgresErrorTransformer{} +} diff --git a/datacatalog/pkg/repositories/factory.go b/datacatalog/pkg/repositories/factory.go new file mode 100644 index 0000000000..535b091db8 --- /dev/null +++ b/datacatalog/pkg/repositories/factory.go @@ -0,0 +1,45 @@ +package repositories + +import ( + "fmt" + + "github.com/lyft/datacatalog/pkg/repositories/config" + "github.com/lyft/datacatalog/pkg/repositories/errors" + "github.com/lyft/datacatalog/pkg/repositories/interfaces" + "github.com/lyft/flytestdlib/promutils" +) + +type RepoConfig int32 + +const ( + POSTGRES RepoConfig = 0 +) + +var RepositoryConfigurationName = map[RepoConfig]string{ + POSTGRES: "POSTGRES", +} + +// The RepositoryInterface indicates the methods that each Repository must support. +// A Repository indicates a Database which is collection of Tables/models. +// The goal is allow databases to be Plugged in easily. +type RepositoryInterface interface { + DatasetRepo() interfaces.DatasetRepo + ArtifactRepo() interfaces.ArtifactRepo + TagRepo() interfaces.TagRepo +} + +func GetRepository(repoType RepoConfig, dbConfig config.DbConfig, scope promutils.Scope) RepositoryInterface { + switch repoType { + case POSTGRES: + db, err := config.OpenDbConnection(config.NewPostgresConfigProvider(dbConfig, scope.NewSubScope("postgres"))) + if err != nil { + panic(err) + } + return NewPostgresRepo( + db, + errors.NewPostgresErrorTransformer(), + scope.NewSubScope("repositories")) + default: + panic(fmt.Sprintf("Invalid repoType %v", repoType)) + } +} diff --git a/datacatalog/pkg/repositories/gormimpl/artifact.go b/datacatalog/pkg/repositories/gormimpl/artifact.go new file mode 100644 index 0000000000..b39093ebcd --- /dev/null +++ b/datacatalog/pkg/repositories/gormimpl/artifact.go @@ -0,0 +1,67 @@ +package gormimpl + +import ( + "context" + + "github.com/jinzhu/gorm" + "github.com/lyft/datacatalog/pkg/repositories/errors" + "github.com/lyft/datacatalog/pkg/repositories/interfaces" + "github.com/lyft/datacatalog/pkg/repositories/models" + datacatalog "github.com/lyft/datacatalog/protos/gen" +) + +type artifactRepo struct { + db *gorm.DB + errorTransformer errors.ErrorTransformer + // TODO: add metrics +} + +func NewArtifactRepo(db *gorm.DB, errorTransformer errors.ErrorTransformer) interfaces.ArtifactRepo { + return &artifactRepo{ + db: db, + errorTransformer: errorTransformer, + } +} + +func (h *artifactRepo) Create(ctx context.Context, artifact models.Artifact) error { + // Create the artifact in a transaction because ArtifactData will be created and associated along with it + tx := h.db.Begin() + + tx = tx.Create(&artifact) + + if tx.Error != nil { + tx.Rollback() + return h.errorTransformer.ToDataCatalogError(tx.Error) + } + + tx = tx.Commit() + if tx.Error != nil { + return h.errorTransformer.ToDataCatalogError(tx.Error) + } + + return nil +} + +func (h *artifactRepo) Get(ctx context.Context, in models.ArtifactKey) (models.Artifact, error) { + var artifact models.Artifact + result := h.db.Preload("ArtifactData").Find(&artifact, &models.Artifact{ + ArtifactKey: in, + }) + + if result.Error != nil { + return models.Artifact{}, h.errorTransformer.ToDataCatalogError(result.Error) + } + if result.RecordNotFound() { + return models.Artifact{}, errors.GetMissingEntityError("Artifact", &datacatalog.Artifact{ + Dataset: &datacatalog.DatasetID{ + Project: in.DatasetProject, + Domain: in.DatasetDomain, + Name: in.DatasetName, + Version: in.DatasetVersion, + }, + Id: in.ArtifactID, + }) + } + + return artifact, nil +} diff --git a/datacatalog/pkg/repositories/gormimpl/artifact_test.go b/datacatalog/pkg/repositories/gormimpl/artifact_test.go new file mode 100644 index 0000000000..4446bbae5c --- /dev/null +++ b/datacatalog/pkg/repositories/gormimpl/artifact_test.go @@ -0,0 +1,168 @@ +package gormimpl + +import ( + "testing" + + "context" + + mocket "github.com/Selvatico/go-mocket" + "github.com/stretchr/testify/assert" + + "database/sql/driver" + + apiErrors "github.com/lyft/datacatalog/pkg/errors" + "github.com/lyft/datacatalog/pkg/repositories/errors" + "github.com/lyft/datacatalog/pkg/repositories/models" + "github.com/lyft/datacatalog/pkg/repositories/utils" + "google.golang.org/grpc/codes" +) + +func getTestArtifact() models.Artifact { + return models.Artifact{ + ArtifactKey: models.ArtifactKey{ + ArtifactID: "123", + DatasetProject: "testProject", + DatasetDomain: "testDomain", + DatasetName: "testName", + DatasetVersion: "testVersion", + }, + } +} + +func TestCreateArtifact(t *testing.T) { + artifact := getTestArtifact() + + artifactCreated := false + GlobalMock := mocket.Catcher.Reset() + GlobalMock.Logging = true + + numArtifactDataCreated := 0 + + // Only match on queries that append expected filters + GlobalMock.NewMock().WithQuery( + `INSERT INTO "artifacts" ("created_at","updated_at","deleted_at","dataset_project","dataset_name","dataset_domain","dataset_version","artifact_id","serialized_metadata") VALUES (?,?,?,?,?,?,?,?,?)`).WithCallback( + func(s string, values []driver.NamedValue) { + artifactCreated = true + }, + ) + + GlobalMock.NewMock().WithQuery( + `INSERT INTO "artifact_data" ("created_at","updated_at","deleted_at","dataset_project","dataset_name","dataset_domain","dataset_version","artifact_id","name","location") VALUES (?,?,?,?,?,?,?,?,?,?)`).WithCallback( + func(s string, values []driver.NamedValue) { + numArtifactDataCreated++ + }, + ) + + data := make([]models.ArtifactData, 2) + data[0] = models.ArtifactData{ + Name: "test", + Location: "dataloc", + } + data[1] = models.ArtifactData{ + Name: "test2", + Location: "dataloc2", + } + + artifact.ArtifactData = data + + artifactRepo := NewArtifactRepo(utils.GetDbForTest(t), errors.NewPostgresErrorTransformer()) + err := artifactRepo.Create(context.Background(), artifact) + assert.NoError(t, err) + assert.True(t, artifactCreated) + assert.Equal(t, 2, numArtifactDataCreated) +} + +func TestGetArtifact(t *testing.T) { + artifact := getTestArtifact() + + expectedArtifactDataResponse := make([]map[string]interface{}, 0) + sampleArtifactData := make(map[string]interface{}) + sampleArtifactData["dataset_project"] = artifact.DatasetProject + sampleArtifactData["dataset_domain"] = artifact.DatasetDomain + sampleArtifactData["dataset_name"] = artifact.DatasetName + sampleArtifactData["dataset_version"] = artifact.DatasetVersion + sampleArtifactData["artifact_id"] = artifact.ArtifactID + sampleArtifactData["name"] = "test-dataloc-name" + sampleArtifactData["location"] = "test-dataloc-location" + + expectedArtifactDataResponse = append(expectedArtifactDataResponse, sampleArtifactData) + + expectedArtifactResponse := make([]map[string]interface{}, 0) + sampleArtifact := make(map[string]interface{}) + sampleArtifact["dataset_project"] = artifact.DatasetProject + sampleArtifact["dataset_domain"] = artifact.DatasetDomain + sampleArtifact["dataset_name"] = artifact.DatasetName + sampleArtifact["dataset_version"] = artifact.DatasetVersion + sampleArtifact["artifact_id"] = artifact.ArtifactID + expectedArtifactResponse = append(expectedArtifactResponse, sampleArtifact) + + GlobalMock := mocket.Catcher.Reset() + GlobalMock.Logging = true + + // Only match on queries that append expected filters + GlobalMock.NewMock().WithQuery( + `SELECT * FROM "artifacts" WHERE "artifacts"."deleted_at" IS NULL AND (("artifacts"."dataset_project" = testProject) AND ("artifacts"."dataset_name" = testName) AND ("artifacts"."dataset_domain" = testDomain) AND ("artifacts"."dataset_version" = testVersion) AND ("artifacts"."artifact_id" = 123))`).WithReply(expectedArtifactResponse) + GlobalMock.NewMock().WithQuery( + `SELECT * FROM "artifact_data" WHERE "artifact_data"."deleted_at" IS NULL AND ((("dataset_project","dataset_name","dataset_domain","dataset_version","artifact_id") IN ((testProject,testName,testDomain,testVersion,123))))`).WithReply(expectedArtifactDataResponse) + getInput := models.ArtifactKey{ + DatasetProject: artifact.DatasetProject, + DatasetDomain: artifact.DatasetDomain, + DatasetName: artifact.DatasetName, + DatasetVersion: artifact.DatasetVersion, + ArtifactID: artifact.ArtifactID, + } + + artifactRepo := NewArtifactRepo(utils.GetDbForTest(t), errors.NewPostgresErrorTransformer()) + response, err := artifactRepo.Get(context.Background(), getInput) + assert.NoError(t, err) + assert.Equal(t, artifact.ArtifactID, response.ArtifactID) + assert.Equal(t, artifact.DatasetProject, response.DatasetProject) + assert.Equal(t, artifact.DatasetDomain, response.DatasetDomain) + assert.Equal(t, artifact.DatasetName, response.DatasetName) + assert.Equal(t, artifact.DatasetVersion, response.DatasetVersion) + + assert.Equal(t, 1, len(response.ArtifactData)) +} + +func TestGetArtifactDoesNotExist(t *testing.T) { + artifact := getTestArtifact() + + GlobalMock := mocket.Catcher.Reset() + GlobalMock.Logging = true + + getInput := models.ArtifactKey{ + DatasetProject: artifact.DatasetProject, + DatasetDomain: artifact.DatasetDomain, + DatasetName: artifact.DatasetName, + DatasetVersion: artifact.DatasetVersion, + ArtifactID: artifact.ArtifactID, + } + + // by default mocket will return nil for any queries + artifactRepo := NewArtifactRepo(utils.GetDbForTest(t), errors.NewPostgresErrorTransformer()) + _, err := artifactRepo.Get(context.Background(), getInput) + assert.Error(t, err) + dcErr, ok := err.(apiErrors.DataCatalogError) + assert.True(t, ok) + assert.Equal(t, dcErr.Code(), codes.NotFound) +} + +func TestCreateArtifactAlreadyExists(t *testing.T) { + artifact := getTestArtifact() + + GlobalMock := mocket.Catcher.Reset() + GlobalMock.Logging = true + + // Only match on queries that append expected filters + GlobalMock.NewMock().WithQuery( + `INSERT INTO "artifacts" ("created_at","updated_at","deleted_at","dataset_project","dataset_name","dataset_domain","dataset_version","artifact_id","serialized_metadata") VALUES (?,?,?,?,?,?,?,?,?)`).WithError( + getAlreadyExistsErr(), + ) + + artifactRepo := NewArtifactRepo(utils.GetDbForTest(t), errors.NewPostgresErrorTransformer()) + err := artifactRepo.Create(context.Background(), artifact) + assert.Error(t, err) + dcErr, ok := err.(apiErrors.DataCatalogError) + assert.True(t, ok) + assert.Equal(t, dcErr.Code(), codes.AlreadyExists) +} diff --git a/datacatalog/pkg/repositories/gormimpl/dataset.go b/datacatalog/pkg/repositories/gormimpl/dataset.go new file mode 100644 index 0000000000..e74dca6260 --- /dev/null +++ b/datacatalog/pkg/repositories/gormimpl/dataset.go @@ -0,0 +1,57 @@ +package gormimpl + +import ( + "context" + + "github.com/jinzhu/gorm" + "github.com/lyft/datacatalog/pkg/repositories/errors" + "github.com/lyft/datacatalog/pkg/repositories/interfaces" + "github.com/lyft/datacatalog/pkg/repositories/models" + + idl_datacatalog "github.com/lyft/datacatalog/protos/gen" + "github.com/lyft/flytestdlib/logger" +) + +type dataSetRepo struct { + db *gorm.DB + errorTransformer errors.ErrorTransformer + + // TODO: add metrics +} + +func NewDatasetRepo(db *gorm.DB, errorTransformer errors.ErrorTransformer) interfaces.DatasetRepo { + return &dataSetRepo{ + db: db, + errorTransformer: errorTransformer, + } +} + +// Create a Dataset model +func (h *dataSetRepo) Create(ctx context.Context, in models.Dataset) error { + result := h.db.Create(&in) + if result.Error != nil { + return h.errorTransformer.ToDataCatalogError(result.Error) + } + return nil +} + +// Get Dataset model +func (h *dataSetRepo) Get(ctx context.Context, in models.DatasetKey) (models.Dataset, error) { + var ds models.Dataset + result := h.db.Where(&models.Dataset{DatasetKey: in}).First(&ds) + + if result.Error != nil { + logger.Debugf(ctx, "Unable to find Dataset: [%+v], err: %v", in, result.Error) + return models.Dataset{}, h.errorTransformer.ToDataCatalogError(result.Error) + } + if result.RecordNotFound() { + return models.Dataset{}, errors.GetMissingEntityError("Dataset", &idl_datacatalog.DatasetID{ + Project: in.Project, + Domain: in.Domain, + Name: in.Name, + Version: in.Version, + }) + } + + return ds, nil +} diff --git a/datacatalog/pkg/repositories/gormimpl/dataset_test.go b/datacatalog/pkg/repositories/gormimpl/dataset_test.go new file mode 100644 index 0000000000..045cf1e7de --- /dev/null +++ b/datacatalog/pkg/repositories/gormimpl/dataset_test.go @@ -0,0 +1,122 @@ +package gormimpl + +import ( + "testing" + + "context" + + mocket "github.com/Selvatico/go-mocket" + "google.golang.org/grpc/codes" + + "github.com/stretchr/testify/assert" + + "database/sql/driver" + + datacatalog_error "github.com/lyft/datacatalog/pkg/errors" + "github.com/lyft/datacatalog/pkg/repositories/errors" + "github.com/lyft/datacatalog/pkg/repositories/models" + "github.com/lyft/datacatalog/pkg/repositories/utils" +) + +func getTestDataset() models.Dataset { + return models.Dataset{ + DatasetKey: models.DatasetKey{ + Project: "testProject", + Domain: "testDomain", + Name: "testName", + Version: "testVersion", + }, + SerializedMetadata: []byte{1, 2, 3}, + } +} + +func TestCreateDataset(t *testing.T) { + dataset := getTestDataset() + datasetCreated := false + GlobalMock := mocket.Catcher.Reset() + GlobalMock.Logging = true + + // Only match on queries that append expected filters + GlobalMock.NewMock().WithQuery( + `INSERT INTO "datasets" ("created_at","updated_at","deleted_at","project","name","domain","version","serialized_metadata") VALUES (?,?,?,?,?,?,?,?)`).WithCallback( + func(s string, values []driver.NamedValue) { + assert.EqualValues(t, dataset.Project, values[3].Value) + assert.EqualValues(t, dataset.Name, values[4].Value) + assert.EqualValues(t, dataset.Domain, values[5].Value) + assert.EqualValues(t, dataset.Version, values[6].Value) + assert.EqualValues(t, dataset.SerializedMetadata, values[7].Value) + datasetCreated = true + }, + ) + + datasetRepo := NewDatasetRepo(utils.GetDbForTest(t), errors.NewPostgresErrorTransformer()) + err := datasetRepo.Create(context.Background(), getTestDataset()) + assert.NoError(t, err) + assert.True(t, datasetCreated) +} + +func TestGetDataset(t *testing.T) { + dataset := getTestDataset() + expectedResponse := make([]map[string]interface{}, 0) + sampleDataset := make(map[string]interface{}) + sampleDataset["project"] = dataset.Project + sampleDataset["domain"] = dataset.Domain + sampleDataset["name"] = dataset.Name + sampleDataset["version"] = dataset.Version + + expectedResponse = append(expectedResponse, sampleDataset) + + GlobalMock := mocket.Catcher.Reset() + GlobalMock.Logging = true + + // Only match on queries that append expected filters + GlobalMock.NewMock().WithQuery(`SELECT * FROM "datasets" WHERE "datasets"."deleted_at" IS NULL AND (("datasets"."project" = testProject) AND ("datasets"."name" = testName) AND ("datasets"."domain" = testDomain) AND ("datasets"."version" = testVersion)) ORDER BY "datasets"."project" ASC LIMIT 1`).WithReply(expectedResponse) + + datasetRepo := NewDatasetRepo(utils.GetDbForTest(t), errors.NewPostgresErrorTransformer()) + actualDataset, err := datasetRepo.Get(context.Background(), dataset.DatasetKey) + assert.NoError(t, err) + assert.Equal(t, dataset.Project, actualDataset.Project) + assert.Equal(t, dataset.Domain, actualDataset.Domain) + assert.Equal(t, dataset.Name, actualDataset.Name) + assert.Equal(t, dataset.Version, actualDataset.Version) +} + +func TestGetDatasetNotFound(t *testing.T) { + dataset := getTestDataset() + sampleDataset := make(map[string]interface{}) + sampleDataset["project"] = dataset.Project + sampleDataset["domain"] = dataset.Domain + sampleDataset["name"] = dataset.Name + sampleDataset["version"] = dataset.Version + + GlobalMock := mocket.Catcher.Reset() + GlobalMock.Logging = true + + // Only match on queries that append expected filters + GlobalMock.NewMock().WithQuery(`SELECT * FROM "datasets" WHERE "datasets"."deleted_at" IS NULL AND (("datasets"."project" = testProject) AND ("datasets"."name" = testName) AND ("datasets"."domain" = testDomain) AND ("datasets"."version" = testVersion)) ORDER BY "datasets"."id" ASC LIMIT 1`).WithReply(nil) + + datasetRepo := NewDatasetRepo(utils.GetDbForTest(t), errors.NewPostgresErrorTransformer()) + _, err := datasetRepo.Get(context.Background(), dataset.DatasetKey) + assert.Error(t, err) + notFoundErr, ok := err.(datacatalog_error.DataCatalogError) + assert.True(t, ok) + assert.Equal(t, codes.NotFound, notFoundErr.Code()) +} + +func TestCreateDatasetAlreadyExists(t *testing.T) { + GlobalMock := mocket.Catcher.Reset() + GlobalMock.Logging = true + + // Only match on queries that append expected filters + GlobalMock.NewMock().WithQuery( + `INSERT INTO "datasets" ("created_at","updated_at","deleted_at","project","name","domain","version","serialized_metadata") VALUES (?,?,?,?,?,?,?,?)`).WithError( + getAlreadyExistsErr(), + ) + + datasetRepo := NewDatasetRepo(utils.GetDbForTest(t), errors.NewPostgresErrorTransformer()) + err := datasetRepo.Create(context.Background(), getTestDataset()) + assert.Error(t, err) + dcErr, ok := err.(datacatalog_error.DataCatalogError) + assert.True(t, ok) + assert.Equal(t, dcErr.Code(), codes.AlreadyExists) +} diff --git a/datacatalog/pkg/repositories/gormimpl/tag.go b/datacatalog/pkg/repositories/gormimpl/tag.go new file mode 100644 index 0000000000..c16fe937dd --- /dev/null +++ b/datacatalog/pkg/repositories/gormimpl/tag.go @@ -0,0 +1,51 @@ +package gormimpl + +import ( + "context" + + "github.com/jinzhu/gorm" + "github.com/lyft/datacatalog/pkg/repositories/errors" + "github.com/lyft/datacatalog/pkg/repositories/interfaces" + "github.com/lyft/datacatalog/pkg/repositories/models" + idl_datacatalog "github.com/lyft/datacatalog/protos/gen" +) + +type tagRepo struct { + db *gorm.DB + errorTransformer errors.ErrorTransformer + // TODO: add metrics +} + +func NewTagRepo(db *gorm.DB, errorTransformer errors.ErrorTransformer) interfaces.TagRepo { + return &tagRepo{ + db: db, + errorTransformer: errorTransformer, + } +} + +func (h *tagRepo) Create(ctx context.Context, tag models.Tag) error { + db := h.db.Create(&tag) + + if db.Error != nil { + return h.errorTransformer.ToDataCatalogError(db.Error) + } + return nil +} + +func (h *tagRepo) Get(ctx context.Context, in models.TagKey) (models.Tag, error) { + var tag models.Tag + result := h.db.Preload("Artifact").Preload("Artifact.ArtifactData").Find(&tag, &models.Tag{ + TagKey: in, + }) + + if result.Error != nil { + return models.Tag{}, h.errorTransformer.ToDataCatalogError(result.Error) + } + if result.RecordNotFound() { + return models.Tag{}, errors.GetMissingEntityError("Tag", &idl_datacatalog.Tag{ + Name: tag.TagName, + }) + } + + return tag, nil +} diff --git a/datacatalog/pkg/repositories/gormimpl/tag_test.go b/datacatalog/pkg/repositories/gormimpl/tag_test.go new file mode 100644 index 0000000000..c1befea462 --- /dev/null +++ b/datacatalog/pkg/repositories/gormimpl/tag_test.go @@ -0,0 +1,136 @@ +package gormimpl + +import ( + "testing" + + "context" + + mocket "github.com/Selvatico/go-mocket" + "github.com/stretchr/testify/assert" + + "database/sql/driver" + + datacatalog_error "github.com/lyft/datacatalog/pkg/errors" + "github.com/lyft/datacatalog/pkg/repositories/errors" + + "github.com/lib/pq" + "github.com/lyft/datacatalog/pkg/repositories/models" + "github.com/lyft/datacatalog/pkg/repositories/utils" + "google.golang.org/grpc/codes" +) + +func getAlreadyExistsErr() error { + return &pq.Error{Code: "23505"} +} + +func getTestTag() models.Tag { + artifact := getTestArtifact() + return models.Tag{ + TagKey: models.TagKey{ + DatasetProject: artifact.DatasetProject, + DatasetDomain: artifact.DatasetDomain, + DatasetName: artifact.DatasetName, + DatasetVersion: artifact.DatasetVersion, + TagName: "test-tagname", + }, + ArtifactID: artifact.ArtifactID, + } +} + +func TestCreateTag(t *testing.T) { + tagCreated := false + GlobalMock := mocket.Catcher.Reset() + GlobalMock.Logging = true + + // Only match on queries that append expected filters + GlobalMock.NewMock().WithQuery( + `INSERT INTO "tags" ("created_at","updated_at","deleted_at","dataset_project","dataset_name","dataset_domain","dataset_version","tag_name","artifact_id") VALUES (?,?,?,?,?,?,?,?,?)`).WithCallback( + func(s string, values []driver.NamedValue) { + tagCreated = true + }, + ) + + tagRepo := NewTagRepo(utils.GetDbForTest(t), errors.NewPostgresErrorTransformer()) + err := tagRepo.Create(context.Background(), getTestTag()) + assert.NoError(t, err) + assert.True(t, tagCreated) +} + +func TestGetTag(t *testing.T) { + artifact := getTestArtifact() + + expectedArtifactDataResponse := make([]map[string]interface{}, 0) + sampleArtifactData := make(map[string]interface{}) + sampleArtifactData["dataset_project"] = artifact.DatasetProject + sampleArtifactData["dataset_domain"] = artifact.DatasetDomain + sampleArtifactData["dataset_name"] = artifact.DatasetName + sampleArtifactData["dataset_version"] = artifact.DatasetVersion + sampleArtifactData["artifact_id"] = artifact.ArtifactID + sampleArtifactData["name"] = "test-dataloc-name" + sampleArtifactData["location"] = "test-dataloc-location" + + expectedArtifactDataResponse = append(expectedArtifactDataResponse, sampleArtifactData) + + expectedArtifactResponse := make([]map[string]interface{}, 0) + sampleArtifact := make(map[string]interface{}) + sampleArtifact["dataset_project"] = artifact.DatasetProject + sampleArtifact["dataset_domain"] = artifact.DatasetDomain + sampleArtifact["dataset_name"] = artifact.DatasetName + sampleArtifact["dataset_version"] = artifact.DatasetVersion + sampleArtifact["artifact_id"] = artifact.ArtifactID + expectedArtifactResponse = append(expectedArtifactResponse, sampleArtifact) + + expectedTagResponse := make([]map[string]interface{}, 0) + sampleTag := make(map[string]interface{}) + sampleTag["dataset_project"] = artifact.DatasetProject + sampleTag["dataset_domain"] = artifact.DatasetDomain + sampleTag["dataset_name"] = artifact.DatasetName + sampleTag["dataset_version"] = artifact.DatasetVersion + sampleTag["artifact_id"] = artifact.ArtifactID + sampleTag["name"] = "test-tag" + expectedTagResponse = append(expectedTagResponse, sampleTag) + + GlobalMock := mocket.Catcher.Reset() + GlobalMock.Logging = true + + // Only match on queries that append expected filters + GlobalMock.NewMock().WithQuery( + `SELECT * FROM "tags" WHERE "tags"."deleted_at" IS NULL AND (("tags"."dataset_project" = testProject) AND ("tags"."dataset_name" = testName) AND ("tags"."dataset_domain" = testDomain) AND ("tags"."dataset_version" = testVersion) AND ("tags"."tag_name" = test-tag))`).WithReply(expectedTagResponse) + GlobalMock.NewMock().WithQuery( + `SELECT * FROM "artifacts" WHERE "artifacts"."deleted_at" IS NULL AND ((("dataset_project","dataset_name","dataset_domain","dataset_version","artifact_id") IN ((testProject,testName,testDomain,testVersion,123))))`).WithReply(expectedArtifactResponse) + GlobalMock.NewMock().WithQuery( + `SELECT * FROM "artifact_data" WHERE "artifact_data"."deleted_at" IS NULL AND ((("dataset_project","dataset_name","dataset_domain","dataset_version","artifact_id") IN ((testProject,testName,testDomain,testVersion,123))))`).WithReply(expectedArtifactDataResponse) + + getInput := models.TagKey{ + DatasetProject: artifact.DatasetProject, + DatasetDomain: artifact.DatasetDomain, + DatasetName: artifact.DatasetName, + DatasetVersion: artifact.DatasetVersion, + TagName: "test-tag", + } + + tagRepo := NewTagRepo(utils.GetDbForTest(t), errors.NewPostgresErrorTransformer()) + response, err := tagRepo.Get(context.Background(), getInput) + assert.NoError(t, err) + assert.Equal(t, artifact.ArtifactID, response.ArtifactID) + assert.Equal(t, artifact.ArtifactID, response.Artifact.ArtifactID) + assert.Len(t, response.Artifact.ArtifactData, 1) +} + +func TestTagAlreadyExists(t *testing.T) { + GlobalMock := mocket.Catcher.Reset() + GlobalMock.Logging = true + + // Only match on queries that append expected filters + GlobalMock.NewMock().WithQuery( + `INSERT INTO "tags" ("created_at","updated_at","deleted_at","dataset_project","dataset_name","dataset_domain","dataset_version","tag_name","artifact_id") VALUES (?,?,?,?,?,?,?,?,?)`).WithError( + getAlreadyExistsErr(), + ) + + tagRepo := NewTagRepo(utils.GetDbForTest(t), errors.NewPostgresErrorTransformer()) + err := tagRepo.Create(context.Background(), getTestTag()) + assert.Error(t, err) + dcErr, ok := err.(datacatalog_error.DataCatalogError) + assert.True(t, ok) + assert.Equal(t, dcErr.Code(), codes.AlreadyExists) +} diff --git a/datacatalog/pkg/repositories/handle.go b/datacatalog/pkg/repositories/handle.go new file mode 100644 index 0000000000..069bf163f6 --- /dev/null +++ b/datacatalog/pkg/repositories/handle.go @@ -0,0 +1,77 @@ +package repositories + +import ( + "context" + + "fmt" + + "github.com/jinzhu/gorm" + "github.com/lyft/datacatalog/pkg/repositories/config" + "github.com/lyft/datacatalog/pkg/repositories/models" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/promutils" +) + +type DBHandle struct { + db *gorm.DB +} + +func NewDBHandle(dbConfigValues config.DbConfig, catalogScope promutils.Scope) (*DBHandle, error) { + dbConfig := config.DbConfig{ + Host: dbConfigValues.Host, + Port: dbConfigValues.Port, + DbName: dbConfigValues.DbName, + User: dbConfigValues.User, + Password: dbConfigValues.Password, + ExtraOptions: dbConfigValues.ExtraOptions, + } + + //TODO: abstract away the type of db we are connecting to + db, err := config.OpenDbConnection(config.NewPostgresConfigProvider(dbConfig, catalogScope.NewSubScope("postgres"))) + if err != nil { + return nil, err + } + + out := &DBHandle{ + db: db, + } + + return out, nil +} + +func (h *DBHandle) CreateDB(dbName string) error { + type DatabaseResult struct { + Exists bool + } + var checkExists DatabaseResult + result := h.db.Raw("SELECT EXISTS(SELECT datname FROM pg_catalog.pg_database WHERE datname = ?)", dbName).Scan(&checkExists) + if result.Error != nil { + return result.Error + } + + // create db if it does not exist + if !checkExists.Exists { + logger.Infof(context.TODO(), "Creating Database %v since it does not exist", dbName) + + // NOTE: golang sql drivers do not support parameter injection for CREATE calls + createDBStatement := fmt.Sprintf("CREATE DATABASE %s", dbName) + result = h.db.Exec(createDBStatement) + + if result.Error != nil { + return result.Error + } + } + + return nil +} + +func (h *DBHandle) Migrate() { + h.db.AutoMigrate(&models.Dataset{}) + h.db.AutoMigrate(&models.Artifact{}) + h.db.AutoMigrate(&models.ArtifactData{}) + h.db.AutoMigrate(&models.Tag{}) +} + +func (h *DBHandle) Close() error { + return h.db.Close() +} diff --git a/datacatalog/pkg/repositories/handle_test.go b/datacatalog/pkg/repositories/handle_test.go new file mode 100644 index 0000000000..d3c6a307ae --- /dev/null +++ b/datacatalog/pkg/repositories/handle_test.go @@ -0,0 +1,78 @@ +package repositories + +import ( + "testing" + + mocket "github.com/Selvatico/go-mocket" + "github.com/stretchr/testify/assert" + + "database/sql/driver" + + "github.com/lyft/datacatalog/pkg/repositories/utils" +) + +func TestCreateDB(t *testing.T) { + GlobalMock := mocket.Catcher.Reset() + GlobalMock.Logging = true + + checkExists := false + GlobalMock.NewMock().WithQuery( + `SELECT EXISTS(SELECT datname FROM pg_catalog.pg_database WHERE datname = testDB)`).WithCallback( + func(s string, values []driver.NamedValue) { + checkExists = true + }, + ).WithReply([]map[string]interface{}{ + {"exists": false}, + }) + + createdDB := false + + // NOTE: unfortunately mocket does not support checking CREATE statements, but let's match the suffix + GlobalMock.NewMock().WithQuery( + `DATABASE testDB`).WithCallback( + func(s string, values []driver.NamedValue) { + assert.Equal(t, "CREATE DATABASE testDB", s) + createdDB = true + }, + ) + + db := utils.GetDbForTest(t) + dbHandle := &DBHandle{ + db: db, + } + _ = dbHandle.CreateDB("testDB") + assert.True(t, checkExists) + assert.True(t, createdDB) +} + +func TestDBAlreadyExists(t *testing.T) { + GlobalMock := mocket.Catcher.Reset() + GlobalMock.Logging = true + + checkExists := false + GlobalMock.NewMock().WithQuery( + `SELECT EXISTS(SELECT datname FROM pg_catalog.pg_database WHERE datname = testDB)`).WithCallback( + func(s string, values []driver.NamedValue) { + checkExists = true + }, + ).WithReply([]map[string]interface{}{ + {"exists": true}, + }) + + createdDB := false + GlobalMock.NewMock().WithQuery( + `DATABASE testDB`).WithCallback( + func(s string, values []driver.NamedValue) { + createdDB = false + }, + ) + + db := utils.GetDbForTest(t) + dbHandle := &DBHandle{ + db: db, + } + err := dbHandle.CreateDB("testDB") + assert.NoError(t, err) + assert.True(t, checkExists) + assert.False(t, createdDB) +} diff --git a/datacatalog/pkg/repositories/interfaces/artifact_repo.go b/datacatalog/pkg/repositories/interfaces/artifact_repo.go new file mode 100644 index 0000000000..db97c7efb2 --- /dev/null +++ b/datacatalog/pkg/repositories/interfaces/artifact_repo.go @@ -0,0 +1,12 @@ +package interfaces + +import ( + "context" + + "github.com/lyft/datacatalog/pkg/repositories/models" +) + +type ArtifactRepo interface { + Create(ctx context.Context, in models.Artifact) error + Get(ctx context.Context, in models.ArtifactKey) (models.Artifact, error) +} diff --git a/datacatalog/pkg/repositories/interfaces/base.go b/datacatalog/pkg/repositories/interfaces/base.go new file mode 100644 index 0000000000..864d724dfe --- /dev/null +++ b/datacatalog/pkg/repositories/interfaces/base.go @@ -0,0 +1,7 @@ +package interfaces + +type DataCatalogRepo interface { + DatasetRepo() DatasetRepo + ArtifactRepo() ArtifactRepo + TagRepo() TagRepo +} diff --git a/datacatalog/pkg/repositories/interfaces/dataset_repo.go b/datacatalog/pkg/repositories/interfaces/dataset_repo.go new file mode 100644 index 0000000000..eee33f8898 --- /dev/null +++ b/datacatalog/pkg/repositories/interfaces/dataset_repo.go @@ -0,0 +1,12 @@ +package interfaces + +import ( + "context" + + "github.com/lyft/datacatalog/pkg/repositories/models" +) + +type DatasetRepo interface { + Create(ctx context.Context, in models.Dataset) error + Get(ctx context.Context, in models.DatasetKey) (models.Dataset, error) +} diff --git a/datacatalog/pkg/repositories/interfaces/tag_repo.go b/datacatalog/pkg/repositories/interfaces/tag_repo.go new file mode 100644 index 0000000000..0746850edf --- /dev/null +++ b/datacatalog/pkg/repositories/interfaces/tag_repo.go @@ -0,0 +1,12 @@ +package interfaces + +import ( + "context" + + "github.com/lyft/datacatalog/pkg/repositories/models" +) + +type TagRepo interface { + Create(ctx context.Context, in models.Tag) error + Get(ctx context.Context, in models.TagKey) (models.Tag, error) +} diff --git a/datacatalog/pkg/repositories/mocks/artifact.go b/datacatalog/pkg/repositories/mocks/artifact.go new file mode 100644 index 0000000000..8477577efa --- /dev/null +++ b/datacatalog/pkg/repositories/mocks/artifact.go @@ -0,0 +1,48 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import context "context" + +import mock "github.com/stretchr/testify/mock" +import models "github.com/lyft/datacatalog/pkg/repositories/models" + +// ArtifactRepo is an autogenerated mock type for the ArtifactRepo type +type ArtifactRepo struct { + mock.Mock +} + +// Create provides a mock function with given fields: ctx, in +func (_m *ArtifactRepo) Create(ctx context.Context, in models.Artifact) error { + ret := _m.Called(ctx, in) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, models.Artifact) error); ok { + r0 = rf(ctx, in) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Get provides a mock function with given fields: ctx, in +func (_m *ArtifactRepo) Get(ctx context.Context, in models.ArtifactKey) (models.Artifact, error) { + ret := _m.Called(ctx, in) + + var r0 models.Artifact + if rf, ok := ret.Get(0).(func(context.Context, models.ArtifactKey) models.Artifact); ok { + r0 = rf(ctx, in) + } else { + r0 = ret.Get(0).(models.Artifact) + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, models.ArtifactKey) error); ok { + r1 = rf(ctx, in) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/datacatalog/pkg/repositories/mocks/base.go b/datacatalog/pkg/repositories/mocks/base.go new file mode 100644 index 0000000000..a728be0afe --- /dev/null +++ b/datacatalog/pkg/repositories/mocks/base.go @@ -0,0 +1,21 @@ +package mocks + +import "github.com/lyft/datacatalog/pkg/repositories/interfaces" + +type DataCatalogRepo struct { + MockDatasetRepo *DatasetRepo + MockArtifactRepo *ArtifactRepo + MockTagRepo *TagRepo +} + +func (m *DataCatalogRepo) DatasetRepo() interfaces.DatasetRepo { + return m.MockDatasetRepo +} + +func (m *DataCatalogRepo) ArtifactRepo() interfaces.ArtifactRepo { + return m.MockArtifactRepo +} + +func (m *DataCatalogRepo) TagRepo() interfaces.TagRepo { + return m.MockTagRepo +} diff --git a/datacatalog/pkg/repositories/mocks/dataset.go b/datacatalog/pkg/repositories/mocks/dataset.go new file mode 100644 index 0000000000..fa038bb748 --- /dev/null +++ b/datacatalog/pkg/repositories/mocks/dataset.go @@ -0,0 +1,48 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import context "context" + +import mock "github.com/stretchr/testify/mock" +import models "github.com/lyft/datacatalog/pkg/repositories/models" + +// DatasetRepo is an autogenerated mock type for the DatasetRepo type +type DatasetRepo struct { + mock.Mock +} + +// Create provides a mock function with given fields: ctx, in +func (_m *DatasetRepo) Create(ctx context.Context, in models.Dataset) error { + ret := _m.Called(ctx, in) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, models.Dataset) error); ok { + r0 = rf(ctx, in) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Get provides a mock function with given fields: ctx, in +func (_m *DatasetRepo) Get(ctx context.Context, in models.DatasetKey) (models.Dataset, error) { + ret := _m.Called(ctx, in) + + var r0 models.Dataset + if rf, ok := ret.Get(0).(func(context.Context, models.DatasetKey) models.Dataset); ok { + r0 = rf(ctx, in) + } else { + r0 = ret.Get(0).(models.Dataset) + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, models.DatasetKey) error); ok { + r1 = rf(ctx, in) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/datacatalog/pkg/repositories/mocks/tag.go b/datacatalog/pkg/repositories/mocks/tag.go new file mode 100644 index 0000000000..0a2adf1580 --- /dev/null +++ b/datacatalog/pkg/repositories/mocks/tag.go @@ -0,0 +1,48 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import context "context" + +import mock "github.com/stretchr/testify/mock" +import models "github.com/lyft/datacatalog/pkg/repositories/models" + +// TagRepo is an autogenerated mock type for the TagRepo type +type TagRepo struct { + mock.Mock +} + +// Create provides a mock function with given fields: ctx, in +func (_m *TagRepo) Create(ctx context.Context, in models.Tag) error { + ret := _m.Called(ctx, in) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, models.Tag) error); ok { + r0 = rf(ctx, in) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Get provides a mock function with given fields: ctx, in +func (_m *TagRepo) Get(ctx context.Context, in models.TagKey) (models.Tag, error) { + ret := _m.Called(ctx, in) + + var r0 models.Tag + if rf, ok := ret.Get(0).(func(context.Context, models.TagKey) models.Tag); ok { + r0 = rf(ctx, in) + } else { + r0 = ret.Get(0).(models.Tag) + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, models.TagKey) error); ok { + r1 = rf(ctx, in) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/datacatalog/pkg/repositories/models/artifact.go b/datacatalog/pkg/repositories/models/artifact.go new file mode 100644 index 0000000000..10a1bdbafa --- /dev/null +++ b/datacatalog/pkg/repositories/models/artifact.go @@ -0,0 +1,24 @@ +package models + +type ArtifactKey struct { + DatasetProject string `gorm:"primary_key"` + DatasetName string `gorm:"primary_key"` + DatasetDomain string `gorm:"primary_key"` + DatasetVersion string `gorm:"primary_key"` + ArtifactID string `gorm:"primary_key"` +} + +type Artifact struct { + BaseModel + ArtifactKey + Dataset Dataset `gorm:"association_autocreate:false"` + ArtifactData []ArtifactData `gorm:"association_foreignkey:DatasetProject,DatasetName,DatasetDomain,DatasetVersion,ArtifactID;foreignkey:DatasetProject,DatasetName,DatasetDomain,DatasetVersion,ArtifactID"` + SerializedMetadata []byte +} + +type ArtifactData struct { + BaseModel + ArtifactKey + Name string `gorm:"primary_key"` + Location string +} diff --git a/datacatalog/pkg/repositories/models/base.go b/datacatalog/pkg/repositories/models/base.go new file mode 100644 index 0000000000..98b3885bbf --- /dev/null +++ b/datacatalog/pkg/repositories/models/base.go @@ -0,0 +1,9 @@ +package models + +import "time" + +type BaseModel struct { + CreatedAt time.Time + UpdatedAt time.Time + DeletedAt *time.Time `sql:"index"` +} diff --git a/datacatalog/pkg/repositories/models/dataset.go b/datacatalog/pkg/repositories/models/dataset.go new file mode 100644 index 0000000000..18f6f17fb8 --- /dev/null +++ b/datacatalog/pkg/repositories/models/dataset.go @@ -0,0 +1,14 @@ +package models + +type DatasetKey struct { + Project string `gorm:"primary_key"` + Name string `gorm:"primary_key"` + Domain string `gorm:"primary_key"` + Version string `gorm:"primary_key"` +} + +type Dataset struct { + BaseModel + DatasetKey + SerializedMetadata []byte +} diff --git a/datacatalog/pkg/repositories/models/tag.go b/datacatalog/pkg/repositories/models/tag.go new file mode 100644 index 0000000000..790a83b59f --- /dev/null +++ b/datacatalog/pkg/repositories/models/tag.go @@ -0,0 +1,16 @@ +package models + +type TagKey struct { + DatasetProject string `gorm:"primary_key"` + DatasetName string `gorm:"primary_key"` + DatasetDomain string `gorm:"primary_key"` + DatasetVersion string `gorm:"primary_key"` + TagName string `gorm:"primary_key"` +} + +type Tag struct { + BaseModel + TagKey + ArtifactID string + Artifact Artifact `gorm:"association_foreignkey:DatasetProject,DatasetName,DatasetDomain,DatasetVersion,ArtifactID;foreignkey:DatasetProject,DatasetName,DatasetDomain,DatasetVersion,ArtifactID"` +} diff --git a/datacatalog/pkg/repositories/postgres_repo.go b/datacatalog/pkg/repositories/postgres_repo.go new file mode 100644 index 0000000000..f4418ffc21 --- /dev/null +++ b/datacatalog/pkg/repositories/postgres_repo.go @@ -0,0 +1,35 @@ +package repositories + +import ( + "github.com/jinzhu/gorm" + "github.com/lyft/datacatalog/pkg/repositories/errors" + "github.com/lyft/datacatalog/pkg/repositories/gormimpl" + "github.com/lyft/datacatalog/pkg/repositories/interfaces" + "github.com/lyft/flytestdlib/promutils" +) + +type PostgresRepo struct { + datasetRepo interfaces.DatasetRepo + artifactRepo interfaces.ArtifactRepo + tagRepo interfaces.TagRepo +} + +func (dc *PostgresRepo) DatasetRepo() interfaces.DatasetRepo { + return dc.datasetRepo +} + +func (dc *PostgresRepo) ArtifactRepo() interfaces.ArtifactRepo { + return dc.artifactRepo +} + +func (dc *PostgresRepo) TagRepo() interfaces.TagRepo { + return dc.tagRepo +} + +func NewPostgresRepo(db *gorm.DB, errorTransformer errors.ErrorTransformer, scope promutils.Scope) interfaces.DataCatalogRepo { + return &PostgresRepo{ + datasetRepo: gormimpl.NewDatasetRepo(db, errorTransformer), + artifactRepo: gormimpl.NewArtifactRepo(db, errorTransformer), + tagRepo: gormimpl.NewTagRepo(db, errorTransformer), + } +} diff --git a/datacatalog/pkg/repositories/transformers/artifact.go b/datacatalog/pkg/repositories/transformers/artifact.go new file mode 100644 index 0000000000..a9b81d0fb5 --- /dev/null +++ b/datacatalog/pkg/repositories/transformers/artifact.go @@ -0,0 +1,55 @@ +package transformers + +import ( + "github.com/lyft/datacatalog/pkg/repositories/models" + datacatalog "github.com/lyft/datacatalog/protos/gen" +) + +func CreateArtifactModel(request datacatalog.CreateArtifactRequest, artifactData []models.ArtifactData) (models.Artifact, error) { + datasetID := request.Artifact.Dataset + + serializedMetadata, err := marshalMetadata(request.Artifact.Metadata) + if err != nil { + return models.Artifact{}, err + } + + return models.Artifact{ + ArtifactKey: models.ArtifactKey{ + DatasetProject: datasetID.Project, + DatasetDomain: datasetID.Domain, + DatasetName: datasetID.Name, + DatasetVersion: datasetID.Version, + ArtifactID: request.Artifact.Id, + }, + ArtifactData: artifactData, + SerializedMetadata: serializedMetadata, + }, nil +} + +func FromArtifactModel(artifact models.Artifact) (datacatalog.Artifact, error) { + metadata, err := unmarshalMetadata(artifact.SerializedMetadata) + if err != nil { + return datacatalog.Artifact{}, err + } + + return datacatalog.Artifact{ + Id: artifact.ArtifactID, + Dataset: &datacatalog.DatasetID{ + Project: artifact.DatasetProject, + Domain: artifact.DatasetDomain, + Name: artifact.DatasetName, + Version: artifact.DatasetVersion, + }, + Metadata: metadata, + }, nil +} + +func ToArtifactKey(datasetID datacatalog.DatasetID, artifactID string) models.ArtifactKey { + return models.ArtifactKey{ + DatasetProject: datasetID.Project, + DatasetDomain: datasetID.Domain, + DatasetName: datasetID.Name, + DatasetVersion: datasetID.Version, + ArtifactID: artifactID, + } +} diff --git a/datacatalog/pkg/repositories/transformers/artifact_test.go b/datacatalog/pkg/repositories/transformers/artifact_test.go new file mode 100644 index 0000000000..8b6e8104b1 --- /dev/null +++ b/datacatalog/pkg/repositories/transformers/artifact_test.go @@ -0,0 +1,102 @@ +package transformers + +import ( + "testing" + + "github.com/lyft/datacatalog/pkg/repositories/models" + datacatalog "github.com/lyft/datacatalog/protos/gen" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/stretchr/testify/assert" +) + +var testInteger = &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Primitive{ + Primitive: &core.Primitive{Value: &core.Primitive_Integer{Integer: 1}}, + }, + }, + }, +} + +func TestCreateArtifactModel(t *testing.T) { + artifactDataList := []*datacatalog.ArtifactData{ + {Name: "data1", Value: testInteger}, + {Name: "data2", Value: testInteger}, + } + + createArtifactRequest := datacatalog.CreateArtifactRequest{ + Artifact: &datacatalog.Artifact{ + Id: "artifactID-1", + Dataset: &datasetID, + Data: artifactDataList, + Metadata: &metadata, + }, + } + + testArtifactData := []models.ArtifactData{ + {Name: "data1", Location: "s3://test1"}, + {Name: "data3", Location: "s3://test2"}, + } + artifactModel, err := CreateArtifactModel(createArtifactRequest, testArtifactData) + assert.NoError(t, err) + assert.Equal(t, artifactModel.ArtifactID, createArtifactRequest.Artifact.Id) + assert.Equal(t, artifactModel.ArtifactKey.DatasetProject, datasetID.Project) + assert.Equal(t, artifactModel.ArtifactKey.DatasetDomain, datasetID.Domain) + assert.Equal(t, artifactModel.ArtifactKey.DatasetName, datasetID.Name) + assert.Equal(t, artifactModel.ArtifactKey.DatasetVersion, datasetID.Version) + assert.EqualValues(t, testArtifactData, artifactModel.ArtifactData) +} + +func TestCreateArtifactModelNoMetdata(t *testing.T) { + artifactDataList := []*datacatalog.ArtifactData{ + {Name: "data1", Value: testInteger}, + {Name: "data2", Value: testInteger}, + } + + createArtifactRequest := datacatalog.CreateArtifactRequest{ + Artifact: &datacatalog.Artifact{ + Id: "artifactID-1", + Dataset: &datasetID, + Data: artifactDataList, + }, + } + + testArtifactData := []models.ArtifactData{ + {Name: "data1", Location: "s3://test1"}, + {Name: "data3", Location: "s3://test2"}, + } + artifactModel, err := CreateArtifactModel(createArtifactRequest, testArtifactData) + assert.NoError(t, err) + assert.Equal(t, []byte{}, artifactModel.SerializedMetadata) +} + +func TestFromArtifactModel(t *testing.T) { + artifactModel := models.Artifact{ + ArtifactKey: models.ArtifactKey{ + DatasetProject: "project1", + DatasetDomain: "domain1", + DatasetName: "name1", + DatasetVersion: "version1", + ArtifactID: "id1", + }, + SerializedMetadata: []byte{}, + } + + actual, err := FromArtifactModel(artifactModel) + assert.NoError(t, err) + assert.Equal(t, artifactModel.ArtifactID, actual.Id) + assert.Equal(t, artifactModel.DatasetProject, actual.Dataset.Project) + assert.Equal(t, artifactModel.DatasetDomain, actual.Dataset.Domain) + assert.Equal(t, artifactModel.DatasetName, actual.Dataset.Name) + assert.Equal(t, artifactModel.DatasetVersion, actual.Dataset.Version) +} + +func TestToArtifactKey(t *testing.T) { + artifactKey := ToArtifactKey(datasetID, "artifactID-1") + assert.Equal(t, datasetID.Project, artifactKey.DatasetProject) + assert.Equal(t, datasetID.Domain, artifactKey.DatasetDomain) + assert.Equal(t, datasetID.Name, artifactKey.DatasetName) + assert.Equal(t, datasetID.Version, artifactKey.DatasetVersion) + assert.Equal(t, artifactKey.ArtifactID, "artifactID-1") +} diff --git a/datacatalog/pkg/repositories/transformers/dataset.go b/datacatalog/pkg/repositories/transformers/dataset.go new file mode 100644 index 0000000000..2eda5e24ef --- /dev/null +++ b/datacatalog/pkg/repositories/transformers/dataset.go @@ -0,0 +1,52 @@ +package transformers + +import ( + "github.com/lyft/datacatalog/pkg/repositories/models" + datacatalog "github.com/lyft/datacatalog/protos/gen" +) + +// Create a dataset model from the Dataset api object. This will serialize the metadata in the dataset as part of the transform +func CreateDatasetModel(dataset *datacatalog.Dataset) (*models.Dataset, error) { + serializedMetadata, err := marshalMetadata(dataset.Metadata) + if err != nil { + return nil, err + } + + return &models.Dataset{ + DatasetKey: models.DatasetKey{ + Project: dataset.Id.Project, + Domain: dataset.Id.Domain, + Name: dataset.Id.Name, + Version: dataset.Id.Version, + }, + SerializedMetadata: serializedMetadata, + }, nil +} + +// Create a dataset ID from the dataset key model +func FromDatasetID(datasetID datacatalog.DatasetID) models.DatasetKey { + return models.DatasetKey{ + Project: datasetID.Project, + Domain: datasetID.Domain, + Name: datasetID.Name, + Version: datasetID.Version, + } +} + +// Create a Dataset api object given a model, this will unmarshal the metadata into the object as part of the transform +func FromDatasetModel(dataset models.Dataset) (*datacatalog.Dataset, error) { + metadata, err := unmarshalMetadata(dataset.SerializedMetadata) + if err != nil { + return nil, err + } + + return &datacatalog.Dataset{ + Id: &datacatalog.DatasetID{ + Project: dataset.Project, + Domain: dataset.Domain, + Name: dataset.Name, + Version: dataset.Version, + }, + Metadata: metadata, + }, nil +} diff --git a/datacatalog/pkg/repositories/transformers/dataset_test.go b/datacatalog/pkg/repositories/transformers/dataset_test.go new file mode 100644 index 0000000000..b37566c20f --- /dev/null +++ b/datacatalog/pkg/repositories/transformers/dataset_test.go @@ -0,0 +1,66 @@ +package transformers + +import ( + "testing" + + "github.com/lyft/datacatalog/pkg/repositories/models" + datacatalog "github.com/lyft/datacatalog/protos/gen" + "github.com/stretchr/testify/assert" +) + +var metadata = datacatalog.Metadata{ + KeyMap: map[string]string{ + "testKey1": "testValue1", + "testKey2": "testValue2", + }, +} + +var datasetID = datacatalog.DatasetID{ + Project: "test-project", + Domain: "test-domain", + Name: "test-name", + Version: "test-version", +} + +func assertDatasetIDEqualsModel(t *testing.T, idlDataset *datacatalog.DatasetID, model *models.DatasetKey) { + assert.Equal(t, idlDataset.Project, model.Project) + assert.Equal(t, idlDataset.Domain, model.Domain) + assert.Equal(t, idlDataset.Name, model.Name) + assert.Equal(t, idlDataset.Version, model.Version) +} + +func TestCreateDatasetModel(t *testing.T) { + dataset := &datacatalog.Dataset{ + Id: &datasetID, + Metadata: &metadata, + } + + datasetModel, err := CreateDatasetModel(dataset) + assert.NoError(t, err) + assertDatasetIDEqualsModel(t, dataset.Id, &datasetModel.DatasetKey) + + unmarshaledMetadata, err := unmarshalMetadata(datasetModel.SerializedMetadata) + assert.NoError(t, err) + assert.EqualValues(t, unmarshaledMetadata.KeyMap, metadata.KeyMap) +} + +func TestFromDatasetID(t *testing.T) { + datasetKey := FromDatasetID(datasetID) + assertDatasetIDEqualsModel(t, &datasetID, &datasetKey) +} + +func TestFromDatasetModel(t *testing.T) { + datasetModel := &models.Dataset{ + DatasetKey: models.DatasetKey{ + Project: "test-project", + Domain: "test-domain", + Name: "test-name", + Version: "test-version", + }, + SerializedMetadata: []byte{}, + } + dataset, err := FromDatasetModel(*datasetModel) + assert.NoError(t, err) + assertDatasetIDEqualsModel(t, dataset.Id, &datasetModel.DatasetKey) + assert.Len(t, dataset.Metadata.KeyMap, 0) +} diff --git a/datacatalog/pkg/repositories/transformers/tag.go b/datacatalog/pkg/repositories/transformers/tag.go new file mode 100644 index 0000000000..9921e22157 --- /dev/null +++ b/datacatalog/pkg/repositories/transformers/tag.go @@ -0,0 +1,16 @@ +package transformers + +import ( + "github.com/lyft/datacatalog/pkg/repositories/models" + datacatalog "github.com/lyft/datacatalog/protos/gen" +) + +func ToTagKey(datasetID datacatalog.DatasetID, tagName string) models.TagKey { + return models.TagKey{ + DatasetProject: datasetID.Project, + DatasetDomain: datasetID.Domain, + DatasetName: datasetID.Name, + DatasetVersion: datasetID.Version, + TagName: tagName, + } +} diff --git a/datacatalog/pkg/repositories/transformers/tag_test.go b/datacatalog/pkg/repositories/transformers/tag_test.go new file mode 100644 index 0000000000..714605a073 --- /dev/null +++ b/datacatalog/pkg/repositories/transformers/tag_test.go @@ -0,0 +1,26 @@ +package transformers + +import ( + "testing" + + datacatalog "github.com/lyft/datacatalog/protos/gen" + "github.com/stretchr/testify/assert" +) + +func TestToTagKey(t *testing.T) { + datasetID := datacatalog.DatasetID{ + Project: "testProj", + Domain: "testDomain", + Name: "testName", + Version: "testVersion", + } + + tagName := "testTag" + tagKey := ToTagKey(datasetID, tagName) + + assert.Equal(t, tagName, tagKey.TagName) + assert.Equal(t, datasetID.Project, tagKey.DatasetProject) + assert.Equal(t, datasetID.Domain, tagKey.DatasetDomain) + assert.Equal(t, datasetID.Name, tagKey.DatasetName) + assert.Equal(t, datasetID.Version, tagKey.DatasetVersion) +} diff --git a/datacatalog/pkg/repositories/transformers/util.go b/datacatalog/pkg/repositories/transformers/util.go new file mode 100644 index 0000000000..326154568d --- /dev/null +++ b/datacatalog/pkg/repositories/transformers/util.go @@ -0,0 +1,25 @@ +package transformers + +import ( + "github.com/golang/protobuf/proto" + "github.com/lyft/datacatalog/pkg/errors" + datacatalog "github.com/lyft/datacatalog/protos/gen" + "google.golang.org/grpc/codes" +) + +func marshalMetadata(metadata *datacatalog.Metadata) ([]byte, error) { + // if it is nil, marshal empty protobuf + if metadata == nil { + metadata = &datacatalog.Metadata{} + } + return proto.Marshal(metadata) +} + +func unmarshalMetadata(serializedMetadata []byte) (*datacatalog.Metadata, error) { + if serializedMetadata == nil { + return nil, errors.NewDataCatalogErrorf(codes.Unknown, "Serialized metadata should never be nil") + } + var metadata datacatalog.Metadata + err := proto.Unmarshal(serializedMetadata, &metadata) + return &metadata, err +} diff --git a/datacatalog/pkg/repositories/transformers/util_test.go b/datacatalog/pkg/repositories/transformers/util_test.go new file mode 100644 index 0000000000..76200f8127 --- /dev/null +++ b/datacatalog/pkg/repositories/transformers/util_test.go @@ -0,0 +1,25 @@ +package transformers + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMarshaling(t *testing.T) { + marshaledMetadata, err := marshalMetadata(&metadata) + assert.NoError(t, err) + + unmarshaledMetadata, err := unmarshalMetadata(marshaledMetadata) + assert.NoError(t, err) + assert.EqualValues(t, unmarshaledMetadata.KeyMap, metadata.KeyMap) +} + +func TestMarshalingWithNil(t *testing.T) { + marshaledMetadata, err := marshalMetadata(nil) + assert.NoError(t, err) + var expectedKeymap map[string]string + unmarshaledMetadata, err := unmarshalMetadata(marshaledMetadata) + assert.NoError(t, err) + assert.EqualValues(t, expectedKeymap, unmarshaledMetadata.KeyMap) +} diff --git a/datacatalog/pkg/repositories/utils/test_utils.go b/datacatalog/pkg/repositories/utils/test_utils.go new file mode 100644 index 0000000000..127ede235a --- /dev/null +++ b/datacatalog/pkg/repositories/utils/test_utils.go @@ -0,0 +1,19 @@ +// Shared utils for postgresql tests. +package utils + +import ( + "fmt" + "testing" + + mocket "github.com/Selvatico/go-mocket" + "github.com/jinzhu/gorm" +) + +func GetDbForTest(t *testing.T) *gorm.DB { + mocket.Catcher.Register() + db, err := gorm.Open(mocket.DriverName, "fake args") + if err != nil { + t.Fatal(fmt.Sprintf("Failed to open mock db with err %v", err)) + } + return db +} diff --git a/datacatalog/pkg/rpc/datacatalogservice/service.go b/datacatalog/pkg/rpc/datacatalogservice/service.go new file mode 100644 index 0000000000..5260a517ff --- /dev/null +++ b/datacatalog/pkg/rpc/datacatalogservice/service.go @@ -0,0 +1,94 @@ +package datacatalogservice + +import ( + "context" + "fmt" + "runtime/debug" + + "github.com/lyft/datacatalog/pkg/manager/impl" + "github.com/lyft/datacatalog/pkg/manager/interfaces" + "github.com/lyft/datacatalog/pkg/repositories" + "github.com/lyft/datacatalog/pkg/repositories/config" + "github.com/lyft/datacatalog/pkg/runtime" + catalog "github.com/lyft/datacatalog/protos/gen" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/storage" +) + +type DataCatalogService struct { + DatasetManager interfaces.DatasetManager + ArtifactManager interfaces.ArtifactManager + TagManager interfaces.TagManager +} + +// TODO: add metrics and counters to these service methods +func (s *DataCatalogService) CreateDataset(ctx context.Context, request *catalog.CreateDatasetRequest) (*catalog.CreateDatasetResponse, error) { + return s.DatasetManager.CreateDataset(ctx, *request) +} + +func (s *DataCatalogService) CreateArtifact(ctx context.Context, request *catalog.CreateArtifactRequest) (*catalog.CreateArtifactResponse, error) { + return s.ArtifactManager.CreateArtifact(ctx, *request) +} + +func (s *DataCatalogService) GetDataset(ctx context.Context, request *catalog.GetDatasetRequest) (*catalog.GetDatasetResponse, error) { + return s.DatasetManager.GetDataset(ctx, *request) +} + +func (s *DataCatalogService) GetArtifact(ctx context.Context, request *catalog.GetArtifactRequest) (*catalog.GetArtifactResponse, error) { + return s.ArtifactManager.GetArtifact(ctx, *request) +} + +func (s *DataCatalogService) AddTag(ctx context.Context, request *catalog.AddTagRequest) (*catalog.AddTagResponse, error) { + return s.TagManager.AddTag(ctx, *request) +} + +func NewDataCatalogService() *DataCatalogService { + ctx := context.Background() + + dataCatalogScope := "datacatalog" + catalogScope := promutils.NewScope(dataCatalogScope).NewSubScope("service") + + defer func() { + if err := recover(); err != nil { + catalogScope.MustNewCounter("initialization_panic", + "panics encountered initializating the datacatalog service").Inc() + logger.Fatalf(context.Background(), fmt.Sprintf("caught panic: %v [%+v]", err, string(debug.Stack()))) + } + }() + + storeConfig := storage.GetConfig() + dataStorageClient, err := storage.NewDataStore(storeConfig, catalogScope.NewSubScope("storage")) + if err != nil { + logger.Errorf(ctx, "Failed to create DataStore %v, err %v", storeConfig, err) + panic(err) + } + logger.Infof(ctx, "Created data storage.") + + configProvider := runtime.NewConfigurationProvider() + baseStorageReference := dataStorageClient.GetBaseContainerFQN(ctx) + dataCatalogConfig := configProvider.ApplicationConfiguration().GetDataCatalogConfig() + storagePrefix, err := dataStorageClient.ConstructReference(ctx, baseStorageReference, dataCatalogConfig.StoragePrefix) + if err != nil { + logger.Errorf(ctx, "Failed to create prefix %v, err %v", dataCatalogConfig.StoragePrefix, err) + panic(err) + } + + dbConfigValues := configProvider.ApplicationConfiguration().GetDbConfig() + dbConfig := config.DbConfig{ + Host: dbConfigValues.Host, + Port: dbConfigValues.Port, + DbName: dbConfigValues.DbName, + User: dbConfigValues.User, + Password: dbConfigValues.Password, + ExtraOptions: dbConfigValues.ExtraOptions, + } + repos := repositories.GetRepository(repositories.POSTGRES, dbConfig, catalogScope) + logger.Infof(ctx, "Created DB connection.") + + return &DataCatalogService{ + DatasetManager: impl.NewDatasetManager(repos, dataStorageClient, catalogScope.NewSubScope("dataset")), + ArtifactManager: impl.NewArtifactManager(repos, dataStorageClient, storagePrefix, catalogScope.NewSubScope("artifact")), + TagManager: impl.NewTagManager(repos, dataStorageClient, catalogScope.NewSubScope("tag")), + } +} diff --git a/datacatalog/pkg/runtime/application_config_provider.go b/datacatalog/pkg/runtime/application_config_provider.go new file mode 100644 index 0000000000..94b6c7dce3 --- /dev/null +++ b/datacatalog/pkg/runtime/application_config_provider.go @@ -0,0 +1,59 @@ +package runtime + +import ( + "context" + "io/ioutil" + "os" + + dbconfig "github.com/lyft/datacatalog/pkg/repositories/config" + "github.com/lyft/datacatalog/pkg/runtime/configs" + "github.com/lyft/flytestdlib/config" + "github.com/lyft/flytestdlib/logger" +) + +const database = "database" +const datacatalog = "datacatalog" + +var databaseConfig = config.MustRegisterSection(database, &dbconfig.DbConfigSection{}) +var datacatalogConfig = config.MustRegisterSection(datacatalog, &configs.DataCatalogConfig{}) + +// Defines the interface to return top-level config structs necessary to start up a datacatalog application. +type ApplicationConfiguration interface { + GetDbConfig() dbconfig.DbConfig + GetDataCatalogConfig() configs.DataCatalogConfig +} + +type ApplicationConfigurationProvider struct{} + +func (p *ApplicationConfigurationProvider) GetDbConfig() dbconfig.DbConfig { + dbConfigSection := databaseConfig.GetConfig().(*dbconfig.DbConfigSection) + password := dbConfigSection.Password + if len(dbConfigSection.PasswordPath) > 0 { + if _, err := os.Stat(dbConfigSection.PasswordPath); os.IsNotExist(err) { + logger.Fatalf(context.Background(), + "missing database password at specified path [%s]", dbConfigSection.PasswordPath) + } + passwordVal, err := ioutil.ReadFile(dbConfigSection.PasswordPath) + if err != nil { + logger.Fatalf(context.Background(), "failed to read database password from path [%s] with err: %v", + dbConfigSection.PasswordPath, err) + } + password = string(passwordVal) + } + return dbconfig.DbConfig{ + Host: dbConfigSection.Host, + Port: dbConfigSection.Port, + DbName: dbConfigSection.DbName, + User: dbConfigSection.User, + Password: password, + ExtraOptions: dbConfigSection.ExtraOptions, + } +} + +func (p *ApplicationConfigurationProvider) GetDataCatalogConfig() configs.DataCatalogConfig { + return *datacatalogConfig.GetConfig().(*configs.DataCatalogConfig) +} + +func NewApplicationConfigurationProvider() ApplicationConfiguration { + return &ApplicationConfigurationProvider{} +} diff --git a/datacatalog/pkg/runtime/configs/data_catalog_config.go b/datacatalog/pkg/runtime/configs/data_catalog_config.go new file mode 100644 index 0000000000..f5ba55dd5d --- /dev/null +++ b/datacatalog/pkg/runtime/configs/data_catalog_config.go @@ -0,0 +1,8 @@ +package configs + +//go:generate pflags DataCatalogConfig + +// This configuration is the base configuration to start admin +type DataCatalogConfig struct { + StoragePrefix string `json:"storage-prefix" pflag:",StoragePrefix specifies the prefix where DataCatalog stores offloaded ArtifactData in CloudStorage. If not specified, the data will be stored in the base container directly."` +} diff --git a/datacatalog/pkg/runtime/configs/datacatalogconfig_flags.go b/datacatalog/pkg/runtime/configs/datacatalogconfig_flags.go new file mode 100755 index 0000000000..946ff45b06 --- /dev/null +++ b/datacatalog/pkg/runtime/configs/datacatalogconfig_flags.go @@ -0,0 +1,19 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots at +// 2019-08-07 10:50:49.390084414 -0700 PDT m=+0.045946390 + +package configs + +import ( + "fmt" + + "github.com/spf13/pflag" +) + +// GetPFlagSet will return strongly types pflags for all fields in DataCatalogConfig and its nested types. The format of the +// flags is json-name.json-sub-name... etc. +func (DataCatalogConfig) GetPFlagSet(prefix string) *pflag.FlagSet { + cmdFlags := pflag.NewFlagSet("DataCatalogConfig", pflag.ExitOnError) + cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "storage-prefix"), []string{}, "StoragePrefix specifies the prefix where DataCatalog stores offloaded ArtifactData in CloudStorage. If not specified, the data will be stored in the base container directly.") + return cmdFlags +} diff --git a/datacatalog/pkg/runtime/configs/datacatalogconfig_flags_test.go b/datacatalog/pkg/runtime/configs/datacatalogconfig_flags_test.go new file mode 100755 index 0000000000..e666480aa4 --- /dev/null +++ b/datacatalog/pkg/runtime/configs/datacatalogconfig_flags_test.go @@ -0,0 +1,113 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots at +// 2019-08-07 10:50:49.390084414 -0700 PDT m=+0.045946390 + +package configs + +import ( + "encoding/json" + "fmt" + "reflect" + "strings" + "testing" + + "github.com/mitchellh/mapstructure" + "github.com/stretchr/testify/assert" +) + +var dereferencableKindsDataCatalogConfig = map[reflect.Kind]struct{}{ + reflect.Array: {}, reflect.Chan: {}, reflect.Map: {}, reflect.Ptr: {}, reflect.Slice: {}, +} + +// Checks if t is a kind that can be dereferenced to get its underlying type. +func canGetElementDataCatalogConfig(t reflect.Kind) bool { + _, exists := dereferencableKindsDataCatalogConfig[t] + return exists +} + +// This decoder hook tests types for json unmarshaling capability. If implemented, it uses json unmarshal to build the +// object. Otherwise, it'll just pass on the original data. +func jsonUnmarshalerHookDataCatalogConfig(_, to reflect.Type, data interface{}) (interface{}, error) { + unmarshalerType := reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() + if to.Implements(unmarshalerType) || reflect.PtrTo(to).Implements(unmarshalerType) || + (canGetElementDataCatalogConfig(to.Kind()) && to.Elem().Implements(unmarshalerType)) { + + raw, err := json.Marshal(data) + if err != nil { + fmt.Printf("Failed to marshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + res := reflect.New(to).Interface() + err = json.Unmarshal(raw, &res) + if err != nil { + fmt.Printf("Failed to umarshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + return res, nil + } + + return data, nil +} + +func decode_DataCatalogConfig(input, result interface{}) error { + config := &mapstructure.DecoderConfig{ + TagName: "json", + WeaklyTypedInput: true, + Result: result, + DecodeHook: mapstructure.ComposeDecodeHookFunc( + mapstructure.StringToTimeDurationHookFunc(), + mapstructure.StringToSliceHookFunc(","), + jsonUnmarshalerHookDataCatalogConfig, + ), + } + + decoder, err := mapstructure.NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(input) +} + +func testDecodeJson_DataCatalogConfig(t *testing.T, val, result interface{}) { + assert.NoError(t, decode_DataCatalogConfig(val, result)) +} + +func testDecodeSlice_DataCatalogConfig(t *testing.T, vStringSlice, result interface{}) { + assert.NoError(t, decode_DataCatalogConfig(vStringSlice, result)) +} + +func TestDataCatalogConfig_GetPFlagSet(t *testing.T) { + val := DataCatalogConfig{} + cmdFlags := val.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) +} + +func TestDataCatalogConfig_SetFlags(t *testing.T) { + actual := DataCatalogConfig{} + cmdFlags := actual.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) + + t.Run("Test_storage-prefix", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vStringSlice, err := cmdFlags.GetStringSlice("storage-prefix"); err == nil { + assert.Equal(t, []string{}, vStringSlice) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + cmdFlags.Set("storage-prefix", "1,1") + if vStringSlice, err := cmdFlags.GetStringSlice("storage-prefix"); err == nil { + testDecodeSlice_DataCatalogConfig(t, strings.Join(vStringSlice, ","), &actual.StoragePrefix) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) +} diff --git a/datacatalog/pkg/runtime/configuration_provider.go b/datacatalog/pkg/runtime/configuration_provider.go new file mode 100644 index 0000000000..fe11026c85 --- /dev/null +++ b/datacatalog/pkg/runtime/configuration_provider.go @@ -0,0 +1,21 @@ +package runtime + +// Interface for getting parsed values from a configuration file +type Configuration interface { + ApplicationConfiguration() ApplicationConfiguration +} + +// Implementation of a Configuration +type ConfigurationProvider struct { + applicationConfiguration ApplicationConfiguration +} + +func (p *ConfigurationProvider) ApplicationConfiguration() ApplicationConfiguration { + return p.applicationConfiguration +} + +func NewConfigurationProvider() Configuration { + return &ConfigurationProvider{ + applicationConfiguration: NewApplicationConfigurationProvider(), + } +} diff --git a/datacatalog/protos/gen/service.pb.go b/datacatalog/protos/gen/service.pb.go new file mode 100644 index 0000000000..d759937ba4 --- /dev/null +++ b/datacatalog/protos/gen/service.pb.go @@ -0,0 +1,1046 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// source: service.proto + +package datacatalog + +import ( + context "context" + fmt "fmt" + proto "github.com/golang/protobuf/proto" + core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" + math "math" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package + +type CreateDatasetRequest struct { + Dataset *Dataset `protobuf:"bytes,1,opt,name=dataset,proto3" json:"dataset,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *CreateDatasetRequest) Reset() { *m = CreateDatasetRequest{} } +func (m *CreateDatasetRequest) String() string { return proto.CompactTextString(m) } +func (*CreateDatasetRequest) ProtoMessage() {} +func (*CreateDatasetRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_a0b84a42fa06f626, []int{0} +} + +func (m *CreateDatasetRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_CreateDatasetRequest.Unmarshal(m, b) +} +func (m *CreateDatasetRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_CreateDatasetRequest.Marshal(b, m, deterministic) +} +func (m *CreateDatasetRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_CreateDatasetRequest.Merge(m, src) +} +func (m *CreateDatasetRequest) XXX_Size() int { + return xxx_messageInfo_CreateDatasetRequest.Size(m) +} +func (m *CreateDatasetRequest) XXX_DiscardUnknown() { + xxx_messageInfo_CreateDatasetRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_CreateDatasetRequest proto.InternalMessageInfo + +func (m *CreateDatasetRequest) GetDataset() *Dataset { + if m != nil { + return m.Dataset + } + return nil +} + +type CreateDatasetResponse struct { + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *CreateDatasetResponse) Reset() { *m = CreateDatasetResponse{} } +func (m *CreateDatasetResponse) String() string { return proto.CompactTextString(m) } +func (*CreateDatasetResponse) ProtoMessage() {} +func (*CreateDatasetResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_a0b84a42fa06f626, []int{1} +} + +func (m *CreateDatasetResponse) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_CreateDatasetResponse.Unmarshal(m, b) +} +func (m *CreateDatasetResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_CreateDatasetResponse.Marshal(b, m, deterministic) +} +func (m *CreateDatasetResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_CreateDatasetResponse.Merge(m, src) +} +func (m *CreateDatasetResponse) XXX_Size() int { + return xxx_messageInfo_CreateDatasetResponse.Size(m) +} +func (m *CreateDatasetResponse) XXX_DiscardUnknown() { + xxx_messageInfo_CreateDatasetResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_CreateDatasetResponse proto.InternalMessageInfo + +type GetDatasetRequest struct { + Dataset *DatasetID `protobuf:"bytes,1,opt,name=dataset,proto3" json:"dataset,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *GetDatasetRequest) Reset() { *m = GetDatasetRequest{} } +func (m *GetDatasetRequest) String() string { return proto.CompactTextString(m) } +func (*GetDatasetRequest) ProtoMessage() {} +func (*GetDatasetRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_a0b84a42fa06f626, []int{2} +} + +func (m *GetDatasetRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_GetDatasetRequest.Unmarshal(m, b) +} +func (m *GetDatasetRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_GetDatasetRequest.Marshal(b, m, deterministic) +} +func (m *GetDatasetRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_GetDatasetRequest.Merge(m, src) +} +func (m *GetDatasetRequest) XXX_Size() int { + return xxx_messageInfo_GetDatasetRequest.Size(m) +} +func (m *GetDatasetRequest) XXX_DiscardUnknown() { + xxx_messageInfo_GetDatasetRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_GetDatasetRequest proto.InternalMessageInfo + +func (m *GetDatasetRequest) GetDataset() *DatasetID { + if m != nil { + return m.Dataset + } + return nil +} + +type GetDatasetResponse struct { + Dataset *Dataset `protobuf:"bytes,1,opt,name=dataset,proto3" json:"dataset,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *GetDatasetResponse) Reset() { *m = GetDatasetResponse{} } +func (m *GetDatasetResponse) String() string { return proto.CompactTextString(m) } +func (*GetDatasetResponse) ProtoMessage() {} +func (*GetDatasetResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_a0b84a42fa06f626, []int{3} +} + +func (m *GetDatasetResponse) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_GetDatasetResponse.Unmarshal(m, b) +} +func (m *GetDatasetResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_GetDatasetResponse.Marshal(b, m, deterministic) +} +func (m *GetDatasetResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_GetDatasetResponse.Merge(m, src) +} +func (m *GetDatasetResponse) XXX_Size() int { + return xxx_messageInfo_GetDatasetResponse.Size(m) +} +func (m *GetDatasetResponse) XXX_DiscardUnknown() { + xxx_messageInfo_GetDatasetResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_GetDatasetResponse proto.InternalMessageInfo + +func (m *GetDatasetResponse) GetDataset() *Dataset { + if m != nil { + return m.Dataset + } + return nil +} + +type GetArtifactRequest struct { + Dataset *DatasetID `protobuf:"bytes,1,opt,name=dataset,proto3" json:"dataset,omitempty"` + // Types that are valid to be assigned to QueryHandle: + // *GetArtifactRequest_ArtifactId + // *GetArtifactRequest_TagName + QueryHandle isGetArtifactRequest_QueryHandle `protobuf_oneof:"query_handle"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *GetArtifactRequest) Reset() { *m = GetArtifactRequest{} } +func (m *GetArtifactRequest) String() string { return proto.CompactTextString(m) } +func (*GetArtifactRequest) ProtoMessage() {} +func (*GetArtifactRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_a0b84a42fa06f626, []int{4} +} + +func (m *GetArtifactRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_GetArtifactRequest.Unmarshal(m, b) +} +func (m *GetArtifactRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_GetArtifactRequest.Marshal(b, m, deterministic) +} +func (m *GetArtifactRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_GetArtifactRequest.Merge(m, src) +} +func (m *GetArtifactRequest) XXX_Size() int { + return xxx_messageInfo_GetArtifactRequest.Size(m) +} +func (m *GetArtifactRequest) XXX_DiscardUnknown() { + xxx_messageInfo_GetArtifactRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_GetArtifactRequest proto.InternalMessageInfo + +func (m *GetArtifactRequest) GetDataset() *DatasetID { + if m != nil { + return m.Dataset + } + return nil +} + +type isGetArtifactRequest_QueryHandle interface { + isGetArtifactRequest_QueryHandle() +} + +type GetArtifactRequest_ArtifactId struct { + ArtifactId string `protobuf:"bytes,2,opt,name=artifact_id,json=artifactId,proto3,oneof"` +} + +type GetArtifactRequest_TagName struct { + TagName string `protobuf:"bytes,3,opt,name=tag_name,json=tagName,proto3,oneof"` +} + +func (*GetArtifactRequest_ArtifactId) isGetArtifactRequest_QueryHandle() {} + +func (*GetArtifactRequest_TagName) isGetArtifactRequest_QueryHandle() {} + +func (m *GetArtifactRequest) GetQueryHandle() isGetArtifactRequest_QueryHandle { + if m != nil { + return m.QueryHandle + } + return nil +} + +func (m *GetArtifactRequest) GetArtifactId() string { + if x, ok := m.GetQueryHandle().(*GetArtifactRequest_ArtifactId); ok { + return x.ArtifactId + } + return "" +} + +func (m *GetArtifactRequest) GetTagName() string { + if x, ok := m.GetQueryHandle().(*GetArtifactRequest_TagName); ok { + return x.TagName + } + return "" +} + +// XXX_OneofWrappers is for the internal use of the proto package. +func (*GetArtifactRequest) XXX_OneofWrappers() []interface{} { + return []interface{}{ + (*GetArtifactRequest_ArtifactId)(nil), + (*GetArtifactRequest_TagName)(nil), + } +} + +type GetArtifactResponse struct { + Artifact *Artifact `protobuf:"bytes,1,opt,name=artifact,proto3" json:"artifact,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *GetArtifactResponse) Reset() { *m = GetArtifactResponse{} } +func (m *GetArtifactResponse) String() string { return proto.CompactTextString(m) } +func (*GetArtifactResponse) ProtoMessage() {} +func (*GetArtifactResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_a0b84a42fa06f626, []int{5} +} + +func (m *GetArtifactResponse) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_GetArtifactResponse.Unmarshal(m, b) +} +func (m *GetArtifactResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_GetArtifactResponse.Marshal(b, m, deterministic) +} +func (m *GetArtifactResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_GetArtifactResponse.Merge(m, src) +} +func (m *GetArtifactResponse) XXX_Size() int { + return xxx_messageInfo_GetArtifactResponse.Size(m) +} +func (m *GetArtifactResponse) XXX_DiscardUnknown() { + xxx_messageInfo_GetArtifactResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_GetArtifactResponse proto.InternalMessageInfo + +func (m *GetArtifactResponse) GetArtifact() *Artifact { + if m != nil { + return m.Artifact + } + return nil +} + +type CreateArtifactRequest struct { + Artifact *Artifact `protobuf:"bytes,1,opt,name=artifact,proto3" json:"artifact,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *CreateArtifactRequest) Reset() { *m = CreateArtifactRequest{} } +func (m *CreateArtifactRequest) String() string { return proto.CompactTextString(m) } +func (*CreateArtifactRequest) ProtoMessage() {} +func (*CreateArtifactRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_a0b84a42fa06f626, []int{6} +} + +func (m *CreateArtifactRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_CreateArtifactRequest.Unmarshal(m, b) +} +func (m *CreateArtifactRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_CreateArtifactRequest.Marshal(b, m, deterministic) +} +func (m *CreateArtifactRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_CreateArtifactRequest.Merge(m, src) +} +func (m *CreateArtifactRequest) XXX_Size() int { + return xxx_messageInfo_CreateArtifactRequest.Size(m) +} +func (m *CreateArtifactRequest) XXX_DiscardUnknown() { + xxx_messageInfo_CreateArtifactRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_CreateArtifactRequest proto.InternalMessageInfo + +func (m *CreateArtifactRequest) GetArtifact() *Artifact { + if m != nil { + return m.Artifact + } + return nil +} + +type CreateArtifactResponse struct { + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *CreateArtifactResponse) Reset() { *m = CreateArtifactResponse{} } +func (m *CreateArtifactResponse) String() string { return proto.CompactTextString(m) } +func (*CreateArtifactResponse) ProtoMessage() {} +func (*CreateArtifactResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_a0b84a42fa06f626, []int{7} +} + +func (m *CreateArtifactResponse) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_CreateArtifactResponse.Unmarshal(m, b) +} +func (m *CreateArtifactResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_CreateArtifactResponse.Marshal(b, m, deterministic) +} +func (m *CreateArtifactResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_CreateArtifactResponse.Merge(m, src) +} +func (m *CreateArtifactResponse) XXX_Size() int { + return xxx_messageInfo_CreateArtifactResponse.Size(m) +} +func (m *CreateArtifactResponse) XXX_DiscardUnknown() { + xxx_messageInfo_CreateArtifactResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_CreateArtifactResponse proto.InternalMessageInfo + +type AddTagRequest struct { + Tag *Tag `protobuf:"bytes,1,opt,name=tag,proto3" json:"tag,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *AddTagRequest) Reset() { *m = AddTagRequest{} } +func (m *AddTagRequest) String() string { return proto.CompactTextString(m) } +func (*AddTagRequest) ProtoMessage() {} +func (*AddTagRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_a0b84a42fa06f626, []int{8} +} + +func (m *AddTagRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_AddTagRequest.Unmarshal(m, b) +} +func (m *AddTagRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_AddTagRequest.Marshal(b, m, deterministic) +} +func (m *AddTagRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_AddTagRequest.Merge(m, src) +} +func (m *AddTagRequest) XXX_Size() int { + return xxx_messageInfo_AddTagRequest.Size(m) +} +func (m *AddTagRequest) XXX_DiscardUnknown() { + xxx_messageInfo_AddTagRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_AddTagRequest proto.InternalMessageInfo + +func (m *AddTagRequest) GetTag() *Tag { + if m != nil { + return m.Tag + } + return nil +} + +type AddTagResponse struct { + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *AddTagResponse) Reset() { *m = AddTagResponse{} } +func (m *AddTagResponse) String() string { return proto.CompactTextString(m) } +func (*AddTagResponse) ProtoMessage() {} +func (*AddTagResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_a0b84a42fa06f626, []int{9} +} + +func (m *AddTagResponse) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_AddTagResponse.Unmarshal(m, b) +} +func (m *AddTagResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_AddTagResponse.Marshal(b, m, deterministic) +} +func (m *AddTagResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_AddTagResponse.Merge(m, src) +} +func (m *AddTagResponse) XXX_Size() int { + return xxx_messageInfo_AddTagResponse.Size(m) +} +func (m *AddTagResponse) XXX_DiscardUnknown() { + xxx_messageInfo_AddTagResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_AddTagResponse proto.InternalMessageInfo + +type Dataset struct { + Id *DatasetID `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` + Metadata *Metadata `protobuf:"bytes,2,opt,name=metadata,proto3" json:"metadata,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Dataset) Reset() { *m = Dataset{} } +func (m *Dataset) String() string { return proto.CompactTextString(m) } +func (*Dataset) ProtoMessage() {} +func (*Dataset) Descriptor() ([]byte, []int) { + return fileDescriptor_a0b84a42fa06f626, []int{10} +} + +func (m *Dataset) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Dataset.Unmarshal(m, b) +} +func (m *Dataset) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Dataset.Marshal(b, m, deterministic) +} +func (m *Dataset) XXX_Merge(src proto.Message) { + xxx_messageInfo_Dataset.Merge(m, src) +} +func (m *Dataset) XXX_Size() int { + return xxx_messageInfo_Dataset.Size(m) +} +func (m *Dataset) XXX_DiscardUnknown() { + xxx_messageInfo_Dataset.DiscardUnknown(m) +} + +var xxx_messageInfo_Dataset proto.InternalMessageInfo + +func (m *Dataset) GetId() *DatasetID { + if m != nil { + return m.Id + } + return nil +} + +func (m *Dataset) GetMetadata() *Metadata { + if m != nil { + return m.Metadata + } + return nil +} + +type DatasetID struct { + Project string `protobuf:"bytes,1,opt,name=project,proto3" json:"project,omitempty"` + Name string `protobuf:"bytes,2,opt,name=name,proto3" json:"name,omitempty"` + Domain string `protobuf:"bytes,3,opt,name=domain,proto3" json:"domain,omitempty"` + Version string `protobuf:"bytes,4,opt,name=version,proto3" json:"version,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *DatasetID) Reset() { *m = DatasetID{} } +func (m *DatasetID) String() string { return proto.CompactTextString(m) } +func (*DatasetID) ProtoMessage() {} +func (*DatasetID) Descriptor() ([]byte, []int) { + return fileDescriptor_a0b84a42fa06f626, []int{11} +} + +func (m *DatasetID) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_DatasetID.Unmarshal(m, b) +} +func (m *DatasetID) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_DatasetID.Marshal(b, m, deterministic) +} +func (m *DatasetID) XXX_Merge(src proto.Message) { + xxx_messageInfo_DatasetID.Merge(m, src) +} +func (m *DatasetID) XXX_Size() int { + return xxx_messageInfo_DatasetID.Size(m) +} +func (m *DatasetID) XXX_DiscardUnknown() { + xxx_messageInfo_DatasetID.DiscardUnknown(m) +} + +var xxx_messageInfo_DatasetID proto.InternalMessageInfo + +func (m *DatasetID) GetProject() string { + if m != nil { + return m.Project + } + return "" +} + +func (m *DatasetID) GetName() string { + if m != nil { + return m.Name + } + return "" +} + +func (m *DatasetID) GetDomain() string { + if m != nil { + return m.Domain + } + return "" +} + +func (m *DatasetID) GetVersion() string { + if m != nil { + return m.Version + } + return "" +} + +type Artifact struct { + Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` + Dataset *DatasetID `protobuf:"bytes,2,opt,name=dataset,proto3" json:"dataset,omitempty"` + Data []*ArtifactData `protobuf:"bytes,3,rep,name=data,proto3" json:"data,omitempty"` + Metadata *Metadata `protobuf:"bytes,4,opt,name=metadata,proto3" json:"metadata,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Artifact) Reset() { *m = Artifact{} } +func (m *Artifact) String() string { return proto.CompactTextString(m) } +func (*Artifact) ProtoMessage() {} +func (*Artifact) Descriptor() ([]byte, []int) { + return fileDescriptor_a0b84a42fa06f626, []int{12} +} + +func (m *Artifact) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Artifact.Unmarshal(m, b) +} +func (m *Artifact) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Artifact.Marshal(b, m, deterministic) +} +func (m *Artifact) XXX_Merge(src proto.Message) { + xxx_messageInfo_Artifact.Merge(m, src) +} +func (m *Artifact) XXX_Size() int { + return xxx_messageInfo_Artifact.Size(m) +} +func (m *Artifact) XXX_DiscardUnknown() { + xxx_messageInfo_Artifact.DiscardUnknown(m) +} + +var xxx_messageInfo_Artifact proto.InternalMessageInfo + +func (m *Artifact) GetId() string { + if m != nil { + return m.Id + } + return "" +} + +func (m *Artifact) GetDataset() *DatasetID { + if m != nil { + return m.Dataset + } + return nil +} + +func (m *Artifact) GetData() []*ArtifactData { + if m != nil { + return m.Data + } + return nil +} + +func (m *Artifact) GetMetadata() *Metadata { + if m != nil { + return m.Metadata + } + return nil +} + +type ArtifactData struct { + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + Value *core.Literal `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *ArtifactData) Reset() { *m = ArtifactData{} } +func (m *ArtifactData) String() string { return proto.CompactTextString(m) } +func (*ArtifactData) ProtoMessage() {} +func (*ArtifactData) Descriptor() ([]byte, []int) { + return fileDescriptor_a0b84a42fa06f626, []int{13} +} + +func (m *ArtifactData) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_ArtifactData.Unmarshal(m, b) +} +func (m *ArtifactData) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_ArtifactData.Marshal(b, m, deterministic) +} +func (m *ArtifactData) XXX_Merge(src proto.Message) { + xxx_messageInfo_ArtifactData.Merge(m, src) +} +func (m *ArtifactData) XXX_Size() int { + return xxx_messageInfo_ArtifactData.Size(m) +} +func (m *ArtifactData) XXX_DiscardUnknown() { + xxx_messageInfo_ArtifactData.DiscardUnknown(m) +} + +var xxx_messageInfo_ArtifactData proto.InternalMessageInfo + +func (m *ArtifactData) GetName() string { + if m != nil { + return m.Name + } + return "" +} + +func (m *ArtifactData) GetValue() *core.Literal { + if m != nil { + return m.Value + } + return nil +} + +type Tag struct { + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + ArtifactId string `protobuf:"bytes,2,opt,name=artifact_id,json=artifactId,proto3" json:"artifact_id,omitempty"` + Dataset *DatasetID `protobuf:"bytes,3,opt,name=dataset,proto3" json:"dataset,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Tag) Reset() { *m = Tag{} } +func (m *Tag) String() string { return proto.CompactTextString(m) } +func (*Tag) ProtoMessage() {} +func (*Tag) Descriptor() ([]byte, []int) { + return fileDescriptor_a0b84a42fa06f626, []int{14} +} + +func (m *Tag) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Tag.Unmarshal(m, b) +} +func (m *Tag) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Tag.Marshal(b, m, deterministic) +} +func (m *Tag) XXX_Merge(src proto.Message) { + xxx_messageInfo_Tag.Merge(m, src) +} +func (m *Tag) XXX_Size() int { + return xxx_messageInfo_Tag.Size(m) +} +func (m *Tag) XXX_DiscardUnknown() { + xxx_messageInfo_Tag.DiscardUnknown(m) +} + +var xxx_messageInfo_Tag proto.InternalMessageInfo + +func (m *Tag) GetName() string { + if m != nil { + return m.Name + } + return "" +} + +func (m *Tag) GetArtifactId() string { + if m != nil { + return m.ArtifactId + } + return "" +} + +func (m *Tag) GetDataset() *DatasetID { + if m != nil { + return m.Dataset + } + return nil +} + +type Metadata struct { + KeyMap map[string]string `protobuf:"bytes,1,rep,name=key_map,json=keyMap,proto3" json:"key_map,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Metadata) Reset() { *m = Metadata{} } +func (m *Metadata) String() string { return proto.CompactTextString(m) } +func (*Metadata) ProtoMessage() {} +func (*Metadata) Descriptor() ([]byte, []int) { + return fileDescriptor_a0b84a42fa06f626, []int{15} +} + +func (m *Metadata) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Metadata.Unmarshal(m, b) +} +func (m *Metadata) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Metadata.Marshal(b, m, deterministic) +} +func (m *Metadata) XXX_Merge(src proto.Message) { + xxx_messageInfo_Metadata.Merge(m, src) +} +func (m *Metadata) XXX_Size() int { + return xxx_messageInfo_Metadata.Size(m) +} +func (m *Metadata) XXX_DiscardUnknown() { + xxx_messageInfo_Metadata.DiscardUnknown(m) +} + +var xxx_messageInfo_Metadata proto.InternalMessageInfo + +func (m *Metadata) GetKeyMap() map[string]string { + if m != nil { + return m.KeyMap + } + return nil +} + +func init() { + proto.RegisterType((*CreateDatasetRequest)(nil), "datacatalog.CreateDatasetRequest") + proto.RegisterType((*CreateDatasetResponse)(nil), "datacatalog.CreateDatasetResponse") + proto.RegisterType((*GetDatasetRequest)(nil), "datacatalog.GetDatasetRequest") + proto.RegisterType((*GetDatasetResponse)(nil), "datacatalog.GetDatasetResponse") + proto.RegisterType((*GetArtifactRequest)(nil), "datacatalog.GetArtifactRequest") + proto.RegisterType((*GetArtifactResponse)(nil), "datacatalog.GetArtifactResponse") + proto.RegisterType((*CreateArtifactRequest)(nil), "datacatalog.CreateArtifactRequest") + proto.RegisterType((*CreateArtifactResponse)(nil), "datacatalog.CreateArtifactResponse") + proto.RegisterType((*AddTagRequest)(nil), "datacatalog.AddTagRequest") + proto.RegisterType((*AddTagResponse)(nil), "datacatalog.AddTagResponse") + proto.RegisterType((*Dataset)(nil), "datacatalog.Dataset") + proto.RegisterType((*DatasetID)(nil), "datacatalog.DatasetID") + proto.RegisterType((*Artifact)(nil), "datacatalog.Artifact") + proto.RegisterType((*ArtifactData)(nil), "datacatalog.ArtifactData") + proto.RegisterType((*Tag)(nil), "datacatalog.Tag") + proto.RegisterType((*Metadata)(nil), "datacatalog.Metadata") + proto.RegisterMapType((map[string]string)(nil), "datacatalog.Metadata.KeyMapEntry") +} + +func init() { proto.RegisterFile("service.proto", fileDescriptor_a0b84a42fa06f626) } + +var fileDescriptor_a0b84a42fa06f626 = []byte{ + // 647 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x9c, 0x95, 0x5f, 0x73, 0x93, 0x4c, + 0x14, 0xc6, 0x5f, 0x42, 0xde, 0xfc, 0x39, 0xb4, 0x99, 0xba, 0xb6, 0x15, 0xa9, 0x63, 0x53, 0x9c, + 0x71, 0x72, 0xa1, 0x54, 0xd3, 0x1b, 0xed, 0x5d, 0x6c, 0xaa, 0xad, 0x1a, 0xa7, 0xc3, 0x74, 0x9c, + 0xf1, 0x2a, 0xb3, 0x66, 0x4f, 0x11, 0x43, 0x20, 0x85, 0x4d, 0x66, 0xb8, 0xf4, 0x4b, 0xf8, 0x09, + 0xbc, 0xf6, 0x33, 0x3a, 0xc0, 0x42, 0x81, 0xd2, 0x1a, 0x7b, 0xc7, 0xb2, 0x67, 0x7f, 0x79, 0xce, + 0x73, 0xd8, 0x27, 0xb0, 0x1e, 0xa0, 0xbf, 0xb4, 0x27, 0x68, 0xcc, 0x7d, 0x8f, 0x7b, 0x44, 0x61, + 0x94, 0xd3, 0x09, 0xe5, 0xd4, 0xf1, 0x2c, 0xed, 0xd1, 0x85, 0x13, 0x72, 0xb4, 0x99, 0xb3, 0x3f, + 0xf1, 0x7c, 0xdc, 0x77, 0x6c, 0x8e, 0x3e, 0x75, 0x82, 0xa4, 0x54, 0x7f, 0x0b, 0x9b, 0x47, 0x3e, + 0x52, 0x8e, 0x43, 0xca, 0x69, 0x80, 0xdc, 0xc4, 0xcb, 0x05, 0x06, 0x9c, 0x18, 0xd0, 0x64, 0xc9, + 0x1b, 0x55, 0xea, 0x4a, 0x3d, 0xa5, 0xbf, 0x69, 0xe4, 0xa0, 0x46, 0x5a, 0x9d, 0x16, 0xe9, 0x0f, + 0x60, 0xab, 0xc4, 0x09, 0xe6, 0x9e, 0x1b, 0xa0, 0x7e, 0x0c, 0xf7, 0xde, 0x21, 0x2f, 0xd1, 0x5f, + 0x94, 0xe9, 0xdb, 0x55, 0xf4, 0xd3, 0xe1, 0x15, 0x7f, 0x08, 0x24, 0x8f, 0x49, 0xe0, 0xff, 0xac, + 0xf2, 0xa7, 0x14, 0x63, 0x06, 0x3e, 0xb7, 0x2f, 0xe8, 0xe4, 0xee, 0x72, 0xc8, 0x1e, 0x28, 0x54, + 0x40, 0xc6, 0x36, 0x53, 0x6b, 0x5d, 0xa9, 0xd7, 0x3e, 0xf9, 0xcf, 0x84, 0xf4, 0xe5, 0x29, 0x23, + 0x3b, 0xd0, 0xe2, 0xd4, 0x1a, 0xbb, 0x74, 0x86, 0xaa, 0x2c, 0xf6, 0x9b, 0x9c, 0x5a, 0x9f, 0xe8, + 0x0c, 0xdf, 0x74, 0x60, 0xed, 0x72, 0x81, 0x7e, 0x38, 0xfe, 0x46, 0x5d, 0xe6, 0xa0, 0x7e, 0x02, + 0xf7, 0x0b, 0xba, 0x44, 0x7f, 0x2f, 0xa1, 0x95, 0x12, 0x85, 0xb2, 0xad, 0x82, 0xb2, 0xec, 0x40, + 0x56, 0xa6, 0xbf, 0x4f, 0x07, 0x51, 0x6e, 0xf2, 0x0e, 0x2c, 0x15, 0xb6, 0xcb, 0x2c, 0x31, 0xd5, + 0x03, 0x58, 0x1f, 0x30, 0x76, 0x4e, 0xad, 0x94, 0xae, 0x83, 0xcc, 0xa9, 0x25, 0xc0, 0x1b, 0x05, + 0x70, 0x54, 0x15, 0x6d, 0xea, 0x1b, 0xd0, 0x49, 0x0f, 0x09, 0x0c, 0x83, 0xa6, 0x30, 0x97, 0x3c, + 0x85, 0x9a, 0xcd, 0xfe, 0x62, 0x7f, 0xcd, 0x66, 0x51, 0x1b, 0x33, 0xe4, 0x34, 0x2a, 0x88, 0x6d, + 0x2f, 0xb7, 0x31, 0x12, 0x9b, 0x66, 0x56, 0xa6, 0x4f, 0xa1, 0x9d, 0x31, 0x88, 0x0a, 0xcd, 0xb9, + 0xef, 0x7d, 0x47, 0xe1, 0x42, 0xdb, 0x4c, 0x97, 0x84, 0x40, 0x3d, 0x1e, 0x56, 0x3c, 0x4c, 0x33, + 0x7e, 0x26, 0xdb, 0xd0, 0x60, 0xde, 0x8c, 0xda, 0x6e, 0x32, 0x42, 0x53, 0xac, 0x22, 0xca, 0x12, + 0xfd, 0xc0, 0xf6, 0x5c, 0xb5, 0x9e, 0x50, 0xc4, 0x52, 0xff, 0x2d, 0x41, 0x2b, 0xb5, 0x8b, 0x74, + 0xb2, 0xa6, 0xda, 0xb1, 0xf8, 0xdc, 0x87, 0x56, 0x5b, 0xed, 0x43, 0x7b, 0x0e, 0xf5, 0xb8, 0x55, + 0xb9, 0x2b, 0xf7, 0x94, 0xfe, 0xc3, 0xca, 0x89, 0x45, 0xc7, 0xcc, 0xb8, 0xac, 0xe0, 0x4e, 0x7d, + 0x35, 0x77, 0xce, 0x60, 0x2d, 0x0f, 0xca, 0x6c, 0x90, 0x72, 0x36, 0x3c, 0x83, 0xff, 0x97, 0xd4, + 0x59, 0x60, 0xa6, 0x3a, 0xcd, 0x14, 0x23, 0xca, 0x14, 0xe3, 0x63, 0x92, 0x29, 0x66, 0x52, 0xa4, + 0x3b, 0x20, 0x9f, 0x53, 0xab, 0x12, 0xb4, 0x5b, 0x71, 0x6f, 0x0a, 0xb7, 0x26, 0xe7, 0x90, 0xbc, + 0x5a, 0x32, 0xfc, 0x90, 0xa0, 0x95, 0xb6, 0x45, 0x0e, 0xa1, 0x39, 0xc5, 0x70, 0x3c, 0xa3, 0x73, + 0x55, 0x8a, 0x1d, 0xdb, 0xab, 0x6c, 0xdf, 0xf8, 0x80, 0xe1, 0x88, 0xce, 0x8f, 0x5d, 0xee, 0x87, + 0x66, 0x63, 0x1a, 0x2f, 0xb4, 0xd7, 0xa0, 0xe4, 0x5e, 0x93, 0x0d, 0x90, 0xa7, 0x18, 0x0a, 0xf5, + 0xd1, 0x23, 0xd9, 0xcc, 0xbb, 0xd0, 0x16, 0xdd, 0x1e, 0xd6, 0x5e, 0x49, 0xfd, 0x5f, 0x32, 0x28, + 0x91, 0xb4, 0xa3, 0xe4, 0x77, 0xc8, 0x67, 0x58, 0x2f, 0xa4, 0x21, 0x29, 0xca, 0xa8, 0x4a, 0x5c, + 0x4d, 0xbf, 0xad, 0x44, 0xe4, 0xc1, 0x08, 0xe0, 0x2a, 0x05, 0xc9, 0xe3, 0xc2, 0x89, 0x6b, 0x29, + 0xab, 0xed, 0xde, 0xb8, 0x2f, 0x70, 0x5f, 0xa0, 0x53, 0xbc, 0xdf, 0xa4, 0x4a, 0x44, 0x29, 0x48, + 0xb4, 0x27, 0xb7, 0xd6, 0x08, 0xf4, 0x19, 0x28, 0xb9, 0x40, 0x23, 0xd7, 0xa4, 0x94, 0xa1, 0xdd, + 0x9b, 0x0b, 0x04, 0x71, 0x00, 0x8d, 0x24, 0x3d, 0x88, 0x56, 0xbc, 0x05, 0xf9, 0x1c, 0xd2, 0x76, + 0x2a, 0xf7, 0x12, 0xc4, 0xd7, 0x46, 0xfc, 0x9f, 0x77, 0xf0, 0x27, 0x00, 0x00, 0xff, 0xff, 0xa6, + 0x2d, 0x08, 0xae, 0x2f, 0x07, 0x00, 0x00, +} + +// Reference imports to suppress errors if they are not otherwise used. +var _ context.Context +var _ grpc.ClientConn + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +const _ = grpc.SupportPackageIsVersion4 + +// DataCatalogClient is the client API for DataCatalog service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream. +type DataCatalogClient interface { + CreateDataset(ctx context.Context, in *CreateDatasetRequest, opts ...grpc.CallOption) (*CreateDatasetResponse, error) + GetDataset(ctx context.Context, in *GetDatasetRequest, opts ...grpc.CallOption) (*GetDatasetResponse, error) + CreateArtifact(ctx context.Context, in *CreateArtifactRequest, opts ...grpc.CallOption) (*CreateArtifactResponse, error) + GetArtifact(ctx context.Context, in *GetArtifactRequest, opts ...grpc.CallOption) (*GetArtifactResponse, error) + AddTag(ctx context.Context, in *AddTagRequest, opts ...grpc.CallOption) (*AddTagResponse, error) +} + +type dataCatalogClient struct { + cc *grpc.ClientConn +} + +func NewDataCatalogClient(cc *grpc.ClientConn) DataCatalogClient { + return &dataCatalogClient{cc} +} + +func (c *dataCatalogClient) CreateDataset(ctx context.Context, in *CreateDatasetRequest, opts ...grpc.CallOption) (*CreateDatasetResponse, error) { + out := new(CreateDatasetResponse) + err := c.cc.Invoke(ctx, "/datacatalog.DataCatalog/CreateDataset", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *dataCatalogClient) GetDataset(ctx context.Context, in *GetDatasetRequest, opts ...grpc.CallOption) (*GetDatasetResponse, error) { + out := new(GetDatasetResponse) + err := c.cc.Invoke(ctx, "/datacatalog.DataCatalog/GetDataset", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *dataCatalogClient) CreateArtifact(ctx context.Context, in *CreateArtifactRequest, opts ...grpc.CallOption) (*CreateArtifactResponse, error) { + out := new(CreateArtifactResponse) + err := c.cc.Invoke(ctx, "/datacatalog.DataCatalog/CreateArtifact", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *dataCatalogClient) GetArtifact(ctx context.Context, in *GetArtifactRequest, opts ...grpc.CallOption) (*GetArtifactResponse, error) { + out := new(GetArtifactResponse) + err := c.cc.Invoke(ctx, "/datacatalog.DataCatalog/GetArtifact", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *dataCatalogClient) AddTag(ctx context.Context, in *AddTagRequest, opts ...grpc.CallOption) (*AddTagResponse, error) { + out := new(AddTagResponse) + err := c.cc.Invoke(ctx, "/datacatalog.DataCatalog/AddTag", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// DataCatalogServer is the server API for DataCatalog service. +type DataCatalogServer interface { + CreateDataset(context.Context, *CreateDatasetRequest) (*CreateDatasetResponse, error) + GetDataset(context.Context, *GetDatasetRequest) (*GetDatasetResponse, error) + CreateArtifact(context.Context, *CreateArtifactRequest) (*CreateArtifactResponse, error) + GetArtifact(context.Context, *GetArtifactRequest) (*GetArtifactResponse, error) + AddTag(context.Context, *AddTagRequest) (*AddTagResponse, error) +} + +// UnimplementedDataCatalogServer can be embedded to have forward compatible implementations. +type UnimplementedDataCatalogServer struct { +} + +func (*UnimplementedDataCatalogServer) CreateDataset(ctx context.Context, req *CreateDatasetRequest) (*CreateDatasetResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method CreateDataset not implemented") +} +func (*UnimplementedDataCatalogServer) GetDataset(ctx context.Context, req *GetDatasetRequest) (*GetDatasetResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method GetDataset not implemented") +} +func (*UnimplementedDataCatalogServer) CreateArtifact(ctx context.Context, req *CreateArtifactRequest) (*CreateArtifactResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method CreateArtifact not implemented") +} +func (*UnimplementedDataCatalogServer) GetArtifact(ctx context.Context, req *GetArtifactRequest) (*GetArtifactResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method GetArtifact not implemented") +} +func (*UnimplementedDataCatalogServer) AddTag(ctx context.Context, req *AddTagRequest) (*AddTagResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method AddTag not implemented") +} + +func RegisterDataCatalogServer(s *grpc.Server, srv DataCatalogServer) { + s.RegisterService(&_DataCatalog_serviceDesc, srv) +} + +func _DataCatalog_CreateDataset_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(CreateDatasetRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DataCatalogServer).CreateDataset(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/datacatalog.DataCatalog/CreateDataset", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DataCatalogServer).CreateDataset(ctx, req.(*CreateDatasetRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _DataCatalog_GetDataset_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(GetDatasetRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DataCatalogServer).GetDataset(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/datacatalog.DataCatalog/GetDataset", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DataCatalogServer).GetDataset(ctx, req.(*GetDatasetRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _DataCatalog_CreateArtifact_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(CreateArtifactRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DataCatalogServer).CreateArtifact(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/datacatalog.DataCatalog/CreateArtifact", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DataCatalogServer).CreateArtifact(ctx, req.(*CreateArtifactRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _DataCatalog_GetArtifact_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(GetArtifactRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DataCatalogServer).GetArtifact(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/datacatalog.DataCatalog/GetArtifact", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DataCatalogServer).GetArtifact(ctx, req.(*GetArtifactRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _DataCatalog_AddTag_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(AddTagRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DataCatalogServer).AddTag(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/datacatalog.DataCatalog/AddTag", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DataCatalogServer).AddTag(ctx, req.(*AddTagRequest)) + } + return interceptor(ctx, in, info, handler) +} + +var _DataCatalog_serviceDesc = grpc.ServiceDesc{ + ServiceName: "datacatalog.DataCatalog", + HandlerType: (*DataCatalogServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "CreateDataset", + Handler: _DataCatalog_CreateDataset_Handler, + }, + { + MethodName: "GetDataset", + Handler: _DataCatalog_GetDataset_Handler, + }, + { + MethodName: "CreateArtifact", + Handler: _DataCatalog_CreateArtifact_Handler, + }, + { + MethodName: "GetArtifact", + Handler: _DataCatalog_GetArtifact_Handler, + }, + { + MethodName: "AddTag", + Handler: _DataCatalog_AddTag_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "service.proto", +} diff --git a/datacatalog/protos/idl/service.proto b/datacatalog/protos/idl/service.proto new file mode 100644 index 0000000000..9afecc383a --- /dev/null +++ b/datacatalog/protos/idl/service.proto @@ -0,0 +1,92 @@ +syntax = "proto3"; + +package datacatalog; + +import "flyteidl/core/literals.proto"; + +service DataCatalog { + rpc CreateDataset (CreateDatasetRequest) returns (CreateDatasetResponse); + rpc GetDataset (GetDatasetRequest) returns (GetDatasetResponse); + rpc CreateArtifact (CreateArtifactRequest) returns (CreateArtifactResponse); + rpc GetArtifact (GetArtifactRequest) returns (GetArtifactResponse); + rpc AddTag (AddTagRequest) returns (AddTagResponse); +} + +message CreateDatasetRequest { + Dataset dataset = 1; +} + +message CreateDatasetResponse { + +} + +message GetDatasetRequest { + DatasetID dataset = 1; +} + +message GetDatasetResponse { + Dataset dataset = 1; +} + +message GetArtifactRequest { + DatasetID dataset = 1; + + oneof query_handle { + string artifact_id = 2; + string tag_name = 3; + } +} + +message GetArtifactResponse { + Artifact artifact = 1; +} + +message CreateArtifactRequest { + Artifact artifact = 1; +} + +message CreateArtifactResponse { + +} + +message AddTagRequest { + Tag tag = 1; +} + +message AddTagResponse { + +} + +message Dataset { + DatasetID id = 1; + Metadata metadata = 2; +} + +message DatasetID { + string project = 1; // The name of the project + string name = 2; // The name of the dataset + string domain = 3; // The domain (eg. environment) in which it's desired to run (optional) + string version = 4; // Version of the data schema +} + +message Artifact { + string id = 1; + DatasetID dataset = 2; + repeated ArtifactData data = 3; + Metadata metadata = 4; +} + +message ArtifactData { + string name = 1; + flyteidl.core.Literal value = 2; +} + +message Tag { + string name = 1; + string artifact_id = 2; + DatasetID dataset = 3; +} + +message Metadata { + map key_map = 1; // key map is a dictionary of key/val strings that represent metadata +} From 30046abccbc52e23b541b382649e1222a00701e2 Mon Sep 17 00:00:00 2001 From: Johnny Burns Date: Wed, 21 Aug 2019 12:28:21 -0700 Subject: [PATCH 0075/1918] use https for all deps --- flyteplugins/Gopkg.lock | 4 ++-- flyteplugins/Gopkg.toml | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/flyteplugins/Gopkg.lock b/flyteplugins/Gopkg.lock index dd76c0bb7d..a0b6d14518 100755 --- a/flyteplugins/Gopkg.lock +++ b/flyteplugins/Gopkg.lock @@ -363,7 +363,7 @@ ] pruneopts = "" revision = "793b09d190148236f41ad8160b5cec9a3325c16f" - source = "git@github.com:lyft/flyteidl" + source = "https://github.com/lyft/flyteidl" version = "v0.1.0" [[projects]] @@ -386,7 +386,7 @@ ] pruneopts = "" revision = "c0e1a9369cb442d70093564fbbc21d8298f5aeb6" - source = "git@github.com:lyft/flytestdlib" + source = "https://github.com/lyft/flytestdlib" version = "v0.2.11" [[projects]] diff --git a/flyteplugins/Gopkg.toml b/flyteplugins/Gopkg.toml index 7969a7306a..99b9a880f5 100755 --- a/flyteplugins/Gopkg.toml +++ b/flyteplugins/Gopkg.toml @@ -9,12 +9,12 @@ ignored = ["k8s.io/spark-on-k8s-operator", [[constraint]] name = "github.com/lyft/flyteidl" - source = "git@github.com:lyft/flyteidl" + source = "https://github.com/lyft/flyteidl" version = "^0.1.x" [[constraint]] name = "github.com/lyft/flytestdlib" - source = "git@github.com:lyft/flytestdlib" + source = "https://github.com/lyft/flytestdlib" version = "^0.2.x" [[constraint]] From 96dba2d99a439c04a00e0165e187dbf71c67835b Mon Sep 17 00:00:00 2001 From: Andrew Chan Date: Mon, 26 Aug 2019 14:41:11 -0700 Subject: [PATCH 0076/1918] Add prometheus server --- datacatalog/cmd/entrypoints/serve.go | 2 +- datacatalog/datacatalog_config.yaml | 4 +- .../pkg/rpc/datacatalogservice/service.go | 10 ++++ .../runtime/configs/data_catalog_config.go | 2 + .../configs/datacatalogconfig_flags.go | 6 ++- .../configs/datacatalogconfig_flags_test.go | 53 ++++++++++++++++--- 6 files changed, 66 insertions(+), 11 deletions(-) diff --git a/datacatalog/cmd/entrypoints/serve.go b/datacatalog/cmd/entrypoints/serve.go index ee84033969..ea92cb1215 100644 --- a/datacatalog/cmd/entrypoints/serve.go +++ b/datacatalog/cmd/entrypoints/serve.go @@ -28,7 +28,7 @@ var serveCmd = &cobra.Command{ go func() { err := serveHealthcheck(ctx, cfg) if err != nil { - logger.Errorf(ctx, "Unable to serve http", config.GetConfig().GetGrpcHostAddress(), err) + logger.Errorf(ctx, "Unable to serve http", config.GetConfig().GetHTTPHostAddress(), err) } }() diff --git a/datacatalog/datacatalog_config.yaml b/datacatalog/datacatalog_config.yaml index b33bb9acc5..2287f7061f 100644 --- a/datacatalog/datacatalog_config.yaml +++ b/datacatalog/datacatalog_config.yaml @@ -2,10 +2,12 @@ # Real configuration when running inside K8s (local or otherwise) lives in a ConfigMap # Look in the artifacts directory in the flyte repo for what's actually run application: - grpcPort: 8089 + grpcPort: 8081 httpPort: 8080 datacatalog: storage-prefix: "metadata" + metrics-scope: "datacatalog" + profiler-port: 10254 storage: connection: access-key: minio diff --git a/datacatalog/pkg/rpc/datacatalogservice/service.go b/datacatalog/pkg/rpc/datacatalogservice/service.go index 5260a517ff..145f33b389 100644 --- a/datacatalog/pkg/rpc/datacatalogservice/service.go +++ b/datacatalog/pkg/rpc/datacatalogservice/service.go @@ -12,6 +12,7 @@ import ( "github.com/lyft/datacatalog/pkg/runtime" catalog "github.com/lyft/datacatalog/protos/gen" "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/profutils" "github.com/lyft/flytestdlib/promutils" "github.com/lyft/flytestdlib/storage" ) @@ -86,6 +87,15 @@ func NewDataCatalogService() *DataCatalogService { repos := repositories.GetRepository(repositories.POSTGRES, dbConfig, catalogScope) logger.Infof(ctx, "Created DB connection.") + // Serve profiling endpoint. + go func() { + err := profutils.StartProfilingServerWithDefaultHandlers( + context.Background(), dataCatalogConfig.ProfilerPort, nil) + if err != nil { + logger.Panicf(context.Background(), "Failed to Start profiling and Metrics server. Error, %v", err) + } + }() + return &DataCatalogService{ DatasetManager: impl.NewDatasetManager(repos, dataStorageClient, catalogScope.NewSubScope("dataset")), ArtifactManager: impl.NewArtifactManager(repos, dataStorageClient, storagePrefix, catalogScope.NewSubScope("artifact")), diff --git a/datacatalog/pkg/runtime/configs/data_catalog_config.go b/datacatalog/pkg/runtime/configs/data_catalog_config.go index f5ba55dd5d..a7bf8de8bf 100644 --- a/datacatalog/pkg/runtime/configs/data_catalog_config.go +++ b/datacatalog/pkg/runtime/configs/data_catalog_config.go @@ -5,4 +5,6 @@ package configs // This configuration is the base configuration to start admin type DataCatalogConfig struct { StoragePrefix string `json:"storage-prefix" pflag:",StoragePrefix specifies the prefix where DataCatalog stores offloaded ArtifactData in CloudStorage. If not specified, the data will be stored in the base container directly."` + MetricsScope string `json:"metrics-ccope" pflag:",Scope that the metrics will record under."` + ProfilerPort int `json:"profiler-port" pflag:",Port that the profiling service is listening on."` } diff --git a/datacatalog/pkg/runtime/configs/datacatalogconfig_flags.go b/datacatalog/pkg/runtime/configs/datacatalogconfig_flags.go index 946ff45b06..2f41ca13b7 100755 --- a/datacatalog/pkg/runtime/configs/datacatalogconfig_flags.go +++ b/datacatalog/pkg/runtime/configs/datacatalogconfig_flags.go @@ -1,6 +1,6 @@ // Code generated by go generate; DO NOT EDIT. // This file was generated by robots at -// 2019-08-07 10:50:49.390084414 -0700 PDT m=+0.045946390 +// 2019-08-23 21:53:46.865656622 -0700 PDT m=+1.391823141 package configs @@ -14,6 +14,8 @@ import ( // flags is json-name.json-sub-name... etc. func (DataCatalogConfig) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags := pflag.NewFlagSet("DataCatalogConfig", pflag.ExitOnError) - cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "storage-prefix"), []string{}, "StoragePrefix specifies the prefix where DataCatalog stores offloaded ArtifactData in CloudStorage. If not specified, the data will be stored in the base container directly.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage-prefix"), *new(string), "StoragePrefix specifies the prefix where DataCatalog stores offloaded ArtifactData in CloudStorage. If not specified, the data will be stored in the base container directly.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "metrics-ccope"), *new(string), "Scope that the metrics will record under.") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "profiler-port"), *new(int), "Port that the profiling service is listening on.") return cmdFlags } diff --git a/datacatalog/pkg/runtime/configs/datacatalogconfig_flags_test.go b/datacatalog/pkg/runtime/configs/datacatalogconfig_flags_test.go index e666480aa4..41f53cdd88 100755 --- a/datacatalog/pkg/runtime/configs/datacatalogconfig_flags_test.go +++ b/datacatalog/pkg/runtime/configs/datacatalogconfig_flags_test.go @@ -1,6 +1,6 @@ // Code generated by go generate; DO NOT EDIT. // This file was generated by robots at -// 2019-08-07 10:50:49.390084414 -0700 PDT m=+0.045946390 +// 2019-08-23 21:53:46.865656622 -0700 PDT m=+1.391823141 package configs @@ -8,7 +8,6 @@ import ( "encoding/json" "fmt" "reflect" - "strings" "testing" "github.com/mitchellh/mapstructure" @@ -93,17 +92,57 @@ func TestDataCatalogConfig_SetFlags(t *testing.T) { t.Run("Test_storage-prefix", func(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly - if vStringSlice, err := cmdFlags.GetStringSlice("storage-prefix"); err == nil { - assert.Equal(t, []string{}, vStringSlice) + if vString, err := cmdFlags.GetString("storage-prefix"); err == nil { + assert.Equal(t, *new(string), vString) } else { assert.FailNow(t, err.Error()) } }) t.Run("Override", func(t *testing.T) { - cmdFlags.Set("storage-prefix", "1,1") - if vStringSlice, err := cmdFlags.GetStringSlice("storage-prefix"); err == nil { - testDecodeSlice_DataCatalogConfig(t, strings.Join(vStringSlice, ","), &actual.StoragePrefix) + cmdFlags.Set("storage-prefix", "1") + if vString, err := cmdFlags.GetString("storage-prefix"); err == nil { + testDecodeJson_DataCatalogConfig(t, fmt.Sprintf("%v", vString), &actual.StoragePrefix) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_metrics-ccope", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("metrics-ccope"); err == nil { + assert.Equal(t, *new(string), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + cmdFlags.Set("metrics-ccope", "1") + if vString, err := cmdFlags.GetString("metrics-ccope"); err == nil { + testDecodeJson_DataCatalogConfig(t, fmt.Sprintf("%v", vString), &actual.MetricsScope) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_profiler-port", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("profiler-port"); err == nil { + assert.Equal(t, *new(int), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + cmdFlags.Set("profiler-port", "1") + if vInt, err := cmdFlags.GetInt("profiler-port"); err == nil { + testDecodeJson_DataCatalogConfig(t, fmt.Sprintf("%v", vInt), &actual.ProfilerPort) } else { assert.FailNow(t, err.Error()) From 88144a09332778056d44adabd2527e0e7049920a Mon Sep 17 00:00:00 2001 From: Andrew Chan Date: Wed, 28 Aug 2019 11:08:36 -0700 Subject: [PATCH 0077/1918] add prometheus stats --- datacatalog/Gopkg.lock | 11 +++-- datacatalog/cmd/entrypoints/serve.go | 15 +++--- datacatalog/cmd/entrypoints/serve_dummy.go | 4 -- .../pkg/manager/impl/artifact_manager.go | 49 +++++++++++++++++++ .../pkg/manager/impl/artifact_manager_test.go | 4 ++ .../pkg/manager/impl/dataset_manager.go | 44 ++++++++++++++++- .../pkg/manager/impl/dataset_manager_test.go | 6 +++ datacatalog/pkg/manager/impl/tag_manager.go | 32 ++++++++++-- .../pkg/manager/impl/tag_manager_test.go | 6 +++ .../pkg/repositories/gormimpl/artifact.go | 14 ++++-- .../repositories/gormimpl/artifact_test.go | 15 ++++-- .../pkg/repositories/gormimpl/dataset.go | 14 ++++-- .../pkg/repositories/gormimpl/dataset_test.go | 15 ++++-- .../pkg/repositories/gormimpl/metrics.go | 25 ++++++++++ datacatalog/pkg/repositories/gormimpl/tag.go | 12 ++++- .../pkg/repositories/gormimpl/tag_test.go | 13 +++-- datacatalog/pkg/repositories/postgres_repo.go | 6 +-- .../pkg/rpc/datacatalogservice/service.go | 6 ++- 18 files changed, 244 insertions(+), 47 deletions(-) create mode 100644 datacatalog/pkg/repositories/gormimpl/metrics.go diff --git a/datacatalog/Gopkg.lock b/datacatalog/Gopkg.lock index 130738217c..f582f67f54 100644 --- a/datacatalog/Gopkg.lock +++ b/datacatalog/Gopkg.lock @@ -212,7 +212,7 @@ version = "v1.2.0" [[projects]] - digest = "1:72841617053e049a34e8d98232b863fe5c173cdf03a4d8bc8dc9039303ad418e" + digest = "1:87d447239dd6752a762a4f94652b69c46cb53ca3d32dc52898857f7d1776c701" name = "github.com/lyft/flyteidl" packages = ["gen/pb-go/flyteidl/core"] pruneopts = "T" @@ -220,7 +220,7 @@ version = "v0.1.0" [[projects]] - digest = "1:6cc3cfda698262608d464cd89bfab217bb8fa8d507bf13b82cd585a520aed37d" + digest = "1:ef46e605151d1cacde38e96f3dac8d458d5380d3be4b58cae9e7c8220c8fffd2" name = "github.com/lyft/flytestdlib" packages = [ "atomic", @@ -230,9 +230,11 @@ "contextutils", "ioutils", "logger", + "profutils", "promutils", "promutils/labeled", "storage", + "version", ] pruneopts = "NUT" revision = "2577ff228d559b8fdf687f6cfad196bfbf1bd50a" @@ -311,11 +313,12 @@ version = "v1.0.0" [[projects]] - digest = "1:0f362379987ecc2cf4df1b8e4c1653a782f6f9f77f749547b734499b3c543080" + digest = "1:3d7ad9c93ee2e43dc7dd48b76f96804c365be7fe30e75f7405028c4e633d189a" name = "github.com/prometheus/client_golang" packages = [ "prometheus", "prometheus/internal", + "prometheus/promhttp", ] pruneopts = "NUT" revision = "2641b987480bca71fb39738eb8c8b0d577cb1d76" @@ -584,11 +587,11 @@ "github.com/lyft/flytestdlib/config/viper", "github.com/lyft/flytestdlib/contextutils", "github.com/lyft/flytestdlib/logger", + "github.com/lyft/flytestdlib/profutils", "github.com/lyft/flytestdlib/promutils", "github.com/lyft/flytestdlib/promutils/labeled", "github.com/lyft/flytestdlib/storage", "github.com/mitchellh/mapstructure", - "github.com/pkg/errors", "github.com/spf13/cobra", "github.com/spf13/pflag", "github.com/stretchr/testify/assert", diff --git a/datacatalog/cmd/entrypoints/serve.go b/datacatalog/cmd/entrypoints/serve.go index ea92cb1215..a52f745823 100644 --- a/datacatalog/cmd/entrypoints/serve.go +++ b/datacatalog/cmd/entrypoints/serve.go @@ -2,17 +2,13 @@ package entrypoints import ( "context" - "fmt" - "html" "net" "net/http" "github.com/lyft/datacatalog/pkg/config" "github.com/lyft/datacatalog/pkg/rpc/datacatalogservice" datacatalog "github.com/lyft/datacatalog/protos/gen" - "github.com/lyft/flytestdlib/contextutils" "github.com/lyft/flytestdlib/logger" - "github.com/lyft/flytestdlib/promutils/labeled" "github.com/spf13/cobra" "google.golang.org/grpc" ) @@ -38,8 +34,6 @@ var serveCmd = &cobra.Command{ func init() { RootCmd.AddCommand(serveCmd) - - labeled.SetMetricKeys(contextutils.AppNameKey) } // Create and start the gRPC server @@ -63,10 +57,13 @@ func newGRPCServer(_ context.Context) *grpc.Server { } func serveHealthcheck(ctx context.Context, cfg *config.Config) error { - http.HandleFunc("/healthcheck", func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintf(w, "Healthcheck success on %v", html.EscapeString(r.URL.Path)) + mux := http.NewServeMux() + + // Register Healthcheck + mux.HandleFunc("/healthcheck", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) }) logger.Infof(ctx, "Serving DataCatalog http on port %v", cfg.GetHTTPHostAddress()) - return http.ListenAndServe(cfg.GetHTTPHostAddress(), nil) + return http.ListenAndServe(cfg.GetHTTPHostAddress(), mux) } diff --git a/datacatalog/cmd/entrypoints/serve_dummy.go b/datacatalog/cmd/entrypoints/serve_dummy.go index 8eacd00ea6..c8bf40c3b3 100644 --- a/datacatalog/cmd/entrypoints/serve_dummy.go +++ b/datacatalog/cmd/entrypoints/serve_dummy.go @@ -7,9 +7,7 @@ import ( "github.com/lyft/datacatalog/pkg/config" "github.com/lyft/datacatalog/pkg/rpc/datacatalogservice" datacatalog "github.com/lyft/datacatalog/protos/gen" - "github.com/lyft/flytestdlib/contextutils" "github.com/lyft/flytestdlib/logger" - "github.com/lyft/flytestdlib/promutils/labeled" "github.com/spf13/cobra" "google.golang.org/grpc" ) @@ -26,8 +24,6 @@ var serveDummyCmd = &cobra.Command{ func init() { RootCmd.AddCommand(serveDummyCmd) - - labeled.SetMetricKeys(contextutils.AppNameKey) } // Create and start the gRPC server and http healthcheck endpoint diff --git a/datacatalog/pkg/manager/impl/artifact_manager.go b/datacatalog/pkg/manager/impl/artifact_manager.go index 6b65f46ca5..7ce5e144e4 100644 --- a/datacatalog/pkg/manager/impl/artifact_manager.go +++ b/datacatalog/pkg/manager/impl/artifact_manager.go @@ -2,6 +2,7 @@ package impl import ( "context" + "time" "github.com/lyft/datacatalog/pkg/errors" "github.com/lyft/datacatalog/pkg/manager/impl/validators" @@ -13,20 +14,40 @@ import ( "github.com/lyft/datacatalog/pkg/repositories/transformers" "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/promutils/labeled" "github.com/lyft/flytestdlib/storage" "google.golang.org/grpc/codes" ) +type artifactMetrics struct { + scope promutils.Scope + createSuccessCounter labeled.Counter + createFailureCounter labeled.Counter + getSuccessCounter labeled.Counter + getFailureCounter labeled.Counter + createDataFailureCounter labeled.Counter + createDataSuccessCounter labeled.Counter + transformerErrorCounter labeled.Counter + validationErrorCounter labeled.Counter + createResponseTime labeled.StopWatch + getResponseTime labeled.StopWatch +} + type artifactManager struct { repo repositories.RepositoryInterface artifactStore ArtifactDataStore + systemMetrics artifactMetrics } // Create an Artifact along with the associated ArtifactData. The ArtifactData will be stored in an offloaded location. func (m *artifactManager) CreateArtifact(ctx context.Context, request datacatalog.CreateArtifactRequest) (*datacatalog.CreateArtifactResponse, error) { + timer := m.systemMetrics.createResponseTime.Start(ctx) + defer timer.Stop() + artifact := request.Artifact err := validators.ValidateArtifact(artifact) if err != nil { + m.systemMetrics.validationErrorCounter.Inc(ctx) return nil, err } @@ -35,6 +56,7 @@ func (m *artifactManager) CreateArtifact(ctx context.Context, request datacatalo // The dataset must exist for the artifact, let's verify that first _, err = m.repo.DatasetRepo().Get(ctx, datasetKey) if err != nil { + m.systemMetrics.createFailureCounter.Inc(ctx) return nil, err } @@ -43,22 +65,29 @@ func (m *artifactManager) CreateArtifact(ctx context.Context, request datacatalo for i, artifactData := range request.Artifact.Data { dataLocation, err := m.artifactStore.PutData(ctx, *artifact, *artifactData) if err != nil { + m.systemMetrics.createDataFailureCounter.Inc(ctx) + m.systemMetrics.createFailureCounter.Inc(ctx) return nil, err } artifactDataModels[i].Name = artifactData.Name artifactDataModels[i].Location = dataLocation.String() + m.systemMetrics.createDataSuccessCounter.Inc(ctx) } artifactModel, err := transformers.CreateArtifactModel(request, artifactDataModels) if err != nil { + m.systemMetrics.transformerErrorCounter.Inc(ctx) return nil, err } err = m.repo.ArtifactRepo().Create(ctx, artifactModel) if err != nil { + m.systemMetrics.createFailureCounter.Inc(ctx) return nil, err } + + m.systemMetrics.createSuccessCounter.Inc(ctx) return &datacatalog.CreateArtifactResponse{}, nil } @@ -67,6 +96,7 @@ func (m *artifactManager) GetArtifact(ctx context.Context, request datacatalog.G datasetID := request.Dataset err := validators.ValidateGetArtifactRequest(request) if err != nil { + m.systemMetrics.validationErrorCounter.Inc(ctx) return nil, err } @@ -77,6 +107,7 @@ func (m *artifactManager) GetArtifact(ctx context.Context, request datacatalog.G artifactModel, err = m.repo.ArtifactRepo().Get(ctx, artifactKey) if err != nil { + m.systemMetrics.getFailureCounter.Inc(ctx) return nil, err } case *datacatalog.GetArtifactRequest_TagName: @@ -84,6 +115,7 @@ func (m *artifactManager) GetArtifact(ctx context.Context, request datacatalog.G tag, err := m.repo.TagRepo().Get(ctx, tagKey) if err != nil { + m.systemMetrics.getFailureCounter.Inc(ctx) return nil, err } @@ -96,6 +128,7 @@ func (m *artifactManager) GetArtifact(ctx context.Context, request datacatalog.G artifact, err := transformers.FromArtifactModel(artifactModel) if err != nil { + m.systemMetrics.transformerErrorCounter.Inc(ctx) return nil, err } @@ -113,14 +146,30 @@ func (m *artifactManager) GetArtifact(ctx context.Context, request datacatalog.G } artifact.Data = artifactDataList + m.systemMetrics.getSuccessCounter.Inc(ctx) return &datacatalog.GetArtifactResponse{ Artifact: &artifact, }, nil } func NewArtifactManager(repo repositories.RepositoryInterface, store *storage.DataStore, storagePrefix storage.DataReference, artifactScope promutils.Scope) interfaces.ArtifactManager { + artifactMetrics := artifactMetrics{ + scope: artifactScope, + createResponseTime: labeled.NewStopWatch("create_artifact_duration", "The duration of the create artifact calls.", time.Millisecond, artifactScope, labeled.EmitUnlabeledMetric), + getResponseTime: labeled.NewStopWatch("get_artifact_duration", "The duration of the get artifact calls.", time.Millisecond, artifactScope, labeled.EmitUnlabeledMetric), + createSuccessCounter: labeled.NewCounter("create_artifact_success_count", "The number of times create artifact was called", artifactScope, labeled.EmitUnlabeledMetric), + getSuccessCounter: labeled.NewCounter("get_artifact_success_count", "The number of times get artifact was called", artifactScope, labeled.EmitUnlabeledMetric), + createFailureCounter: labeled.NewCounter("create_artifact_failure_count", "The number of times create artifact failed", artifactScope, labeled.EmitUnlabeledMetric), + getFailureCounter: labeled.NewCounter("get_artifact_failure_count", "The number of times get artifact failed", artifactScope, labeled.EmitUnlabeledMetric), + createDataFailureCounter: labeled.NewCounter("create_artifact_data_failure_count", "The number of times create artifact data failed", artifactScope, labeled.EmitUnlabeledMetric), + createDataSuccessCounter: labeled.NewCounter("create_artifact_data_succeeded_count", "The number of times create artifact data succeeded", artifactScope, labeled.EmitUnlabeledMetric), + transformerErrorCounter: labeled.NewCounter("transformer_failed_count", "The number of times transformations failed", artifactScope, labeled.EmitUnlabeledMetric), + validationErrorCounter: labeled.NewCounter("validation_failed_count", "The number of times validation failed", artifactScope, labeled.EmitUnlabeledMetric), + } + return &artifactManager{ repo: repo, artifactStore: NewArtifactDataStore(store, storagePrefix), + systemMetrics: artifactMetrics, } } diff --git a/datacatalog/pkg/manager/impl/artifact_manager_test.go b/datacatalog/pkg/manager/impl/artifact_manager_test.go index 5800768116..707a74166e 100644 --- a/datacatalog/pkg/manager/impl/artifact_manager_test.go +++ b/datacatalog/pkg/manager/impl/artifact_manager_test.go @@ -20,6 +20,10 @@ import ( "google.golang.org/grpc/status" ) +func init() { + labeled.SetMetricKeys(contextutils.AppNameKey) +} + func createInmemoryDataStore(t testing.TB, scope mockScope.Scope) *storage.DataStore { labeled.SetMetricKeys(contextutils.AppNameKey) cfg := storage.Config{ diff --git a/datacatalog/pkg/manager/impl/dataset_manager.go b/datacatalog/pkg/manager/impl/dataset_manager.go index a4a318aa56..614c45d612 100644 --- a/datacatalog/pkg/manager/impl/dataset_manager.go +++ b/datacatalog/pkg/manager/impl/dataset_manager.go @@ -2,6 +2,7 @@ package impl import ( "context" + "time" "github.com/lyft/datacatalog/pkg/manager/impl/validators" "github.com/lyft/datacatalog/pkg/manager/interfaces" @@ -10,38 +11,63 @@ import ( datacatalog "github.com/lyft/datacatalog/protos/gen" "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/promutils/labeled" "github.com/lyft/flytestdlib/storage" ) +type datasetMetrics struct { + scope promutils.Scope + createSuccessCounter labeled.Counter + createErrorCounter labeled.Counter + getSuccessCounter labeled.Counter + getErrorCounter labeled.Counter + transformerErrorCounter labeled.Counter + validationErrorCounter labeled.Counter + createResponseTime labeled.StopWatch + getResponseTime labeled.StopWatch +} + type datasetManager struct { - repo repositories.RepositoryInterface - store *storage.DataStore + repo repositories.RepositoryInterface + store *storage.DataStore + systemMetrics datasetMetrics } // Create a Dataset with optional metadata. If one already exists a grpc AlreadyExists err will be returned func (dm *datasetManager) CreateDataset(ctx context.Context, request datacatalog.CreateDatasetRequest) (*datacatalog.CreateDatasetResponse, error) { + t := dm.systemMetrics.createResponseTime.Start(ctx) + defer t.Stop() + err := validators.ValidateDatasetID(request.Dataset.Id) if err != nil { + dm.systemMetrics.validationErrorCounter.Inc(ctx) return nil, err } datasetModel, err := transformers.CreateDatasetModel(request.Dataset) if err != nil { + dm.systemMetrics.transformerErrorCounter.Inc(ctx) return nil, err } err = dm.repo.DatasetRepo().Create(ctx, *datasetModel) if err != nil { + dm.systemMetrics.createErrorCounter.Inc(ctx) return nil, err } + dm.systemMetrics.createSuccessCounter.Inc(ctx) return &datacatalog.CreateDatasetResponse{}, nil } // Get a Dataset with the given DatasetID if it exists. If none exist a grpc NotFound err will be returned func (dm *datasetManager) GetDataset(ctx context.Context, request datacatalog.GetDatasetRequest) (*datacatalog.GetDatasetResponse, error) { + t := dm.systemMetrics.getResponseTime.Start(ctx) + defer t.Stop() + err := validators.ValidateDatasetID(request.Dataset) if err != nil { + dm.systemMetrics.validationErrorCounter.Inc(ctx) return nil, err } @@ -49,14 +75,17 @@ func (dm *datasetManager) GetDataset(ctx context.Context, request datacatalog.Ge datasetModel, err := dm.repo.DatasetRepo().Get(ctx, datasetKey) if err != nil { + dm.systemMetrics.getErrorCounter.Inc(ctx) return nil, err } datasetResponse, err := transformers.FromDatasetModel(datasetModel) if err != nil { + dm.systemMetrics.transformerErrorCounter.Inc(ctx) return nil, err } + dm.systemMetrics.getSuccessCounter.Inc(ctx) return &datacatalog.GetDatasetResponse{ Dataset: datasetResponse, }, nil @@ -66,5 +95,16 @@ func NewDatasetManager(repo repositories.RepositoryInterface, store *storage.Dat return &datasetManager{ repo: repo, store: store, + systemMetrics: datasetMetrics{ + scope: datasetScope, + createResponseTime: labeled.NewStopWatch("create_duration", "The duration of the create dataset calls.", time.Millisecond, datasetScope, labeled.EmitUnlabeledMetric), + getResponseTime: labeled.NewStopWatch("get_duration", "The duration of the get dataset calls.", time.Millisecond, datasetScope, labeled.EmitUnlabeledMetric), + createSuccessCounter: labeled.NewCounter("create_success_count", "The number of times create dataset was called", datasetScope, labeled.EmitUnlabeledMetric), + getSuccessCounter: labeled.NewCounter("get_success_count", "The number of times get dataset was called", datasetScope, labeled.EmitUnlabeledMetric), + createErrorCounter: labeled.NewCounter("create_failed_count", "The number of times create dataset failed", datasetScope, labeled.EmitUnlabeledMetric), + getErrorCounter: labeled.NewCounter("get_failed_count", "The number of times get dataset failed", datasetScope, labeled.EmitUnlabeledMetric), + transformerErrorCounter: labeled.NewCounter("transformer_failed_count", "The number of times transformations failed", datasetScope, labeled.EmitUnlabeledMetric), + validationErrorCounter: labeled.NewCounter("validation_failed_count", "The number of times validation failed", datasetScope, labeled.EmitUnlabeledMetric), + }, } } diff --git a/datacatalog/pkg/manager/impl/dataset_manager_test.go b/datacatalog/pkg/manager/impl/dataset_manager_test.go index 9f9c39e611..ef0ceaf7b8 100644 --- a/datacatalog/pkg/manager/impl/dataset_manager_test.go +++ b/datacatalog/pkg/manager/impl/dataset_manager_test.go @@ -10,13 +10,19 @@ import ( "github.com/lyft/datacatalog/pkg/repositories/mocks" "github.com/lyft/datacatalog/pkg/repositories/models" datacatalog "github.com/lyft/datacatalog/protos/gen" + "github.com/lyft/flytestdlib/contextutils" mockScope "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/promutils/labeled" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) +func init() { + labeled.SetMetricKeys(contextutils.AppNameKey) +} + func getTestDataset() *datacatalog.Dataset { return &datacatalog.Dataset{ Id: &datacatalog.DatasetID{ diff --git a/datacatalog/pkg/manager/impl/tag_manager.go b/datacatalog/pkg/manager/impl/tag_manager.go index 2c2de65cb3..46e085b3d7 100644 --- a/datacatalog/pkg/manager/impl/tag_manager.go +++ b/datacatalog/pkg/manager/impl/tag_manager.go @@ -2,6 +2,7 @@ package impl import ( "context" + "time" "github.com/lyft/datacatalog/pkg/manager/impl/validators" "github.com/lyft/datacatalog/pkg/manager/interfaces" @@ -12,17 +13,28 @@ import ( datacatalog "github.com/lyft/datacatalog/protos/gen" "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/promutils/labeled" "github.com/lyft/flytestdlib/storage" ) +type tagMetrics struct { + scope promutils.Scope + addTagSuccessCounter labeled.Counter + addTagFailureCounter labeled.Counter + addTagResponseTime labeled.StopWatch + validationErrorCounter labeled.Counter +} + type tagManager struct { - repo repositories.RepositoryInterface - store *storage.DataStore + repo repositories.RepositoryInterface + store *storage.DataStore + systemMetrics tagMetrics } func (m *tagManager) AddTag(ctx context.Context, request datacatalog.AddTagRequest) (*datacatalog.AddTagResponse, error) { if err := validators.ValidateTag(request.Tag); err != nil { + m.systemMetrics.validationErrorCounter.Inc(ctx) return nil, err } @@ -31,6 +43,7 @@ func (m *tagManager) AddTag(ctx context.Context, request datacatalog.AddTagReque artifactKey := transformers.ToArtifactKey(datasetID, request.Tag.ArtifactId) _, err := m.repo.ArtifactRepo().Get(ctx, artifactKey) if err != nil { + m.systemMetrics.addTagFailureCounter.Inc(ctx) return nil, err } @@ -40,15 +53,26 @@ func (m *tagManager) AddTag(ctx context.Context, request datacatalog.AddTagReque ArtifactID: request.Tag.ArtifactId, }) if err != nil { + m.systemMetrics.addTagFailureCounter.Inc(ctx) return nil, err } + m.systemMetrics.addTagSuccessCounter.Inc(ctx) return &datacatalog.AddTagResponse{}, nil } func NewTagManager(repo repositories.RepositoryInterface, store *storage.DataStore, tagScope promutils.Scope) interfaces.TagManager { + systemMetrics := tagMetrics{ + scope: tagScope, + addTagResponseTime: labeled.NewStopWatch("add_tag_duration", "The duration of tagging an artifact.", time.Millisecond, tagScope, labeled.EmitUnlabeledMetric), + addTagSuccessCounter: labeled.NewCounter("add_tag_success_count", "The number of times an artifact was tagged successfully", tagScope, labeled.EmitUnlabeledMetric), + addTagFailureCounter: labeled.NewCounter("add_tag_failure_count", "The number of times we failed to tag an artifact", tagScope, labeled.EmitUnlabeledMetric), + validationErrorCounter: labeled.NewCounter("validation_error_count", "The number of times we failed validate a tag", tagScope, labeled.EmitUnlabeledMetric), + } + return &tagManager{ - repo: repo, - store: store, + repo: repo, + store: store, + systemMetrics: systemMetrics, } } diff --git a/datacatalog/pkg/manager/impl/tag_manager_test.go b/datacatalog/pkg/manager/impl/tag_manager_test.go index 7f8a75db85..1bb985e91c 100644 --- a/datacatalog/pkg/manager/impl/tag_manager_test.go +++ b/datacatalog/pkg/manager/impl/tag_manager_test.go @@ -8,13 +8,19 @@ import ( "github.com/lyft/datacatalog/pkg/repositories/models" datacatalog "github.com/lyft/datacatalog/protos/gen" + "github.com/lyft/flytestdlib/contextutils" mockScope "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/promutils/labeled" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) +func init() { + labeled.SetMetricKeys(contextutils.AppNameKey) +} + func getTestTag() models.Tag { return models.Tag{ TagKey: models.TagKey{ diff --git a/datacatalog/pkg/repositories/gormimpl/artifact.go b/datacatalog/pkg/repositories/gormimpl/artifact.go index b39093ebcd..e67ab49f87 100644 --- a/datacatalog/pkg/repositories/gormimpl/artifact.go +++ b/datacatalog/pkg/repositories/gormimpl/artifact.go @@ -8,23 +8,28 @@ import ( "github.com/lyft/datacatalog/pkg/repositories/interfaces" "github.com/lyft/datacatalog/pkg/repositories/models" datacatalog "github.com/lyft/datacatalog/protos/gen" + "github.com/lyft/flytestdlib/promutils" ) type artifactRepo struct { db *gorm.DB errorTransformer errors.ErrorTransformer - // TODO: add metrics + repoMetrics gormMetrics } -func NewArtifactRepo(db *gorm.DB, errorTransformer errors.ErrorTransformer) interfaces.ArtifactRepo { +func NewArtifactRepo(db *gorm.DB, errorTransformer errors.ErrorTransformer, scope promutils.Scope) interfaces.ArtifactRepo { return &artifactRepo{ db: db, errorTransformer: errorTransformer, + repoMetrics: newGormMetrics(scope), } } +// Create the artifact in a transaction because ArtifactData will be created and associated along with it func (h *artifactRepo) Create(ctx context.Context, artifact models.Artifact) error { - // Create the artifact in a transaction because ArtifactData will be created and associated along with it + timer := h.repoMetrics.CreateDuration.Start(ctx) + defer timer.Stop() + tx := h.db.Begin() tx = tx.Create(&artifact) @@ -43,6 +48,9 @@ func (h *artifactRepo) Create(ctx context.Context, artifact models.Artifact) err } func (h *artifactRepo) Get(ctx context.Context, in models.ArtifactKey) (models.Artifact, error) { + timer := h.repoMetrics.GetDuration.Start(ctx) + defer timer.Stop() + var artifact models.Artifact result := h.db.Preload("ArtifactData").Find(&artifact, &models.Artifact{ ArtifactKey: in, diff --git a/datacatalog/pkg/repositories/gormimpl/artifact_test.go b/datacatalog/pkg/repositories/gormimpl/artifact_test.go index 4446bbae5c..7192ae98be 100644 --- a/datacatalog/pkg/repositories/gormimpl/artifact_test.go +++ b/datacatalog/pkg/repositories/gormimpl/artifact_test.go @@ -14,9 +14,16 @@ import ( "github.com/lyft/datacatalog/pkg/repositories/errors" "github.com/lyft/datacatalog/pkg/repositories/models" "github.com/lyft/datacatalog/pkg/repositories/utils" + "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/promutils/labeled" "google.golang.org/grpc/codes" ) +func init() { + labeled.SetMetricKeys(contextutils.AppNameKey) +} + func getTestArtifact() models.Artifact { return models.Artifact{ ArtifactKey: models.ArtifactKey{ @@ -65,7 +72,7 @@ func TestCreateArtifact(t *testing.T) { artifact.ArtifactData = data - artifactRepo := NewArtifactRepo(utils.GetDbForTest(t), errors.NewPostgresErrorTransformer()) + artifactRepo := NewArtifactRepo(utils.GetDbForTest(t), errors.NewPostgresErrorTransformer(), promutils.NewTestScope()) err := artifactRepo.Create(context.Background(), artifact) assert.NoError(t, err) assert.True(t, artifactCreated) @@ -112,7 +119,7 @@ func TestGetArtifact(t *testing.T) { ArtifactID: artifact.ArtifactID, } - artifactRepo := NewArtifactRepo(utils.GetDbForTest(t), errors.NewPostgresErrorTransformer()) + artifactRepo := NewArtifactRepo(utils.GetDbForTest(t), errors.NewPostgresErrorTransformer(), promutils.NewTestScope()) response, err := artifactRepo.Get(context.Background(), getInput) assert.NoError(t, err) assert.Equal(t, artifact.ArtifactID, response.ArtifactID) @@ -139,7 +146,7 @@ func TestGetArtifactDoesNotExist(t *testing.T) { } // by default mocket will return nil for any queries - artifactRepo := NewArtifactRepo(utils.GetDbForTest(t), errors.NewPostgresErrorTransformer()) + artifactRepo := NewArtifactRepo(utils.GetDbForTest(t), errors.NewPostgresErrorTransformer(), promutils.NewTestScope()) _, err := artifactRepo.Get(context.Background(), getInput) assert.Error(t, err) dcErr, ok := err.(apiErrors.DataCatalogError) @@ -159,7 +166,7 @@ func TestCreateArtifactAlreadyExists(t *testing.T) { getAlreadyExistsErr(), ) - artifactRepo := NewArtifactRepo(utils.GetDbForTest(t), errors.NewPostgresErrorTransformer()) + artifactRepo := NewArtifactRepo(utils.GetDbForTest(t), errors.NewPostgresErrorTransformer(), promutils.NewTestScope()) err := artifactRepo.Create(context.Background(), artifact) assert.Error(t, err) dcErr, ok := err.(apiErrors.DataCatalogError) diff --git a/datacatalog/pkg/repositories/gormimpl/dataset.go b/datacatalog/pkg/repositories/gormimpl/dataset.go index e74dca6260..d0e5dc1526 100644 --- a/datacatalog/pkg/repositories/gormimpl/dataset.go +++ b/datacatalog/pkg/repositories/gormimpl/dataset.go @@ -7,27 +7,30 @@ import ( "github.com/lyft/datacatalog/pkg/repositories/errors" "github.com/lyft/datacatalog/pkg/repositories/interfaces" "github.com/lyft/datacatalog/pkg/repositories/models" - idl_datacatalog "github.com/lyft/datacatalog/protos/gen" "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/promutils" ) type dataSetRepo struct { db *gorm.DB errorTransformer errors.ErrorTransformer - - // TODO: add metrics + repoMetrics gormMetrics } -func NewDatasetRepo(db *gorm.DB, errorTransformer errors.ErrorTransformer) interfaces.DatasetRepo { +func NewDatasetRepo(db *gorm.DB, errorTransformer errors.ErrorTransformer, scope promutils.Scope) interfaces.DatasetRepo { return &dataSetRepo{ db: db, errorTransformer: errorTransformer, + repoMetrics: newGormMetrics(scope), } } // Create a Dataset model func (h *dataSetRepo) Create(ctx context.Context, in models.Dataset) error { + timer := h.repoMetrics.CreateDuration.Start(ctx) + defer timer.Stop() + result := h.db.Create(&in) if result.Error != nil { return h.errorTransformer.ToDataCatalogError(result.Error) @@ -37,6 +40,9 @@ func (h *dataSetRepo) Create(ctx context.Context, in models.Dataset) error { // Get Dataset model func (h *dataSetRepo) Get(ctx context.Context, in models.DatasetKey) (models.Dataset, error) { + timer := h.repoMetrics.GetDuration.Start(ctx) + defer timer.Stop() + var ds models.Dataset result := h.db.Where(&models.Dataset{DatasetKey: in}).First(&ds) diff --git a/datacatalog/pkg/repositories/gormimpl/dataset_test.go b/datacatalog/pkg/repositories/gormimpl/dataset_test.go index 045cf1e7de..e20ff156ff 100644 --- a/datacatalog/pkg/repositories/gormimpl/dataset_test.go +++ b/datacatalog/pkg/repositories/gormimpl/dataset_test.go @@ -16,8 +16,15 @@ import ( "github.com/lyft/datacatalog/pkg/repositories/errors" "github.com/lyft/datacatalog/pkg/repositories/models" "github.com/lyft/datacatalog/pkg/repositories/utils" + "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/promutils/labeled" ) +func init() { + labeled.SetMetricKeys(contextutils.AppNameKey) +} + func getTestDataset() models.Dataset { return models.Dataset{ DatasetKey: models.DatasetKey{ @@ -49,7 +56,7 @@ func TestCreateDataset(t *testing.T) { }, ) - datasetRepo := NewDatasetRepo(utils.GetDbForTest(t), errors.NewPostgresErrorTransformer()) + datasetRepo := NewDatasetRepo(utils.GetDbForTest(t), errors.NewPostgresErrorTransformer(), promutils.NewTestScope()) err := datasetRepo.Create(context.Background(), getTestDataset()) assert.NoError(t, err) assert.True(t, datasetCreated) @@ -72,7 +79,7 @@ func TestGetDataset(t *testing.T) { // Only match on queries that append expected filters GlobalMock.NewMock().WithQuery(`SELECT * FROM "datasets" WHERE "datasets"."deleted_at" IS NULL AND (("datasets"."project" = testProject) AND ("datasets"."name" = testName) AND ("datasets"."domain" = testDomain) AND ("datasets"."version" = testVersion)) ORDER BY "datasets"."project" ASC LIMIT 1`).WithReply(expectedResponse) - datasetRepo := NewDatasetRepo(utils.GetDbForTest(t), errors.NewPostgresErrorTransformer()) + datasetRepo := NewDatasetRepo(utils.GetDbForTest(t), errors.NewPostgresErrorTransformer(), promutils.NewTestScope()) actualDataset, err := datasetRepo.Get(context.Background(), dataset.DatasetKey) assert.NoError(t, err) assert.Equal(t, dataset.Project, actualDataset.Project) @@ -95,7 +102,7 @@ func TestGetDatasetNotFound(t *testing.T) { // Only match on queries that append expected filters GlobalMock.NewMock().WithQuery(`SELECT * FROM "datasets" WHERE "datasets"."deleted_at" IS NULL AND (("datasets"."project" = testProject) AND ("datasets"."name" = testName) AND ("datasets"."domain" = testDomain) AND ("datasets"."version" = testVersion)) ORDER BY "datasets"."id" ASC LIMIT 1`).WithReply(nil) - datasetRepo := NewDatasetRepo(utils.GetDbForTest(t), errors.NewPostgresErrorTransformer()) + datasetRepo := NewDatasetRepo(utils.GetDbForTest(t), errors.NewPostgresErrorTransformer(), promutils.NewTestScope()) _, err := datasetRepo.Get(context.Background(), dataset.DatasetKey) assert.Error(t, err) notFoundErr, ok := err.(datacatalog_error.DataCatalogError) @@ -113,7 +120,7 @@ func TestCreateDatasetAlreadyExists(t *testing.T) { getAlreadyExistsErr(), ) - datasetRepo := NewDatasetRepo(utils.GetDbForTest(t), errors.NewPostgresErrorTransformer()) + datasetRepo := NewDatasetRepo(utils.GetDbForTest(t), errors.NewPostgresErrorTransformer(), promutils.NewTestScope()) err := datasetRepo.Create(context.Background(), getTestDataset()) assert.Error(t, err) dcErr, ok := err.(datacatalog_error.DataCatalogError) diff --git a/datacatalog/pkg/repositories/gormimpl/metrics.go b/datacatalog/pkg/repositories/gormimpl/metrics.go new file mode 100644 index 0000000000..d80c3d8b11 --- /dev/null +++ b/datacatalog/pkg/repositories/gormimpl/metrics.go @@ -0,0 +1,25 @@ +package gormimpl + +import ( + "time" + + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/promutils/labeled" +) + +// Common metrics for DB CRUD operations +type gormMetrics struct { + Scope promutils.Scope + CreateDuration labeled.StopWatch + GetDuration labeled.StopWatch +} + +func newGormMetrics(scope promutils.Scope) gormMetrics { + return gormMetrics{ + Scope: scope, + CreateDuration: labeled.NewStopWatch( + "create", "Duration for creating a new entity", time.Millisecond, scope), + GetDuration: labeled.NewStopWatch( + "get", "Duration for retrieving an entity ", time.Millisecond, scope), + } +} diff --git a/datacatalog/pkg/repositories/gormimpl/tag.go b/datacatalog/pkg/repositories/gormimpl/tag.go index c16fe937dd..adb9cbbf1a 100644 --- a/datacatalog/pkg/repositories/gormimpl/tag.go +++ b/datacatalog/pkg/repositories/gormimpl/tag.go @@ -8,22 +8,27 @@ import ( "github.com/lyft/datacatalog/pkg/repositories/interfaces" "github.com/lyft/datacatalog/pkg/repositories/models" idl_datacatalog "github.com/lyft/datacatalog/protos/gen" + "github.com/lyft/flytestdlib/promutils" ) type tagRepo struct { db *gorm.DB errorTransformer errors.ErrorTransformer - // TODO: add metrics + repoMetrics gormMetrics } -func NewTagRepo(db *gorm.DB, errorTransformer errors.ErrorTransformer) interfaces.TagRepo { +func NewTagRepo(db *gorm.DB, errorTransformer errors.ErrorTransformer, scope promutils.Scope) interfaces.TagRepo { return &tagRepo{ db: db, errorTransformer: errorTransformer, + repoMetrics: newGormMetrics(scope), } } func (h *tagRepo) Create(ctx context.Context, tag models.Tag) error { + timer := h.repoMetrics.CreateDuration.Start(ctx) + defer timer.Stop() + db := h.db.Create(&tag) if db.Error != nil { @@ -33,6 +38,9 @@ func (h *tagRepo) Create(ctx context.Context, tag models.Tag) error { } func (h *tagRepo) Get(ctx context.Context, in models.TagKey) (models.Tag, error) { + timer := h.repoMetrics.GetDuration.Start(ctx) + defer timer.Stop() + var tag models.Tag result := h.db.Preload("Artifact").Preload("Artifact.ArtifactData").Find(&tag, &models.Tag{ TagKey: in, diff --git a/datacatalog/pkg/repositories/gormimpl/tag_test.go b/datacatalog/pkg/repositories/gormimpl/tag_test.go index c1befea462..acde271b28 100644 --- a/datacatalog/pkg/repositories/gormimpl/tag_test.go +++ b/datacatalog/pkg/repositories/gormimpl/tag_test.go @@ -16,9 +16,16 @@ import ( "github.com/lib/pq" "github.com/lyft/datacatalog/pkg/repositories/models" "github.com/lyft/datacatalog/pkg/repositories/utils" + "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/promutils/labeled" "google.golang.org/grpc/codes" ) +func init() { + labeled.SetMetricKeys(contextutils.AppNameKey) +} + func getAlreadyExistsErr() error { return &pq.Error{Code: "23505"} } @@ -50,7 +57,7 @@ func TestCreateTag(t *testing.T) { }, ) - tagRepo := NewTagRepo(utils.GetDbForTest(t), errors.NewPostgresErrorTransformer()) + tagRepo := NewTagRepo(utils.GetDbForTest(t), errors.NewPostgresErrorTransformer(), promutils.NewTestScope()) err := tagRepo.Create(context.Background(), getTestTag()) assert.NoError(t, err) assert.True(t, tagCreated) @@ -109,7 +116,7 @@ func TestGetTag(t *testing.T) { TagName: "test-tag", } - tagRepo := NewTagRepo(utils.GetDbForTest(t), errors.NewPostgresErrorTransformer()) + tagRepo := NewTagRepo(utils.GetDbForTest(t), errors.NewPostgresErrorTransformer(), promutils.NewTestScope()) response, err := tagRepo.Get(context.Background(), getInput) assert.NoError(t, err) assert.Equal(t, artifact.ArtifactID, response.ArtifactID) @@ -127,7 +134,7 @@ func TestTagAlreadyExists(t *testing.T) { getAlreadyExistsErr(), ) - tagRepo := NewTagRepo(utils.GetDbForTest(t), errors.NewPostgresErrorTransformer()) + tagRepo := NewTagRepo(utils.GetDbForTest(t), errors.NewPostgresErrorTransformer(), promutils.NewTestScope()) err := tagRepo.Create(context.Background(), getTestTag()) assert.Error(t, err) dcErr, ok := err.(datacatalog_error.DataCatalogError) diff --git a/datacatalog/pkg/repositories/postgres_repo.go b/datacatalog/pkg/repositories/postgres_repo.go index f4418ffc21..6024c13d81 100644 --- a/datacatalog/pkg/repositories/postgres_repo.go +++ b/datacatalog/pkg/repositories/postgres_repo.go @@ -28,8 +28,8 @@ func (dc *PostgresRepo) TagRepo() interfaces.TagRepo { func NewPostgresRepo(db *gorm.DB, errorTransformer errors.ErrorTransformer, scope promutils.Scope) interfaces.DataCatalogRepo { return &PostgresRepo{ - datasetRepo: gormimpl.NewDatasetRepo(db, errorTransformer), - artifactRepo: gormimpl.NewArtifactRepo(db, errorTransformer), - tagRepo: gormimpl.NewTagRepo(db, errorTransformer), + datasetRepo: gormimpl.NewDatasetRepo(db, errorTransformer, scope.NewSubScope("dataset")), + artifactRepo: gormimpl.NewArtifactRepo(db, errorTransformer, scope.NewSubScope("artifact")), + tagRepo: gormimpl.NewTagRepo(db, errorTransformer, scope.NewSubScope("tag")), } } diff --git a/datacatalog/pkg/rpc/datacatalogservice/service.go b/datacatalog/pkg/rpc/datacatalogservice/service.go index 145f33b389..2c2f4b4683 100644 --- a/datacatalog/pkg/rpc/datacatalogservice/service.go +++ b/datacatalog/pkg/rpc/datacatalogservice/service.go @@ -11,9 +11,11 @@ import ( "github.com/lyft/datacatalog/pkg/repositories/config" "github.com/lyft/datacatalog/pkg/runtime" catalog "github.com/lyft/datacatalog/protos/gen" + "github.com/lyft/flytestdlib/contextutils" "github.com/lyft/flytestdlib/logger" "github.com/lyft/flytestdlib/profutils" "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/promutils/labeled" "github.com/lyft/flytestdlib/storage" ) @@ -23,7 +25,6 @@ type DataCatalogService struct { TagManager interfaces.TagManager } -// TODO: add metrics and counters to these service methods func (s *DataCatalogService) CreateDataset(ctx context.Context, request *catalog.CreateDatasetRequest) (*catalog.CreateDatasetResponse, error) { return s.DatasetManager.CreateDataset(ctx, *request) } @@ -50,6 +51,9 @@ func NewDataCatalogService() *DataCatalogService { dataCatalogScope := "datacatalog" catalogScope := promutils.NewScope(dataCatalogScope).NewSubScope("service") + // Set Keys + labeled.SetMetricKeys(contextutils.AppNameKey, contextutils.ProjectKey, contextutils.DomainKey) + defer func() { if err := recover(); err != nil { catalogScope.MustNewCounter("initialization_panic", From 50d427ac019c385c21fae6dac95229b7f997ba9d Mon Sep 17 00:00:00 2001 From: Andrew Chan Date: Wed, 28 Aug 2019 11:20:53 -0700 Subject: [PATCH 0078/1918] add labels for appName, project and domain --- datacatalog/pkg/manager/impl/artifact_manager.go | 3 +++ datacatalog/pkg/manager/impl/tag_manager.go | 4 +++- datacatalog/pkg/rpc/datacatalogservice/service.go | 7 +++---- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/datacatalog/pkg/manager/impl/artifact_manager.go b/datacatalog/pkg/manager/impl/artifact_manager.go index 7ce5e144e4..740eaeda45 100644 --- a/datacatalog/pkg/manager/impl/artifact_manager.go +++ b/datacatalog/pkg/manager/impl/artifact_manager.go @@ -13,6 +13,7 @@ import ( "github.com/lyft/datacatalog/pkg/repositories/models" "github.com/lyft/datacatalog/pkg/repositories/transformers" + "github.com/lyft/flytestdlib/contextutils" "github.com/lyft/flytestdlib/promutils" "github.com/lyft/flytestdlib/promutils/labeled" "github.com/lyft/flytestdlib/storage" @@ -51,6 +52,7 @@ func (m *artifactManager) CreateArtifact(ctx context.Context, request datacatalo return nil, err } + ctx = contextutils.WithProjectDomain(ctx, artifact.Dataset.Project, artifact.Dataset.Domain) datasetKey := transformers.FromDatasetID(*artifact.Dataset) // The dataset must exist for the artifact, let's verify that first @@ -100,6 +102,7 @@ func (m *artifactManager) GetArtifact(ctx context.Context, request datacatalog.G return nil, err } + ctx = contextutils.WithProjectDomain(ctx, datasetID.Project, datasetID.Domain) var artifactModel models.Artifact switch request.QueryHandle.(type) { case *datacatalog.GetArtifactRequest_ArtifactId: diff --git a/datacatalog/pkg/manager/impl/tag_manager.go b/datacatalog/pkg/manager/impl/tag_manager.go index 46e085b3d7..646e8f15a5 100644 --- a/datacatalog/pkg/manager/impl/tag_manager.go +++ b/datacatalog/pkg/manager/impl/tag_manager.go @@ -12,6 +12,7 @@ import ( "github.com/lyft/datacatalog/pkg/repositories/transformers" datacatalog "github.com/lyft/datacatalog/protos/gen" + "github.com/lyft/flytestdlib/contextutils" "github.com/lyft/flytestdlib/promutils" "github.com/lyft/flytestdlib/promutils/labeled" "github.com/lyft/flytestdlib/storage" @@ -32,7 +33,6 @@ type tagManager struct { } func (m *tagManager) AddTag(ctx context.Context, request datacatalog.AddTagRequest) (*datacatalog.AddTagResponse, error) { - if err := validators.ValidateTag(request.Tag); err != nil { m.systemMetrics.validationErrorCounter.Inc(ctx) return nil, err @@ -40,6 +40,8 @@ func (m *tagManager) AddTag(ctx context.Context, request datacatalog.AddTagReque // verify the artifact exists before adding a tag to it datasetID := *request.Tag.Dataset + ctx = contextutils.WithProjectDomain(ctx, datasetID.Project, datasetID.Domain) + artifactKey := transformers.ToArtifactKey(datasetID, request.Tag.ArtifactId) _, err := m.repo.ArtifactRepo().Get(ctx, artifactKey) if err != nil { diff --git a/datacatalog/pkg/rpc/datacatalogservice/service.go b/datacatalog/pkg/rpc/datacatalogservice/service.go index 2c2f4b4683..614411ccfa 100644 --- a/datacatalog/pkg/rpc/datacatalogservice/service.go +++ b/datacatalog/pkg/rpc/datacatalogservice/service.go @@ -46,10 +46,9 @@ func (s *DataCatalogService) AddTag(ctx context.Context, request *catalog.AddTag } func NewDataCatalogService() *DataCatalogService { - ctx := context.Background() - - dataCatalogScope := "datacatalog" - catalogScope := promutils.NewScope(dataCatalogScope).NewSubScope("service") + dataCatalogName := "datacatalog" + catalogScope := promutils.NewScope(dataCatalogName).NewSubScope("service") + ctx := contextutils.WithAppName(context.Background(), dataCatalogName) // Set Keys labeled.SetMetricKeys(contextutils.AppNameKey, contextutils.ProjectKey, contextutils.DomainKey) From a70231ae49cf1334f50a9d1f2551b5ae939af2b3 Mon Sep 17 00:00:00 2001 From: Andrew Chan Date: Wed, 28 Aug 2019 12:20:12 -0700 Subject: [PATCH 0079/1918] special case soft failures like already exists and not found --- datacatalog/pkg/errors/errors.go | 16 +++++++ datacatalog/pkg/errors/errors_test.go | 22 ++++++++++ .../pkg/manager/impl/artifact_manager.go | 44 +++++++++++++++++-- .../pkg/manager/impl/dataset_manager.go | 26 ++++++++++- datacatalog/pkg/manager/impl/tag_manager.go | 13 +++++- 5 files changed, 114 insertions(+), 7 deletions(-) create mode 100644 datacatalog/pkg/errors/errors_test.go diff --git a/datacatalog/pkg/errors/errors.go b/datacatalog/pkg/errors/errors.go index a6d3ef0910..8931444f16 100644 --- a/datacatalog/pkg/errors/errors.go +++ b/datacatalog/pkg/errors/errors.go @@ -43,3 +43,19 @@ func NewDataCatalogError(code codes.Code, message string) error { func NewDataCatalogErrorf(code codes.Code, format string, a ...interface{}) error { return NewDataCatalogError(code, fmt.Sprintf(format, a...)) } + +func IsAlreadyExistsError(err error) bool { + dcErr, ok := err.(DataCatalogError) + if ok && dcErr.GRPCStatus().Code() == codes.AlreadyExists { + return true + } + return false +} + +func IsDoesNotExistError(err error) bool { + dcErr, ok := err.(DataCatalogError) + if ok && dcErr.GRPCStatus().Code() == codes.NotFound { + return true + } + return false +} diff --git a/datacatalog/pkg/errors/errors_test.go b/datacatalog/pkg/errors/errors_test.go new file mode 100644 index 0000000000..1d75360374 --- /dev/null +++ b/datacatalog/pkg/errors/errors_test.go @@ -0,0 +1,22 @@ +package errors + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "google.golang.org/grpc/codes" +) + +func TestAlreadyExists(t *testing.T) { + alreadyExistsErr := NewDataCatalogError(codes.AlreadyExists, "already exists") + notFoundErr := NewDataCatalogError(codes.NotFound, "not found") + assert.True(t, IsAlreadyExistsError(alreadyExistsErr)) + assert.False(t, IsAlreadyExistsError(notFoundErr)) +} + +func TestNotFoundErr(t *testing.T) { + alreadyExistsErr := NewDataCatalogError(codes.AlreadyExists, "already exists") + notFoundErr := NewDataCatalogError(codes.NotFound, "not found") + assert.False(t, IsDoesNotExistError(alreadyExistsErr)) + assert.True(t, IsDoesNotExistError(notFoundErr)) +} diff --git a/datacatalog/pkg/manager/impl/artifact_manager.go b/datacatalog/pkg/manager/impl/artifact_manager.go index 740eaeda45..4944d66010 100644 --- a/datacatalog/pkg/manager/impl/artifact_manager.go +++ b/datacatalog/pkg/manager/impl/artifact_manager.go @@ -14,6 +14,7 @@ import ( "github.com/lyft/datacatalog/pkg/repositories/transformers" "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flytestdlib/logger" "github.com/lyft/flytestdlib/promutils" "github.com/lyft/flytestdlib/promutils/labeled" "github.com/lyft/flytestdlib/storage" @@ -32,6 +33,8 @@ type artifactMetrics struct { validationErrorCounter labeled.Counter createResponseTime labeled.StopWatch getResponseTime labeled.StopWatch + alreadyExistsCounter labeled.Counter + doesNotExistCounter labeled.Counter } type artifactManager struct { @@ -48,6 +51,7 @@ func (m *artifactManager) CreateArtifact(ctx context.Context, request datacatalo artifact := request.Artifact err := validators.ValidateArtifact(artifact) if err != nil { + logger.Errorf(ctx, "Invalid create artifact request %v, err: %v", request, err) m.systemMetrics.validationErrorCounter.Inc(ctx) return nil, err } @@ -58,6 +62,7 @@ func (m *artifactManager) CreateArtifact(ctx context.Context, request datacatalo // The dataset must exist for the artifact, let's verify that first _, err = m.repo.DatasetRepo().Get(ctx, datasetKey) if err != nil { + logger.Errorf(ctx, "Failed to get dataset for artifact creation %v, err: %v", datasetKey, err) m.systemMetrics.createFailureCounter.Inc(ctx) return nil, err } @@ -67,8 +72,8 @@ func (m *artifactManager) CreateArtifact(ctx context.Context, request datacatalo for i, artifactData := range request.Artifact.Data { dataLocation, err := m.artifactStore.PutData(ctx, *artifact, *artifactData) if err != nil { + logger.Errorf(ctx, "Failed to store artifact data err: %v", err) m.systemMetrics.createDataFailureCounter.Inc(ctx) - m.systemMetrics.createFailureCounter.Inc(ctx) return nil, err } @@ -77,18 +82,29 @@ func (m *artifactManager) CreateArtifact(ctx context.Context, request datacatalo m.systemMetrics.createDataSuccessCounter.Inc(ctx) } + logger.Debugf(ctx, "Stored %v data for artifact %+v", len(artifactDataModels), artifact.Id) + artifactModel, err := transformers.CreateArtifactModel(request, artifactDataModels) if err != nil { + logger.Errorf(ctx, "Failed to transform artifact err: %v", err) m.systemMetrics.transformerErrorCounter.Inc(ctx) return nil, err } err = m.repo.ArtifactRepo().Create(ctx, artifactModel) if err != nil { - m.systemMetrics.createFailureCounter.Inc(ctx) + if errors.IsAlreadyExistsError(err) { + logger.Warnf(ctx, "Artifact already exists key: %+v, err %v", artifact.Id, err) + m.systemMetrics.alreadyExistsCounter.Inc(ctx) + } else { + logger.Errorf(ctx, "Failed to create artifact %v, err: %v", artifactDataModels, err) + m.systemMetrics.createFailureCounter.Inc(ctx) + } return nil, err } + logger.Debugf(ctx, "Successfully created artifact id: %v", artifact.Id) + m.systemMetrics.createSuccessCounter.Inc(ctx) return &datacatalog.CreateArtifactResponse{}, nil } @@ -98,6 +114,7 @@ func (m *artifactManager) GetArtifact(ctx context.Context, request datacatalog.G datasetID := request.Dataset err := validators.ValidateGetArtifactRequest(request) if err != nil { + logger.Errorf(ctx, "Invalid get artifact request %v, err: %v", request, err) m.systemMetrics.validationErrorCounter.Inc(ctx) return nil, err } @@ -106,19 +123,33 @@ func (m *artifactManager) GetArtifact(ctx context.Context, request datacatalog.G var artifactModel models.Artifact switch request.QueryHandle.(type) { case *datacatalog.GetArtifactRequest_ArtifactId: + logger.Debugf(ctx, "Get artifact by id %v", request.GetArtifactId()) artifactKey := transformers.ToArtifactKey(*datasetID, request.GetArtifactId()) artifactModel, err = m.repo.ArtifactRepo().Get(ctx, artifactKey) if err != nil { - m.systemMetrics.getFailureCounter.Inc(ctx) + if errors.IsDoesNotExistError(err) { + logger.Warnf(ctx, "Artifact does not exist id: %+v, err %v", request.GetArtifactId(), err) + m.systemMetrics.doesNotExistCounter.Inc(ctx) + } else { + logger.Errorf(ctx, "Unable to retrieve artifact by id: %+v, err %v", request.GetArtifactId(), err) + m.systemMetrics.getFailureCounter.Inc(ctx) + } return nil, err } case *datacatalog.GetArtifactRequest_TagName: + logger.Debugf(ctx, "Get artifact by id %v", request.GetTagName()) tagKey := transformers.ToTagKey(*datasetID, request.GetTagName()) tag, err := m.repo.TagRepo().Get(ctx, tagKey) if err != nil { - m.systemMetrics.getFailureCounter.Inc(ctx) + if errors.IsDoesNotExistError(err) { + logger.Warnf(ctx, "Artifact does not exist tag: %+v, err %v", request.GetTagName(), err) + m.systemMetrics.doesNotExistCounter.Inc(ctx) + } else { + logger.Errorf(ctx, "Unable to retrieve Artifact by tag %v, err: %v", request.GetTagName(), err) + m.systemMetrics.getFailureCounter.Inc(ctx) + } return nil, err } @@ -131,6 +162,7 @@ func (m *artifactManager) GetArtifact(ctx context.Context, request datacatalog.G artifact, err := transformers.FromArtifactModel(artifactModel) if err != nil { + logger.Errorf(ctx, "Error in transforming get artifact request %+v, err %v", artifactModel, err) m.systemMetrics.transformerErrorCounter.Inc(ctx) return nil, err } @@ -139,6 +171,7 @@ func (m *artifactManager) GetArtifact(ctx context.Context, request datacatalog.G for i, artifactData := range artifactModel.ArtifactData { value, err := m.artifactStore.GetData(ctx, artifactData) if err != nil { + logger.Errorf(ctx, "Error in getting artifact data from datastore %+v, err %v", artifactData.Location, err) return nil, err } @@ -149,6 +182,7 @@ func (m *artifactManager) GetArtifact(ctx context.Context, request datacatalog.G } artifact.Data = artifactDataList + logger.Debugf(ctx, "Retrieved artifact dataset %v, id: %v", artifact.Dataset, artifact.Id) m.systemMetrics.getSuccessCounter.Inc(ctx) return &datacatalog.GetArtifactResponse{ Artifact: &artifact, @@ -168,6 +202,8 @@ func NewArtifactManager(repo repositories.RepositoryInterface, store *storage.Da createDataSuccessCounter: labeled.NewCounter("create_artifact_data_succeeded_count", "The number of times create artifact data succeeded", artifactScope, labeled.EmitUnlabeledMetric), transformerErrorCounter: labeled.NewCounter("transformer_failed_count", "The number of times transformations failed", artifactScope, labeled.EmitUnlabeledMetric), validationErrorCounter: labeled.NewCounter("validation_failed_count", "The number of times validation failed", artifactScope, labeled.EmitUnlabeledMetric), + alreadyExistsCounter: labeled.NewCounter("already_exists_count", "The number of times an artifact already exists", artifactScope, labeled.EmitUnlabeledMetric), + doesNotExistCounter: labeled.NewCounter("does_not_exists_count", "The number of times an artifact was not found", artifactScope, labeled.EmitUnlabeledMetric), } return &artifactManager{ diff --git a/datacatalog/pkg/manager/impl/dataset_manager.go b/datacatalog/pkg/manager/impl/dataset_manager.go index 614c45d612..35a7ab7438 100644 --- a/datacatalog/pkg/manager/impl/dataset_manager.go +++ b/datacatalog/pkg/manager/impl/dataset_manager.go @@ -10,6 +10,8 @@ import ( "github.com/lyft/datacatalog/pkg/repositories/transformers" datacatalog "github.com/lyft/datacatalog/protos/gen" + "github.com/lyft/datacatalog/pkg/errors" + "github.com/lyft/flytestdlib/logger" "github.com/lyft/flytestdlib/promutils" "github.com/lyft/flytestdlib/promutils/labeled" "github.com/lyft/flytestdlib/storage" @@ -25,6 +27,8 @@ type datasetMetrics struct { validationErrorCounter labeled.Counter createResponseTime labeled.StopWatch getResponseTime labeled.StopWatch + alreadyExistsCounter labeled.Counter + doesNotExistCounter labeled.Counter } type datasetManager struct { @@ -40,22 +44,31 @@ func (dm *datasetManager) CreateDataset(ctx context.Context, request datacatalog err := validators.ValidateDatasetID(request.Dataset.Id) if err != nil { + logger.Errorf(ctx, "Invalid create dataset request %+v err: %v", request, err) dm.systemMetrics.validationErrorCounter.Inc(ctx) return nil, err } datasetModel, err := transformers.CreateDatasetModel(request.Dataset) if err != nil { + logger.Errorf(ctx, "Unable to transform create dataset request %+v err: %v", request, err) dm.systemMetrics.transformerErrorCounter.Inc(ctx) return nil, err } err = dm.repo.DatasetRepo().Create(ctx, *datasetModel) if err != nil { - dm.systemMetrics.createErrorCounter.Inc(ctx) + if errors.IsAlreadyExistsError(err) { + logger.Warnf(ctx, "Dataset already exists key: %+v, err %v", request.Dataset, err) + dm.systemMetrics.alreadyExistsCounter.Inc(ctx) + } else { + logger.Errorf(ctx, "Failed to create dataset model: %+v err: %v", datasetModel, err) + dm.systemMetrics.createErrorCounter.Inc(ctx) + } return nil, err } + logger.Debugf(ctx, "Successfully created dataset %+v", request.Dataset) dm.systemMetrics.createSuccessCounter.Inc(ctx) return &datacatalog.CreateDatasetResponse{}, nil } @@ -67,6 +80,7 @@ func (dm *datasetManager) GetDataset(ctx context.Context, request datacatalog.Ge err := validators.ValidateDatasetID(request.Dataset) if err != nil { + logger.Errorf(ctx, "Invalid get dataset request %+v err: %v", request, err) dm.systemMetrics.validationErrorCounter.Inc(ctx) return nil, err } @@ -75,7 +89,13 @@ func (dm *datasetManager) GetDataset(ctx context.Context, request datacatalog.Ge datasetModel, err := dm.repo.DatasetRepo().Get(ctx, datasetKey) if err != nil { - dm.systemMetrics.getErrorCounter.Inc(ctx) + if errors.IsDoesNotExistError(err) { + logger.Warnf(ctx, "Dataset does not exist key: %+v, err %v", datasetKey, err) + dm.systemMetrics.doesNotExistCounter.Inc(ctx) + } else { + logger.Errorf(ctx, "Unable to get dataset request %+v err: %v", request, err) + dm.systemMetrics.getErrorCounter.Inc(ctx) + } return nil, err } @@ -105,6 +125,8 @@ func NewDatasetManager(repo repositories.RepositoryInterface, store *storage.Dat getErrorCounter: labeled.NewCounter("get_failed_count", "The number of times get dataset failed", datasetScope, labeled.EmitUnlabeledMetric), transformerErrorCounter: labeled.NewCounter("transformer_failed_count", "The number of times transformations failed", datasetScope, labeled.EmitUnlabeledMetric), validationErrorCounter: labeled.NewCounter("validation_failed_count", "The number of times validation failed", datasetScope, labeled.EmitUnlabeledMetric), + alreadyExistsCounter: labeled.NewCounter("already_exists_count", "The number of times a dataset already exists", datasetScope, labeled.EmitUnlabeledMetric), + doesNotExistCounter: labeled.NewCounter("does_not_exists_count", "The number of times a dataset was not found", datasetScope, labeled.EmitUnlabeledMetric), }, } } diff --git a/datacatalog/pkg/manager/impl/tag_manager.go b/datacatalog/pkg/manager/impl/tag_manager.go index 646e8f15a5..4ae4c80da6 100644 --- a/datacatalog/pkg/manager/impl/tag_manager.go +++ b/datacatalog/pkg/manager/impl/tag_manager.go @@ -12,7 +12,9 @@ import ( "github.com/lyft/datacatalog/pkg/repositories/transformers" datacatalog "github.com/lyft/datacatalog/protos/gen" + "github.com/lyft/datacatalog/pkg/errors" "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flytestdlib/logger" "github.com/lyft/flytestdlib/promutils" "github.com/lyft/flytestdlib/promutils/labeled" "github.com/lyft/flytestdlib/storage" @@ -24,6 +26,7 @@ type tagMetrics struct { addTagFailureCounter labeled.Counter addTagResponseTime labeled.StopWatch validationErrorCounter labeled.Counter + alreadyExistsCounter labeled.Counter } type tagManager struct { @@ -55,7 +58,14 @@ func (m *tagManager) AddTag(ctx context.Context, request datacatalog.AddTagReque ArtifactID: request.Tag.ArtifactId, }) if err != nil { - m.systemMetrics.addTagFailureCounter.Inc(ctx) + if errors.IsAlreadyExistsError(err) { + logger.Warnf(ctx, "Tag already exists key: %+v, err %v", request, err) + m.systemMetrics.alreadyExistsCounter.Inc(ctx) + } else { + logger.Errorf(ctx, "Failed to tag artifact: %+v err: %v", request, err) + m.systemMetrics.addTagFailureCounter.Inc(ctx) + } + return nil, err } @@ -70,6 +80,7 @@ func NewTagManager(repo repositories.RepositoryInterface, store *storage.DataSto addTagSuccessCounter: labeled.NewCounter("add_tag_success_count", "The number of times an artifact was tagged successfully", tagScope, labeled.EmitUnlabeledMetric), addTagFailureCounter: labeled.NewCounter("add_tag_failure_count", "The number of times we failed to tag an artifact", tagScope, labeled.EmitUnlabeledMetric), validationErrorCounter: labeled.NewCounter("validation_error_count", "The number of times we failed validate a tag", tagScope, labeled.EmitUnlabeledMetric), + alreadyExistsCounter: labeled.NewCounter("already_exists_count", "The number of times an tag already exists", tagScope, labeled.EmitUnlabeledMetric), } return &tagManager{ From a0587e768186bb5b84d207a95cc20c4b0899eb1b Mon Sep 17 00:00:00 2001 From: Andrew Chan Date: Wed, 28 Aug 2019 17:40:50 -0700 Subject: [PATCH 0080/1918] CR feedback --- datacatalog/pkg/errors/errors.go | 10 ++----- .../pkg/manager/impl/artifact_manager.go | 14 +++------- .../pkg/manager/impl/artifact_manager_test.go | 1 - .../pkg/manager/impl/dataset_manager.go | 15 ++--------- datacatalog/pkg/manager/impl/tag_manager.go | 4 +-- .../pkg/rpc/datacatalogservice/service.go | 27 +++++++++++++++++++ 6 files changed, 35 insertions(+), 36 deletions(-) diff --git a/datacatalog/pkg/errors/errors.go b/datacatalog/pkg/errors/errors.go index 8931444f16..d6fdbdb9a0 100644 --- a/datacatalog/pkg/errors/errors.go +++ b/datacatalog/pkg/errors/errors.go @@ -46,16 +46,10 @@ func NewDataCatalogErrorf(code codes.Code, format string, a ...interface{}) erro func IsAlreadyExistsError(err error) bool { dcErr, ok := err.(DataCatalogError) - if ok && dcErr.GRPCStatus().Code() == codes.AlreadyExists { - return true - } - return false + return ok && dcErr.GRPCStatus().Code() == codes.AlreadyExists } func IsDoesNotExistError(err error) bool { dcErr, ok := err.(DataCatalogError) - if ok && dcErr.GRPCStatus().Code() == codes.NotFound { - return true - } - return false + return ok && dcErr.GRPCStatus().Code() == codes.NotFound } diff --git a/datacatalog/pkg/manager/impl/artifact_manager.go b/datacatalog/pkg/manager/impl/artifact_manager.go index 4944d66010..30f0d76433 100644 --- a/datacatalog/pkg/manager/impl/artifact_manager.go +++ b/datacatalog/pkg/manager/impl/artifact_manager.go @@ -2,7 +2,6 @@ package impl import ( "context" - "time" "github.com/lyft/datacatalog/pkg/errors" "github.com/lyft/datacatalog/pkg/manager/impl/validators" @@ -31,8 +30,6 @@ type artifactMetrics struct { createDataSuccessCounter labeled.Counter transformerErrorCounter labeled.Counter validationErrorCounter labeled.Counter - createResponseTime labeled.StopWatch - getResponseTime labeled.StopWatch alreadyExistsCounter labeled.Counter doesNotExistCounter labeled.Counter } @@ -45,13 +42,10 @@ type artifactManager struct { // Create an Artifact along with the associated ArtifactData. The ArtifactData will be stored in an offloaded location. func (m *artifactManager) CreateArtifact(ctx context.Context, request datacatalog.CreateArtifactRequest) (*datacatalog.CreateArtifactResponse, error) { - timer := m.systemMetrics.createResponseTime.Start(ctx) - defer timer.Stop() - artifact := request.Artifact err := validators.ValidateArtifact(artifact) if err != nil { - logger.Errorf(ctx, "Invalid create artifact request %v, err: %v", request, err) + logger.Warningf(ctx, "Invalid create artifact request %v, err: %v", request, err) m.systemMetrics.validationErrorCounter.Inc(ctx) return nil, err } @@ -62,7 +56,7 @@ func (m *artifactManager) CreateArtifact(ctx context.Context, request datacatalo // The dataset must exist for the artifact, let's verify that first _, err = m.repo.DatasetRepo().Get(ctx, datasetKey) if err != nil { - logger.Errorf(ctx, "Failed to get dataset for artifact creation %v, err: %v", datasetKey, err) + logger.Warnf(ctx, "Failed to get dataset for artifact creation %v, err: %v", datasetKey, err) m.systemMetrics.createFailureCounter.Inc(ctx) return nil, err } @@ -114,7 +108,7 @@ func (m *artifactManager) GetArtifact(ctx context.Context, request datacatalog.G datasetID := request.Dataset err := validators.ValidateGetArtifactRequest(request) if err != nil { - logger.Errorf(ctx, "Invalid get artifact request %v, err: %v", request, err) + logger.Warningf(ctx, "Invalid get artifact request %v, err: %v", request, err) m.systemMetrics.validationErrorCounter.Inc(ctx) return nil, err } @@ -192,8 +186,6 @@ func (m *artifactManager) GetArtifact(ctx context.Context, request datacatalog.G func NewArtifactManager(repo repositories.RepositoryInterface, store *storage.DataStore, storagePrefix storage.DataReference, artifactScope promutils.Scope) interfaces.ArtifactManager { artifactMetrics := artifactMetrics{ scope: artifactScope, - createResponseTime: labeled.NewStopWatch("create_artifact_duration", "The duration of the create artifact calls.", time.Millisecond, artifactScope, labeled.EmitUnlabeledMetric), - getResponseTime: labeled.NewStopWatch("get_artifact_duration", "The duration of the get artifact calls.", time.Millisecond, artifactScope, labeled.EmitUnlabeledMetric), createSuccessCounter: labeled.NewCounter("create_artifact_success_count", "The number of times create artifact was called", artifactScope, labeled.EmitUnlabeledMetric), getSuccessCounter: labeled.NewCounter("get_artifact_success_count", "The number of times get artifact was called", artifactScope, labeled.EmitUnlabeledMetric), createFailureCounter: labeled.NewCounter("create_artifact_failure_count", "The number of times create artifact failed", artifactScope, labeled.EmitUnlabeledMetric), diff --git a/datacatalog/pkg/manager/impl/artifact_manager_test.go b/datacatalog/pkg/manager/impl/artifact_manager_test.go index 707a74166e..051e3eb602 100644 --- a/datacatalog/pkg/manager/impl/artifact_manager_test.go +++ b/datacatalog/pkg/manager/impl/artifact_manager_test.go @@ -25,7 +25,6 @@ func init() { } func createInmemoryDataStore(t testing.TB, scope mockScope.Scope) *storage.DataStore { - labeled.SetMetricKeys(contextutils.AppNameKey) cfg := storage.Config{ Type: storage.TypeMemory, } diff --git a/datacatalog/pkg/manager/impl/dataset_manager.go b/datacatalog/pkg/manager/impl/dataset_manager.go index 35a7ab7438..0aaee32990 100644 --- a/datacatalog/pkg/manager/impl/dataset_manager.go +++ b/datacatalog/pkg/manager/impl/dataset_manager.go @@ -2,7 +2,6 @@ package impl import ( "context" - "time" "github.com/lyft/datacatalog/pkg/manager/impl/validators" "github.com/lyft/datacatalog/pkg/manager/interfaces" @@ -25,8 +24,6 @@ type datasetMetrics struct { getErrorCounter labeled.Counter transformerErrorCounter labeled.Counter validationErrorCounter labeled.Counter - createResponseTime labeled.StopWatch - getResponseTime labeled.StopWatch alreadyExistsCounter labeled.Counter doesNotExistCounter labeled.Counter } @@ -39,12 +36,9 @@ type datasetManager struct { // Create a Dataset with optional metadata. If one already exists a grpc AlreadyExists err will be returned func (dm *datasetManager) CreateDataset(ctx context.Context, request datacatalog.CreateDatasetRequest) (*datacatalog.CreateDatasetResponse, error) { - t := dm.systemMetrics.createResponseTime.Start(ctx) - defer t.Stop() - err := validators.ValidateDatasetID(request.Dataset.Id) if err != nil { - logger.Errorf(ctx, "Invalid create dataset request %+v err: %v", request, err) + logger.Warnf(ctx, "Invalid create dataset request %+v err: %v", request, err) dm.systemMetrics.validationErrorCounter.Inc(ctx) return nil, err } @@ -75,12 +69,9 @@ func (dm *datasetManager) CreateDataset(ctx context.Context, request datacatalog // Get a Dataset with the given DatasetID if it exists. If none exist a grpc NotFound err will be returned func (dm *datasetManager) GetDataset(ctx context.Context, request datacatalog.GetDatasetRequest) (*datacatalog.GetDatasetResponse, error) { - t := dm.systemMetrics.getResponseTime.Start(ctx) - defer t.Stop() - err := validators.ValidateDatasetID(request.Dataset) if err != nil { - logger.Errorf(ctx, "Invalid get dataset request %+v err: %v", request, err) + logger.Warnf(ctx, "Invalid get dataset request %+v err: %v", request, err) dm.systemMetrics.validationErrorCounter.Inc(ctx) return nil, err } @@ -117,8 +108,6 @@ func NewDatasetManager(repo repositories.RepositoryInterface, store *storage.Dat store: store, systemMetrics: datasetMetrics{ scope: datasetScope, - createResponseTime: labeled.NewStopWatch("create_duration", "The duration of the create dataset calls.", time.Millisecond, datasetScope, labeled.EmitUnlabeledMetric), - getResponseTime: labeled.NewStopWatch("get_duration", "The duration of the get dataset calls.", time.Millisecond, datasetScope, labeled.EmitUnlabeledMetric), createSuccessCounter: labeled.NewCounter("create_success_count", "The number of times create dataset was called", datasetScope, labeled.EmitUnlabeledMetric), getSuccessCounter: labeled.NewCounter("get_success_count", "The number of times get dataset was called", datasetScope, labeled.EmitUnlabeledMetric), createErrorCounter: labeled.NewCounter("create_failed_count", "The number of times create dataset failed", datasetScope, labeled.EmitUnlabeledMetric), diff --git a/datacatalog/pkg/manager/impl/tag_manager.go b/datacatalog/pkg/manager/impl/tag_manager.go index 4ae4c80da6..dec654a288 100644 --- a/datacatalog/pkg/manager/impl/tag_manager.go +++ b/datacatalog/pkg/manager/impl/tag_manager.go @@ -2,7 +2,6 @@ package impl import ( "context" - "time" "github.com/lyft/datacatalog/pkg/manager/impl/validators" "github.com/lyft/datacatalog/pkg/manager/interfaces" @@ -24,7 +23,6 @@ type tagMetrics struct { scope promutils.Scope addTagSuccessCounter labeled.Counter addTagFailureCounter labeled.Counter - addTagResponseTime labeled.StopWatch validationErrorCounter labeled.Counter alreadyExistsCounter labeled.Counter } @@ -37,6 +35,7 @@ type tagManager struct { func (m *tagManager) AddTag(ctx context.Context, request datacatalog.AddTagRequest) (*datacatalog.AddTagResponse, error) { if err := validators.ValidateTag(request.Tag); err != nil { + logger.Warnf(ctx, "Invalid get tag request %+v err: %v", request, err) m.systemMetrics.validationErrorCounter.Inc(ctx) return nil, err } @@ -76,7 +75,6 @@ func (m *tagManager) AddTag(ctx context.Context, request datacatalog.AddTagReque func NewTagManager(repo repositories.RepositoryInterface, store *storage.DataStore, tagScope promutils.Scope) interfaces.TagManager { systemMetrics := tagMetrics{ scope: tagScope, - addTagResponseTime: labeled.NewStopWatch("add_tag_duration", "The duration of tagging an artifact.", time.Millisecond, tagScope, labeled.EmitUnlabeledMetric), addTagSuccessCounter: labeled.NewCounter("add_tag_success_count", "The number of times an artifact was tagged successfully", tagScope, labeled.EmitUnlabeledMetric), addTagFailureCounter: labeled.NewCounter("add_tag_failure_count", "The number of times we failed to tag an artifact", tagScope, labeled.EmitUnlabeledMetric), validationErrorCounter: labeled.NewCounter("validation_error_count", "The number of times we failed validate a tag", tagScope, labeled.EmitUnlabeledMetric), diff --git a/datacatalog/pkg/rpc/datacatalogservice/service.go b/datacatalog/pkg/rpc/datacatalogservice/service.go index 614411ccfa..1421c9a837 100644 --- a/datacatalog/pkg/rpc/datacatalogservice/service.go +++ b/datacatalog/pkg/rpc/datacatalogservice/service.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "runtime/debug" + "time" "github.com/lyft/datacatalog/pkg/manager/impl" "github.com/lyft/datacatalog/pkg/manager/interfaces" @@ -19,29 +20,48 @@ import ( "github.com/lyft/flytestdlib/storage" ) +type serviceMetrics struct { + createDatasetResponseTime labeled.StopWatch + getDatasetResponseTime labeled.StopWatch + createArtifactResponseTime labeled.StopWatch + getArtifactResponseTime labeled.StopWatch + addTagResponseTime labeled.StopWatch +} + type DataCatalogService struct { DatasetManager interfaces.DatasetManager ArtifactManager interfaces.ArtifactManager TagManager interfaces.TagManager + serviceMetrics serviceMetrics } func (s *DataCatalogService) CreateDataset(ctx context.Context, request *catalog.CreateDatasetRequest) (*catalog.CreateDatasetResponse, error) { + timer := s.serviceMetrics.createDatasetResponseTime.Start(ctx) + defer timer.Stop() return s.DatasetManager.CreateDataset(ctx, *request) } func (s *DataCatalogService) CreateArtifact(ctx context.Context, request *catalog.CreateArtifactRequest) (*catalog.CreateArtifactResponse, error) { + timer := s.serviceMetrics.createArtifactResponseTime.Start(ctx) + defer timer.Stop() return s.ArtifactManager.CreateArtifact(ctx, *request) } func (s *DataCatalogService) GetDataset(ctx context.Context, request *catalog.GetDatasetRequest) (*catalog.GetDatasetResponse, error) { + timer := s.serviceMetrics.getDatasetResponseTime.Start(ctx) + defer timer.Stop() return s.DatasetManager.GetDataset(ctx, *request) } func (s *DataCatalogService) GetArtifact(ctx context.Context, request *catalog.GetArtifactRequest) (*catalog.GetArtifactResponse, error) { + timer := s.serviceMetrics.getArtifactResponseTime.Start(ctx) + defer timer.Stop() return s.ArtifactManager.GetArtifact(ctx, *request) } func (s *DataCatalogService) AddTag(ctx context.Context, request *catalog.AddTagRequest) (*catalog.AddTagResponse, error) { + timer := s.serviceMetrics.addTagResponseTime.Start(ctx) + defer timer.Stop() return s.TagManager.AddTag(ctx, *request) } @@ -103,5 +123,12 @@ func NewDataCatalogService() *DataCatalogService { DatasetManager: impl.NewDatasetManager(repos, dataStorageClient, catalogScope.NewSubScope("dataset")), ArtifactManager: impl.NewArtifactManager(repos, dataStorageClient, storagePrefix, catalogScope.NewSubScope("artifact")), TagManager: impl.NewTagManager(repos, dataStorageClient, catalogScope.NewSubScope("tag")), + serviceMetrics: serviceMetrics{ + createDatasetResponseTime: labeled.NewStopWatch("create_dataset_duration", "The duration of the create artifact calls.", time.Millisecond, catalogScope, labeled.EmitUnlabeledMetric), + getDatasetResponseTime: labeled.NewStopWatch("get_dataset_duration", "The duration of the get artifact calls.", time.Millisecond, catalogScope, labeled.EmitUnlabeledMetric), + createArtifactResponseTime: labeled.NewStopWatch("create_artifact_duration", "The duration of the get artifact calls.", time.Millisecond, catalogScope, labeled.EmitUnlabeledMetric), + getArtifactResponseTime: labeled.NewStopWatch("get_artifact_duration", "The duration of the get artifact calls.", time.Millisecond, catalogScope, labeled.EmitUnlabeledMetric), + addTagResponseTime: labeled.NewStopWatch("add_tag_duration", "The duration of the get artifact calls.", time.Millisecond, catalogScope, labeled.EmitUnlabeledMetric), + }, } } From 289cb8c2e61d3f2b02b4f4d0ea456c1cb573fa9b Mon Sep 17 00:00:00 2001 From: Andrew Chan Date: Thu, 29 Aug 2019 07:54:23 -0700 Subject: [PATCH 0081/1918] Use the metricscope config --- datacatalog/pkg/rpc/datacatalogservice/service.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/datacatalog/pkg/rpc/datacatalogservice/service.go b/datacatalog/pkg/rpc/datacatalogservice/service.go index 1421c9a837..0eac919e4e 100644 --- a/datacatalog/pkg/rpc/datacatalogservice/service.go +++ b/datacatalog/pkg/rpc/datacatalogservice/service.go @@ -66,9 +66,11 @@ func (s *DataCatalogService) AddTag(ctx context.Context, request *catalog.AddTag } func NewDataCatalogService() *DataCatalogService { - dataCatalogName := "datacatalog" - catalogScope := promutils.NewScope(dataCatalogName).NewSubScope("service") - ctx := contextutils.WithAppName(context.Background(), dataCatalogName) + configProvider := runtime.NewConfigurationProvider() + dataCatalogConfig := configProvider.ApplicationConfiguration().GetDataCatalogConfig() + + catalogScope := promutils.NewScope(dataCatalogConfig.MetricsScope).NewSubScope("service") + ctx := contextutils.WithAppName(context.Background(), "datacatalog") // Set Keys labeled.SetMetricKeys(contextutils.AppNameKey, contextutils.ProjectKey, contextutils.DomainKey) @@ -89,9 +91,7 @@ func NewDataCatalogService() *DataCatalogService { } logger.Infof(ctx, "Created data storage.") - configProvider := runtime.NewConfigurationProvider() baseStorageReference := dataStorageClient.GetBaseContainerFQN(ctx) - dataCatalogConfig := configProvider.ApplicationConfiguration().GetDataCatalogConfig() storagePrefix, err := dataStorageClient.ConstructReference(ctx, baseStorageReference, dataCatalogConfig.StoragePrefix) if err != nil { logger.Errorf(ctx, "Failed to create prefix %v, err %v", dataCatalogConfig.StoragePrefix, err) From 6f459ad58ab4c2d722fcf61ff7089e9c3f74c828 Mon Sep 17 00:00:00 2001 From: Andrew Chan Date: Thu, 29 Aug 2019 10:09:10 -0700 Subject: [PATCH 0082/1918] typo in flag name --- datacatalog/pkg/rpc/datacatalogservice/service.go | 1 - datacatalog/pkg/runtime/configs/data_catalog_config.go | 2 +- .../pkg/runtime/configs/datacatalogconfig_flags.go | 4 ++-- .../runtime/configs/datacatalogconfig_flags_test.go | 10 +++++----- 4 files changed, 8 insertions(+), 9 deletions(-) diff --git a/datacatalog/pkg/rpc/datacatalogservice/service.go b/datacatalog/pkg/rpc/datacatalogservice/service.go index 0eac919e4e..ce3cfc79c3 100644 --- a/datacatalog/pkg/rpc/datacatalogservice/service.go +++ b/datacatalog/pkg/rpc/datacatalogservice/service.go @@ -68,7 +68,6 @@ func (s *DataCatalogService) AddTag(ctx context.Context, request *catalog.AddTag func NewDataCatalogService() *DataCatalogService { configProvider := runtime.NewConfigurationProvider() dataCatalogConfig := configProvider.ApplicationConfiguration().GetDataCatalogConfig() - catalogScope := promutils.NewScope(dataCatalogConfig.MetricsScope).NewSubScope("service") ctx := contextutils.WithAppName(context.Background(), "datacatalog") diff --git a/datacatalog/pkg/runtime/configs/data_catalog_config.go b/datacatalog/pkg/runtime/configs/data_catalog_config.go index a7bf8de8bf..c7160dca55 100644 --- a/datacatalog/pkg/runtime/configs/data_catalog_config.go +++ b/datacatalog/pkg/runtime/configs/data_catalog_config.go @@ -5,6 +5,6 @@ package configs // This configuration is the base configuration to start admin type DataCatalogConfig struct { StoragePrefix string `json:"storage-prefix" pflag:",StoragePrefix specifies the prefix where DataCatalog stores offloaded ArtifactData in CloudStorage. If not specified, the data will be stored in the base container directly."` - MetricsScope string `json:"metrics-ccope" pflag:",Scope that the metrics will record under."` + MetricsScope string `json:"metrics-scope" pflag:",Scope that the metrics will record under."` ProfilerPort int `json:"profiler-port" pflag:",Port that the profiling service is listening on."` } diff --git a/datacatalog/pkg/runtime/configs/datacatalogconfig_flags.go b/datacatalog/pkg/runtime/configs/datacatalogconfig_flags.go index 2f41ca13b7..cb4d24390d 100755 --- a/datacatalog/pkg/runtime/configs/datacatalogconfig_flags.go +++ b/datacatalog/pkg/runtime/configs/datacatalogconfig_flags.go @@ -1,6 +1,6 @@ // Code generated by go generate; DO NOT EDIT. // This file was generated by robots at -// 2019-08-23 21:53:46.865656622 -0700 PDT m=+1.391823141 +// 2019-08-29 10:08:04.85326469 -0700 PDT m=+1.565524498 package configs @@ -15,7 +15,7 @@ import ( func (DataCatalogConfig) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags := pflag.NewFlagSet("DataCatalogConfig", pflag.ExitOnError) cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage-prefix"), *new(string), "StoragePrefix specifies the prefix where DataCatalog stores offloaded ArtifactData in CloudStorage. If not specified, the data will be stored in the base container directly.") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "metrics-ccope"), *new(string), "Scope that the metrics will record under.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "metrics-scope"), *new(string), "Scope that the metrics will record under.") cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "profiler-port"), *new(int), "Port that the profiling service is listening on.") return cmdFlags } diff --git a/datacatalog/pkg/runtime/configs/datacatalogconfig_flags_test.go b/datacatalog/pkg/runtime/configs/datacatalogconfig_flags_test.go index 41f53cdd88..6411058716 100755 --- a/datacatalog/pkg/runtime/configs/datacatalogconfig_flags_test.go +++ b/datacatalog/pkg/runtime/configs/datacatalogconfig_flags_test.go @@ -1,6 +1,6 @@ // Code generated by go generate; DO NOT EDIT. // This file was generated by robots at -// 2019-08-23 21:53:46.865656622 -0700 PDT m=+1.391823141 +// 2019-08-29 10:08:04.85326469 -0700 PDT m=+1.565524498 package configs @@ -109,10 +109,10 @@ func TestDataCatalogConfig_SetFlags(t *testing.T) { } }) }) - t.Run("Test_metrics-ccope", func(t *testing.T) { + t.Run("Test_metrics-scope", func(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly - if vString, err := cmdFlags.GetString("metrics-ccope"); err == nil { + if vString, err := cmdFlags.GetString("metrics-scope"); err == nil { assert.Equal(t, *new(string), vString) } else { assert.FailNow(t, err.Error()) @@ -120,8 +120,8 @@ func TestDataCatalogConfig_SetFlags(t *testing.T) { }) t.Run("Override", func(t *testing.T) { - cmdFlags.Set("metrics-ccope", "1") - if vString, err := cmdFlags.GetString("metrics-ccope"); err == nil { + cmdFlags.Set("metrics-scope", "1") + if vString, err := cmdFlags.GetString("metrics-scope"); err == nil { testDecodeJson_DataCatalogConfig(t, fmt.Sprintf("%v", vString), &actual.MetricsScope) } else { From 58cc4336fb05493fbe5d4b18c7621b41e9822af3 Mon Sep 17 00:00:00 2001 From: Ketan Umare Date: Fri, 6 Sep 2019 15:12:06 -0700 Subject: [PATCH 0083/1918] Initial Commit --- flytepropeller/.dockerignore | 2 + flytepropeller/.gitignore | 4 + flytepropeller/.golangci.yml | 30 + flytepropeller/.travis.yml | 26 + flytepropeller/CODE_OF_CONDUCT.md | 3 + flytepropeller/Dockerfile | 33 + flytepropeller/Gopkg.lock | 1433 ++++++++++++++++ flytepropeller/Gopkg.toml | 102 ++ flytepropeller/LICENSE | 202 +++ flytepropeller/Makefile | 44 + flytepropeller/NOTICE | 5 + flytepropeller/README.rst | 129 ++ .../boilerplate/lyft/docker_build/Makefile | 12 + .../boilerplate/lyft/docker_build/Readme.rst | 23 + .../lyft/docker_build/docker_build.sh | 67 + .../golang_dockerfile/Dockerfile.GoTemplate | 33 + .../lyft/golang_dockerfile/Readme.rst | 16 + .../lyft/golang_dockerfile/update.sh | 13 + .../lyft/golang_test_targets/Makefile | 38 + .../lyft/golang_test_targets/Readme.rst | 31 + .../lyft/golang_test_targets/goimports | 8 + .../lyft/golangci_file/.golangci.yml | 30 + .../boilerplate/lyft/golangci_file/Readme.rst | 8 + .../boilerplate/lyft/golangci_file/update.sh | 14 + flytepropeller/boilerplate/update.cfg | 4 + flytepropeller/boilerplate/update.sh | 53 + flytepropeller/cmd/controller/cmd/root.go | 235 +++ flytepropeller/cmd/controller/main.go | 9 + .../cmd/kubectl-flyte/cmd/compile.go | 106 ++ .../cmd/kubectl-flyte/cmd/create.go | 228 +++ .../cmd/kubectl-flyte/cmd/create_test.go | 292 ++++ .../cmd/kubectl-flyte/cmd/delete.go | 87 + flytepropeller/cmd/kubectl-flyte/cmd/get.go | 141 ++ .../cmd/kubectl-flyte/cmd/printers/node.go | 157 ++ .../kubectl-flyte/cmd/printers/workflow.go | 63 + flytepropeller/cmd/kubectl-flyte/cmd/root.go | 116 ++ .../cmd/kubectl-flyte/cmd/string_map_value.go | 83 + .../cmd/string_map_value_test.go | 65 + .../cmd/testdata/inputs.json.golden | 1 + .../cmd/testdata/inputs.pb.golden | 19 + .../cmd/testdata/inputs.yaml.golden | 17 + .../cmd/testdata/workflow.json.golden | 1 + .../cmd/testdata/workflow.pb.golden | Bin 0 -> 193 bytes .../cmd/testdata/workflow.yaml.golden | 38 + .../testdata/workflow_w_inputs.json.golden | 1 + .../cmd/testdata/workflow_w_inputs.pb.golden | Bin 0 -> 291 bytes .../testdata/workflow_w_inputs.yaml.golden | 67 + flytepropeller/cmd/kubectl-flyte/cmd/util.go | 18 + .../cmd/kubectl-flyte/cmd/visualize.go | 39 + flytepropeller/cmd/kubectl-flyte/main.go | 17 + flytepropeller/config.yaml | 95 ++ flytepropeller/hack/boilerplate.go.txt | 0 flytepropeller/hack/custom-boilerplate.go.txt | 0 flytepropeller/hack/update-codegen.sh | 43 + flytepropeller/hack/verify-codegen.sh | 50 + .../pkg/apis/flyteworkflow/register.go | 5 + .../pkg/apis/flyteworkflow/v1alpha1/branch.go | 100 ++ .../flyteworkflow/v1alpha1/branch_test.go | 22 + .../pkg/apis/flyteworkflow/v1alpha1/doc.go | 5 + .../apis/flyteworkflow/v1alpha1/identifier.go | 45 + .../pkg/apis/flyteworkflow/v1alpha1/iface.go | 391 +++++ .../flyteworkflow/v1alpha1/mocks/BaseNode.go | 39 + .../v1alpha1/mocks/BaseWorkflow.go | 87 + .../v1alpha1/mocks/BaseWorkflowWithStatus.go | 103 ++ .../v1alpha1/mocks/ExecutableBranchNode.go | 76 + .../mocks/ExecutableBranchNodeStatus.go | 41 + .../mocks/ExecutableDynamicNodeStatus.go | 25 + .../v1alpha1/mocks/ExecutableIfBlock.go | 43 + .../v1alpha1/mocks/ExecutableNode.go | 196 +++ .../v1alpha1/mocks/ExecutableNodeStatus.go | 407 +++++ .../v1alpha1/mocks/ExecutableSubWorkflow.go | 167 ++ .../mocks/ExecutableSubWorkflowNodeStatus.go | 25 + .../v1alpha1/mocks/ExecutableTask.go | 41 + .../mocks/ExecutableTaskNodeStatus.go | 55 + .../v1alpha1/mocks/ExecutableWorkflow.go | 370 +++++ .../v1alpha1/mocks/ExecutableWorkflowNode.go | 43 + .../mocks/ExecutableWorkflowNodeStatus.go | 24 + .../mocks/ExecutableWorkflowStatus.go | 194 +++ .../v1alpha1/mocks/MutableBranchNodeStatus.go | 51 + .../mocks/MutableDynamicNodeStatus.go | 30 + .../v1alpha1/mocks/MutableNodeStatus.go | 158 ++ .../mocks/MutableSubWorkflowNodeStatus.go | 30 + .../v1alpha1/mocks/MutableTaskNodeStatus.go | 70 + .../mocks/MutableWorkflowNodeStatus.go | 29 + .../v1alpha1/mocks/NodeStatusGetter.go | 27 + .../v1alpha1/mocks/NodeStatusVisitor.go | 16 + .../v1alpha1/mocks/WorkflowMeta.go | 143 ++ .../v1alpha1/mocks/WorkflowMetaExtended.go | 198 +++ .../flyteworkflow/v1alpha1/node_status.go | 512 ++++++ .../v1alpha1/node_status_test.go | 156 ++ .../pkg/apis/flyteworkflow/v1alpha1/nodes.go | 193 +++ .../apis/flyteworkflow/v1alpha1/register.go | 38 + .../flyteworkflow/v1alpha1/subworkflow.go | 23 + .../pkg/apis/flyteworkflow/v1alpha1/tasks.go | 39 + .../apis/flyteworkflow/v1alpha1/tasks_test.go | 20 + .../v1alpha1/testdata/branch.json | 34 + .../v1alpha1/testdata/connections.json | 15 + .../flyteworkflow/v1alpha1/testdata/task.yaml | 33 + .../v1alpha1/testdata/workflowspec.yaml | 220 +++ .../apis/flyteworkflow/v1alpha1/workflow.go | 232 +++ .../flyteworkflow/v1alpha1/workflow_status.go | 154 ++ .../v1alpha1/workflow_status_test.go | 54 + .../flyteworkflow/v1alpha1/workflow_test.go | 51 + .../v1alpha1/zz_generated.deepcopy.go | 676 ++++++++ .../client/clientset/versioned/clientset.go | 82 + .../pkg/client/clientset/versioned/doc.go | 4 + .../versioned/fake/clientset_generated.go | 66 + .../client/clientset/versioned/fake/doc.go | 4 + .../clientset/versioned/fake/register.go | 40 + .../client/clientset/versioned/scheme/doc.go | 4 + .../clientset/versioned/scheme/register.go | 40 + .../typed/flyteworkflow/v1alpha1/doc.go | 4 + .../typed/flyteworkflow/v1alpha1/fake/doc.go | 4 + .../v1alpha1/fake/fake_flyteworkflow.go | 124 ++ .../fake/fake_flyteworkflow_client.go | 24 + .../flyteworkflow/v1alpha1/flyteworkflow.go | 175 ++ .../v1alpha1/flyteworkflow_client.go | 74 + .../v1alpha1/generated_expansion.go | 5 + .../informers/externalversions/factory.go | 164 ++ .../flyteworkflow/interface.go | 30 + .../flyteworkflow/v1alpha1/flyteworkflow.go | 73 + .../flyteworkflow/v1alpha1/interface.go | 29 + .../informers/externalversions/generic.go | 46 + .../internalinterfaces/factory_interfaces.go | 24 + .../v1alpha1/expansion_generated.go | 11 + .../flyteworkflow/v1alpha1/flyteworkflow.go | 78 + flytepropeller/pkg/compiler/builders.go | 136 ++ flytepropeller/pkg/compiler/common/builder.go | 32 + flytepropeller/pkg/compiler/common/id_set.go | 99 ++ flytepropeller/pkg/compiler/common/index.go | 71 + .../common/mocks/interface_provider.go | 59 + .../pkg/compiler/common/mocks/node.go | 202 +++ .../pkg/compiler/common/mocks/node_builder.go | 222 +++ .../pkg/compiler/common/mocks/task.go | 57 + .../pkg/compiler/common/mocks/workflow.go | 200 +++ .../compiler/common/mocks/workflow_builder.go | 268 +++ flytepropeller/pkg/compiler/common/reader.go | 55 + .../compiler/errors/compiler_error_test.go | 52 + .../pkg/compiler/errors/compiler_errors.go | 272 +++ flytepropeller/pkg/compiler/errors/config.go | 27 + flytepropeller/pkg/compiler/errors/error.go | 125 ++ .../pkg/compiler/errors/error_test.go | 43 + flytepropeller/pkg/compiler/errors/sets.go | 44 + .../pkg/compiler/errors/sets_test.go | 20 + flytepropeller/pkg/compiler/requirements.go | 88 + .../pkg/compiler/requirements_test.go | 125 ++ flytepropeller/pkg/compiler/task_compiler.go | 98 ++ .../pkg/compiler/task_compiler_test.go | 82 + .../pkg/compiler/test/compiler_test.go | 247 +++ ...rkflows-work-one-python-task-w-f-inputs.pb | 0 ...flows-work-one-python-task-w-f-inputs.yaml | 1 + .../app-workflows-work-one-python-task-w-f.pb | Bin 0 -> 560 bytes ...pp-workflows-work-one-python-task-w-f.yaml | 79 + ...beta-one-second-functional-test.dot.golden | 1 + .../transformers/k8s/builder_mock_test.go | 140 ++ .../pkg/compiler/transformers/k8s/inputs.go | 53 + .../compiler/transformers/k8s/inputs_test.go | 1 + .../pkg/compiler/transformers/k8s/node.go | 169 ++ .../compiler/transformers/k8s/node_test.go | 181 ++ .../pkg/compiler/transformers/k8s/utils.go | 86 + .../compiler/transformers/k8s/utils_test.go | 58 + .../pkg/compiler/transformers/k8s/workflow.go | 208 +++ .../transformers/k8s/workflow_test.go | 238 +++ .../pkg/compiler/typing/variable.go | 36 + flytepropeller/pkg/compiler/utils.go | 79 + flytepropeller/pkg/compiler/utils_test.go | 65 + .../pkg/compiler/validators/bindings.go | 105 ++ .../pkg/compiler/validators/branch.go | 92 + .../pkg/compiler/validators/condition.go | 55 + .../pkg/compiler/validators/interface.go | 128 ++ .../pkg/compiler/validators/interface_test.go | 282 ++++ .../pkg/compiler/validators/node.go | 116 ++ .../pkg/compiler/validators/typing.go | 154 ++ .../pkg/compiler/validators/typing_test.go | 359 ++++ .../pkg/compiler/validators/utils.go | 199 +++ .../pkg/compiler/validators/vars.go | 81 + .../pkg/compiler/workflow_compiler.go | 330 ++++ .../pkg/compiler/workflow_compiler_test.go | 659 ++++++++ .../pkg/controller/catalog/catalog_client.go | 28 + .../pkg/controller/catalog/config_flags.go | 47 + .../controller/catalog/config_flags_test.go | 146 ++ .../controller/catalog/discovery_config.go | 34 + .../controller/catalog/legacy_discovery.go | 202 +++ .../catalog/legacy_discovery_test.go | 286 ++++ .../pkg/controller/catalog/mock_catalog.go | 21 + .../pkg/controller/catalog/no_op_discovery.go | 29 + .../catalog/no_op_discovery_test.go | 26 + .../pkg/controller/completed_workflows.go | 87 + .../controller/completed_workflows_test.go | 160 ++ .../pkg/controller/composite_workqueue.go | 172 ++ .../controller/composite_workqueue_test.go | 146 ++ .../pkg/controller/config/config.go | 93 ++ .../pkg/controller/config/config_flags.go | 78 + .../controller/config/config_flags_test.go | 828 +++++++++ flytepropeller/pkg/controller/controller.go | 305 ++++ .../pkg/controller/executors/contextual.go | 30 + .../pkg/controller/executors/kube.go | 56 + .../pkg/controller/executors/mocks/Client.go | 45 + .../pkg/controller/executors/mocks/fake.go | 13 + .../pkg/controller/executors/node.go | 100 ++ .../pkg/controller/executors/workflow.go | 13 + flytepropeller/pkg/controller/finalizer.go | 36 + .../pkg/controller/finalizer_test.go | 70 + .../pkg/controller/garbage_collector.go | 145 ++ .../pkg/controller/garbage_collector_test.go | 166 ++ flytepropeller/pkg/controller/handler.go | 188 +++ flytepropeller/pkg/controller/handler_test.go | 408 +++++ .../pkg/controller/leaderelection.go | 78 + .../pkg/controller/nodes/branch/comparator.go | 139 ++ .../nodes/branch/comparator_test.go | 403 +++++ .../pkg/controller/nodes/branch/evaluator.go | 139 ++ .../controller/nodes/branch/evaluator_test.go | 667 ++++++++ .../pkg/controller/nodes/branch/handler.go | 135 ++ .../controller/nodes/branch/handler_test.go | 236 +++ .../nodes/common/output_resolver.go | 62 + .../nodes/common/output_resolver_test.go | 30 + .../pkg/controller/nodes/dynamic/handler.go | 391 +++++ .../controller/nodes/dynamic/handler_test.go | 261 +++ .../controller/nodes/dynamic/subworkflow.go | 89 + .../nodes/dynamic/subworkflow_test.go | 52 + .../pkg/controller/nodes/dynamic/utils.go | 105 ++ .../controller/nodes/dynamic/utils_test.go | 77 + .../pkg/controller/nodes/end/handler.go | 52 + .../pkg/controller/nodes/end/handler_test.go | 135 ++ .../pkg/controller/nodes/errors/codes.go | 26 + .../pkg/controller/nodes/errors/errors.go | 80 + .../controller/nodes/errors/errors_test.go | 48 + .../pkg/controller/nodes/executor.go | 540 ++++++ .../pkg/controller/nodes/executor_test.go | 1479 +++++++++++++++++ .../pkg/controller/nodes/handler/iface.go | 128 ++ .../controller/nodes/handler/mocks/IFace.go | 105 ++ .../nodes/handler/mocks/OutputResolver.go | 37 + .../pkg/controller/nodes/handler_factory.go | 76 + .../controller/nodes/mocks/HandlerFactory.go | 36 + .../pkg/controller/nodes/predicate.go | 111 ++ .../pkg/controller/nodes/predicate_test.go | 550 ++++++ .../pkg/controller/nodes/resolve.go | 104 ++ .../pkg/controller/nodes/resolve_test.go | 434 +++++ .../pkg/controller/nodes/start/handler.go | 40 + .../controller/nodes/start/handler_test.go | 77 + .../controller/nodes/subworkflow/handler.go | 78 + .../nodes/subworkflow/handler_test.go | 230 +++ .../nodes/subworkflow/launchplan.go | 139 ++ .../nodes/subworkflow/launchplan/admin.go | 167 ++ .../subworkflow/launchplan/admin_test.go | 277 +++ .../subworkflow/launchplan/adminconfig.go | 33 + .../launchplan/adminconfig_flags.go | 48 + .../launchplan/adminconfig_flags_test.go | 168 ++ .../nodes/subworkflow/launchplan/errors.go | 57 + .../subworkflow/launchplan/errors_test.go | 36 + .../subworkflow/launchplan/launchplan.go | 36 + .../subworkflow/launchplan/mocks/Executor.go | 79 + .../nodes/subworkflow/launchplan/noop.go | 37 + .../nodes/subworkflow/launchplan/noop_test.go | 51 + .../nodes/subworkflow/launchplan_test.go | 640 +++++++ .../nodes/subworkflow/sub_workflow.go | 195 +++ .../pkg/controller/nodes/subworkflow/util.go | 24 + .../controller/nodes/subworkflow/util_test.go | 19 + .../pkg/controller/nodes/task/factory.go | 72 + .../pkg/controller/nodes/task/handler.go | 439 +++++ .../pkg/controller/nodes/task/handler_test.go | 769 +++++++++ flytepropeller/pkg/controller/workers.go | 177 ++ flytepropeller/pkg/controller/workers_test.go | 93 ++ .../pkg/controller/workflow/errors/codes.go | 15 + .../pkg/controller/workflow/errors/errors.go | 80 + .../controller/workflow/errors/errors_test.go | 48 + .../pkg/controller/workflow/executor.go | 420 +++++ .../pkg/controller/workflow/executor_test.go | 615 +++++++ .../workflow/testdata/benchmark_wf.yaml | 378 +++++ .../pkg/controller/workflowstore/errors.go | 18 + .../pkg/controller/workflowstore/iface.go | 20 + .../pkg/controller/workflowstore/inmemory.go | 70 + .../controller/workflowstore/passthrough.go | 106 ++ .../workflowstore/passthrough_test.go | 130 ++ .../workflowstore/resource_version_caching.go | 96 ++ .../resource_version_caching_test.go | 153 ++ flytepropeller/pkg/controller/workqueue.go | 47 + .../pkg/controller/workqueue_test.go | 54 + flytepropeller/pkg/signals/signal.go | 46 + flytepropeller/pkg/signals/signal_posix.go | 28 + flytepropeller/pkg/signals/signal_windows.go | 25 + flytepropeller/pkg/utils/assert/literals.go | 74 + flytepropeller/pkg/utils/bindings.go | 85 + flytepropeller/pkg/utils/bindings_test.go | 156 ++ flytepropeller/pkg/utils/encoder.go | 55 + flytepropeller/pkg/utils/encoder_test.go | 61 + flytepropeller/pkg/utils/event_helpers.go | 32 + flytepropeller/pkg/utils/failing_datastore.go | 32 + .../pkg/utils/failing_datastore_test.go | 26 + flytepropeller/pkg/utils/helpers.go | 12 + flytepropeller/pkg/utils/helpers_test.go | 19 + flytepropeller/pkg/utils/k8s.go | 105 ++ flytepropeller/pkg/utils/k8s_test.go | 194 +++ flytepropeller/pkg/utils/literals.go | 276 +++ flytepropeller/pkg/utils/literals_test.go | 200 +++ flytepropeller/pkg/visualize/nodeq.go | 35 + flytepropeller/pkg/visualize/sort.go | 72 + flytepropeller/pkg/visualize/visualize.go | 244 +++ flytepropeller/raw_examples/README.md | 3 + .../raw_examples/example-condition.yaml | 104 ++ .../raw_examples/example-inputs.yaml | 61 + .../raw_examples/example-noinputs.yaml | 41 + 302 files changed, 36721 insertions(+) create mode 100644 flytepropeller/.dockerignore create mode 100644 flytepropeller/.gitignore create mode 100644 flytepropeller/.golangci.yml create mode 100644 flytepropeller/.travis.yml create mode 100644 flytepropeller/CODE_OF_CONDUCT.md create mode 100644 flytepropeller/Dockerfile create mode 100644 flytepropeller/Gopkg.lock create mode 100644 flytepropeller/Gopkg.toml create mode 100644 flytepropeller/LICENSE create mode 100644 flytepropeller/Makefile create mode 100644 flytepropeller/NOTICE create mode 100644 flytepropeller/README.rst create mode 100644 flytepropeller/boilerplate/lyft/docker_build/Makefile create mode 100644 flytepropeller/boilerplate/lyft/docker_build/Readme.rst create mode 100755 flytepropeller/boilerplate/lyft/docker_build/docker_build.sh create mode 100644 flytepropeller/boilerplate/lyft/golang_dockerfile/Dockerfile.GoTemplate create mode 100644 flytepropeller/boilerplate/lyft/golang_dockerfile/Readme.rst create mode 100755 flytepropeller/boilerplate/lyft/golang_dockerfile/update.sh create mode 100644 flytepropeller/boilerplate/lyft/golang_test_targets/Makefile create mode 100644 flytepropeller/boilerplate/lyft/golang_test_targets/Readme.rst create mode 100755 flytepropeller/boilerplate/lyft/golang_test_targets/goimports create mode 100644 flytepropeller/boilerplate/lyft/golangci_file/.golangci.yml create mode 100644 flytepropeller/boilerplate/lyft/golangci_file/Readme.rst create mode 100755 flytepropeller/boilerplate/lyft/golangci_file/update.sh create mode 100644 flytepropeller/boilerplate/update.cfg create mode 100755 flytepropeller/boilerplate/update.sh create mode 100644 flytepropeller/cmd/controller/cmd/root.go create mode 100644 flytepropeller/cmd/controller/main.go create mode 100644 flytepropeller/cmd/kubectl-flyte/cmd/compile.go create mode 100644 flytepropeller/cmd/kubectl-flyte/cmd/create.go create mode 100644 flytepropeller/cmd/kubectl-flyte/cmd/create_test.go create mode 100644 flytepropeller/cmd/kubectl-flyte/cmd/delete.go create mode 100644 flytepropeller/cmd/kubectl-flyte/cmd/get.go create mode 100644 flytepropeller/cmd/kubectl-flyte/cmd/printers/node.go create mode 100644 flytepropeller/cmd/kubectl-flyte/cmd/printers/workflow.go create mode 100644 flytepropeller/cmd/kubectl-flyte/cmd/root.go create mode 100644 flytepropeller/cmd/kubectl-flyte/cmd/string_map_value.go create mode 100644 flytepropeller/cmd/kubectl-flyte/cmd/string_map_value_test.go create mode 100755 flytepropeller/cmd/kubectl-flyte/cmd/testdata/inputs.json.golden create mode 100755 flytepropeller/cmd/kubectl-flyte/cmd/testdata/inputs.pb.golden create mode 100755 flytepropeller/cmd/kubectl-flyte/cmd/testdata/inputs.yaml.golden create mode 100755 flytepropeller/cmd/kubectl-flyte/cmd/testdata/workflow.json.golden create mode 100755 flytepropeller/cmd/kubectl-flyte/cmd/testdata/workflow.pb.golden create mode 100755 flytepropeller/cmd/kubectl-flyte/cmd/testdata/workflow.yaml.golden create mode 100755 flytepropeller/cmd/kubectl-flyte/cmd/testdata/workflow_w_inputs.json.golden create mode 100755 flytepropeller/cmd/kubectl-flyte/cmd/testdata/workflow_w_inputs.pb.golden create mode 100755 flytepropeller/cmd/kubectl-flyte/cmd/testdata/workflow_w_inputs.yaml.golden create mode 100644 flytepropeller/cmd/kubectl-flyte/cmd/util.go create mode 100644 flytepropeller/cmd/kubectl-flyte/cmd/visualize.go create mode 100644 flytepropeller/cmd/kubectl-flyte/main.go create mode 100644 flytepropeller/config.yaml create mode 100644 flytepropeller/hack/boilerplate.go.txt create mode 100644 flytepropeller/hack/custom-boilerplate.go.txt create mode 100755 flytepropeller/hack/update-codegen.sh create mode 100755 flytepropeller/hack/verify-codegen.sh create mode 100644 flytepropeller/pkg/apis/flyteworkflow/register.go create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/branch.go create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/branch_test.go create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/doc.go create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/identifier.go create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/iface.go create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/BaseNode.go create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/BaseWorkflow.go create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/BaseWorkflowWithStatus.go create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableBranchNode.go create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableBranchNodeStatus.go create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableDynamicNodeStatus.go create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableIfBlock.go create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNode.go create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNodeStatus.go create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableSubWorkflow.go create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableSubWorkflowNodeStatus.go create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableTask.go create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableTaskNodeStatus.go create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableWorkflow.go create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableWorkflowNode.go create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableWorkflowNodeStatus.go create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableWorkflowStatus.go create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableBranchNodeStatus.go create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableDynamicNodeStatus.go create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableNodeStatus.go create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableSubWorkflowNodeStatus.go create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableTaskNodeStatus.go create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableWorkflowNodeStatus.go create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/NodeStatusGetter.go create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/NodeStatusVisitor.go create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/WorkflowMeta.go create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/WorkflowMetaExtended.go create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/node_status.go create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/node_status_test.go create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/nodes.go create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/register.go create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/subworkflow.go create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/tasks.go create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/tasks_test.go create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/testdata/branch.json create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/testdata/connections.json create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/testdata/task.yaml create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/testdata/workflowspec.yaml create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/workflow.go create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/workflow_status.go create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/workflow_status_test.go create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/workflow_test.go create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/zz_generated.deepcopy.go create mode 100644 flytepropeller/pkg/client/clientset/versioned/clientset.go create mode 100644 flytepropeller/pkg/client/clientset/versioned/doc.go create mode 100644 flytepropeller/pkg/client/clientset/versioned/fake/clientset_generated.go create mode 100644 flytepropeller/pkg/client/clientset/versioned/fake/doc.go create mode 100644 flytepropeller/pkg/client/clientset/versioned/fake/register.go create mode 100644 flytepropeller/pkg/client/clientset/versioned/scheme/doc.go create mode 100644 flytepropeller/pkg/client/clientset/versioned/scheme/register.go create mode 100644 flytepropeller/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/doc.go create mode 100644 flytepropeller/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/fake/doc.go create mode 100644 flytepropeller/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/fake/fake_flyteworkflow.go create mode 100644 flytepropeller/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/fake/fake_flyteworkflow_client.go create mode 100644 flytepropeller/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/flyteworkflow.go create mode 100644 flytepropeller/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/flyteworkflow_client.go create mode 100644 flytepropeller/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/generated_expansion.go create mode 100644 flytepropeller/pkg/client/informers/externalversions/factory.go create mode 100644 flytepropeller/pkg/client/informers/externalversions/flyteworkflow/interface.go create mode 100644 flytepropeller/pkg/client/informers/externalversions/flyteworkflow/v1alpha1/flyteworkflow.go create mode 100644 flytepropeller/pkg/client/informers/externalversions/flyteworkflow/v1alpha1/interface.go create mode 100644 flytepropeller/pkg/client/informers/externalversions/generic.go create mode 100644 flytepropeller/pkg/client/informers/externalversions/internalinterfaces/factory_interfaces.go create mode 100644 flytepropeller/pkg/client/listers/flyteworkflow/v1alpha1/expansion_generated.go create mode 100644 flytepropeller/pkg/client/listers/flyteworkflow/v1alpha1/flyteworkflow.go create mode 100755 flytepropeller/pkg/compiler/builders.go create mode 100644 flytepropeller/pkg/compiler/common/builder.go create mode 100644 flytepropeller/pkg/compiler/common/id_set.go create mode 100644 flytepropeller/pkg/compiler/common/index.go create mode 100644 flytepropeller/pkg/compiler/common/mocks/interface_provider.go create mode 100644 flytepropeller/pkg/compiler/common/mocks/node.go create mode 100644 flytepropeller/pkg/compiler/common/mocks/node_builder.go create mode 100644 flytepropeller/pkg/compiler/common/mocks/task.go create mode 100644 flytepropeller/pkg/compiler/common/mocks/workflow.go create mode 100644 flytepropeller/pkg/compiler/common/mocks/workflow_builder.go create mode 100644 flytepropeller/pkg/compiler/common/reader.go create mode 100644 flytepropeller/pkg/compiler/errors/compiler_error_test.go create mode 100755 flytepropeller/pkg/compiler/errors/compiler_errors.go create mode 100644 flytepropeller/pkg/compiler/errors/config.go create mode 100755 flytepropeller/pkg/compiler/errors/error.go create mode 100755 flytepropeller/pkg/compiler/errors/error_test.go create mode 100755 flytepropeller/pkg/compiler/errors/sets.go create mode 100644 flytepropeller/pkg/compiler/errors/sets_test.go create mode 100755 flytepropeller/pkg/compiler/requirements.go create mode 100755 flytepropeller/pkg/compiler/requirements_test.go create mode 100644 flytepropeller/pkg/compiler/task_compiler.go create mode 100644 flytepropeller/pkg/compiler/task_compiler_test.go create mode 100644 flytepropeller/pkg/compiler/test/compiler_test.go create mode 100755 flytepropeller/pkg/compiler/test/testdata/app-workflows-work-one-python-task-w-f-inputs.pb create mode 100755 flytepropeller/pkg/compiler/test/testdata/app-workflows-work-one-python-task-w-f-inputs.yaml create mode 100755 flytepropeller/pkg/compiler/test/testdata/app-workflows-work-one-python-task-w-f.pb create mode 100644 flytepropeller/pkg/compiler/test/testdata/app-workflows-work-one-python-task-w-f.yaml create mode 100755 flytepropeller/pkg/compiler/testdata/beta-one-second-functional-test.dot.golden create mode 100644 flytepropeller/pkg/compiler/transformers/k8s/builder_mock_test.go create mode 100644 flytepropeller/pkg/compiler/transformers/k8s/inputs.go create mode 100644 flytepropeller/pkg/compiler/transformers/k8s/inputs_test.go create mode 100644 flytepropeller/pkg/compiler/transformers/k8s/node.go create mode 100644 flytepropeller/pkg/compiler/transformers/k8s/node_test.go create mode 100644 flytepropeller/pkg/compiler/transformers/k8s/utils.go create mode 100644 flytepropeller/pkg/compiler/transformers/k8s/utils_test.go create mode 100644 flytepropeller/pkg/compiler/transformers/k8s/workflow.go create mode 100644 flytepropeller/pkg/compiler/transformers/k8s/workflow_test.go create mode 100644 flytepropeller/pkg/compiler/typing/variable.go create mode 100755 flytepropeller/pkg/compiler/utils.go create mode 100644 flytepropeller/pkg/compiler/utils_test.go create mode 100644 flytepropeller/pkg/compiler/validators/bindings.go create mode 100644 flytepropeller/pkg/compiler/validators/branch.go create mode 100644 flytepropeller/pkg/compiler/validators/condition.go create mode 100644 flytepropeller/pkg/compiler/validators/interface.go create mode 100644 flytepropeller/pkg/compiler/validators/interface_test.go create mode 100644 flytepropeller/pkg/compiler/validators/node.go create mode 100644 flytepropeller/pkg/compiler/validators/typing.go create mode 100644 flytepropeller/pkg/compiler/validators/typing_test.go create mode 100644 flytepropeller/pkg/compiler/validators/utils.go create mode 100644 flytepropeller/pkg/compiler/validators/vars.go create mode 100755 flytepropeller/pkg/compiler/workflow_compiler.go create mode 100755 flytepropeller/pkg/compiler/workflow_compiler_test.go create mode 100644 flytepropeller/pkg/controller/catalog/catalog_client.go create mode 100755 flytepropeller/pkg/controller/catalog/config_flags.go create mode 100755 flytepropeller/pkg/controller/catalog/config_flags_test.go create mode 100644 flytepropeller/pkg/controller/catalog/discovery_config.go create mode 100644 flytepropeller/pkg/controller/catalog/legacy_discovery.go create mode 100644 flytepropeller/pkg/controller/catalog/legacy_discovery_test.go create mode 100644 flytepropeller/pkg/controller/catalog/mock_catalog.go create mode 100644 flytepropeller/pkg/controller/catalog/no_op_discovery.go create mode 100644 flytepropeller/pkg/controller/catalog/no_op_discovery_test.go create mode 100644 flytepropeller/pkg/controller/completed_workflows.go create mode 100644 flytepropeller/pkg/controller/completed_workflows_test.go create mode 100644 flytepropeller/pkg/controller/composite_workqueue.go create mode 100644 flytepropeller/pkg/controller/composite_workqueue_test.go create mode 100644 flytepropeller/pkg/controller/config/config.go create mode 100755 flytepropeller/pkg/controller/config/config_flags.go create mode 100755 flytepropeller/pkg/controller/config/config_flags_test.go create mode 100644 flytepropeller/pkg/controller/controller.go create mode 100644 flytepropeller/pkg/controller/executors/contextual.go create mode 100644 flytepropeller/pkg/controller/executors/kube.go create mode 100644 flytepropeller/pkg/controller/executors/mocks/Client.go create mode 100644 flytepropeller/pkg/controller/executors/mocks/fake.go create mode 100644 flytepropeller/pkg/controller/executors/node.go create mode 100644 flytepropeller/pkg/controller/executors/workflow.go create mode 100644 flytepropeller/pkg/controller/finalizer.go create mode 100644 flytepropeller/pkg/controller/finalizer_test.go create mode 100644 flytepropeller/pkg/controller/garbage_collector.go create mode 100644 flytepropeller/pkg/controller/garbage_collector_test.go create mode 100644 flytepropeller/pkg/controller/handler.go create mode 100644 flytepropeller/pkg/controller/handler_test.go create mode 100644 flytepropeller/pkg/controller/leaderelection.go create mode 100644 flytepropeller/pkg/controller/nodes/branch/comparator.go create mode 100644 flytepropeller/pkg/controller/nodes/branch/comparator_test.go create mode 100644 flytepropeller/pkg/controller/nodes/branch/evaluator.go create mode 100644 flytepropeller/pkg/controller/nodes/branch/evaluator_test.go create mode 100644 flytepropeller/pkg/controller/nodes/branch/handler.go create mode 100644 flytepropeller/pkg/controller/nodes/branch/handler_test.go create mode 100644 flytepropeller/pkg/controller/nodes/common/output_resolver.go create mode 100644 flytepropeller/pkg/controller/nodes/common/output_resolver_test.go create mode 100644 flytepropeller/pkg/controller/nodes/dynamic/handler.go create mode 100644 flytepropeller/pkg/controller/nodes/dynamic/handler_test.go create mode 100644 flytepropeller/pkg/controller/nodes/dynamic/subworkflow.go create mode 100644 flytepropeller/pkg/controller/nodes/dynamic/subworkflow_test.go create mode 100644 flytepropeller/pkg/controller/nodes/dynamic/utils.go create mode 100644 flytepropeller/pkg/controller/nodes/dynamic/utils_test.go create mode 100644 flytepropeller/pkg/controller/nodes/end/handler.go create mode 100644 flytepropeller/pkg/controller/nodes/end/handler_test.go create mode 100644 flytepropeller/pkg/controller/nodes/errors/codes.go create mode 100644 flytepropeller/pkg/controller/nodes/errors/errors.go create mode 100644 flytepropeller/pkg/controller/nodes/errors/errors_test.go create mode 100644 flytepropeller/pkg/controller/nodes/executor.go create mode 100644 flytepropeller/pkg/controller/nodes/executor_test.go create mode 100644 flytepropeller/pkg/controller/nodes/handler/iface.go create mode 100644 flytepropeller/pkg/controller/nodes/handler/mocks/IFace.go create mode 100644 flytepropeller/pkg/controller/nodes/handler/mocks/OutputResolver.go create mode 100644 flytepropeller/pkg/controller/nodes/handler_factory.go create mode 100644 flytepropeller/pkg/controller/nodes/mocks/HandlerFactory.go create mode 100644 flytepropeller/pkg/controller/nodes/predicate.go create mode 100644 flytepropeller/pkg/controller/nodes/predicate_test.go create mode 100644 flytepropeller/pkg/controller/nodes/resolve.go create mode 100644 flytepropeller/pkg/controller/nodes/resolve_test.go create mode 100644 flytepropeller/pkg/controller/nodes/start/handler.go create mode 100644 flytepropeller/pkg/controller/nodes/start/handler_test.go create mode 100644 flytepropeller/pkg/controller/nodes/subworkflow/handler.go create mode 100644 flytepropeller/pkg/controller/nodes/subworkflow/handler_test.go create mode 100644 flytepropeller/pkg/controller/nodes/subworkflow/launchplan.go create mode 100644 flytepropeller/pkg/controller/nodes/subworkflow/launchplan/admin.go create mode 100644 flytepropeller/pkg/controller/nodes/subworkflow/launchplan/admin_test.go create mode 100644 flytepropeller/pkg/controller/nodes/subworkflow/launchplan/adminconfig.go create mode 100755 flytepropeller/pkg/controller/nodes/subworkflow/launchplan/adminconfig_flags.go create mode 100755 flytepropeller/pkg/controller/nodes/subworkflow/launchplan/adminconfig_flags_test.go create mode 100644 flytepropeller/pkg/controller/nodes/subworkflow/launchplan/errors.go create mode 100644 flytepropeller/pkg/controller/nodes/subworkflow/launchplan/errors_test.go create mode 100644 flytepropeller/pkg/controller/nodes/subworkflow/launchplan/launchplan.go create mode 100644 flytepropeller/pkg/controller/nodes/subworkflow/launchplan/mocks/Executor.go create mode 100644 flytepropeller/pkg/controller/nodes/subworkflow/launchplan/noop.go create mode 100644 flytepropeller/pkg/controller/nodes/subworkflow/launchplan/noop_test.go create mode 100644 flytepropeller/pkg/controller/nodes/subworkflow/launchplan_test.go create mode 100644 flytepropeller/pkg/controller/nodes/subworkflow/sub_workflow.go create mode 100644 flytepropeller/pkg/controller/nodes/subworkflow/util.go create mode 100644 flytepropeller/pkg/controller/nodes/subworkflow/util_test.go create mode 100644 flytepropeller/pkg/controller/nodes/task/factory.go create mode 100644 flytepropeller/pkg/controller/nodes/task/handler.go create mode 100644 flytepropeller/pkg/controller/nodes/task/handler_test.go create mode 100644 flytepropeller/pkg/controller/workers.go create mode 100644 flytepropeller/pkg/controller/workers_test.go create mode 100644 flytepropeller/pkg/controller/workflow/errors/codes.go create mode 100644 flytepropeller/pkg/controller/workflow/errors/errors.go create mode 100644 flytepropeller/pkg/controller/workflow/errors/errors_test.go create mode 100644 flytepropeller/pkg/controller/workflow/executor.go create mode 100644 flytepropeller/pkg/controller/workflow/executor_test.go create mode 100644 flytepropeller/pkg/controller/workflow/testdata/benchmark_wf.yaml create mode 100644 flytepropeller/pkg/controller/workflowstore/errors.go create mode 100644 flytepropeller/pkg/controller/workflowstore/iface.go create mode 100644 flytepropeller/pkg/controller/workflowstore/inmemory.go create mode 100644 flytepropeller/pkg/controller/workflowstore/passthrough.go create mode 100644 flytepropeller/pkg/controller/workflowstore/passthrough_test.go create mode 100644 flytepropeller/pkg/controller/workflowstore/resource_version_caching.go create mode 100644 flytepropeller/pkg/controller/workflowstore/resource_version_caching_test.go create mode 100644 flytepropeller/pkg/controller/workqueue.go create mode 100644 flytepropeller/pkg/controller/workqueue_test.go create mode 100644 flytepropeller/pkg/signals/signal.go create mode 100644 flytepropeller/pkg/signals/signal_posix.go create mode 100644 flytepropeller/pkg/signals/signal_windows.go create mode 100644 flytepropeller/pkg/utils/assert/literals.go create mode 100644 flytepropeller/pkg/utils/bindings.go create mode 100644 flytepropeller/pkg/utils/bindings_test.go create mode 100644 flytepropeller/pkg/utils/encoder.go create mode 100644 flytepropeller/pkg/utils/encoder_test.go create mode 100644 flytepropeller/pkg/utils/event_helpers.go create mode 100644 flytepropeller/pkg/utils/failing_datastore.go create mode 100644 flytepropeller/pkg/utils/failing_datastore_test.go create mode 100644 flytepropeller/pkg/utils/helpers.go create mode 100644 flytepropeller/pkg/utils/helpers_test.go create mode 100644 flytepropeller/pkg/utils/k8s.go create mode 100644 flytepropeller/pkg/utils/k8s_test.go create mode 100644 flytepropeller/pkg/utils/literals.go create mode 100644 flytepropeller/pkg/utils/literals_test.go create mode 100644 flytepropeller/pkg/visualize/nodeq.go create mode 100644 flytepropeller/pkg/visualize/sort.go create mode 100644 flytepropeller/pkg/visualize/visualize.go create mode 100644 flytepropeller/raw_examples/README.md create mode 100644 flytepropeller/raw_examples/example-condition.yaml create mode 100644 flytepropeller/raw_examples/example-inputs.yaml create mode 100644 flytepropeller/raw_examples/example-noinputs.yaml diff --git a/flytepropeller/.dockerignore b/flytepropeller/.dockerignore new file mode 100644 index 0000000000..ce85d3b8c5 --- /dev/null +++ b/flytepropeller/.dockerignore @@ -0,0 +1,2 @@ +vendor/* +bin/* diff --git a/flytepropeller/.gitignore b/flytepropeller/.gitignore new file mode 100644 index 0000000000..ff8a9b166d --- /dev/null +++ b/flytepropeller/.gitignore @@ -0,0 +1,4 @@ +.idea +vendor +bin +.DS_Store diff --git a/flytepropeller/.golangci.yml b/flytepropeller/.golangci.yml new file mode 100644 index 0000000000..a414f33f79 --- /dev/null +++ b/flytepropeller/.golangci.yml @@ -0,0 +1,30 @@ +# WARNING: THIS FILE IS MANAGED IN THE 'BOILERPLATE' REPO AND COPIED TO OTHER REPOSITORIES. +# ONLY EDIT THIS FILE FROM WITHIN THE 'LYFT/BOILERPLATE' REPOSITORY: +# +# TO OPT OUT OF UPDATES, SEE https://github.com/lyft/boilerplate/blob/master/Readme.rst + +run: + skip-dirs: + - pkg/client + +linters: + disable-all: true + enable: + - deadcode + - errcheck + - gas + - goconst + - goimports + - golint + - gosimple + - govet + - ineffassign + - misspell + - nakedret + - staticcheck + - structcheck + - typecheck + - unconvert + - unparam + - unused + - varcheck diff --git a/flytepropeller/.travis.yml b/flytepropeller/.travis.yml new file mode 100644 index 0000000000..9ec4f2c9c6 --- /dev/null +++ b/flytepropeller/.travis.yml @@ -0,0 +1,26 @@ +sudo: required +language: go +go: + - "1.12" +services: + - docker +jobs: + include: + - if: fork = true + stage: test + name: docker build + install: true + script: make docker_build + - if: fork = false + stage: test + name: docker build and push + install: true + script: make dockerhub_push + - stage: test + install: make install + name: lint + script: make lint + - stage: test + name: unit tests + install: make install + script: make test_unit diff --git a/flytepropeller/CODE_OF_CONDUCT.md b/flytepropeller/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000..803d8a77f3 --- /dev/null +++ b/flytepropeller/CODE_OF_CONDUCT.md @@ -0,0 +1,3 @@ +This project is governed by [Lyft's code of +conduct](https://github.com/lyft/code-of-conduct). All contributors +and participants agree to abide by its terms. diff --git a/flytepropeller/Dockerfile b/flytepropeller/Dockerfile new file mode 100644 index 0000000000..2ad29f2ff6 --- /dev/null +++ b/flytepropeller/Dockerfile @@ -0,0 +1,33 @@ +# WARNING: THIS FILE IS MANAGED IN THE 'BOILERPLATE' REPO AND COPIED TO OTHER REPOSITORIES. +# ONLY EDIT THIS FILE FROM WITHIN THE 'LYFT/BOILERPLATE' REPOSITORY: +# +# TO OPT OUT OF UPDATES, SEE https://github.com/lyft/boilerplate/blob/master/Readme.rst + +# Using go1.10.4 +FROM golang:1.10.4-alpine3.8 as builder +RUN apk add git openssh-client make curl dep + +# COPY only the dep files for efficient caching +COPY Gopkg.* /go/src/github.com/lyft/flytepropeller/ +WORKDIR /go/src/github.com/lyft/flytepropeller + +# Pull dependencies +RUN dep ensure -vendor-only + +# COPY the rest of the source code +COPY . /go/src/github.com/lyft/flytepropeller/ + +# This 'linux_compile' target should compile binaries to the /artifacts directory +# The main entrypoint should be compiled to /artifacts/flytepropeller +RUN make linux_compile + +# update the PATH to include the /artifacts directory +ENV PATH="/artifacts:${PATH}" + +# This will eventually move to centurylink/ca-certs:latest for minimum possible image size +FROM alpine:3.8 +COPY --from=builder /artifacts /bin + +RUN apk --update add ca-certificates + +CMD ["flytepropeller"] diff --git a/flytepropeller/Gopkg.lock b/flytepropeller/Gopkg.lock new file mode 100644 index 0000000000..08a0aa8c73 --- /dev/null +++ b/flytepropeller/Gopkg.lock @@ -0,0 +1,1433 @@ +# This file is autogenerated, do not edit; changes may be undone by the next 'dep ensure'. + + +[[projects]] + digest = "1:80e5d0810f1448259385b08f381852a83f87b6c958d8500e621db821e15c3771" + name = "cloud.google.com/go" + packages = ["compute/metadata"] + pruneopts = "" + revision = "cdaaf98f9226c39dc162b8e55083b2fbc67b4674" + version = "v0.43.0" + +[[projects]] + digest = "1:6158256042564abf0da300ea7cb016f79ddaf24fdda2cc06c9712b0c2e06dd2a" + name = "contrib.go.opencensus.io/exporter/ocagent" + packages = ["."] + pruneopts = "" + revision = "dcb33c7f3b7cfe67e8a2cea10207ede1b7c40764" + version = "v0.4.12" + +[[projects]] + digest = "1:9a11be778d5fcb8e4873e64a097dfd2862d8665d9e2d969b90810d5272e51acb" + name = "github.com/Azure/azure-sdk-for-go" + packages = ["storage"] + pruneopts = "" + revision = "2d49bb8f2cee530cc16f1f1a9f0aae763dee257d" + version = "v10.2.1-beta" + +[[projects]] + digest = "1:5cb9540799639936e705a6ac54cfb6744b598519485fb357acb6e3285f43fbfb" + name = "github.com/Azure/go-autorest" + packages = [ + "autorest", + "autorest/adal", + "autorest/azure", + "autorest/date", + "logger", + "tracing", + ] + pruneopts = "" + revision = "7166fb346dbf8978ad28211a1937b20fdabc08c8" + version = "v12.4.2" + +[[projects]] + digest = "1:558b53577dc0c9fde49b08405d706b202bcac3064320e9be53a75fc866280ee3" + name = "github.com/DiSiqueira/GoTree" + packages = ["."] + pruneopts = "" + revision = "53a8e837f2952215f256fc9acf4ecb2045b056fb" + version = "2.0.3" + +[[projects]] + digest = "1:e1549ae10031ac55dd7d26ac4d480130ddbdf97f9a26ebbedff089aa0335798f" + name = "github.com/GoogleCloudPlatform/spark-on-k8s-operator" + packages = [ + "pkg/apis/sparkoperator.k8s.io", + "pkg/apis/sparkoperator.k8s.io/v1beta1", + ] + pruneopts = "" + revision = "5306d013b4dbd6a9c75879c1643c7fcb237560ec" + source = "https://github.com/lyft/spark-on-k8s-operator" + version = "v0.1.3" + +[[projects]] + digest = "1:60942d250d0e06d3722ddc8e22bc52f8cef7961ba6d8d3e95327a32b6b024a7b" + name = "github.com/appscode/jsonpatch" + packages = ["."] + pruneopts = "" + revision = "7c0e3b262f30165a8ec3d0b4c6059fd92703bfb2" + version = "1.0.0" + +[[projects]] + digest = "1:cfe39a015adcf9cc2bce0e8bd38ecf041cb516b8ab7a2ecb11b1c84a4b8acabf" + name = "github.com/aws/aws-sdk-go" + packages = [ + "aws", + "aws/awserr", + "aws/awsutil", + "aws/client", + "aws/client/metadata", + "aws/corehandlers", + "aws/credentials", + "aws/credentials/ec2rolecreds", + "aws/credentials/endpointcreds", + "aws/credentials/processcreds", + "aws/credentials/stscreds", + "aws/csm", + "aws/defaults", + "aws/ec2metadata", + "aws/endpoints", + "aws/request", + "aws/session", + "aws/signer/v4", + "internal/ini", + "internal/s3err", + "internal/sdkio", + "internal/sdkrand", + "internal/sdkuri", + "internal/shareddefaults", + "private/protocol", + "private/protocol/eventstream", + "private/protocol/eventstream/eventstreamapi", + "private/protocol/json/jsonutil", + "private/protocol/query", + "private/protocol/query/queryutil", + "private/protocol/rest", + "private/protocol/restxml", + "private/protocol/xml/xmlutil", + "service/s3", + "service/s3/s3iface", + "service/s3/s3manager", + "service/sts", + "service/sts/stsiface", + ] + pruneopts = "" + revision = "14379de571db1ac1b08f2f723a1acc1810c4dd0d" + version = "v1.22.2" + +[[projects]] + branch = "master" + digest = "1:d700667a9f768e1c6db34b091ec28b709dd8de6a62a61cd9a3ceb39d442154a7" + name = "github.com/benlaurie/objecthash" + packages = ["go/objecthash"] + pruneopts = "" + revision = "d1e3d6079fc16f8f542183fb5b2fdc11d9f00866" + +[[projects]] + digest = "1:ac2a05be7167c495fe8aaf8aaf62ecf81e78d2180ecb04e16778dc6c185c96a5" + name = "github.com/beorn7/perks" + packages = ["quantile"] + pruneopts = "" + revision = "37c8de3658fcb183f997c4e13e8337516ab753e6" + version = "v1.0.1" + +[[projects]] + digest = "1:ad70cf78ff17abf96d92a6082f4d3241fef8f149118f87c3a267ed47a08be603" + name = "github.com/census-instrumentation/opencensus-proto" + packages = [ + "gen-go/agent/common/v1", + "gen-go/agent/metrics/v1", + "gen-go/agent/trace/v1", + "gen-go/metrics/v1", + "gen-go/resource/v1", + "gen-go/trace/v1", + ] + pruneopts = "" + revision = "d89fa54de508111353cb0b06403c00569be780d8" + version = "v0.2.1" + +[[projects]] + digest = "1:f6485831252319cd6ca29fc170adecf1eb81bf1e805f62f44eb48564ce2485fe" + name = "github.com/cespare/xxhash" + packages = ["."] + pruneopts = "" + revision = "3b82fb7d186719faeedd0c2864f868c74fbf79a1" + version = "v2.0.0" + +[[projects]] + digest = "1:193f6d32d751f26540aa8eeedc114ce0a51f9e77b6c22dda3a4db4e5f65aec66" + name = "github.com/coocood/freecache" + packages = ["."] + pruneopts = "" + revision = "3c79a0a23c1940ab4479332fb3e0127265650ce3" + version = "v1.1.0" + +[[projects]] + digest = "1:0deddd908b6b4b768cfc272c16ee61e7088a60f7fe2f06c547bd3d8e1f8b8e77" + name = "github.com/davecgh/go-spew" + packages = ["spew"] + pruneopts = "" + revision = "8991bc29aa16c548c550c7ff78260e27b9ab7c73" + version = "v1.1.1" + +[[projects]] + digest = "1:6098222470fe0172157ce9bbef5d2200df4edde17ee649c5d6e48330e4afa4c6" + name = "github.com/dgrijalva/jwt-go" + packages = ["."] + pruneopts = "" + revision = "06ea1031745cb8b3dab3f6a236daf2b0aa468b7e" + version = "v3.2.0" + +[[projects]] + digest = "1:46ddeb9dd35d875ac7568c4dc1fc96ce424e034bdbb984239d8ffc151398ec01" + name = "github.com/evanphx/json-patch" + packages = ["."] + pruneopts = "" + revision = "026c730a0dcc5d11f93f1cf1cc65b01247ea7b6f" + version = "v4.5.0" + +[[projects]] + digest = "1:e988ed0ca0d81f4d28772760c02ee95084961311291bdfefc1b04617c178b722" + name = "github.com/fatih/color" + packages = ["."] + pruneopts = "" + revision = "5b77d2a35fb0ede96d138fc9a99f5c9b6aef11b4" + version = "v1.7.0" + +[[projects]] + branch = "master" + digest = "1:135223bf2c128b2158178ee48779ac9983b003634864d46b73e913c95f7a847e" + name = "github.com/fsnotify/fsnotify" + packages = ["."] + pruneopts = "" + revision = "1485a34d5d5723fea214f5710708e19a831720e4" + +[[projects]] + digest = "1:b13707423743d41665fd23f0c36b2f37bb49c30e94adb813319c44188a51ba22" + name = "github.com/ghodss/yaml" + packages = ["."] + pruneopts = "" + revision = "0ca9ea5df5451ffdf184b4428c902747c2c11cd7" + version = "v1.0.0" + +[[projects]] + digest = "1:65587005c6fa4293c0b8a2e457e689df7fda48cc5e1f5449ea2c1e7784551558" + name = "github.com/go-logr/logr" + packages = ["."] + pruneopts = "" + revision = "9fb12b3b21c5415d16ac18dc5cd42c1cfdd40c4e" + version = "v0.1.0" + +[[projects]] + digest = "1:d81dfed1aa731d8e4a45d87154ec15ef18da2aa80fa9a2f95bec38577a244a99" + name = "github.com/go-logr/zapr" + packages = ["."] + pruneopts = "" + revision = "03f06a783fbb7dfaf3f629c7825480e43a7105e6" + version = "v0.1.1" + +[[projects]] + digest = "1:c2db84082861ca42d0b00580d28f4b31aceec477a00a38e1a057fb3da75c8adc" + name = "github.com/go-redis/redis" + packages = [ + ".", + "internal", + "internal/consistenthash", + "internal/hashtag", + "internal/pool", + "internal/proto", + "internal/util", + ] + pruneopts = "" + revision = "75795aa4236dc7341eefac3bbe945e68c99ef9df" + version = "v6.15.3" + +[[projects]] + digest = "1:fd53b471edb4c28c7d297f617f4da0d33402755f58d6301e7ca1197ef0a90937" + name = "github.com/gogo/protobuf" + packages = [ + "proto", + "sortkeys", + ] + pruneopts = "" + revision = "ba06b47c162d49f2af050fb4c75bcbc86a159d5c" + version = "v1.2.1" + +[[projects]] + branch = "master" + digest = "1:e1822d37be8e11e101357a27170527b1056c99182407f270e080f76409adbd9a" + name = "github.com/golang/groupcache" + packages = ["lru"] + pruneopts = "" + revision = "869f871628b6baa9cfbc11732cdf6546b17c1298" + +[[projects]] + digest = "1:b852d2b62be24e445fcdbad9ce3015b44c207815d631230dfce3f14e7803f5bf" + name = "github.com/golang/protobuf" + packages = [ + "jsonpb", + "proto", + "protoc-gen-go/descriptor", + "protoc-gen-go/generator", + "protoc-gen-go/generator/internal/remap", + "protoc-gen-go/plugin", + "ptypes", + "ptypes/any", + "ptypes/duration", + "ptypes/struct", + "ptypes/timestamp", + "ptypes/wrappers", + ] + pruneopts = "" + revision = "6c65a5562fc06764971b7c5d05c76c75e84bdbf7" + version = "v1.3.2" + +[[projects]] + digest = "1:1e5b1e14524ed08301977b7b8e10c719ed853cbf3f24ecb66fae783a46f207a6" + name = "github.com/google/btree" + packages = ["."] + pruneopts = "" + revision = "4030bb1f1f0c35b30ca7009e9ebd06849dd45306" + version = "v1.0.0" + +[[projects]] + digest = "1:8d4a577a9643f713c25a32151c0f26af7228b4b97a219b5ddb7fd38d16f6e673" + name = "github.com/google/gofuzz" + packages = ["."] + pruneopts = "" + revision = "f140a6486e521aad38f5917de355cbf147cc0496" + version = "v1.0.0" + +[[projects]] + digest = "1:ad92aa49f34cbc3546063c7eb2cabb55ee2278b72842eda80e2a20a8a06a8d73" + name = "github.com/google/uuid" + packages = ["."] + pruneopts = "" + revision = "0cd6bf5da1e1c83f8b45653022c74f71af0538a4" + version = "v1.1.1" + +[[projects]] + digest = "1:5facc3828b6a56f9aec988433ea33fb4407a89460952ed75be5347cec07318c0" + name = "github.com/googleapis/gnostic" + packages = [ + "OpenAPIv2", + "compiler", + "extensions", + ] + pruneopts = "" + revision = "e73c7ec21d36ddb0711cb36d1502d18363b5c2c9" + version = "v0.3.0" + +[[projects]] + digest = "1:1ea91d049b6a609f628ecdfda32e85f445a0d3671980dcbf7cbe1bbd7ee6aabc" + name = "github.com/graymeta/stow" + packages = [ + ".", + "azure", + "google", + "local", + "oracle", + "s3", + "swift", + ] + pruneopts = "" + revision = "903027f87de7054953efcdb8ba70d5dc02df38c7" + +[[projects]] + branch = "master" + digest = "1:e1fd67b5695fb12f54f979606c5d650a5aa72ef242f8e71072bfd4f7b5a141a0" + name = "github.com/gregjones/httpcache" + packages = [ + ".", + "diskcache", + ] + pruneopts = "" + revision = "901d90724c7919163f472a9812253fb26761123d" + +[[projects]] + digest = "1:9a0b2dd1f882668a3d7fbcd424eed269c383a16f1faa3a03d14e0dd5fba571b1" + name = "github.com/grpc-ecosystem/go-grpc-middleware" + packages = [ + ".", + "retry", + "util/backoffutils", + "util/metautils", + ] + pruneopts = "" + revision = "c250d6563d4d4c20252cd865923440e829844f4e" + version = "v1.0.0" + +[[projects]] + digest = "1:e24dc5ef44694848785de507f439a24e9e6d96d7b43b8cf3d6cfa857aa1e2186" + name = "github.com/grpc-ecosystem/go-grpc-prometheus" + packages = ["."] + pruneopts = "" + revision = "c225b8c3b01faf2899099b768856a9e916e5087b" + version = "v1.2.0" + +[[projects]] + digest = "1:4ab82898193e99be9d4f1f1eb4ca3b1113ab6b7b2ff4605198ae305de864f05e" + name = "github.com/grpc-ecosystem/grpc-gateway" + packages = [ + "internal", + "protoc-gen-swagger/options", + "runtime", + "utilities", + ] + pruneopts = "" + revision = "ad529a448ba494a88058f9e5be0988713174ac86" + version = "v1.9.5" + +[[projects]] + digest = "1:7f6f07500a0b7d3766b00fa466040b97f2f5b5f3eef2ecabfe516e703b05119a" + name = "github.com/hashicorp/golang-lru" + packages = [ + ".", + "simplelru", + ] + pruneopts = "" + revision = "7f827b33c0f158ec5dfbba01bb0b14a4541fd81d" + version = "v0.5.3" + +[[projects]] + digest = "1:d14365c51dd1d34d5c79833ec91413bfbb166be978724f15701e17080dc06dec" + name = "github.com/hashicorp/hcl" + packages = [ + ".", + "hcl/ast", + "hcl/parser", + "hcl/printer", + "hcl/scanner", + "hcl/strconv", + "hcl/token", + "json/parser", + "json/scanner", + "json/token", + ] + pruneopts = "" + revision = "8cb6e5b959231cc1119e43259c4a608f9c51a241" + version = "v1.0.0" + +[[projects]] + digest = "1:31bfd110d31505e9ffbc9478e31773bf05bf02adcaeb9b139af42684f9294c13" + name = "github.com/imdario/mergo" + packages = ["."] + pruneopts = "" + revision = "7c29201646fa3de8506f701213473dd407f19646" + version = "v0.3.7" + +[[projects]] + digest = "1:870d441fe217b8e689d7949fef6e43efbc787e50f200cb1e70dbca9204a1d6be" + name = "github.com/inconshreveable/mousetrap" + packages = ["."] + pruneopts = "" + revision = "76626ae9c91c4f2a10f34cad8ce83ea42c93bb75" + version = "v1.0" + +[[projects]] + digest = "1:13fe471d0ed891e8544eddfeeb0471fd3c9f2015609a1c000aefdedf52a19d40" + name = "github.com/jmespath/go-jmespath" + packages = ["."] + pruneopts = "" + revision = "c2b33e84" + +[[projects]] + digest = "1:e716a02584d94519e2ccf7ac461c4028da736d41a58c1ed95e641c1603bdb056" + name = "github.com/json-iterator/go" + packages = ["."] + pruneopts = "" + revision = "27518f6661eba504be5a7a9a9f6d9460d892ade3" + version = "v1.1.7" + +[[projects]] + digest = "1:0f51cee70b0d254dbc93c22666ea2abf211af81c1701a96d04e2284b408621db" + name = "github.com/konsorten/go-windows-terminal-sequences" + packages = ["."] + pruneopts = "" + revision = "f55edac94c9bbba5d6182a4be46d86a2c9b5b50e" + version = "v1.0.2" + +[[projects]] + digest = "1:2b5f0e6bc8fb862fed5bccf9fbb1ab819c8b3f8a21e813fe442c06aec3bb3e86" + name = "github.com/lyft/flyteidl" + packages = [ + "clients/go/admin", + "clients/go/admin/mocks", + "clients/go/coreutils", + "clients/go/coreutils/logs", + "clients/go/datacatalog/mocks", + "clients/go/events", + "clients/go/events/errors", + "gen/pb-go/flyteidl/admin", + "gen/pb-go/flyteidl/core", + "gen/pb-go/flyteidl/datacatalog", + "gen/pb-go/flyteidl/event", + "gen/pb-go/flyteidl/plugins", + "gen/pb-go/flyteidl/service", + ] + pruneopts = "" + revision = "793b09d190148236f41ad8160b5cec9a3325c16f" + source = "https://github.com/lyft/flyteidl" + version = "v0.1.0" + +[[projects]] + digest = "1:500471ee50c4141d3523c79615cc90529b3152f8aa5924b63122df6bf201a7a0" + name = "github.com/lyft/flyteplugins" + packages = [ + "go/tasks", + "go/tasks/v1", + "go/tasks/v1/config", + "go/tasks/v1/errors", + "go/tasks/v1/events", + "go/tasks/v1/flytek8s", + "go/tasks/v1/flytek8s/config", + "go/tasks/v1/k8splugins", + "go/tasks/v1/logs", + "go/tasks/v1/qubole", + "go/tasks/v1/qubole/client", + "go/tasks/v1/qubole/config", + "go/tasks/v1/qubole/mocks", + "go/tasks/v1/resourcemanager", + "go/tasks/v1/types", + "go/tasks/v1/types/mocks", + "go/tasks/v1/utils", + ] + pruneopts = "" + revision = "8c85a7c9f19de4df4767de329c56a7f09d0a7bbc" + source = "https://github.com/lyft/flyteplugins" + version = "v0.1.0" + +[[projects]] + digest = "1:77615fd7bcfd377c5e898c41d33be827f108166691cb91257d27113ee5d08650" + name = "github.com/lyft/flytestdlib" + packages = [ + "atomic", + "config", + "config/files", + "config/viper", + "contextutils", + "errors", + "ioutils", + "logger", + "pbhash", + "profutils", + "promutils", + "promutils/labeled", + "sets", + "storage", + "utils", + "version", + "yamlutils", + ] + pruneopts = "" + revision = "7292f20ec17b42f104fd61d7f0120e17bcacf751" + source = "https://github.com/lyft/flytestdlib" + version = "v0.2.16" + +[[projects]] + digest = "1:ae39921edb7f801f7ce1b6b5484f9715a1dd2b52cb645daef095cd10fd6ee774" + name = "github.com/magiconair/properties" + packages = [ + ".", + "assert", + ] + pruneopts = "" + revision = "de8848e004dd33dc07a2947b3d76f618a7fc7ef1" + version = "v1.8.1" + +[[projects]] + digest = "1:9ea83adf8e96d6304f394d40436f2eb44c1dc3250d223b74088cc253a6cd0a1c" + name = "github.com/mattn/go-colorable" + packages = ["."] + pruneopts = "" + revision = "167de6bfdfba052fa6b2d3664c8f5272e23c9072" + version = "v0.0.9" + +[[projects]] + digest = "1:dbfae9da5a674236b914e486086671145b37b5e3880a38da906665aede3c9eab" + name = "github.com/mattn/go-isatty" + packages = ["."] + pruneopts = "" + revision = "1311e847b0cb909da63b5fecfb5370aa66236465" + version = "v0.0.8" + +[[projects]] + digest = "1:63722a4b1e1717be7b98fc686e0b30d5e7f734b9e93d7dee86293b6deab7ea28" + name = "github.com/matttproud/golang_protobuf_extensions" + packages = ["pbutil"] + pruneopts = "" + revision = "c12348ce28de40eed0136aa2b644d0ee0650e56c" + version = "v1.0.1" + +[[projects]] + digest = "1:bcc46a0fbd9e933087bef394871256b5c60269575bb661935874729c65bbbf60" + name = "github.com/mitchellh/mapstructure" + packages = ["."] + pruneopts = "" + revision = "3536a929edddb9a5b34bd6861dc4a9647cb459fe" + version = "v1.1.2" + +[[projects]] + digest = "1:0c0ff2a89c1bb0d01887e1dac043ad7efbf3ec77482ef058ac423d13497e16fd" + name = "github.com/modern-go/concurrent" + packages = ["."] + pruneopts = "" + revision = "bacd9c7ef1dd9b15be4a9909b8ac7a4e313eec94" + version = "1.0.3" + +[[projects]] + digest = "1:e32bdbdb7c377a07a9a46378290059822efdce5c8d96fe71940d87cb4f918855" + name = "github.com/modern-go/reflect2" + packages = ["."] + pruneopts = "" + revision = "4b7aa43c6742a2c18fdef89dd197aaae7dac7ccd" + version = "1.0.1" + +[[projects]] + branch = "master" + digest = "1:b6c101f6c8ab09c631e969c30d3a4b42aeca82580499253bad77cb2426d4fc27" + name = "github.com/ncw/swift" + packages = ["."] + pruneopts = "" + revision = "a24ef33bc9b7e59ae4bed9e87a51d7bc76122731" + +[[projects]] + digest = "1:c1a07a723fa656d4ba5ac489fcb4dfa3aef0fec6b34e415f0002dfc5ee2ba872" + name = "github.com/operator-framework/operator-sdk" + packages = ["pkg/util/k8sutil"] + pruneopts = "" + revision = "e5a0ab096e1a7c0e6b937d2b41707eccb82c3c77" + version = "v0.0.7" + +[[projects]] + digest = "1:a5484d4fa43127138ae6e7b2299a6a52ae006c7f803d98d717f60abf3e97192e" + name = "github.com/pborman/uuid" + packages = ["."] + pruneopts = "" + revision = "adf5a7427709b9deb95d29d3fa8a2bf9cfd388f1" + version = "v1.2" + +[[projects]] + digest = "1:3d2c33720d4255686b9f4a7e4d3b94938ee36063f14705c5eb0f73347ed4c496" + name = "github.com/pelletier/go-toml" + packages = ["."] + pruneopts = "" + revision = "728039f679cbcd4f6a54e080d2219a4c4928c546" + version = "v1.4.0" + +[[projects]] + branch = "master" + digest = "1:5f0faa008e8ff4221b55a1a5057c8b02cb2fd68da6a65c9e31c82b72cbc836d0" + name = "github.com/petar/GoLLRB" + packages = ["llrb"] + pruneopts = "" + revision = "33fb24c13b99c46c93183c291836c573ac382536" + +[[projects]] + digest = "1:4709c61d984ef9ba99b037b047546d8a576ae984fb49486e48d99658aa750cd5" + name = "github.com/peterbourgon/diskv" + packages = ["."] + pruneopts = "" + revision = "0be1b92a6df0e4f5cb0a5d15fb7f643d0ad93ce6" + version = "v3.0.0" + +[[projects]] + digest = "1:1d7e1867c49a6dd9856598ef7c3123604ea3daabf5b83f303ff457bcbc410b1d" + name = "github.com/pkg/errors" + packages = ["."] + pruneopts = "" + revision = "ba968bfe8b2f7e042a574c888954fccecfa385b4" + version = "v0.8.1" + +[[projects]] + digest = "1:256484dbbcd271f9ecebc6795b2df8cad4c458dd0f5fd82a8c2fa0c29f233411" + name = "github.com/pmezard/go-difflib" + packages = ["difflib"] + pruneopts = "" + revision = "792786c7400a136282c1664665ae0a8db921c6c2" + version = "v1.0.0" + +[[projects]] + digest = "1:c826496cad27bd9a7644a01230a79d472b4093dd33587236e8f8369bb1d8534e" + name = "github.com/prometheus/client_golang" + packages = [ + "prometheus", + "prometheus/internal", + "prometheus/promhttp", + ] + pruneopts = "" + revision = "2641b987480bca71fb39738eb8c8b0d577cb1d76" + version = "v0.9.4" + +[[projects]] + branch = "master" + digest = "1:cd67319ee7536399990c4b00fae07c3413035a53193c644549a676091507cadc" + name = "github.com/prometheus/client_model" + packages = ["go"] + pruneopts = "" + revision = "fd36f4220a901265f90734c3183c5f0c91daa0b8" + +[[projects]] + digest = "1:0f2cee44695a3208fe5d6926076641499c72304e6f015348c9ab2df90a202cdf" + name = "github.com/prometheus/common" + packages = [ + "expfmt", + "internal/bitbucket.org/ww/goautoneg", + "model", + ] + pruneopts = "" + revision = "31bed53e4047fd6c510e43a941f90cb31be0972a" + version = "v0.6.0" + +[[projects]] + digest = "1:9b33e539d6bf6e4453668a847392d1e9e6345225ea1426f9341212c652bcbee4" + name = "github.com/prometheus/procfs" + packages = [ + ".", + "internal/fs", + ] + pruneopts = "" + revision = "3f98efb27840a48a7a2898ec80be07674d19f9c8" + version = "v0.0.3" + +[[projects]] + digest = "1:7f569d906bdd20d906b606415b7d794f798f91a62fcfb6a4daa6d50690fb7a3f" + name = "github.com/satori/uuid" + packages = ["."] + pruneopts = "" + revision = "f58768cc1a7a7e77a3bd49e98cdd21419399b6a3" + version = "v1.2.0" + +[[projects]] + digest = "1:1a405cddcf3368445051fb70ab465ae99da56ad7be8d8ca7fc52159d1c2d873c" + name = "github.com/sirupsen/logrus" + packages = ["."] + pruneopts = "" + revision = "839c75faf7f98a33d445d181f3018b5c3409a45e" + version = "v1.4.2" + +[[projects]] + digest = "1:956f655c87b7255c6b1ae6c203ebb0af98cf2a13ef2507e34c9bf1c0332ac0f5" + name = "github.com/spf13/afero" + packages = [ + ".", + "mem", + ] + pruneopts = "" + revision = "588a75ec4f32903aa5e39a2619ba6a4631e28424" + version = "v1.2.2" + +[[projects]] + digest = "1:ae3493c780092be9d576a1f746ab967293ec165e8473425631f06658b6212afc" + name = "github.com/spf13/cast" + packages = ["."] + pruneopts = "" + revision = "8c9545af88b134710ab1cd196795e7f2388358d7" + version = "v1.3.0" + +[[projects]] + digest = "1:0c63b3c7ad6d825a898f28cb854252a3b29d37700c68a117a977263f5ec94efe" + name = "github.com/spf13/cobra" + packages = ["."] + pruneopts = "" + revision = "f2b07da1e2c38d5f12845a4f607e2e1018cbb1f5" + version = "v0.0.5" + +[[projects]] + digest = "1:cc15ae4fbdb02ce31f3392361a70ac041f4f02e0485de8ffac92bd8033e3d26e" + name = "github.com/spf13/jwalterweatherman" + packages = ["."] + pruneopts = "" + revision = "94f6ae3ed3bceceafa716478c5fbf8d29ca601a1" + version = "v1.1.0" + +[[projects]] + digest = "1:cbaf13cdbfef0e4734ed8a7504f57fe893d471d62a35b982bf6fb3f036449a66" + name = "github.com/spf13/pflag" + packages = ["."] + pruneopts = "" + revision = "298182f68c66c05229eb03ac171abe6e309ee79a" + version = "v1.0.3" + +[[projects]] + digest = "1:c25a789c738f7cc8ec7f34026badd4e117853f329334a5aa45cf5d0727d7d442" + name = "github.com/spf13/viper" + packages = ["."] + pruneopts = "" + revision = "ae103d7e593e371c69e832d5eb3347e2b80cbbc9" + +[[projects]] + digest = "1:711eebe744c0151a9d09af2315f0bb729b2ec7637ef4c410fa90a18ef74b65b6" + name = "github.com/stretchr/objx" + packages = ["."] + pruneopts = "" + revision = "477a77ecc69700c7cdeb1fa9e129548e1c1c393c" + version = "v0.1.1" + +[[projects]] + digest = "1:381bcbeb112a51493d9d998bbba207a529c73dbb49b3fd789e48c63fac1f192c" + name = "github.com/stretchr/testify" + packages = [ + "assert", + "mock", + ] + pruneopts = "" + revision = "ffdc059bfe9ce6a4e144ba849dbedead332c6053" + version = "v1.3.0" + +[[projects]] + digest = "1:98f63c8942146f9bf4b3925db1d96637b86c1d83693a894a244eae54aa53bb40" + name = "go.opencensus.io" + packages = [ + ".", + "exemplar", + "internal", + "internal/tagencoding", + "plugin/ocgrpc", + "plugin/ochttp", + "plugin/ochttp/propagation/b3", + "plugin/ochttp/propagation/tracecontext", + "resource", + "stats", + "stats/internal", + "stats/view", + "tag", + "trace", + "trace/internal", + "trace/propagation", + "trace/tracestate", + ] + pruneopts = "" + revision = "aab39bd6a98b853ab66c8a564f5d6cfcad59ce8a" + +[[projects]] + digest = "1:e6ff7840319b6fda979a918a8801005ec2049abca62af19211d96971d8ec3327" + name = "go.uber.org/atomic" + packages = ["."] + pruneopts = "" + revision = "df976f2515e274675050de7b3f42545de80594fd" + version = "v1.4.0" + +[[projects]] + digest = "1:22c7effcb4da0eacb2bb1940ee173fac010e9ef3c691f5de4b524d538bd980f5" + name = "go.uber.org/multierr" + packages = ["."] + pruneopts = "" + revision = "3c4937480c32f4c13a875a1829af76c98ca3d40a" + version = "v1.1.0" + +[[projects]] + digest = "1:984e93aca9088b440b894df41f2043b6a3db8f9cf30767032770bfc4796993b0" + name = "go.uber.org/zap" + packages = [ + ".", + "buffer", + "internal/bufferpool", + "internal/color", + "internal/exit", + "zapcore", + ] + pruneopts = "" + revision = "27376062155ad36be76b0f12cf1572a221d3a48c" + version = "v1.10.0" + +[[projects]] + branch = "master" + digest = "1:086760278d762dbb0e9a26e09b57f04c89178c86467d8d94fae47d64c222f328" + name = "golang.org/x/crypto" + packages = ["ssh/terminal"] + pruneopts = "" + revision = "4def268fd1a49955bfb3dda92fe3db4f924f2285" + +[[projects]] + branch = "master" + digest = "1:955694a7c42527d7fb188505a22f10b3e158c6c2cf31fe64b1e62c9ab7b18401" + name = "golang.org/x/net" + packages = [ + "context", + "context/ctxhttp", + "http/httpguts", + "http2", + "http2/hpack", + "idna", + "internal/timeseries", + "trace", + ] + pruneopts = "" + revision = "ca1201d0de80cfde86cb01aea620983605dfe99b" + +[[projects]] + branch = "master" + digest = "1:01bdbbc604dcd5afb6f66a717f69ad45e9643c72d5bc11678d44ffa5c50f9e42" + name = "golang.org/x/oauth2" + packages = [ + ".", + "google", + "internal", + "jws", + "jwt", + ] + pruneopts = "" + revision = "0f29369cfe4552d0e4bcddc57cc75f4d7e672a33" + +[[projects]] + branch = "master" + digest = "1:9f6efefb4e401a4f699a295d14518871368eb89403f2dd23ec11dfcd2c0836ba" + name = "golang.org/x/sync" + packages = ["semaphore"] + pruneopts = "" + revision = "112230192c580c3556b8cee6403af37a4fc5f28c" + +[[projects]] + branch = "master" + digest = "1:0b5c2207c72f2d13995040f176feb6e3f453d6b01af2b9d57df76b05ded2e926" + name = "golang.org/x/sys" + packages = [ + "unix", + "windows", + ] + pruneopts = "" + revision = "51ab0e2deafac1f46c46ad59cf0921be2f180c3d" + +[[projects]] + digest = "1:740b51a55815493a8d0f2b1e0d0ae48fe48953bf7eaf3fcc4198823bf67768c0" + name = "golang.org/x/text" + packages = [ + "collate", + "collate/build", + "internal/colltab", + "internal/gen", + "internal/language", + "internal/language/compact", + "internal/tag", + "internal/triegen", + "internal/ucd", + "language", + "secure/bidirule", + "transform", + "unicode/bidi", + "unicode/cldr", + "unicode/norm", + "unicode/rangetable", + ] + pruneopts = "" + revision = "342b2e1fbaa52c93f31447ad2c6abc048c63e475" + version = "v0.3.2" + +[[projects]] + branch = "master" + digest = "1:9522af4be529c108010f95b05f1022cb872f2b9ff8b101080f554245673466e1" + name = "golang.org/x/time" + packages = ["rate"] + pruneopts = "" + revision = "9d24e82272b4f38b78bc8cff74fa936d31ccd8ef" + +[[projects]] + branch = "master" + digest = "1:3f52587092bc722a3c3843989e6b88ec26924dc4b7b9c971095b7e93a11e0eff" + name = "golang.org/x/tools" + packages = [ + "go/ast/astutil", + "go/gcexportdata", + "go/internal/gcimporter", + "go/internal/packagesdriver", + "go/packages", + "go/types/typeutil", + "imports", + "internal/fastwalk", + "internal/gopathwalk", + "internal/imports", + "internal/module", + "internal/semver", + ] + pruneopts = "" + revision = "e713427fea3f98cb070e72a058c557a1a560cf22" + +[[projects]] + branch = "master" + digest = "1:f77558501305be5977ac30110f9820d21c5f1a89328667dc82db0bd9ebaab4c4" + name = "google.golang.org/api" + packages = [ + "gensupport", + "googleapi", + "googleapi/internal/uritemplates", + "googleapi/transport", + "internal", + "option", + "storage/v1", + "support/bundler", + "transport/http", + "transport/http/internal/propagation", + ] + pruneopts = "" + revision = "6f3912904777a209e099b9dbda3ed7bcb4e25ad7" + +[[projects]] + digest = "1:47f391ee443f578f01168347818cb234ed819521e49e4d2c8dd2fb80d48ee41a" + name = "google.golang.org/appengine" + packages = [ + ".", + "internal", + "internal/app_identity", + "internal/base", + "internal/datastore", + "internal/log", + "internal/modules", + "internal/remote_api", + "internal/urlfetch", + "urlfetch", + ] + pruneopts = "" + revision = "b2f4a3cf3c67576a2ee09e1fe62656a5086ce880" + version = "v1.6.1" + +[[projects]] + branch = "master" + digest = "1:95b0a53d4d31736b2483a8c41667b2bd83f303706106f81bd2f54e3f9c24eaf4" + name = "google.golang.org/genproto" + packages = [ + "googleapis/api/annotations", + "googleapis/api/httpbody", + "googleapis/rpc/status", + "protobuf/field_mask", + ] + pruneopts = "" + revision = "fa694d86fc64c7654a660f8908de4e879866748d" + +[[projects]] + digest = "1:425ee670b3e8b6562e31754021a82d78aa46b9281247827376616c8aa78f4687" + name = "google.golang.org/grpc" + packages = [ + ".", + "balancer", + "balancer/base", + "balancer/roundrobin", + "binarylog/grpc_binarylog_v1", + "codes", + "connectivity", + "credentials", + "credentials/internal", + "encoding", + "encoding/proto", + "grpclog", + "internal", + "internal/backoff", + "internal/balancerload", + "internal/binarylog", + "internal/channelz", + "internal/envconfig", + "internal/grpcrand", + "internal/grpcsync", + "internal/syscall", + "internal/transport", + "keepalive", + "metadata", + "naming", + "peer", + "resolver", + "resolver/dns", + "resolver/passthrough", + "serviceconfig", + "stats", + "status", + "tap", + ] + pruneopts = "" + revision = "045159ad57f3781d409358e3ade910a018c16b30" + version = "v1.22.1" + +[[projects]] + digest = "1:75fb3fcfc73a8c723efde7777b40e8e8ff9babf30d8c56160d01beffea8a95a6" + name = "gopkg.in/inf.v0" + packages = ["."] + pruneopts = "" + revision = "d2d2541c53f18d2a059457998ce2876cc8e67cbf" + version = "v0.9.1" + +[[projects]] + digest = "1:cedccf16b71e86db87a24f8d4c70b0a855872eb967cb906a66b95de56aefbd0d" + name = "gopkg.in/yaml.v2" + packages = ["."] + pruneopts = "" + revision = "51d6538a90f86fe93ac480b35f37b2be17fef232" + version = "v2.2.2" + +[[projects]] + digest = "1:73ee122857f257aa507ebae097783fe08ad8af49398e5b3876787325411f1a4b" + name = "k8s.io/api" + packages = [ + "admission/v1beta1", + "admissionregistration/v1alpha1", + "admissionregistration/v1beta1", + "apps/v1", + "apps/v1beta1", + "apps/v1beta2", + "auditregistration/v1alpha1", + "authentication/v1", + "authentication/v1beta1", + "authorization/v1", + "authorization/v1beta1", + "autoscaling/v1", + "autoscaling/v2beta1", + "autoscaling/v2beta2", + "batch/v1", + "batch/v1beta1", + "batch/v2alpha1", + "certificates/v1beta1", + "coordination/v1beta1", + "core/v1", + "events/v1beta1", + "extensions/v1beta1", + "networking/v1", + "policy/v1beta1", + "rbac/v1", + "rbac/v1alpha1", + "rbac/v1beta1", + "scheduling/v1alpha1", + "scheduling/v1beta1", + "settings/v1alpha1", + "storage/v1", + "storage/v1alpha1", + "storage/v1beta1", + ] + pruneopts = "" + revision = "27b77cf22008a0bf6e510b2a500b8885805a1b68" + source = "https://github.com/lyft/api" + +[[projects]] + digest = "1:a3bee4b1e4013573fc15631b51a7b7e0d580497e6fec63dc3724b370e624569f" + name = "k8s.io/apimachinery" + packages = [ + "pkg/api/errors", + "pkg/api/meta", + "pkg/api/resource", + "pkg/apis/meta/internalversion", + "pkg/apis/meta/v1", + "pkg/apis/meta/v1/unstructured", + "pkg/apis/meta/v1beta1", + "pkg/conversion", + "pkg/conversion/queryparams", + "pkg/fields", + "pkg/labels", + "pkg/runtime", + "pkg/runtime/schema", + "pkg/runtime/serializer", + "pkg/runtime/serializer/json", + "pkg/runtime/serializer/protobuf", + "pkg/runtime/serializer/recognizer", + "pkg/runtime/serializer/streaming", + "pkg/runtime/serializer/versioning", + "pkg/selection", + "pkg/types", + "pkg/util/cache", + "pkg/util/clock", + "pkg/util/diff", + "pkg/util/errors", + "pkg/util/framer", + "pkg/util/intstr", + "pkg/util/json", + "pkg/util/mergepatch", + "pkg/util/naming", + "pkg/util/net", + "pkg/util/rand", + "pkg/util/runtime", + "pkg/util/sets", + "pkg/util/strategicpatch", + "pkg/util/uuid", + "pkg/util/validation", + "pkg/util/validation/field", + "pkg/util/wait", + "pkg/util/yaml", + "pkg/version", + "pkg/watch", + "third_party/forked/golang/json", + "third_party/forked/golang/reflect", + ] + pruneopts = "" + revision = "695912cabb3a9c08353fbe628115867d47f56b1f" + source = "https://github.com/lyft/apimachinery" + +[[projects]] + digest = "1:8dbb2adde0a196cc682fcff9f26d5d7f407a8639c8ee2936ac5da514582e1f65" + name = "k8s.io/client-go" + packages = [ + "discovery", + "discovery/fake", + "dynamic", + "kubernetes", + "kubernetes/scheme", + "kubernetes/typed/admissionregistration/v1alpha1", + "kubernetes/typed/admissionregistration/v1beta1", + "kubernetes/typed/apps/v1", + "kubernetes/typed/apps/v1beta1", + "kubernetes/typed/apps/v1beta2", + "kubernetes/typed/auditregistration/v1alpha1", + "kubernetes/typed/authentication/v1", + "kubernetes/typed/authentication/v1beta1", + "kubernetes/typed/authorization/v1", + "kubernetes/typed/authorization/v1beta1", + "kubernetes/typed/autoscaling/v1", + "kubernetes/typed/autoscaling/v2beta1", + "kubernetes/typed/autoscaling/v2beta2", + "kubernetes/typed/batch/v1", + "kubernetes/typed/batch/v1beta1", + "kubernetes/typed/batch/v2alpha1", + "kubernetes/typed/certificates/v1beta1", + "kubernetes/typed/coordination/v1beta1", + "kubernetes/typed/core/v1", + "kubernetes/typed/events/v1beta1", + "kubernetes/typed/extensions/v1beta1", + "kubernetes/typed/networking/v1", + "kubernetes/typed/policy/v1beta1", + "kubernetes/typed/rbac/v1", + "kubernetes/typed/rbac/v1alpha1", + "kubernetes/typed/rbac/v1beta1", + "kubernetes/typed/scheduling/v1alpha1", + "kubernetes/typed/scheduling/v1beta1", + "kubernetes/typed/settings/v1alpha1", + "kubernetes/typed/storage/v1", + "kubernetes/typed/storage/v1alpha1", + "kubernetes/typed/storage/v1beta1", + "pkg/apis/clientauthentication", + "pkg/apis/clientauthentication/v1alpha1", + "pkg/apis/clientauthentication/v1beta1", + "pkg/version", + "plugin/pkg/client/auth/exec", + "rest", + "rest/watch", + "restmapper", + "testing", + "tools/auth", + "tools/cache", + "tools/clientcmd", + "tools/clientcmd/api", + "tools/clientcmd/api/latest", + "tools/clientcmd/api/v1", + "tools/leaderelection", + "tools/leaderelection/resourcelock", + "tools/metrics", + "tools/pager", + "tools/record", + "tools/reference", + "transport", + "util/buffer", + "util/cert", + "util/connrotation", + "util/flowcontrol", + "util/homedir", + "util/integer", + "util/retry", + "util/workqueue", + ] + pruneopts = "" + revision = "7621a5ebb88b1e49ce7e7837ae8e99ca030a3c13" + version = "kubernetes-1.13.5" + +[[projects]] + digest = "1:d809e6c8dfa3448ae10f5624eff4ed1ebdc906755e7cea294c44e8b7ac0b077a" + name = "k8s.io/code-generator" + packages = [ + "cmd/client-gen", + "cmd/client-gen/args", + "cmd/client-gen/generators", + "cmd/client-gen/generators/fake", + "cmd/client-gen/generators/scheme", + "cmd/client-gen/generators/util", + "cmd/client-gen/path", + "cmd/client-gen/types", + "cmd/conversion-gen", + "cmd/conversion-gen/args", + "cmd/conversion-gen/generators", + "cmd/deepcopy-gen", + "cmd/deepcopy-gen/args", + "cmd/defaulter-gen", + "cmd/defaulter-gen/args", + "cmd/informer-gen", + "cmd/informer-gen/args", + "cmd/informer-gen/generators", + "cmd/lister-gen", + "cmd/lister-gen/args", + "cmd/lister-gen/generators", + "pkg/util", + ] + pruneopts = "" + revision = "c2090bec4d9b1fb25de3812f868accc2bc9ecbae" + version = "kubernetes-1.13.5" + +[[projects]] + branch = "master" + digest = "1:6a2a63e09a59caff3fd2d36d69b7b92c2fe7cf783390f0b7349fb330820f9a8e" + name = "k8s.io/gengo" + packages = [ + "args", + "examples/deepcopy-gen/generators", + "examples/defaulter-gen/generators", + "examples/set-gen/sets", + "generator", + "namer", + "parser", + "types", + ] + pruneopts = "" + revision = "e17681d19d3ac4837a019ece36c2a0ec31ffe985" + +[[projects]] + digest = "1:3063061b6514ad2666c4fa292451685884cacf77c803e1b10b4a4fa23f7787fb" + name = "k8s.io/klog" + packages = ["."] + pruneopts = "" + revision = "3ca30a56d8a775276f9cdae009ba326fdc05af7f" + version = "v0.4.0" + +[[projects]] + branch = "master" + digest = "1:3176cac3365c8442ab92d465e69e05071b0dbc0d715e66b76059b04611811dff" + name = "k8s.io/kube-openapi" + packages = ["pkg/util/proto"] + pruneopts = "" + revision = "5e22f3d471e6f24ca20becfdffdc6206c7cecac8" + +[[projects]] + digest = "1:77629c3c036454b4623e99e20f5591b9551dd81d92db616384af92435b52e9b6" + name = "sigs.k8s.io/controller-runtime" + packages = [ + "pkg/cache", + "pkg/cache/informertest", + "pkg/cache/internal", + "pkg/client", + "pkg/client/apiutil", + "pkg/client/config", + "pkg/client/fake", + "pkg/controller/controllertest", + "pkg/event", + "pkg/handler", + "pkg/internal/objectutil", + "pkg/internal/recorder", + "pkg/leaderelection", + "pkg/manager", + "pkg/metrics", + "pkg/patch", + "pkg/predicate", + "pkg/reconcile", + "pkg/recorder", + "pkg/runtime/inject", + "pkg/runtime/log", + "pkg/source", + "pkg/source/internal", + "pkg/webhook/admission", + "pkg/webhook/admission/types", + "pkg/webhook/internal/metrics", + "pkg/webhook/types", + ] + pruneopts = "" + revision = "f1eaba5087d69cebb154c6a48193e6667f5b512c" + version = "v0.1.12" + +[[projects]] + digest = "1:321081b4a44256715f2b68411d8eda9a17f17ebfe6f0cc61d2cc52d11c08acfa" + name = "sigs.k8s.io/yaml" + packages = ["."] + pruneopts = "" + revision = "fd68e9863619f6ec2fdd8625fe1f02e7c877e480" + version = "v1.1.0" + +[solve-meta] + analyzer-name = "dep" + analyzer-version = 1 + input-imports = [ + "github.com/DiSiqueira/GoTree", + "github.com/fatih/color", + "github.com/ghodss/yaml", + "github.com/golang/protobuf/jsonpb", + "github.com/golang/protobuf/proto", + "github.com/golang/protobuf/ptypes", + "github.com/golang/protobuf/ptypes/struct", + "github.com/golang/protobuf/ptypes/timestamp", + "github.com/grpc-ecosystem/go-grpc-middleware/retry", + "github.com/lyft/flyteidl/clients/go/admin", + "github.com/lyft/flyteidl/clients/go/admin/mocks", + "github.com/lyft/flyteidl/clients/go/coreutils", + "github.com/lyft/flyteidl/clients/go/datacatalog/mocks", + "github.com/lyft/flyteidl/clients/go/events", + "github.com/lyft/flyteidl/clients/go/events/errors", + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin", + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core", + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/datacatalog", + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/event", + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/service", + "github.com/lyft/flyteplugins/go/tasks", + "github.com/lyft/flyteplugins/go/tasks/v1", + "github.com/lyft/flyteplugins/go/tasks/v1/flytek8s", + "github.com/lyft/flyteplugins/go/tasks/v1/types", + "github.com/lyft/flyteplugins/go/tasks/v1/types/mocks", + "github.com/lyft/flytestdlib/config", + "github.com/lyft/flytestdlib/config/viper", + "github.com/lyft/flytestdlib/contextutils", + "github.com/lyft/flytestdlib/logger", + "github.com/lyft/flytestdlib/pbhash", + "github.com/lyft/flytestdlib/profutils", + "github.com/lyft/flytestdlib/promutils", + "github.com/lyft/flytestdlib/promutils/labeled", + "github.com/lyft/flytestdlib/storage", + "github.com/lyft/flytestdlib/utils", + "github.com/lyft/flytestdlib/version", + "github.com/lyft/flytestdlib/yamlutils", + "github.com/magiconair/properties/assert", + "github.com/mitchellh/mapstructure", + "github.com/operator-framework/operator-sdk/pkg/util/k8sutil", + "github.com/pkg/errors", + "github.com/prometheus/client_golang/prometheus", + "github.com/spf13/cobra", + "github.com/spf13/pflag", + "github.com/stretchr/testify/assert", + "github.com/stretchr/testify/mock", + "golang.org/x/time/rate", + "google.golang.org/grpc", + "google.golang.org/grpc/codes", + "google.golang.org/grpc/status", + "k8s.io/api/batch/v1", + "k8s.io/api/core/v1", + "k8s.io/apimachinery/pkg/api/errors", + "k8s.io/apimachinery/pkg/api/resource", + "k8s.io/apimachinery/pkg/apis/meta/v1", + "k8s.io/apimachinery/pkg/labels", + "k8s.io/apimachinery/pkg/runtime", + "k8s.io/apimachinery/pkg/runtime/schema", + "k8s.io/apimachinery/pkg/runtime/serializer", + "k8s.io/apimachinery/pkg/types", + "k8s.io/apimachinery/pkg/util/clock", + "k8s.io/apimachinery/pkg/util/rand", + "k8s.io/apimachinery/pkg/util/runtime", + "k8s.io/apimachinery/pkg/util/sets", + "k8s.io/apimachinery/pkg/util/wait", + "k8s.io/apimachinery/pkg/watch", + "k8s.io/client-go/discovery", + "k8s.io/client-go/discovery/fake", + "k8s.io/client-go/kubernetes", + "k8s.io/client-go/kubernetes/scheme", + "k8s.io/client-go/kubernetes/typed/core/v1", + "k8s.io/client-go/rest", + "k8s.io/client-go/testing", + "k8s.io/client-go/tools/cache", + "k8s.io/client-go/tools/clientcmd", + "k8s.io/client-go/tools/leaderelection", + "k8s.io/client-go/tools/leaderelection/resourcelock", + "k8s.io/client-go/tools/record", + "k8s.io/client-go/util/flowcontrol", + "k8s.io/client-go/util/workqueue", + "k8s.io/code-generator/cmd/client-gen", + "k8s.io/code-generator/cmd/conversion-gen", + "k8s.io/code-generator/cmd/deepcopy-gen", + "k8s.io/code-generator/cmd/defaulter-gen", + "k8s.io/code-generator/cmd/informer-gen", + "k8s.io/code-generator/cmd/lister-gen", + "k8s.io/gengo/args", + "sigs.k8s.io/controller-runtime/pkg/cache", + "sigs.k8s.io/controller-runtime/pkg/cache/informertest", + "sigs.k8s.io/controller-runtime/pkg/client", + "sigs.k8s.io/controller-runtime/pkg/client/fake", + "sigs.k8s.io/controller-runtime/pkg/manager", + "sigs.k8s.io/controller-runtime/pkg/runtime/inject", + ] + solver-name = "gps-cdcl" + solver-version = 1 diff --git a/flytepropeller/Gopkg.toml b/flytepropeller/Gopkg.toml new file mode 100644 index 0000000000..1a0e8974e5 --- /dev/null +++ b/flytepropeller/Gopkg.toml @@ -0,0 +1,102 @@ +# Gopkg.toml example +# +# Refer to https://golang.github.io/dep/docs/Gopkg.toml.html +# for detailed Gopkg.toml documentation. +# +# required = ["github.com/user/thing/cmd/thing"] +# ignored = ["github.com/user/project/pkgX", "bitbucket.org/user/project/pkgA/pkgY"] +# +# [[constraint]] +# name = "github.com/user/project" +# version = "1.0.0" +# +# [[constraint]] +# name = "github.com/user/project2" +# branch = "dev" +# source = "github.com/myfork/project2" +# +# [[override]] +# name = "github.com/x/y" +# version = "2.4.0" +# +# [prune] +# non-go = false +# go-tests = true +# unused-packages = true + +required = [ + "k8s.io/code-generator/cmd/defaulter-gen", + "k8s.io/code-generator/cmd/deepcopy-gen", + "k8s.io/code-generator/cmd/conversion-gen", + "k8s.io/code-generator/cmd/client-gen", + "k8s.io/code-generator/cmd/lister-gen", + "k8s.io/code-generator/cmd/informer-gen", + "k8s.io/gengo/args", +] + +[[constraint]] + name = "github.com/fatih/color" + version = "1.7.0" + +[[override]] + name = "contrib.go.opencensus.io/exporter/ocagent" + version = "0.4.x" + +[[constraint]] + name = "github.com/golang/protobuf" + version = "1.1.0" + +[[constraint]] + name = "github.com/lyft/flyteidl" + source = "https://github.com/lyft/flyteidl" + version = "^0.1.x" + +[[constraint]] + name = "github.com/lyft/flyteplugins" + source = "https://github.com/lyft/flyteplugins" + version = "^0.1.0" + +[[override]] + name = "github.com/lyft/flytestdlib" + source = "https://github.com/lyft/flytestdlib" + version = "^0.2.16" + +# Spark has a dependency on 1.11.2, so we cannot upgrade yet +[[override]] + name = "k8s.io/api" + revision = "27b77cf22008a0bf6e510b2a500b8885805a1b68" + source = "https://github.com/lyft/api" + +[[override]] + name = "k8s.io/apimachinery" + source = "https://github.com/lyft/apimachinery" + revision = "695912cabb3a9c08353fbe628115867d47f56b1f" + +[[override]] + name = "k8s.io/client-go" + version = "kubernetes-1.13.5" + +[[constraint]] + name = "github.com/DiSiqueira/GoTree" + version = "2.0.3" + +[[override]] + name = "k8s.io/code-generator" + # revision = "6702109cc68eb6fe6350b83e14407c8d7309fd1a" + version = "kubernetes-1.13.5" + +[[override]] + name = "github.com/graymeta/stow" + revision = "903027f87de7054953efcdb8ba70d5dc02df38c7" + +[[override]] + name = "github.com/json-iterator/go" + version = "^1.1.5" + +[[override]] + name = "sigs.k8s.io/controller-runtime" + version = "=v0.1.12" + +[[override]] + branch = "master" + name = "golang.org/x/net" diff --git a/flytepropeller/LICENSE b/flytepropeller/LICENSE new file mode 100644 index 0000000000..bed437514f --- /dev/null +++ b/flytepropeller/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2019 Lyft, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/flytepropeller/Makefile b/flytepropeller/Makefile new file mode 100644 index 0000000000..30db27782e --- /dev/null +++ b/flytepropeller/Makefile @@ -0,0 +1,44 @@ +export REPOSITORY=flytepropeller +include boilerplate/lyft/docker_build/Makefile +include boilerplate/lyft/golang_test_targets/Makefile + +.PHONY: update_boilerplate +update_boilerplate: + @boilerplate/update.sh + +.PHONY: linux_compile +linux_compile: + GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -o /artifacts/flytepropeller ./cmd/controller/main.go + GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -o /artifacts/kubectl-flyte ./cmd/kubectl-flyte/main.go + +.PHONY: compile +compile: + mkdir -p ./bin + go build -o bin/flytepropeller ./cmd/controller/main.go + go build -o bin/kubectl-flyte ./cmd/kubectl-flyte/main.go && cp bin/kubectl-flyte ${GOPATH}/bin + +cross_compile: + @glide install + @mkdir -p ./bin/cross + GOOS=linux GOARCH=amd64 go build -o bin/cross/flytepropeller ./cmd/controller/main.go + GOOS=linux GOARCH=amd64 go build -o bin/cross/kubectl-flyte ./cmd/kubectl-flyte/main.go + +op_code_generate: + @RESOURCE_NAME=flyteworkflow OPERATOR_PKG=github.com/lyft/flytepropeller ./hack/update-codegen.sh + +benchmark: + mkdir -p ./bin/benchmark + @go test -run=^$ -bench=. -cpuprofile=cpu.out -memprofile=mem.out ./pkg/controller/nodes/. && mv *.out ./bin/benchmark/ && mv *.test ./bin/benchmark/ + +# server starts the service in development mode +.PHONY: server +server: + @go run ./cmd/controller/main.go -logtostderr --kubeconfig=$(HOME)/.kube/config + +clean: + rm -rf bin + +# Generate golden files. Add test packages that generate golden files here. +golden: + go test ./cmd/kubectl-flyte/cmd -update + go test ./pkg/compiler/test -update diff --git a/flytepropeller/NOTICE b/flytepropeller/NOTICE new file mode 100644 index 0000000000..dba5e33d0b --- /dev/null +++ b/flytepropeller/NOTICE @@ -0,0 +1,5 @@ +flytepropeller +Copyright 2019 Lyft Inc. + +This product includes software developed at Lyft Inc. +This product includes software derived from https://github.com/kubernetes/sample-controller diff --git a/flytepropeller/README.rst b/flytepropeller/README.rst new file mode 100644 index 0000000000..de0f1ac2cb --- /dev/null +++ b/flytepropeller/README.rst @@ -0,0 +1,129 @@ +Flyte Propeller +=============== + + +.. +.. image:: https://img.shields.io/github/release/lyft/flytepropeller.svg +:target: https://github.com/lyft/flytepropeller/releases/latest + +.. image:: https://godoc.org/github.com/lyft/flytepropeller?status.svg +:target: https://godoc.org/github.com/lyft/flytepropeller) + +.. image:: https://img.shields.io/badge/LICENSE-Apache2.0-ff69b4.svg +:target: http://www.apache.org/licenses/LICENSE-2.0.html) + +.. image:: https://img.shields.io/codecov/c/github/lyft/flytepropeller.svg +:target: https://codecov.io/gh/lyft/flytepropeller + +.. image:: https://goreportcard.com/badge/github.com/lyft/flytepropeller +:target: https://goreportcard.com/report/github.com/lyft/flytepropeller + +.. image:: https://img.shields.io/github/commit-activity/w/lyft/flytepropeller.svg?style=plastic + +.. image:: https://img.shields.io/github/commits-since/lyft/flytepropeller/latest.svg?style=plastic + + +Kubernetes operator to executes Flyte graphs natively on kubernetes + +Getting Started +=============== +kubectl-flyte tool +------------------ +kubectl-flyte is an command line tool that can be used as an extension to kubectl. It is a separate binary that is built from the propeller repo. + +Install +------- +This command will install kubectl-flyte and flytepropeller to `~/go/bin` +.. code-block:: make + + $make compile + +Use +--- +Two ways to execute the command, either standalone *kubectl-flyte* or as a subcommand of *kubectl* + +.. code-block:: command + + $ kubectl-flyte --help + OR + $ kubectl flyte --help + Flyte is a serverless workflow processing platform built for native execution on K8s. + It is extensible and flexible to allow adding new operators and comes with many operators built in + + Usage: + kubectl-flyte [flags] + kubectl-flyte [command] + + Available Commands: + compile Compile a workflow from core proto-buffer files and output a closure. + config Runs various config commands, look at the help of this command to get a list of available commands.. + create Creates a new workflow from proto-buffer files. + delete delete a workflow + get Gets a single workflow or lists all workflows currently in execution + help Help about any command + visualize Get GraphViz dot-formatted output. + + +Observing running workflows +--------------------------- + +To retrieve all workflows in a namespace use the --namespace option, --namespace = "" implies all namespaces. + +.. code-block:: command + + $ kubectl-flyte get --namespace flytekit-development + workflows + ├── flytekit-development/flytekit-development-f01c74085110840b8827 [ExecId: ... ] (2m34s Succeeded) - Time SinceCreation(30h1m39.683602s) + ... + Found 19 workflows + Success: 19, Failed: 0, Running: 0, Waiting: 0 + + +To retrieve a specific workflow, namespace can either be provided in the format namespace/name or using the --namespace argument + +.. code-block:: command + + $ kubectl-flyte get flytekit-development/flytekit-development-ff806e973581f4508bf1 + Workflow + └── flytekit-development/flytekit-development-ff806e973581f4508bf1 [ExecId: project:"flytekit" domain:"development" name:"ff806e973581f4508bf1" ] (2m32s Succeeded ) + ├── start-node start 0s Succeeded + ├── c task 0s Succeeded + ├── b task 0s Succeeded + ├── a task 0s Succeeded + └── end-node end 0s Succeeded + +Deleting workflows +------------------ +To delete a specific workflow + +.. code-block:: command + + $ kubectl-flyte delete --namespace flytekit-development flytekit-development-ff806e973581f4508bf1 + +To delete all completed workflows - they have to be either success/failed with a special isCompleted label set on them. The Label is set `here ` + +.. code-block:: command + + $ kubectl-flyte delete --namespace flytekit-development --all-completed + +Running propeller locally +------------------------- +use the config.yaml in root found `here `. Cd into this folder and then run + +.. code-block:: command + + $ flytepropeller --logtostderr + +Following dependencies need to be met +1. Blob store (you can forward minio port to localhost) +2. Admin Service endpoint (can be forwarded) OR *Disable* events to admin and launchplans +3. access to kubeconfig and kubeapi + +Making changes to CRD +===================== +*Remember* changes to CRD should be carefully done, they should be backwards compatible or else you should use proper +operator versioning system. Once you do the changes, remember to execute + +.. code-block:: make + + $make op_code_generate diff --git a/flytepropeller/boilerplate/lyft/docker_build/Makefile b/flytepropeller/boilerplate/lyft/docker_build/Makefile new file mode 100644 index 0000000000..4019dab839 --- /dev/null +++ b/flytepropeller/boilerplate/lyft/docker_build/Makefile @@ -0,0 +1,12 @@ +# WARNING: THIS FILE IS MANAGED IN THE 'BOILERPLATE' REPO AND COPIED TO OTHER REPOSITORIES. +# ONLY EDIT THIS FILE FROM WITHIN THE 'LYFT/BOILERPLATE' REPOSITORY: +# +# TO OPT OUT OF UPDATES, SEE https://github.com/lyft/boilerplate/blob/master/Readme.rst + +.PHONY: docker_build +docker_build: + IMAGE_NAME=$$REPOSITORY ./boilerplate/lyft/docker_build/docker_build.sh + +.PHONY: dockerhub_push +dockerhub_push: + IMAGE_NAME=lyft/$$REPOSITORY REGISTRY=docker.io ./boilerplate/lyft/docker_build/docker_build.sh diff --git a/flytepropeller/boilerplate/lyft/docker_build/Readme.rst b/flytepropeller/boilerplate/lyft/docker_build/Readme.rst new file mode 100644 index 0000000000..bb6af9b49e --- /dev/null +++ b/flytepropeller/boilerplate/lyft/docker_build/Readme.rst @@ -0,0 +1,23 @@ +Docker Build and Push +~~~~~~~~~~~~~~~~~~~~~ + +Provides a ``make docker_build`` target that builds your image locally. + +Provides a ``make dockerhub_push`` target that pushes your final image to Dockerhub. + +The Dockerhub image will tagged ``:`` + +If git head has a git tag, the Dockerhub image will also be tagged ``:``. + +**To Enable:** + +Add ``lyft/docker_build`` to your ``boilerplate/update.cfg`` file. + +Add ``include boilerplate/lyft/docker_build/Makefile`` in your main ``Makefile`` _after_ your REPOSITORY environment variable + +:: + + REPOSITORY= + include boilerplate/lyft/docker_build/Makefile + +(this ensures the extra Make targets get included in your main Makefile) diff --git a/flytepropeller/boilerplate/lyft/docker_build/docker_build.sh b/flytepropeller/boilerplate/lyft/docker_build/docker_build.sh new file mode 100755 index 0000000000..f504c100c7 --- /dev/null +++ b/flytepropeller/boilerplate/lyft/docker_build/docker_build.sh @@ -0,0 +1,67 @@ +#!/usr/bin/env bash + +# WARNING: THIS FILE IS MANAGED IN THE 'BOILERPLATE' REPO AND COPIED TO OTHER REPOSITORIES. +# ONLY EDIT THIS FILE FROM WITHIN THE 'LYFT/BOILERPLATE' REPOSITORY: +# +# TO OPT OUT OF UPDATES, SEE https://github.com/lyft/boilerplate/blob/master/Readme.rst + +set -e + +echo "" +echo "------------------------------------" +echo " DOCKER BUILD" +echo "------------------------------------" +echo "" + +if [ -n "$REGISTRY" ]; then + # Do not push if there are unstaged git changes + CHANGED=$(git status --porcelain) + if [ -n "$CHANGED" ]; then + echo "Please commit git changes before pushing to a registry" + exit 1 + fi +fi + + +GIT_SHA=$(git rev-parse HEAD) + +IMAGE_TAG_SUFFIX="" +# for intermediate build phases, append -$BUILD_PHASE to all image tags +if [ -n "$BUILD_PHASE" ]; then + IMAGE_TAG_SUFFIX="-${BUILD_PHASE}" +fi + +IMAGE_TAG_WITH_SHA="${IMAGE_NAME}:${GIT_SHA}${IMAGE_TAG_SUFFIX}" + +RELEASE_SEMVER=$(git describe --tags --exact-match "$GIT_SHA" 2>/dev/null) || true +if [ -n "$RELEASE_SEMVER" ]; then + IMAGE_TAG_WITH_SEMVER="${IMAGE_NAME}:${RELEASE_SEMVER}${IMAGE_TAG_SUFFIX}" +fi + +# build the image +# passing no build phase will build the final image +docker build -t "$IMAGE_TAG_WITH_SHA" --target=${BUILD_PHASE} . +echo "${IMAGE_TAG_WITH_SHA} built locally." + +# if REGISTRY specified, push the images to the remote registy +if [ -n "$REGISTRY" ]; then + + if [ -n "${DOCKER_REGISTRY_PASSWORD}" ]; then + docker login --username="$DOCKER_REGISTRY_USERNAME" --password="$DOCKER_REGISTRY_PASSWORD" + fi + + docker tag "$IMAGE_TAG_WITH_SHA" "${REGISTRY}/${IMAGE_TAG_WITH_SHA}" + + docker push "${REGISTRY}/${IMAGE_TAG_WITH_SHA}" + echo "${REGISTRY}/${IMAGE_TAG_WITH_SHA} pushed to remote." + + # If the current commit has a semver tag, also push the images with the semver tag + if [ -n "$RELEASE_SEMVER" ]; then + + docker tag "$IMAGE_TAG_WITH_SHA" "${REGISTRY}/${IMAGE_TAG_WITH_SEMVER}" + + docker push "${REGISTRY}/${IMAGE_TAG_WITH_SEMVER}" + echo "${REGISTRY}/${IMAGE_TAG_WITH_SEMVER} pushed to remote." + + fi +fi diff --git a/flytepropeller/boilerplate/lyft/golang_dockerfile/Dockerfile.GoTemplate b/flytepropeller/boilerplate/lyft/golang_dockerfile/Dockerfile.GoTemplate new file mode 100644 index 0000000000..5e7b984a11 --- /dev/null +++ b/flytepropeller/boilerplate/lyft/golang_dockerfile/Dockerfile.GoTemplate @@ -0,0 +1,33 @@ +# WARNING: THIS FILE IS MANAGED IN THE 'BOILERPLATE' REPO AND COPIED TO OTHER REPOSITORIES. +# ONLY EDIT THIS FILE FROM WITHIN THE 'LYFT/BOILERPLATE' REPOSITORY: +# +# TO OPT OUT OF UPDATES, SEE https://github.com/lyft/boilerplate/blob/master/Readme.rst + +# Using go1.10.4 +FROM golang:1.10.4-alpine3.8 as builder +RUN apk add git openssh-client make curl dep + +# COPY only the dep files for efficient caching +COPY Gopkg.* /go/src/github.com/lyft/{{REPOSITORY}}/ +WORKDIR /go/src/github.com/lyft/{{REPOSITORY}} + +# Pull dependencies +RUN dep ensure -vendor-only + +# COPY the rest of the source code +COPY . /go/src/github.com/lyft/{{REPOSITORY}}/ + +# This 'linux_compile' target should compile binaries to the /artifacts directory +# The main entrypoint should be compiled to /artifacts/{{REPOSITORY}} +RUN make linux_compile + +# update the PATH to include the /artifacts directory +ENV PATH="/artifacts:${PATH}" + +# This will eventually move to centurylink/ca-certs:latest for minimum possible image size +FROM alpine:3.8 +COPY --from=builder /artifacts /bin + +RUN apk --update add ca-certificates + +CMD ["{{REPOSITORY}}"] diff --git a/flytepropeller/boilerplate/lyft/golang_dockerfile/Readme.rst b/flytepropeller/boilerplate/lyft/golang_dockerfile/Readme.rst new file mode 100644 index 0000000000..f801ef98d6 --- /dev/null +++ b/flytepropeller/boilerplate/lyft/golang_dockerfile/Readme.rst @@ -0,0 +1,16 @@ +Golang Dockerfile +~~~~~~~~~~~~~~~~~ + +Provides a Dockerfile that produces a small image. + +**To Enable:** + +Add ``lyft/golang_dockerfile`` to your ``boilerplate/update.cfg`` file. + +Create and configure a ``make linux_compile`` target that compiles your go binaries to the ``/artifacts`` directory :: + + .PHONY: linux_compile + linux_compile: + RUN GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -o /artifacts {{ packages }} + +All binaries compiled to ``/artifacts`` will be available at ``/bin`` in your final image. diff --git a/flytepropeller/boilerplate/lyft/golang_dockerfile/update.sh b/flytepropeller/boilerplate/lyft/golang_dockerfile/update.sh new file mode 100755 index 0000000000..7d84663262 --- /dev/null +++ b/flytepropeller/boilerplate/lyft/golang_dockerfile/update.sh @@ -0,0 +1,13 @@ +#!/usr/bin/env bash + +# WARNING: THIS FILE IS MANAGED IN THE 'BOILERPLATE' REPO AND COPIED TO OTHER REPOSITORIES. +# ONLY EDIT THIS FILE FROM WITHIN THE 'LYFT/BOILERPLATE' REPOSITORY: +# +# TO OPT OUT OF UPDATES, SEE https://github.com/lyft/boilerplate/blob/master/Readme.rst + +set -e + +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null && pwd )" + +echo " - generating Dockerfile in root directory." +sed -e "s/{{REPOSITORY}}/${REPOSITORY}/g" ${DIR}/Dockerfile.GoTemplate > ${DIR}/../../../Dockerfile diff --git a/flytepropeller/boilerplate/lyft/golang_test_targets/Makefile b/flytepropeller/boilerplate/lyft/golang_test_targets/Makefile new file mode 100644 index 0000000000..6c1e527fd6 --- /dev/null +++ b/flytepropeller/boilerplate/lyft/golang_test_targets/Makefile @@ -0,0 +1,38 @@ +# WARNING: THIS FILE IS MANAGED IN THE 'BOILERPLATE' REPO AND COPIED TO OTHER REPOSITORIES. +# ONLY EDIT THIS FILE FROM WITHIN THE 'LYFT/BOILERPLATE' REPOSITORY: +# +# TO OPT OUT OF UPDATES, SEE https://github.com/lyft/boilerplate/blob/master/Readme.rst + +DEP_SHA=1f7c19e5f52f49ffb9f956f64c010be14683468b + +.PHONY: lint +lint: #lints the package for common code smells + which golangci-lint || curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | bash -s -- -b $$GOPATH/bin v1.16.0 + golangci-lint run --exclude deprecated + +# If code is failing goimports linter, this will fix. +# skips 'vendor' +.PHONY: goimports +goimports: + @boilerplate/lyft/golang_test_targets/goimports + +.PHONY: install +install: #download dependencies (including test deps) for the package + which dep || (curl "https://raw.githubusercontent.com/golang/dep/${DEP_SHA}/install.sh" | sh) + dep ensure + +.PHONY: test_unit +test_unit: + go test -cover ./... -race + +.PHONY: test_benchmark +test_benchmark: + go test -bench . ./... + +.PHONY: test_unit_cover +test_unit_cover: + go test ./... -coverprofile /tmp/cover.out -covermode=count; go tool cover -func /tmp/cover.out + +.PHONY: test_unit_visual +test_unit_visual: + go test ./... -coverprofile /tmp/cover.out -covermode=count; go tool cover -html=/tmp/cover.out diff --git a/flytepropeller/boilerplate/lyft/golang_test_targets/Readme.rst b/flytepropeller/boilerplate/lyft/golang_test_targets/Readme.rst new file mode 100644 index 0000000000..acc5744f59 --- /dev/null +++ b/flytepropeller/boilerplate/lyft/golang_test_targets/Readme.rst @@ -0,0 +1,31 @@ +Golang Test Targets +~~~~~~~~~~~~~~~~~~~ + +Provides an ``install`` make target that uses ``dep`` install golang dependencies. + +Provides a ``lint`` make target that uses golangci to lint your code. + +Provides a ``test_unit`` target for unit tests. + +Provides a ``test_unit_cover`` target for analysing coverage of unit tests, which will output the coverage of each function and total statement coverage. + +Provides a ``test_unit_visual`` target for visualizing coverage of unit tests through an interactive html code heat map. + +Provides a ``test_benchmark`` target for benchmark tests. + +**To Enable:** + +Add ``lyft/golang_test_targets`` to your ``boilerplate/update.cfg`` file. + +Make sure you're using ``dep`` for dependency management. + +Provide a ``.golangci`` configuration (the lint target requires it). + +Add ``include boilerplate/lyft/golang_test_targets/Makefile`` in your main ``Makefile`` _after_ your REPOSITORY environment variable + +:: + + REPOSITORY= + include boilerplate/lyft/golang_test_targets/Makefile + +(this ensures the extra make targets get included in your main Makefile) diff --git a/flytepropeller/boilerplate/lyft/golang_test_targets/goimports b/flytepropeller/boilerplate/lyft/golang_test_targets/goimports new file mode 100755 index 0000000000..160525a8cc --- /dev/null +++ b/flytepropeller/boilerplate/lyft/golang_test_targets/goimports @@ -0,0 +1,8 @@ +#!/usr/bin/env bash + +# WARNING: THIS FILE IS MANAGED IN THE 'BOILERPLATE' REPO AND COPIED TO OTHER REPOSITORIES. +# ONLY EDIT THIS FILE FROM WITHIN THE 'LYFT/BOILERPLATE' REPOSITORY: +# +# TO OPT OUT OF UPDATES, SEE https://github.com/lyft/boilerplate/blob/master/Readme.rst + +goimports -w $(find . -type f -name '*.go' -not -path "./vendor/*" -not -path "./pkg/client/*") diff --git a/flytepropeller/boilerplate/lyft/golangci_file/.golangci.yml b/flytepropeller/boilerplate/lyft/golangci_file/.golangci.yml new file mode 100644 index 0000000000..a414f33f79 --- /dev/null +++ b/flytepropeller/boilerplate/lyft/golangci_file/.golangci.yml @@ -0,0 +1,30 @@ +# WARNING: THIS FILE IS MANAGED IN THE 'BOILERPLATE' REPO AND COPIED TO OTHER REPOSITORIES. +# ONLY EDIT THIS FILE FROM WITHIN THE 'LYFT/BOILERPLATE' REPOSITORY: +# +# TO OPT OUT OF UPDATES, SEE https://github.com/lyft/boilerplate/blob/master/Readme.rst + +run: + skip-dirs: + - pkg/client + +linters: + disable-all: true + enable: + - deadcode + - errcheck + - gas + - goconst + - goimports + - golint + - gosimple + - govet + - ineffassign + - misspell + - nakedret + - staticcheck + - structcheck + - typecheck + - unconvert + - unparam + - unused + - varcheck diff --git a/flytepropeller/boilerplate/lyft/golangci_file/Readme.rst b/flytepropeller/boilerplate/lyft/golangci_file/Readme.rst new file mode 100644 index 0000000000..ba5d2b61ce --- /dev/null +++ b/flytepropeller/boilerplate/lyft/golangci_file/Readme.rst @@ -0,0 +1,8 @@ +GolangCI File +~~~~~~~~~~~~~ + +Provides a ``.golangci`` file with the linters we've agreed upon. + +**To Enable:** + +Add ``lyft/golangci_file`` to your ``boilerplate/update.cfg`` file. diff --git a/flytepropeller/boilerplate/lyft/golangci_file/update.sh b/flytepropeller/boilerplate/lyft/golangci_file/update.sh new file mode 100755 index 0000000000..9e9e6c1f46 --- /dev/null +++ b/flytepropeller/boilerplate/lyft/golangci_file/update.sh @@ -0,0 +1,14 @@ +#!/usr/bin/env bash + +# WARNING: THIS FILE IS MANAGED IN THE 'BOILERPLATE' REPO AND COPIED TO OTHER REPOSITORIES. +# ONLY EDIT THIS FILE FROM WITHIN THE 'LYFT/BOILERPLATE' REPOSITORY: +# +# TO OPT OUT OF UPDATES, SEE https://github.com/lyft/boilerplate/blob/master/Readme.rst + +set -e + +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null && pwd )" + +# Clone the .golangci file +echo " - copying ${DIR}/.golangci to the root directory." +cp ${DIR}/.golangci.yml ${DIR}/../../../.golangci.yml diff --git a/flytepropeller/boilerplate/update.cfg b/flytepropeller/boilerplate/update.cfg new file mode 100644 index 0000000000..5417c80464 --- /dev/null +++ b/flytepropeller/boilerplate/update.cfg @@ -0,0 +1,4 @@ +lyft/docker_build +lyft/golang_test_targets +lyft/golangci_file +lyft/golang_dockerfile diff --git a/flytepropeller/boilerplate/update.sh b/flytepropeller/boilerplate/update.sh new file mode 100755 index 0000000000..bea661d9a0 --- /dev/null +++ b/flytepropeller/boilerplate/update.sh @@ -0,0 +1,53 @@ +#!/usr/bin/env bash + +# WARNING: THIS FILE IS MANAGED IN THE 'BOILERPLATE' REPO AND COPIED TO OTHER REPOSITORIES. +# ONLY EDIT THIS FILE FROM WITHIN THE 'LYFT/BOILERPLATE' REPOSITORY: +# +# TO OPT OUT OF UPDATES, SEE https://github.com/lyft/boilerplate/blob/master/Readme.rst + +set -e + +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null && pwd )" + +OUT="$(mktemp -d)" +git clone git@github.com:lyft/boilerplate.git "${OUT}" + +echo "Updating the update.sh script." +cp "${OUT}/boilerplate/update.sh" "${DIR}/update.sh" +echo "" + + +CONFIG_FILE="${DIR}/update.cfg" +README="https://github.com/lyft/boilerplate/blob/master/Readme.rst" + +if [ ! -f "$CONFIG_FILE" ]; then + echo "$CONFIG_FILE not found." + echo "This file is required in order to select which features to include." + echo "See $README for more details." + exit 1 +fi + +if [ -z "$REPOSITORY" ]; then + echo '$REPOSITORY is required to run this script' + echo "See $README for more details." + exit 1 +fi + +while read directory; do + echo "***********************************************************************************" + echo "$directory is configured in update.cfg." + echo "-----------------------------------------------------------------------------------" + echo "syncing files from source." + dir_path="${OUT}/boilerplate/${directory}" + rm -rf "${DIR}/${directory}" + mkdir -p $(dirname "${DIR}/${directory}") + cp -r "$dir_path" "${DIR}/${directory}" + if [ -f "${DIR}/${directory}/update.sh" ]; then + echo "executing ${DIR}/${directory}/update.sh" + "${DIR}/${directory}/update.sh" + fi + echo "***********************************************************************************" + echo "" +done < "$CONFIG_FILE" + +rm -rf "${OUT}" diff --git a/flytepropeller/cmd/controller/cmd/root.go b/flytepropeller/cmd/controller/cmd/root.go new file mode 100644 index 0000000000..55923eb790 --- /dev/null +++ b/flytepropeller/cmd/controller/cmd/root.go @@ -0,0 +1,235 @@ +package cmd + +import ( + "context" + "flag" + "fmt" + "os" + + "github.com/lyft/flytepropeller/pkg/controller/executors" + "sigs.k8s.io/controller-runtime/pkg/cache" + + "sigs.k8s.io/controller-runtime/pkg/client" + + "sigs.k8s.io/controller-runtime/pkg/manager" + + config2 "github.com/lyft/flytepropeller/pkg/controller/config" + + "github.com/lyft/flyteplugins/go/tasks" + + "github.com/lyft/flyteplugins/go/tasks/v1/flytek8s" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/lyft/flytestdlib/config/viper" + "github.com/lyft/flytestdlib/version" + + "github.com/lyft/flytestdlib/config" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/profutils" + "github.com/lyft/flytestdlib/promutils" + "github.com/operator-framework/operator-sdk/pkg/util/k8sutil" + "github.com/pkg/errors" + "github.com/spf13/pflag" + + "github.com/spf13/cobra" + + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/tools/clientcmd" + + clientset "github.com/lyft/flytepropeller/pkg/client/clientset/versioned" + informers "github.com/lyft/flytepropeller/pkg/client/informers/externalversions" + "github.com/lyft/flytepropeller/pkg/controller" + "github.com/lyft/flytepropeller/pkg/signals" + restclient "k8s.io/client-go/rest" +) + +const ( + defaultNamespace = "all" + appName = "flytepropeller" +) + +var ( + cfgFile string + configAccessor = viper.NewAccessor(config.Options{StrictMode: true}) +) + +// rootCmd represents the base command when called without any subcommands +var rootCmd = &cobra.Command{ + Use: "flyte-propeller", + Short: "Operator for running Flyte Workflows", + Long: `Flyte Propeller runs a workflow to completion by recursing through the nodes, + handling their tasks to completion and propagating their status upstream.`, + PreRunE: initConfig, + Run: func(cmd *cobra.Command, args []string) { + executeRootCmd(config2.GetConfig()) + }, +} + +// Execute adds all child commands to the root command and sets flags appropriately. +// This is called by main.main(). It only needs to happen once to the rootCmd. +func Execute() { + version.LogBuildInformation(appName) + if err := rootCmd.Execute(); err != nil { + fmt.Println(err) + os.Exit(1) + } +} + +func init() { + // allows `$ flytepropeller --logtostderr` to work + pflag.CommandLine.AddGoFlagSet(flag.CommandLine) + err := flag.CommandLine.Parse([]string{}) + if err != nil { + logAndExit(err) + } + + // Here you will define your flags and configuration settings. Cobra supports persistent flags, which, if defined + // here, will be global for your application. + rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", + "config file (default is $HOME/config.yaml)") + + configAccessor.InitializePflags(rootCmd.PersistentFlags()) +} + +func initConfig(_ *cobra.Command, _ []string) error { + configAccessor = viper.NewAccessor(config.Options{ + StrictMode: true, + SearchPaths: []string{cfgFile}, + }) + + err := configAccessor.UpdateConfig(context.TODO()) + if err != nil { + return err + } + + // Operator-SDK expects kube config to be in KUBERNETES_CONFIG env var. + controllerCfg := config2.GetConfig() + if controllerCfg.KubeConfigPath != "" { + fmt.Printf("Setting env variable for operator-sdk, %v\n", controllerCfg.KubeConfigPath) + return os.Setenv(k8sutil.KubeConfigEnvVar, os.ExpandEnv(controllerCfg.KubeConfigPath)) + } + + fmt.Printf("Started in-cluster mode\n") + return nil +} + +func logAndExit(err error) { + logger.Error(context.Background(), err) + os.Exit(-1) +} + +func getKubeConfig(_ context.Context, cfg *config2.Config) (*kubernetes.Clientset, *restclient.Config, error) { + var kubecfg *restclient.Config + var err error + if cfg.KubeConfigPath != "" { + kubeConfigPath := os.ExpandEnv(cfg.KubeConfigPath) + kubecfg, err = clientcmd.BuildConfigFromFlags(cfg.MasterURL, kubeConfigPath) + if err != nil { + return nil, nil, errors.Wrapf(err, "Error building kubeconfig") + } + } else { + kubecfg, err = restclient.InClusterConfig() + if err != nil { + return nil, nil, errors.Wrapf(err, "Cannot get InCluster kubeconfig") + } + } + + kubeClient, err := kubernetes.NewForConfig(kubecfg) + if err != nil { + return nil, nil, errors.Wrapf(err, "Error building kubernetes clientset") + } + return kubeClient, kubecfg, err +} + +func sharedInformerOptions(cfg *config2.Config) []informers.SharedInformerOption { + opts := []informers.SharedInformerOption{ + informers.WithTweakListOptions(func(options *v1.ListOptions) { + options.LabelSelector = v1.FormatLabelSelector(controller.IgnoreCompletedWorkflowsLabelSelector()) + }), + } + if cfg.LimitNamespace != defaultNamespace { + opts = append(opts, informers.WithNamespace(cfg.LimitNamespace)) + } + return opts +} + +func executeRootCmd(cfg *config2.Config) { + baseCtx := context.TODO() + + // set up signals so we handle the first shutdown signal gracefully + ctx := signals.SetupSignalHandler(baseCtx) + + kubeClient, kubecfg, err := getKubeConfig(ctx, cfg) + if err != nil { + logger.Fatalf(ctx, "Error building kubernetes clientset: %s", err.Error()) + } + + flyteworkflowClient, err := clientset.NewForConfig(kubecfg) + if err != nil { + logger.Fatalf(ctx, "Error building example clientset: %s", err.Error()) + } + + opts := sharedInformerOptions(cfg) + flyteworkflowInformerFactory := informers.NewSharedInformerFactoryWithOptions(flyteworkflowClient, cfg.WorkflowReEval.Duration, opts...) + + // Add the propeller subscope because the MetricsPrefix only has "flyte:" to get uniform collection of metrics. + propellerScope := promutils.NewScope(cfg.MetricsPrefix).NewSubScope("propeller").NewSubScope(cfg.LimitNamespace) + + go func() { + err := profutils.StartProfilingServerWithDefaultHandlers(ctx, cfg.ProfilerPort.Port, nil) + if err != nil { + logger.Panicf(ctx, "Failed to Start profiling and metrics server. Error: %v", err) + } + }() + + limitNamespace := "" + if cfg.LimitNamespace != defaultNamespace { + limitNamespace = cfg.LimitNamespace + } + + err = flytek8s.Initialize(ctx, limitNamespace, cfg.DownstreamEval.Duration) + if err != nil { + logger.Panicf(ctx, "Failed to initialize k8s plugins. Error: %v", err) + } + + if err := tasks.Load(ctx); err != nil { + logger.Fatalf(ctx, "Failed to load task plugins. [%v]", err) + } + + mgr, err := manager.New(kubecfg, manager.Options{ + Namespace: limitNamespace, + SyncPeriod: &cfg.DownstreamEval.Duration, + NewClient: func(cache cache.Cache, config *restclient.Config, options client.Options) (i client.Client, e error) { + rawClient, err := client.New(kubecfg, client.Options{}) + if err != nil { + return nil, err + } + + return executors.NewFallbackClient(&client.DelegatingClient{ + Reader: &client.DelegatingReader{ + CacheReader: cache, + ClientReader: rawClient, + }, + Writer: rawClient, + StatusClient: rawClient, + }, rawClient), nil + }, + }) + if err != nil { + logger.Fatalf(ctx, "Failed to initialize controller run-time manager. Error: %v", err) + } + + c, err := controller.New(ctx, cfg, kubeClient, flyteworkflowClient, flyteworkflowInformerFactory, mgr, propellerScope) + + if err != nil { + logger.Fatalf(ctx, "Failed to start Controller - [%v]", err.Error()) + } else if c == nil { + logger.Fatalf(ctx, "Failed to start Controller, nil controller received.") + } + + go flyteworkflowInformerFactory.Start(ctx.Done()) + + if err = c.Run(ctx); err != nil { + logger.Fatalf(ctx, "Error running controller: %s", err.Error()) + } +} diff --git a/flytepropeller/cmd/controller/main.go b/flytepropeller/cmd/controller/main.go new file mode 100644 index 0000000000..6b551ff730 --- /dev/null +++ b/flytepropeller/cmd/controller/main.go @@ -0,0 +1,9 @@ +package main + +import ( + "github.com/lyft/flytepropeller/cmd/controller/cmd" +) + +func main() { + cmd.Execute() +} diff --git a/flytepropeller/cmd/kubectl-flyte/cmd/compile.go b/flytepropeller/cmd/kubectl-flyte/cmd/compile.go new file mode 100644 index 0000000000..f43f164cfe --- /dev/null +++ b/flytepropeller/cmd/kubectl-flyte/cmd/compile.go @@ -0,0 +1,106 @@ +package cmd + +import ( + "fmt" + "io/ioutil" + "os" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/compiler" + "github.com/lyft/flytepropeller/pkg/compiler/common" + compilerErrors "github.com/lyft/flytepropeller/pkg/compiler/errors" + "github.com/pkg/errors" + "github.com/spf13/cobra" +) + +type CompileOpts struct { + *RootOptions + inputFormat format + outputFormat format + protoFile string + outputPath string + dumpClosureYaml bool +} + +func NewCompileCommand(opts *RootOptions) *cobra.Command { + + compileOpts := &CompileOpts{ + RootOptions: opts, + } + compileCmd := &cobra.Command{ + Use: "compile", + Aliases: []string{"new", "compile"}, + Short: "Compile a workflow from core proto-buffer files and output a closure.", + Long: ``, + RunE: func(cmd *cobra.Command, args []string) error { + if err := requiredFlags(cmd, protofileKey, formatKey); err != nil { + return err + } + + fmt.Println("Line numbers in errors enabled") + compilerErrors.SetIncludeSource() + + return compileOpts.compileWorkflowCmd() + }, + } + + compileCmd.Flags().StringVarP(&compileOpts.protoFile, "input-file", "i", "", "Path of the workflow package proto-buffer file to be uploaded") + compileCmd.Flags().StringVarP(&compileOpts.inputFormat, "input-format", "f", formatProto, "Format of the provided file. Supported formats: proto (default), json, yaml") + compileCmd.Flags().StringVarP(&compileOpts.outputPath, "output-file", "o", "", "Path of the generated output file.") + compileCmd.Flags().StringVarP(&compileOpts.outputFormat, "output-format", "m", formatProto, "Format of the generated file. Supported formats: proto (default), json, yaml") + compileCmd.Flags().BoolVarP(&compileOpts.dumpClosureYaml, "dump-closure-yaml", "d", false, "Compiles and transforms, but does not create a workflow. OutputsRef ts to STDOUT.") + + return compileCmd +} + +func (c *CompileOpts) compileWorkflowCmd() error { + if c.protoFile == "" { + return errors.Errorf("Input file not specified") + } + fmt.Printf("Received protofiles : [%v].\n", c.protoFile) + + rawWf, err := ioutil.ReadFile(c.protoFile) + if err != nil { + return err + } + + wfClosure := core.WorkflowClosure{} + err = unmarshal(rawWf, c.inputFormat, &wfClosure) + if err != nil { + return errors.Wrapf(err, "Failed to unmarshal input Workflow") + } + + if c.dumpClosureYaml { + b, err := marshal(&wfClosure, formatYaml) + if err != nil { + return err + } + err = ioutil.WriteFile(c.protoFile+".yaml", b, os.ModePerm) + if err != nil { + return err + } + } + + compiledTasks, err := compileTasks(wfClosure.Tasks) + if err != nil { + return err + } + + compileWfClosure, err := compiler.CompileWorkflow(wfClosure.Workflow, []*core.WorkflowTemplate{}, compiledTasks, []common.InterfaceProvider{}) + if err != nil { + return err + } + + fmt.Printf("Workflow compiled successfully, creating output location: [%v] format [%v]\n", c.outputPath, c.outputFormat) + + o, err := marshal(compileWfClosure, c.outputFormat) + if err != nil { + return errors.Wrapf(err, "Failed to marshal final workflow.") + } + + if c.outputPath != "" { + return ioutil.WriteFile(c.outputPath, o, os.ModePerm) + } + fmt.Printf("%v", string(o)) + return nil +} diff --git a/flytepropeller/cmd/kubectl-flyte/cmd/create.go b/flytepropeller/cmd/kubectl-flyte/cmd/create.go new file mode 100644 index 0000000000..d742193828 --- /dev/null +++ b/flytepropeller/cmd/kubectl-flyte/cmd/create.go @@ -0,0 +1,228 @@ +package cmd + +import ( + "bytes" + "encoding/json" + "fmt" + "io/ioutil" + + "github.com/ghodss/yaml" + "github.com/golang/protobuf/jsonpb" + "github.com/golang/protobuf/proto" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/compiler" + "github.com/lyft/flytepropeller/pkg/compiler/common" + compilerErrors "github.com/lyft/flytepropeller/pkg/compiler/errors" + "github.com/lyft/flytepropeller/pkg/compiler/transformers/k8s" + + "github.com/pkg/errors" + "github.com/spf13/cobra" +) + +const ( + protofileKey = "proto-path" + formatKey = "format" + executionIDKey = "execution-id" + inputsKey = "input-path" + annotationsKey = "annotations" +) + +type format = string + +const ( + formatProto format = "proto" + formatJSON format = "json" + formatYaml format = "yaml" +) + +const createCmdName = "create" + +type CreateOpts struct { + *RootOptions + format format + execID string + inputsPath string + protoFile string + annotations *stringMapValue + dryRun bool +} + +func NewCreateCommand(opts *RootOptions) *cobra.Command { + + createOpts := &CreateOpts{ + RootOptions: opts, + } + + createCmd := &cobra.Command{ + Use: createCmdName, + Aliases: []string{"new", "compile"}, + Short: "Creates a new workflow from proto-buffer files.", + Long: ``, + RunE: func(cmd *cobra.Command, args []string) error { + if err := requiredFlags(cmd, protofileKey, formatKey); err != nil { + return err + } + + fmt.Println("Line numbers in errors enabled") + compilerErrors.SetIncludeSource() + + return createOpts.createWorkflowFromProto() + }, + } + + createCmd.Flags().StringVarP(&createOpts.protoFile, protofileKey, "p", "", "Path of the workflow package proto-buffer file to be uploaded") + createCmd.Flags().StringVarP(&createOpts.format, formatKey, "f", formatProto, "Format of the provided file. Supported formats: proto (default), json, yaml") + createCmd.Flags().StringVarP(&createOpts.execID, executionIDKey, "", "", "Execution Id of the Workflow to create.") + createCmd.Flags().StringVarP(&createOpts.inputsPath, inputsKey, "i", "", "Path to inputs file.") + createOpts.annotations = newStringMapValue() + createCmd.Flags().VarP(createOpts.annotations, annotationsKey, "a", "Defines extra annotations to declare on the created object.") + createCmd.Flags().BoolVarP(&createOpts.dryRun, "dry-run", "d", false, "Compiles and transforms, but does not create a workflow. OutputsRef ts to STDOUT.") + + return createCmd +} + +func unmarshal(in []byte, format format, message proto.Message) (err error) { + switch format { + case formatProto: + err = proto.Unmarshal(in, message) + case formatJSON: + err = jsonpb.Unmarshal(bytes.NewReader(in), message) + if err != nil { + err = errors.Wrapf(err, "Failed to unmarshal converted Json. [%v]", string(in)) + } + case formatYaml: + jsonRaw, err := yaml.YAMLToJSON(in) + if err != nil { + return errors.Wrapf(err, "Failed to convert yaml to JSON. [%v]", string(in)) + } + + return unmarshal(jsonRaw, formatJSON, message) + } + + return +} + +var jsonPbMarshaler = jsonpb.Marshaler{} + +func marshal(message proto.Message, format format) (raw []byte, err error) { + switch format { + case formatProto: + return proto.Marshal(message) + case formatJSON: + b := &bytes.Buffer{} + err := jsonPbMarshaler.Marshal(b, message) + if err != nil { + return nil, errors.Wrapf(err, "Failed to marshal Json.") + } + return b.Bytes(), nil + case formatYaml: + b, err := marshal(message, formatJSON) + if err != nil { + return nil, errors.Wrapf(err, "Failed to marshal JSON") + } + return yaml.JSONToYAML(b) + } + return nil, errors.Errorf("Unknown format type") +} + +func loadInputs(path string, format format) (c *core.LiteralMap, err error) { + // Support reading from s3, etc.? + var raw []byte + raw, err = ioutil.ReadFile(path) + if err != nil { + return + } + + c = &core.LiteralMap{} + err = unmarshal(raw, format, c) + return +} + +func compileTasks(tasks []*core.TaskTemplate) ([]*core.CompiledTask, error) { + res := make([]*core.CompiledTask, 0, len(tasks)) + for _, task := range tasks { + compiledTask, err := compiler.CompileTask(task) + if err != nil { + return nil, err + } + + res = append(res, compiledTask) + } + + return res, nil +} + +func (c *CreateOpts) createWorkflowFromProto() error { + fmt.Printf("Received protofiles : [%v] [%v].\n", c.protoFile, c.inputsPath) + rawWf, err := ioutil.ReadFile(c.protoFile) + if err != nil { + return err + } + + wfClosure := core.WorkflowClosure{} + err = unmarshal(rawWf, c.format, &wfClosure) + if err != nil { + return err + } + + compiledTasks, err := compileTasks(wfClosure.Tasks) + if err != nil { + return err + } + + wf, err := compiler.CompileWorkflow(wfClosure.Workflow, []*core.WorkflowTemplate{}, compiledTasks, []common.InterfaceProvider{}) + if err != nil { + return err + } + + var inputs *core.LiteralMap + if c.inputsPath != "" { + inputs, err = loadInputs(c.inputsPath, c.format) + if err != nil { + return errors.Wrapf(err, "Failed to load inputs.") + } + } + + var executionID *core.WorkflowExecutionIdentifier + if len(c.execID) > 0 { + executionID = &core.WorkflowExecutionIdentifier{ + Name: c.execID, + Domain: wfClosure.Workflow.Id.Domain, + Project: wfClosure.Workflow.Id.Project, + } + } + + flyteWf, err := k8s.BuildFlyteWorkflow(wf, inputs, executionID, c.ConfigOverrides.Context.Namespace) + if err != nil { + return err + } + if flyteWf.Annotations == nil { + flyteWf.Annotations = *c.annotations.value + } else { + for key, val := range *c.annotations.value { + flyteWf.Annotations[key] = val + } + } + + if c.dryRun { + fmt.Printf("Dry Run mode enabled. Printing the compiled workflow.") + j, err := json.Marshal(flyteWf) + if err != nil { + return errors.Wrapf(err, "Failed to marshal final workflow to Propeller format.") + } + y, err := yaml.JSONToYAML(j) + if err != nil { + return errors.Wrapf(err, "Failed to marshal final workflow from json to yaml.") + } + fmt.Println(string(y)) + } else { + wf, err := c.flyteClient.FlyteworkflowV1alpha1().FlyteWorkflows(c.ConfigOverrides.Context.Namespace).Create(flyteWf) + if err != nil { + return err + } + + fmt.Printf("Successfully created Flyte Workflow %v.\n", wf.Name) + } + return nil +} diff --git a/flytepropeller/cmd/kubectl-flyte/cmd/create_test.go b/flytepropeller/cmd/kubectl-flyte/cmd/create_test.go new file mode 100644 index 0000000000..d23485791f --- /dev/null +++ b/flytepropeller/cmd/kubectl-flyte/cmd/create_test.go @@ -0,0 +1,292 @@ +package cmd + +import ( + "encoding/json" + "flag" + "io/ioutil" + "os" + "path/filepath" + "testing" + + "github.com/ghodss/yaml" + + "github.com/golang/protobuf/jsonpb" + "github.com/golang/protobuf/proto" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/compiler" + "github.com/lyft/flytepropeller/pkg/compiler/common" + "github.com/lyft/flytepropeller/pkg/compiler/transformers/k8s" + "github.com/lyft/flytepropeller/pkg/utils" + "github.com/stretchr/testify/assert" +) + +var update = flag.Bool("update", false, "Update .golden files") + +func init() { +} + +func createEmptyVariableMap() *core.VariableMap { + res := &core.VariableMap{ + Variables: map[string]*core.Variable{}, + } + return res +} + +func createVariableMap(variableMap map[string]*core.Variable) *core.VariableMap { + res := &core.VariableMap{ + Variables: variableMap, + } + return res +} + +func TestCreate(t *testing.T) { + t.Run("Generate simple workflow", generateSimpleWorkflow) + t.Run("Generate workflow with inputs", generateWorkflowWithInputs) + t.Run("Compile", testCompile) +} + +func generateSimpleWorkflow(t *testing.T) { + if !*update { + t.SkipNow() + } + + t.Log("Generating golden files.") + closure := core.WorkflowClosure{ + Workflow: &core.WorkflowTemplate{ + Id: &core.Identifier{Name: "workflow-id-123"}, + Interface: &core.TypedInterface{ + Inputs: createEmptyVariableMap(), + }, + Nodes: []*core.Node{ + { + Id: "node-1", + Target: &core.Node_TaskNode{ + TaskNode: &core.TaskNode{ + Reference: &core.TaskNode_ReferenceId{ + ReferenceId: &core.Identifier{Name: "task-1"}, + }, + }, + }, + }, + { + Id: "node-2", + Target: &core.Node_TaskNode{ + TaskNode: &core.TaskNode{ + Reference: &core.TaskNode_ReferenceId{ + ReferenceId: &core.Identifier{Name: "task-2"}, + }, + }, + }, + }, + }, + }, + Tasks: []*core.TaskTemplate{ + { + Id: &core.Identifier{Name: "task-1"}, + Interface: &core.TypedInterface{ + Inputs: createEmptyVariableMap(), + }, + Target: &core.TaskTemplate_Container{ + Container: &core.Container{ + Image: "myflyteimage:latest", + Command: []string{"execute-task"}, + Args: []string{"testArg"}, + }, + }, + }, + { + Id: &core.Identifier{Name: "task-2"}, + Interface: &core.TypedInterface{ + Inputs: createEmptyVariableMap(), + }, + Target: &core.TaskTemplate_Container{ + Container: &core.Container{ + Image: "myflyteimage:latest", + Command: []string{"execute-task"}, + Args: []string{"testArg"}, + }, + }, + }, + }, + } + + marshaller := &jsonpb.Marshaler{} + s, err := marshaller.MarshalToString(&closure) + assert.NoError(t, err) + assert.NoError(t, ioutil.WriteFile(filepath.Join("testdata", "workflow.json.golden"), []byte(s), os.ModePerm)) + + m := map[string]interface{}{} + err = json.Unmarshal([]byte(s), &m) + assert.NoError(t, err) + + b, err := yaml.Marshal(m) + assert.NoError(t, err) + assert.NoError(t, ioutil.WriteFile(filepath.Join("testdata", "workflow.yaml.golden"), b, os.ModePerm)) + + raw, err := proto.Marshal(&closure) + assert.NoError(t, err) + assert.NoError(t, ioutil.WriteFile(filepath.Join("testdata", "workflow.pb.golden"), raw, os.ModePerm)) +} + +func generateWorkflowWithInputs(t *testing.T) { + if !*update { + t.SkipNow() + } + + t.Log("Generating golden files.") + closure := core.WorkflowClosure{ + Workflow: &core.WorkflowTemplate{ + Id: &core.Identifier{Name: "workflow-with-inputs"}, + Interface: &core.TypedInterface{ + Inputs: createVariableMap(map[string]*core.Variable{ + "x": { + Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}}, + }, + "y": { + Type: &core.LiteralType{ + Type: &core.LiteralType_CollectionType{ + CollectionType: &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRING}, + }, + }, + }}, + }), + }, + Nodes: []*core.Node{ + { + Id: "node-1", + Inputs: []*core.Binding{ + {Var: "x", Binding: &core.BindingData{Value: &core.BindingData_Promise{Promise: &core.OutputReference{Var: "x"}}}}, + {Var: "y", Binding: &core.BindingData{Value: &core.BindingData_Promise{Promise: &core.OutputReference{Var: "y"}}}}, + }, + Target: &core.Node_TaskNode{ + TaskNode: &core.TaskNode{ + Reference: &core.TaskNode_ReferenceId{ + ReferenceId: &core.Identifier{Name: "task-1"}, + }, + }, + }, + }, + { + Id: "node-2", + Target: &core.Node_TaskNode{ + TaskNode: &core.TaskNode{ + Reference: &core.TaskNode_ReferenceId{ + ReferenceId: &core.Identifier{Name: "task-2"}, + }, + }, + }, + }, + }, + }, + Tasks: []*core.TaskTemplate{ + { + Id: &core.Identifier{Name: "task-1"}, + Interface: &core.TypedInterface{ + Inputs: createVariableMap(map[string]*core.Variable{ + "x": { + Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}}, + }, + "y": { + Type: &core.LiteralType{ + Type: &core.LiteralType_CollectionType{ + CollectionType: &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRING}, + }, + }, + }}, + }), + }, + Target: &core.TaskTemplate_Container{ + Container: &core.Container{ + Image: "myflyteimage:latest", + Command: []string{"execute-task"}, + Args: []string{"testArg"}, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "2"}, + {Name: core.Resources_MEMORY, Value: "2048Mi"}, + }, + }, + }, + }, + }, + { + Id: &core.Identifier{Name: "task-2"}, + Interface: &core.TypedInterface{ + Inputs: createEmptyVariableMap(), + }, + Target: &core.TaskTemplate_Container{ + Container: &core.Container{ + Image: "myflyteimage:latest", + Command: []string{"execute-task"}, + Args: []string{"testArg"}, + }, + }, + }, + }, + } + + marshalGolden(t, &closure, "workflow_w_inputs") + sampleInputs := core.LiteralMap{ + Literals: map[string]*core.Literal{ + "x": utils.MustMakeLiteral(2), + "y": utils.MustMakeLiteral([]interface{}{"val1", "val2", "val3"}), + }, + } + + marshalGolden(t, &sampleInputs, "inputs") +} + +func marshalGolden(t *testing.T, message proto.Message, filename string) { + marshaller := &jsonpb.Marshaler{} + s, err := marshaller.MarshalToString(message) + assert.NoError(t, err) + assert.NoError(t, ioutil.WriteFile(filepath.Join("testdata", filename+".json.golden"), []byte(s), os.ModePerm)) + + m := map[string]interface{}{} + err = json.Unmarshal([]byte(s), &m) + assert.NoError(t, err) + + b, err := yaml.Marshal(m) + assert.NoError(t, err) + assert.NoError(t, ioutil.WriteFile(filepath.Join("testdata", filename+".yaml.golden"), b, os.ModePerm)) + + raw, err := proto.Marshal(message) + assert.NoError(t, err) + assert.NoError(t, ioutil.WriteFile(filepath.Join("testdata", filename+".pb.golden"), raw, os.ModePerm)) +} + +func testCompile(t *testing.T) { + f := func(t *testing.T, filePath, format string) { + raw, err := ioutil.ReadFile(filepath.Join("testdata", filePath)) + assert.NoError(t, err) + wf := &core.WorkflowClosure{} + err = unmarshal(raw, format, wf) + assert.NoError(t, err) + assert.NotNil(t, wf) + assert.Equal(t, 2, len(wf.Tasks)) + if len(wf.Tasks) == 2 { + c := wf.Tasks[0].GetContainer() + assert.NotNil(t, c) + compiledTasks, err := compileTasks(wf.Tasks) + assert.NoError(t, err) + compiledWf, err := compiler.CompileWorkflow(wf.Workflow, []*core.WorkflowTemplate{}, compiledTasks, []common.InterfaceProvider{}) + assert.NoError(t, err) + _, err = k8s.BuildFlyteWorkflow(compiledWf, nil, nil, "") + assert.NoError(t, err) + } + } + + t.Run("yaml", func(t *testing.T) { + f(t, "workflow.yaml.golden", formatYaml) + }) + + t.Run("json", func(t *testing.T) { + f(t, "workflow.json.golden", formatJSON) + }) + + t.Run("proto", func(t *testing.T) { + f(t, "workflow.pb.golden", formatProto) + }) +} diff --git a/flytepropeller/cmd/kubectl-flyte/cmd/delete.go b/flytepropeller/cmd/kubectl-flyte/cmd/delete.go new file mode 100644 index 0000000000..7ddede1e09 --- /dev/null +++ b/flytepropeller/cmd/kubectl-flyte/cmd/delete.go @@ -0,0 +1,87 @@ +package cmd + +import ( + "fmt" + + "github.com/lyft/flytepropeller/pkg/controller" + "github.com/spf13/cobra" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +type DeleteOpts struct { + *RootOptions + force bool + allCompleted bool + chunkSize int64 + limit int64 +} + +func NewDeleteCommand(opts *RootOptions) *cobra.Command { + + deleteOpts := &DeleteOpts{ + RootOptions: opts, + } + + // deleteCmd represents the delete command + deleteCmd := &cobra.Command{ + Use: "delete [workflow-name]", + Short: "delete a workflow", + Long: ``, + RunE: func(cmd *cobra.Command, args []string) error { + if len(args) > 0 { + name := args[0] + return deleteOpts.deleteWorkflow(name) + } + + return deleteOpts.deleteCompletedWorkflows() + }, + } + + deleteCmd.Flags().BoolVarP(&deleteOpts.force, "force", "f", false, "Enable force deletion to remove finalizers from a workflow.") + deleteCmd.Flags().BoolVarP(&deleteOpts.allCompleted, "all-completed", "a", false, "Delete all the workflows that have completed. Cannot be used with --force.") + deleteCmd.Flags().Int64VarP(&deleteOpts.chunkSize, "chunk-size", "c", 100, "When using all-completed, provide a chunk size to retrieve at once from the server.") + deleteCmd.Flags().Int64VarP(&deleteOpts.limit, "limit", "l", -1, "Only iterate over max limit records.") + + return deleteCmd +} + +func (d *DeleteOpts) deleteCompletedWorkflows() error { + if d.force && d.allCompleted { + return fmt.Errorf("cannot delete multiple workflows with --force") + } + if !d.allCompleted { + return fmt.Errorf("all completed | workflow name is required") + } + + t, err := d.GetTimeoutSeconds() + if err != nil { + return err + } + + p := v1.DeletePropagationBackground + return d.flyteClient.FlyteworkflowV1alpha1().FlyteWorkflows(d.ConfigOverrides.Context.Namespace).DeleteCollection( + &v1.DeleteOptions{PropagationPolicy: &p}, v1.ListOptions{ + TimeoutSeconds: &t, + LabelSelector: v1.FormatLabelSelector(controller.CompletedWorkflowsLabelSelector()), + }, + ) + +} + +func (d *DeleteOpts) deleteWorkflow(name string) error { + p := v1.DeletePropagationBackground + if err := d.flyteClient.FlyteworkflowV1alpha1().FlyteWorkflows(d.ConfigOverrides.Context.Namespace).Delete(name, &v1.DeleteOptions{PropagationPolicy: &p}); err != nil { + return err + } + if d.force { + w, err := d.flyteClient.FlyteworkflowV1alpha1().FlyteWorkflows(d.ConfigOverrides.Context.Namespace).Get(name, v1.GetOptions{}) + if err != nil { + return err + } + w.SetFinalizers([]string{}) + if _, err := d.flyteClient.FlyteworkflowV1alpha1().FlyteWorkflows(d.ConfigOverrides.Context.Namespace).Update(w); err != nil { + return err + } + } + return nil +} diff --git a/flytepropeller/cmd/kubectl-flyte/cmd/get.go b/flytepropeller/cmd/kubectl-flyte/cmd/get.go new file mode 100644 index 0000000000..b1460c340b --- /dev/null +++ b/flytepropeller/cmd/kubectl-flyte/cmd/get.go @@ -0,0 +1,141 @@ +package cmd + +import ( + "fmt" + "strings" + + gotree "github.com/DiSiqueira/GoTree" + "github.com/lyft/flytepropeller/cmd/kubectl-flyte/cmd/printers" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/spf13/cobra" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +type GetOpts struct { + *RootOptions + detailsEnabledFlag bool + limit int64 + chunkSize int64 +} + +func NewGetCommand(opts *RootOptions) *cobra.Command { + + getOpts := &GetOpts{ + RootOptions: opts, + } + + getCmd := &cobra.Command{ + Use: "get [opts] []", + Short: "Gets a single workflow or lists all workflows currently in execution", + Long: `use labels to filter`, + RunE: func(cmd *cobra.Command, args []string) error { + if len(args) > 0 { + name := args[0] + return getOpts.getWorkflow(name) + } + return getOpts.listWorkflows() + }, + } + + getCmd.Flags().BoolVarP(&getOpts.detailsEnabledFlag, "details", "d", false, "If details of node execs are desired.") + getCmd.Flags().Int64VarP(&getOpts.chunkSize, "chunk-size", "c", 100, "Use this much batch size.") + getCmd.Flags().Int64VarP(&getOpts.limit, "limit", "l", -1, "Only get limit records. -1 => all records.") + + return getCmd +} + +func (g *GetOpts) getWorkflow(name string) error { + parts := strings.Split(name, "/") + if len(parts) > 1 { + g.ConfigOverrides.Context.Namespace = parts[0] + name = parts[1] + } + w, err := g.flyteClient.FlyteworkflowV1alpha1().FlyteWorkflows(g.ConfigOverrides.Context.Namespace).Get(name, v1.GetOptions{}) + if err != nil { + return err + } + wp := printers.WorkflowPrinter{} + tree := gotree.New("Workflow") + if err := wp.Print(tree, w); err != nil { + return err + } + fmt.Print(tree.Print()) + return nil +} + +func (g *GetOpts) iterateOverWorkflows(f func(*v1alpha1.FlyteWorkflow) error, batchSize int64, limit int64) error { + if limit > 0 && limit < batchSize { + batchSize = limit + } + t, err := g.GetTimeoutSeconds() + if err != nil { + return err + } + opts := &v1.ListOptions{ + Limit: batchSize, + TimeoutSeconds: &t, + } + var counter int64 + for { + wList, err := g.flyteClient.FlyteworkflowV1alpha1().FlyteWorkflows(g.ConfigOverrides.Context.Namespace).List(*opts) + if err != nil { + return err + } + for _, w := range wList.Items { + if err := f(&w); err != nil { + return err + } + counter++ + if counter == limit { + return nil + } + } + if wList.Continue == "" { + return nil + } + opts.Continue = wList.Continue + } +} + +func (g *GetOpts) listWorkflows() error { + fmt.Printf("Listing workflows in [%s]\n", g.ConfigOverrides.Context.Namespace) + wp := printers.WorkflowPrinter{} + workflows := gotree.New("workflows") + var counter int64 + var succeeded = 0 + var failed = 0 + var running = 0 + var waiting = 0 + err := g.iterateOverWorkflows( + func(w *v1alpha1.FlyteWorkflow) error { + counter++ + if err := wp.PrintShort(workflows, w); err != nil { + return err + } + switch w.GetExecutionStatus().GetPhase() { + case v1alpha1.WorkflowPhaseReady: + waiting++ + case v1alpha1.WorkflowPhaseSuccess: + succeeded++ + case v1alpha1.WorkflowPhaseFailed: + failed++ + default: + running++ + } + if counter%g.chunkSize == 0 { + fmt.Println("") + fmt.Print(workflows.Print()) + workflows = gotree.New("\nworkflows") + } else { + fmt.Print(".") + } + return nil + }, g.chunkSize, g.limit) + if err != nil { + return err + } + fmt.Print(workflows.Print()) + fmt.Printf("Found %d workflows\n", counter) + fmt.Printf("Sucess: %d, Failed: %d, Running: %d, Waiting: %d\n", succeeded, failed, running, waiting) + return nil +} diff --git a/flytepropeller/cmd/kubectl-flyte/cmd/printers/node.go b/flytepropeller/cmd/kubectl-flyte/cmd/printers/node.go new file mode 100644 index 0000000000..d1c935049c --- /dev/null +++ b/flytepropeller/cmd/kubectl-flyte/cmd/printers/node.go @@ -0,0 +1,157 @@ +package printers + +import ( + "fmt" + "strconv" + "strings" + "time" + + "k8s.io/apimachinery/pkg/util/sets" + + gotree "github.com/DiSiqueira/GoTree" + "github.com/fatih/color" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/controller/executors" + "github.com/lyft/flytepropeller/pkg/controller/nodes/task" + "github.com/lyft/flytepropeller/pkg/utils" +) + +var boldString = color.New(color.Bold) + +func ColorizeNodePhase(p v1alpha1.NodePhase) string { + switch p { + case v1alpha1.NodePhaseNotYetStarted: + return p.String() + case v1alpha1.NodePhaseRunning: + return color.YellowString("%s", p.String()) + case v1alpha1.NodePhaseSucceeded: + return color.HiGreenString("%s", p.String()) + case v1alpha1.NodePhaseFailed: + return color.HiRedString("%s", p.String()) + } + return color.CyanString("%s", p.String()) +} + +func CalculateRuntime(s v1alpha1.ExecutableNodeStatus) string { + if s.GetStartedAt() != nil { + if s.GetStoppedAt() != nil { + return s.GetStoppedAt().Sub(s.GetStartedAt().Time).String() + } + return time.Since(s.GetStartedAt().Time).String() + } + return "na" +} + +type NodePrinter struct { + NodeStatusPrinter +} + +func (p NodeStatusPrinter) BaseNodeInfo(node v1alpha1.BaseNode, nodeStatus v1alpha1.ExecutableNodeStatus) []string { + return []string{ + fmt.Sprintf("%s (%s)", boldString.Sprint(node.GetID()), node.GetKind().String()), + CalculateRuntime(nodeStatus), + ColorizeNodePhase(nodeStatus.GetPhase()), + nodeStatus.GetMessage(), + } +} + +func (p NodeStatusPrinter) NodeInfo(wName string, node v1alpha1.BaseNode, nodeStatus v1alpha1.ExecutableNodeStatus) []string { + resourceName, err := utils.FixedLengthUniqueIDForParts(task.IDMaxLength, wName, node.GetID(), strconv.Itoa(int(nodeStatus.GetAttempts()))) + if err != nil { + resourceName = "na" + } + return append( + p.BaseNodeInfo(node, nodeStatus), + fmt.Sprintf("resource=%s", resourceName), + ) +} + +func (p NodePrinter) BranchNodeInfo(node v1alpha1.ExecutableNode, nodeStatus v1alpha1.ExecutableNodeStatus) []string { + info := p.BaseNodeInfo(node, nodeStatus) + branchStatus := nodeStatus.GetOrCreateBranchStatus() + info = append(info, branchStatus.GetPhase().String()) + if branchStatus.GetFinalizedNode() != nil { + info = append(info, *branchStatus.GetFinalizedNode()) + } + return info + +} + +func (p NodePrinter) traverseNode(tree gotree.Tree, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, nodeStatus v1alpha1.ExecutableNodeStatus) error { + switch node.GetKind() { + case v1alpha1.NodeKindBranch: + subTree := tree.Add(strings.Join(p.BranchNodeInfo(node, nodeStatus), " | ")) + f := func(nodeID *v1alpha1.NodeID) error { + if nodeID != nil { + ifNode, ok := w.GetNode(*nodeID) + if !ok { + return fmt.Errorf("failed to find branch node %s", *nodeID) + } + if err := p.traverseNode(subTree, w, ifNode, nodeStatus.GetNodeExecutionStatus(*nodeID)); err != nil { + return err + } + } + return nil + } + if err := f(node.GetBranchNode().GetIf().GetThenNode()); err != nil { + return err + } + if len(node.GetBranchNode().GetElseIf()) > 0 { + for _, n := range node.GetBranchNode().GetElseIf() { + if err := f(n.GetThenNode()); err != nil { + return err + } + } + } + if err := f(node.GetBranchNode().GetElse()); err != nil { + return err + } + case v1alpha1.NodeKindWorkflow: + if node.GetWorkflowNode().GetSubWorkflowRef() != nil { + s := w.FindSubWorkflow(*node.GetWorkflowNode().GetSubWorkflowRef()) + wp := WorkflowPrinter{} + cw := executors.NewSubContextualWorkflow(w, s, nodeStatus) + return wp.Print(tree, cw) + } + case v1alpha1.NodeKindTask: + sub := tree.Add(strings.Join(p.NodeInfo(w.GetName(), node, nodeStatus), " | ")) + if err := p.PrintRecursive(sub, w.GetName(), nodeStatus); err != nil { + return err + } + default: + _ = tree.Add(strings.Join(p.NodeInfo(w.GetName(), node, nodeStatus), " | ")) + } + return nil +} + +func (p NodePrinter) PrintList(tree gotree.Tree, w v1alpha1.ExecutableWorkflow, nodes []v1alpha1.ExecutableNode) error { + for _, n := range nodes { + s := w.GetNodeExecutionStatus(n.GetID()) + if err := p.traverseNode(tree, w, n, s); err != nil { + return err + } + } + return nil +} + +type NodeStatusPrinter struct { +} + +func (p NodeStatusPrinter) PrintRecursive(tree gotree.Tree, wfName string, s v1alpha1.ExecutableNodeStatus) error { + orderedKeys := sets.String{} + allStatuses := map[v1alpha1.NodeID]v1alpha1.ExecutableNodeStatus{} + s.VisitNodeStatuses(func(node v1alpha1.NodeID, status v1alpha1.ExecutableNodeStatus) { + orderedKeys.Insert(node) + allStatuses[node] = status + }) + + for _, id := range orderedKeys.List() { + ns := allStatuses[id] + sub := tree.Add(strings.Join(p.NodeInfo(wfName, &v1alpha1.NodeSpec{ID: id}, ns), " | ")) + if err := p.PrintRecursive(sub, wfName, ns); err != nil { + return err + } + } + + return nil +} diff --git a/flytepropeller/cmd/kubectl-flyte/cmd/printers/workflow.go b/flytepropeller/cmd/kubectl-flyte/cmd/printers/workflow.go new file mode 100644 index 0000000000..f15da36bd1 --- /dev/null +++ b/flytepropeller/cmd/kubectl-flyte/cmd/printers/workflow.go @@ -0,0 +1,63 @@ +package printers + +import ( + "fmt" + "time" + + gotree "github.com/DiSiqueira/GoTree" + "github.com/fatih/color" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/visualize" +) + +func ColorizeWorkflowPhase(p v1alpha1.WorkflowPhase) string { + switch p { + case v1alpha1.WorkflowPhaseReady: + return p.String() + case v1alpha1.WorkflowPhaseRunning: + return color.YellowString("%s", p.String()) + case v1alpha1.WorkflowPhaseSuccess: + return color.HiGreenString("%s", p.String()) + case v1alpha1.WorkflowPhaseFailed: + return color.HiRedString("%s", p.String()) + } + return color.CyanString("%s", p.String()) +} + +func CalculateWorkflowRuntime(s v1alpha1.ExecutableWorkflowStatus) string { + if s.GetStartedAt() != nil { + if s.GetStoppedAt() != nil { + return s.GetStoppedAt().Sub(s.GetStartedAt().Time).String() + } + return time.Since(s.GetStartedAt().Time).String() + } + return "na" +} + +type WorkflowPrinter struct { +} + +func (p WorkflowPrinter) Print(tree gotree.Tree, w v1alpha1.ExecutableWorkflow) error { + sortedNodes, err := visualize.TopologicalSort(w) + if err != nil { + return err + } + newTree := gotree.New(fmt.Sprintf("%s/%s [ExecId: %s] (%s %s %s)", + w.GetNamespace(), boldString.Sprint(w.GetName()), w.GetExecutionID(), CalculateWorkflowRuntime(w.GetExecutionStatus()), + ColorizeWorkflowPhase(w.GetExecutionStatus().GetPhase()), w.GetExecutionStatus().GetMessage())) + if tree != nil { + tree.AddTree(newTree) + } + np := NodePrinter{} + return np.PrintList(newTree, w, sortedNodes) +} + +func (p WorkflowPrinter) PrintShort(tree gotree.Tree, w v1alpha1.ExecutableWorkflow) error { + if tree == nil { + return fmt.Errorf("bad state in printer") + } + tree.Add(fmt.Sprintf("%s/%s [ExecId: %s] (%s %s) - Time SinceCreation(%s)", + w.GetNamespace(), boldString.Sprint(w.GetName()), w.GetExecutionID(), CalculateWorkflowRuntime(w.GetExecutionStatus()), + ColorizeWorkflowPhase(w.GetExecutionStatus().GetPhase()), time.Since(w.GetCreationTimestamp().Time))) + return nil +} diff --git a/flytepropeller/cmd/kubectl-flyte/cmd/root.go b/flytepropeller/cmd/kubectl-flyte/cmd/root.go new file mode 100644 index 0000000000..f56009f43e --- /dev/null +++ b/flytepropeller/cmd/kubectl-flyte/cmd/root.go @@ -0,0 +1,116 @@ +package cmd + +import ( + "context" + "flag" + "fmt" + "os" + "runtime" + "time" + + "github.com/lyft/flytestdlib/config/viper" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/version" + "github.com/spf13/pflag" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/rest" + "k8s.io/client-go/tools/clientcmd" + + flyteclient "github.com/lyft/flytepropeller/pkg/client/clientset/versioned" + "github.com/spf13/cobra" +) + +func init() { + pflag.CommandLine.AddGoFlagSet(flag.CommandLine) + err := flag.CommandLine.Parse([]string{}) + if err != nil { + logger.Error(context.TODO(), "Error in initializing: %v", err) + os.Exit(-1) + } +} + +type RootOptions struct { + *clientcmd.ConfigOverrides + allNamespaces bool + showSource bool + clientConfig clientcmd.ClientConfig + restConfig *rest.Config + kubeClient kubernetes.Interface + flyteClient flyteclient.Interface +} + +func (r *RootOptions) GetTimeoutSeconds() (int64, error) { + if r.Timeout != "" { + d, err := time.ParseDuration(r.Timeout) + if err != nil { + return 10, err + } + return int64(d.Seconds()), nil + } + return 10, nil + +} + +func (r *RootOptions) executeRootCmd() error { + ctx := context.TODO() + logger.Infof(ctx, "Go Version: %s", runtime.Version()) + logger.Infof(ctx, "Go OS/Arch: %s/%s", runtime.GOOS, runtime.GOARCH) + version.LogBuildInformation("kubectl-flyte") + return fmt.Errorf("use one of the sub-commands") +} + +func (r *RootOptions) ConfigureClient() error { + restConfig, err := r.clientConfig.ClientConfig() + if err != nil { + return err + } + k, err := kubernetes.NewForConfig(restConfig) + if err != nil { + return err + } + fc, err := flyteclient.NewForConfig(restConfig) + if err != nil { + return err + } + r.restConfig = restConfig + r.kubeClient = k + r.flyteClient = fc + return nil +} + +// NewCommand returns a new instance of an argo command +func NewFlyteCommand() *cobra.Command { + rootOpts := &RootOptions{} + command := &cobra.Command{ + Use: "kubectl-flyte", + Short: "kubectl-flyte allows launching and managing K8s native workflows", + Long: `Flyte is a serverless workflow processing platform built for native execution on K8s. + It is extensible and flexible to allow adding new operators and comes with many operators built in`, + PersistentPreRunE: func(cmd *cobra.Command, args []string) error { + return rootOpts.ConfigureClient() + }, + RunE: func(cmd *cobra.Command, args []string) error { + return rootOpts.executeRootCmd() + }, + } + + command.AddCommand(NewDeleteCommand(rootOpts)) + command.AddCommand(NewGetCommand(rootOpts)) + command.AddCommand(NewVisualizeCommand(rootOpts)) + command.AddCommand(NewCreateCommand(rootOpts)) + command.AddCommand(NewCompileCommand(rootOpts)) + + loadingRules := clientcmd.NewDefaultClientConfigLoadingRules() + loadingRules.DefaultClientConfig = &clientcmd.DefaultClientConfig + rootOpts.ConfigOverrides = &clientcmd.ConfigOverrides{} + kflags := clientcmd.RecommendedConfigOverrideFlags("") + command.PersistentFlags().StringVar(&loadingRules.ExplicitPath, "kubeconfig", "", "Path to a kube config. Only required if out-of-cluster") + clientcmd.BindOverrideFlags(rootOpts.ConfigOverrides, command.PersistentFlags(), kflags) + rootOpts.clientConfig = clientcmd.NewInteractiveDeferredLoadingClientConfig(loadingRules, rootOpts.ConfigOverrides, os.Stdin) + + command.PersistentFlags().BoolVar(&rootOpts.allNamespaces, "all-namespaces", false, "Enable this flag to execute for all namespaces") + command.PersistentFlags().BoolVarP(&rootOpts.showSource, "show-source", "s", false, "Show line number for errors") + command.AddCommand(viper.GetConfigCommand()) + + return command +} diff --git a/flytepropeller/cmd/kubectl-flyte/cmd/string_map_value.go b/flytepropeller/cmd/kubectl-flyte/cmd/string_map_value.go new file mode 100644 index 0000000000..6f067ab209 --- /dev/null +++ b/flytepropeller/cmd/kubectl-flyte/cmd/string_map_value.go @@ -0,0 +1,83 @@ +package cmd + +import ( + "bytes" + "fmt" + "regexp" + "strings" +) + +// Represents a pflag value that parses a string into a map +type stringMapValue struct { + value *map[string]string + changed bool +} + +func newStringMapValue() *stringMapValue { + return &stringMapValue{ + value: &map[string]string{}, + changed: false, + } +} + +var entryRegex = regexp.MustCompile("(?P[^,]+)=(?P[^,]+)") + +// Parses val into a map. Accepted format: a=1,b=2 +func (s *stringMapValue) Set(val string) error { + matches := entryRegex.FindAllStringSubmatch(val, -1) + out := make(map[string]string, len(matches)) + for _, entry := range matches { + if len(entry) != 3 { + return fmt.Errorf("invalid value for entry. Entries must be formatted as key=value. Found %v", + entry) + } + + out[strings.TrimSpace(entry[1])] = strings.TrimSpace(entry[2]) + } + + if !s.changed { + *s.value = out + } else { + for k, v := range out { + (*s.value)[k] = v + } + } + s.changed = true + return nil +} + +func (s *stringMapValue) Type() string { + return "stringToString" +} + +func (s *stringMapValue) String() string { + var buf bytes.Buffer + i := 0 + for k, v := range *s.value { + if i > 0 { + _, err := buf.WriteRune(',') + if err != nil { + return "" + } + } + + _, err := buf.WriteString(k) + if err != nil { + return "" + } + + _, err = buf.WriteRune('=') + if err != nil { + return "" + } + + _, err = buf.WriteString(v) + if err != nil { + return "" + } + + i++ + } + + return "[" + buf.String() + "]" +} diff --git a/flytepropeller/cmd/kubectl-flyte/cmd/string_map_value_test.go b/flytepropeller/cmd/kubectl-flyte/cmd/string_map_value_test.go new file mode 100644 index 0000000000..ff3ae8d17f --- /dev/null +++ b/flytepropeller/cmd/kubectl-flyte/cmd/string_map_value_test.go @@ -0,0 +1,65 @@ +package cmd + +import ( + "fmt" + "math/rand" + "testing" + + "github.com/stretchr/testify/assert" +) + +func formatArg(values map[string]string) string { + res := "" + for key, value := range values { + res += fmt.Sprintf(",%v%v%v=%v%v%v", randSpaces(), key, randSpaces(), randSpaces(), value, randSpaces()) + } + + if len(values) > 0 { + return res[1:] + } + + return res +} + +func randSpaces() string { + res := "" + for cnt := rand.Int() % 10; cnt > 0; cnt-- { // nolint: gas + res += " " + } + + return res +} + +func runPositiveTest(t *testing.T, expected map[string]string) { + v := newStringMapValue() + assert.NoError(t, v.Set(formatArg(expected))) + + assert.Equal(t, len(expected), len(*v.value)) + assert.Equal(t, expected, *v.value) +} + +func TestSet(t *testing.T) { + t.Run("Simple", func(t *testing.T) { + expected := map[string]string{ + "a": "1", + "b": "2", + "c": "3", + "d.sub": "x.y", + "e": "4", + } + runPositiveTest(t, expected) + }) + + t.Run("Empty arg", func(t *testing.T) { + expected := map[string]string{ + "": "", + "a": "1", + "b": "2", + "c": "3", + "d.sub": "x.y", + "e": "4", + } + + runPositiveTest(t, expected) + }) +} diff --git a/flytepropeller/cmd/kubectl-flyte/cmd/testdata/inputs.json.golden b/flytepropeller/cmd/kubectl-flyte/cmd/testdata/inputs.json.golden new file mode 100755 index 0000000000..0af476b90f --- /dev/null +++ b/flytepropeller/cmd/kubectl-flyte/cmd/testdata/inputs.json.golden @@ -0,0 +1 @@ +{"literals":{"x":{"scalar":{"primitive":{"integer":"2"}}},"y":{"collection":{"literals":[{"scalar":{"primitive":{"stringValue":"val1"}}},{"scalar":{"primitive":{"stringValue":"val2"}}},{"scalar":{"primitive":{"stringValue":"val3"}}}]}}}} \ No newline at end of file diff --git a/flytepropeller/cmd/kubectl-flyte/cmd/testdata/inputs.pb.golden b/flytepropeller/cmd/kubectl-flyte/cmd/testdata/inputs.pb.golden new file mode 100755 index 0000000000..5ff2c7009d --- /dev/null +++ b/flytepropeller/cmd/kubectl-flyte/cmd/testdata/inputs.pb.golden @@ -0,0 +1,19 @@ + + +x + + ++ +y&$ + + + +val1 + + + +val2 + + + +val3 \ No newline at end of file diff --git a/flytepropeller/cmd/kubectl-flyte/cmd/testdata/inputs.yaml.golden b/flytepropeller/cmd/kubectl-flyte/cmd/testdata/inputs.yaml.golden new file mode 100755 index 0000000000..efeefc56a3 --- /dev/null +++ b/flytepropeller/cmd/kubectl-flyte/cmd/testdata/inputs.yaml.golden @@ -0,0 +1,17 @@ +literals: + x: + scalar: + primitive: + integer: "2" + "y": + collection: + literals: + - scalar: + primitive: + stringValue: val1 + - scalar: + primitive: + stringValue: val2 + - scalar: + primitive: + stringValue: val3 diff --git a/flytepropeller/cmd/kubectl-flyte/cmd/testdata/workflow.json.golden b/flytepropeller/cmd/kubectl-flyte/cmd/testdata/workflow.json.golden new file mode 100755 index 0000000000..83c9dc2652 --- /dev/null +++ b/flytepropeller/cmd/kubectl-flyte/cmd/testdata/workflow.json.golden @@ -0,0 +1 @@ +{"workflow":{"id":{"name":"workflow-id-123"},"interface":{"inputs":{"variables":{}}},"nodes":[{"id":"node-1","taskNode":{"referenceId":{"name":"task-1"}}},{"id":"node-2","taskNode":{"referenceId":{"name":"task-2"}}}]},"tasks":[{"id":{"name":"task-1"},"interface":{"inputs":{"variables":{}}},"container":{"image":"myflyteimage:latest","command":["execute-task"],"args":["testArg"]}},{"id":{"name":"task-2"},"interface":{"inputs":{"variables":{}}},"container":{"image":"myflyteimage:latest","command":["execute-task"],"args":["testArg"]}}]} \ No newline at end of file diff --git a/flytepropeller/cmd/kubectl-flyte/cmd/testdata/workflow.pb.golden b/flytepropeller/cmd/kubectl-flyte/cmd/testdata/workflow.pb.golden new file mode 100755 index 0000000000000000000000000000000000000000..4bfb48a1dafc57742dfc9220ce918aee17e24a57 GIT binary patch literal 193 zcmd;b<`PumFV8Q^PRq$J*Ue1PH8e7oV&Y;@65(RY%TGxK^0~M;l-NoVi?ekN;i5)J zqDDeCNGgDOjdZw#b1TzwDoau`a}(23t#T4eQj1H3cv34;lS@ldbwP$nv4g}Mi_(#F J8!3^j8vtv^HK70i literal 0 HcmV?d00001 diff --git a/flytepropeller/cmd/kubectl-flyte/cmd/testdata/workflow.yaml.golden b/flytepropeller/cmd/kubectl-flyte/cmd/testdata/workflow.yaml.golden new file mode 100755 index 0000000000..261175022f --- /dev/null +++ b/flytepropeller/cmd/kubectl-flyte/cmd/testdata/workflow.yaml.golden @@ -0,0 +1,38 @@ +tasks: +- container: + args: + - testArg + command: + - execute-task + image: myflyteimage:latest + id: + name: task-1 + interface: + inputs: + variables: {} +- container: + args: + - testArg + command: + - execute-task + image: myflyteimage:latest + id: + name: task-2 + interface: + inputs: + variables: {} +workflow: + id: + name: workflow-id-123 + interface: + inputs: + variables: {} + nodes: + - id: node-1 + taskNode: + referenceId: + name: task-1 + - id: node-2 + taskNode: + referenceId: + name: task-2 diff --git a/flytepropeller/cmd/kubectl-flyte/cmd/testdata/workflow_w_inputs.json.golden b/flytepropeller/cmd/kubectl-flyte/cmd/testdata/workflow_w_inputs.json.golden new file mode 100755 index 0000000000..08fd7df321 --- /dev/null +++ b/flytepropeller/cmd/kubectl-flyte/cmd/testdata/workflow_w_inputs.json.golden @@ -0,0 +1 @@ +{"workflow":{"id":{"name":"workflow-with-inputs"},"interface":{"inputs":{"variables":{"x":{"type":{"simple":"INTEGER"}},"y":{"type":{"collectionType":{"simple":"STRING"}}}}}},"nodes":[{"id":"node-1","inputs":[{"var":"x","binding":{"promise":{"var":"x"}}},{"var":"y","binding":{"promise":{"var":"y"}}}],"taskNode":{"referenceId":{"name":"task-1"}}},{"id":"node-2","taskNode":{"referenceId":{"name":"task-2"}}}]},"tasks":[{"id":{"name":"task-1"},"interface":{"inputs":{"variables":{"x":{"type":{"simple":"INTEGER"}},"y":{"type":{"collectionType":{"simple":"STRING"}}}}}},"container":{"image":"myflyteimage:latest","command":["execute-task"],"args":["testArg"],"resources":{"requests":[{"name":"CPU","value":"2"},{"name":"MEMORY","value":"2048Mi"}]}}},{"id":{"name":"task-2"},"interface":{"inputs":{"variables":{}}},"container":{"image":"myflyteimage:latest","command":["execute-task"],"args":["testArg"]}}]} \ No newline at end of file diff --git a/flytepropeller/cmd/kubectl-flyte/cmd/testdata/workflow_w_inputs.pb.golden b/flytepropeller/cmd/kubectl-flyte/cmd/testdata/workflow_w_inputs.pb.golden new file mode 100755 index 0000000000000000000000000000000000000000..0db9aa296b0171d57b56d2534d9895e863580129 GIT binary patch literal 291 zcmb7-y9&ZE7)DKR#$TscL^5V>zy=2qmp*_GQ8cZgH)&IvyuAcPItgwchr@Tkz>dXJ z*H=Z|O?*=N!BzVl^*|}?VFk#L06dIf4akxR0_tI(B@ngUUGkKIGeYVUWZbszZ3O^k zk^umuw3mMg+N% VPa<(Rp00{qzA3>xScv4G%NOXALH_^% literal 0 HcmV?d00001 diff --git a/flytepropeller/cmd/kubectl-flyte/cmd/testdata/workflow_w_inputs.yaml.golden b/flytepropeller/cmd/kubectl-flyte/cmd/testdata/workflow_w_inputs.yaml.golden new file mode 100755 index 0000000000..ce937c2c9b --- /dev/null +++ b/flytepropeller/cmd/kubectl-flyte/cmd/testdata/workflow_w_inputs.yaml.golden @@ -0,0 +1,67 @@ +tasks: +- container: + args: + - testArg + command: + - execute-task + image: myflyteimage:latest + resources: + requests: + - name: CPU + value: "2" + - name: MEMORY + value: 2048Mi + id: + name: task-1 + interface: + inputs: + variables: + x: + type: + simple: INTEGER + "y": + type: + collectionType: + simple: STRING +- container: + args: + - testArg + command: + - execute-task + image: myflyteimage:latest + id: + name: task-2 + interface: + inputs: + variables: {} +workflow: + id: + name: workflow-with-inputs + interface: + inputs: + variables: + x: + type: + simple: INTEGER + "y": + type: + collectionType: + simple: STRING + nodes: + - id: node-1 + inputs: + - binding: + promise: + var: x + var: x + - binding: + promise: + var: "y" + var: "y" + taskNode: + referenceId: + name: task-1 + - id: node-2 + taskNode: + referenceId: + name: task-2 diff --git a/flytepropeller/cmd/kubectl-flyte/cmd/util.go b/flytepropeller/cmd/kubectl-flyte/cmd/util.go new file mode 100644 index 0000000000..a2346ae9ab --- /dev/null +++ b/flytepropeller/cmd/kubectl-flyte/cmd/util.go @@ -0,0 +1,18 @@ +package cmd + +import ( + "fmt" + + "github.com/spf13/cobra" +) + +func requiredFlags(cmd *cobra.Command, flags ...string) error { + for _, flag := range flags { + f := cmd.Flag(flag) + if f == nil { + return fmt.Errorf("unable to find Key [%v]", flag) + } + } + + return nil +} diff --git a/flytepropeller/cmd/kubectl-flyte/cmd/visualize.go b/flytepropeller/cmd/kubectl-flyte/cmd/visualize.go new file mode 100644 index 0000000000..561363c3ea --- /dev/null +++ b/flytepropeller/cmd/kubectl-flyte/cmd/visualize.go @@ -0,0 +1,39 @@ +package cmd + +import ( + "fmt" + + "github.com/lyft/flytepropeller/pkg/visualize" + "github.com/spf13/cobra" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +type VisualizeOpts struct { + *RootOptions +} + +func NewVisualizeCommand(opts *RootOptions) *cobra.Command { + + vizOpts := &VisualizeOpts{ + RootOptions: opts, + } + + visualizeCmd := &cobra.Command{ + Use: "visualize ", + Short: "Get GraphViz dot-formatted output.", + Long: `Generates GraphViz dot-formatted output for the workflow.`, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + name := args[0] + w, err := vizOpts.flyteClient.FlyteworkflowV1alpha1().FlyteWorkflows(vizOpts.ConfigOverrides.Context.Namespace).Get(name, v1.GetOptions{}) + if err != nil { + return err + } + + fmt.Printf("Dot-formatted: %v\n", visualize.WorkflowToGraphViz(w)) + return nil + }, + } + + return visualizeCmd +} diff --git a/flytepropeller/cmd/kubectl-flyte/main.go b/flytepropeller/cmd/kubectl-flyte/main.go new file mode 100644 index 0000000000..f0e7fa82b0 --- /dev/null +++ b/flytepropeller/cmd/kubectl-flyte/main.go @@ -0,0 +1,17 @@ +package main + +import ( + "fmt" + "os" + + "github.com/lyft/flytepropeller/cmd/kubectl-flyte/cmd" +) + +func main() { + + rootCmd := cmd.NewFlyteCommand() + if err := rootCmd.Execute(); err != nil { + fmt.Println(err) + os.Exit(1) + } +} diff --git a/flytepropeller/config.yaml b/flytepropeller/config.yaml new file mode 100644 index 0000000000..c50dbc4b31 --- /dev/null +++ b/flytepropeller/config.yaml @@ -0,0 +1,95 @@ +# This is a sample configuration file. +# Real configuration when running inside K8s (local or otherwise) lives in a ConfigMap +propeller: + workers: 4 + workflow-reeval-duration: 10s + downstream-eval-duration: 5s + limit-namespace: "all" + prof-port: 11254 + metrics-prefix: flyte + enable-admin-launcher: true + max-ttl-hours: 1 + gc-interval: 500m + queue: + type: batch + queue: + type: bucket + rate: 20 + capacity: 100 + sub-queue: + type: bucket + rate: 100 + capacity: 1000 + kube-config: "$HOME/.kube/config" + publish-k8s-events: true +# Sample plugins config +plugins: + # Set of enabled plugins at root level + enabled-plugins: + - container + - waitable + - K8S-ARRAY + # All k8s plugins default configuration + k8s: + inject-finalizer: true + default-annotations: + - annotationKey1: annotationValue1 + resource-tolerations: + nvidia.com/gpu: + key: flyte/gpu + value: dedicated + operator: Equal + effect: NoSchedule + default-env-vars: + - AWS_METADATA_SERVICE_TIMEOUT: 5 + - AWS_METADATA_SERVICE_NUM_ATTEMPTS: 20 + - FLYTE_AWS_ENDPOINT: "http://minio.flyte:9000" + - FLYTE_AWS_ACCESS_KEY_ID: minio + - FLYTE_AWS_SECRET_ACCESS_KEY: miniostorage + # Spark Plugin configuration + spark: + spark-config-default: + - spark.hadoop.mapreduce.fileoutputcommitter.algorithm.version: "2" + - spark.kubernetes.allocation.batch.size: "50" + - spark.hadoop.fs.s3a.acl.default: "BucketOwnerFullControl" + - spark.hadoop.fs.s3n.impl: "org.apache.hadoop.fs.s3a.S3AFileSystem" + - spark.hadoop.fs.AbstractFileSystem.s3n.impl: "org.apache.hadoop.fs.s3a.S3A" + - spark.hadoop.fs.s3.impl: "org.apache.hadoop.fs.s3a.S3AFileSystem" + - spark.hadoop.fs.AbstractFileSystem.s3.impl: "org.apache.hadoop.fs.s3a.S3A" + - spark.hadoop.fs.s3a.impl: "org.apache.hadoop.fs.s3a.S3AFileSystem" + - spark.hadoop.fs.AbstractFileSystem.s3a.impl: "org.apache.hadoop.fs.s3a.S3A" + - spark.hadoop.fs.s3a.multipart.threshold: "536870912" + - spark.blacklist.enabled: "true" + - spark.blacklist.timeout: "5m" + # Waitable plugin configuration + waitable: + console-uri: http://localhost:30081/console + # Logging configuration + logs: + kubernetes-enabled: true + kubernetes-url: "http://localhost:30082" +storage: + connection: + access-key: minio + auth-type: accesskey + disable-ssl: true + endpoint: http://localhost:9000 + region: us-east-1 + secret-key: miniostorage + cache: + max_size_mbs: 10 + target_gc_percent: 100 + container: myflytecontainer + type: minio +event: + type: admin + rate: 500 + capacity: 1000 +admin: + endpoint: localhost:8089 + insecure: true +errors: + show-source: true +logger: + level: 4 + show-source: true diff --git a/flytepropeller/hack/boilerplate.go.txt b/flytepropeller/hack/boilerplate.go.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/flytepropeller/hack/custom-boilerplate.go.txt b/flytepropeller/hack/custom-boilerplate.go.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/flytepropeller/hack/update-codegen.sh b/flytepropeller/hack/update-codegen.sh new file mode 100755 index 0000000000..e116e7cc88 --- /dev/null +++ b/flytepropeller/hack/update-codegen.sh @@ -0,0 +1,43 @@ +#!/bin/bash + +# Copyright 2017 The Kubernetes Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file was derived from https://github.com/kubernetes/sample-controller/blob/4d46ec53ca337e118754c5fc50f02634b6a83380/hack/update-codegen.sh + +set -o errexit +set -o nounset +set -o pipefail + +: "${RESOURCE_NAME:?should be set for CRD}" +: "${OPERATOR_PKG:?should be set for operator}" + +echo "Generating CRD: ${RESOURCE_NAME}, in package ${OPERATOR_PKG}..." + +SCRIPT_ROOT=$(dirname ${BASH_SOURCE})/.. +CODEGEN_PKG=${CODEGEN_PKG:-$(cd ${SCRIPT_ROOT}; ls -d -1 ./vendor/k8s.io/code-generator 2>/dev/null || echo ../code-generator)} + +# generate the code with: +# --output-base because this script should also be able to run inside the vendor dir of +# k8s.io/kubernetes. The output-base is needed for the generators to output into the vendor dir +# instead of the $GOPATH directly. For normal projects this can be dropped. +${CODEGEN_PKG}/generate-groups.sh "deepcopy,client,informer,lister" \ + ${OPERATOR_PKG}/pkg/client \ + ${OPERATOR_PKG}/pkg/apis \ + ${RESOURCE_NAME}:v1alpha1 \ + --output-base "$(dirname ${BASH_SOURCE})/../../../.." \ + --go-header-file ${SCRIPT_ROOT}/hack/boilerplate.go.txt + +# To use your own boilerplate text use: +# --go-header-file ${SCRIPT_ROOT}/hack/custom-boilerplate.go.txt diff --git a/flytepropeller/hack/verify-codegen.sh b/flytepropeller/hack/verify-codegen.sh new file mode 100755 index 0000000000..fb944feda9 --- /dev/null +++ b/flytepropeller/hack/verify-codegen.sh @@ -0,0 +1,50 @@ +#!/bin/bash + +# Copyright 2017 The Kubernetes Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file was derived from https://raw.githubusercontent.com/kubernetes/sample-controller/7ec2c1043bd0e5b511bfdf79eb215bc429effa24/hack/verify-codegen.sh + +set -o errexit +set -o nounset +set -o pipefail + +SCRIPT_ROOT=$(dirname "${BASH_SOURCE}")/.. + +DIFFROOT="${SCRIPT_ROOT}/pkg" +TMP_DIFFROOT="${SCRIPT_ROOT}/_tmp/pkg" +_tmp="${SCRIPT_ROOT}/_tmp" + +cleanup() { + rm -rf "${_tmp}" +} +trap "cleanup" EXIT SIGINT + +cleanup + +mkdir -p "${TMP_DIFFROOT}" +cp -a "${DIFFROOT}"/* "${TMP_DIFFROOT}" + +"${SCRIPT_ROOT}/hack/update-codegen.sh" +echo "diffing ${DIFFROOT} against freshly generated codegen" +ret=0 +diff -Naupr "${DIFFROOT}" "${TMP_DIFFROOT}" || ret=$? +cp -a "${TMP_DIFFROOT}"/* "${DIFFROOT}" +if [[ $ret -eq 0 ]] +then + echo "${DIFFROOT} up to date." +else + echo "${DIFFROOT} is out of date. Please run hack/update-codegen.sh" + exit 1 +fi diff --git a/flytepropeller/pkg/apis/flyteworkflow/register.go b/flytepropeller/pkg/apis/flyteworkflow/register.go new file mode 100644 index 0000000000..6ea43f4567 --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/register.go @@ -0,0 +1,5 @@ +package flyteworkflow + +const ( + GroupName = "flyte.lyft.com" +) diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/branch.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/branch.go new file mode 100644 index 0000000000..692e73e0e7 --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/branch.go @@ -0,0 +1,100 @@ +package v1alpha1 + +import ( + "bytes" + + "github.com/golang/protobuf/jsonpb" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" +) + +type Error struct { + *core.Error +} + +func (in Error) UnmarshalJSON(b []byte) error { + in.Error = &core.Error{} + return jsonpb.Unmarshal(bytes.NewReader(b), in.Error) +} + +func (in Error) MarshalJSON() ([]byte, error) { + var buf bytes.Buffer + if err := marshaler.Marshal(&buf, in.Error); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *Error) DeepCopyInto(out *Error) { + *out = *in + // We do not manipulate the object, so its ok + // Once we figure out the autogenerate story we can replace this + +} + +type BooleanExpression struct { + *core.BooleanExpression +} + +func (in BooleanExpression) MarshalJSON() ([]byte, error) { + var buf bytes.Buffer + if err := marshaler.Marshal(&buf, in.BooleanExpression); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +func (in *BooleanExpression) UnmarshalJSON(b []byte) error { + in.BooleanExpression = &core.BooleanExpression{} + return jsonpb.Unmarshal(bytes.NewReader(b), in.BooleanExpression) +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *BooleanExpression) DeepCopyInto(out *BooleanExpression) { + *out = *in + // We do not manipulate the object, so its ok + // Once we figure out the autogenerate story we can replace this +} + +type IfBlock struct { + Condition BooleanExpression `json:"condition"` + ThenNode *NodeID `json:"then"` +} + +func (in IfBlock) GetCondition() *core.BooleanExpression { + return in.Condition.BooleanExpression +} + +func (in *IfBlock) GetThenNode() *NodeID { + return in.ThenNode +} + +type BranchNodeSpec struct { + If IfBlock `json:"if"` + ElseIf []*IfBlock `json:"elseIf,omitempty"` + Else *NodeID `json:"else,omitempty"` + ElseFail *Error `json:"elseFail,omitempty"` +} + +func (in *BranchNodeSpec) GetIf() ExecutableIfBlock { + return &in.If +} + +func (in *BranchNodeSpec) GetElse() *NodeID { + return in.Else +} + +func (in *BranchNodeSpec) GetElseIf() []ExecutableIfBlock { + elifs := make([]ExecutableIfBlock, 0, len(in.ElseIf)) + for _, b := range in.ElseIf { + elifs = append(elifs, b) + } + return elifs +} + +func (in *BranchNodeSpec) GetElseFail() *core.Error { + if in.ElseFail != nil { + return in.ElseFail.Error + } + return nil +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/branch_test.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/branch_test.go new file mode 100644 index 0000000000..3f23f5553b --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/branch_test.go @@ -0,0 +1,22 @@ +package v1alpha1_test + +import ( + "encoding/json" + "io/ioutil" + "testing" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/stretchr/testify/assert" +) + +func TestMarshalUnMarshal_BranchTask(t *testing.T) { + r, err := ioutil.ReadFile("testdata/branch.json") + assert.NoError(t, err) + o := v1alpha1.NodeSpec{} + err = json.Unmarshal(r, &o) + assert.NoError(t, err) + assert.NotNil(t, o.BranchNode.If) + assert.Equal(t, core.ComparisonExpression_GT, o.BranchNode.If.Condition.BooleanExpression.GetComparison().Operator) + assert.Equal(t, 1, len(o.InputBindings)) +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/doc.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/doc.go new file mode 100644 index 0000000000..37762e6968 --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/doc.go @@ -0,0 +1,5 @@ +// +k8s:deepcopy-gen=package + +// Package v1alpha1 is the v1alpha1 version of the API. +// +groupName=flyteworkflow.flyte.net +package v1alpha1 diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/identifier.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/identifier.go new file mode 100644 index 0000000000..ff102680c7 --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/identifier.go @@ -0,0 +1,45 @@ +package v1alpha1 + +import ( + "bytes" + + "github.com/golang/protobuf/jsonpb" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" +) + +type Identifier struct { + *core.Identifier +} + +func (in *Identifier) UnmarshalJSON(b []byte) error { + in.Identifier = &core.Identifier{} + return jsonpb.Unmarshal(bytes.NewReader(b), in.Identifier) +} + +func (in *Identifier) MarshalJSON() ([]byte, error) { + var buf bytes.Buffer + if err := marshaler.Marshal(&buf, in.Identifier); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +func (in *Identifier) DeepCopyInto(out *Identifier) { + *out = *in +} + +type WorkflowExecutionIdentifier struct { + *core.WorkflowExecutionIdentifier +} + +func (in *WorkflowExecutionIdentifier) DeepCopyInto(out *WorkflowExecutionIdentifier) { + *out = *in +} + +type TaskExecutionIdentifier struct { + *core.TaskExecutionIdentifier +} + +func (in *TaskExecutionIdentifier) DeepCopyInto(out *TaskExecutionIdentifier) { + *out = *in +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/iface.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/iface.go new file mode 100644 index 0000000000..deea5db502 --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/iface.go @@ -0,0 +1,391 @@ +package v1alpha1 + +import ( + "context" + + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + types2 "github.com/lyft/flyteplugins/go/tasks/v1/types" + "github.com/lyft/flytestdlib/storage" +) + +// The intention of these interfaces is to decouple the algorithm and usage from the actual CRD definition. +// this would help in ease of changes underneath without affecting the code. + +//go:generate mockery -all + +type CustomState map[string]interface{} +type WorkflowID = string +type TaskID = string +type NodeID = string +type LaunchPlanRefID = Identifier +type ExecutionID = WorkflowExecutionIdentifier + +// NodeKind refers to the type of Node. +type NodeKind string + +func (n NodeKind) String() string { + return string(n) +} + +type DataReference = storage.DataReference + +const ( + // TODO Should we default a NodeKindTask to empty? thus we can assume all unspecified nodetypes as task + NodeKindTask NodeKind = "task" + NodeKindBranch NodeKind = "branch" // A Branch node with conditions + NodeKindWorkflow NodeKind = "workflow" // Either an inline workflow or a remote workflow definition + NodeKindStart NodeKind = "start" // Start node is a special node + NodeKindEnd NodeKind = "end" +) + +// NodePhase indicates the current state of the Node (phase). A node progresses through these states +type NodePhase int + +const ( + NodePhaseNotYetStarted NodePhase = iota + NodePhaseQueued + NodePhaseRunning + NodePhaseFailing + NodePhaseSucceeding + NodePhaseSucceeded + NodePhaseFailed + NodePhaseSkipped + NodePhaseRetryableFailure +) + +func (p NodePhase) String() string { + switch p { + case NodePhaseNotYetStarted: + return "NotYetStarted" + case NodePhaseQueued: + return "Queued" + case NodePhaseRunning: + return "Running" + case NodePhaseSucceeding: + return "Succeeding" + case NodePhaseSucceeded: + return "Succeeded" + case NodePhaseFailed: + return "Failed" + case NodePhaseSkipped: + return "Skipped" + case NodePhaseRetryableFailure: + return "RetryableFailure" + } + return "Unknown" +} + +// WorkflowPhase indicates current state of the Workflow. +type WorkflowPhase int + +const ( + WorkflowPhaseReady WorkflowPhase = iota + WorkflowPhaseRunning + WorkflowPhaseSucceeding + WorkflowPhaseSuccess + WorkflowPhaseFailing + WorkflowPhaseFailed + WorkflowPhaseAborted +) + +func (p WorkflowPhase) String() string { + switch p { + case WorkflowPhaseReady: + return "Ready" + case WorkflowPhaseRunning: + return "Running" + case WorkflowPhaseSuccess: + return "Succeeded" + case WorkflowPhaseFailed: + return "Failed" + case WorkflowPhaseFailing: + return "Failing" + case WorkflowPhaseSucceeding: + return "Succeeding" + case WorkflowPhaseAborted: + return "Aborted" + } + return "Unknown" +} + +// A branchNode has its own Phases. These are used by the child nodes to ensure that the branch node is in the right state +type BranchNodePhase int + +const ( + BranchNodeNotYetEvaluated BranchNodePhase = iota + BranchNodeSuccess + BranchNodeError +) + +func (b BranchNodePhase) String() string { + switch b { + case BranchNodeNotYetEvaluated: + return "NotYetEvaluated" + case BranchNodeSuccess: + return "BranchEvalSuccess" + case BranchNodeError: + return "BranchEvalFailed" + } + return "Undefined" +} + +// TaskType is a dynamic enumeration, that is defined by configuration +type TaskType = string + +// Interface for a Task that can be executed +type ExecutableTask interface { + TaskType() TaskType + CoreTask() *core.TaskTemplate +} + +// Interface for the executable If block +type ExecutableIfBlock interface { + GetCondition() *core.BooleanExpression + GetThenNode() *NodeID +} + +// Interface for branch node status. This is the mutable API for a branch node +type ExecutableBranchNodeStatus interface { + GetPhase() BranchNodePhase + GetFinalizedNode() *NodeID +} + +type MutableBranchNodeStatus interface { + ExecutableBranchNodeStatus + + SetBranchNodeError() + SetBranchNodeSuccess(id NodeID) +} + +// Interface for dynamic node status. +type ExecutableDynamicNodeStatus interface { + GetDynamicNodePhase() DynamicNodePhase +} + +type MutableDynamicNodeStatus interface { + ExecutableDynamicNodeStatus + + SetDynamicNodePhase(phase DynamicNodePhase) +} + +// Interface for Branch node. All the methods are purely read only except for the GetExecutionStatus. +// Phase returns ExecutableBranchNodeStatus, which permits some mutations +type ExecutableBranchNode interface { + GetIf() ExecutableIfBlock + GetElse() *NodeID + GetElseIf() []ExecutableIfBlock + GetElseFail() *core.Error +} + +type ExecutableWorkflowNodeStatus interface { + // Name of the child execution. We only store name since the project and domain will be + // the same as the parent workflow execution. + GetWorkflowExecutionName() string +} + +type MutableWorkflowNodeStatus interface { + ExecutableWorkflowNodeStatus + + // Sets the name of the child execution. We only store name since the project and domain + // will be the same as the parent workflow execution. + SetWorkflowExecutionName(name string) +} + +type MutableNodeStatus interface { + // Mutation API's + SetDataDir(DataReference) + SetParentNodeID(n *NodeID) + SetParentTaskID(t *core.TaskExecutionIdentifier) + UpdatePhase(phase NodePhase, occurredAt metav1.Time, reason string) + IncrementAttempts() uint32 + SetCached() + ResetDirty() + + GetOrCreateBranchStatus() MutableBranchNodeStatus + GetOrCreateWorkflowStatus() MutableWorkflowNodeStatus + ClearWorkflowStatus() + GetOrCreateTaskStatus() MutableTaskNodeStatus + ClearTaskStatus() + GetOrCreateSubWorkflowStatus() MutableSubWorkflowNodeStatus + ClearSubWorkflowStatus() + GetOrCreateDynamicNodeStatus() MutableDynamicNodeStatus + ClearDynamicNodeStatus() +} + +// Interface for a Node Phase. This provides a mutable API. +type ExecutableNodeStatus interface { + NodeStatusGetter + MutableNodeStatus + NodeStatusVisitor + GetPhase() NodePhase + GetQueuedAt() *metav1.Time + GetStoppedAt() *metav1.Time + GetStartedAt() *metav1.Time + GetLastUpdatedAt() *metav1.Time + GetParentNodeID() *NodeID + GetParentTaskID() *core.TaskExecutionIdentifier + GetDataDir() DataReference + GetMessage() string + GetAttempts() uint32 + GetWorkflowNodeStatus() ExecutableWorkflowNodeStatus + GetTaskNodeStatus() ExecutableTaskNodeStatus + GetSubWorkflowNodeStatus() ExecutableSubWorkflowNodeStatus + + IsCached() bool + IsDirty() bool +} + +type ExecutableSubWorkflowNodeStatus interface { + GetPhase() WorkflowPhase +} + +type MutableSubWorkflowNodeStatus interface { + ExecutableSubWorkflowNodeStatus + SetPhase(phase WorkflowPhase) +} + +type ExecutableTaskNodeStatus interface { + GetPhase() types2.TaskPhase + GetPhaseVersion() uint32 + GetCustomState() types2.CustomState +} + +type MutableTaskNodeStatus interface { + ExecutableTaskNodeStatus + SetPhase(phase types2.TaskPhase) + SetPhaseVersion(version uint32) + SetCustomState(state types2.CustomState) +} + +// Interface for a Child Workflow Node +type ExecutableWorkflowNode interface { + GetLaunchPlanRefID() *LaunchPlanRefID + GetSubWorkflowRef() *WorkflowID +} + +type BaseNode interface { + GetID() NodeID + GetKind() NodeKind +} + +// Interface for the Executable Node +type ExecutableNode interface { + BaseNode + IsStartNode() bool + IsEndNode() bool + GetTaskID() *TaskID + GetBranchNode() ExecutableBranchNode + GetWorkflowNode() ExecutableWorkflowNode + GetOutputAlias() []Alias + GetInputBindings() []*Binding + GetResources() *v1.ResourceRequirements + GetConfig() *v1.ConfigMap + GetRetryStrategy() *RetryStrategy +} + +// Interface for the Workflow Phase. This is the mutable portion for a Workflow +type ExecutableWorkflowStatus interface { + NodeStatusGetter + UpdatePhase(p WorkflowPhase, msg string) + GetPhase() WorkflowPhase + GetStoppedAt() *metav1.Time + GetStartedAt() *metav1.Time + GetLastUpdatedAt() *metav1.Time + IsTerminated() bool + GetMessage() string + SetDataDir(DataReference) + GetDataDir() DataReference + GetOutputReference() DataReference + SetOutputReference(reference DataReference) + IncFailedAttempts() + SetMessage(msg string) + ConstructNodeDataDir(ctx context.Context, constructor storage.ReferenceConstructor, name NodeID) (storage.DataReference, error) +} + +type BaseWorkflow interface { + StartNode() ExecutableNode + GetID() WorkflowID + // From returns all nodes that can be reached directly + // from the node with the given unique name. + FromNode(name NodeID) ([]NodeID, error) + GetNode(nodeID NodeID) (ExecutableNode, bool) +} + +type BaseWorkflowWithStatus interface { + BaseWorkflow + NodeStatusGetter +} + +// This interface captures the methods available on any workflow (top level or child). The Meta section is available +// only for the top level workflow +type ExecutableSubWorkflow interface { + BaseWorkflow + GetOutputBindings() []*Binding + GetOnFailureNode() ExecutableNode + GetNodes() []NodeID + GetConnections() *Connections + GetOutputs() *OutputVarMap +} + +// WorkflowMeta provides an interface to retrieve labels, annotations and other concepts that are declared only once +// for the top level workflow +type WorkflowMeta interface { + GetExecutionID() ExecutionID + GetK8sWorkflowID() types.NamespacedName + NewControllerRef() metav1.OwnerReference + GetNamespace() string + GetCreationTimestamp() metav1.Time + GetAnnotations() map[string]string + GetLabels() map[string]string + GetName() string + GetServiceAccountName() string +} + +type WorkflowMetaExtended interface { + WorkflowMeta + GetTask(id TaskID) (ExecutableTask, error) + FindSubWorkflow(subID WorkflowID) ExecutableSubWorkflow + GetExecutionStatus() ExecutableWorkflowStatus +} + +// A Top level Workflow is a combination of WorkflowMeta and an ExecutableSubWorkflow +type ExecutableWorkflow interface { + ExecutableSubWorkflow + WorkflowMetaExtended + NodeStatusGetter +} + +type NodeStatusGetter interface { + GetNodeExecutionStatus(id NodeID) ExecutableNodeStatus +} + +type NodeStatusMap = map[NodeID]ExecutableNodeStatus + +type NodeStatusVisitFn = func(node NodeID, status ExecutableNodeStatus) + +type NodeStatusVisitor interface { + VisitNodeStatuses(visitor NodeStatusVisitFn) +} + +// Simple callback that can be used to indicate that the workflow with WorkflowID should be re-enqueued for examination. +type EnqueueWorkflow func(workflowID WorkflowID) + +func GetOutputsFile(outputDir DataReference) DataReference { + return outputDir + "/outputs.pb" +} + +func GetInputsFile(inputDir DataReference) DataReference { + return inputDir + "/inputs.pb" +} + +func GetOutputErrorFile(inputDir DataReference) DataReference { + return inputDir + "/error.pb" +} + +func GetFutureFile() string { + return "futures.pb" +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/BaseNode.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/BaseNode.go new file mode 100644 index 0000000000..a6ed88364f --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/BaseNode.go @@ -0,0 +1,39 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" +import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +// BaseNode is an autogenerated mock type for the BaseNode type +type BaseNode struct { + mock.Mock +} + +// GetID provides a mock function with given fields: +func (_m *BaseNode) GetID() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// GetKind provides a mock function with given fields: +func (_m *BaseNode) GetKind() v1alpha1.NodeKind { + ret := _m.Called() + + var r0 v1alpha1.NodeKind + if rf, ok := ret.Get(0).(func() v1alpha1.NodeKind); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1alpha1.NodeKind) + } + + return r0 +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/BaseWorkflow.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/BaseWorkflow.go new file mode 100644 index 0000000000..5f4dbb259f --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/BaseWorkflow.go @@ -0,0 +1,87 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" +import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +// BaseWorkflow is an autogenerated mock type for the BaseWorkflow type +type BaseWorkflow struct { + mock.Mock +} + +// FromNode provides a mock function with given fields: name +func (_m *BaseWorkflow) FromNode(name string) ([]string, error) { + ret := _m.Called(name) + + var r0 []string + if rf, ok := ret.Get(0).(func(string) []string); ok { + r0 = rf(name) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(name) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetID provides a mock function with given fields: +func (_m *BaseWorkflow) GetID() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// GetNode provides a mock function with given fields: nodeID +func (_m *BaseWorkflow) GetNode(nodeID string) (v1alpha1.ExecutableNode, bool) { + ret := _m.Called(nodeID) + + var r0 v1alpha1.ExecutableNode + if rf, ok := ret.Get(0).(func(string) v1alpha1.ExecutableNode); ok { + r0 = rf(nodeID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableNode) + } + } + + var r1 bool + if rf, ok := ret.Get(1).(func(string) bool); ok { + r1 = rf(nodeID) + } else { + r1 = ret.Get(1).(bool) + } + + return r0, r1 +} + +// StartNode provides a mock function with given fields: +func (_m *BaseWorkflow) StartNode() v1alpha1.ExecutableNode { + ret := _m.Called() + + var r0 v1alpha1.ExecutableNode + if rf, ok := ret.Get(0).(func() v1alpha1.ExecutableNode); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableNode) + } + } + + return r0 +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/BaseWorkflowWithStatus.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/BaseWorkflowWithStatus.go new file mode 100644 index 0000000000..fd33751164 --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/BaseWorkflowWithStatus.go @@ -0,0 +1,103 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" +import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +// BaseWorkflowWithStatus is an autogenerated mock type for the BaseWorkflowWithStatus type +type BaseWorkflowWithStatus struct { + mock.Mock +} + +// FromNode provides a mock function with given fields: name +func (_m *BaseWorkflowWithStatus) FromNode(name string) ([]string, error) { + ret := _m.Called(name) + + var r0 []string + if rf, ok := ret.Get(0).(func(string) []string); ok { + r0 = rf(name) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(name) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetID provides a mock function with given fields: +func (_m *BaseWorkflowWithStatus) GetID() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// GetNode provides a mock function with given fields: nodeID +func (_m *BaseWorkflowWithStatus) GetNode(nodeID string) (v1alpha1.ExecutableNode, bool) { + ret := _m.Called(nodeID) + + var r0 v1alpha1.ExecutableNode + if rf, ok := ret.Get(0).(func(string) v1alpha1.ExecutableNode); ok { + r0 = rf(nodeID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableNode) + } + } + + var r1 bool + if rf, ok := ret.Get(1).(func(string) bool); ok { + r1 = rf(nodeID) + } else { + r1 = ret.Get(1).(bool) + } + + return r0, r1 +} + +// GetNodeExecutionStatus provides a mock function with given fields: id +func (_m *BaseWorkflowWithStatus) GetNodeExecutionStatus(id string) v1alpha1.ExecutableNodeStatus { + ret := _m.Called(id) + + var r0 v1alpha1.ExecutableNodeStatus + if rf, ok := ret.Get(0).(func(string) v1alpha1.ExecutableNodeStatus); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableNodeStatus) + } + } + + return r0 +} + +// StartNode provides a mock function with given fields: +func (_m *BaseWorkflowWithStatus) StartNode() v1alpha1.ExecutableNode { + ret := _m.Called() + + var r0 v1alpha1.ExecutableNode + if rf, ok := ret.Get(0).(func() v1alpha1.ExecutableNode); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableNode) + } + } + + return r0 +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableBranchNode.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableBranchNode.go new file mode 100644 index 0000000000..6c2c3948f6 --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableBranchNode.go @@ -0,0 +1,76 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" +import mock "github.com/stretchr/testify/mock" +import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +// ExecutableBranchNode is an autogenerated mock type for the ExecutableBranchNode type +type ExecutableBranchNode struct { + mock.Mock +} + +// GetElse provides a mock function with given fields: +func (_m *ExecutableBranchNode) GetElse() *string { + ret := _m.Called() + + var r0 *string + if rf, ok := ret.Get(0).(func() *string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*string) + } + } + + return r0 +} + +// GetElseFail provides a mock function with given fields: +func (_m *ExecutableBranchNode) GetElseFail() *core.Error { + ret := _m.Called() + + var r0 *core.Error + if rf, ok := ret.Get(0).(func() *core.Error); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.Error) + } + } + + return r0 +} + +// GetElseIf provides a mock function with given fields: +func (_m *ExecutableBranchNode) GetElseIf() []v1alpha1.ExecutableIfBlock { + ret := _m.Called() + + var r0 []v1alpha1.ExecutableIfBlock + if rf, ok := ret.Get(0).(func() []v1alpha1.ExecutableIfBlock); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]v1alpha1.ExecutableIfBlock) + } + } + + return r0 +} + +// GetIf provides a mock function with given fields: +func (_m *ExecutableBranchNode) GetIf() v1alpha1.ExecutableIfBlock { + ret := _m.Called() + + var r0 v1alpha1.ExecutableIfBlock + if rf, ok := ret.Get(0).(func() v1alpha1.ExecutableIfBlock); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableIfBlock) + } + } + + return r0 +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableBranchNodeStatus.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableBranchNodeStatus.go new file mode 100644 index 0000000000..a24a937043 --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableBranchNodeStatus.go @@ -0,0 +1,41 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" +import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +// ExecutableBranchNodeStatus is an autogenerated mock type for the ExecutableBranchNodeStatus type +type ExecutableBranchNodeStatus struct { + mock.Mock +} + +// GetFinalizedNode provides a mock function with given fields: +func (_m *ExecutableBranchNodeStatus) GetFinalizedNode() *string { + ret := _m.Called() + + var r0 *string + if rf, ok := ret.Get(0).(func() *string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*string) + } + } + + return r0 +} + +// GetPhase provides a mock function with given fields: +func (_m *ExecutableBranchNodeStatus) GetPhase() v1alpha1.BranchNodePhase { + ret := _m.Called() + + var r0 v1alpha1.BranchNodePhase + if rf, ok := ret.Get(0).(func() v1alpha1.BranchNodePhase); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1alpha1.BranchNodePhase) + } + + return r0 +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableDynamicNodeStatus.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableDynamicNodeStatus.go new file mode 100644 index 0000000000..fc8819ba33 --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableDynamicNodeStatus.go @@ -0,0 +1,25 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" +import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +// ExecutableDynamicNodeStatus is an autogenerated mock type for the ExecutableDynamicNodeStatus type +type ExecutableDynamicNodeStatus struct { + mock.Mock +} + +// GetDynamicNodePhase provides a mock function with given fields: +func (_m *ExecutableDynamicNodeStatus) GetDynamicNodePhase() v1alpha1.DynamicNodePhase { + ret := _m.Called() + + var r0 v1alpha1.DynamicNodePhase + if rf, ok := ret.Get(0).(func() v1alpha1.DynamicNodePhase); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1alpha1.DynamicNodePhase) + } + + return r0 +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableIfBlock.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableIfBlock.go new file mode 100644 index 0000000000..7e29c8b373 --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableIfBlock.go @@ -0,0 +1,43 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" +import mock "github.com/stretchr/testify/mock" + +// ExecutableIfBlock is an autogenerated mock type for the ExecutableIfBlock type +type ExecutableIfBlock struct { + mock.Mock +} + +// GetCondition provides a mock function with given fields: +func (_m *ExecutableIfBlock) GetCondition() *core.BooleanExpression { + ret := _m.Called() + + var r0 *core.BooleanExpression + if rf, ok := ret.Get(0).(func() *core.BooleanExpression); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.BooleanExpression) + } + } + + return r0 +} + +// GetThenNode provides a mock function with given fields: +func (_m *ExecutableIfBlock) GetThenNode() *string { + ret := _m.Called() + + var r0 *string + if rf, ok := ret.Get(0).(func() *string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*string) + } + } + + return r0 +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNode.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNode.go new file mode 100644 index 0000000000..0a7fa9e598 --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNode.go @@ -0,0 +1,196 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" +import v1 "k8s.io/api/core/v1" +import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +// ExecutableNode is an autogenerated mock type for the ExecutableNode type +type ExecutableNode struct { + mock.Mock +} + +// GetBranchNode provides a mock function with given fields: +func (_m *ExecutableNode) GetBranchNode() v1alpha1.ExecutableBranchNode { + ret := _m.Called() + + var r0 v1alpha1.ExecutableBranchNode + if rf, ok := ret.Get(0).(func() v1alpha1.ExecutableBranchNode); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableBranchNode) + } + } + + return r0 +} + +// GetConfig provides a mock function with given fields: +func (_m *ExecutableNode) GetConfig() *v1.ConfigMap { + ret := _m.Called() + + var r0 *v1.ConfigMap + if rf, ok := ret.Get(0).(func() *v1.ConfigMap); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1.ConfigMap) + } + } + + return r0 +} + +// GetID provides a mock function with given fields: +func (_m *ExecutableNode) GetID() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// GetInputBindings provides a mock function with given fields: +func (_m *ExecutableNode) GetInputBindings() []*v1alpha1.Binding { + ret := _m.Called() + + var r0 []*v1alpha1.Binding + if rf, ok := ret.Get(0).(func() []*v1alpha1.Binding); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*v1alpha1.Binding) + } + } + + return r0 +} + +// GetKind provides a mock function with given fields: +func (_m *ExecutableNode) GetKind() v1alpha1.NodeKind { + ret := _m.Called() + + var r0 v1alpha1.NodeKind + if rf, ok := ret.Get(0).(func() v1alpha1.NodeKind); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1alpha1.NodeKind) + } + + return r0 +} + +// GetOutputAlias provides a mock function with given fields: +func (_m *ExecutableNode) GetOutputAlias() []v1alpha1.Alias { + ret := _m.Called() + + var r0 []v1alpha1.Alias + if rf, ok := ret.Get(0).(func() []v1alpha1.Alias); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]v1alpha1.Alias) + } + } + + return r0 +} + +// GetResources provides a mock function with given fields: +func (_m *ExecutableNode) GetResources() *v1.ResourceRequirements { + ret := _m.Called() + + var r0 *v1.ResourceRequirements + if rf, ok := ret.Get(0).(func() *v1.ResourceRequirements); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1.ResourceRequirements) + } + } + + return r0 +} + +// GetRetryStrategy provides a mock function with given fields: +func (_m *ExecutableNode) GetRetryStrategy() *v1alpha1.RetryStrategy { + ret := _m.Called() + + var r0 *v1alpha1.RetryStrategy + if rf, ok := ret.Get(0).(func() *v1alpha1.RetryStrategy); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1alpha1.RetryStrategy) + } + } + + return r0 +} + +// GetTaskID provides a mock function with given fields: +func (_m *ExecutableNode) GetTaskID() *string { + ret := _m.Called() + + var r0 *string + if rf, ok := ret.Get(0).(func() *string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*string) + } + } + + return r0 +} + +// GetWorkflowNode provides a mock function with given fields: +func (_m *ExecutableNode) GetWorkflowNode() v1alpha1.ExecutableWorkflowNode { + ret := _m.Called() + + var r0 v1alpha1.ExecutableWorkflowNode + if rf, ok := ret.Get(0).(func() v1alpha1.ExecutableWorkflowNode); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableWorkflowNode) + } + } + + return r0 +} + +// IsEndNode provides a mock function with given fields: +func (_m *ExecutableNode) IsEndNode() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// IsStartNode provides a mock function with given fields: +func (_m *ExecutableNode) IsStartNode() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNodeStatus.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNodeStatus.go new file mode 100644 index 0000000000..ee91f4ad98 --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNodeStatus.go @@ -0,0 +1,407 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" +import mock "github.com/stretchr/testify/mock" +import storage "github.com/lyft/flytestdlib/storage" +import v1 "k8s.io/apimachinery/pkg/apis/meta/v1" +import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +// ExecutableNodeStatus is an autogenerated mock type for the ExecutableNodeStatus type +type ExecutableNodeStatus struct { + mock.Mock +} + +// ClearDynamicNodeStatus provides a mock function with given fields: +func (_m *ExecutableNodeStatus) ClearDynamicNodeStatus() { + _m.Called() +} + +// ClearSubWorkflowStatus provides a mock function with given fields: +func (_m *ExecutableNodeStatus) ClearSubWorkflowStatus() { + _m.Called() +} + +// ClearTaskStatus provides a mock function with given fields: +func (_m *ExecutableNodeStatus) ClearTaskStatus() { + _m.Called() +} + +// ClearWorkflowStatus provides a mock function with given fields: +func (_m *ExecutableNodeStatus) ClearWorkflowStatus() { + _m.Called() +} + +// GetAttempts provides a mock function with given fields: +func (_m *ExecutableNodeStatus) GetAttempts() uint32 { + ret := _m.Called() + + var r0 uint32 + if rf, ok := ret.Get(0).(func() uint32); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(uint32) + } + + return r0 +} + +// GetDataDir provides a mock function with given fields: +func (_m *ExecutableNodeStatus) GetDataDir() storage.DataReference { + ret := _m.Called() + + var r0 storage.DataReference + if rf, ok := ret.Get(0).(func() storage.DataReference); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(storage.DataReference) + } + + return r0 +} + +// GetLastUpdatedAt provides a mock function with given fields: +func (_m *ExecutableNodeStatus) GetLastUpdatedAt() *v1.Time { + ret := _m.Called() + + var r0 *v1.Time + if rf, ok := ret.Get(0).(func() *v1.Time); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1.Time) + } + } + + return r0 +} + +// GetMessage provides a mock function with given fields: +func (_m *ExecutableNodeStatus) GetMessage() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// GetNodeExecutionStatus provides a mock function with given fields: id +func (_m *ExecutableNodeStatus) GetNodeExecutionStatus(id string) v1alpha1.ExecutableNodeStatus { + ret := _m.Called(id) + + var r0 v1alpha1.ExecutableNodeStatus + if rf, ok := ret.Get(0).(func(string) v1alpha1.ExecutableNodeStatus); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableNodeStatus) + } + } + + return r0 +} + +// GetOrCreateBranchStatus provides a mock function with given fields: +func (_m *ExecutableNodeStatus) GetOrCreateBranchStatus() v1alpha1.MutableBranchNodeStatus { + ret := _m.Called() + + var r0 v1alpha1.MutableBranchNodeStatus + if rf, ok := ret.Get(0).(func() v1alpha1.MutableBranchNodeStatus); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.MutableBranchNodeStatus) + } + } + + return r0 +} + +// GetOrCreateDynamicNodeStatus provides a mock function with given fields: +func (_m *ExecutableNodeStatus) GetOrCreateDynamicNodeStatus() v1alpha1.MutableDynamicNodeStatus { + ret := _m.Called() + + var r0 v1alpha1.MutableDynamicNodeStatus + if rf, ok := ret.Get(0).(func() v1alpha1.MutableDynamicNodeStatus); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.MutableDynamicNodeStatus) + } + } + + return r0 +} + +// GetOrCreateSubWorkflowStatus provides a mock function with given fields: +func (_m *ExecutableNodeStatus) GetOrCreateSubWorkflowStatus() v1alpha1.MutableSubWorkflowNodeStatus { + ret := _m.Called() + + var r0 v1alpha1.MutableSubWorkflowNodeStatus + if rf, ok := ret.Get(0).(func() v1alpha1.MutableSubWorkflowNodeStatus); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.MutableSubWorkflowNodeStatus) + } + } + + return r0 +} + +// GetOrCreateTaskStatus provides a mock function with given fields: +func (_m *ExecutableNodeStatus) GetOrCreateTaskStatus() v1alpha1.MutableTaskNodeStatus { + ret := _m.Called() + + var r0 v1alpha1.MutableTaskNodeStatus + if rf, ok := ret.Get(0).(func() v1alpha1.MutableTaskNodeStatus); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.MutableTaskNodeStatus) + } + } + + return r0 +} + +// GetOrCreateWorkflowStatus provides a mock function with given fields: +func (_m *ExecutableNodeStatus) GetOrCreateWorkflowStatus() v1alpha1.MutableWorkflowNodeStatus { + ret := _m.Called() + + var r0 v1alpha1.MutableWorkflowNodeStatus + if rf, ok := ret.Get(0).(func() v1alpha1.MutableWorkflowNodeStatus); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.MutableWorkflowNodeStatus) + } + } + + return r0 +} + +// GetParentNodeID provides a mock function with given fields: +func (_m *ExecutableNodeStatus) GetParentNodeID() *string { + ret := _m.Called() + + var r0 *string + if rf, ok := ret.Get(0).(func() *string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*string) + } + } + + return r0 +} + +// GetParentTaskID provides a mock function with given fields: +func (_m *ExecutableNodeStatus) GetParentTaskID() *core.TaskExecutionIdentifier { + ret := _m.Called() + + var r0 *core.TaskExecutionIdentifier + if rf, ok := ret.Get(0).(func() *core.TaskExecutionIdentifier); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.TaskExecutionIdentifier) + } + } + + return r0 +} + +// GetPhase provides a mock function with given fields: +func (_m *ExecutableNodeStatus) GetPhase() v1alpha1.NodePhase { + ret := _m.Called() + + var r0 v1alpha1.NodePhase + if rf, ok := ret.Get(0).(func() v1alpha1.NodePhase); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1alpha1.NodePhase) + } + + return r0 +} + +// GetQueuedAt provides a mock function with given fields: +func (_m *ExecutableNodeStatus) GetQueuedAt() *v1.Time { + ret := _m.Called() + + var r0 *v1.Time + if rf, ok := ret.Get(0).(func() *v1.Time); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1.Time) + } + } + + return r0 +} + +// GetStartedAt provides a mock function with given fields: +func (_m *ExecutableNodeStatus) GetStartedAt() *v1.Time { + ret := _m.Called() + + var r0 *v1.Time + if rf, ok := ret.Get(0).(func() *v1.Time); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1.Time) + } + } + + return r0 +} + +// GetStoppedAt provides a mock function with given fields: +func (_m *ExecutableNodeStatus) GetStoppedAt() *v1.Time { + ret := _m.Called() + + var r0 *v1.Time + if rf, ok := ret.Get(0).(func() *v1.Time); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1.Time) + } + } + + return r0 +} + +// GetSubWorkflowNodeStatus provides a mock function with given fields: +func (_m *ExecutableNodeStatus) GetSubWorkflowNodeStatus() v1alpha1.ExecutableSubWorkflowNodeStatus { + ret := _m.Called() + + var r0 v1alpha1.ExecutableSubWorkflowNodeStatus + if rf, ok := ret.Get(0).(func() v1alpha1.ExecutableSubWorkflowNodeStatus); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableSubWorkflowNodeStatus) + } + } + + return r0 +} + +// GetTaskNodeStatus provides a mock function with given fields: +func (_m *ExecutableNodeStatus) GetTaskNodeStatus() v1alpha1.ExecutableTaskNodeStatus { + ret := _m.Called() + + var r0 v1alpha1.ExecutableTaskNodeStatus + if rf, ok := ret.Get(0).(func() v1alpha1.ExecutableTaskNodeStatus); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableTaskNodeStatus) + } + } + + return r0 +} + +// GetWorkflowNodeStatus provides a mock function with given fields: +func (_m *ExecutableNodeStatus) GetWorkflowNodeStatus() v1alpha1.ExecutableWorkflowNodeStatus { + ret := _m.Called() + + var r0 v1alpha1.ExecutableWorkflowNodeStatus + if rf, ok := ret.Get(0).(func() v1alpha1.ExecutableWorkflowNodeStatus); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableWorkflowNodeStatus) + } + } + + return r0 +} + +// IncrementAttempts provides a mock function with given fields: +func (_m *ExecutableNodeStatus) IncrementAttempts() uint32 { + ret := _m.Called() + + var r0 uint32 + if rf, ok := ret.Get(0).(func() uint32); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(uint32) + } + + return r0 +} + +// IsCached provides a mock function with given fields: +func (_m *ExecutableNodeStatus) IsCached() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// IsDirty provides a mock function with given fields: +func (_m *ExecutableNodeStatus) IsDirty() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// ResetDirty provides a mock function with given fields: +func (_m *ExecutableNodeStatus) ResetDirty() { + _m.Called() +} + +// SetCached provides a mock function with given fields: +func (_m *ExecutableNodeStatus) SetCached() { + _m.Called() +} + +// SetDataDir provides a mock function with given fields: _a0 +func (_m *ExecutableNodeStatus) SetDataDir(_a0 storage.DataReference) { + _m.Called(_a0) +} + +// SetParentNodeID provides a mock function with given fields: n +func (_m *ExecutableNodeStatus) SetParentNodeID(n *string) { + _m.Called(n) +} + +// SetParentTaskID provides a mock function with given fields: t +func (_m *ExecutableNodeStatus) SetParentTaskID(t *core.TaskExecutionIdentifier) { + _m.Called(t) +} + +// UpdatePhase provides a mock function with given fields: phase, occurredAt, reason +func (_m *ExecutableNodeStatus) UpdatePhase(phase v1alpha1.NodePhase, occurredAt v1.Time, reason string) { + _m.Called(phase, occurredAt, reason) +} + +// VisitNodeStatuses provides a mock function with given fields: visitor +func (_m *ExecutableNodeStatus) VisitNodeStatuses(visitor func(string, v1alpha1.ExecutableNodeStatus)) { + _m.Called(visitor) +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableSubWorkflow.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableSubWorkflow.go new file mode 100644 index 0000000000..132991edbd --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableSubWorkflow.go @@ -0,0 +1,167 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" +import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +// ExecutableSubWorkflow is an autogenerated mock type for the ExecutableSubWorkflow type +type ExecutableSubWorkflow struct { + mock.Mock +} + +// FromNode provides a mock function with given fields: name +func (_m *ExecutableSubWorkflow) FromNode(name string) ([]string, error) { + ret := _m.Called(name) + + var r0 []string + if rf, ok := ret.Get(0).(func(string) []string); ok { + r0 = rf(name) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(name) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetConnections provides a mock function with given fields: +func (_m *ExecutableSubWorkflow) GetConnections() *v1alpha1.Connections { + ret := _m.Called() + + var r0 *v1alpha1.Connections + if rf, ok := ret.Get(0).(func() *v1alpha1.Connections); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1alpha1.Connections) + } + } + + return r0 +} + +// GetID provides a mock function with given fields: +func (_m *ExecutableSubWorkflow) GetID() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// GetNode provides a mock function with given fields: nodeID +func (_m *ExecutableSubWorkflow) GetNode(nodeID string) (v1alpha1.ExecutableNode, bool) { + ret := _m.Called(nodeID) + + var r0 v1alpha1.ExecutableNode + if rf, ok := ret.Get(0).(func(string) v1alpha1.ExecutableNode); ok { + r0 = rf(nodeID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableNode) + } + } + + var r1 bool + if rf, ok := ret.Get(1).(func(string) bool); ok { + r1 = rf(nodeID) + } else { + r1 = ret.Get(1).(bool) + } + + return r0, r1 +} + +// GetNodes provides a mock function with given fields: +func (_m *ExecutableSubWorkflow) GetNodes() []string { + ret := _m.Called() + + var r0 []string + if rf, ok := ret.Get(0).(func() []string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + return r0 +} + +// GetOnFailureNode provides a mock function with given fields: +func (_m *ExecutableSubWorkflow) GetOnFailureNode() v1alpha1.ExecutableNode { + ret := _m.Called() + + var r0 v1alpha1.ExecutableNode + if rf, ok := ret.Get(0).(func() v1alpha1.ExecutableNode); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableNode) + } + } + + return r0 +} + +// GetOutputBindings provides a mock function with given fields: +func (_m *ExecutableSubWorkflow) GetOutputBindings() []*v1alpha1.Binding { + ret := _m.Called() + + var r0 []*v1alpha1.Binding + if rf, ok := ret.Get(0).(func() []*v1alpha1.Binding); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*v1alpha1.Binding) + } + } + + return r0 +} + +// GetOutputs provides a mock function with given fields: +func (_m *ExecutableSubWorkflow) GetOutputs() *v1alpha1.OutputVarMap { + ret := _m.Called() + + var r0 *v1alpha1.OutputVarMap + if rf, ok := ret.Get(0).(func() *v1alpha1.OutputVarMap); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1alpha1.OutputVarMap) + } + } + + return r0 +} + +// StartNode provides a mock function with given fields: +func (_m *ExecutableSubWorkflow) StartNode() v1alpha1.ExecutableNode { + ret := _m.Called() + + var r0 v1alpha1.ExecutableNode + if rf, ok := ret.Get(0).(func() v1alpha1.ExecutableNode); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableNode) + } + } + + return r0 +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableSubWorkflowNodeStatus.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableSubWorkflowNodeStatus.go new file mode 100644 index 0000000000..c90dc45985 --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableSubWorkflowNodeStatus.go @@ -0,0 +1,25 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" +import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +// ExecutableSubWorkflowNodeStatus is an autogenerated mock type for the ExecutableSubWorkflowNodeStatus type +type ExecutableSubWorkflowNodeStatus struct { + mock.Mock +} + +// GetPhase provides a mock function with given fields: +func (_m *ExecutableSubWorkflowNodeStatus) GetPhase() v1alpha1.WorkflowPhase { + ret := _m.Called() + + var r0 v1alpha1.WorkflowPhase + if rf, ok := ret.Get(0).(func() v1alpha1.WorkflowPhase); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1alpha1.WorkflowPhase) + } + + return r0 +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableTask.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableTask.go new file mode 100644 index 0000000000..61700e7ecc --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableTask.go @@ -0,0 +1,41 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" +import mock "github.com/stretchr/testify/mock" + +// ExecutableTask is an autogenerated mock type for the ExecutableTask type +type ExecutableTask struct { + mock.Mock +} + +// CoreTask provides a mock function with given fields: +func (_m *ExecutableTask) CoreTask() *core.TaskTemplate { + ret := _m.Called() + + var r0 *core.TaskTemplate + if rf, ok := ret.Get(0).(func() *core.TaskTemplate); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.TaskTemplate) + } + } + + return r0 +} + +// TaskType provides a mock function with given fields: +func (_m *ExecutableTask) TaskType() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableTaskNodeStatus.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableTaskNodeStatus.go new file mode 100644 index 0000000000..c1c2c7e258 --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableTaskNodeStatus.go @@ -0,0 +1,55 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" +import types "github.com/lyft/flyteplugins/go/tasks/v1/types" + +// ExecutableTaskNodeStatus is an autogenerated mock type for the ExecutableTaskNodeStatus type +type ExecutableTaskNodeStatus struct { + mock.Mock +} + +// GetCustomState provides a mock function with given fields: +func (_m *ExecutableTaskNodeStatus) GetCustomState() map[string]interface{} { + ret := _m.Called() + + var r0 map[string]interface{} + if rf, ok := ret.Get(0).(func() map[string]interface{}); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]interface{}) + } + } + + return r0 +} + +// GetPhase provides a mock function with given fields: +func (_m *ExecutableTaskNodeStatus) GetPhase() types.TaskPhase { + ret := _m.Called() + + var r0 types.TaskPhase + if rf, ok := ret.Get(0).(func() types.TaskPhase); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(types.TaskPhase) + } + + return r0 +} + +// GetPhaseVersion provides a mock function with given fields: +func (_m *ExecutableTaskNodeStatus) GetPhaseVersion() uint32 { + ret := _m.Called() + + var r0 uint32 + if rf, ok := ret.Get(0).(func() uint32); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(uint32) + } + + return r0 +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableWorkflow.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableWorkflow.go new file mode 100644 index 0000000000..d14b3d0f93 --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableWorkflow.go @@ -0,0 +1,370 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" +import types "k8s.io/apimachinery/pkg/types" +import v1 "k8s.io/apimachinery/pkg/apis/meta/v1" +import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +// ExecutableWorkflow is an autogenerated mock type for the ExecutableWorkflow type +type ExecutableWorkflow struct { + mock.Mock +} + +// FindSubWorkflow provides a mock function with given fields: subID +func (_m *ExecutableWorkflow) FindSubWorkflow(subID string) v1alpha1.ExecutableSubWorkflow { + ret := _m.Called(subID) + + var r0 v1alpha1.ExecutableSubWorkflow + if rf, ok := ret.Get(0).(func(string) v1alpha1.ExecutableSubWorkflow); ok { + r0 = rf(subID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableSubWorkflow) + } + } + + return r0 +} + +// FromNode provides a mock function with given fields: name +func (_m *ExecutableWorkflow) FromNode(name string) ([]string, error) { + ret := _m.Called(name) + + var r0 []string + if rf, ok := ret.Get(0).(func(string) []string); ok { + r0 = rf(name) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(name) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetAnnotations provides a mock function with given fields: +func (_m *ExecutableWorkflow) GetAnnotations() map[string]string { + ret := _m.Called() + + var r0 map[string]string + if rf, ok := ret.Get(0).(func() map[string]string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]string) + } + } + + return r0 +} + +// GetConnections provides a mock function with given fields: +func (_m *ExecutableWorkflow) GetConnections() *v1alpha1.Connections { + ret := _m.Called() + + var r0 *v1alpha1.Connections + if rf, ok := ret.Get(0).(func() *v1alpha1.Connections); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1alpha1.Connections) + } + } + + return r0 +} + +// GetCreationTimestamp provides a mock function with given fields: +func (_m *ExecutableWorkflow) GetCreationTimestamp() v1.Time { + ret := _m.Called() + + var r0 v1.Time + if rf, ok := ret.Get(0).(func() v1.Time); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1.Time) + } + + return r0 +} + +// GetExecutionID provides a mock function with given fields: +func (_m *ExecutableWorkflow) GetExecutionID() v1alpha1.WorkflowExecutionIdentifier { + ret := _m.Called() + + var r0 v1alpha1.WorkflowExecutionIdentifier + if rf, ok := ret.Get(0).(func() v1alpha1.WorkflowExecutionIdentifier); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1alpha1.WorkflowExecutionIdentifier) + } + + return r0 +} + +// GetExecutionStatus provides a mock function with given fields: +func (_m *ExecutableWorkflow) GetExecutionStatus() v1alpha1.ExecutableWorkflowStatus { + ret := _m.Called() + + var r0 v1alpha1.ExecutableWorkflowStatus + if rf, ok := ret.Get(0).(func() v1alpha1.ExecutableWorkflowStatus); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableWorkflowStatus) + } + } + + return r0 +} + +// GetID provides a mock function with given fields: +func (_m *ExecutableWorkflow) GetID() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// GetK8sWorkflowID provides a mock function with given fields: +func (_m *ExecutableWorkflow) GetK8sWorkflowID() types.NamespacedName { + ret := _m.Called() + + var r0 types.NamespacedName + if rf, ok := ret.Get(0).(func() types.NamespacedName); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(types.NamespacedName) + } + + return r0 +} + +// GetLabels provides a mock function with given fields: +func (_m *ExecutableWorkflow) GetLabels() map[string]string { + ret := _m.Called() + + var r0 map[string]string + if rf, ok := ret.Get(0).(func() map[string]string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]string) + } + } + + return r0 +} + +// GetName provides a mock function with given fields: +func (_m *ExecutableWorkflow) GetName() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// GetNamespace provides a mock function with given fields: +func (_m *ExecutableWorkflow) GetNamespace() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// GetNode provides a mock function with given fields: nodeID +func (_m *ExecutableWorkflow) GetNode(nodeID string) (v1alpha1.ExecutableNode, bool) { + ret := _m.Called(nodeID) + + var r0 v1alpha1.ExecutableNode + if rf, ok := ret.Get(0).(func(string) v1alpha1.ExecutableNode); ok { + r0 = rf(nodeID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableNode) + } + } + + var r1 bool + if rf, ok := ret.Get(1).(func(string) bool); ok { + r1 = rf(nodeID) + } else { + r1 = ret.Get(1).(bool) + } + + return r0, r1 +} + +// GetNodeExecutionStatus provides a mock function with given fields: id +func (_m *ExecutableWorkflow) GetNodeExecutionStatus(id string) v1alpha1.ExecutableNodeStatus { + ret := _m.Called(id) + + var r0 v1alpha1.ExecutableNodeStatus + if rf, ok := ret.Get(0).(func(string) v1alpha1.ExecutableNodeStatus); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableNodeStatus) + } + } + + return r0 +} + +// GetNodes provides a mock function with given fields: +func (_m *ExecutableWorkflow) GetNodes() []string { + ret := _m.Called() + + var r0 []string + if rf, ok := ret.Get(0).(func() []string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + return r0 +} + +// GetOnFailureNode provides a mock function with given fields: +func (_m *ExecutableWorkflow) GetOnFailureNode() v1alpha1.ExecutableNode { + ret := _m.Called() + + var r0 v1alpha1.ExecutableNode + if rf, ok := ret.Get(0).(func() v1alpha1.ExecutableNode); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableNode) + } + } + + return r0 +} + +// GetOutputBindings provides a mock function with given fields: +func (_m *ExecutableWorkflow) GetOutputBindings() []*v1alpha1.Binding { + ret := _m.Called() + + var r0 []*v1alpha1.Binding + if rf, ok := ret.Get(0).(func() []*v1alpha1.Binding); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*v1alpha1.Binding) + } + } + + return r0 +} + +// GetOutputs provides a mock function with given fields: +func (_m *ExecutableWorkflow) GetOutputs() *v1alpha1.OutputVarMap { + ret := _m.Called() + + var r0 *v1alpha1.OutputVarMap + if rf, ok := ret.Get(0).(func() *v1alpha1.OutputVarMap); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1alpha1.OutputVarMap) + } + } + + return r0 +} + +// GetServiceAccountName provides a mock function with given fields: +func (_m *ExecutableWorkflow) GetServiceAccountName() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// GetTask provides a mock function with given fields: id +func (_m *ExecutableWorkflow) GetTask(id string) (v1alpha1.ExecutableTask, error) { + ret := _m.Called(id) + + var r0 v1alpha1.ExecutableTask + if rf, ok := ret.Get(0).(func(string) v1alpha1.ExecutableTask); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableTask) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NewControllerRef provides a mock function with given fields: +func (_m *ExecutableWorkflow) NewControllerRef() v1.OwnerReference { + ret := _m.Called() + + var r0 v1.OwnerReference + if rf, ok := ret.Get(0).(func() v1.OwnerReference); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1.OwnerReference) + } + + return r0 +} + +// StartNode provides a mock function with given fields: +func (_m *ExecutableWorkflow) StartNode() v1alpha1.ExecutableNode { + ret := _m.Called() + + var r0 v1alpha1.ExecutableNode + if rf, ok := ret.Get(0).(func() v1alpha1.ExecutableNode); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableNode) + } + } + + return r0 +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableWorkflowNode.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableWorkflowNode.go new file mode 100644 index 0000000000..3cf799a200 --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableWorkflowNode.go @@ -0,0 +1,43 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" +import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +// ExecutableWorkflowNode is an autogenerated mock type for the ExecutableWorkflowNode type +type ExecutableWorkflowNode struct { + mock.Mock +} + +// GetLaunchPlanRefID provides a mock function with given fields: +func (_m *ExecutableWorkflowNode) GetLaunchPlanRefID() *v1alpha1.Identifier { + ret := _m.Called() + + var r0 *v1alpha1.Identifier + if rf, ok := ret.Get(0).(func() *v1alpha1.Identifier); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1alpha1.Identifier) + } + } + + return r0 +} + +// GetSubWorkflowRef provides a mock function with given fields: +func (_m *ExecutableWorkflowNode) GetSubWorkflowRef() *string { + ret := _m.Called() + + var r0 *string + if rf, ok := ret.Get(0).(func() *string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*string) + } + } + + return r0 +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableWorkflowNodeStatus.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableWorkflowNodeStatus.go new file mode 100644 index 0000000000..fbf9544210 --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableWorkflowNodeStatus.go @@ -0,0 +1,24 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" + +// ExecutableWorkflowNodeStatus is an autogenerated mock type for the ExecutableWorkflowNodeStatus type +type ExecutableWorkflowNodeStatus struct { + mock.Mock +} + +// GetWorkflowExecutionName provides a mock function with given fields: +func (_m *ExecutableWorkflowNodeStatus) GetWorkflowExecutionName() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableWorkflowStatus.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableWorkflowStatus.go new file mode 100644 index 0000000000..8de18b32c2 --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableWorkflowStatus.go @@ -0,0 +1,194 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import context "context" +import mock "github.com/stretchr/testify/mock" +import storage "github.com/lyft/flytestdlib/storage" +import v1 "k8s.io/apimachinery/pkg/apis/meta/v1" +import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +// ExecutableWorkflowStatus is an autogenerated mock type for the ExecutableWorkflowStatus type +type ExecutableWorkflowStatus struct { + mock.Mock +} + +// ConstructNodeDataDir provides a mock function with given fields: ctx, constructor, name +func (_m *ExecutableWorkflowStatus) ConstructNodeDataDir(ctx context.Context, constructor storage.ReferenceConstructor, name string) (storage.DataReference, error) { + ret := _m.Called(ctx, constructor, name) + + var r0 storage.DataReference + if rf, ok := ret.Get(0).(func(context.Context, storage.ReferenceConstructor, string) storage.DataReference); ok { + r0 = rf(ctx, constructor, name) + } else { + r0 = ret.Get(0).(storage.DataReference) + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, storage.ReferenceConstructor, string) error); ok { + r1 = rf(ctx, constructor, name) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetDataDir provides a mock function with given fields: +func (_m *ExecutableWorkflowStatus) GetDataDir() storage.DataReference { + ret := _m.Called() + + var r0 storage.DataReference + if rf, ok := ret.Get(0).(func() storage.DataReference); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(storage.DataReference) + } + + return r0 +} + +// GetLastUpdatedAt provides a mock function with given fields: +func (_m *ExecutableWorkflowStatus) GetLastUpdatedAt() *v1.Time { + ret := _m.Called() + + var r0 *v1.Time + if rf, ok := ret.Get(0).(func() *v1.Time); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1.Time) + } + } + + return r0 +} + +// GetMessage provides a mock function with given fields: +func (_m *ExecutableWorkflowStatus) GetMessage() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// GetNodeExecutionStatus provides a mock function with given fields: id +func (_m *ExecutableWorkflowStatus) GetNodeExecutionStatus(id string) v1alpha1.ExecutableNodeStatus { + ret := _m.Called(id) + + var r0 v1alpha1.ExecutableNodeStatus + if rf, ok := ret.Get(0).(func(string) v1alpha1.ExecutableNodeStatus); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableNodeStatus) + } + } + + return r0 +} + +// GetOutputReference provides a mock function with given fields: +func (_m *ExecutableWorkflowStatus) GetOutputReference() storage.DataReference { + ret := _m.Called() + + var r0 storage.DataReference + if rf, ok := ret.Get(0).(func() storage.DataReference); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(storage.DataReference) + } + + return r0 +} + +// GetPhase provides a mock function with given fields: +func (_m *ExecutableWorkflowStatus) GetPhase() v1alpha1.WorkflowPhase { + ret := _m.Called() + + var r0 v1alpha1.WorkflowPhase + if rf, ok := ret.Get(0).(func() v1alpha1.WorkflowPhase); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1alpha1.WorkflowPhase) + } + + return r0 +} + +// GetStartedAt provides a mock function with given fields: +func (_m *ExecutableWorkflowStatus) GetStartedAt() *v1.Time { + ret := _m.Called() + + var r0 *v1.Time + if rf, ok := ret.Get(0).(func() *v1.Time); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1.Time) + } + } + + return r0 +} + +// GetStoppedAt provides a mock function with given fields: +func (_m *ExecutableWorkflowStatus) GetStoppedAt() *v1.Time { + ret := _m.Called() + + var r0 *v1.Time + if rf, ok := ret.Get(0).(func() *v1.Time); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1.Time) + } + } + + return r0 +} + +// IncFailedAttempts provides a mock function with given fields: +func (_m *ExecutableWorkflowStatus) IncFailedAttempts() { + _m.Called() +} + +// IsTerminated provides a mock function with given fields: +func (_m *ExecutableWorkflowStatus) IsTerminated() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// SetDataDir provides a mock function with given fields: _a0 +func (_m *ExecutableWorkflowStatus) SetDataDir(_a0 storage.DataReference) { + _m.Called(_a0) +} + +// SetMessage provides a mock function with given fields: msg +func (_m *ExecutableWorkflowStatus) SetMessage(msg string) { + _m.Called(msg) +} + +// SetOutputReference provides a mock function with given fields: reference +func (_m *ExecutableWorkflowStatus) SetOutputReference(reference storage.DataReference) { + _m.Called(reference) +} + +// UpdatePhase provides a mock function with given fields: p, msg +func (_m *ExecutableWorkflowStatus) UpdatePhase(p v1alpha1.WorkflowPhase, msg string) { + _m.Called(p, msg) +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableBranchNodeStatus.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableBranchNodeStatus.go new file mode 100644 index 0000000000..fcf090d22f --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableBranchNodeStatus.go @@ -0,0 +1,51 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" +import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +// MutableBranchNodeStatus is an autogenerated mock type for the MutableBranchNodeStatus type +type MutableBranchNodeStatus struct { + mock.Mock +} + +// GetFinalizedNode provides a mock function with given fields: +func (_m *MutableBranchNodeStatus) GetFinalizedNode() *string { + ret := _m.Called() + + var r0 *string + if rf, ok := ret.Get(0).(func() *string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*string) + } + } + + return r0 +} + +// GetPhase provides a mock function with given fields: +func (_m *MutableBranchNodeStatus) GetPhase() v1alpha1.BranchNodePhase { + ret := _m.Called() + + var r0 v1alpha1.BranchNodePhase + if rf, ok := ret.Get(0).(func() v1alpha1.BranchNodePhase); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1alpha1.BranchNodePhase) + } + + return r0 +} + +// SetBranchNodeError provides a mock function with given fields: +func (_m *MutableBranchNodeStatus) SetBranchNodeError() { + _m.Called() +} + +// SetBranchNodeSuccess provides a mock function with given fields: id +func (_m *MutableBranchNodeStatus) SetBranchNodeSuccess(id string) { + _m.Called(id) +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableDynamicNodeStatus.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableDynamicNodeStatus.go new file mode 100644 index 0000000000..0208256ad4 --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableDynamicNodeStatus.go @@ -0,0 +1,30 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" +import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +// MutableDynamicNodeStatus is an autogenerated mock type for the MutableDynamicNodeStatus type +type MutableDynamicNodeStatus struct { + mock.Mock +} + +// GetDynamicNodePhase provides a mock function with given fields: +func (_m *MutableDynamicNodeStatus) GetDynamicNodePhase() v1alpha1.DynamicNodePhase { + ret := _m.Called() + + var r0 v1alpha1.DynamicNodePhase + if rf, ok := ret.Get(0).(func() v1alpha1.DynamicNodePhase); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1alpha1.DynamicNodePhase) + } + + return r0 +} + +// SetDynamicNodePhase provides a mock function with given fields: phase +func (_m *MutableDynamicNodeStatus) SetDynamicNodePhase(phase v1alpha1.DynamicNodePhase) { + _m.Called(phase) +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableNodeStatus.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableNodeStatus.go new file mode 100644 index 0000000000..0ef3f0b330 --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableNodeStatus.go @@ -0,0 +1,158 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" +import mock "github.com/stretchr/testify/mock" +import storage "github.com/lyft/flytestdlib/storage" +import v1 "k8s.io/apimachinery/pkg/apis/meta/v1" +import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +// MutableNodeStatus is an autogenerated mock type for the MutableNodeStatus type +type MutableNodeStatus struct { + mock.Mock +} + +// ClearDynamicNodeStatus provides a mock function with given fields: +func (_m *MutableNodeStatus) ClearDynamicNodeStatus() { + _m.Called() +} + +// ClearSubWorkflowStatus provides a mock function with given fields: +func (_m *MutableNodeStatus) ClearSubWorkflowStatus() { + _m.Called() +} + +// ClearTaskStatus provides a mock function with given fields: +func (_m *MutableNodeStatus) ClearTaskStatus() { + _m.Called() +} + +// ClearWorkflowStatus provides a mock function with given fields: +func (_m *MutableNodeStatus) ClearWorkflowStatus() { + _m.Called() +} + +// GetOrCreateBranchStatus provides a mock function with given fields: +func (_m *MutableNodeStatus) GetOrCreateBranchStatus() v1alpha1.MutableBranchNodeStatus { + ret := _m.Called() + + var r0 v1alpha1.MutableBranchNodeStatus + if rf, ok := ret.Get(0).(func() v1alpha1.MutableBranchNodeStatus); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.MutableBranchNodeStatus) + } + } + + return r0 +} + +// GetOrCreateDynamicNodeStatus provides a mock function with given fields: +func (_m *MutableNodeStatus) GetOrCreateDynamicNodeStatus() v1alpha1.MutableDynamicNodeStatus { + ret := _m.Called() + + var r0 v1alpha1.MutableDynamicNodeStatus + if rf, ok := ret.Get(0).(func() v1alpha1.MutableDynamicNodeStatus); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.MutableDynamicNodeStatus) + } + } + + return r0 +} + +// GetOrCreateSubWorkflowStatus provides a mock function with given fields: +func (_m *MutableNodeStatus) GetOrCreateSubWorkflowStatus() v1alpha1.MutableSubWorkflowNodeStatus { + ret := _m.Called() + + var r0 v1alpha1.MutableSubWorkflowNodeStatus + if rf, ok := ret.Get(0).(func() v1alpha1.MutableSubWorkflowNodeStatus); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.MutableSubWorkflowNodeStatus) + } + } + + return r0 +} + +// GetOrCreateTaskStatus provides a mock function with given fields: +func (_m *MutableNodeStatus) GetOrCreateTaskStatus() v1alpha1.MutableTaskNodeStatus { + ret := _m.Called() + + var r0 v1alpha1.MutableTaskNodeStatus + if rf, ok := ret.Get(0).(func() v1alpha1.MutableTaskNodeStatus); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.MutableTaskNodeStatus) + } + } + + return r0 +} + +// GetOrCreateWorkflowStatus provides a mock function with given fields: +func (_m *MutableNodeStatus) GetOrCreateWorkflowStatus() v1alpha1.MutableWorkflowNodeStatus { + ret := _m.Called() + + var r0 v1alpha1.MutableWorkflowNodeStatus + if rf, ok := ret.Get(0).(func() v1alpha1.MutableWorkflowNodeStatus); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.MutableWorkflowNodeStatus) + } + } + + return r0 +} + +// IncrementAttempts provides a mock function with given fields: +func (_m *MutableNodeStatus) IncrementAttempts() uint32 { + ret := _m.Called() + + var r0 uint32 + if rf, ok := ret.Get(0).(func() uint32); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(uint32) + } + + return r0 +} + +// ResetDirty provides a mock function with given fields: +func (_m *MutableNodeStatus) ResetDirty() { + _m.Called() +} + +// SetCached provides a mock function with given fields: +func (_m *MutableNodeStatus) SetCached() { + _m.Called() +} + +// SetDataDir provides a mock function with given fields: _a0 +func (_m *MutableNodeStatus) SetDataDir(_a0 storage.DataReference) { + _m.Called(_a0) +} + +// SetParentNodeID provides a mock function with given fields: n +func (_m *MutableNodeStatus) SetParentNodeID(n *string) { + _m.Called(n) +} + +// SetParentTaskID provides a mock function with given fields: t +func (_m *MutableNodeStatus) SetParentTaskID(t *core.TaskExecutionIdentifier) { + _m.Called(t) +} + +// UpdatePhase provides a mock function with given fields: phase, occurredAt, reason +func (_m *MutableNodeStatus) UpdatePhase(phase v1alpha1.NodePhase, occurredAt v1.Time, reason string) { + _m.Called(phase, occurredAt, reason) +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableSubWorkflowNodeStatus.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableSubWorkflowNodeStatus.go new file mode 100644 index 0000000000..b194a63732 --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableSubWorkflowNodeStatus.go @@ -0,0 +1,30 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" +import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +// MutableSubWorkflowNodeStatus is an autogenerated mock type for the MutableSubWorkflowNodeStatus type +type MutableSubWorkflowNodeStatus struct { + mock.Mock +} + +// GetPhase provides a mock function with given fields: +func (_m *MutableSubWorkflowNodeStatus) GetPhase() v1alpha1.WorkflowPhase { + ret := _m.Called() + + var r0 v1alpha1.WorkflowPhase + if rf, ok := ret.Get(0).(func() v1alpha1.WorkflowPhase); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1alpha1.WorkflowPhase) + } + + return r0 +} + +// SetPhase provides a mock function with given fields: phase +func (_m *MutableSubWorkflowNodeStatus) SetPhase(phase v1alpha1.WorkflowPhase) { + _m.Called(phase) +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableTaskNodeStatus.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableTaskNodeStatus.go new file mode 100644 index 0000000000..adade901a5 --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableTaskNodeStatus.go @@ -0,0 +1,70 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" +import types "github.com/lyft/flyteplugins/go/tasks/v1/types" + +// MutableTaskNodeStatus is an autogenerated mock type for the MutableTaskNodeStatus type +type MutableTaskNodeStatus struct { + mock.Mock +} + +// GetCustomState provides a mock function with given fields: +func (_m *MutableTaskNodeStatus) GetCustomState() map[string]interface{} { + ret := _m.Called() + + var r0 map[string]interface{} + if rf, ok := ret.Get(0).(func() map[string]interface{}); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]interface{}) + } + } + + return r0 +} + +// GetPhase provides a mock function with given fields: +func (_m *MutableTaskNodeStatus) GetPhase() types.TaskPhase { + ret := _m.Called() + + var r0 types.TaskPhase + if rf, ok := ret.Get(0).(func() types.TaskPhase); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(types.TaskPhase) + } + + return r0 +} + +// GetPhaseVersion provides a mock function with given fields: +func (_m *MutableTaskNodeStatus) GetPhaseVersion() uint32 { + ret := _m.Called() + + var r0 uint32 + if rf, ok := ret.Get(0).(func() uint32); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(uint32) + } + + return r0 +} + +// SetCustomState provides a mock function with given fields: state +func (_m *MutableTaskNodeStatus) SetCustomState(state map[string]interface{}) { + _m.Called(state) +} + +// SetPhase provides a mock function with given fields: phase +func (_m *MutableTaskNodeStatus) SetPhase(phase types.TaskPhase) { + _m.Called(phase) +} + +// SetPhaseVersion provides a mock function with given fields: version +func (_m *MutableTaskNodeStatus) SetPhaseVersion(version uint32) { + _m.Called(version) +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableWorkflowNodeStatus.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableWorkflowNodeStatus.go new file mode 100644 index 0000000000..a6dc4c8c9f --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableWorkflowNodeStatus.go @@ -0,0 +1,29 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" + +// MutableWorkflowNodeStatus is an autogenerated mock type for the MutableWorkflowNodeStatus type +type MutableWorkflowNodeStatus struct { + mock.Mock +} + +// GetWorkflowExecutionName provides a mock function with given fields: +func (_m *MutableWorkflowNodeStatus) GetWorkflowExecutionName() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// SetWorkflowExecutionName provides a mock function with given fields: name +func (_m *MutableWorkflowNodeStatus) SetWorkflowExecutionName(name string) { + _m.Called(name) +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/NodeStatusGetter.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/NodeStatusGetter.go new file mode 100644 index 0000000000..fff448c480 --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/NodeStatusGetter.go @@ -0,0 +1,27 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" +import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +// NodeStatusGetter is an autogenerated mock type for the NodeStatusGetter type +type NodeStatusGetter struct { + mock.Mock +} + +// GetNodeExecutionStatus provides a mock function with given fields: id +func (_m *NodeStatusGetter) GetNodeExecutionStatus(id string) v1alpha1.ExecutableNodeStatus { + ret := _m.Called(id) + + var r0 v1alpha1.ExecutableNodeStatus + if rf, ok := ret.Get(0).(func(string) v1alpha1.ExecutableNodeStatus); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableNodeStatus) + } + } + + return r0 +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/NodeStatusVisitor.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/NodeStatusVisitor.go new file mode 100644 index 0000000000..fc30a240ba --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/NodeStatusVisitor.go @@ -0,0 +1,16 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" +import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +// NodeStatusVisitor is an autogenerated mock type for the NodeStatusVisitor type +type NodeStatusVisitor struct { + mock.Mock +} + +// VisitNodeStatuses provides a mock function with given fields: visitor +func (_m *NodeStatusVisitor) VisitNodeStatuses(visitor func(string, v1alpha1.ExecutableNodeStatus)) { + _m.Called(visitor) +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/WorkflowMeta.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/WorkflowMeta.go new file mode 100644 index 0000000000..6149479c45 --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/WorkflowMeta.go @@ -0,0 +1,143 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" +import types "k8s.io/apimachinery/pkg/types" +import v1 "k8s.io/apimachinery/pkg/apis/meta/v1" +import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +// WorkflowMeta is an autogenerated mock type for the WorkflowMeta type +type WorkflowMeta struct { + mock.Mock +} + +// GetAnnotations provides a mock function with given fields: +func (_m *WorkflowMeta) GetAnnotations() map[string]string { + ret := _m.Called() + + var r0 map[string]string + if rf, ok := ret.Get(0).(func() map[string]string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]string) + } + } + + return r0 +} + +// GetCreationTimestamp provides a mock function with given fields: +func (_m *WorkflowMeta) GetCreationTimestamp() v1.Time { + ret := _m.Called() + + var r0 v1.Time + if rf, ok := ret.Get(0).(func() v1.Time); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1.Time) + } + + return r0 +} + +// GetExecutionID provides a mock function with given fields: +func (_m *WorkflowMeta) GetExecutionID() v1alpha1.WorkflowExecutionIdentifier { + ret := _m.Called() + + var r0 v1alpha1.WorkflowExecutionIdentifier + if rf, ok := ret.Get(0).(func() v1alpha1.WorkflowExecutionIdentifier); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1alpha1.WorkflowExecutionIdentifier) + } + + return r0 +} + +// GetK8sWorkflowID provides a mock function with given fields: +func (_m *WorkflowMeta) GetK8sWorkflowID() types.NamespacedName { + ret := _m.Called() + + var r0 types.NamespacedName + if rf, ok := ret.Get(0).(func() types.NamespacedName); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(types.NamespacedName) + } + + return r0 +} + +// GetLabels provides a mock function with given fields: +func (_m *WorkflowMeta) GetLabels() map[string]string { + ret := _m.Called() + + var r0 map[string]string + if rf, ok := ret.Get(0).(func() map[string]string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]string) + } + } + + return r0 +} + +// GetName provides a mock function with given fields: +func (_m *WorkflowMeta) GetName() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// GetNamespace provides a mock function with given fields: +func (_m *WorkflowMeta) GetNamespace() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// GetServiceAccountName provides a mock function with given fields: +func (_m *WorkflowMeta) GetServiceAccountName() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// NewControllerRef provides a mock function with given fields: +func (_m *WorkflowMeta) NewControllerRef() v1.OwnerReference { + ret := _m.Called() + + var r0 v1.OwnerReference + if rf, ok := ret.Get(0).(func() v1.OwnerReference); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1.OwnerReference) + } + + return r0 +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/WorkflowMetaExtended.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/WorkflowMetaExtended.go new file mode 100644 index 0000000000..a9d69c6b7a --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/WorkflowMetaExtended.go @@ -0,0 +1,198 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" +import types "k8s.io/apimachinery/pkg/types" +import v1 "k8s.io/apimachinery/pkg/apis/meta/v1" +import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +// WorkflowMetaExtended is an autogenerated mock type for the WorkflowMetaExtended type +type WorkflowMetaExtended struct { + mock.Mock +} + +// FindSubWorkflow provides a mock function with given fields: subID +func (_m *WorkflowMetaExtended) FindSubWorkflow(subID string) v1alpha1.ExecutableSubWorkflow { + ret := _m.Called(subID) + + var r0 v1alpha1.ExecutableSubWorkflow + if rf, ok := ret.Get(0).(func(string) v1alpha1.ExecutableSubWorkflow); ok { + r0 = rf(subID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableSubWorkflow) + } + } + + return r0 +} + +// GetAnnotations provides a mock function with given fields: +func (_m *WorkflowMetaExtended) GetAnnotations() map[string]string { + ret := _m.Called() + + var r0 map[string]string + if rf, ok := ret.Get(0).(func() map[string]string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]string) + } + } + + return r0 +} + +// GetCreationTimestamp provides a mock function with given fields: +func (_m *WorkflowMetaExtended) GetCreationTimestamp() v1.Time { + ret := _m.Called() + + var r0 v1.Time + if rf, ok := ret.Get(0).(func() v1.Time); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1.Time) + } + + return r0 +} + +// GetExecutionID provides a mock function with given fields: +func (_m *WorkflowMetaExtended) GetExecutionID() v1alpha1.WorkflowExecutionIdentifier { + ret := _m.Called() + + var r0 v1alpha1.WorkflowExecutionIdentifier + if rf, ok := ret.Get(0).(func() v1alpha1.WorkflowExecutionIdentifier); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1alpha1.WorkflowExecutionIdentifier) + } + + return r0 +} + +// GetExecutionStatus provides a mock function with given fields: +func (_m *WorkflowMetaExtended) GetExecutionStatus() v1alpha1.ExecutableWorkflowStatus { + ret := _m.Called() + + var r0 v1alpha1.ExecutableWorkflowStatus + if rf, ok := ret.Get(0).(func() v1alpha1.ExecutableWorkflowStatus); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableWorkflowStatus) + } + } + + return r0 +} + +// GetK8sWorkflowID provides a mock function with given fields: +func (_m *WorkflowMetaExtended) GetK8sWorkflowID() types.NamespacedName { + ret := _m.Called() + + var r0 types.NamespacedName + if rf, ok := ret.Get(0).(func() types.NamespacedName); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(types.NamespacedName) + } + + return r0 +} + +// GetLabels provides a mock function with given fields: +func (_m *WorkflowMetaExtended) GetLabels() map[string]string { + ret := _m.Called() + + var r0 map[string]string + if rf, ok := ret.Get(0).(func() map[string]string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]string) + } + } + + return r0 +} + +// GetName provides a mock function with given fields: +func (_m *WorkflowMetaExtended) GetName() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// GetNamespace provides a mock function with given fields: +func (_m *WorkflowMetaExtended) GetNamespace() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// GetServiceAccountName provides a mock function with given fields: +func (_m *WorkflowMetaExtended) GetServiceAccountName() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// GetTask provides a mock function with given fields: id +func (_m *WorkflowMetaExtended) GetTask(id string) (v1alpha1.ExecutableTask, error) { + ret := _m.Called(id) + + var r0 v1alpha1.ExecutableTask + if rf, ok := ret.Get(0).(func(string) v1alpha1.ExecutableTask); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableTask) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NewControllerRef provides a mock function with given fields: +func (_m *WorkflowMetaExtended) NewControllerRef() v1.OwnerReference { + ret := _m.Called() + + var r0 v1.OwnerReference + if rf, ok := ret.Get(0).(func() v1.OwnerReference); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1.OwnerReference) + } + + return r0 +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/node_status.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/node_status.go new file mode 100644 index 0000000000..a02c99c946 --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/node_status.go @@ -0,0 +1,512 @@ +package v1alpha1 + +import ( + "encoding/json" + "reflect" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flyteplugins/go/tasks/v1/types" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +type BranchNodeStatus struct { + Phase BranchNodePhase `json:"phase"` + FinalizedNodeID *NodeID `json:"finalNodeId"` +} + +func (in *BranchNodeStatus) GetPhase() BranchNodePhase { + return in.Phase +} + +func (in *BranchNodeStatus) SetBranchNodeError() { + in.Phase = BranchNodeError +} + +func (in *BranchNodeStatus) SetBranchNodeSuccess(id NodeID) { + in.Phase = BranchNodeSuccess + in.FinalizedNodeID = &id +} + +func (in *BranchNodeStatus) GetFinalizedNode() *NodeID { + return in.FinalizedNodeID +} + +func (in *BranchNodeStatus) Equals(other *BranchNodeStatus) bool { + if in == nil && other == nil { + return true + } + if in != nil && other != nil { + phaseEqual := in.Phase == other.Phase + if phaseEqual { + if in.FinalizedNodeID == nil && other.FinalizedNodeID == nil { + return true + } + if in.FinalizedNodeID != nil && other.FinalizedNodeID != nil { + return *in.FinalizedNodeID == *other.FinalizedNodeID + } + return false + } + return false + } + return false +} + +type DynamicNodePhase int + +const ( + DynamicNodePhaseNone DynamicNodePhase = iota + DynamicNodePhaseExecuting +) + +type DynamicNodeStatus struct { + Phase DynamicNodePhase `json:"phase"` +} + +func (s *DynamicNodeStatus) GetDynamicNodePhase() DynamicNodePhase { + return s.Phase +} + +func (s *DynamicNodeStatus) SetDynamicNodePhase(phase DynamicNodePhase) { + s.Phase = phase +} + +func (s *DynamicNodeStatus) Equals(o *DynamicNodeStatus) bool { + if s == nil && o == nil { + return true + } + if s != nil && o != nil { + return s.Phase == o.Phase + } + return false +} + +type SubWorkflowNodeStatus struct { + Phase WorkflowPhase `json:"phase"` +} + +func (s SubWorkflowNodeStatus) GetPhase() WorkflowPhase { + return s.Phase +} + +func (s *SubWorkflowNodeStatus) SetPhase(phase WorkflowPhase) { + s.Phase = phase +} + +type WorkflowNodeStatus struct { + WorkflowName string `json:"name"` +} + +func (in *WorkflowNodeStatus) SetWorkflowExecutionName(name string) { + in.WorkflowName = name +} + +func (in *WorkflowNodeStatus) GetWorkflowExecutionName() string { + return in.WorkflowName +} + +type NodeStatus struct { + Phase NodePhase `json:"phase"` + QueuedAt *metav1.Time `json:"queuedAt,omitempty"` + StartedAt *metav1.Time `json:"startedAt,omitempty"` + StoppedAt *metav1.Time `json:"stoppedAt,omitempty"` + LastUpdatedAt *metav1.Time `json:"lastUpdatedAt,omitempty"` + Message string `json:"message,omitempty"` + DataDir DataReference `json:"dataDir,omitempty"` + Attempts uint32 `json:"attempts"` + Cached bool `json:"cached"` + dirty bool + // This is useful only for branch nodes. If this is set, then it can be used to determine if execution can proceed + ParentNode *NodeID `json:"parentNode,omitempty"` + ParentTask *TaskExecutionIdentifier `json:"parentTask,omitempty"` + BranchStatus *BranchNodeStatus `json:"branchStatus,omitempty"` + SubNodeStatus map[NodeID]*NodeStatus `json:"subNodeStatus,omitempty"` + // We can store the outputs at this layer + + WorkflowNodeStatus *WorkflowNodeStatus `json:"workflowNodeStatus,omitempty"` + TaskNodeStatus *TaskNodeStatus `json:",omitempty"` + SubWorkflowNodeStatus *SubWorkflowNodeStatus `json:"subWorkflowStatus,omitempty"` + DynamicNodeStatus *DynamicNodeStatus `json:"dynamicNodeStatus,omitempty"` +} + +func (in NodeStatus) VisitNodeStatuses(visitor NodeStatusVisitFn) { + for n, s := range in.SubNodeStatus { + visitor(n, s) + } +} + +func (in *NodeStatus) ClearWorkflowStatus() { + in.WorkflowNodeStatus = nil +} + +func (in *NodeStatus) ClearTaskStatus() { + in.TaskNodeStatus = nil +} + +func (in *NodeStatus) GetLastUpdatedAt() *metav1.Time { + return in.LastUpdatedAt +} + +func (in *NodeStatus) GetAttempts() uint32 { + return in.Attempts +} + +func (in *NodeStatus) SetCached() { + in.Cached = true + in.setDirty() +} + +func (in *NodeStatus) setDirty() { + in.dirty = true +} +func (in *NodeStatus) IsCached() bool { + return in.Cached +} + +func (in *NodeStatus) IsDirty() bool { + return in.dirty +} + +// ResetDirty is for unit tests, shouldn't be used in actual logic. +func (in *NodeStatus) ResetDirty() { + in.dirty = false +} + +func (in *NodeStatus) IncrementAttempts() uint32 { + in.Attempts++ + in.setDirty() + return in.Attempts +} + +func (in *NodeStatus) GetOrCreateDynamicNodeStatus() MutableDynamicNodeStatus { + if in.DynamicNodeStatus == nil { + in.setDirty() + in.DynamicNodeStatus = &DynamicNodeStatus{} + } + + return in.DynamicNodeStatus +} + +func (in *NodeStatus) ClearDynamicNodeStatus() { + in.DynamicNodeStatus = nil +} + +func (in *NodeStatus) GetOrCreateBranchStatus() MutableBranchNodeStatus { + if in.BranchStatus == nil { + in.BranchStatus = &BranchNodeStatus{} + } + + in.setDirty() + return in.BranchStatus +} + +func (in *NodeStatus) GetWorkflowNodeStatus() ExecutableWorkflowNodeStatus { + if in.WorkflowNodeStatus == nil { + return nil + } + + in.setDirty() + return in.WorkflowNodeStatus +} + +func (in *NodeStatus) GetPhase() NodePhase { + return in.Phase +} + +func (in *NodeStatus) GetMessage() string { + return in.Message +} + +func IsPhaseTerminal(phase NodePhase) bool { + return phase == NodePhaseSucceeded || phase == NodePhaseFailed || phase == NodePhaseSkipped +} + +func (in *NodeStatus) GetOrCreateTaskStatus() MutableTaskNodeStatus { + if in.TaskNodeStatus == nil { + in.TaskNodeStatus = &TaskNodeStatus{} + } + + in.setDirty() + return in.TaskNodeStatus +} + +func (in *NodeStatus) UpdatePhase(p NodePhase, occurredAt metav1.Time, reason string) { + if in.Phase == p { + // We will not update the phase multiple times. This prevents the comparison from returning false positive + return + } + + in.Phase = p + in.Message = reason + if len(reason) > maxMessageSize { + in.Message = reason[:maxMessageSize] + } + + n := occurredAt + if occurredAt.IsZero() { + n = metav1.Now() + } + + if p == NodePhaseQueued && in.QueuedAt == nil { + in.QueuedAt = &n + } else if p == NodePhaseRunning && in.StartedAt == nil { + in.StartedAt = &n + } else if IsPhaseTerminal(p) && in.StoppedAt == nil { + if in.StartedAt == nil { + in.StartedAt = &n + } + + in.StoppedAt = &n + } + + if in.Phase != p { + in.LastUpdatedAt = &n + } + + in.setDirty() +} + +func (in *NodeStatus) GetStartedAt() *metav1.Time { + return in.StartedAt +} + +func (in *NodeStatus) GetStoppedAt() *metav1.Time { + return in.StoppedAt +} + +func (in *NodeStatus) GetQueuedAt() *metav1.Time { + return in.QueuedAt +} + +func (in *NodeStatus) GetParentNodeID() *NodeID { + return in.ParentNode +} + +func (in *NodeStatus) GetParentTaskID() *core.TaskExecutionIdentifier { + if in.ParentTask != nil { + return in.ParentTask.TaskExecutionIdentifier + } + return nil +} + +func (in *NodeStatus) SetParentNodeID(n *NodeID) { + in.ParentNode = n + in.setDirty() +} + +func (in *NodeStatus) SetParentTaskID(t *core.TaskExecutionIdentifier) { + in.ParentTask = &TaskExecutionIdentifier{ + TaskExecutionIdentifier: t, + } + in.setDirty() +} + +func (in *NodeStatus) GetOrCreateWorkflowStatus() MutableWorkflowNodeStatus { + if in.WorkflowNodeStatus == nil { + in.WorkflowNodeStatus = &WorkflowNodeStatus{} + } + + in.setDirty() + return in.WorkflowNodeStatus +} + +func (in NodeStatus) GetTaskNodeStatus() ExecutableTaskNodeStatus { + // Explicitly return nil here to avoid a misleading non-nil interface. + if in.TaskNodeStatus == nil { + return nil + } + + return in.TaskNodeStatus +} + +func (in NodeStatus) GetSubWorkflowNodeStatus() ExecutableSubWorkflowNodeStatus { + if in.SubWorkflowNodeStatus == nil { + return nil + } + + return in.SubWorkflowNodeStatus +} + +func (in NodeStatus) GetOrCreateSubWorkflowStatus() MutableSubWorkflowNodeStatus { + if in.SubWorkflowNodeStatus == nil { + in.SubWorkflowNodeStatus = &SubWorkflowNodeStatus{} + } + + return in.SubWorkflowNodeStatus +} + +func (in *NodeStatus) ClearSubWorkflowStatus() { + in.SubWorkflowNodeStatus = nil +} + +func (in *NodeStatus) GetNodeExecutionStatus(id NodeID) ExecutableNodeStatus { + n, ok := in.SubNodeStatus[id] + if ok { + return n + } + if in.SubNodeStatus == nil { + in.SubNodeStatus = make(map[NodeID]*NodeStatus) + } + newNodeStatus := &NodeStatus{} + newNodeStatus.SetParentTaskID(in.GetParentTaskID()) + newNodeStatus.SetParentNodeID(in.GetParentNodeID()) + + in.SubNodeStatus[id] = newNodeStatus + return newNodeStatus +} + +func (in *NodeStatus) IsTerminated() bool { + return in.GetPhase() == NodePhaseFailed || in.GetPhase() == NodePhaseSkipped || in.GetPhase() == NodePhaseSucceeded +} + +func (in *NodeStatus) GetDataDir() DataReference { + return in.DataDir +} + +func (in *NodeStatus) SetDataDir(d DataReference) { + in.DataDir = d + in.setDirty() +} + +func (in *NodeStatus) Equals(other *NodeStatus) bool { + // Assuming in is never nil + if other == nil { + return false + } + + if in.Attempts != other.Attempts { + return false + } + + if in.Phase != other.Phase { + return false + } + + if !reflect.DeepEqual(in.TaskNodeStatus, other.TaskNodeStatus) { + return false + } + + if in.DataDir != other.DataDir { + return false + } + + if in.ParentNode != nil && other.ParentNode != nil { + if *in.ParentNode != *other.ParentNode { + return false + } + } else if !(in.ParentNode == other.ParentNode) { + // Both are not nil + return false + } + + if !reflect.DeepEqual(in.ParentTask, other.ParentTask) { + return false + } + + if len(in.SubNodeStatus) != len(other.SubNodeStatus) { + return false + } + + for k, v := range in.SubNodeStatus { + otherV, ok := other.SubNodeStatus[k] + if !ok { + return false + } + if !v.Equals(otherV) { + return false + } + } + + return in.BranchStatus.Equals(other.BranchStatus) // && in.DynamicNodeStatus.Equals(other.DynamicNodeStatus) +} + +// THIS IS NOT AUTO GENERATED +func (in *CustomState) DeepCopyInto(out *CustomState) { + if in == nil || *in == nil { + return + } + + raw, err := json.Marshal(in) + if err != nil { + return + } + + err = json.Unmarshal(raw, out) + if err != nil { + return + } +} + +func (in *CustomState) DeepCopy() *CustomState { + if in == nil || *in == nil { + return nil + } + + out := &CustomState{} + in.DeepCopyInto(out) + return out +} + +type TaskNodeStatus struct { + Phase types.TaskPhase `json:"phase,omitempty"` + PhaseVersion uint32 `json:"phaseVersion,omitempty"` + CustomState types.CustomState `json:"custom,omitempty"` +} + +func (in *TaskNodeStatus) SetPhase(phase types.TaskPhase) { + in.Phase = phase +} + +func (in *TaskNodeStatus) SetPhaseVersion(version uint32) { + in.PhaseVersion = version +} + +func (in *TaskNodeStatus) SetCustomState(state types.CustomState) { + in.CustomState = state +} + +func (in TaskNodeStatus) GetPhase() types.TaskPhase { + return in.Phase +} + +func (in TaskNodeStatus) GetPhaseVersion() uint32 { + return in.PhaseVersion +} + +func (in TaskNodeStatus) GetCustomState() types.CustomState { + return in.CustomState +} + +func (in *TaskNodeStatus) UpdatePhase(phase types.TaskPhase, phaseVersion uint32) { + in.Phase = phase + in.PhaseVersion = phaseVersion +} + +func (in *TaskNodeStatus) UpdateCustomState(state types.CustomState) { + in.CustomState = state +} + +func (in *TaskNodeStatus) DeepCopyInto(out *TaskNodeStatus) { + if in == nil { + return + } + + raw, err := json.Marshal(in) + if err != nil { + return + } + + err = json.Unmarshal(raw, out) + if err != nil { + return + } +} + +func (in *TaskNodeStatus) DeepCopy() *TaskNodeStatus { + if in == nil { + return nil + } + + out := &TaskNodeStatus{} + in.DeepCopyInto(out) + return out +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/node_status_test.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/node_status_test.go new file mode 100644 index 0000000000..f458c1c38b --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/node_status_test.go @@ -0,0 +1,156 @@ +package v1alpha1 + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIsPhaseTerminal(t *testing.T) { + assert.True(t, IsPhaseTerminal(NodePhaseFailed)) + assert.True(t, IsPhaseTerminal(NodePhaseSkipped)) + assert.True(t, IsPhaseTerminal(NodePhaseSucceeded)) + + assert.False(t, IsPhaseTerminal(NodePhaseFailing)) + assert.False(t, IsPhaseTerminal(NodePhaseRunning)) + assert.False(t, IsPhaseTerminal(NodePhaseNotYetStarted)) +} + +func TestNodeStatus_Equals(t *testing.T) { + one := &NodeStatus{} + var other *NodeStatus + assert.False(t, one.Equals(other)) + + other = &NodeStatus{} + assert.True(t, one.Equals(other)) + + one.Phase = NodePhaseRunning + assert.False(t, one.Equals(other)) + + other.Phase = one.Phase + assert.True(t, one.Equals(other)) + + one.DataDir = "data-dir" + assert.False(t, one.Equals(other)) + + other.DataDir = one.DataDir + assert.True(t, one.Equals(other)) + + parentNode := "x" + one.ParentNode = &parentNode + assert.False(t, one.Equals(other)) + + parentNode2 := "y" + other.ParentNode = &parentNode2 + assert.False(t, one.Equals(other)) + + other.ParentNode = &parentNode + assert.True(t, one.Equals(other)) + + one.BranchStatus = &BranchNodeStatus{} + assert.False(t, one.Equals(other)) + other.BranchStatus = &BranchNodeStatus{} + assert.True(t, one.Equals(other)) + + node := "x" + one.SubNodeStatus = map[NodeID]*NodeStatus{ + node: {}, + } + assert.False(t, one.Equals(other)) + other.SubNodeStatus = map[NodeID]*NodeStatus{ + node: {}, + } + assert.True(t, one.Equals(other)) + + one.SubNodeStatus[node].Phase = NodePhaseRunning + assert.False(t, one.Equals(other)) + other.SubNodeStatus[node].Phase = NodePhaseRunning + assert.True(t, one.Equals(other)) +} + +func TestBranchNodeStatus_Equals(t *testing.T) { + var one *BranchNodeStatus + var other *BranchNodeStatus + assert.True(t, one.Equals(other)) + one = &BranchNodeStatus{} + + assert.False(t, one.Equals(other)) + other = &BranchNodeStatus{} + + assert.True(t, one.Equals(other)) + + one.Phase = BranchNodeError + assert.False(t, one.Equals(other)) + other.Phase = one.Phase + + assert.True(t, one.Equals(other)) + + node := "x" + one.FinalizedNodeID = &node + assert.False(t, one.Equals(other)) + + node2 := "y" + other.FinalizedNodeID = &node2 + assert.False(t, one.Equals(other)) + + node2 = node + other.FinalizedNodeID = &node2 + assert.True(t, one.Equals(other)) +} + +func TestDynamicNodeStatus_Equals(t *testing.T) { + var one *DynamicNodeStatus + var other *DynamicNodeStatus + assert.True(t, one.Equals(other)) + one = &DynamicNodeStatus{} + + assert.False(t, one.Equals(other)) + other = &DynamicNodeStatus{} + + assert.True(t, one.Equals(other)) + + one.Phase = DynamicNodePhaseExecuting + assert.False(t, one.Equals(other)) + other.Phase = one.Phase + + assert.True(t, one.Equals(other)) +} + +func TestCustomState_DeepCopyInto(t *testing.T) { + t.Run("Nil", func(t *testing.T) { + var in CustomState + var out CustomState + in.DeepCopyInto(&out) + assert.Nil(t, in) + assert.Nil(t, out) + }) + + t.Run("Not nil in", func(t *testing.T) { + in := CustomState(map[string]interface{}{ + "key1": "hello", + }) + + var out CustomState + in.DeepCopyInto(&out) + assert.NotNil(t, out) + assert.Equal(t, 1, len(out)) + }) +} + +func TestCustomState_DeepCopy(t *testing.T) { + t.Run("Nil", func(t *testing.T) { + var in CustomState + assert.Nil(t, in) + assert.Nil(t, in.DeepCopy()) + }) + + t.Run("Not nil in", func(t *testing.T) { + in := CustomState(map[string]interface{}{ + "key1": "hello", + }) + + out := in.DeepCopy() + assert.NotNil(t, out) + assert.Equal(t, 1, len(*out)) + }) +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/nodes.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/nodes.go new file mode 100644 index 0000000000..3a6fa0bfe7 --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/nodes.go @@ -0,0 +1,193 @@ +package v1alpha1 + +import ( + "bytes" + + "github.com/golang/protobuf/jsonpb" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + typesv1 "k8s.io/api/core/v1" +) + +var marshaler = jsonpb.Marshaler{} + +type OutputVarMap struct { + *core.VariableMap +} + +func (in *OutputVarMap) MarshalJSON() ([]byte, error) { + var buf bytes.Buffer + if err := marshaler.Marshal(&buf, in.VariableMap); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +func (in *OutputVarMap) UnmarshalJSON(b []byte) error { + in.VariableMap = &core.VariableMap{} + return jsonpb.Unmarshal(bytes.NewReader(b), in.VariableMap) +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *OutputVarMap) DeepCopyInto(out *OutputVarMap) { + *out = *in + // We do not manipulate the object, so its ok + // Once we figure out the autogenerate story we can replace this +} + +type Binding struct { + *core.Binding +} + +func (in *Binding) UnmarshalJSON(b []byte) error { + in.Binding = &core.Binding{} + return jsonpb.Unmarshal(bytes.NewReader(b), in.Binding) +} + +func (in *Binding) MarshalJSON() ([]byte, error) { + var buf bytes.Buffer + if err := marshaler.Marshal(&buf, in.Binding); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *Binding) DeepCopyInto(out *Binding) { + *out = *in + // We do not manipulate the object, so its ok + // Once we figure out the autogenerate story we can replace this +} + +// Strategy to be used to Retry a node that is in RetryableFailure state +type RetryStrategy struct { + // MinAttempts implies the atleast n attempts to try this node before giving up. The atleast here is because we may + // fail to write the attempt information and end up retrying again. + // Also `0` and `1` both mean atleast one attempt will be done. 0 is a degenerate case. + MinAttempts *int `json:"minAttempts"` + // TODO Add retrydelay? +} + +type Alias struct { + core.Alias +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *Alias) DeepCopyInto(out *Alias) { + *out = *in + // We do not manipulate the object, so its ok + // Once we figure out the autogenerate story we can replace this +} + +type NodeMetadata struct { + core.NodeMetadata +} + +func (in *NodeMetadata) DeepCopyInto(out *NodeMetadata) { + *out = *in + // We do not manipulate the object, so its ok + // Once we figure out the autogenerate story we can replace this +} + +type NodeSpec struct { + ID NodeID `json:"id"` + Resources *typesv1.ResourceRequirements `json:"resources,omitempty"` + Kind NodeKind `json:"kind"` + BranchNode *BranchNodeSpec `json:"branch,omitempty"` + TaskRef *TaskID `json:"task,omitempty"` + WorkflowNode *WorkflowNodeSpec `json:"workflow,omitempty"` + InputBindings []*Binding `json:"inputBindings,omitempty"` + Config *typesv1.ConfigMap `json:"config,omitempty"` + RetryStrategy *RetryStrategy `json:"retry,omitempty"` + OutputAliases []Alias `json:"outputAlias,omitempty"` + + // SecurityContext holds pod-level security attributes and common container settings. + // Optional: Defaults to empty. See type description for default values of each field. + // +optional + SecurityContext *typesv1.PodSecurityContext `json:"securityContext,omitempty" protobuf:"bytes,14,opt,name=securityContext"` + // ImagePullSecrets is an optional list of references to secrets in the same namespace to use for pulling any of the images used by this PodSpec. + // If specified, these secrets will be passed to individual puller implementations for them to use. For example, + // in the case of docker, only DockerConfig type secrets are honored. + // More info: https://kubernetes.io/docs/concepts/containers/images#specifying-imagepullsecrets-on-a-pod + // +optional + // +patchMergeKey=name + // +patchStrategy=merge + ImagePullSecrets []typesv1.LocalObjectReference `json:"imagePullSecrets,omitempty" patchStrategy:"merge" patchMergeKey:"name" protobuf:"bytes,15,rep,name=imagePullSecrets"` + // Specifies the hostname of the Pod + // If not specified, the pod's hostname will be set to a system-defined value. + // +optional + Hostname string `json:"hostname,omitempty" protobuf:"bytes,16,opt,name=hostname"` + // If specified, the fully qualified Pod hostname will be "...svc.". + // If not specified, the pod will not have a domainname at all. + // +optional + Subdomain string `json:"subdomain,omitempty" protobuf:"bytes,17,opt,name=subdomain"` + // If specified, the pod's scheduling constraints + // +optional + Affinity *typesv1.Affinity `json:"affinity,omitempty" protobuf:"bytes,18,opt,name=affinity"` + // If specified, the pod will be dispatched by specified scheduler. + // If not specified, the pod will be dispatched by default scheduler. + // +optional + SchedulerName string `json:"schedulerName,omitempty" protobuf:"bytes,19,opt,name=schedulerName"` + // If specified, the pod's tolerations. + // +optional + Tolerations []typesv1.Toleration `json:"tolerations,omitempty" protobuf:"bytes,22,opt,name=tolerations"` + // StartTime before the system will actively try to mark it failed and kill associated containers. + // Value must be a positive integer. + // +optional + ActiveDeadlineSeconds *int64 `json:"activeDeadlineSeconds,omitempty"` +} + +func (in *NodeSpec) GetRetryStrategy() *RetryStrategy { + return in.RetryStrategy +} + +func (in *NodeSpec) GetConfig() *typesv1.ConfigMap { + return in.Config +} + +func (in *NodeSpec) GetResources() *typesv1.ResourceRequirements { + return in.Resources +} + +func (in *NodeSpec) GetOutputAlias() []Alias { + return in.OutputAliases +} + +func (in *NodeSpec) GetWorkflowNode() ExecutableWorkflowNode { + if in.WorkflowNode == nil { + return nil + } + return in.WorkflowNode +} + +func (in *NodeSpec) GetBranchNode() ExecutableBranchNode { + if in.BranchNode == nil { + return nil + } + return in.BranchNode +} + +func (in *NodeSpec) GetTaskID() *TaskID { + return in.TaskRef +} + +func (in *NodeSpec) GetKind() NodeKind { + return in.Kind +} + +func (in *NodeSpec) GetID() NodeID { + return in.ID +} + +func (in *NodeSpec) IsStartNode() bool { + return in.ID == StartNodeID +} + +func (in *NodeSpec) IsEndNode() bool { + return in.ID == EndNodeID +} + +func (in *NodeSpec) GetInputBindings() []*Binding { + return in.InputBindings +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/register.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/register.go new file mode 100644 index 0000000000..56772feed6 --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/register.go @@ -0,0 +1,38 @@ +package v1alpha1 + +import ( + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" +) + +const FlyteWorkflowKind = "flyteworkflow" + +// SchemeGroupVersion is group version used to register these objects +var SchemeGroupVersion = schema.GroupVersion{Group: flyteworkflow.GroupName, Version: "v1alpha1"} + +// GetKind takes an unqualified kind and returns back a Group qualified GroupKind +func Kind(kind string) schema.GroupKind { + return SchemeGroupVersion.WithKind(kind).GroupKind() +} + +// Resource takes an unqualified resource and returns a Group qualified GroupResource +func Resource(resource string) schema.GroupResource { + return SchemeGroupVersion.WithResource(resource).GroupResource() +} + +var ( + SchemeBuilder = runtime.NewSchemeBuilder(addKnownTypes) + AddToScheme = SchemeBuilder.AddToScheme +) + +// Adds the list of known types to Scheme. +func addKnownTypes(scheme *runtime.Scheme) error { + scheme.AddKnownTypes(SchemeGroupVersion, + &FlyteWorkflow{}, + &FlyteWorkflowList{}, + ) + metav1.AddToGroupVersion(scheme, SchemeGroupVersion) + return nil +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/subworkflow.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/subworkflow.go new file mode 100644 index 0000000000..a7d7532b97 --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/subworkflow.go @@ -0,0 +1,23 @@ +package v1alpha1 + +type WorkflowNodeSpec struct { + // Either one of the two + LaunchPlanRefID *LaunchPlanRefID `json:"launchPlanRefId,omitempty"` + // We currently want the SubWorkflow to be completely contained in the node. this is because + // We use the node status to store the information of the execution. + // Important Note: This may cause a bloat in case we use the same SubWorkflow in multiple nodes. The recommended + // technique for that is to use launch plan refs. This is because we will end up executing the launch plan refs as + // disparate executions in Flyte propeller. This is potentially better as it prevents us from hitting the storage limit + // in etcd + //+optional. + // Workflow *WorkflowSpec `json:"workflow,omitempty"` + SubWorkflowReference *WorkflowID `json:"subWorkflowRef,omitempty"` +} + +func (in *WorkflowNodeSpec) GetLaunchPlanRefID() *LaunchPlanRefID { + return in.LaunchPlanRefID +} + +func (in *WorkflowNodeSpec) GetSubWorkflowRef() *WorkflowID { + return in.SubWorkflowReference +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/tasks.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/tasks.go new file mode 100644 index 0000000000..bcf9227304 --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/tasks.go @@ -0,0 +1,39 @@ +package v1alpha1 + +import ( + "bytes" + + "github.com/golang/protobuf/jsonpb" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" +) + +type TaskSpec struct { + *core.TaskTemplate +} + +func (in *TaskSpec) TaskType() TaskType { + return in.Type +} + +func (in *TaskSpec) CoreTask() *core.TaskTemplate { + return in.TaskTemplate +} + +func (in *TaskSpec) DeepCopyInto(out *TaskSpec) { + *out = *in + // We do not manipulate the object, so its ok + // Once we figure out the autogenerate story we can replace this +} + +func (in *TaskSpec) MarshalJSON() ([]byte, error) { + var buf bytes.Buffer + if err := marshaler.Marshal(&buf, in.TaskTemplate); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +func (in *TaskSpec) UnmarshalJSON(b []byte) error { + in.TaskTemplate = &core.TaskTemplate{} + return jsonpb.Unmarshal(bytes.NewReader(b), in.TaskTemplate) +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/tasks_test.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/tasks_test.go new file mode 100644 index 0000000000..3020223134 --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/tasks_test.go @@ -0,0 +1,20 @@ +package v1alpha1_test + +import ( + "encoding/json" + "testing" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/stretchr/testify/assert" +) + +func TestTaskSpec(t *testing.T) { + j, err := ReadYamlFileAsJSON("testdata/task.yaml") + assert.NoError(t, err) + + task := &v1alpha1.TaskSpec{} + assert.NoError(t, json.Unmarshal(j, task)) + + assert.NotNil(t, task.CoreTask()) + assert.Equal(t, "demo", task.TaskType()) +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/testdata/branch.json b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/testdata/branch.json new file mode 100644 index 0000000000..73d91d8e8d --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/testdata/branch.json @@ -0,0 +1,34 @@ +{ + "branch": { + "if": { + "condition": { + "comparison": { + "operator": "GT", + "leftValue": { + "primitive": { + "integer": "5" + } + }, + "rightValue": { + "var": "x" + } + } + }, + "then": "foo1" + }, + "else": "foo2" + }, + "id": "foobranch", + "inputBindings": [ + { + "binding": { + "promise": { + "nodeId": "start", + "var": "x" + } + }, + "var": "x" + } + ], + "kind": "branch" +} \ No newline at end of file diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/testdata/connections.json b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/testdata/connections.json new file mode 100644 index 0000000000..c671cdbddd --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/testdata/connections.json @@ -0,0 +1,15 @@ +{ + "n1": [ + "n2", + "n3" + ], + "n2": [ + "n4" + ], + "n3": [ + "n4" + ], + "n4": [ + "n5" + ] +} \ No newline at end of file diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/testdata/task.yaml b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/testdata/task.yaml new file mode 100644 index 0000000000..d99786db52 --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/testdata/task.yaml @@ -0,0 +1,33 @@ +id: + name: add-one-and-print +type: "demo" +interface: + inputs: + variables: + value_to_print: + type: + simple: INTEGER + outputs: + variables: + out: + type: + simple: INTEGER +metadata: + runtime: + version: 1.19.0b7 + timeout: 0s +container: + args: + - --task-module=flytekit.examples.tasks + - --task-name=add_one_and_print + - --inputs={{$input}} + - --output-prefix={{$output}} + command: + - flyte-python-entrypoint + image: myflyteimage:abc123 + resources: + requests: + - value: "0.000" + - value: "2.000" + - value: 2048Mi + diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/testdata/workflowspec.yaml b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/testdata/workflowspec.yaml new file mode 100644 index 0000000000..c83f8182b9 --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/testdata/workflowspec.yaml @@ -0,0 +1,220 @@ +workflow.apiVersion: flyte.lyft.com/v1alpha1 +kind: flyteworkflow +metadata: + creationTimestamp: null + generateName: dummy-workflow-1-0- + labels: + execution-id: "" + workflow-id: dummy-workflow-1-0 +inputs: + literals: + triggered_date: + scalar: + primitive: + datetime: 2018-08-01T18:09:18Z +spec: + connections: + add-one-and-print-0: + - sum-non-none-0 + add-one-and-print-1: + - add-one-and-print-2 + - sum-and-print-0 + add-one-and-print-2: + - sum-and-print-0 + add-one-and-print-3: + - sum-non-none-0 + start-node: + - add-one-and-print-0 + - add-one-and-print-3 + - print-every-time-0 + sum-and-print-0: + - end-node + - print-every-time-0 + sum-non-none-0: + - add-one-and-print-1 + - sum-and-print-0 + id: dummy-workflow-1-0 + nodes: + add-one-and-print-0: + activeDeadlineSeconds: 0 + id: add-one-and-print-0 + input_bindings: + - binding: + scalar: + primitive: + integer: "3" + var: value_to_print + kind: task + resources: {} + status: + phase: 0 + task_ref: add-one-and-print + add-one-and-print-1: + activeDeadlineSeconds: 0 + id: add-one-and-print-1 + input_bindings: + - binding: + promise: + nodeId: sum-non-none-0 + var: out + var: value_to_print + kind: task + resources: {} + status: + phase: 0 + task_ref: add-one-and-print + add-one-and-print-2: + activeDeadlineSeconds: 0 + id: add-one-and-print-2 + input_bindings: + - binding: + promise: + nodeId: add-one-and-print-1 + var: out + var: value_to_print + kind: task + resources: {} + status: + phase: 0 + task_ref: add-one-and-print + add-one-and-print-3: + activeDeadlineSeconds: 0 + id: add-one-and-print-3 + input_bindings: + - binding: + scalar: + primitive: + integer: "101" + var: value_to_print + kind: task + resources: {} + status: + phase: 0 + task_ref: add-one-and-print + end-node: + id: end-node + input_bindings: + - binding: + promise: + nodeId: sum-and-print-0 + var: out + var: output + kind: end + resources: {} + status: + phase: 0 + print-every-time-0: + activeDeadlineSeconds: 0 + id: print-every-time-0 + input_bindings: + - binding: + promise: + nodeId: start-node + var: triggered_date + var: date_triggered + - binding: + promise: + nodeId: sum-and-print-0 + var: out_blob + var: in_blob + - binding: + promise: + nodeId: sum-and-print-0 + var: multi_blob + var: multi_blob + - binding: + promise: + nodeId: sum-and-print-0 + var: out + var: value_to_print + kind: task + resources: {} + status: + phase: 0 + task_ref: print-every-time + start-node: + id: start-node + kind: start + resources: {} + status: + phase: 0 + sum-and-print-0: + activeDeadlineSeconds: 0 + id: sum-and-print-0 + input_bindings: + - binding: + collection: + bindings: + - promise: + nodeId: sum-non-none-0 + var: out + - promise: + nodeId: add-one-and-print-1 + var: out + - promise: + nodeId: add-one-and-print-2 + var: out + - scalar: + primitive: + integer: "100" + var: values_to_add + kind: task + resources: {} + status: + phase: 0 + task_ref: sum-and-print + sum-non-none-0: + activeDeadlineSeconds: 0 + id: sum-non-none-0 + input_bindings: + - binding: + collection: + bindings: + - promise: + nodeId: add-one-and-print-0 + var: out + - promise: + nodeId: add-one-and-print-3 + var: out + var: values_to_print + kind: task + resources: {} + status: + phase: 0 + task_ref: sum-non-none +status: + phase: 0 +tasks: + add-one-and-print: + container: + args: + - --task-module=flytekit.examples.tasks + - --task-name=add_one_and_print + - --inputs={{$input}} + - --output-prefix={{$output}} + command: + - flyte-python-entrypoint + image: myflyteimage:abc123 + resources: + requests: + - value: "0.000" + - value: "2.000" + - value: 2048Mi + id: + name: add-one-and-print + interface: + inputs: + variables: + value_to_print: + type: + simple: INTEGER + outputs: + variables: + out: + type: + simple: INTEGER + metadata: + runtime: + version: 1.19.0b7 + timeout: 0s + type: "7" diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/workflow.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/workflow.go new file mode 100644 index 0000000000..c53a680a98 --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/workflow.go @@ -0,0 +1,232 @@ +package v1alpha1 + +import ( + "bytes" + "encoding/json" + + "k8s.io/apimachinery/pkg/types" + + "github.com/golang/protobuf/jsonpb" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/pkg/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +const StartNodeID = "start-node" +const EndNodeID = "end-node" + +// +genclient +// +k8s:deepcopy-gen:interfaces=k8s.io/apimachinery/pkg/runtime.Object + +// FlyteWorkflow: represents one Execution Workflow object +type FlyteWorkflow struct { + metav1.TypeMeta `json:",inline"` + metav1.ObjectMeta `json:"metadata,omitempty"` + *WorkflowSpec `json:"spec"` + Inputs *Inputs `json:"inputs,omitempty"` + ExecutionID ExecutionID `json:"executionId"` + Tasks map[TaskID]*TaskSpec `json:"tasks"` + SubWorkflows map[WorkflowID]*WorkflowSpec `json:"subWorkflows,omitempty"` + // StartTime before the system will actively try to mark it failed and kill associated containers. + // Value must be a positive integer. + // +optional + ActiveDeadlineSeconds *int64 `json:"activeDeadlineSeconds,omitempty"` + // Specifies the time when the workflow has been accepted into the system. + AcceptedAt *metav1.Time `json:"acceptedAt,omitEmpty"` + // ServiceAccountName is the name of the ServiceAccount to use to run this pod. + // More info: https://kubernetes.io/docs/tasks/configure-pod-container/configure-service-account/ + // +optional + ServiceAccountName string `json:"serviceAccountName,omitempty" protobuf:"bytes,8,opt,name=serviceAccountName"` + // Status is the only mutable section in the workflow. It holds all the execution information + Status WorkflowStatus `json:"status,omitempty"` +} + +var FlyteWorkflowGVK = SchemeGroupVersion.WithKind(FlyteWorkflowKind) + +func (in *FlyteWorkflow) NewControllerRef() metav1.OwnerReference { + // TODO Open Issue - https://github.com/kubernetes/client-go/issues/308 + // For some reason the CRD does not have the GVK correctly populated. So we will fake it. + if len(in.GroupVersionKind().Group) == 0 || len(in.GroupVersionKind().Kind) == 0 || len(in.GroupVersionKind().Version) == 0 { + return *metav1.NewControllerRef(in, FlyteWorkflowGVK) + } + return *metav1.NewControllerRef(in, in.GroupVersionKind()) +} + +func (in *FlyteWorkflow) GetTask(id TaskID) (ExecutableTask, error) { + t, ok := in.Tasks[id] + if !ok { + return nil, errors.Errorf("Unable to find task with Id [%v]", id) + } + return t, nil +} + +func (in *FlyteWorkflow) GetExecutionStatus() ExecutableWorkflowStatus { + return &in.Status +} + +func (in *FlyteWorkflow) GetK8sWorkflowID() types.NamespacedName { + return types.NamespacedName{ + Name: in.GetName(), + Namespace: in.GetNamespace(), + } +} + +func (in *FlyteWorkflow) GetExecutionID() ExecutionID { + return in.ExecutionID +} + +func (in *FlyteWorkflow) FindSubWorkflow(subID WorkflowID) ExecutableSubWorkflow { + s, ok := in.SubWorkflows[subID] + if !ok { + return nil + } + return s +} + +func (in *FlyteWorkflow) GetNodeExecutionStatus(id NodeID) ExecutableNodeStatus { + return in.Status.GetNodeExecutionStatus(id) +} + +func (in *FlyteWorkflow) GetServiceAccountName() string { + return in.ServiceAccountName +} + +type Inputs struct { + *core.LiteralMap +} + +func (in *Inputs) UnmarshalJSON(b []byte) error { + in.LiteralMap = &core.LiteralMap{} + return jsonpb.Unmarshal(bytes.NewReader(b), in.LiteralMap) +} + +func (in *Inputs) MarshalJSON() ([]byte, error) { + var buf bytes.Buffer + if err := marshaler.Marshal(&buf, in.LiteralMap); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *Inputs) DeepCopyInto(out *Inputs) { + *out = *in + // We do not manipulate the object, so its ok + // Once we figure out the autogenerate story we can replace this +} + +type Connections struct { + DownstreamEdges map[NodeID][]NodeID + UpstreamEdges map[NodeID][]NodeID +} + +func (in *Connections) UnmarshalJSON(b []byte) error { + in.DownstreamEdges = map[NodeID][]NodeID{} + err := json.Unmarshal(b, &in.DownstreamEdges) + if err != nil { + return err + } + in.UpstreamEdges = map[NodeID][]NodeID{} + for from, nodes := range in.DownstreamEdges { + for _, to := range nodes { + if _, ok := in.UpstreamEdges[to]; !ok { + in.UpstreamEdges[to] = []NodeID{} + } + in.UpstreamEdges[to] = append(in.UpstreamEdges[to], from) + } + } + return nil +} + +func (in *Connections) MarshalJSON() ([]byte, error) { + return json.Marshal(in.DownstreamEdges) +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *Connections) DeepCopyInto(out *Connections) { + *out = *in + // We do not manipulate the object, so its ok + // Once we figure out the autogenerate story we can replace this +} + +// WorkflowSpec is the spec for the actual Flyte Workflow (DAG) +type WorkflowSpec struct { + ID WorkflowID `json:"id"` + Nodes map[NodeID]*NodeSpec `json:"nodes"` + + // Defines the set of connections (both data dependencies and execution dependencies) that the graph is + // formed of. The execution engine will respect and follow these connections as it determines which nodes + // can and should be executed. + Connections Connections `json:"connections"` + + // Defines a single node to execute in case the system determined the Workflow has failed. + OnFailure *NodeSpec `json:"onFailure,omitempty"` + + // Defines the declaration of the outputs types and names this workflow is expected to generate. + Outputs *OutputVarMap `json:"outputs,omitempty"` + + // Defines the data links used to construct the final outputs of the workflow. Bindings will typically + // refer to specific outputs of a subset of the nodes executed in the Workflow. When executing the end-node, + // the execution engine will traverse these bindings and assemble the final set of outputs of the workflow. + OutputBindings []*Binding `json:"outputBindings,omitempty"` +} + +func (in *WorkflowSpec) StartNode() ExecutableNode { + n, ok := in.Nodes[StartNodeID] + if !ok { + return nil + } + return n +} + +func (in *WorkflowSpec) GetID() WorkflowID { + return in.ID +} + +func (in *WorkflowSpec) FromNode(name NodeID) ([]NodeID, error) { + if _, ok := in.Nodes[name]; !ok { + return nil, errors.Errorf("Bad Node [%v], is not defined in the Workflow [%v]", name, in.ID) + } + downstreamNodes := in.Connections.DownstreamEdges[name] + return downstreamNodes, nil +} + +func (in *WorkflowSpec) GetOutputs() *OutputVarMap { + return in.Outputs +} + +func (in *WorkflowSpec) GetNode(nodeID NodeID) (ExecutableNode, bool) { + n, ok := in.Nodes[nodeID] + return n, ok +} + +func (in *WorkflowSpec) GetConnections() *Connections { + return &in.Connections +} + +func (in *WorkflowSpec) GetOutputBindings() []*Binding { + return in.OutputBindings +} + +func (in *WorkflowSpec) GetOnFailureNode() ExecutableNode { + if in.OnFailure == nil { + return nil + } + return in.OnFailure +} + +func (in *WorkflowSpec) GetNodes() []NodeID { + nodeIds := make([]NodeID, 0, len(in.Nodes)) + for id := range in.Nodes { + nodeIds = append(nodeIds, id) + } + return nodeIds +} + +// +k8s:deepcopy-gen:interfaces=k8s.io/apimachinery/pkg/runtime.Object +// FlyteWorkflowList is a list of FlyteWorkflow resources +type FlyteWorkflowList struct { + metav1.TypeMeta `json:",inline"` + metav1.ListMeta `json:"metadata"` + Items []FlyteWorkflow `json:"items"` +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/workflow_status.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/workflow_status.go new file mode 100644 index 0000000000..4027d4d958 --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/workflow_status.go @@ -0,0 +1,154 @@ +package v1alpha1 + +import ( + "context" + + "github.com/lyft/flytestdlib/storage" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +const maxMessageSize = 1024 + +type WorkflowStatus struct { + Phase WorkflowPhase `json:"phase"` + StartedAt *metav1.Time `json:"startedAt,omitempty"` + StoppedAt *metav1.Time `json:"stoppedAt,omitempty"` + LastUpdatedAt *metav1.Time `json:"lastUpdatedAt,omitempty"` + Message string `json:"message,omitempty"` + DataDir DataReference `json:"dataDir,omitempty"` + OutputReference DataReference `json:"outputRef,omitempty"` + + // We can store the outputs at this layer + // We can also store a cross section of nodes being executed currently here. This could be an optimization + + NodeStatus map[NodeID]*NodeStatus `json:"nodeStatus,omitempty"` + + // Number of Attempts completed with rounds resulting in error. this is used to cap out poison pill workflows + // that spin in an error loop. The value should be set at the global level and will be enforced. At the end of + // the retries the workflow will fail + FailedAttempts uint32 `json:"failedAttempts,omitempty"` +} + +func IsWorkflowPhaseTerminal(p WorkflowPhase) bool { + return p == WorkflowPhaseFailed || p == WorkflowPhaseSuccess || p == WorkflowPhaseAborted +} + +func (in *WorkflowStatus) SetMessage(msg string) { + in.Message = msg +} + +func (in *WorkflowStatus) UpdatePhase(p WorkflowPhase, msg string) { + in.Phase = p + in.Message = msg + if len(msg) > maxMessageSize { + in.Message = msg[:maxMessageSize] + } + + n := metav1.Now() + if in.StartedAt == nil { + in.StartedAt = &n + } + + if IsWorkflowPhaseTerminal(p) && in.StoppedAt == nil { + in.StoppedAt = &n + } + + in.LastUpdatedAt = &n +} + +func (in *WorkflowStatus) IncFailedAttempts() { + in.FailedAttempts++ +} + +func (in *WorkflowStatus) GetPhase() WorkflowPhase { + return in.Phase +} + +func (in *WorkflowStatus) GetStartedAt() *metav1.Time { + return in.StartedAt +} + +func (in *WorkflowStatus) GetStoppedAt() *metav1.Time { + return in.StoppedAt +} + +func (in *WorkflowStatus) GetLastUpdatedAt() *metav1.Time { + return in.LastUpdatedAt +} + +func (in *WorkflowStatus) IsTerminated() bool { + return in.Phase == WorkflowPhaseSuccess || in.Phase == WorkflowPhaseFailed || in.Phase == WorkflowPhaseAborted +} + +func (in *WorkflowStatus) GetMessage() string { + return in.Message +} + +func (in *WorkflowStatus) GetNodeExecutionStatus(id NodeID) ExecutableNodeStatus { + n, ok := in.NodeStatus[id] + if ok { + return n + } + if in.NodeStatus == nil { + in.NodeStatus = make(map[NodeID]*NodeStatus) + } + newNodeStatus := &NodeStatus{} + in.NodeStatus[id] = newNodeStatus + return newNodeStatus +} + +func (in *WorkflowStatus) ConstructNodeDataDir(ctx context.Context, constructor storage.ReferenceConstructor, name NodeID) (storage.DataReference, error) { + return constructor.ConstructReference(ctx, in.GetDataDir(), name, "data") +} + +func (in *WorkflowStatus) GetDataDir() DataReference { + return in.DataDir +} + +func (in *WorkflowStatus) SetDataDir(d DataReference) { + in.DataDir = d +} + +func (in *WorkflowStatus) GetOutputReference() DataReference { + return in.OutputReference +} + +func (in *WorkflowStatus) SetOutputReference(reference DataReference) { + in.OutputReference = reference +} + +func (in *WorkflowStatus) Equals(other *WorkflowStatus) bool { + // Assuming in is never nil! + if other == nil { + return false + } + if in.FailedAttempts != other.FailedAttempts { + return false + } + if in.Phase != other.Phase { + return false + } + // We will not compare the time and message + if in.DataDir != other.DataDir { + return false + } + + if in.OutputReference != other.OutputReference { + return false + } + + if len(in.NodeStatus) != len(other.NodeStatus) { + return false + } + + for k, v := range in.NodeStatus { + otherV, ok := other.NodeStatus[k] + if !ok { + return false + } + if !v.Equals(otherV) { + return false + } + } + return true +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/workflow_status_test.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/workflow_status_test.go new file mode 100644 index 0000000000..9d53caac70 --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/workflow_status_test.go @@ -0,0 +1,54 @@ +package v1alpha1 + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIsWorkflowPhaseTerminal(t *testing.T) { + assert.True(t, IsWorkflowPhaseTerminal(WorkflowPhaseFailed)) + assert.True(t, IsWorkflowPhaseTerminal(WorkflowPhaseSuccess)) + + assert.False(t, IsWorkflowPhaseTerminal(WorkflowPhaseFailing)) + assert.False(t, IsWorkflowPhaseTerminal(WorkflowPhaseSucceeding)) + assert.False(t, IsWorkflowPhaseTerminal(WorkflowPhaseReady)) + assert.False(t, IsWorkflowPhaseTerminal(WorkflowPhaseRunning)) +} + +func TestWorkflowStatus_Equals(t *testing.T) { + one := &WorkflowStatus{} + other := &WorkflowStatus{} + assert.True(t, one.Equals(other)) + + one.Phase = WorkflowPhaseRunning + assert.False(t, one.Equals(other)) + + other.Phase = one.Phase + assert.True(t, one.Equals(other)) + + one.DataDir = "data-dir" + assert.False(t, one.Equals(other)) + other.DataDir = one.DataDir + assert.True(t, one.Equals(other)) + + node := "x" + one.NodeStatus = map[NodeID]*NodeStatus{ + node: {}, + } + assert.False(t, one.Equals(other)) + other.NodeStatus = map[NodeID]*NodeStatus{ + node: {}, + } + assert.True(t, one.Equals(other)) + + one.NodeStatus[node].Phase = NodePhaseRunning + assert.False(t, one.Equals(other)) + other.NodeStatus[node].Phase = NodePhaseRunning + assert.True(t, one.Equals(other)) + + one.OutputReference = "out" + assert.False(t, one.Equals(other)) + other.OutputReference = "out" + assert.True(t, one.Equals(other)) +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/workflow_test.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/workflow_test.go new file mode 100644 index 0000000000..c50d118de8 --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/workflow_test.go @@ -0,0 +1,51 @@ +package v1alpha1_test + +import ( + "encoding/json" + "io/ioutil" + "testing" + + "github.com/ghodss/yaml" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/stretchr/testify/assert" + "k8s.io/apimachinery/pkg/util/sets" +) + +func TestMarshalUnmarshal_Connections(t *testing.T) { + r, err := ioutil.ReadFile("testdata/connections.json") + assert.NoError(t, err) + o := v1alpha1.Connections{} + err = json.Unmarshal(r, &o) + assert.NoError(t, err) + assert.Equal(t, map[v1alpha1.NodeID][]v1alpha1.NodeID{ + "n1": {"n2", "n3"}, + "n2": {"n4"}, + "n3": {"n4"}, + "n4": {"n5"}, + }, o.DownstreamEdges) + assert.Equal(t, []v1alpha1.NodeID{"n1"}, o.UpstreamEdges["n2"]) + assert.Equal(t, []v1alpha1.NodeID{"n1"}, o.UpstreamEdges["n3"]) + assert.Equal(t, []v1alpha1.NodeID{"n4"}, o.UpstreamEdges["n5"]) + assert.True(t, sets.NewString(o.UpstreamEdges["n4"]...).Equal(sets.NewString("n2", "n3"))) +} + +func ReadYamlFileAsJSON(path string) ([]byte, error) { + r, err := ioutil.ReadFile(path) + if err != nil { + return nil, err + } + return yaml.YAMLToJSON(r) +} + +func TestWorkflowSpec(t *testing.T) { + j, err := ReadYamlFileAsJSON("testdata/workflowspec.yaml") + assert.NoError(t, err) + w := &v1alpha1.FlyteWorkflow{} + err = json.Unmarshal(j, w) + assert.NoError(t, err) + assert.NotNil(t, w.WorkflowSpec) + assert.Nil(t, w.GetOnFailureNode()) + assert.Equal(t, 7, len(w.Connections.DownstreamEdges)) + assert.Equal(t, 8, len(w.Connections.UpstreamEdges)) + +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/zz_generated.deepcopy.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/zz_generated.deepcopy.go new file mode 100644 index 0000000000..a4cd7186af --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/zz_generated.deepcopy.go @@ -0,0 +1,676 @@ +// +build !ignore_autogenerated + +// Code generated by deepcopy-gen. DO NOT EDIT. + +package v1alpha1 + +import ( + v1 "k8s.io/api/core/v1" + runtime "k8s.io/apimachinery/pkg/runtime" +) + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Alias. +func (in *Alias) DeepCopy() *Alias { + if in == nil { + return nil + } + out := new(Alias) + in.DeepCopyInto(out) + return out +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Binding. +func (in *Binding) DeepCopy() *Binding { + if in == nil { + return nil + } + out := new(Binding) + in.DeepCopyInto(out) + return out +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new BooleanExpression. +func (in *BooleanExpression) DeepCopy() *BooleanExpression { + if in == nil { + return nil + } + out := new(BooleanExpression) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *BranchNodeSpec) DeepCopyInto(out *BranchNodeSpec) { + *out = *in + in.If.DeepCopyInto(&out.If) + if in.ElseIf != nil { + in, out := &in.ElseIf, &out.ElseIf + *out = make([]*IfBlock, len(*in)) + for i := range *in { + if (*in)[i] != nil { + in, out := &(*in)[i], &(*out)[i] + *out = new(IfBlock) + (*in).DeepCopyInto(*out) + } + } + } + if in.Else != nil { + in, out := &in.Else, &out.Else + *out = new(string) + **out = **in + } + if in.ElseFail != nil { + in, out := &in.ElseFail, &out.ElseFail + *out = (*in).DeepCopy() + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new BranchNodeSpec. +func (in *BranchNodeSpec) DeepCopy() *BranchNodeSpec { + if in == nil { + return nil + } + out := new(BranchNodeSpec) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *BranchNodeStatus) DeepCopyInto(out *BranchNodeStatus) { + *out = *in + if in.FinalizedNodeID != nil { + in, out := &in.FinalizedNodeID, &out.FinalizedNodeID + *out = new(string) + **out = **in + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new BranchNodeStatus. +func (in *BranchNodeStatus) DeepCopy() *BranchNodeStatus { + if in == nil { + return nil + } + out := new(BranchNodeStatus) + in.DeepCopyInto(out) + return out +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Connections. +func (in *Connections) DeepCopy() *Connections { + if in == nil { + return nil + } + out := new(Connections) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *DynamicNodeStatus) DeepCopyInto(out *DynamicNodeStatus) { + *out = *in + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new DynamicNodeStatus. +func (in *DynamicNodeStatus) DeepCopy() *DynamicNodeStatus { + if in == nil { + return nil + } + out := new(DynamicNodeStatus) + in.DeepCopyInto(out) + return out +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Error. +func (in *Error) DeepCopy() *Error { + if in == nil { + return nil + } + out := new(Error) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *FlyteWorkflow) DeepCopyInto(out *FlyteWorkflow) { + *out = *in + out.TypeMeta = in.TypeMeta + in.ObjectMeta.DeepCopyInto(&out.ObjectMeta) + if in.WorkflowSpec != nil { + in, out := &in.WorkflowSpec, &out.WorkflowSpec + *out = new(WorkflowSpec) + (*in).DeepCopyInto(*out) + } + if in.Inputs != nil { + in, out := &in.Inputs, &out.Inputs + *out = (*in).DeepCopy() + } + in.ExecutionID.DeepCopyInto(&out.ExecutionID) + if in.Tasks != nil { + in, out := &in.Tasks, &out.Tasks + *out = make(map[string]*TaskSpec, len(*in)) + for key, val := range *in { + var outVal *TaskSpec + if val == nil { + (*out)[key] = nil + } else { + in, out := &val, &outVal + *out = (*in).DeepCopy() + } + (*out)[key] = outVal + } + } + if in.SubWorkflows != nil { + in, out := &in.SubWorkflows, &out.SubWorkflows + *out = make(map[string]*WorkflowSpec, len(*in)) + for key, val := range *in { + var outVal *WorkflowSpec + if val == nil { + (*out)[key] = nil + } else { + in, out := &val, &outVal + *out = new(WorkflowSpec) + (*in).DeepCopyInto(*out) + } + (*out)[key] = outVal + } + } + if in.ActiveDeadlineSeconds != nil { + in, out := &in.ActiveDeadlineSeconds, &out.ActiveDeadlineSeconds + *out = new(int64) + **out = **in + } + if in.AcceptedAt != nil { + in, out := &in.AcceptedAt, &out.AcceptedAt + *out = (*in).DeepCopy() + } + in.Status.DeepCopyInto(&out.Status) + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new FlyteWorkflow. +func (in *FlyteWorkflow) DeepCopy() *FlyteWorkflow { + if in == nil { + return nil + } + out := new(FlyteWorkflow) + in.DeepCopyInto(out) + return out +} + +// DeepCopyObject is an autogenerated deepcopy function, copying the receiver, creating a new runtime.Object. +func (in *FlyteWorkflow) DeepCopyObject() runtime.Object { + if c := in.DeepCopy(); c != nil { + return c + } + return nil +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *FlyteWorkflowList) DeepCopyInto(out *FlyteWorkflowList) { + *out = *in + out.TypeMeta = in.TypeMeta + out.ListMeta = in.ListMeta + if in.Items != nil { + in, out := &in.Items, &out.Items + *out = make([]FlyteWorkflow, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new FlyteWorkflowList. +func (in *FlyteWorkflowList) DeepCopy() *FlyteWorkflowList { + if in == nil { + return nil + } + out := new(FlyteWorkflowList) + in.DeepCopyInto(out) + return out +} + +// DeepCopyObject is an autogenerated deepcopy function, copying the receiver, creating a new runtime.Object. +func (in *FlyteWorkflowList) DeepCopyObject() runtime.Object { + if c := in.DeepCopy(); c != nil { + return c + } + return nil +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Identifier. +func (in *Identifier) DeepCopy() *Identifier { + if in == nil { + return nil + } + out := new(Identifier) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *IfBlock) DeepCopyInto(out *IfBlock) { + *out = *in + in.Condition.DeepCopyInto(&out.Condition) + if in.ThenNode != nil { + in, out := &in.ThenNode, &out.ThenNode + *out = new(string) + **out = **in + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new IfBlock. +func (in *IfBlock) DeepCopy() *IfBlock { + if in == nil { + return nil + } + out := new(IfBlock) + in.DeepCopyInto(out) + return out +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Inputs. +func (in *Inputs) DeepCopy() *Inputs { + if in == nil { + return nil + } + out := new(Inputs) + in.DeepCopyInto(out) + return out +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new NodeMetadata. +func (in *NodeMetadata) DeepCopy() *NodeMetadata { + if in == nil { + return nil + } + out := new(NodeMetadata) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *NodeSpec) DeepCopyInto(out *NodeSpec) { + *out = *in + if in.Resources != nil { + in, out := &in.Resources, &out.Resources + *out = new(v1.ResourceRequirements) + (*in).DeepCopyInto(*out) + } + if in.BranchNode != nil { + in, out := &in.BranchNode, &out.BranchNode + *out = new(BranchNodeSpec) + (*in).DeepCopyInto(*out) + } + if in.TaskRef != nil { + in, out := &in.TaskRef, &out.TaskRef + *out = new(string) + **out = **in + } + if in.WorkflowNode != nil { + in, out := &in.WorkflowNode, &out.WorkflowNode + *out = new(WorkflowNodeSpec) + (*in).DeepCopyInto(*out) + } + if in.InputBindings != nil { + in, out := &in.InputBindings, &out.InputBindings + *out = make([]*Binding, len(*in)) + for i := range *in { + if (*in)[i] != nil { + in, out := &(*in)[i], &(*out)[i] + *out = (*in).DeepCopy() + } + } + } + if in.Config != nil { + in, out := &in.Config, &out.Config + *out = new(v1.ConfigMap) + (*in).DeepCopyInto(*out) + } + if in.RetryStrategy != nil { + in, out := &in.RetryStrategy, &out.RetryStrategy + *out = new(RetryStrategy) + (*in).DeepCopyInto(*out) + } + if in.OutputAliases != nil { + in, out := &in.OutputAliases, &out.OutputAliases + *out = make([]Alias, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } + if in.SecurityContext != nil { + in, out := &in.SecurityContext, &out.SecurityContext + *out = new(v1.PodSecurityContext) + (*in).DeepCopyInto(*out) + } + if in.ImagePullSecrets != nil { + in, out := &in.ImagePullSecrets, &out.ImagePullSecrets + *out = make([]v1.LocalObjectReference, len(*in)) + copy(*out, *in) + } + if in.Affinity != nil { + in, out := &in.Affinity, &out.Affinity + *out = new(v1.Affinity) + (*in).DeepCopyInto(*out) + } + if in.Tolerations != nil { + in, out := &in.Tolerations, &out.Tolerations + *out = make([]v1.Toleration, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } + if in.ActiveDeadlineSeconds != nil { + in, out := &in.ActiveDeadlineSeconds, &out.ActiveDeadlineSeconds + *out = new(int64) + **out = **in + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new NodeSpec. +func (in *NodeSpec) DeepCopy() *NodeSpec { + if in == nil { + return nil + } + out := new(NodeSpec) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *NodeStatus) DeepCopyInto(out *NodeStatus) { + *out = *in + if in.QueuedAt != nil { + in, out := &in.QueuedAt, &out.QueuedAt + *out = (*in).DeepCopy() + } + if in.StartedAt != nil { + in, out := &in.StartedAt, &out.StartedAt + *out = (*in).DeepCopy() + } + if in.StoppedAt != nil { + in, out := &in.StoppedAt, &out.StoppedAt + *out = (*in).DeepCopy() + } + if in.LastUpdatedAt != nil { + in, out := &in.LastUpdatedAt, &out.LastUpdatedAt + *out = (*in).DeepCopy() + } + if in.ParentNode != nil { + in, out := &in.ParentNode, &out.ParentNode + *out = new(string) + **out = **in + } + if in.ParentTask != nil { + in, out := &in.ParentTask, &out.ParentTask + *out = (*in).DeepCopy() + } + if in.BranchStatus != nil { + in, out := &in.BranchStatus, &out.BranchStatus + *out = new(BranchNodeStatus) + (*in).DeepCopyInto(*out) + } + if in.SubNodeStatus != nil { + in, out := &in.SubNodeStatus, &out.SubNodeStatus + *out = make(map[string]*NodeStatus, len(*in)) + for key, val := range *in { + var outVal *NodeStatus + if val == nil { + (*out)[key] = nil + } else { + in, out := &val, &outVal + *out = new(NodeStatus) + (*in).DeepCopyInto(*out) + } + (*out)[key] = outVal + } + } + if in.WorkflowNodeStatus != nil { + in, out := &in.WorkflowNodeStatus, &out.WorkflowNodeStatus + *out = new(WorkflowNodeStatus) + **out = **in + } + if in.TaskNodeStatus != nil { + in, out := &in.TaskNodeStatus, &out.TaskNodeStatus + *out = (*in).DeepCopy() + } + if in.SubWorkflowNodeStatus != nil { + in, out := &in.SubWorkflowNodeStatus, &out.SubWorkflowNodeStatus + *out = new(SubWorkflowNodeStatus) + **out = **in + } + if in.DynamicNodeStatus != nil { + in, out := &in.DynamicNodeStatus, &out.DynamicNodeStatus + *out = new(DynamicNodeStatus) + **out = **in + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new NodeStatus. +func (in *NodeStatus) DeepCopy() *NodeStatus { + if in == nil { + return nil + } + out := new(NodeStatus) + in.DeepCopyInto(out) + return out +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new OutputVarMap. +func (in *OutputVarMap) DeepCopy() *OutputVarMap { + if in == nil { + return nil + } + out := new(OutputVarMap) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *RetryStrategy) DeepCopyInto(out *RetryStrategy) { + *out = *in + if in.MinAttempts != nil { + in, out := &in.MinAttempts, &out.MinAttempts + *out = new(int) + **out = **in + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new RetryStrategy. +func (in *RetryStrategy) DeepCopy() *RetryStrategy { + if in == nil { + return nil + } + out := new(RetryStrategy) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *SubWorkflowNodeStatus) DeepCopyInto(out *SubWorkflowNodeStatus) { + *out = *in + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new SubWorkflowNodeStatus. +func (in *SubWorkflowNodeStatus) DeepCopy() *SubWorkflowNodeStatus { + if in == nil { + return nil + } + out := new(SubWorkflowNodeStatus) + in.DeepCopyInto(out) + return out +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new TaskExecutionIdentifier. +func (in *TaskExecutionIdentifier) DeepCopy() *TaskExecutionIdentifier { + if in == nil { + return nil + } + out := new(TaskExecutionIdentifier) + in.DeepCopyInto(out) + return out +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new TaskSpec. +func (in *TaskSpec) DeepCopy() *TaskSpec { + if in == nil { + return nil + } + out := new(TaskSpec) + in.DeepCopyInto(out) + return out +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new WorkflowExecutionIdentifier. +func (in *WorkflowExecutionIdentifier) DeepCopy() *WorkflowExecutionIdentifier { + if in == nil { + return nil + } + out := new(WorkflowExecutionIdentifier) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *WorkflowNodeSpec) DeepCopyInto(out *WorkflowNodeSpec) { + *out = *in + if in.LaunchPlanRefID != nil { + in, out := &in.LaunchPlanRefID, &out.LaunchPlanRefID + *out = (*in).DeepCopy() + } + if in.SubWorkflowReference != nil { + in, out := &in.SubWorkflowReference, &out.SubWorkflowReference + *out = new(string) + **out = **in + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new WorkflowNodeSpec. +func (in *WorkflowNodeSpec) DeepCopy() *WorkflowNodeSpec { + if in == nil { + return nil + } + out := new(WorkflowNodeSpec) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *WorkflowNodeStatus) DeepCopyInto(out *WorkflowNodeStatus) { + *out = *in + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new WorkflowNodeStatus. +func (in *WorkflowNodeStatus) DeepCopy() *WorkflowNodeStatus { + if in == nil { + return nil + } + out := new(WorkflowNodeStatus) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *WorkflowSpec) DeepCopyInto(out *WorkflowSpec) { + *out = *in + if in.Nodes != nil { + in, out := &in.Nodes, &out.Nodes + *out = make(map[string]*NodeSpec, len(*in)) + for key, val := range *in { + var outVal *NodeSpec + if val == nil { + (*out)[key] = nil + } else { + in, out := &val, &outVal + *out = new(NodeSpec) + (*in).DeepCopyInto(*out) + } + (*out)[key] = outVal + } + } + in.Connections.DeepCopyInto(&out.Connections) + if in.OnFailure != nil { + in, out := &in.OnFailure, &out.OnFailure + *out = new(NodeSpec) + (*in).DeepCopyInto(*out) + } + if in.Outputs != nil { + in, out := &in.Outputs, &out.Outputs + *out = (*in).DeepCopy() + } + if in.OutputBindings != nil { + in, out := &in.OutputBindings, &out.OutputBindings + *out = make([]*Binding, len(*in)) + for i := range *in { + if (*in)[i] != nil { + in, out := &(*in)[i], &(*out)[i] + *out = (*in).DeepCopy() + } + } + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new WorkflowSpec. +func (in *WorkflowSpec) DeepCopy() *WorkflowSpec { + if in == nil { + return nil + } + out := new(WorkflowSpec) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *WorkflowStatus) DeepCopyInto(out *WorkflowStatus) { + *out = *in + if in.StartedAt != nil { + in, out := &in.StartedAt, &out.StartedAt + *out = (*in).DeepCopy() + } + if in.StoppedAt != nil { + in, out := &in.StoppedAt, &out.StoppedAt + *out = (*in).DeepCopy() + } + if in.LastUpdatedAt != nil { + in, out := &in.LastUpdatedAt, &out.LastUpdatedAt + *out = (*in).DeepCopy() + } + if in.NodeStatus != nil { + in, out := &in.NodeStatus, &out.NodeStatus + *out = make(map[string]*NodeStatus, len(*in)) + for key, val := range *in { + var outVal *NodeStatus + if val == nil { + (*out)[key] = nil + } else { + in, out := &val, &outVal + *out = new(NodeStatus) + (*in).DeepCopyInto(*out) + } + (*out)[key] = outVal + } + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new WorkflowStatus. +func (in *WorkflowStatus) DeepCopy() *WorkflowStatus { + if in == nil { + return nil + } + out := new(WorkflowStatus) + in.DeepCopyInto(out) + return out +} diff --git a/flytepropeller/pkg/client/clientset/versioned/clientset.go b/flytepropeller/pkg/client/clientset/versioned/clientset.go new file mode 100644 index 0000000000..93756ede06 --- /dev/null +++ b/flytepropeller/pkg/client/clientset/versioned/clientset.go @@ -0,0 +1,82 @@ +// Code generated by client-gen. DO NOT EDIT. + +package versioned + +import ( + flyteworkflowv1alpha1 "github.com/lyft/flytepropeller/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1" + discovery "k8s.io/client-go/discovery" + rest "k8s.io/client-go/rest" + flowcontrol "k8s.io/client-go/util/flowcontrol" +) + +type Interface interface { + Discovery() discovery.DiscoveryInterface + FlyteworkflowV1alpha1() flyteworkflowv1alpha1.FlyteworkflowV1alpha1Interface + // Deprecated: please explicitly pick a version if possible. + Flyteworkflow() flyteworkflowv1alpha1.FlyteworkflowV1alpha1Interface +} + +// Clientset contains the clients for groups. Each group has exactly one +// version included in a Clientset. +type Clientset struct { + *discovery.DiscoveryClient + flyteworkflowV1alpha1 *flyteworkflowv1alpha1.FlyteworkflowV1alpha1Client +} + +// FlyteworkflowV1alpha1 retrieves the FlyteworkflowV1alpha1Client +func (c *Clientset) FlyteworkflowV1alpha1() flyteworkflowv1alpha1.FlyteworkflowV1alpha1Interface { + return c.flyteworkflowV1alpha1 +} + +// Deprecated: Flyteworkflow retrieves the default version of FlyteworkflowClient. +// Please explicitly pick a version. +func (c *Clientset) Flyteworkflow() flyteworkflowv1alpha1.FlyteworkflowV1alpha1Interface { + return c.flyteworkflowV1alpha1 +} + +// Discovery retrieves the DiscoveryClient +func (c *Clientset) Discovery() discovery.DiscoveryInterface { + if c == nil { + return nil + } + return c.DiscoveryClient +} + +// NewForConfig creates a new Clientset for the given config. +func NewForConfig(c *rest.Config) (*Clientset, error) { + configShallowCopy := *c + if configShallowCopy.RateLimiter == nil && configShallowCopy.QPS > 0 { + configShallowCopy.RateLimiter = flowcontrol.NewTokenBucketRateLimiter(configShallowCopy.QPS, configShallowCopy.Burst) + } + var cs Clientset + var err error + cs.flyteworkflowV1alpha1, err = flyteworkflowv1alpha1.NewForConfig(&configShallowCopy) + if err != nil { + return nil, err + } + + cs.DiscoveryClient, err = discovery.NewDiscoveryClientForConfig(&configShallowCopy) + if err != nil { + return nil, err + } + return &cs, nil +} + +// NewForConfigOrDie creates a new Clientset for the given config and +// panics if there is an error in the config. +func NewForConfigOrDie(c *rest.Config) *Clientset { + var cs Clientset + cs.flyteworkflowV1alpha1 = flyteworkflowv1alpha1.NewForConfigOrDie(c) + + cs.DiscoveryClient = discovery.NewDiscoveryClientForConfigOrDie(c) + return &cs +} + +// New creates a new Clientset for the given RESTClient. +func New(c rest.Interface) *Clientset { + var cs Clientset + cs.flyteworkflowV1alpha1 = flyteworkflowv1alpha1.New(c) + + cs.DiscoveryClient = discovery.NewDiscoveryClient(c) + return &cs +} diff --git a/flytepropeller/pkg/client/clientset/versioned/doc.go b/flytepropeller/pkg/client/clientset/versioned/doc.go new file mode 100644 index 0000000000..0e0c2a8900 --- /dev/null +++ b/flytepropeller/pkg/client/clientset/versioned/doc.go @@ -0,0 +1,4 @@ +// Code generated by client-gen. DO NOT EDIT. + +// This package has the automatically generated clientset. +package versioned diff --git a/flytepropeller/pkg/client/clientset/versioned/fake/clientset_generated.go b/flytepropeller/pkg/client/clientset/versioned/fake/clientset_generated.go new file mode 100644 index 0000000000..65395e5db5 --- /dev/null +++ b/flytepropeller/pkg/client/clientset/versioned/fake/clientset_generated.go @@ -0,0 +1,66 @@ +// Code generated by client-gen. DO NOT EDIT. + +package fake + +import ( + clientset "github.com/lyft/flytepropeller/pkg/client/clientset/versioned" + flyteworkflowv1alpha1 "github.com/lyft/flytepropeller/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1" + fakeflyteworkflowv1alpha1 "github.com/lyft/flytepropeller/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/fake" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/watch" + "k8s.io/client-go/discovery" + fakediscovery "k8s.io/client-go/discovery/fake" + "k8s.io/client-go/testing" +) + +// NewSimpleClientset returns a clientset that will respond with the provided objects. +// It's backed by a very simple object tracker that processes creates, updates and deletions as-is, +// without applying any validations and/or defaults. It shouldn't be considered a replacement +// for a real clientset and is mostly useful in simple unit tests. +func NewSimpleClientset(objects ...runtime.Object) *Clientset { + o := testing.NewObjectTracker(scheme, codecs.UniversalDecoder()) + for _, obj := range objects { + if err := o.Add(obj); err != nil { + panic(err) + } + } + + cs := &Clientset{} + cs.discovery = &fakediscovery.FakeDiscovery{Fake: &cs.Fake} + cs.AddReactor("*", "*", testing.ObjectReaction(o)) + cs.AddWatchReactor("*", func(action testing.Action) (handled bool, ret watch.Interface, err error) { + gvr := action.GetResource() + ns := action.GetNamespace() + watch, err := o.Watch(gvr, ns) + if err != nil { + return false, nil, err + } + return true, watch, nil + }) + + return cs +} + +// Clientset implements clientset.Interface. Meant to be embedded into a +// struct to get a default implementation. This makes faking out just the method +// you want to test easier. +type Clientset struct { + testing.Fake + discovery *fakediscovery.FakeDiscovery +} + +func (c *Clientset) Discovery() discovery.DiscoveryInterface { + return c.discovery +} + +var _ clientset.Interface = &Clientset{} + +// FlyteworkflowV1alpha1 retrieves the FlyteworkflowV1alpha1Client +func (c *Clientset) FlyteworkflowV1alpha1() flyteworkflowv1alpha1.FlyteworkflowV1alpha1Interface { + return &fakeflyteworkflowv1alpha1.FakeFlyteworkflowV1alpha1{Fake: &c.Fake} +} + +// Flyteworkflow retrieves the FlyteworkflowV1alpha1Client +func (c *Clientset) Flyteworkflow() flyteworkflowv1alpha1.FlyteworkflowV1alpha1Interface { + return &fakeflyteworkflowv1alpha1.FakeFlyteworkflowV1alpha1{Fake: &c.Fake} +} diff --git a/flytepropeller/pkg/client/clientset/versioned/fake/doc.go b/flytepropeller/pkg/client/clientset/versioned/fake/doc.go new file mode 100644 index 0000000000..3630ed1cd1 --- /dev/null +++ b/flytepropeller/pkg/client/clientset/versioned/fake/doc.go @@ -0,0 +1,4 @@ +// Code generated by client-gen. DO NOT EDIT. + +// This package has the automatically generated fake clientset. +package fake diff --git a/flytepropeller/pkg/client/clientset/versioned/fake/register.go b/flytepropeller/pkg/client/clientset/versioned/fake/register.go new file mode 100644 index 0000000000..23a9c3a390 --- /dev/null +++ b/flytepropeller/pkg/client/clientset/versioned/fake/register.go @@ -0,0 +1,40 @@ +// Code generated by client-gen. DO NOT EDIT. + +package fake + +import ( + flyteworkflowv1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + runtime "k8s.io/apimachinery/pkg/runtime" + schema "k8s.io/apimachinery/pkg/runtime/schema" + serializer "k8s.io/apimachinery/pkg/runtime/serializer" + utilruntime "k8s.io/apimachinery/pkg/util/runtime" +) + +var scheme = runtime.NewScheme() +var codecs = serializer.NewCodecFactory(scheme) +var parameterCodec = runtime.NewParameterCodec(scheme) +var localSchemeBuilder = runtime.SchemeBuilder{ + flyteworkflowv1alpha1.AddToScheme, +} + +// AddToScheme adds all types of this clientset into the given scheme. This allows composition +// of clientsets, like in: +// +// import ( +// "k8s.io/client-go/kubernetes" +// clientsetscheme "k8s.io/client-go/kubernetes/scheme" +// aggregatorclientsetscheme "k8s.io/kube-aggregator/pkg/client/clientset_generated/clientset/scheme" +// ) +// +// kclientset, _ := kubernetes.NewForConfig(c) +// _ = aggregatorclientsetscheme.AddToScheme(clientsetscheme.Scheme) +// +// After this, RawExtensions in Kubernetes types will serialize kube-aggregator types +// correctly. +var AddToScheme = localSchemeBuilder.AddToScheme + +func init() { + v1.AddToGroupVersion(scheme, schema.GroupVersion{Version: "v1"}) + utilruntime.Must(AddToScheme(scheme)) +} diff --git a/flytepropeller/pkg/client/clientset/versioned/scheme/doc.go b/flytepropeller/pkg/client/clientset/versioned/scheme/doc.go new file mode 100644 index 0000000000..14db57a58f --- /dev/null +++ b/flytepropeller/pkg/client/clientset/versioned/scheme/doc.go @@ -0,0 +1,4 @@ +// Code generated by client-gen. DO NOT EDIT. + +// This package contains the scheme of the automatically generated clientset. +package scheme diff --git a/flytepropeller/pkg/client/clientset/versioned/scheme/register.go b/flytepropeller/pkg/client/clientset/versioned/scheme/register.go new file mode 100644 index 0000000000..6323cb3268 --- /dev/null +++ b/flytepropeller/pkg/client/clientset/versioned/scheme/register.go @@ -0,0 +1,40 @@ +// Code generated by client-gen. DO NOT EDIT. + +package scheme + +import ( + flyteworkflowv1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + runtime "k8s.io/apimachinery/pkg/runtime" + schema "k8s.io/apimachinery/pkg/runtime/schema" + serializer "k8s.io/apimachinery/pkg/runtime/serializer" + utilruntime "k8s.io/apimachinery/pkg/util/runtime" +) + +var Scheme = runtime.NewScheme() +var Codecs = serializer.NewCodecFactory(Scheme) +var ParameterCodec = runtime.NewParameterCodec(Scheme) +var localSchemeBuilder = runtime.SchemeBuilder{ + flyteworkflowv1alpha1.AddToScheme, +} + +// AddToScheme adds all types of this clientset into the given scheme. This allows composition +// of clientsets, like in: +// +// import ( +// "k8s.io/client-go/kubernetes" +// clientsetscheme "k8s.io/client-go/kubernetes/scheme" +// aggregatorclientsetscheme "k8s.io/kube-aggregator/pkg/client/clientset_generated/clientset/scheme" +// ) +// +// kclientset, _ := kubernetes.NewForConfig(c) +// _ = aggregatorclientsetscheme.AddToScheme(clientsetscheme.Scheme) +// +// After this, RawExtensions in Kubernetes types will serialize kube-aggregator types +// correctly. +var AddToScheme = localSchemeBuilder.AddToScheme + +func init() { + v1.AddToGroupVersion(Scheme, schema.GroupVersion{Version: "v1"}) + utilruntime.Must(AddToScheme(Scheme)) +} diff --git a/flytepropeller/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/doc.go b/flytepropeller/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/doc.go new file mode 100644 index 0000000000..93a7ca4e0e --- /dev/null +++ b/flytepropeller/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/doc.go @@ -0,0 +1,4 @@ +// Code generated by client-gen. DO NOT EDIT. + +// This package has the automatically generated typed clients. +package v1alpha1 diff --git a/flytepropeller/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/fake/doc.go b/flytepropeller/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/fake/doc.go new file mode 100644 index 0000000000..2b5ba4c8e4 --- /dev/null +++ b/flytepropeller/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/fake/doc.go @@ -0,0 +1,4 @@ +// Code generated by client-gen. DO NOT EDIT. + +// Package fake has the automatically generated clients. +package fake diff --git a/flytepropeller/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/fake/fake_flyteworkflow.go b/flytepropeller/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/fake/fake_flyteworkflow.go new file mode 100644 index 0000000000..c9f48bec43 --- /dev/null +++ b/flytepropeller/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/fake/fake_flyteworkflow.go @@ -0,0 +1,124 @@ +// Code generated by client-gen. DO NOT EDIT. + +package fake + +import ( + v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + labels "k8s.io/apimachinery/pkg/labels" + schema "k8s.io/apimachinery/pkg/runtime/schema" + types "k8s.io/apimachinery/pkg/types" + watch "k8s.io/apimachinery/pkg/watch" + testing "k8s.io/client-go/testing" +) + +// FakeFlyteWorkflows implements FlyteWorkflowInterface +type FakeFlyteWorkflows struct { + Fake *FakeFlyteworkflowV1alpha1 + ns string +} + +var flyteworkflowsResource = schema.GroupVersionResource{Group: "flyteworkflow.flyte.net", Version: "v1alpha1", Resource: "flyteworkflows"} + +var flyteworkflowsKind = schema.GroupVersionKind{Group: "flyteworkflow.flyte.net", Version: "v1alpha1", Kind: "FlyteWorkflow"} + +// Get takes name of the flyteWorkflow, and returns the corresponding flyteWorkflow object, and an error if there is any. +func (c *FakeFlyteWorkflows) Get(name string, options v1.GetOptions) (result *v1alpha1.FlyteWorkflow, err error) { + obj, err := c.Fake. + Invokes(testing.NewGetAction(flyteworkflowsResource, c.ns, name), &v1alpha1.FlyteWorkflow{}) + + if obj == nil { + return nil, err + } + return obj.(*v1alpha1.FlyteWorkflow), err +} + +// List takes label and field selectors, and returns the list of FlyteWorkflows that match those selectors. +func (c *FakeFlyteWorkflows) List(opts v1.ListOptions) (result *v1alpha1.FlyteWorkflowList, err error) { + obj, err := c.Fake. + Invokes(testing.NewListAction(flyteworkflowsResource, flyteworkflowsKind, c.ns, opts), &v1alpha1.FlyteWorkflowList{}) + + if obj == nil { + return nil, err + } + + label, _, _ := testing.ExtractFromListOptions(opts) + if label == nil { + label = labels.Everything() + } + list := &v1alpha1.FlyteWorkflowList{ListMeta: obj.(*v1alpha1.FlyteWorkflowList).ListMeta} + for _, item := range obj.(*v1alpha1.FlyteWorkflowList).Items { + if label.Matches(labels.Set(item.Labels)) { + list.Items = append(list.Items, item) + } + } + return list, err +} + +// Watch returns a watch.Interface that watches the requested flyteWorkflows. +func (c *FakeFlyteWorkflows) Watch(opts v1.ListOptions) (watch.Interface, error) { + return c.Fake. + InvokesWatch(testing.NewWatchAction(flyteworkflowsResource, c.ns, opts)) + +} + +// Create takes the representation of a flyteWorkflow and creates it. Returns the server's representation of the flyteWorkflow, and an error, if there is any. +func (c *FakeFlyteWorkflows) Create(flyteWorkflow *v1alpha1.FlyteWorkflow) (result *v1alpha1.FlyteWorkflow, err error) { + obj, err := c.Fake. + Invokes(testing.NewCreateAction(flyteworkflowsResource, c.ns, flyteWorkflow), &v1alpha1.FlyteWorkflow{}) + + if obj == nil { + return nil, err + } + return obj.(*v1alpha1.FlyteWorkflow), err +} + +// Update takes the representation of a flyteWorkflow and updates it. Returns the server's representation of the flyteWorkflow, and an error, if there is any. +func (c *FakeFlyteWorkflows) Update(flyteWorkflow *v1alpha1.FlyteWorkflow) (result *v1alpha1.FlyteWorkflow, err error) { + obj, err := c.Fake. + Invokes(testing.NewUpdateAction(flyteworkflowsResource, c.ns, flyteWorkflow), &v1alpha1.FlyteWorkflow{}) + + if obj == nil { + return nil, err + } + return obj.(*v1alpha1.FlyteWorkflow), err +} + +// UpdateStatus was generated because the type contains a Status member. +// Add a +genclient:noStatus comment above the type to avoid generating UpdateStatus(). +func (c *FakeFlyteWorkflows) UpdateStatus(flyteWorkflow *v1alpha1.FlyteWorkflow) (*v1alpha1.FlyteWorkflow, error) { + obj, err := c.Fake. + Invokes(testing.NewUpdateSubresourceAction(flyteworkflowsResource, "status", c.ns, flyteWorkflow), &v1alpha1.FlyteWorkflow{}) + + if obj == nil { + return nil, err + } + return obj.(*v1alpha1.FlyteWorkflow), err +} + +// Delete takes name of the flyteWorkflow and deletes it. Returns an error if one occurs. +func (c *FakeFlyteWorkflows) Delete(name string, options *v1.DeleteOptions) error { + _, err := c.Fake. + Invokes(testing.NewDeleteAction(flyteworkflowsResource, c.ns, name), &v1alpha1.FlyteWorkflow{}) + + return err +} + +// DeleteCollection deletes a collection of objects. +func (c *FakeFlyteWorkflows) DeleteCollection(options *v1.DeleteOptions, listOptions v1.ListOptions) error { + action := testing.NewDeleteCollectionAction(flyteworkflowsResource, c.ns, listOptions) + + _, err := c.Fake.Invokes(action, &v1alpha1.FlyteWorkflowList{}) + return err +} + +// Patch applies the patch and returns the patched flyteWorkflow. +func (c *FakeFlyteWorkflows) Patch(name string, pt types.PatchType, data []byte, subresources ...string) (result *v1alpha1.FlyteWorkflow, err error) { + obj, err := c.Fake. + Invokes(testing.NewPatchSubresourceAction(flyteworkflowsResource, c.ns, name, pt, data, subresources...), &v1alpha1.FlyteWorkflow{}) + + if obj == nil { + return nil, err + } + return obj.(*v1alpha1.FlyteWorkflow), err +} diff --git a/flytepropeller/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/fake/fake_flyteworkflow_client.go b/flytepropeller/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/fake/fake_flyteworkflow_client.go new file mode 100644 index 0000000000..11460605c9 --- /dev/null +++ b/flytepropeller/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/fake/fake_flyteworkflow_client.go @@ -0,0 +1,24 @@ +// Code generated by client-gen. DO NOT EDIT. + +package fake + +import ( + v1alpha1 "github.com/lyft/flytepropeller/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1" + rest "k8s.io/client-go/rest" + testing "k8s.io/client-go/testing" +) + +type FakeFlyteworkflowV1alpha1 struct { + *testing.Fake +} + +func (c *FakeFlyteworkflowV1alpha1) FlyteWorkflows(namespace string) v1alpha1.FlyteWorkflowInterface { + return &FakeFlyteWorkflows{c, namespace} +} + +// RESTClient returns a RESTClient that is used to communicate +// with API server by this client implementation. +func (c *FakeFlyteworkflowV1alpha1) RESTClient() rest.Interface { + var ret *rest.RESTClient + return ret +} diff --git a/flytepropeller/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/flyteworkflow.go b/flytepropeller/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/flyteworkflow.go new file mode 100644 index 0000000000..6b2dc62c14 --- /dev/null +++ b/flytepropeller/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/flyteworkflow.go @@ -0,0 +1,175 @@ +// Code generated by client-gen. DO NOT EDIT. + +package v1alpha1 + +import ( + "time" + + v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + scheme "github.com/lyft/flytepropeller/pkg/client/clientset/versioned/scheme" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + types "k8s.io/apimachinery/pkg/types" + watch "k8s.io/apimachinery/pkg/watch" + rest "k8s.io/client-go/rest" +) + +// FlyteWorkflowsGetter has a method to return a FlyteWorkflowInterface. +// A group's client should implement this interface. +type FlyteWorkflowsGetter interface { + FlyteWorkflows(namespace string) FlyteWorkflowInterface +} + +// FlyteWorkflowInterface has methods to work with FlyteWorkflow resources. +type FlyteWorkflowInterface interface { + Create(*v1alpha1.FlyteWorkflow) (*v1alpha1.FlyteWorkflow, error) + Update(*v1alpha1.FlyteWorkflow) (*v1alpha1.FlyteWorkflow, error) + UpdateStatus(*v1alpha1.FlyteWorkflow) (*v1alpha1.FlyteWorkflow, error) + Delete(name string, options *v1.DeleteOptions) error + DeleteCollection(options *v1.DeleteOptions, listOptions v1.ListOptions) error + Get(name string, options v1.GetOptions) (*v1alpha1.FlyteWorkflow, error) + List(opts v1.ListOptions) (*v1alpha1.FlyteWorkflowList, error) + Watch(opts v1.ListOptions) (watch.Interface, error) + Patch(name string, pt types.PatchType, data []byte, subresources ...string) (result *v1alpha1.FlyteWorkflow, err error) + FlyteWorkflowExpansion +} + +// flyteWorkflows implements FlyteWorkflowInterface +type flyteWorkflows struct { + client rest.Interface + ns string +} + +// newFlyteWorkflows returns a FlyteWorkflows +func newFlyteWorkflows(c *FlyteworkflowV1alpha1Client, namespace string) *flyteWorkflows { + return &flyteWorkflows{ + client: c.RESTClient(), + ns: namespace, + } +} + +// Get takes name of the flyteWorkflow, and returns the corresponding flyteWorkflow object, and an error if there is any. +func (c *flyteWorkflows) Get(name string, options v1.GetOptions) (result *v1alpha1.FlyteWorkflow, err error) { + result = &v1alpha1.FlyteWorkflow{} + err = c.client.Get(). + Namespace(c.ns). + Resource("flyteworkflows"). + Name(name). + VersionedParams(&options, scheme.ParameterCodec). + Do(). + Into(result) + return +} + +// List takes label and field selectors, and returns the list of FlyteWorkflows that match those selectors. +func (c *flyteWorkflows) List(opts v1.ListOptions) (result *v1alpha1.FlyteWorkflowList, err error) { + var timeout time.Duration + if opts.TimeoutSeconds != nil { + timeout = time.Duration(*opts.TimeoutSeconds) * time.Second + } + result = &v1alpha1.FlyteWorkflowList{} + err = c.client.Get(). + Namespace(c.ns). + Resource("flyteworkflows"). + VersionedParams(&opts, scheme.ParameterCodec). + Timeout(timeout). + Do(). + Into(result) + return +} + +// Watch returns a watch.Interface that watches the requested flyteWorkflows. +func (c *flyteWorkflows) Watch(opts v1.ListOptions) (watch.Interface, error) { + var timeout time.Duration + if opts.TimeoutSeconds != nil { + timeout = time.Duration(*opts.TimeoutSeconds) * time.Second + } + opts.Watch = true + return c.client.Get(). + Namespace(c.ns). + Resource("flyteworkflows"). + VersionedParams(&opts, scheme.ParameterCodec). + Timeout(timeout). + Watch() +} + +// Create takes the representation of a flyteWorkflow and creates it. Returns the server's representation of the flyteWorkflow, and an error, if there is any. +func (c *flyteWorkflows) Create(flyteWorkflow *v1alpha1.FlyteWorkflow) (result *v1alpha1.FlyteWorkflow, err error) { + result = &v1alpha1.FlyteWorkflow{} + err = c.client.Post(). + Namespace(c.ns). + Resource("flyteworkflows"). + Body(flyteWorkflow). + Do(). + Into(result) + return +} + +// Update takes the representation of a flyteWorkflow and updates it. Returns the server's representation of the flyteWorkflow, and an error, if there is any. +func (c *flyteWorkflows) Update(flyteWorkflow *v1alpha1.FlyteWorkflow) (result *v1alpha1.FlyteWorkflow, err error) { + result = &v1alpha1.FlyteWorkflow{} + err = c.client.Put(). + Namespace(c.ns). + Resource("flyteworkflows"). + Name(flyteWorkflow.Name). + Body(flyteWorkflow). + Do(). + Into(result) + return +} + +// UpdateStatus was generated because the type contains a Status member. +// Add a +genclient:noStatus comment above the type to avoid generating UpdateStatus(). + +func (c *flyteWorkflows) UpdateStatus(flyteWorkflow *v1alpha1.FlyteWorkflow) (result *v1alpha1.FlyteWorkflow, err error) { + result = &v1alpha1.FlyteWorkflow{} + err = c.client.Put(). + Namespace(c.ns). + Resource("flyteworkflows"). + Name(flyteWorkflow.Name). + SubResource("status"). + Body(flyteWorkflow). + Do(). + Into(result) + return +} + +// Delete takes name of the flyteWorkflow and deletes it. Returns an error if one occurs. +func (c *flyteWorkflows) Delete(name string, options *v1.DeleteOptions) error { + return c.client.Delete(). + Namespace(c.ns). + Resource("flyteworkflows"). + Name(name). + Body(options). + Do(). + Error() +} + +// DeleteCollection deletes a collection of objects. +func (c *flyteWorkflows) DeleteCollection(options *v1.DeleteOptions, listOptions v1.ListOptions) error { + var timeout time.Duration + if listOptions.TimeoutSeconds != nil { + timeout = time.Duration(*listOptions.TimeoutSeconds) * time.Second + } + return c.client.Delete(). + Namespace(c.ns). + Resource("flyteworkflows"). + VersionedParams(&listOptions, scheme.ParameterCodec). + Timeout(timeout). + Body(options). + Do(). + Error() +} + +// Patch applies the patch and returns the patched flyteWorkflow. +func (c *flyteWorkflows) Patch(name string, pt types.PatchType, data []byte, subresources ...string) (result *v1alpha1.FlyteWorkflow, err error) { + result = &v1alpha1.FlyteWorkflow{} + err = c.client.Patch(pt). + Namespace(c.ns). + Resource("flyteworkflows"). + SubResource(subresources...). + Name(name). + Body(data). + Do(). + Into(result) + return +} diff --git a/flytepropeller/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/flyteworkflow_client.go b/flytepropeller/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/flyteworkflow_client.go new file mode 100644 index 0000000000..2d7414f236 --- /dev/null +++ b/flytepropeller/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/flyteworkflow_client.go @@ -0,0 +1,74 @@ +// Code generated by client-gen. DO NOT EDIT. + +package v1alpha1 + +import ( + v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/client/clientset/versioned/scheme" + serializer "k8s.io/apimachinery/pkg/runtime/serializer" + rest "k8s.io/client-go/rest" +) + +type FlyteworkflowV1alpha1Interface interface { + RESTClient() rest.Interface + FlyteWorkflowsGetter +} + +// FlyteworkflowV1alpha1Client is used to interact with features provided by the flyteworkflow.flyte.net group. +type FlyteworkflowV1alpha1Client struct { + restClient rest.Interface +} + +func (c *FlyteworkflowV1alpha1Client) FlyteWorkflows(namespace string) FlyteWorkflowInterface { + return newFlyteWorkflows(c, namespace) +} + +// NewForConfig creates a new FlyteworkflowV1alpha1Client for the given config. +func NewForConfig(c *rest.Config) (*FlyteworkflowV1alpha1Client, error) { + config := *c + if err := setConfigDefaults(&config); err != nil { + return nil, err + } + client, err := rest.RESTClientFor(&config) + if err != nil { + return nil, err + } + return &FlyteworkflowV1alpha1Client{client}, nil +} + +// NewForConfigOrDie creates a new FlyteworkflowV1alpha1Client for the given config and +// panics if there is an error in the config. +func NewForConfigOrDie(c *rest.Config) *FlyteworkflowV1alpha1Client { + client, err := NewForConfig(c) + if err != nil { + panic(err) + } + return client +} + +// New creates a new FlyteworkflowV1alpha1Client for the given RESTClient. +func New(c rest.Interface) *FlyteworkflowV1alpha1Client { + return &FlyteworkflowV1alpha1Client{c} +} + +func setConfigDefaults(config *rest.Config) error { + gv := v1alpha1.SchemeGroupVersion + config.GroupVersion = &gv + config.APIPath = "/apis" + config.NegotiatedSerializer = serializer.DirectCodecFactory{CodecFactory: scheme.Codecs} + + if config.UserAgent == "" { + config.UserAgent = rest.DefaultKubernetesUserAgent() + } + + return nil +} + +// RESTClient returns a RESTClient that is used to communicate +// with API server by this client implementation. +func (c *FlyteworkflowV1alpha1Client) RESTClient() rest.Interface { + if c == nil { + return nil + } + return c.restClient +} diff --git a/flytepropeller/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/generated_expansion.go b/flytepropeller/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/generated_expansion.go new file mode 100644 index 0000000000..eb8294c165 --- /dev/null +++ b/flytepropeller/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1/generated_expansion.go @@ -0,0 +1,5 @@ +// Code generated by client-gen. DO NOT EDIT. + +package v1alpha1 + +type FlyteWorkflowExpansion interface{} diff --git a/flytepropeller/pkg/client/informers/externalversions/factory.go b/flytepropeller/pkg/client/informers/externalversions/factory.go new file mode 100644 index 0000000000..2a094285f9 --- /dev/null +++ b/flytepropeller/pkg/client/informers/externalversions/factory.go @@ -0,0 +1,164 @@ +// Code generated by informer-gen. DO NOT EDIT. + +package externalversions + +import ( + reflect "reflect" + sync "sync" + time "time" + + versioned "github.com/lyft/flytepropeller/pkg/client/clientset/versioned" + flyteworkflow "github.com/lyft/flytepropeller/pkg/client/informers/externalversions/flyteworkflow" + internalinterfaces "github.com/lyft/flytepropeller/pkg/client/informers/externalversions/internalinterfaces" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + runtime "k8s.io/apimachinery/pkg/runtime" + schema "k8s.io/apimachinery/pkg/runtime/schema" + cache "k8s.io/client-go/tools/cache" +) + +// SharedInformerOption defines the functional option type for SharedInformerFactory. +type SharedInformerOption func(*sharedInformerFactory) *sharedInformerFactory + +type sharedInformerFactory struct { + client versioned.Interface + namespace string + tweakListOptions internalinterfaces.TweakListOptionsFunc + lock sync.Mutex + defaultResync time.Duration + customResync map[reflect.Type]time.Duration + + informers map[reflect.Type]cache.SharedIndexInformer + // startedInformers is used for tracking which informers have been started. + // This allows Start() to be called multiple times safely. + startedInformers map[reflect.Type]bool +} + +// WithCustomResyncConfig sets a custom resync period for the specified informer types. +func WithCustomResyncConfig(resyncConfig map[v1.Object]time.Duration) SharedInformerOption { + return func(factory *sharedInformerFactory) *sharedInformerFactory { + for k, v := range resyncConfig { + factory.customResync[reflect.TypeOf(k)] = v + } + return factory + } +} + +// WithTweakListOptions sets a custom filter on all listers of the configured SharedInformerFactory. +func WithTweakListOptions(tweakListOptions internalinterfaces.TweakListOptionsFunc) SharedInformerOption { + return func(factory *sharedInformerFactory) *sharedInformerFactory { + factory.tweakListOptions = tweakListOptions + return factory + } +} + +// WithNamespace limits the SharedInformerFactory to the specified namespace. +func WithNamespace(namespace string) SharedInformerOption { + return func(factory *sharedInformerFactory) *sharedInformerFactory { + factory.namespace = namespace + return factory + } +} + +// NewSharedInformerFactory constructs a new instance of sharedInformerFactory for all namespaces. +func NewSharedInformerFactory(client versioned.Interface, defaultResync time.Duration) SharedInformerFactory { + return NewSharedInformerFactoryWithOptions(client, defaultResync) +} + +// NewFilteredSharedInformerFactory constructs a new instance of sharedInformerFactory. +// Listers obtained via this SharedInformerFactory will be subject to the same filters +// as specified here. +// Deprecated: Please use NewSharedInformerFactoryWithOptions instead +func NewFilteredSharedInformerFactory(client versioned.Interface, defaultResync time.Duration, namespace string, tweakListOptions internalinterfaces.TweakListOptionsFunc) SharedInformerFactory { + return NewSharedInformerFactoryWithOptions(client, defaultResync, WithNamespace(namespace), WithTweakListOptions(tweakListOptions)) +} + +// NewSharedInformerFactoryWithOptions constructs a new instance of a SharedInformerFactory with additional options. +func NewSharedInformerFactoryWithOptions(client versioned.Interface, defaultResync time.Duration, options ...SharedInformerOption) SharedInformerFactory { + factory := &sharedInformerFactory{ + client: client, + namespace: v1.NamespaceAll, + defaultResync: defaultResync, + informers: make(map[reflect.Type]cache.SharedIndexInformer), + startedInformers: make(map[reflect.Type]bool), + customResync: make(map[reflect.Type]time.Duration), + } + + // Apply all options + for _, opt := range options { + factory = opt(factory) + } + + return factory +} + +// Start initializes all requested informers. +func (f *sharedInformerFactory) Start(stopCh <-chan struct{}) { + f.lock.Lock() + defer f.lock.Unlock() + + for informerType, informer := range f.informers { + if !f.startedInformers[informerType] { + go informer.Run(stopCh) + f.startedInformers[informerType] = true + } + } +} + +// WaitForCacheSync waits for all started informers' cache were synced. +func (f *sharedInformerFactory) WaitForCacheSync(stopCh <-chan struct{}) map[reflect.Type]bool { + informers := func() map[reflect.Type]cache.SharedIndexInformer { + f.lock.Lock() + defer f.lock.Unlock() + + informers := map[reflect.Type]cache.SharedIndexInformer{} + for informerType, informer := range f.informers { + if f.startedInformers[informerType] { + informers[informerType] = informer + } + } + return informers + }() + + res := map[reflect.Type]bool{} + for informType, informer := range informers { + res[informType] = cache.WaitForCacheSync(stopCh, informer.HasSynced) + } + return res +} + +// InternalInformerFor returns the SharedIndexInformer for obj using an internal +// client. +func (f *sharedInformerFactory) InformerFor(obj runtime.Object, newFunc internalinterfaces.NewInformerFunc) cache.SharedIndexInformer { + f.lock.Lock() + defer f.lock.Unlock() + + informerType := reflect.TypeOf(obj) + informer, exists := f.informers[informerType] + if exists { + return informer + } + + resyncPeriod, exists := f.customResync[informerType] + if !exists { + resyncPeriod = f.defaultResync + } + + informer = newFunc(f.client, resyncPeriod) + f.informers[informerType] = informer + + return informer +} + +// SharedInformerFactory provides shared informers for resources in all known +// API group versions. +type SharedInformerFactory interface { + internalinterfaces.SharedInformerFactory + ForResource(resource schema.GroupVersionResource) (GenericInformer, error) + WaitForCacheSync(stopCh <-chan struct{}) map[reflect.Type]bool + + Flyteworkflow() flyteworkflow.Interface +} + +func (f *sharedInformerFactory) Flyteworkflow() flyteworkflow.Interface { + return flyteworkflow.New(f, f.namespace, f.tweakListOptions) +} diff --git a/flytepropeller/pkg/client/informers/externalversions/flyteworkflow/interface.go b/flytepropeller/pkg/client/informers/externalversions/flyteworkflow/interface.go new file mode 100644 index 0000000000..b8410c1688 --- /dev/null +++ b/flytepropeller/pkg/client/informers/externalversions/flyteworkflow/interface.go @@ -0,0 +1,30 @@ +// Code generated by informer-gen. DO NOT EDIT. + +package flyteworkflow + +import ( + v1alpha1 "github.com/lyft/flytepropeller/pkg/client/informers/externalversions/flyteworkflow/v1alpha1" + internalinterfaces "github.com/lyft/flytepropeller/pkg/client/informers/externalversions/internalinterfaces" +) + +// Interface provides access to each of this group's versions. +type Interface interface { + // V1alpha1 provides access to shared informers for resources in V1alpha1. + V1alpha1() v1alpha1.Interface +} + +type group struct { + factory internalinterfaces.SharedInformerFactory + namespace string + tweakListOptions internalinterfaces.TweakListOptionsFunc +} + +// New returns a new Interface. +func New(f internalinterfaces.SharedInformerFactory, namespace string, tweakListOptions internalinterfaces.TweakListOptionsFunc) Interface { + return &group{factory: f, namespace: namespace, tweakListOptions: tweakListOptions} +} + +// V1alpha1 returns a new v1alpha1.Interface. +func (g *group) V1alpha1() v1alpha1.Interface { + return v1alpha1.New(g.factory, g.namespace, g.tweakListOptions) +} diff --git a/flytepropeller/pkg/client/informers/externalversions/flyteworkflow/v1alpha1/flyteworkflow.go b/flytepropeller/pkg/client/informers/externalversions/flyteworkflow/v1alpha1/flyteworkflow.go new file mode 100644 index 0000000000..3ea918b5d9 --- /dev/null +++ b/flytepropeller/pkg/client/informers/externalversions/flyteworkflow/v1alpha1/flyteworkflow.go @@ -0,0 +1,73 @@ +// Code generated by informer-gen. DO NOT EDIT. + +package v1alpha1 + +import ( + time "time" + + flyteworkflowv1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + versioned "github.com/lyft/flytepropeller/pkg/client/clientset/versioned" + internalinterfaces "github.com/lyft/flytepropeller/pkg/client/informers/externalversions/internalinterfaces" + v1alpha1 "github.com/lyft/flytepropeller/pkg/client/listers/flyteworkflow/v1alpha1" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + runtime "k8s.io/apimachinery/pkg/runtime" + watch "k8s.io/apimachinery/pkg/watch" + cache "k8s.io/client-go/tools/cache" +) + +// FlyteWorkflowInformer provides access to a shared informer and lister for +// FlyteWorkflows. +type FlyteWorkflowInformer interface { + Informer() cache.SharedIndexInformer + Lister() v1alpha1.FlyteWorkflowLister +} + +type flyteWorkflowInformer struct { + factory internalinterfaces.SharedInformerFactory + tweakListOptions internalinterfaces.TweakListOptionsFunc + namespace string +} + +// NewFlyteWorkflowInformer constructs a new informer for FlyteWorkflow type. +// Always prefer using an informer factory to get a shared informer instead of getting an independent +// one. This reduces memory footprint and number of connections to the server. +func NewFlyteWorkflowInformer(client versioned.Interface, namespace string, resyncPeriod time.Duration, indexers cache.Indexers) cache.SharedIndexInformer { + return NewFilteredFlyteWorkflowInformer(client, namespace, resyncPeriod, indexers, nil) +} + +// NewFilteredFlyteWorkflowInformer constructs a new informer for FlyteWorkflow type. +// Always prefer using an informer factory to get a shared informer instead of getting an independent +// one. This reduces memory footprint and number of connections to the server. +func NewFilteredFlyteWorkflowInformer(client versioned.Interface, namespace string, resyncPeriod time.Duration, indexers cache.Indexers, tweakListOptions internalinterfaces.TweakListOptionsFunc) cache.SharedIndexInformer { + return cache.NewSharedIndexInformer( + &cache.ListWatch{ + ListFunc: func(options v1.ListOptions) (runtime.Object, error) { + if tweakListOptions != nil { + tweakListOptions(&options) + } + return client.FlyteworkflowV1alpha1().FlyteWorkflows(namespace).List(options) + }, + WatchFunc: func(options v1.ListOptions) (watch.Interface, error) { + if tweakListOptions != nil { + tweakListOptions(&options) + } + return client.FlyteworkflowV1alpha1().FlyteWorkflows(namespace).Watch(options) + }, + }, + &flyteworkflowv1alpha1.FlyteWorkflow{}, + resyncPeriod, + indexers, + ) +} + +func (f *flyteWorkflowInformer) defaultInformer(client versioned.Interface, resyncPeriod time.Duration) cache.SharedIndexInformer { + return NewFilteredFlyteWorkflowInformer(client, f.namespace, resyncPeriod, cache.Indexers{cache.NamespaceIndex: cache.MetaNamespaceIndexFunc}, f.tweakListOptions) +} + +func (f *flyteWorkflowInformer) Informer() cache.SharedIndexInformer { + return f.factory.InformerFor(&flyteworkflowv1alpha1.FlyteWorkflow{}, f.defaultInformer) +} + +func (f *flyteWorkflowInformer) Lister() v1alpha1.FlyteWorkflowLister { + return v1alpha1.NewFlyteWorkflowLister(f.Informer().GetIndexer()) +} diff --git a/flytepropeller/pkg/client/informers/externalversions/flyteworkflow/v1alpha1/interface.go b/flytepropeller/pkg/client/informers/externalversions/flyteworkflow/v1alpha1/interface.go new file mode 100644 index 0000000000..c4425cb4c9 --- /dev/null +++ b/flytepropeller/pkg/client/informers/externalversions/flyteworkflow/v1alpha1/interface.go @@ -0,0 +1,29 @@ +// Code generated by informer-gen. DO NOT EDIT. + +package v1alpha1 + +import ( + internalinterfaces "github.com/lyft/flytepropeller/pkg/client/informers/externalversions/internalinterfaces" +) + +// Interface provides access to all the informers in this group version. +type Interface interface { + // FlyteWorkflows returns a FlyteWorkflowInformer. + FlyteWorkflows() FlyteWorkflowInformer +} + +type version struct { + factory internalinterfaces.SharedInformerFactory + namespace string + tweakListOptions internalinterfaces.TweakListOptionsFunc +} + +// New returns a new Interface. +func New(f internalinterfaces.SharedInformerFactory, namespace string, tweakListOptions internalinterfaces.TweakListOptionsFunc) Interface { + return &version{factory: f, namespace: namespace, tweakListOptions: tweakListOptions} +} + +// FlyteWorkflows returns a FlyteWorkflowInformer. +func (v *version) FlyteWorkflows() FlyteWorkflowInformer { + return &flyteWorkflowInformer{factory: v.factory, namespace: v.namespace, tweakListOptions: v.tweakListOptions} +} diff --git a/flytepropeller/pkg/client/informers/externalversions/generic.go b/flytepropeller/pkg/client/informers/externalversions/generic.go new file mode 100644 index 0000000000..3d1564aa53 --- /dev/null +++ b/flytepropeller/pkg/client/informers/externalversions/generic.go @@ -0,0 +1,46 @@ +// Code generated by informer-gen. DO NOT EDIT. + +package externalversions + +import ( + "fmt" + + v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + schema "k8s.io/apimachinery/pkg/runtime/schema" + cache "k8s.io/client-go/tools/cache" +) + +// GenericInformer is type of SharedIndexInformer which will locate and delegate to other +// sharedInformers based on type +type GenericInformer interface { + Informer() cache.SharedIndexInformer + Lister() cache.GenericLister +} + +type genericInformer struct { + informer cache.SharedIndexInformer + resource schema.GroupResource +} + +// Informer returns the SharedIndexInformer. +func (f *genericInformer) Informer() cache.SharedIndexInformer { + return f.informer +} + +// Lister returns the GenericLister. +func (f *genericInformer) Lister() cache.GenericLister { + return cache.NewGenericLister(f.Informer().GetIndexer(), f.resource) +} + +// ForResource gives generic access to a shared informer of the matching type +// TODO extend this to unknown resources with a client pool +func (f *sharedInformerFactory) ForResource(resource schema.GroupVersionResource) (GenericInformer, error) { + switch resource { + // Group=flyteworkflow.flyte.net, Version=v1alpha1 + case v1alpha1.SchemeGroupVersion.WithResource("flyteworkflows"): + return &genericInformer{resource: resource.GroupResource(), informer: f.Flyteworkflow().V1alpha1().FlyteWorkflows().Informer()}, nil + + } + + return nil, fmt.Errorf("no informer found for %v", resource) +} diff --git a/flytepropeller/pkg/client/informers/externalversions/internalinterfaces/factory_interfaces.go b/flytepropeller/pkg/client/informers/externalversions/internalinterfaces/factory_interfaces.go new file mode 100644 index 0000000000..147b4a34cd --- /dev/null +++ b/flytepropeller/pkg/client/informers/externalversions/internalinterfaces/factory_interfaces.go @@ -0,0 +1,24 @@ +// Code generated by informer-gen. DO NOT EDIT. + +package internalinterfaces + +import ( + time "time" + + versioned "github.com/lyft/flytepropeller/pkg/client/clientset/versioned" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + runtime "k8s.io/apimachinery/pkg/runtime" + cache "k8s.io/client-go/tools/cache" +) + +// NewInformerFunc takes versioned.Interface and time.Duration to return a SharedIndexInformer. +type NewInformerFunc func(versioned.Interface, time.Duration) cache.SharedIndexInformer + +// SharedInformerFactory a small interface to allow for adding an informer without an import cycle +type SharedInformerFactory interface { + Start(stopCh <-chan struct{}) + InformerFor(obj runtime.Object, newFunc NewInformerFunc) cache.SharedIndexInformer +} + +// TweakListOptionsFunc is a function that transforms a v1.ListOptions. +type TweakListOptionsFunc func(*v1.ListOptions) diff --git a/flytepropeller/pkg/client/listers/flyteworkflow/v1alpha1/expansion_generated.go b/flytepropeller/pkg/client/listers/flyteworkflow/v1alpha1/expansion_generated.go new file mode 100644 index 0000000000..74ad855480 --- /dev/null +++ b/flytepropeller/pkg/client/listers/flyteworkflow/v1alpha1/expansion_generated.go @@ -0,0 +1,11 @@ +// Code generated by lister-gen. DO NOT EDIT. + +package v1alpha1 + +// FlyteWorkflowListerExpansion allows custom methods to be added to +// FlyteWorkflowLister. +type FlyteWorkflowListerExpansion interface{} + +// FlyteWorkflowNamespaceListerExpansion allows custom methods to be added to +// FlyteWorkflowNamespaceLister. +type FlyteWorkflowNamespaceListerExpansion interface{} diff --git a/flytepropeller/pkg/client/listers/flyteworkflow/v1alpha1/flyteworkflow.go b/flytepropeller/pkg/client/listers/flyteworkflow/v1alpha1/flyteworkflow.go new file mode 100644 index 0000000000..1ddbf256bb --- /dev/null +++ b/flytepropeller/pkg/client/listers/flyteworkflow/v1alpha1/flyteworkflow.go @@ -0,0 +1,78 @@ +// Code generated by lister-gen. DO NOT EDIT. + +package v1alpha1 + +import ( + v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/labels" + "k8s.io/client-go/tools/cache" +) + +// FlyteWorkflowLister helps list FlyteWorkflows. +type FlyteWorkflowLister interface { + // List lists all FlyteWorkflows in the indexer. + List(selector labels.Selector) (ret []*v1alpha1.FlyteWorkflow, err error) + // FlyteWorkflows returns an object that can list and get FlyteWorkflows. + FlyteWorkflows(namespace string) FlyteWorkflowNamespaceLister + FlyteWorkflowListerExpansion +} + +// flyteWorkflowLister implements the FlyteWorkflowLister interface. +type flyteWorkflowLister struct { + indexer cache.Indexer +} + +// NewFlyteWorkflowLister returns a new FlyteWorkflowLister. +func NewFlyteWorkflowLister(indexer cache.Indexer) FlyteWorkflowLister { + return &flyteWorkflowLister{indexer: indexer} +} + +// List lists all FlyteWorkflows in the indexer. +func (s *flyteWorkflowLister) List(selector labels.Selector) (ret []*v1alpha1.FlyteWorkflow, err error) { + err = cache.ListAll(s.indexer, selector, func(m interface{}) { + ret = append(ret, m.(*v1alpha1.FlyteWorkflow)) + }) + return ret, err +} + +// FlyteWorkflows returns an object that can list and get FlyteWorkflows. +func (s *flyteWorkflowLister) FlyteWorkflows(namespace string) FlyteWorkflowNamespaceLister { + return flyteWorkflowNamespaceLister{indexer: s.indexer, namespace: namespace} +} + +// FlyteWorkflowNamespaceLister helps list and get FlyteWorkflows. +type FlyteWorkflowNamespaceLister interface { + // List lists all FlyteWorkflows in the indexer for a given namespace. + List(selector labels.Selector) (ret []*v1alpha1.FlyteWorkflow, err error) + // Get retrieves the FlyteWorkflow from the indexer for a given namespace and name. + Get(name string) (*v1alpha1.FlyteWorkflow, error) + FlyteWorkflowNamespaceListerExpansion +} + +// flyteWorkflowNamespaceLister implements the FlyteWorkflowNamespaceLister +// interface. +type flyteWorkflowNamespaceLister struct { + indexer cache.Indexer + namespace string +} + +// List lists all FlyteWorkflows in the indexer for a given namespace. +func (s flyteWorkflowNamespaceLister) List(selector labels.Selector) (ret []*v1alpha1.FlyteWorkflow, err error) { + err = cache.ListAllByNamespace(s.indexer, s.namespace, selector, func(m interface{}) { + ret = append(ret, m.(*v1alpha1.FlyteWorkflow)) + }) + return ret, err +} + +// Get retrieves the FlyteWorkflow from the indexer for a given namespace and name. +func (s flyteWorkflowNamespaceLister) Get(name string) (*v1alpha1.FlyteWorkflow, error) { + obj, exists, err := s.indexer.GetByKey(s.namespace + "/" + name) + if err != nil { + return nil, err + } + if !exists { + return nil, errors.NewNotFound(v1alpha1.Resource("flyteworkflow"), name) + } + return obj.(*v1alpha1.FlyteWorkflow), nil +} diff --git a/flytepropeller/pkg/compiler/builders.go b/flytepropeller/pkg/compiler/builders.go new file mode 100755 index 0000000000..70f3205254 --- /dev/null +++ b/flytepropeller/pkg/compiler/builders.go @@ -0,0 +1,136 @@ +package compiler + +import ( + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + c "github.com/lyft/flytepropeller/pkg/compiler/common" +) + +type flyteTask = core.TaskTemplate +type flyteWorkflow = core.CompiledWorkflow +type flyteNode = core.Node + +// A builder object for the Graph struct. This contains information the compiler uses while building the final Graph +// struct. +type workflowBuilder struct { + CoreWorkflow *flyteWorkflow + LaunchPlans map[c.WorkflowIDKey]c.InterfaceProvider + Tasks c.TaskIndex + downstreamNodes c.StringAdjacencyList + upstreamNodes c.StringAdjacencyList + Nodes c.NodeIndex + + // These are references to all subgraphs and tasks passed to CompileWorkflow. They will be passed around but will + // not show in their entirety in the final Graph. The required subset of these will be added to each subgraph as + // the compile traverses them. + allLaunchPlans map[string]c.InterfaceProvider + allTasks c.TaskIndex + allSubWorkflows c.WorkflowIndex +} + +func (w workflowBuilder) GetFailureNode() c.Node { + if w.GetCoreWorkflow() != nil && w.GetCoreWorkflow().GetTemplate() != nil && w.GetCoreWorkflow().GetTemplate().FailureNode != nil { + return w.NewNodeBuilder(w.GetCoreWorkflow().GetTemplate().FailureNode) + } + + return nil +} + +func (w workflowBuilder) GetNodes() c.NodeIndex { + return w.Nodes +} + +func (w workflowBuilder) GetTasks() c.TaskIndex { + return w.Tasks +} + +func (w workflowBuilder) GetDownstreamNodes() c.StringAdjacencyList { + return w.downstreamNodes +} + +func (w workflowBuilder) GetUpstreamNodes() c.StringAdjacencyList { + return w.upstreamNodes +} + +func (w workflowBuilder) NewNodeBuilder(n *flyteNode) c.NodeBuilder { + return &nodeBuilder{flyteNode: n} +} + +func (w workflowBuilder) GetNode(id c.NodeID) (node c.NodeBuilder, found bool) { + node, found = w.Nodes[id] + return +} + +func (w workflowBuilder) GetTask(id c.TaskID) (task c.Task, found bool) { + task, found = w.Tasks[id.String()] + return +} + +func (w workflowBuilder) GetLaunchPlan(id c.LaunchPlanID) (wf c.InterfaceProvider, found bool) { + wf, found = w.LaunchPlans[id.String()] + return +} + +func (w workflowBuilder) GetSubWorkflow(id c.WorkflowID) (wf *core.CompiledWorkflow, found bool) { + wf, found = w.allSubWorkflows[id.String()] + return +} + +func (w workflowBuilder) GetCoreWorkflow() *flyteWorkflow { + return w.CoreWorkflow +} + +// A wrapper around core.nodeBuilder to augment with computed fields during compilation +type nodeBuilder struct { + *flyteNode + subWorkflow c.Workflow + Task c.Task + Iface *core.TypedInterface +} + +func (n nodeBuilder) GetTask() c.Task { + return n.Task +} + +func (n *nodeBuilder) SetTask(task c.Task) { + n.Task = task +} + +func (n nodeBuilder) GetSubWorkflow() c.Workflow { + return n.subWorkflow +} + +func (n nodeBuilder) GetCoreNode() *core.Node { + return n.flyteNode +} + +func (n nodeBuilder) GetInterface() *core.TypedInterface { + return n.Iface +} + +func (n *nodeBuilder) SetInterface(iface *core.TypedInterface) { + n.Iface = iface +} + +func (n *nodeBuilder) SetSubWorkflow(wf c.Workflow) { + n.subWorkflow = wf +} + +func (n *nodeBuilder) SetInputs(inputs []*core.Binding) { + n.Inputs = inputs +} + +type taskBuilder struct { + *flyteTask +} + +func (t taskBuilder) GetCoreTask() *core.TaskTemplate { + return t.flyteTask +} + +func (t taskBuilder) GetID() c.Identifier { + if t.flyteTask.Id != nil { + return *t.flyteTask.Id + } + + return c.Identifier{} +} diff --git a/flytepropeller/pkg/compiler/common/builder.go b/flytepropeller/pkg/compiler/common/builder.go new file mode 100644 index 0000000000..032b9c02e5 --- /dev/null +++ b/flytepropeller/pkg/compiler/common/builder.go @@ -0,0 +1,32 @@ +// This package defines the intermediate layer that the compiler builds and transformers accept. +package common + +import ( + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/compiler/errors" +) + +const ( + StartNodeID = "start-node" + EndNodeID = "end-node" +) + +//go:generate mockery -all -output=mocks -case=underscore + +// A mutable workflow used during the build of the intermediate layer. +type WorkflowBuilder interface { + Workflow + AddExecutionEdge(nodeFrom, nodeTo NodeID) + AddNode(n NodeBuilder, errs errors.CompileErrors) (node NodeBuilder, ok bool) + ValidateWorkflow(fg *core.CompiledWorkflow, errs errors.CompileErrors) (Workflow, bool) + NewNodeBuilder(n *core.Node) NodeBuilder +} + +// A mutable node used during the build of the intermediate layer. +type NodeBuilder interface { + Node + SetInterface(iface *core.TypedInterface) + SetInputs(inputs []*core.Binding) + SetSubWorkflow(wf Workflow) + SetTask(task Task) +} diff --git a/flytepropeller/pkg/compiler/common/id_set.go b/flytepropeller/pkg/compiler/common/id_set.go new file mode 100644 index 0000000000..8489f10f03 --- /dev/null +++ b/flytepropeller/pkg/compiler/common/id_set.go @@ -0,0 +1,99 @@ +package common + +import ( + "sort" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" +) + +type Empty struct{} +type Identifier = core.Identifier +type IdentifierSet map[string]Identifier + +// NewString creates a String from a list of values. +func NewIdentifierSet(items ...Identifier) IdentifierSet { + ss := IdentifierSet{} + ss.Insert(items...) + return ss +} + +// Insert adds items to the set. +func (s IdentifierSet) Insert(items ...Identifier) { + for _, item := range items { + s[item.String()] = item + } +} + +// Delete removes all items from the set. +func (s IdentifierSet) Delete(items ...Identifier) { + for _, item := range items { + delete(s, item.String()) + } +} + +// Has returns true if and only if item is contained in the set. +func (s IdentifierSet) Has(item Identifier) bool { + _, contained := s[item.String()] + return contained +} + +// HasAll returns true if and only if all items are contained in the set. +func (s IdentifierSet) HasAll(items ...Identifier) bool { + for _, item := range items { + if !s.Has(item) { + return false + } + } + return true +} + +// HasAny returns true if any items are contained in the set. +func (s IdentifierSet) HasAny(items ...Identifier) bool { + for _, item := range items { + if s.Has(item) { + return true + } + } + return false +} + +type sortableSliceOfString []Identifier + +func (s sortableSliceOfString) Len() int { return len(s) } +func (s sortableSliceOfString) Less(i, j int) bool { + first, second := s[i], s[j] + if first.ResourceType != second.ResourceType { + return first.ResourceType < second.ResourceType + } + + if first.Project != second.Project { + return first.Project < second.Project + } + + if first.Domain != second.Domain { + return first.Domain < second.Domain + } + + if first.Name != second.Name { + return first.Name < second.Name + } + + if first.Version != second.Version { + return first.Version < second.Version + } + + return false +} + +func (s sortableSliceOfString) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +// List returns the contents as a sorted Identifier slice. +func (s IdentifierSet) List() []Identifier { + res := make(sortableSliceOfString, 0, len(s)) + for _, value := range s { + res = append(res, value) + } + + sort.Sort(res) + return []Identifier(res) +} diff --git a/flytepropeller/pkg/compiler/common/index.go b/flytepropeller/pkg/compiler/common/index.go new file mode 100644 index 0000000000..e445616fb4 --- /dev/null +++ b/flytepropeller/pkg/compiler/common/index.go @@ -0,0 +1,71 @@ +package common + +import ( + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/compiler/errors" + "k8s.io/apimachinery/pkg/util/sets" +) + +// Defines an index of nodebuilders based on the id. +type NodeIndex map[NodeID]NodeBuilder + +// Defines an index of tasks based on the id. +type TaskIndex map[TaskIDKey]Task + +type WorkflowIndex map[WorkflowIDKey]*core.CompiledWorkflow + +// Defines a string adjacency list. +type AdjacencyList map[string]IdentifierSet + +type StringAdjacencyList map[string]sets.String + +// Converts the sets in the adjacency list to sorted arrays. +func (l AdjacencyList) ToMapOfLists() map[string][]Identifier { + res := make(map[string][]Identifier, len(l)) + for key, set := range l { + res[key] = set.List() + } + + return res +} + +// Creates a new TaskIndex. +func NewTaskIndex(tasks ...Task) TaskIndex { + res := make(TaskIndex, len(tasks)) + for _, task := range tasks { + id := task.GetID() + res[(&id).String()] = task + } + + return res +} + +// Creates a new NodeIndex +func NewNodeIndex(nodes ...NodeBuilder) NodeIndex { + res := make(NodeIndex, len(nodes)) + for _, task := range nodes { + res[task.GetId()] = task + } + + return res +} + +func NewWorkflowIndex(workflows []*core.CompiledWorkflow, errs errors.CompileErrors) (index WorkflowIndex, ok bool) { + ok = true + index = make(WorkflowIndex, len(workflows)) + for _, wf := range workflows { + if wf.Template.Id == nil { + // TODO: Log/Return error + return nil, false + } + + if _, found := index[wf.Template.Id.String()]; found { + errs.Collect(errors.NewDuplicateIDFoundErr(wf.Template.Id.String())) + ok = false + } else { + index[wf.Template.Id.String()] = wf + } + } + + return +} diff --git a/flytepropeller/pkg/compiler/common/mocks/interface_provider.go b/flytepropeller/pkg/compiler/common/mocks/interface_provider.go new file mode 100644 index 0000000000..d7f776ffc1 --- /dev/null +++ b/flytepropeller/pkg/compiler/common/mocks/interface_provider.go @@ -0,0 +1,59 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" +import mock "github.com/stretchr/testify/mock" + +// InterfaceProvider is an autogenerated mock type for the InterfaceProvider type +type InterfaceProvider struct { + mock.Mock +} + +// GetExpectedInputs provides a mock function with given fields: +func (_m *InterfaceProvider) GetExpectedInputs() *core.ParameterMap { + ret := _m.Called() + + var r0 *core.ParameterMap + if rf, ok := ret.Get(0).(func() *core.ParameterMap); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.ParameterMap) + } + } + + return r0 +} + +// GetExpectedOutputs provides a mock function with given fields: +func (_m *InterfaceProvider) GetExpectedOutputs() *core.VariableMap { + ret := _m.Called() + + var r0 *core.VariableMap + if rf, ok := ret.Get(0).(func() *core.VariableMap); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.VariableMap) + } + } + + return r0 +} + +// GetID provides a mock function with given fields: +func (_m *InterfaceProvider) GetID() *core.Identifier { + ret := _m.Called() + + var r0 *core.Identifier + if rf, ok := ret.Get(0).(func() *core.Identifier); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.Identifier) + } + } + + return r0 +} diff --git a/flytepropeller/pkg/compiler/common/mocks/node.go b/flytepropeller/pkg/compiler/common/mocks/node.go new file mode 100644 index 0000000000..eebbeb5099 --- /dev/null +++ b/flytepropeller/pkg/compiler/common/mocks/node.go @@ -0,0 +1,202 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import common "github.com/lyft/flytepropeller/pkg/compiler/common" +import core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" +import mock "github.com/stretchr/testify/mock" + +// Node is an autogenerated mock type for the Node type +type Node struct { + mock.Mock +} + +// GetBranchNode provides a mock function with given fields: +func (_m *Node) GetBranchNode() *core.BranchNode { + ret := _m.Called() + + var r0 *core.BranchNode + if rf, ok := ret.Get(0).(func() *core.BranchNode); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.BranchNode) + } + } + + return r0 +} + +// GetCoreNode provides a mock function with given fields: +func (_m *Node) GetCoreNode() *core.Node { + ret := _m.Called() + + var r0 *core.Node + if rf, ok := ret.Get(0).(func() *core.Node); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.Node) + } + } + + return r0 +} + +// GetId provides a mock function with given fields: +func (_m *Node) GetId() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// GetInputs provides a mock function with given fields: +func (_m *Node) GetInputs() []*core.Binding { + ret := _m.Called() + + var r0 []*core.Binding + if rf, ok := ret.Get(0).(func() []*core.Binding); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*core.Binding) + } + } + + return r0 +} + +// GetInterface provides a mock function with given fields: +func (_m *Node) GetInterface() *core.TypedInterface { + ret := _m.Called() + + var r0 *core.TypedInterface + if rf, ok := ret.Get(0).(func() *core.TypedInterface); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.TypedInterface) + } + } + + return r0 +} + +// GetMetadata provides a mock function with given fields: +func (_m *Node) GetMetadata() *core.NodeMetadata { + ret := _m.Called() + + var r0 *core.NodeMetadata + if rf, ok := ret.Get(0).(func() *core.NodeMetadata); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.NodeMetadata) + } + } + + return r0 +} + +// GetOutputAliases provides a mock function with given fields: +func (_m *Node) GetOutputAliases() []*core.Alias { + ret := _m.Called() + + var r0 []*core.Alias + if rf, ok := ret.Get(0).(func() []*core.Alias); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*core.Alias) + } + } + + return r0 +} + +// GetSubWorkflow provides a mock function with given fields: +func (_m *Node) GetSubWorkflow() common.Workflow { + ret := _m.Called() + + var r0 common.Workflow + if rf, ok := ret.Get(0).(func() common.Workflow); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.Workflow) + } + } + + return r0 +} + +// GetTask provides a mock function with given fields: +func (_m *Node) GetTask() common.Task { + ret := _m.Called() + + var r0 common.Task + if rf, ok := ret.Get(0).(func() common.Task); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.Task) + } + } + + return r0 +} + +// GetTaskNode provides a mock function with given fields: +func (_m *Node) GetTaskNode() *core.TaskNode { + ret := _m.Called() + + var r0 *core.TaskNode + if rf, ok := ret.Get(0).(func() *core.TaskNode); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.TaskNode) + } + } + + return r0 +} + +// GetUpstreamNodeIds provides a mock function with given fields: +func (_m *Node) GetUpstreamNodeIds() []string { + ret := _m.Called() + + var r0 []string + if rf, ok := ret.Get(0).(func() []string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + return r0 +} + +// GetWorkflowNode provides a mock function with given fields: +func (_m *Node) GetWorkflowNode() *core.WorkflowNode { + ret := _m.Called() + + var r0 *core.WorkflowNode + if rf, ok := ret.Get(0).(func() *core.WorkflowNode); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.WorkflowNode) + } + } + + return r0 +} diff --git a/flytepropeller/pkg/compiler/common/mocks/node_builder.go b/flytepropeller/pkg/compiler/common/mocks/node_builder.go new file mode 100644 index 0000000000..2a164b806e --- /dev/null +++ b/flytepropeller/pkg/compiler/common/mocks/node_builder.go @@ -0,0 +1,222 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import common "github.com/lyft/flytepropeller/pkg/compiler/common" +import core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" +import mock "github.com/stretchr/testify/mock" + +// NodeBuilder is an autogenerated mock type for the NodeBuilder type +type NodeBuilder struct { + mock.Mock +} + +// GetBranchNode provides a mock function with given fields: +func (_m *NodeBuilder) GetBranchNode() *core.BranchNode { + ret := _m.Called() + + var r0 *core.BranchNode + if rf, ok := ret.Get(0).(func() *core.BranchNode); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.BranchNode) + } + } + + return r0 +} + +// GetCoreNode provides a mock function with given fields: +func (_m *NodeBuilder) GetCoreNode() *core.Node { + ret := _m.Called() + + var r0 *core.Node + if rf, ok := ret.Get(0).(func() *core.Node); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.Node) + } + } + + return r0 +} + +// GetId provides a mock function with given fields: +func (_m *NodeBuilder) GetId() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// GetInputs provides a mock function with given fields: +func (_m *NodeBuilder) GetInputs() []*core.Binding { + ret := _m.Called() + + var r0 []*core.Binding + if rf, ok := ret.Get(0).(func() []*core.Binding); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*core.Binding) + } + } + + return r0 +} + +// GetInterface provides a mock function with given fields: +func (_m *NodeBuilder) GetInterface() *core.TypedInterface { + ret := _m.Called() + + var r0 *core.TypedInterface + if rf, ok := ret.Get(0).(func() *core.TypedInterface); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.TypedInterface) + } + } + + return r0 +} + +// GetMetadata provides a mock function with given fields: +func (_m *NodeBuilder) GetMetadata() *core.NodeMetadata { + ret := _m.Called() + + var r0 *core.NodeMetadata + if rf, ok := ret.Get(0).(func() *core.NodeMetadata); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.NodeMetadata) + } + } + + return r0 +} + +// GetOutputAliases provides a mock function with given fields: +func (_m *NodeBuilder) GetOutputAliases() []*core.Alias { + ret := _m.Called() + + var r0 []*core.Alias + if rf, ok := ret.Get(0).(func() []*core.Alias); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*core.Alias) + } + } + + return r0 +} + +// GetSubWorkflow provides a mock function with given fields: +func (_m *NodeBuilder) GetSubWorkflow() common.Workflow { + ret := _m.Called() + + var r0 common.Workflow + if rf, ok := ret.Get(0).(func() common.Workflow); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.Workflow) + } + } + + return r0 +} + +// GetTask provides a mock function with given fields: +func (_m *NodeBuilder) GetTask() common.Task { + ret := _m.Called() + + var r0 common.Task + if rf, ok := ret.Get(0).(func() common.Task); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.Task) + } + } + + return r0 +} + +// GetTaskNode provides a mock function with given fields: +func (_m *NodeBuilder) GetTaskNode() *core.TaskNode { + ret := _m.Called() + + var r0 *core.TaskNode + if rf, ok := ret.Get(0).(func() *core.TaskNode); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.TaskNode) + } + } + + return r0 +} + +// GetUpstreamNodeIds provides a mock function with given fields: +func (_m *NodeBuilder) GetUpstreamNodeIds() []string { + ret := _m.Called() + + var r0 []string + if rf, ok := ret.Get(0).(func() []string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + return r0 +} + +// GetWorkflowNode provides a mock function with given fields: +func (_m *NodeBuilder) GetWorkflowNode() *core.WorkflowNode { + ret := _m.Called() + + var r0 *core.WorkflowNode + if rf, ok := ret.Get(0).(func() *core.WorkflowNode); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.WorkflowNode) + } + } + + return r0 +} + +// SetInputs provides a mock function with given fields: inputs +func (_m *NodeBuilder) SetInputs(inputs []*core.Binding) { + _m.Called(inputs) +} + +// SetInterface provides a mock function with given fields: iface +func (_m *NodeBuilder) SetInterface(iface *core.TypedInterface) { + _m.Called(iface) +} + +// SetSubWorkflow provides a mock function with given fields: wf +func (_m *NodeBuilder) SetSubWorkflow(wf common.Workflow) { + _m.Called(wf) +} + +// SetTask provides a mock function with given fields: task +func (_m *NodeBuilder) SetTask(task common.Task) { + _m.Called(task) +} diff --git a/flytepropeller/pkg/compiler/common/mocks/task.go b/flytepropeller/pkg/compiler/common/mocks/task.go new file mode 100644 index 0000000000..4769618740 --- /dev/null +++ b/flytepropeller/pkg/compiler/common/mocks/task.go @@ -0,0 +1,57 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" +import mock "github.com/stretchr/testify/mock" + +// Task is an autogenerated mock type for the Task type +type Task struct { + mock.Mock +} + +// GetCoreTask provides a mock function with given fields: +func (_m *Task) GetCoreTask() *core.TaskTemplate { + ret := _m.Called() + + var r0 *core.TaskTemplate + if rf, ok := ret.Get(0).(func() *core.TaskTemplate); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.TaskTemplate) + } + } + + return r0 +} + +// GetID provides a mock function with given fields: +func (_m *Task) GetID() core.Identifier { + ret := _m.Called() + + var r0 core.Identifier + if rf, ok := ret.Get(0).(func() core.Identifier); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(core.Identifier) + } + + return r0 +} + +// GetInterface provides a mock function with given fields: +func (_m *Task) GetInterface() *core.TypedInterface { + ret := _m.Called() + + var r0 *core.TypedInterface + if rf, ok := ret.Get(0).(func() *core.TypedInterface); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.TypedInterface) + } + } + + return r0 +} diff --git a/flytepropeller/pkg/compiler/common/mocks/workflow.go b/flytepropeller/pkg/compiler/common/mocks/workflow.go new file mode 100644 index 0000000000..2e1a3dc096 --- /dev/null +++ b/flytepropeller/pkg/compiler/common/mocks/workflow.go @@ -0,0 +1,200 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import common "github.com/lyft/flytepropeller/pkg/compiler/common" +import core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" +import mock "github.com/stretchr/testify/mock" + +// Workflow is an autogenerated mock type for the Workflow type +type Workflow struct { + mock.Mock +} + +// GetCoreWorkflow provides a mock function with given fields: +func (_m *Workflow) GetCoreWorkflow() *core.CompiledWorkflow { + ret := _m.Called() + + var r0 *core.CompiledWorkflow + if rf, ok := ret.Get(0).(func() *core.CompiledWorkflow); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.CompiledWorkflow) + } + } + + return r0 +} + +// GetDownstreamNodes provides a mock function with given fields: +func (_m *Workflow) GetDownstreamNodes() common.StringAdjacencyList { + ret := _m.Called() + + var r0 common.StringAdjacencyList + if rf, ok := ret.Get(0).(func() common.StringAdjacencyList); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.StringAdjacencyList) + } + } + + return r0 +} + +// GetFailureNode provides a mock function with given fields: +func (_m *Workflow) GetFailureNode() common.Node { + ret := _m.Called() + + var r0 common.Node + if rf, ok := ret.Get(0).(func() common.Node); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.Node) + } + } + + return r0 +} + +// GetLaunchPlan provides a mock function with given fields: id +func (_m *Workflow) GetLaunchPlan(id core.Identifier) (common.InterfaceProvider, bool) { + ret := _m.Called(id) + + var r0 common.InterfaceProvider + if rf, ok := ret.Get(0).(func(core.Identifier) common.InterfaceProvider); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.InterfaceProvider) + } + } + + var r1 bool + if rf, ok := ret.Get(1).(func(core.Identifier) bool); ok { + r1 = rf(id) + } else { + r1 = ret.Get(1).(bool) + } + + return r0, r1 +} + +// GetNode provides a mock function with given fields: id +func (_m *Workflow) GetNode(id string) (common.NodeBuilder, bool) { + ret := _m.Called(id) + + var r0 common.NodeBuilder + if rf, ok := ret.Get(0).(func(string) common.NodeBuilder); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.NodeBuilder) + } + } + + var r1 bool + if rf, ok := ret.Get(1).(func(string) bool); ok { + r1 = rf(id) + } else { + r1 = ret.Get(1).(bool) + } + + return r0, r1 +} + +// GetNodes provides a mock function with given fields: +func (_m *Workflow) GetNodes() common.NodeIndex { + ret := _m.Called() + + var r0 common.NodeIndex + if rf, ok := ret.Get(0).(func() common.NodeIndex); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.NodeIndex) + } + } + + return r0 +} + +// GetSubWorkflow provides a mock function with given fields: id +func (_m *Workflow) GetSubWorkflow(id core.Identifier) (*core.CompiledWorkflow, bool) { + ret := _m.Called(id) + + var r0 *core.CompiledWorkflow + if rf, ok := ret.Get(0).(func(core.Identifier) *core.CompiledWorkflow); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.CompiledWorkflow) + } + } + + var r1 bool + if rf, ok := ret.Get(1).(func(core.Identifier) bool); ok { + r1 = rf(id) + } else { + r1 = ret.Get(1).(bool) + } + + return r0, r1 +} + +// GetTask provides a mock function with given fields: id +func (_m *Workflow) GetTask(id core.Identifier) (common.Task, bool) { + ret := _m.Called(id) + + var r0 common.Task + if rf, ok := ret.Get(0).(func(core.Identifier) common.Task); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.Task) + } + } + + var r1 bool + if rf, ok := ret.Get(1).(func(core.Identifier) bool); ok { + r1 = rf(id) + } else { + r1 = ret.Get(1).(bool) + } + + return r0, r1 +} + +// GetTasks provides a mock function with given fields: +func (_m *Workflow) GetTasks() common.TaskIndex { + ret := _m.Called() + + var r0 common.TaskIndex + if rf, ok := ret.Get(0).(func() common.TaskIndex); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.TaskIndex) + } + } + + return r0 +} + +// GetUpstreamNodes provides a mock function with given fields: +func (_m *Workflow) GetUpstreamNodes() common.StringAdjacencyList { + ret := _m.Called() + + var r0 common.StringAdjacencyList + if rf, ok := ret.Get(0).(func() common.StringAdjacencyList); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.StringAdjacencyList) + } + } + + return r0 +} diff --git a/flytepropeller/pkg/compiler/common/mocks/workflow_builder.go b/flytepropeller/pkg/compiler/common/mocks/workflow_builder.go new file mode 100644 index 0000000000..7816e76246 --- /dev/null +++ b/flytepropeller/pkg/compiler/common/mocks/workflow_builder.go @@ -0,0 +1,268 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import common "github.com/lyft/flytepropeller/pkg/compiler/common" +import core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" +import errors "github.com/lyft/flytepropeller/pkg/compiler/errors" +import mock "github.com/stretchr/testify/mock" + +// WorkflowBuilder is an autogenerated mock type for the WorkflowBuilder type +type WorkflowBuilder struct { + mock.Mock +} + +// AddExecutionEdge provides a mock function with given fields: nodeFrom, nodeTo +func (_m *WorkflowBuilder) AddExecutionEdge(nodeFrom string, nodeTo string) { + _m.Called(nodeFrom, nodeTo) +} + +// AddNode provides a mock function with given fields: n, errs +func (_m *WorkflowBuilder) AddNode(n common.NodeBuilder, errs errors.CompileErrors) (common.NodeBuilder, bool) { + ret := _m.Called(n, errs) + + var r0 common.NodeBuilder + if rf, ok := ret.Get(0).(func(common.NodeBuilder, errors.CompileErrors) common.NodeBuilder); ok { + r0 = rf(n, errs) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.NodeBuilder) + } + } + + var r1 bool + if rf, ok := ret.Get(1).(func(common.NodeBuilder, errors.CompileErrors) bool); ok { + r1 = rf(n, errs) + } else { + r1 = ret.Get(1).(bool) + } + + return r0, r1 +} + +// GetCoreWorkflow provides a mock function with given fields: +func (_m *WorkflowBuilder) GetCoreWorkflow() *core.CompiledWorkflow { + ret := _m.Called() + + var r0 *core.CompiledWorkflow + if rf, ok := ret.Get(0).(func() *core.CompiledWorkflow); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.CompiledWorkflow) + } + } + + return r0 +} + +// GetDownstreamNodes provides a mock function with given fields: +func (_m *WorkflowBuilder) GetDownstreamNodes() common.StringAdjacencyList { + ret := _m.Called() + + var r0 common.StringAdjacencyList + if rf, ok := ret.Get(0).(func() common.StringAdjacencyList); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.StringAdjacencyList) + } + } + + return r0 +} + +// GetFailureNode provides a mock function with given fields: +func (_m *WorkflowBuilder) GetFailureNode() common.Node { + ret := _m.Called() + + var r0 common.Node + if rf, ok := ret.Get(0).(func() common.Node); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.Node) + } + } + + return r0 +} + +// GetLaunchPlan provides a mock function with given fields: id +func (_m *WorkflowBuilder) GetLaunchPlan(id core.Identifier) (common.InterfaceProvider, bool) { + ret := _m.Called(id) + + var r0 common.InterfaceProvider + if rf, ok := ret.Get(0).(func(core.Identifier) common.InterfaceProvider); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.InterfaceProvider) + } + } + + var r1 bool + if rf, ok := ret.Get(1).(func(core.Identifier) bool); ok { + r1 = rf(id) + } else { + r1 = ret.Get(1).(bool) + } + + return r0, r1 +} + +// GetNode provides a mock function with given fields: id +func (_m *WorkflowBuilder) GetNode(id string) (common.NodeBuilder, bool) { + ret := _m.Called(id) + + var r0 common.NodeBuilder + if rf, ok := ret.Get(0).(func(string) common.NodeBuilder); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.NodeBuilder) + } + } + + var r1 bool + if rf, ok := ret.Get(1).(func(string) bool); ok { + r1 = rf(id) + } else { + r1 = ret.Get(1).(bool) + } + + return r0, r1 +} + +// GetNodes provides a mock function with given fields: +func (_m *WorkflowBuilder) GetNodes() common.NodeIndex { + ret := _m.Called() + + var r0 common.NodeIndex + if rf, ok := ret.Get(0).(func() common.NodeIndex); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.NodeIndex) + } + } + + return r0 +} + +// GetSubWorkflow provides a mock function with given fields: id +func (_m *WorkflowBuilder) GetSubWorkflow(id core.Identifier) (*core.CompiledWorkflow, bool) { + ret := _m.Called(id) + + var r0 *core.CompiledWorkflow + if rf, ok := ret.Get(0).(func(core.Identifier) *core.CompiledWorkflow); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.CompiledWorkflow) + } + } + + var r1 bool + if rf, ok := ret.Get(1).(func(core.Identifier) bool); ok { + r1 = rf(id) + } else { + r1 = ret.Get(1).(bool) + } + + return r0, r1 +} + +// GetTask provides a mock function with given fields: id +func (_m *WorkflowBuilder) GetTask(id core.Identifier) (common.Task, bool) { + ret := _m.Called(id) + + var r0 common.Task + if rf, ok := ret.Get(0).(func(core.Identifier) common.Task); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.Task) + } + } + + var r1 bool + if rf, ok := ret.Get(1).(func(core.Identifier) bool); ok { + r1 = rf(id) + } else { + r1 = ret.Get(1).(bool) + } + + return r0, r1 +} + +// GetTasks provides a mock function with given fields: +func (_m *WorkflowBuilder) GetTasks() common.TaskIndex { + ret := _m.Called() + + var r0 common.TaskIndex + if rf, ok := ret.Get(0).(func() common.TaskIndex); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.TaskIndex) + } + } + + return r0 +} + +// GetUpstreamNodes provides a mock function with given fields: +func (_m *WorkflowBuilder) GetUpstreamNodes() common.StringAdjacencyList { + ret := _m.Called() + + var r0 common.StringAdjacencyList + if rf, ok := ret.Get(0).(func() common.StringAdjacencyList); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.StringAdjacencyList) + } + } + + return r0 +} + +// NewNodeBuilder provides a mock function with given fields: n +func (_m *WorkflowBuilder) NewNodeBuilder(n *core.Node) common.NodeBuilder { + ret := _m.Called(n) + + var r0 common.NodeBuilder + if rf, ok := ret.Get(0).(func(*core.Node) common.NodeBuilder); ok { + r0 = rf(n) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.NodeBuilder) + } + } + + return r0 +} + +// ValidateWorkflow provides a mock function with given fields: fg, errs +func (_m *WorkflowBuilder) ValidateWorkflow(fg *core.CompiledWorkflow, errs errors.CompileErrors) (common.Workflow, bool) { + ret := _m.Called(fg, errs) + + var r0 common.Workflow + if rf, ok := ret.Get(0).(func(*core.CompiledWorkflow, errors.CompileErrors) common.Workflow); ok { + r0 = rf(fg, errs) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.Workflow) + } + } + + var r1 bool + if rf, ok := ret.Get(1).(func(*core.CompiledWorkflow, errors.CompileErrors) bool); ok { + r1 = rf(fg, errs) + } else { + r1 = ret.Get(1).(bool) + } + + return r0, r1 +} diff --git a/flytepropeller/pkg/compiler/common/reader.go b/flytepropeller/pkg/compiler/common/reader.go new file mode 100644 index 0000000000..2edd098da9 --- /dev/null +++ b/flytepropeller/pkg/compiler/common/reader.go @@ -0,0 +1,55 @@ +package common + +import ( + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" +) + +type NodeID = string +type TaskID = Identifier +type WorkflowID = Identifier +type LaunchPlanID = Identifier +type TaskIDKey = string +type WorkflowIDKey = string + +// An immutable workflow that represents the final output of the compiler. +type Workflow interface { + GetNode(id NodeID) (node NodeBuilder, found bool) + GetTask(id TaskID) (task Task, found bool) + GetLaunchPlan(id LaunchPlanID) (wf InterfaceProvider, found bool) + GetSubWorkflow(id WorkflowID) (wf *core.CompiledWorkflow, found bool) + GetCoreWorkflow() *core.CompiledWorkflow + GetFailureNode() Node + GetNodes() NodeIndex + GetTasks() TaskIndex + GetDownstreamNodes() StringAdjacencyList + GetUpstreamNodes() StringAdjacencyList +} + +// An immutable Node that represents the final output of the compiler. +type Node interface { + GetId() NodeID + GetInterface() *core.TypedInterface + GetInputs() []*core.Binding + GetWorkflowNode() *core.WorkflowNode + GetOutputAliases() []*core.Alias + GetUpstreamNodeIds() []string + GetCoreNode() *core.Node + GetBranchNode() *core.BranchNode + GetTaskNode() *core.TaskNode + GetMetadata() *core.NodeMetadata + GetTask() Task + GetSubWorkflow() Workflow +} + +// An immutable task that represents the final output of the compiler. +type Task interface { + GetID() TaskID + GetCoreTask() *core.TaskTemplate + GetInterface() *core.TypedInterface +} + +type InterfaceProvider interface { + GetID() *core.Identifier + GetExpectedInputs() *core.ParameterMap + GetExpectedOutputs() *core.VariableMap +} diff --git a/flytepropeller/pkg/compiler/errors/compiler_error_test.go b/flytepropeller/pkg/compiler/errors/compiler_error_test.go new file mode 100644 index 0000000000..a3d1eb7556 --- /dev/null +++ b/flytepropeller/pkg/compiler/errors/compiler_error_test.go @@ -0,0 +1,52 @@ +package errors + +import ( + "testing" + + "github.com/magiconair/properties/assert" + "github.com/pkg/errors" +) + +func mustErrorCode(t *testing.T, compileError *CompileError, code ErrorCode) { + assert.Equal(t, code, compileError.Code()) +} + +func TestErrorCodes(t *testing.T) { + testCases := map[ErrorCode]*CompileError{ + CycleDetected: NewCycleDetectedInWorkflowErr("", ""), + BranchNodeIDNotFound: NewBranchNodeNotSpecified(""), + BranchNodeHasNoCondition: NewBranchNodeHasNoCondition(""), + ValueRequired: NewValueRequiredErr("", ""), + NodeReferenceNotFound: NewNodeReferenceNotFoundErr("", ""), + TaskReferenceNotFound: NewTaskReferenceNotFoundErr("", ""), + WorkflowReferenceNotFound: NewWorkflowReferenceNotFoundErr("", ""), + VariableNameNotFound: NewVariableNameNotFoundErr("", "", ""), + DuplicateAlias: NewDuplicateAliasErr("", ""), + DuplicateNodeID: NewDuplicateIDFoundErr(""), + MismatchingTypes: NewMismatchingTypesErr("", "", "", ""), + MismatchingInterfaces: NewMismatchingInterfacesErr("", ""), + InconsistentTypes: NewInconsistentTypesErr("", "", ""), + ParameterBoundMoreThanOnce: NewParameterBoundMoreThanOnceErr("", ""), + ParameterNotBound: NewParameterNotBoundErr("", ""), + NoEntryNodeFound: NewWorkflowHasNoEntryNodeErr(""), + UnreachableNodes: NewUnreachableNodesErr("", ""), + UnrecognizedValue: NewUnrecognizedValueErr("", ""), + WorkflowBuildError: NewWorkflowBuildError(errors.New("")), + } + + for key, value := range testCases { + t.Run(string(key), func(t *testing.T) { + mustErrorCode(t, value, key) + }) + } +} + +func TestIncludeSource(t *testing.T) { + e := NewCycleDetectedInWorkflowErr("", "") + assert.Equal(t, e.source, "") + + SetConfig(Config{IncludeSource: true}) + e = NewCycleDetectedInWorkflowErr("", "") + assert.Equal(t, e.source, "compiler_error_test.go:49") + SetConfig(Config{}) +} diff --git a/flytepropeller/pkg/compiler/errors/compiler_errors.go b/flytepropeller/pkg/compiler/errors/compiler_errors.go new file mode 100755 index 0000000000..86236c8943 --- /dev/null +++ b/flytepropeller/pkg/compiler/errors/compiler_errors.go @@ -0,0 +1,272 @@ +package errors + +import ( + "fmt" + "runtime" + "strings" +) + +const ( + // A cycle is detected in the Workflow, the error description should detail the nodes involved. + CycleDetected ErrorCode = "CycleDetected" + + // BranchNode is missing a case with ThenNode populated. + BranchNodeIDNotFound ErrorCode = "BranchNodeIdNotFound" + + // BranchNode is missing a condition. + BranchNodeHasNoCondition ErrorCode = "BranchNodeHasNoCondition" + + // An expected field isn't populated. + ValueRequired ErrorCode = "ValueRequired" + + // A nodeBuilder referenced by an edge doesn't belong to the Workflow. + NodeReferenceNotFound ErrorCode = "NodeReferenceNotFound" + + // A Task referenced by a node wasn't found. + TaskReferenceNotFound ErrorCode = "TaskReferenceNotFound" + + // A Workflow referenced by a node wasn't found. + WorkflowReferenceNotFound ErrorCode = "WorkflowReferenceNotFound" + + // A referenced variable (in a parameter or a condition) wasn't found. + VariableNameNotFound ErrorCode = "VariableNameNotFound" + + // An alias existed twice. + DuplicateAlias ErrorCode = "DuplicateAlias" + + // An Id existed twice. + DuplicateNodeID ErrorCode = "DuplicateId" + + // Two types expected to be compatible but aren't. + MismatchingTypes ErrorCode = "MismatchingTypes" + + // A binding is attempted via a list or map syntax, but the underlying type isn't a list or map. + MismatchingBindings ErrorCode = "MismatchingBindings" + + // Two interfaced expected to be compatible but aren't. + MismatchingInterfaces ErrorCode = "MismatchingInterfaces" + + // Expected types to be consistent. + InconsistentTypes ErrorCode = "InconsistentTypes" + + // An input/output parameter was assigned a value through an edge more than once. + ParameterBoundMoreThanOnce ErrorCode = "ParameterBoundMoreThanOnce" + + // One of the required input parameters or a Workflow output parameter wasn't bound. + ParameterNotBound ErrorCode = "ParameterNotBound" + + // When we couldn't assign an entry point to the Workflow. + NoEntryNodeFound ErrorCode = "NoEntryNodeFound" + + // When one more more unreachable node are detected. + UnreachableNodes ErrorCode = "UnreachableNodes" + + // A Value doesn't fall within the expected range. + UnrecognizedValue ErrorCode = "UnrecognizedValue" + + // An unknown error occurred while building the workflow. + WorkflowBuildError ErrorCode = "WorkflowBuildError" + + // A value is expected to be unique but wasnt. + ValueCollision ErrorCode = "ValueCollision" + + // A value isn't on the right syntax. + SyntaxError ErrorCode = "SyntaxError" +) + +func NewBranchNodeNotSpecified(branchNodeID string) *CompileError { + return newError( + BranchNodeIDNotFound, + fmt.Sprintf("BranchNode not assigned"), + branchNodeID, + ) +} + +func NewBranchNodeHasNoCondition(branchNodeID string) *CompileError { + return newError( + BranchNodeHasNoCondition, + "One of the branches on the node doesn't have a condition.", + branchNodeID, + ) +} + +func NewValueRequiredErr(nodeID, paramName string) *CompileError { + return newError( + ValueRequired, + fmt.Sprintf("Value required [%v].", paramName), + nodeID, + ) +} + +func NewParameterNotBoundErr(nodeID, paramName string) *CompileError { + return newError( + ParameterNotBound, + fmt.Sprintf("Parameter not bound [%v].", paramName), + nodeID, + ) +} + +func NewNodeReferenceNotFoundErr(nodeID, referenceID string) *CompileError { + return newError( + NodeReferenceNotFound, + fmt.Sprintf("Referenced node [%v] not found.", referenceID), + nodeID, + ) +} + +func NewWorkflowReferenceNotFoundErr(nodeID, referenceID string) *CompileError { + return newError( + WorkflowReferenceNotFound, + fmt.Sprintf("Referenced Workflow [%v] not found.", referenceID), + nodeID, + ) +} + +func NewTaskReferenceNotFoundErr(nodeID, referenceID string) *CompileError { + return newError( + TaskReferenceNotFound, + fmt.Sprintf("Referenced Task [%v] not found.", referenceID), + nodeID, + ) +} + +func NewVariableNameNotFoundErr(nodeID, referenceID, variableName string) *CompileError { + return newError( + VariableNameNotFound, + fmt.Sprintf("Variable [%v] not found on node [%v].", variableName, referenceID), + nodeID, + ) +} + +func NewParameterBoundMoreThanOnceErr(nodeID, paramName string) *CompileError { + return newError( + ParameterBoundMoreThanOnce, + fmt.Sprintf("Input [%v] is bound more than once.", paramName), + nodeID, + ) +} + +func NewDuplicateAliasErr(nodeID, alias string) *CompileError { + return newError( + DuplicateAlias, + fmt.Sprintf("Duplicate alias [%v] found. An output alias can only be used once in the Workflow.", alias), + nodeID, + ) +} + +func NewDuplicateIDFoundErr(nodeID string) *CompileError { + return newError( + DuplicateNodeID, + "Trying to insert two nodes with the same id.", + nodeID, + ) +} + +func NewMismatchingTypesErr(nodeID, fromVar, fromType, toType string) *CompileError { + return newError( + MismatchingTypes, + fmt.Sprintf("Variable [%v] (type [%v]) doesn't match expected type [%v].", fromVar, fromType, + toType), + nodeID, + ) +} + +func NewMismatchingBindingsErr(nodeID, sinkParam, expectedType, receivedType string) *CompileError { + return newError( + MismatchingBindings, + fmt.Sprintf("Input [%v] on node [%v] expects bindings of type [%v]. Received [%v]", sinkParam, nodeID, expectedType, receivedType), + nodeID, + ) +} + +func NewMismatchingInterfacesErr(nodeID1, nodeID2 string) *CompileError { + return newError( + MismatchingInterfaces, + fmt.Sprintf("Interfaces of nodes [%v] and [%v] do not match.", nodeID1, nodeID2), + nodeID1, + ) +} + +func NewInconsistentTypesErr(nodeID, expectedType, actualType string) *CompileError { + return newError( + InconsistentTypes, + fmt.Sprintf("Expected type: %v but found %v", expectedType, actualType), + nodeID, + ) +} + +func NewWorkflowHasNoEntryNodeErr(graphID string) *CompileError { + return newError( + NoEntryNodeFound, + fmt.Sprintf("Can't find a node to start executing Workflow [%v].", graphID), + graphID, + ) +} + +func NewCycleDetectedInWorkflowErr(nodeID, cycle string) *CompileError { + return newError( + CycleDetected, + fmt.Sprintf("A cycle has been detected while traversing the Workflow [%v].", cycle), + nodeID, + ) +} + +func NewUnreachableNodesErr(nodeID, nodes string) *CompileError { + return newError( + UnreachableNodes, + fmt.Sprintf("The Workflow contain unreachable nodes [%v].", nodes), + nodeID, + ) +} + +func NewUnrecognizedValueErr(nodeID, value string) *CompileError { + return newError( + UnrecognizedValue, + fmt.Sprintf("Unrecognized value [%v].", value), + nodeID, + ) +} + +func NewWorkflowBuildError(err error) *CompileError { + return newError(WorkflowBuildError, err.Error(), "") +} + +func NewValueCollisionError(nodeID string, valueName, value string) *CompileError { + return newError( + ValueCollision, + fmt.Sprintf("%v is expected to be unique. %v already exists.", valueName, value), + nodeID, + ) +} + +func NewSyntaxError(nodeID string, element string, err error) *CompileError { + return newError(SyntaxError, + fmt.Sprintf("Failed to parse element [%v].", element), + nodeID, + ) +} + +func newError(code ErrorCode, description, nodeID string) (err *CompileError) { + err = &CompileError{ + code: code, + description: description, + nodeID: nodeID, + } + + if GetConfig().IncludeSource { + _, file, line, ok := runtime.Caller(2) + if !ok { + file = "???" + line = 1 + } else { + slash := strings.LastIndex(file, "/") + if slash >= 0 { + file = file[slash+1:] + } + } + + err.source = fmt.Sprintf("%v:%v", file, line) + } + + return +} diff --git a/flytepropeller/pkg/compiler/errors/config.go b/flytepropeller/pkg/compiler/errors/config.go new file mode 100644 index 0000000000..7cde12c912 --- /dev/null +++ b/flytepropeller/pkg/compiler/errors/config.go @@ -0,0 +1,27 @@ +package errors + +// Represents error config that can change the behavior of how errors collection/reporting is handled. +type Config struct { + // Indicates that a panic should be issued as soon as the first error is collected. + PanicOnError bool + + // Indicates that errors should include source code information when collected. There is an associated performance + // penalty with this behavior. + IncludeSource bool +} + +var config = Config{} + +// Sets global config. +func SetConfig(cfg Config) { + config = cfg +} + +// Gets global config. +func GetConfig() Config { + return config +} + +func SetIncludeSource() { + config.IncludeSource = true +} diff --git a/flytepropeller/pkg/compiler/errors/error.go b/flytepropeller/pkg/compiler/errors/error.go new file mode 100755 index 0000000000..735e42619f --- /dev/null +++ b/flytepropeller/pkg/compiler/errors/error.go @@ -0,0 +1,125 @@ +// This package is a central repository of all compile errors that can be reported. It contains ways to collect and format +// errors to make it easy to find and correct workflow spec problems. +package errors + +import "fmt" + +type ErrorCode string + +// Represents a compile error for coreWorkflow. +type CompileError struct { + code ErrorCode + nodeID string + description string + source string +} + +// Represents a compile error with a root cause. +type CompileErrorWithCause struct { + *CompileError + cause error +} + +// A set of Compile errors. +type CompileErrors interface { + error + Collect(e ...*CompileError) + NewScope() CompileErrors + Errors() *compileErrorSet + HasErrors() bool + ErrorCount() int +} + +type compileErrors struct { + errorSet *compileErrorSet + parent CompileErrors + errorCountInScope int +} + +// Gets the Compile Error code +func (err CompileError) Code() ErrorCode { + return err.code +} + +// Gets a readable/formatted string explaining the compile error as well as at which node it occurred. +func (err CompileError) Error() string { + source := "" + if err.source != "" { + source = fmt.Sprintf("[%v] ", err.source) + } + + return fmt.Sprintf("%vCode: %s, Node Id: %s, Description: %s", source, err.code, err.nodeID, err.description) +} + +// Gets a readable/formatted string explaining the compile error as well as at which node it occurred. +func (err CompileErrorWithCause) Error() string { + cause := "" + if err.cause != nil { + cause = fmt.Sprintf(", Cause: %v", err.cause.Error()) + } + + return fmt.Sprintf("%v%v", err.CompileError.Error(), cause) +} + +// Exposes the set of unique errors. +func (errs *compileErrors) Errors() *compileErrorSet { + return errs.errorSet +} + +// Appends a compile error to the set. +func (errs *compileErrors) Collect(e ...*CompileError) { + if e != nil { + if GetConfig().PanicOnError { + panic(e) + } + + if errs.parent != nil { + errs.parent.Collect(e...) + errs.errorCountInScope += len(e) + } else { + for _, err := range e { + if err != nil { + errs.errorSet.Put(*err) + errs.errorCountInScope++ + } + } + } + } +} + +// Creates a new scope for compile errors. Parent scope will always automatically collect errors reported in any of its +// child scopes. +func (errs *compileErrors) NewScope() CompileErrors { + return &compileErrors{parent: errs} +} + +// Gets a formatted string of all compile errors collected. +func (errs *compileErrors) Error() (err string) { + if errs.parent != nil { + return errs.parent.Error() + } + + err = fmt.Sprintf("Collected Errors: %v\n", len(*errs.Errors())) + i := 0 + for _, e := range errs.Errors().List() { + err += fmt.Sprintf("\tError %d: %s\n", i, e.Error()) + i++ + } + + return err +} + +// Gets a value indicating whether there are any errors collected within current scope and all of its children. +func (errs *compileErrors) HasErrors() bool { + return errs.errorCountInScope > 0 +} + +// Gets the number of errors collected within current scope and all of its children. +func (errs *compileErrors) ErrorCount() int { + return errs.errorCountInScope +} + +// Creates a new empty compile errors +func NewCompileErrors() CompileErrors { + return &compileErrors{errorSet: &compileErrorSet{}} +} diff --git a/flytepropeller/pkg/compiler/errors/error_test.go b/flytepropeller/pkg/compiler/errors/error_test.go new file mode 100755 index 0000000000..3756501fb7 --- /dev/null +++ b/flytepropeller/pkg/compiler/errors/error_test.go @@ -0,0 +1,43 @@ +package errors + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func addError(errs CompileErrors) { + errs.Collect(NewValueRequiredErr("node", "param")) +} + +func TestCompileErrors_Collect(t *testing.T) { + errs := NewCompileErrors() + assert.False(t, errs.HasErrors()) + addError(errs) + assert.True(t, errs.HasErrors()) +} + +func TestCompileErrors_NewScope(t *testing.T) { + errs := NewCompileErrors() + addError(errs.NewScope().NewScope()) + assert.True(t, errs.HasErrors()) + assert.Equal(t, 1, errs.ErrorCount()) +} + +func TestCompileErrors_Errors(t *testing.T) { + errs := NewCompileErrors() + addError(errs.NewScope().NewScope()) + addError(errs.NewScope().NewScope()) + assert.True(t, errs.HasErrors()) + assert.Equal(t, 2, errs.ErrorCount()) + + set := errs.Errors() + assert.Equal(t, 1, len(*set)) +} + +func TestCompileErrors_Error(t *testing.T) { + errs := NewCompileErrors() + addError(errs.NewScope().NewScope()) + addError(errs.NewScope().NewScope()) + assert.NotEqual(t, "", errs.Error()) +} diff --git a/flytepropeller/pkg/compiler/errors/sets.go b/flytepropeller/pkg/compiler/errors/sets.go new file mode 100755 index 0000000000..c03c7259d7 --- /dev/null +++ b/flytepropeller/pkg/compiler/errors/sets.go @@ -0,0 +1,44 @@ +package errors + +import ( + "sort" + "strings" +) + +var keyExists = struct{}{} + +type compileErrorSet map[CompileError]struct{} + +func (s compileErrorSet) Put(key CompileError) { + s[key] = keyExists +} + +func (s compileErrorSet) Contains(key CompileError) bool { + _, ok := s[key] + return ok +} + +func (s compileErrorSet) Remove(key CompileError) { + delete(s, key) +} + +func refCompileError(x CompileError) *CompileError { + return &x +} + +func (s compileErrorSet) List() []*CompileError { + res := make([]*CompileError, 0, len(s)) + for key := range s { + res = append(res, refCompileError(key)) + } + + sort.SliceStable(res, func(i, j int) bool { + if res[i].Code() == res[j].Code() { + return res[i].Error() < res[j].Error() + } + + return strings.Compare(string(res[i].Code()), string(res[j].Code())) < 0 + }) + + return res +} diff --git a/flytepropeller/pkg/compiler/errors/sets_test.go b/flytepropeller/pkg/compiler/errors/sets_test.go new file mode 100644 index 0000000000..d487d3b560 --- /dev/null +++ b/flytepropeller/pkg/compiler/errors/sets_test.go @@ -0,0 +1,20 @@ +package errors + +import ( + "testing" + + "github.com/magiconair/properties/assert" +) + +func TestCompileErrorSet_List(t *testing.T) { + set := compileErrorSet{} + set.Put(*NewValueRequiredErr("node1", "param")) + set.Put(*NewWorkflowHasNoEntryNodeErr("graph1")) + set.Put(*NewWorkflowHasNoEntryNodeErr("graph1")) + assert.Equal(t, len(set), 2) + + lst := set.List() + assert.Equal(t, len(lst), 2) + assert.Equal(t, lst[0].Code(), NoEntryNodeFound) + assert.Equal(t, lst[1].Code(), ValueRequired) +} diff --git a/flytepropeller/pkg/compiler/requirements.go b/flytepropeller/pkg/compiler/requirements.go new file mode 100755 index 0000000000..989ecb4034 --- /dev/null +++ b/flytepropeller/pkg/compiler/requirements.go @@ -0,0 +1,88 @@ +package compiler + +import ( + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/compiler/common" + "github.com/lyft/flytepropeller/pkg/compiler/errors" +) + +type TaskIdentifier = common.Identifier +type LaunchPlanRefIdentifier = common.Identifier + +// Represents the set of required resources for a given Workflow's execution. All of the resources should be loaded before +// hand and passed to the compiler. +type WorkflowExecutionRequirements struct { + taskIds []TaskIdentifier + launchPlanIds []LaunchPlanRefIdentifier +} + +// Gets a slice of required Task ids to load. +func (g WorkflowExecutionRequirements) GetRequiredTaskIds() []TaskIdentifier { + return g.taskIds +} + +// Gets a slice of required Workflow ids to load. +func (g WorkflowExecutionRequirements) GetRequiredLaunchPlanIds() []LaunchPlanRefIdentifier { + return g.launchPlanIds +} + +// Computes requirements for a given Workflow. +func GetRequirements(fg *core.WorkflowTemplate, subWfs []*core.WorkflowTemplate) (reqs WorkflowExecutionRequirements, err error) { + errs := errors.NewCompileErrors() + compiledSubWfs := toCompiledWorkflows(subWfs...) + + index, ok := common.NewWorkflowIndex(compiledSubWfs, errs) + + if ok { + return getRequirements(fg, index, true, errs), nil + } + + return WorkflowExecutionRequirements{}, errs +} + +func getRequirements(fg *core.WorkflowTemplate, subWfs common.WorkflowIndex, followSubworkflows bool, + errs errors.CompileErrors) (reqs WorkflowExecutionRequirements) { + + taskIds := common.NewIdentifierSet() + launchPlanIds := common.NewIdentifierSet() + updateWorkflowRequirements(fg, subWfs, taskIds, launchPlanIds, followSubworkflows, errs) + + reqs.taskIds = taskIds.List() + reqs.launchPlanIds = launchPlanIds.List() + + return +} + +// Augments taskIds and launchPlanIds with referenced tasks/workflows within coreWorkflow nodes +func updateWorkflowRequirements(workflow *core.WorkflowTemplate, subWfs common.WorkflowIndex, + taskIds, workflowIds common.IdentifierSet, followSubworkflows bool, errs errors.CompileErrors) { + + for _, node := range workflow.Nodes { + updateNodeRequirements(node, subWfs, taskIds, workflowIds, followSubworkflows, errs) + } +} + +func updateNodeRequirements(node *flyteNode, subWfs common.WorkflowIndex, taskIds, workflowIds common.IdentifierSet, + followSubworkflows bool, errs errors.CompileErrors) (ok bool) { + + if taskN := node.GetTaskNode(); taskN != nil && taskN.GetReferenceId() != nil { + taskIds.Insert(*taskN.GetReferenceId()) + } else if workflowNode := node.GetWorkflowNode(); workflowNode != nil { + if workflowNode.GetLaunchplanRef() != nil { + workflowIds.Insert(*workflowNode.GetLaunchplanRef()) + } else if workflowNode.GetSubWorkflowRef() != nil && followSubworkflows { + if subWf, found := subWfs[workflowNode.GetSubWorkflowRef().String()]; !found { + errs.Collect(errors.NewWorkflowReferenceNotFoundErr(node.Id, workflowNode.GetSubWorkflowRef().String())) + } else { + updateWorkflowRequirements(subWf.Template, subWfs, taskIds, workflowIds, followSubworkflows, errs) + } + } + } else if branchN := node.GetBranchNode(); branchN != nil { + updateNodeRequirements(branchN.IfElse.Case.ThenNode, subWfs, taskIds, workflowIds, followSubworkflows, errs) + for _, otherCase := range branchN.IfElse.Other { + updateNodeRequirements(otherCase.ThenNode, subWfs, taskIds, workflowIds, followSubworkflows, errs) + } + } + + return !errs.HasErrors() +} diff --git a/flytepropeller/pkg/compiler/requirements_test.go b/flytepropeller/pkg/compiler/requirements_test.go new file mode 100755 index 0000000000..fdc7eaa72f --- /dev/null +++ b/flytepropeller/pkg/compiler/requirements_test.go @@ -0,0 +1,125 @@ +package compiler + +import ( + "testing" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/stretchr/testify/assert" +) + +func TestGetRequirements(t *testing.T) { + g := &core.WorkflowTemplate{ + Nodes: []*core.Node{ + { + Target: &core.Node_TaskNode{ + TaskNode: &core.TaskNode{ + Reference: &core.TaskNode_ReferenceId{ + ReferenceId: &core.Identifier{Name: "Task_1"}, + }, + }, + }, + }, + { + Target: &core.Node_TaskNode{ + TaskNode: &core.TaskNode{ + Reference: &core.TaskNode_ReferenceId{ + ReferenceId: &core.Identifier{Name: "Task_2"}, + }, + }, + }, + }, + { + Target: &core.Node_WorkflowNode{ + WorkflowNode: &core.WorkflowNode{ + Reference: &core.WorkflowNode_LaunchplanRef{ + LaunchplanRef: &core.Identifier{Name: "Graph_1"}, + }, + }, + }, + }, + { + Target: &core.Node_BranchNode{ + BranchNode: &core.BranchNode{ + IfElse: &core.IfElseBlock{ + Case: &core.IfBlock{ + ThenNode: &core.Node{ + Target: &core.Node_WorkflowNode{ + WorkflowNode: &core.WorkflowNode{ + Reference: &core.WorkflowNode_LaunchplanRef{ + LaunchplanRef: &core.Identifier{Name: "Graph_1"}, + }, + }, + }, + }, + }, + Other: []*core.IfBlock{ + { + ThenNode: &core.Node{ + Target: &core.Node_TaskNode{ + TaskNode: &core.TaskNode{ + Reference: &core.TaskNode_ReferenceId{ + ReferenceId: &core.Identifier{Name: "Task_3"}, + }, + }, + }, + }, + }, + { + ThenNode: &core.Node{ + Target: &core.Node_BranchNode{ + BranchNode: &core.BranchNode{ + IfElse: &core.IfElseBlock{ + Case: &core.IfBlock{ + ThenNode: &core.Node{ + Target: &core.Node_WorkflowNode{ + WorkflowNode: &core.WorkflowNode{ + Reference: &core.WorkflowNode_LaunchplanRef{ + LaunchplanRef: &core.Identifier{Name: "Graph_2"}, + }, + }, + }, + }, + }, + Other: []*core.IfBlock{ + { + ThenNode: &core.Node{ + Target: &core.Node_TaskNode{ + TaskNode: &core.TaskNode{ + Reference: &core.TaskNode_ReferenceId{ + ReferenceId: &core.Identifier{Name: "Task_4"}, + }, + }, + }, + }, + }, + { + ThenNode: &core.Node{ + Target: &core.Node_TaskNode{ + TaskNode: &core.TaskNode{ + Reference: &core.TaskNode_ReferenceId{ + ReferenceId: &core.Identifier{Name: "Task_5"}, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + + subWorkflows := make([]*core.WorkflowTemplate, 0) + reqs, err := GetRequirements(g, subWorkflows) + assert.NoError(t, err) + assert.Equal(t, 5, len(reqs.GetRequiredTaskIds())) + assert.Equal(t, 2, len(reqs.GetRequiredLaunchPlanIds())) +} diff --git a/flytepropeller/pkg/compiler/task_compiler.go b/flytepropeller/pkg/compiler/task_compiler.go new file mode 100644 index 0000000000..08f3b88067 --- /dev/null +++ b/flytepropeller/pkg/compiler/task_compiler.go @@ -0,0 +1,98 @@ +package compiler + +import ( + "fmt" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/compiler/common" + "github.com/lyft/flytepropeller/pkg/compiler/errors" + "k8s.io/apimachinery/pkg/api/resource" +) + +func validateResource(resourceName core.Resources_ResourceName, resourceVal string, errs errors.CompileErrors) (ok bool) { + if _, err := resource.ParseQuantity(resourceVal); err != nil { + errs.Collect(errors.NewUnrecognizedValueErr(fmt.Sprintf("resources.%v", resourceName), resourceVal)) + return true + } + return false +} + +func validateKnownResources(resources []*core.Resources_ResourceEntry, errs errors.CompileErrors) (ok bool) { + for _, r := range resources { + validateResource(r.Name, r.Value, errs.NewScope()) + } + + return !errs.HasErrors() +} + +func validateResources(resources *core.Resources, errs errors.CompileErrors) (ok bool) { + // Validate known resource keys. + validateKnownResources(resources.Requests, errs.NewScope()) + validateKnownResources(resources.Limits, errs.NewScope()) + + return !errs.HasErrors() +} + +func validateContainerCommand(task *core.TaskTemplate, errs errors.CompileErrors) (ok bool) { + if task.Interface == nil { + // Nothing to validate. + return + } + hasInputs := task.Interface.Inputs != nil && len(task.Interface.GetInputs().Variables) > 0 + hasOutputs := task.Interface.Outputs != nil && len(task.Interface.GetOutputs().Variables) > 0 + if !(hasInputs || hasOutputs) { + // Nothing to validate. + return + } + if task.GetContainer().Command == nil && task.GetContainer().Args == nil { + // When an interface with inputs or outputs is defined, the container command + args together must not be empty. + errs.Collect(errors.NewValueRequiredErr("container", "command")) + } + + return !errs.HasErrors() +} + +func validateContainer(task *core.TaskTemplate, errs errors.CompileErrors) (ok bool) { + if task.GetContainer() == nil { + errs.Collect(errors.NewValueRequiredErr("root", "container")) + return + } + + validateContainerCommand(task, errs) + + container := task.GetContainer() + if container.Image == "" { + errs.Collect(errors.NewValueRequiredErr("container", "image")) + } + + if container.Resources != nil { + validateResources(container.Resources, errs.NewScope()) + } + + return !errs.HasErrors() +} + +func compileTaskInternal(task *core.TaskTemplate, errs errors.CompileErrors) (common.Task, bool) { + if task.Id == nil { + errs.Collect(errors.NewValueRequiredErr("root", "Id")) + } + + switch task.GetTarget().(type) { + case *core.TaskTemplate_Container: + validateContainer(task, errs.NewScope()) + } + + return taskBuilder{flyteTask: task}, !errs.HasErrors() +} + +// Task compiler compiles a given Task into an executable Task. It validates all required parameters and ensures a Task +// is well-formed. +func CompileTask(task *core.TaskTemplate) (*core.CompiledTask, error) { + errs := errors.NewCompileErrors() + t, _ := compileTaskInternal(task, errs.NewScope()) + if errs.HasErrors() { + return nil, errs + } + + return &core.CompiledTask{Template: t.GetCoreTask()}, nil +} diff --git a/flytepropeller/pkg/compiler/task_compiler_test.go b/flytepropeller/pkg/compiler/task_compiler_test.go new file mode 100644 index 0000000000..ff878509b6 --- /dev/null +++ b/flytepropeller/pkg/compiler/task_compiler_test.go @@ -0,0 +1,82 @@ +package compiler + +import ( + "testing" + + "github.com/lyft/flytepropeller/pkg/compiler/errors" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/stretchr/testify/assert" +) + +func MakeResource(name core.Resources_ResourceName, v string) *core.Resources_ResourceEntry { + return &core.Resources_ResourceEntry{ + Name: name, + Value: v, + } +} + +func TestValidateContainerCommand(t *testing.T) { + task := core.TaskTemplate{ + Id: &core.Identifier{Name: "task_123"}, + Interface: &core.TypedInterface{ + Inputs: createVariableMap(map[string]*core.Variable{ + "foo": {}, + }), + Outputs: createEmptyVariableMap(), + }, + Target: &core.TaskTemplate_Container{ + Container: &core.Container{ + Image: "image://", + }, + }, + } + errs := errors.NewCompileErrors() + assert.False(t, validateContainerCommand(&task, errs)) + assert.Contains(t, errs.Error(), "Node Id: container, Description: Value required [command]") + + task.GetContainer().Command = []string{"cmd"} + errs = errors.NewCompileErrors() + assert.True(t, validateContainerCommand(&task, errs)) + assert.False(t, errs.HasErrors()) +} + +func TestCompileTask(t *testing.T) { + task, err := CompileTask(&core.TaskTemplate{ + Id: &core.Identifier{Name: "task_123"}, + Interface: &core.TypedInterface{ + Inputs: createEmptyVariableMap(), + Outputs: createEmptyVariableMap(), + }, + Target: &core.TaskTemplate_Container{ + Container: &core.Container{ + Image: "image://", + Command: []string{"cmd"}, + Args: []string{"args"}, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + MakeResource(core.Resources_CPU, "5"), + }, + Limits: []*core.Resources_ResourceEntry{ + MakeResource(core.Resources_MEMORY, "100Gi"), + }, + }, + Env: []*core.KeyValuePair{ + { + Key: "Env_Var", + Value: "Env_Val", + }, + }, + Config: []*core.KeyValuePair{ + { + Key: "config_key", + Value: "config_value", + }, + }, + }, + }, + }) + + assert.NoError(t, err) + assert.NotNil(t, task) +} diff --git a/flytepropeller/pkg/compiler/test/compiler_test.go b/flytepropeller/pkg/compiler/test/compiler_test.go new file mode 100644 index 0000000000..f9bc360670 --- /dev/null +++ b/flytepropeller/pkg/compiler/test/compiler_test.go @@ -0,0 +1,247 @@ +package test + +import ( + "encoding/json" + "flag" + "fmt" + "io/ioutil" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/ghodss/yaml" + + "github.com/golang/protobuf/jsonpb" + "github.com/golang/protobuf/proto" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/compiler" + "github.com/lyft/flytepropeller/pkg/compiler/common" + "github.com/lyft/flytepropeller/pkg/compiler/errors" + "github.com/lyft/flytepropeller/pkg/compiler/transformers/k8s" + "github.com/lyft/flytepropeller/pkg/utils" + "github.com/lyft/flytepropeller/pkg/visualize" + "github.com/stretchr/testify/assert" +) + +var update = flag.Bool("update", false, "Update .golden files") +var reverse = flag.Bool("reverse", false, "Reverse .golden files") + +func makeDefaultInputs(iface *core.TypedInterface) *core.LiteralMap { + if iface == nil || iface.GetInputs() == nil { + return nil + } + + res := make(map[string]*core.Literal, len(iface.GetInputs().Variables)) + for inputName, inputVar := range iface.GetInputs().Variables { + val := utils.MustMakeDefaultLiteralForType(inputVar.Type) + res[inputName] = val + } + + return &core.LiteralMap{ + Literals: res, + } +} + +func setDefaultFields(task *core.TaskTemplate) { + if container := task.GetContainer(); container != nil { + if container.Config == nil { + container.Config = []*core.KeyValuePair{} + } + + container.Config = append(container.Config, &core.KeyValuePair{ + Key: "testKey1", + Value: "testValue1", + }) + container.Config = append(container.Config, &core.KeyValuePair{ + Key: "testKey2", + Value: "testValue2", + }) + container.Config = append(container.Config, &core.KeyValuePair{ + Key: "testKey3", + Value: "testValue3", + }) + } +} + +func mustCompileTasks(t *testing.T, tasks []*core.TaskTemplate) []*core.CompiledTask { + compiledTasks := make([]*core.CompiledTask, 0, len(tasks)) + for _, inputTask := range tasks { + setDefaultFields(inputTask) + task, err := compiler.CompileTask(inputTask) + compiledTasks = append(compiledTasks, task) + assert.NoError(t, err) + if err != nil { + assert.FailNow(t, err.Error()) + } + } + + return compiledTasks +} + +func marshalProto(t *testing.T, filename string, p proto.Message) { + marshaller := &jsonpb.Marshaler{} + s, err := marshaller.MarshalToString(p) + assert.NoError(t, err) + + if err != nil { + return + } + + originalRaw, err := proto.Marshal(p) + assert.NoError(t, err) + assert.NoError(t, ioutil.WriteFile(strings.Replace(filename, filepath.Ext(filename), ".pb", 1), originalRaw, os.ModePerm)) + + m := map[string]interface{}{} + err = json.Unmarshal([]byte(s), &m) + assert.NoError(t, err) + + b, err := yaml.Marshal(m) + assert.NoError(t, err) + assert.NoError(t, ioutil.WriteFile(strings.Replace(filename, filepath.Ext(filename), ".yaml", 1), b, os.ModePerm)) +} + +func TestReverseEngineerFromYaml(t *testing.T) { + root := "testdata" + errors.SetConfig(errors.Config{IncludeSource: true}) + assert.NoError(t, filepath.Walk(root, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + if info.IsDir() { + return nil + } + + if !strings.HasSuffix(path, ".yaml") { + return nil + } + + if strings.HasSuffix(path, "-inputs.yaml") { + return nil + } + + ext := ".yaml" + + testName := strings.TrimLeft(path, root) + testName = strings.Trim(testName, string(os.PathSeparator)) + testName = strings.TrimSuffix(testName, ext) + testName = strings.Replace(testName, string(os.PathSeparator), "_", -1) + + t.Run(testName, func(t *testing.T) { + t.Log("Reading from file") + raw, err := ioutil.ReadFile(path) + assert.NoError(t, err) + + raw, err = yaml.YAMLToJSON(raw) + assert.NoError(t, err) + + t.Log("Unmarshalling Workflow Closure") + wf := &core.WorkflowClosure{} + err = jsonpb.UnmarshalString(string(raw), wf) + assert.NoError(t, err) + assert.NotNil(t, wf) + if err != nil { + return + } + + t.Log("Compiling Workflow") + compiledWf, err := compiler.CompileWorkflow(wf.Workflow, []*core.WorkflowTemplate{}, mustCompileTasks(t, wf.Tasks), []common.InterfaceProvider{}) + assert.NoError(t, err) + if err != nil { + return + } + + inputs := makeDefaultInputs(compiledWf.Primary.Template.GetInterface()) + if *reverse { + marshalProto(t, strings.Replace(path, ext, fmt.Sprintf("-inputs%v", ext), -1), inputs) + } + + t.Log("Building k8s resource") + _, err = k8s.BuildFlyteWorkflow(compiledWf, inputs, nil, "") + assert.NoError(t, err) + if err != nil { + return + } + + dotFormat := visualize.ToGraphViz(compiledWf.Primary) + t.Logf("GraphViz Dot: %v\n", dotFormat) + + if *reverse { + marshalProto(t, path, wf) + } + }) + + return nil + })) +} + +func TestCompileAndBuild(t *testing.T) { + root := "testdata" + errors.SetConfig(errors.Config{IncludeSource: true}) + assert.NoError(t, filepath.Walk(root, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + if info.IsDir() { + return nil + } + + if ext := filepath.Ext(path); ext != ".pb" { + return nil + } + + if strings.HasSuffix(path, "-inputs.pb") { + return nil + } + + testName := strings.TrimLeft(path, root) + testName = strings.Trim(testName, string(os.PathSeparator)) + testName = strings.Trim(testName, filepath.Ext(testName)) + testName = strings.Replace(testName, string(os.PathSeparator), "_", -1) + + t.Run(testName, func(t *testing.T) { + t.Log("Reading from file") + raw, err := ioutil.ReadFile(path) + assert.NoError(t, err) + + t.Log("Unmarshalling Workflow Closure") + wf := &core.WorkflowClosure{} + err = proto.Unmarshal(raw, wf) + assert.NoError(t, err) + assert.NotNil(t, wf) + if err != nil { + return + } + + t.Log("Compiling Workflow") + compiledWf, err := compiler.CompileWorkflow(wf.Workflow, []*core.WorkflowTemplate{}, mustCompileTasks(t, wf.Tasks), []common.InterfaceProvider{}) + assert.NoError(t, err) + if err != nil { + return + } + + inputs := makeDefaultInputs(compiledWf.Primary.Template.GetInterface()) + if *update { + marshalProto(t, strings.Replace(path, filepath.Ext(path), fmt.Sprintf("-inputs%v", filepath.Ext(path)), -1), inputs) + } + + t.Log("Building k8s resource") + _, err = k8s.BuildFlyteWorkflow(compiledWf, inputs, nil, "") + assert.NoError(t, err) + if err != nil { + return + } + + dotFormat := visualize.ToGraphViz(compiledWf.Primary) + t.Logf("GraphViz Dot: %v\n", dotFormat) + + if *update { + marshalProto(t, path, wf) + } + }) + + return nil + })) +} diff --git a/flytepropeller/pkg/compiler/test/testdata/app-workflows-work-one-python-task-w-f-inputs.pb b/flytepropeller/pkg/compiler/test/testdata/app-workflows-work-one-python-task-w-f-inputs.pb new file mode 100755 index 0000000000..e69de29bb2 diff --git a/flytepropeller/pkg/compiler/test/testdata/app-workflows-work-one-python-task-w-f-inputs.yaml b/flytepropeller/pkg/compiler/test/testdata/app-workflows-work-one-python-task-w-f-inputs.yaml new file mode 100755 index 0000000000..3893cfb77a --- /dev/null +++ b/flytepropeller/pkg/compiler/test/testdata/app-workflows-work-one-python-task-w-f-inputs.yaml @@ -0,0 +1 @@ +literals: {} diff --git a/flytepropeller/pkg/compiler/test/testdata/app-workflows-work-one-python-task-w-f.pb b/flytepropeller/pkg/compiler/test/testdata/app-workflows-work-one-python-task-w-f.pb new file mode 100755 index 0000000000000000000000000000000000000000..5f6557e400ff8e248c0ac1ca59de2cc54c8758a4 GIT binary patch literal 560 zcma)&J5R$f6or!zmFSjIFrq~k0}4+eIck!P4b-x*FtJu{*%p+WwB=Q-wS8hQ={ji)7`8Vx88C2+l~H%Xh(vr0I&RiH{VlgMt+E0|vtL z2+yLl4$&bP)7j(nX*PQ{m3sv{a1aKmoC|+mEg=(KKo4}#46U7e?bO@3@7%(*d0B#3 zzaX`;{D6D&^114!1xr%&t^Iei-zk94(TSm={fRZPY&yZcCaZ8A)izx?gV7FkNsP7f z%ot9gPZL%VR3xYgbmvQjag_UEF4z^PF?_UJ92Mo_6Y8fMJy_)f=nI= 0 { + for param := range diff { + errs.Collect(errors.NewParameterNotBoundErr(nodeID, param)) + } + } + + return !errs.HasErrors() +} diff --git a/flytepropeller/pkg/compiler/transformers/k8s/inputs_test.go b/flytepropeller/pkg/compiler/transformers/k8s/inputs_test.go new file mode 100644 index 0000000000..01d667c35d --- /dev/null +++ b/flytepropeller/pkg/compiler/transformers/k8s/inputs_test.go @@ -0,0 +1 @@ +package k8s diff --git a/flytepropeller/pkg/compiler/transformers/k8s/node.go b/flytepropeller/pkg/compiler/transformers/k8s/node.go new file mode 100644 index 0000000000..fe31b6e59a --- /dev/null +++ b/flytepropeller/pkg/compiler/transformers/k8s/node.go @@ -0,0 +1,169 @@ +package k8s + +import ( + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/compiler/common" + "github.com/lyft/flytepropeller/pkg/compiler/errors" + "github.com/lyft/flytepropeller/pkg/utils" +) + +// Gets the compiled subgraph if this node contains an inline-declared coreWorkflow. Otherwise nil. +func buildNodeSpec(n *core.Node, tasks []*core.CompiledTask, errs errors.CompileErrors) (*v1alpha1.NodeSpec, bool) { + if n == nil { + errs.Collect(errors.NewValueRequiredErr("root", "node")) + return nil, !errs.HasErrors() + } + + if n.GetId() != common.StartNodeID && n.GetId() != common.EndNodeID && + n.GetTarget() == nil { + + errs.Collect(errors.NewValueRequiredErr(n.GetId(), "target")) + return nil, !errs.HasErrors() + } + + var task *core.TaskTemplate + if n.GetTaskNode() != nil { + taskID := n.GetTaskNode().GetReferenceId().String() + // TODO: Use task index for quick lookup + for _, t := range tasks { + if t.Template.Id.String() == taskID { + task = t.Template + break + } + } + + if task == nil { + errs.Collect(errors.NewTaskReferenceNotFoundErr(n.GetId(), taskID)) + return nil, !errs.HasErrors() + } + } + + res, err := utils.ToK8sResourceRequirements(getResources(task)) + if err != nil { + errs.Collect(errors.NewWorkflowBuildError(err)) + return nil, false + } + + nodeSpec := &v1alpha1.NodeSpec{ + ID: n.GetId(), + RetryStrategy: computeRetryStrategy(n, task), + Resources: res, + OutputAliases: toAliasValueArray(n.GetOutputAliases()), + InputBindings: toBindingValueArray(n.GetInputs()), + ActiveDeadlineSeconds: computeActiveDeadlineSeconds(n, task), + } + + switch v := n.GetTarget().(type) { + case *core.Node_TaskNode: + nodeSpec.Kind = v1alpha1.NodeKindTask + nodeSpec.TaskRef = refStr(n.GetTaskNode().GetReferenceId().String()) + case *core.Node_WorkflowNode: + if n.GetWorkflowNode().Reference == nil { + errs.Collect(errors.NewValueRequiredErr(n.GetId(), "WorkflowNode.Reference")) + return nil, !errs.HasErrors() + } + + switch n.GetWorkflowNode().Reference.(type) { + case *core.WorkflowNode_LaunchplanRef: + nodeSpec.Kind = v1alpha1.NodeKindWorkflow + nodeSpec.WorkflowNode = &v1alpha1.WorkflowNodeSpec{ + LaunchPlanRefID: &v1alpha1.LaunchPlanRefID{Identifier: n.GetWorkflowNode().GetLaunchplanRef()}, + } + case *core.WorkflowNode_SubWorkflowRef: + nodeSpec.Kind = v1alpha1.NodeKindWorkflow + if v.WorkflowNode.GetSubWorkflowRef() != nil { + nodeSpec.WorkflowNode = &v1alpha1.WorkflowNodeSpec{ + SubWorkflowReference: refStr(v.WorkflowNode.GetSubWorkflowRef().String()), + } + } else if v.WorkflowNode.GetLaunchplanRef() != nil { + nodeSpec.WorkflowNode = &v1alpha1.WorkflowNodeSpec{ + LaunchPlanRefID: &v1alpha1.LaunchPlanRefID{Identifier: n.GetWorkflowNode().GetLaunchplanRef()}, + } + } else { + errs.Collect(errors.NewValueRequiredErr(n.GetId(), "WorkflowNode.WorkflowTemplate")) + return nil, !errs.HasErrors() + } + } + case *core.Node_BranchNode: + nodeSpec.Kind = v1alpha1.NodeKindBranch + nodeSpec.BranchNode = buildBranchNodeSpec(n.GetBranchNode(), errs.NewScope()) + default: + if n.GetId() == v1alpha1.StartNodeID { + nodeSpec.Kind = v1alpha1.NodeKindStart + } else if n.GetId() == v1alpha1.EndNodeID { + nodeSpec.Kind = v1alpha1.NodeKindEnd + } + } + + return nodeSpec, !errs.HasErrors() +} + +func buildIfBlockSpec(block *core.IfBlock, _ errors.CompileErrors) *v1alpha1.IfBlock { + return &v1alpha1.IfBlock{ + Condition: v1alpha1.BooleanExpression{BooleanExpression: block.Condition}, + ThenNode: refStr(block.ThenNode.Id), + } +} + +func buildBranchNodeSpec(branch *core.BranchNode, errs errors.CompileErrors) *v1alpha1.BranchNodeSpec { + if branch == nil { + return nil + } + + res := &v1alpha1.BranchNodeSpec{ + If: *buildIfBlockSpec(branch.IfElse.Case, errs.NewScope()), + } + + switch branch.IfElse.GetDefault().(type) { + case *core.IfElseBlock_ElseNode: + res.Else = refStr(branch.IfElse.GetElseNode().Id) + case *core.IfElseBlock_Error: + res.ElseFail = &v1alpha1.Error{Error: branch.IfElse.GetError()} + } + + other := make([]*v1alpha1.IfBlock, 0, len(branch.IfElse.Other)) + for _, block := range branch.IfElse.Other { + other = append(other, buildIfBlockSpec(block, errs.NewScope())) + } + + res.ElseIf = other + + return res +} + +func buildNodes(nodes []*core.Node, tasks []*core.CompiledTask, errs errors.CompileErrors) (map[common.NodeID]*v1alpha1.NodeSpec, bool) { + res := make(map[common.NodeID]*v1alpha1.NodeSpec, len(nodes)) + for _, nodeBuidler := range nodes { + n, ok := buildNodeSpec(nodeBuidler, tasks, errs.NewScope()) + if !ok { + return nil, ok + } + + if _, exists := res[n.ID]; exists { + errs.Collect(errors.NewValueCollisionError(nodeBuidler.GetId(), "Id", n.ID)) + } + + res[n.ID] = n + } + + return res, !errs.HasErrors() +} + +func buildTasks(tasks []*core.CompiledTask, errs errors.CompileErrors) map[common.TaskIDKey]*v1alpha1.TaskSpec { + res := make(map[common.TaskIDKey]*v1alpha1.TaskSpec, len(tasks)) + for _, flyteTask := range tasks { + if flyteTask == nil { + errs.Collect(errors.NewValueRequiredErr("root", "coreTask")) + } else { + taskID := flyteTask.Template.Id.String() + if _, exists := res[taskID]; exists { + errs.Collect(errors.NewValueCollisionError(taskID, "Id", taskID)) + } + + res[taskID] = &v1alpha1.TaskSpec{TaskTemplate: flyteTask.Template} + } + } + + return res +} diff --git a/flytepropeller/pkg/compiler/transformers/k8s/node_test.go b/flytepropeller/pkg/compiler/transformers/k8s/node_test.go new file mode 100644 index 0000000000..abe13f6bcd --- /dev/null +++ b/flytepropeller/pkg/compiler/transformers/k8s/node_test.go @@ -0,0 +1,181 @@ +package k8s + +import ( + "testing" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "k8s.io/apimachinery/pkg/api/resource" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/compiler/common" + "github.com/lyft/flytepropeller/pkg/compiler/errors" + "github.com/stretchr/testify/assert" +) + +func createNodeWithTask() *core.Node { + return &core.Node{ + Id: "n_1", + Target: &core.Node_TaskNode{ + TaskNode: &core.TaskNode{ + Reference: &core.TaskNode_ReferenceId{ + ReferenceId: &core.Identifier{Name: "ref_1"}, + }, + }, + }, + } +} + +func TestBuildNodeSpec(t *testing.T) { + n := mockNode{ + id: "n_1", + Node: &core.Node{}, + } + + tasks := []*core.CompiledTask{ + { + Template: &core.TaskTemplate{ + Id: &core.Identifier{Name: "ref_1"}, + }, + }, + { + Template: &core.TaskTemplate{ + Id: &core.Identifier{Name: "ref_2"}, + Target: &core.TaskTemplate_Container{ + Container: &core.Container{ + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + { + Name: core.Resources_CPU, + Value: "10Mi", + }, + }, + }, + }, + }, + }, + }, + } + + errors.SetConfig(errors.Config{IncludeSource: true}) + errs := errors.NewCompileErrors() + + mustBuild := func(n common.Node, errs errors.CompileErrors) *v1alpha1.NodeSpec { + spec, ok := buildNodeSpec(n.GetCoreNode(), tasks, errs) + assert.False(t, errs.HasErrors()) + assert.True(t, ok) + assert.NotNil(t, spec) + + if errs.HasErrors() { + assert.Fail(t, errs.Error()) + } + + return spec + } + + t.Run("Task", func(t *testing.T) { + n.Node.Target = &core.Node_TaskNode{ + TaskNode: &core.TaskNode{ + Reference: &core.TaskNode_ReferenceId{ + ReferenceId: &core.Identifier{Name: "ref_1"}, + }, + }, + } + + mustBuild(n, errs.NewScope()) + }) + + t.Run("Task with resources", func(t *testing.T) { + expectedCPU := resource.MustParse("10Mi") + n.Node.Target = &core.Node_TaskNode{ + TaskNode: &core.TaskNode{ + Reference: &core.TaskNode_ReferenceId{ + ReferenceId: &core.Identifier{Name: "ref_2"}, + }, + }, + } + + spec := mustBuild(n, errs.NewScope()) + assert.NotNil(t, spec.Resources) + assert.NotNil(t, spec.Resources.Requests.Cpu()) + assert.Equal(t, expectedCPU.Value(), spec.Resources.Requests.Cpu().Value()) + }) + + t.Run("LaunchPlanRef", func(t *testing.T) { + n.Node.Target = &core.Node_WorkflowNode{ + WorkflowNode: &core.WorkflowNode{ + Reference: &core.WorkflowNode_LaunchplanRef{ + LaunchplanRef: &core.Identifier{Name: "ref_1"}, + }, + }, + } + + mustBuild(n, errs.NewScope()) + }) + + t.Run("Workflow", func(t *testing.T) { + n.subWF = createSampleMockWorkflow() + n.Node.Target = &core.Node_WorkflowNode{ + WorkflowNode: &core.WorkflowNode{ + Reference: &core.WorkflowNode_SubWorkflowRef{ + SubWorkflowRef: n.subWF.GetCoreWorkflow().Template.Id, + }, + }, + } + + mustBuild(n, errs.NewScope()) + }) + + t.Run("Branch", func(t *testing.T) { + n.Node.Target = &core.Node_BranchNode{ + BranchNode: &core.BranchNode{ + IfElse: &core.IfElseBlock{ + Other: []*core.IfBlock{}, + Default: &core.IfElseBlock_Error{ + Error: &core.Error{ + Message: "failed", + }, + }, + Case: &core.IfBlock{ + ThenNode: &core.Node{ + Target: &core.Node_TaskNode{ + TaskNode: &core.TaskNode{ + Reference: &core.TaskNode_ReferenceId{ + ReferenceId: &core.Identifier{Name: "ref_1"}, + }, + }, + }, + }, + Condition: &core.BooleanExpression{ + Expr: &core.BooleanExpression_Comparison{ + Comparison: &core.ComparisonExpression{ + Operator: core.ComparisonExpression_EQ, + LeftValue: &core.Operand{ + Val: &core.Operand_Primitive{ + Primitive: &core.Primitive{ + Value: &core.Primitive_Integer{ + Integer: 123, + }, + }, + }, + }, + RightValue: &core.Operand{ + Val: &core.Operand_Primitive{ + Primitive: &core.Primitive{ + Value: &core.Primitive_Integer{ + Integer: 123, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + + mustBuild(n, errs.NewScope()) + }) + +} diff --git a/flytepropeller/pkg/compiler/transformers/k8s/utils.go b/flytepropeller/pkg/compiler/transformers/k8s/utils.go new file mode 100644 index 0000000000..1836944ea6 --- /dev/null +++ b/flytepropeller/pkg/compiler/transformers/k8s/utils.go @@ -0,0 +1,86 @@ +package k8s + +import ( + "math" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" +) + +func refInt(i int) *int { + return &i +} + +func refStr(s string) *string { + return &s +} + +func computeRetryStrategy(n *core.Node, t *core.TaskTemplate) *v1alpha1.RetryStrategy { + if n.GetMetadata() != nil && n.GetMetadata().GetRetries() != nil { + return &v1alpha1.RetryStrategy{ + MinAttempts: refInt(int(n.GetMetadata().GetRetries().Retries + 1)), + } + } + + if t != nil && t.GetMetadata() != nil && t.GetMetadata().GetRetries() != nil { + return &v1alpha1.RetryStrategy{ + MinAttempts: refInt(int(t.GetMetadata().GetRetries().Retries + 1)), + } + } + + return nil +} + +func computeActiveDeadlineSeconds(n *core.Node, t *core.TaskTemplate) *int64 { + if n.GetMetadata() != nil && n.GetMetadata().Timeout != nil { + return &n.GetMetadata().Timeout.Seconds + } + + if t != nil && t.GetMetadata() != nil && t.GetMetadata().Timeout != nil { + return &t.GetMetadata().Timeout.Seconds + } + + return nil +} + +func getResources(task *core.TaskTemplate) *core.Resources { + if task == nil { + return nil + } + + if task.GetContainer() == nil { + return nil + } + + return task.GetContainer().Resources +} + +func toAliasValueArray(aliases []*core.Alias) []v1alpha1.Alias { + if aliases == nil { + return nil + } + + res := make([]v1alpha1.Alias, 0, len(aliases)) + for _, alias := range aliases { + res = append(res, v1alpha1.Alias{Alias: *alias}) + } + + return res +} + +func toBindingValueArray(bindings []*core.Binding) []*v1alpha1.Binding { + if bindings == nil { + return nil + } + + res := make([]*v1alpha1.Binding, 0, len(bindings)) + for _, binding := range bindings { + res = append(res, &v1alpha1.Binding{Binding: binding}) + } + + return res +} + +func minInt(i, j int) int { + return int(math.Min(float64(i), float64(j))) +} diff --git a/flytepropeller/pkg/compiler/transformers/k8s/utils_test.go b/flytepropeller/pkg/compiler/transformers/k8s/utils_test.go new file mode 100644 index 0000000000..371b2531a6 --- /dev/null +++ b/flytepropeller/pkg/compiler/transformers/k8s/utils_test.go @@ -0,0 +1,58 @@ +package k8s + +import ( + "testing" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/stretchr/testify/assert" +) + +func TestComputeRetryStrategy(t *testing.T) { + + tests := []struct { + name string + nodeRetries int + taskRetries int + expectedRetries int + }{ + {"node-only", 1, 0, 2}, + {"task-only", 0, 1, 2}, + {"node-task", 2, 3, 3}, + {"no-retries", 0, 0, 0}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + + var node *core.Node + if test.nodeRetries != 0 { + node = &core.Node{ + Metadata: &core.NodeMetadata{ + Retries: &core.RetryStrategy{ + Retries: uint32(test.nodeRetries), + }, + }, + } + } + + var tmpl *core.TaskTemplate + if test.taskRetries != 0 { + tmpl = &core.TaskTemplate{ + Metadata: &core.TaskMetadata{ + Retries: &core.RetryStrategy{ + Retries: uint32(test.taskRetries), + }, + }, + } + } + + r := computeRetryStrategy(node, tmpl) + if test.expectedRetries != 0 { + assert.NotNil(t, r) + assert.Equal(t, test.expectedRetries, *r.MinAttempts) + } else { + assert.Nil(t, r) + } + }) + } + +} diff --git a/flytepropeller/pkg/compiler/transformers/k8s/workflow.go b/flytepropeller/pkg/compiler/transformers/k8s/workflow.go new file mode 100644 index 0000000000..868e36bbf6 --- /dev/null +++ b/flytepropeller/pkg/compiler/transformers/k8s/workflow.go @@ -0,0 +1,208 @@ +// This package converts the output of the compiler into a K8s resource for propeller to execute. +package k8s + +import ( + "fmt" + "strings" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/compiler/common" + "github.com/lyft/flytepropeller/pkg/compiler/errors" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +const ExecutionIDLabel = "execution-id" +const WorkflowIDLabel = "workflow-id" + +func requiresInputs(w *core.WorkflowTemplate) bool { + if w == nil || w.GetInterface() == nil || w.GetInterface().GetInputs() == nil || + w.GetInterface().GetInputs().Variables == nil { + + return false + } + + return len(w.GetInterface().GetInputs().Variables) > 0 +} + +func WorkflowIDAsString(id *core.Identifier) string { + b := strings.Builder{} + _, err := b.WriteString(id.Project) + if err != nil { + return "" + } + + _, err = b.WriteRune(':') + if err != nil { + return "" + } + + _, err = b.WriteString(id.Domain) + if err != nil { + return "" + } + + _, err = b.WriteRune(':') + if err != nil { + return "" + } + + _, err = b.WriteString(id.Name) + if err != nil { + return "" + } + + return b.String() +} + +func buildFlyteWorkflowSpec(wf *core.CompiledWorkflow, tasks []*core.CompiledTask, errs errors.CompileErrors) ( + spec *v1alpha1.WorkflowSpec, ok bool) { + var failureN *v1alpha1.NodeSpec + if n := wf.Template.GetFailureNode(); n != nil { + failureN, _ = buildNodeSpec(n, tasks, errs.NewScope()) + } + + nodes, _ := buildNodes(wf.Template.GetNodes(), tasks, errs.NewScope()) + + if errs.HasErrors() { + return nil, !errs.HasErrors() + } + + outputBindings := make([]*v1alpha1.Binding, 0, len(wf.Template.Outputs)) + for _, b := range wf.Template.Outputs { + outputBindings = append(outputBindings, &v1alpha1.Binding{ + Binding: b, + }) + } + + var outputs *v1alpha1.OutputVarMap + if wf.Template.GetInterface() != nil { + outputs = &v1alpha1.OutputVarMap{VariableMap: wf.Template.GetInterface().Outputs} + } else { + outputs = &v1alpha1.OutputVarMap{VariableMap: &core.VariableMap{}} + } + + return &v1alpha1.WorkflowSpec{ + ID: WorkflowIDAsString(wf.Template.Id), + OnFailure: failureN, + Nodes: nodes, + Connections: buildConnections(wf), + Outputs: outputs, + OutputBindings: outputBindings, + }, !errs.HasErrors() +} + +func withSeparatorIfNotEmpty(value string) string { + if len(value) > 0 { + return fmt.Sprintf("%v-", value) + } + + return "" +} + +func generateName(wfID *core.Identifier, execID *core.WorkflowExecutionIdentifier) ( + name string, generateName string, label string, err error) { + + if execID != nil { + return execID.Name, "", execID.Name, nil + } else if wfID != nil { + wid := fmt.Sprintf("%v%v%v", + withSeparatorIfNotEmpty(wfID.Project), + withSeparatorIfNotEmpty(wfID.Domain), + wfID.Name, + ) + + // TODO: this is a hack until we figure out how to restrict generated names. K8s has a limitation of 63 chars + wid = wid[:minInt(32, len(wid))] + return "", fmt.Sprintf("%v-", wid), wid, nil + } else { + return "", "", "", fmt.Errorf("expected param not set. wfID or execID must be non-nil values") + } +} + +// Builds v1alpha1.FlyteWorkflow resource. Returned error, if not nil, is of type errors.CompilerErrors. +func BuildFlyteWorkflow(wfClosure *core.CompiledWorkflowClosure, inputs *core.LiteralMap, + executionID *core.WorkflowExecutionIdentifier, namespace string) (*v1alpha1.FlyteWorkflow, error) { + + errs := errors.NewCompileErrors() + if wfClosure == nil { + errs.Collect(errors.NewValueRequiredErr("root", "wfClosure")) + return nil, errs + } + + primarySpec, _ := buildFlyteWorkflowSpec(wfClosure.Primary, wfClosure.Tasks, errs.NewScope()) + subwfs := make(map[v1alpha1.WorkflowID]*v1alpha1.WorkflowSpec, len(wfClosure.SubWorkflows)) + for _, subWf := range wfClosure.SubWorkflows { + spec, _ := buildFlyteWorkflowSpec(wfClosure.Primary, wfClosure.Tasks, errs.NewScope()) + subwfs[subWf.Template.Id.String()] = spec + } + + wf := wfClosure.Primary.Template + tasks := wfClosure.Tasks + // Fill in inputs in the start node. + if inputs != nil { + if ok := validateInputs(common.StartNodeID, wf.GetInterface(), *inputs, errs.NewScope()); !ok { + return nil, errs + } + } else if requiresInputs(wf) { + errs.Collect(errors.NewValueRequiredErr("root", "inputs")) + return nil, errs + } + + obj := &v1alpha1.FlyteWorkflow{ + TypeMeta: v1.TypeMeta{ + Kind: v1alpha1.FlyteWorkflowKind, + APIVersion: v1alpha1.SchemeGroupVersion.String(), + }, + ObjectMeta: v1.ObjectMeta{ + Namespace: namespace, + Labels: map[string]string{}, + }, + Inputs: &v1alpha1.Inputs{LiteralMap: inputs}, + WorkflowSpec: primarySpec, + SubWorkflows: subwfs, + Tasks: buildTasks(tasks, errs.NewScope()), + } + + var err error + obj.ObjectMeta.Name, obj.ObjectMeta.GenerateName, obj.ObjectMeta.Labels[ExecutionIDLabel], err = + generateName(wf.GetId(), executionID) + + if err != nil { + errs.Collect(errors.NewWorkflowBuildError(err)) + } + + if obj.Nodes == nil || obj.Connections.DownstreamEdges == nil { + // If we come here, we'd better have an error generated earlier. Otherwise, add one to make sure build fails. + if !errs.HasErrors() { + errs.Collect(errors.NewWorkflowBuildError(fmt.Errorf("failed to build workflow for unknown reason." + + " Make sure to pass this workflow through the compiler first"))) + } + } else if startingNodes, err := obj.FromNode(v1alpha1.StartNodeID); err == nil && len(startingNodes) == 0 { + errs.Collect(errors.NewWorkflowHasNoEntryNodeErr(wf.GetId().String())) + } else if err != nil { + errs.Collect(errors.NewWorkflowBuildError(err)) + } + + if errs.HasErrors() { + return nil, errs + } + + return obj, nil +} + +func toMapOfLists(connections map[string]*core.ConnectionSet_IdList) map[string][]string { + res := make(map[string][]string, len(connections)) + for key, val := range connections { + res[key] = val.Ids + } + + return res +} + +func buildConnections(w *core.CompiledWorkflow) v1alpha1.Connections { + res := v1alpha1.Connections{} + res.DownstreamEdges = toMapOfLists(w.GetConnections().GetDownstream()) + res.UpstreamEdges = toMapOfLists(w.GetConnections().GetUpstream()) + return res +} diff --git a/flytepropeller/pkg/compiler/transformers/k8s/workflow_test.go b/flytepropeller/pkg/compiler/transformers/k8s/workflow_test.go new file mode 100644 index 0000000000..098dab245a --- /dev/null +++ b/flytepropeller/pkg/compiler/transformers/k8s/workflow_test.go @@ -0,0 +1,238 @@ +package k8s + +import ( + "testing" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/compiler/common" + "github.com/lyft/flytepropeller/pkg/compiler/errors" + "github.com/lyft/flytepropeller/pkg/utils" + "github.com/stretchr/testify/assert" + "k8s.io/apimachinery/pkg/util/sets" +) + +func createSampleMockWorkflow() *mockWorkflow { + return &mockWorkflow{ + tasks: common.TaskIndex{ + "task_1": &mockTask{ + task: &core.TaskTemplate{ + Id: &core.Identifier{Name: "task_1"}, + }, + }, + }, + nodes: common.NodeIndex{ + "node_1": &mockNode{ + upstream: []string{common.StartNodeID}, + inputs: []*core.Binding{}, + task: &mockTask{}, + id: "node_1", + Node: createNodeWithTask(), + }, + common.StartNodeID: &mockNode{ + id: common.StartNodeID, + Node: &core.Node{}, + }, + }, + //failureNode: &mockNode{ + // id: "node_1", + //}, + downstream: common.StringAdjacencyList{ + common.StartNodeID: sets.NewString("node_1"), + }, + upstream: common.StringAdjacencyList{ + "node_1": sets.NewString(common.StartNodeID), + }, + CompiledWorkflow: &core.CompiledWorkflow{ + Template: &core.WorkflowTemplate{ + Id: &core.Identifier{Name: "wf_1"}, + Nodes: []*core.Node{ + createNodeWithTask(), + { + Id: common.StartNodeID, + }, + }, + }, + Connections: &core.ConnectionSet{ + Downstream: map[string]*core.ConnectionSet_IdList{ + common.StartNodeID: { + Ids: []string{"node_1"}, + }, + }, + }, + }, + } +} + +func TestWorkflowIDAsString(t *testing.T) { + assert.Equal(t, "project:domain:name", WorkflowIDAsString(&core.Identifier{ + Project: "project", + Domain: "domain", + Name: "name", + Version: "v1", + })) + + assert.Equal(t, ":domain:name", WorkflowIDAsString(&core.Identifier{ + Domain: "domain", + Name: "name", + Version: "v1", + })) + + assert.Equal(t, "project::name", WorkflowIDAsString(&core.Identifier{ + Project: "project", + Name: "name", + Version: "v1", + })) + + assert.Equal(t, "project:domain:name", WorkflowIDAsString(&core.Identifier{ + Project: "project", + Domain: "domain", + Name: "name", + })) +} + +func TestBuildFlyteWorkflow(t *testing.T) { + w := createSampleMockWorkflow() + + errors.SetConfig(errors.Config{IncludeSource: true}) + wf, err := BuildFlyteWorkflow( + &core.CompiledWorkflowClosure{ + Primary: w.GetCoreWorkflow(), + Tasks: []*core.CompiledTask{ + { + Template: &core.TaskTemplate{ + Id: &core.Identifier{Name: "ref_1"}, + }, + }, + }, + }, + nil, nil, "") + assert.NoError(t, err) + assert.NotNil(t, wf) + errors.SetConfig(errors.Config{}) +} + +func TestBuildFlyteWorkflow_withInputs(t *testing.T) { + w := createSampleMockWorkflow() + + startNode := w.GetNodes()[common.StartNodeID].(*mockNode) + vars := []*core.Variable{ + { + Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}}, + }, + { + Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRING}}, + }, + } + + w.Template.Interface = &core.TypedInterface{ + Inputs: &core.VariableMap{ + Variables: map[string]*core.Variable{ + "x": vars[0], + "y": vars[1], + }, + }, + } + + startNode.iface = &core.TypedInterface{ + Outputs: &core.VariableMap{ + Variables: map[string]*core.Variable{ + "x": vars[0], + "y": vars[1], + }, + }, + } + + intLiteral, err := utils.MakePrimitiveLiteral(123) + assert.NoError(t, err) + stringLiteral, err := utils.MakePrimitiveLiteral("hello") + assert.NoError(t, err) + inputs := &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "x": intLiteral, + "y": stringLiteral, + }, + } + + errors.SetConfig(errors.Config{IncludeSource: true}) + wf, err := BuildFlyteWorkflow( + &core.CompiledWorkflowClosure{ + Primary: w.GetCoreWorkflow(), + Tasks: []*core.CompiledTask{ + { + Template: &core.TaskTemplate{ + Id: &core.Identifier{Name: "ref_1"}, + }, + }, + }, + }, + inputs, nil, "") + assert.NoError(t, err) + assert.NotNil(t, wf) + errors.SetConfig(errors.Config{}) + + assert.Equal(t, 2, len(wf.Inputs.Literals)) + assert.Equal(t, int64(123), wf.Inputs.Literals["x"].GetScalar().GetPrimitive().GetInteger()) +} + +func TestGenerateName(t *testing.T) { + t.Run("Invalid params", func(t *testing.T) { + _, _, _, err := generateName(nil, nil) + assert.Error(t, err) + }) + + t.Run("wfID full", func(t *testing.T) { + name, generateName, _, err := generateName(&core.Identifier{ + Name: "myworkflow", + Project: "myproject", + Domain: "development", + }, nil) + + assert.NoError(t, err) + assert.Empty(t, name) + assert.Equal(t, "myproject-development-myworkflow-", generateName) + }) + + t.Run("wfID missing project domain", func(t *testing.T) { + name, generateName, _, err := generateName(&core.Identifier{ + Name: "myworkflow", + }, nil) + + assert.NoError(t, err) + assert.Empty(t, name) + assert.Equal(t, "myworkflow-", generateName) + }) + + t.Run("wfID too long", func(t *testing.T) { + name, generateName, _, err := generateName(&core.Identifier{ + Name: "workflowsomethingsomethingsomething", + Project: "myproject", + Domain: "development", + }, nil) + + assert.NoError(t, err) + assert.Empty(t, name) + assert.Equal(t, "myproject-development-workflowso-", generateName) + }) + + t.Run("execID full", func(t *testing.T) { + name, generateName, _, err := generateName(nil, &core.WorkflowExecutionIdentifier{ + Name: "myexecution", + Project: "myproject", + Domain: "development", + }) + + assert.NoError(t, err) + assert.Empty(t, generateName) + assert.Equal(t, "myexecution", name) + }) + + t.Run("execID missing project domain", func(t *testing.T) { + name, generateName, _, err := generateName(nil, &core.WorkflowExecutionIdentifier{ + Name: "myexecution", + }) + + assert.NoError(t, err) + assert.Empty(t, generateName) + assert.Equal(t, "myexecution", name) + }) +} diff --git a/flytepropeller/pkg/compiler/typing/variable.go b/flytepropeller/pkg/compiler/typing/variable.go new file mode 100644 index 0000000000..958d49e3c4 --- /dev/null +++ b/flytepropeller/pkg/compiler/typing/variable.go @@ -0,0 +1,36 @@ +package typing + +import ( + "fmt" + "regexp" + "strconv" +) + +var arrayVarMatcher = regexp.MustCompile(`(\[(?P\d+)\]\.)?(?P\w+)`) + +type Variable struct { + Name string + Index *int +} + +// Parses var names +func ParseVarName(varName string) (v Variable, err error) { + allMatches := arrayVarMatcher.FindAllStringSubmatch(varName, -1) + if len(allMatches) != 1 { + return Variable{}, fmt.Errorf("unexpected number of matches [%v]", len(allMatches)) + } + + if len(allMatches[0]) != 4 { + return Variable{}, fmt.Errorf("unexpected number of groups [%v]", len(allMatches[0])) + } + + res := Variable{} + if len(allMatches[0][2]) > 0 { + index, convErr := strconv.Atoi(allMatches[0][2]) + err = convErr + res.Index = &index + } + + res.Name = allMatches[0][3] + return res, err +} diff --git a/flytepropeller/pkg/compiler/utils.go b/flytepropeller/pkg/compiler/utils.go new file mode 100755 index 0000000000..1ff24fc1bc --- /dev/null +++ b/flytepropeller/pkg/compiler/utils.go @@ -0,0 +1,79 @@ +package compiler + +import ( + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/compiler/common" + "k8s.io/apimachinery/pkg/util/sets" +) + +func toInterfaceProviderMap(tasks []common.InterfaceProvider) map[string]common.InterfaceProvider { + res := make(map[string]common.InterfaceProvider, len(tasks)) + for _, task := range tasks { + res[task.GetID().String()] = task + } + + return res +} + +func toSlice(s sets.String) []string { + res := make([]string, 0, len(s)) + for str := range s { + res = append(res, str) + } + + return res +} + +func toNodeIdsSet(nodes common.NodeIndex) sets.String { + res := sets.NewString() + for nodeID := range nodes { + res.Insert(nodeID) + } + + return res +} + +// Runs a depth-first coreWorkflow traversal to detect any cycles in the coreWorkflow. It produces the first cycle found, as well as +// all visited nodes and a boolean indicating whether or not it found a cycle. +func detectCycle(startNode string, neighbors func(nodeId string) sets.String) (cycle []common.NodeID, visited sets.String, + detected bool) { + + // This is a set of nodes that were ever visited. + visited = sets.NewString() + // This is a set of in-progress visiting nodes. + visiting := sets.NewString() + var detector func(nodeId string) ([]common.NodeID, bool) + detector = func(nodeId string) ([]common.NodeID, bool) { + if visiting.Has(nodeId) { + return []common.NodeID{}, true + } + + visiting.Insert(nodeId) + visited.Insert(nodeId) + + for nextID := range neighbors(nodeId) { + if path, detected := detector(nextID); detected { + return append([]common.NodeID{nextID}, path...), true + } + } + + visiting.Delete(nodeId) + + return []common.NodeID{}, false + } + + if path, detected := detector(startNode); detected { + return append([]common.NodeID{startNode}, path...), visiting, true + } + + return []common.NodeID{}, visited, false +} + +func toCompiledWorkflows(wfs ...*core.WorkflowTemplate) []*core.CompiledWorkflow { + compiledSubWfs := make([]*core.CompiledWorkflow, 0, len(wfs)) + for _, wf := range wfs { + compiledSubWfs = append(compiledSubWfs, &core.CompiledWorkflow{Template: wf}) + } + + return compiledSubWfs +} diff --git a/flytepropeller/pkg/compiler/utils_test.go b/flytepropeller/pkg/compiler/utils_test.go new file mode 100644 index 0000000000..6d7065f7a3 --- /dev/null +++ b/flytepropeller/pkg/compiler/utils_test.go @@ -0,0 +1,65 @@ +package compiler + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "k8s.io/apimachinery/pkg/util/sets" +) + +func neighbors(adjList map[string][]string) func(nodeId string) sets.String { + return func(nodeId string) sets.String { + if lst, found := adjList[nodeId]; found { + return sets.NewString(lst...) + } + + return sets.NewString() + } +} + +func uniqueNodesCount(adjList map[string][]string) int { + uniqueNodeIds := sets.NewString() + for key, value := range adjList { + uniqueNodeIds.Insert(key) + uniqueNodeIds.Insert(value...) + } + + return uniqueNodeIds.Len() +} + +func assertNoCycle(t *testing.T, startNode string, adjList map[string][]string) { + cycle, visited, detected := detectCycle(startNode, neighbors(adjList)) + assert.False(t, detected) + assert.Equal(t, uniqueNodesCount(adjList), len(visited)) + assert.Equal(t, 0, len(cycle)) +} + +func assertCycle(t *testing.T, startNode string, adjList map[string][]string) { + cycle, _, detected := detectCycle(startNode, neighbors(adjList)) + assert.True(t, detected) + assert.NotEqual(t, 0, len(cycle)) + t.Logf("Cycle: %v", strings.Join(cycle, ",")) +} + +func TestDetectCycle(t *testing.T) { + t.Run("Linear", func(t *testing.T) { + linear := map[string][]string{ + "1": {"2"}, + "2": {"3"}, + "3": {"4"}, + } + + assertNoCycle(t, "1", linear) + }) + + t.Run("Cycle", func(t *testing.T) { + cyclic := map[string][]string{ + "1": {"2", "3"}, + "2": {"3"}, + "3": {"1"}, + } + + assertCycle(t, "1", cyclic) + }) +} diff --git a/flytepropeller/pkg/compiler/validators/bindings.go b/flytepropeller/pkg/compiler/validators/bindings.go new file mode 100644 index 0000000000..10cf0c705c --- /dev/null +++ b/flytepropeller/pkg/compiler/validators/bindings.go @@ -0,0 +1,105 @@ +package validators + +import ( + "reflect" + + "github.com/lyft/flytepropeller/pkg/compiler/typing" + + flyte "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + c "github.com/lyft/flytepropeller/pkg/compiler/common" + "github.com/lyft/flytepropeller/pkg/compiler/errors" + "k8s.io/apimachinery/pkg/util/sets" +) + +func validateBinding(w c.WorkflowBuilder, nodeID c.NodeID, nodeParam string, binding *flyte.BindingData, expectedType *flyte.LiteralType, errs errors.CompileErrors) ( + []c.NodeID, bool) { + + switch binding.GetValue().(type) { + case *flyte.BindingData_Collection: + if expectedType.GetCollectionType() != nil { + allNodeIds := make([]c.NodeID, 0, len(binding.GetMap().GetBindings())) + for _, v := range binding.GetCollection().GetBindings() { + if nodeIds, ok := validateBinding(w, nodeID, nodeParam, v, expectedType.GetCollectionType(), errs.NewScope()); ok { + allNodeIds = append(allNodeIds, nodeIds...) + } + } + return allNodeIds, !errs.HasErrors() + } + errs.Collect(errors.NewMismatchingBindingsErr(nodeID, nodeParam, expectedType.String(), binding.GetCollection().String())) + case *flyte.BindingData_Map: + if expectedType.GetMapValueType() != nil { + allNodeIds := make([]c.NodeID, 0, len(binding.GetMap().GetBindings())) + for _, v := range binding.GetMap().GetBindings() { + if nodeIds, ok := validateBinding(w, nodeID, nodeParam, v, expectedType.GetMapValueType(), errs.NewScope()); ok { + allNodeIds = append(allNodeIds, nodeIds...) + } + } + return allNodeIds, !errs.HasErrors() + } + + errs.Collect(errors.NewMismatchingBindingsErr(nodeID, nodeParam, expectedType.String(), binding.GetMap().String())) + case *flyte.BindingData_Promise: + if upNode, found := validateNodeID(w, binding.GetPromise().NodeId, errs.NewScope()); found { + v, err := typing.ParseVarName(binding.GetPromise().GetVar()) + if err != nil { + errs.Collect(errors.NewSyntaxError(nodeID, binding.GetPromise().GetVar(), err)) + return nil, !errs.HasErrors() + } + + if param, paramFound := validateOutputVar(upNode, v.Name, errs.NewScope()); paramFound { + if AreTypesCastable(param.Type, expectedType) { + binding.GetPromise().NodeId = upNode.GetId() + return []c.NodeID{binding.GetPromise().NodeId}, true + } + errs.Collect(errors.NewMismatchingTypesErr(nodeID, binding.GetPromise().Var, param.Type.String(), expectedType.String())) + } + } + case *flyte.BindingData_Scalar: + literalType := literalTypeForScalar(binding.GetScalar()) + if literalType == nil { + errs.Collect(errors.NewUnrecognizedValueErr(nodeID, reflect.TypeOf(binding.GetScalar().GetValue()).String())) + } + + if !AreTypesCastable(literalType, expectedType) { + errs.Collect(errors.NewMismatchingTypesErr(nodeID, nodeParam, literalType.String(), expectedType.String())) + } + + return []c.NodeID{}, !errs.HasErrors() + default: + errs.Collect(errors.NewUnrecognizedValueErr(nodeID, reflect.TypeOf(binding.GetValue()).String())) + } + + return nil, !errs.HasErrors() +} + +func ValidateBindings(w c.WorkflowBuilder, node c.Node, bindings []*flyte.Binding, params *flyte.VariableMap, + errs errors.CompileErrors) (ok bool) { + + providedBindings := sets.NewString() + for _, binding := range bindings { + if param, ok := findVariableByName(params, binding.GetVar()); !ok { + errs.Collect(errors.NewVariableNameNotFoundErr(node.GetId(), node.GetId(), binding.GetVar())) + } else if binding.GetBinding() == nil { + errs.Collect(errors.NewValueRequiredErr(node.GetId(), "Binding")) + } else if providedBindings.Has(binding.GetVar()) { + errs.Collect(errors.NewParameterBoundMoreThanOnceErr(node.GetId(), binding.GetVar())) + } else { + providedBindings.Insert(binding.GetVar()) + if upstreamNodes, bindingOk := validateBinding(w, node.GetId(), binding.GetVar(), binding.GetBinding(), param.Type, errs.NewScope()); bindingOk { + for _, upNode := range upstreamNodes { + // Add implicit Edges + w.AddExecutionEdge(upNode, node.GetId()) + } + } + } + } + + // If we missed binding some params, add errors + for paramName := range params.Variables { + if !providedBindings.Has(paramName) { + errs.Collect(errors.NewParameterNotBoundErr(node.GetId(), paramName)) + } + } + + return !errs.HasErrors() +} diff --git a/flytepropeller/pkg/compiler/validators/branch.go b/flytepropeller/pkg/compiler/validators/branch.go new file mode 100644 index 0000000000..a85131078f --- /dev/null +++ b/flytepropeller/pkg/compiler/validators/branch.go @@ -0,0 +1,92 @@ +package validators + +import ( + flyte "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + c "github.com/lyft/flytepropeller/pkg/compiler/common" + "github.com/lyft/flytepropeller/pkg/compiler/errors" + "k8s.io/apimachinery/pkg/util/sets" +) + +func validateBranchInterface(w c.WorkflowBuilder, node c.NodeBuilder, errs errors.CompileErrors) (iface *flyte.TypedInterface, ok bool) { + if branch := node.GetBranchNode(); branch == nil { + errs.Collect(errors.NewValueRequiredErr(node.GetId(), "Branch")) + return + } + + if ifBlock := node.GetBranchNode().IfElse; ifBlock == nil { + errs.Collect(errors.NewValueRequiredErr(node.GetId(), "Branch.IfElse")) + return + } + + if ifCase := node.GetBranchNode().IfElse.Case; ifCase == nil { + errs.Collect(errors.NewValueRequiredErr(node.GetId(), "Branch.IfElse.Case")) + return + } + + if thenNode := node.GetBranchNode().IfElse.Case.ThenNode; thenNode == nil { + errs.Collect(errors.NewValueRequiredErr(node.GetId(), "Branch.IfElse.Case.ThenNode")) + return + } + + finalInputParameterNames := sets.NewString() + finalOutputParameterNames := sets.NewString() + + var inputs map[string]*flyte.Variable + var outputs map[string]*flyte.Variable + inputsSet := sets.NewString() + outputsSet := sets.NewString() + + validateIfaceMatch := func(nodeId string, iface2 *flyte.TypedInterface, errsScope errors.CompileErrors) (match bool) { + inputs2, inputs2Set := buildVariablesIndex(iface2.Inputs) + validateVarsSetMatch(nodeId, inputs, inputs2, inputsSet, inputs2Set, errsScope.NewScope()) + finalInputParameterNames = finalInputParameterNames.Intersection(inputs2Set) + + outputs2, outputs2Set := buildVariablesIndex(iface2.Outputs) + validateVarsSetMatch(nodeId, outputs, outputs2, outputsSet, outputs2Set, errsScope.NewScope()) + finalOutputParameterNames = finalOutputParameterNames.Intersection(outputs2Set) + + return !errsScope.HasErrors() + } + + cases := make([]*flyte.IfBlock, 0, len(node.GetBranchNode().IfElse.Other)+1) + caseBlock := node.GetBranchNode().IfElse.Case + cases = append(cases, caseBlock) + + if otherCases := node.GetBranchNode().IfElse.Other; otherCases != nil { + cases = append(cases, otherCases...) + } + + for _, block := range cases { + if block.ThenNode == nil { + errs.Collect(errors.NewValueRequiredErr(node.GetId(), "IfElse.Case.ThenNode")) + continue + } + + n := w.NewNodeBuilder(block.ThenNode) + if iface == nil { + // if this is the first node to validate, just assume all other nodes will match the interface + if iface, ok = ValidateUnderlyingInterface(w, n, errs.NewScope()); ok { + inputs, inputsSet = buildVariablesIndex(iface.Inputs) + finalInputParameterNames = finalInputParameterNames.Union(inputsSet) + + outputs, outputsSet = buildVariablesIndex(iface.Outputs) + finalOutputParameterNames = finalOutputParameterNames.Union(outputsSet) + } + } else { + if iface2, ok2 := ValidateUnderlyingInterface(w, n, errs.NewScope()); ok2 { + validateIfaceMatch(n.GetId(), iface2, errs.NewScope()) + } + } + } + + if !errs.HasErrors() { + iface = &flyte.TypedInterface{ + Inputs: filterVariables(iface.Inputs, finalInputParameterNames), + Outputs: filterVariables(iface.Outputs, finalOutputParameterNames), + } + } else { + iface = nil + } + + return iface, !errs.HasErrors() +} diff --git a/flytepropeller/pkg/compiler/validators/condition.go b/flytepropeller/pkg/compiler/validators/condition.go new file mode 100644 index 0000000000..0c800a9b87 --- /dev/null +++ b/flytepropeller/pkg/compiler/validators/condition.go @@ -0,0 +1,55 @@ +package validators + +import ( + "fmt" + + flyte "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + c "github.com/lyft/flytepropeller/pkg/compiler/common" + "github.com/lyft/flytepropeller/pkg/compiler/errors" +) + +func validateOperand(node c.NodeBuilder, paramName string, operand *flyte.Operand, + errs errors.CompileErrors) (literalType *flyte.LiteralType, ok bool) { + if operand == nil { + errs.Collect(errors.NewValueRequiredErr(node.GetId(), paramName)) + } else if operand.GetPrimitive() != nil { + // no validation + literalType = literalTypeForPrimitive(operand.GetPrimitive()) + } else if operand.GetVar() != "" { + if node.GetInterface() != nil { + if param, paramOk := validateInputVar(node, operand.GetVar(), errs.NewScope()); paramOk { + literalType = param.GetType() + } + } + } else { + errs.Collect(errors.NewValueRequiredErr(node.GetId(), fmt.Sprintf("%v.%v", paramName, "Val"))) + } + + return literalType, !errs.HasErrors() +} + +func ValidateBooleanExpression(node c.NodeBuilder, expr *flyte.BooleanExpression, errs errors.CompileErrors) (ok bool) { + if expr == nil { + errs.Collect(errors.NewBranchNodeHasNoCondition(node.GetId())) + } else { + if expr.GetComparison() != nil { + op1Type, op1Valid := validateOperand(node, "RightValue", + expr.GetComparison().GetRightValue(), errs.NewScope()) + op2Type, op2Valid := validateOperand(node, "LeftValue", + expr.GetComparison().GetLeftValue(), errs.NewScope()) + if op1Valid && op2Valid { + if op1Type.String() != op2Type.String() { + errs.Collect(errors.NewMismatchingTypesErr(node.GetId(), "RightValue", + op1Type.String(), op2Type.String())) + } + } + } else if expr.GetConjunction() != nil { + ValidateBooleanExpression(node, expr.GetConjunction().LeftExpression, errs.NewScope()) + ValidateBooleanExpression(node, expr.GetConjunction().RightExpression, errs.NewScope()) + } else { + errs.Collect(errors.NewValueRequiredErr(node.GetId(), "Expr")) + } + } + + return !errs.HasErrors() +} diff --git a/flytepropeller/pkg/compiler/validators/interface.go b/flytepropeller/pkg/compiler/validators/interface.go new file mode 100644 index 0000000000..f5cc3b8bf6 --- /dev/null +++ b/flytepropeller/pkg/compiler/validators/interface.go @@ -0,0 +1,128 @@ +package validators + +import ( + "fmt" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + c "github.com/lyft/flytepropeller/pkg/compiler/common" + "github.com/lyft/flytepropeller/pkg/compiler/errors" +) + +// Validate interface has its required attributes set +func ValidateInterface(nodeID c.NodeID, iface *core.TypedInterface, errs errors.CompileErrors) ( + typedInterface *core.TypedInterface, ok bool) { + + if iface == nil { + iface = &core.TypedInterface{} + } + + // validate InputsRef/OutputsRef parameters required attributes are set + if iface.Inputs != nil && iface.Inputs.Variables != nil { + validateVariables(nodeID, iface.Inputs, errs.NewScope()) + } else { + iface.Inputs = &core.VariableMap{Variables: map[string]*core.Variable{}} + } + + if iface.Outputs != nil && iface.Outputs.Variables != nil { + validateVariables(nodeID, iface.Outputs, errs.NewScope()) + } else { + iface.Outputs = &core.VariableMap{Variables: map[string]*core.Variable{}} + } + + return iface, !errs.HasErrors() +} + +// Validates underlying interface of a node and returns the effective Typed Interface. +func ValidateUnderlyingInterface(w c.WorkflowBuilder, node c.NodeBuilder, errs errors.CompileErrors) (iface *core.TypedInterface, ok bool) { + switch node.GetCoreNode().GetTarget().(type) { + case *core.Node_TaskNode: + if node.GetTaskNode().GetReferenceId() == nil { + errs.Collect(errors.NewValueRequiredErr(node.GetId(), "TaskNode.ReferenceId")) + } else if task, taskOk := w.GetTask(*node.GetTaskNode().GetReferenceId()); taskOk { + iface = task.GetInterface() + if iface == nil { + // Default value for no interface is nil, initialize an empty interface + iface = &core.TypedInterface{ + Inputs: &core.VariableMap{Variables: map[string]*core.Variable{}}, + Outputs: &core.VariableMap{Variables: map[string]*core.Variable{}}, + } + } + } else { + errs.Collect(errors.NewTaskReferenceNotFoundErr(node.GetId(), node.GetTaskNode().GetReferenceId().String())) + } + case *core.Node_WorkflowNode: + if node.GetWorkflowNode().GetLaunchplanRef().String() == w.GetCoreWorkflow().Template.Id.String() { + iface = w.GetCoreWorkflow().Template.Interface + if iface == nil { + errs.Collect(errors.NewValueRequiredErr(node.GetId(), "WorkflowNode.Interface")) + } + } else if node.GetWorkflowNode().GetLaunchplanRef() != nil { + if launchPlan, launchPlanOk := w.GetLaunchPlan(*node.GetWorkflowNode().GetLaunchplanRef()); launchPlanOk { + inputs := launchPlan.GetExpectedInputs() + if inputs == nil { + errs.Collect(errors.NewValueRequiredErr(node.GetId(), "WorkflowNode.ExpectedInputs")) + } + + outputs := launchPlan.GetExpectedOutputs() + if outputs == nil { + errs.Collect(errors.NewValueRequiredErr(node.GetId(), "WorkflowNode.ExpectedOutputs")) + } + + // Compute exposed inputs as the union of all required inputs and any input overwritten by the node. + exposedInputs := map[string]*core.Variable{} + for name, p := range inputs.Parameters { + if p.GetRequired() { + exposedInputs[name] = p.Var + } else if _, found := findBindingByVariableName(node.GetInputs(), name); found { + exposedInputs[name] = p.Var + } + // else, the param has a default value and is not being overwritten by the node + } + + iface = &core.TypedInterface{ + Inputs: &core.VariableMap{ + Variables: exposedInputs, + }, + Outputs: outputs, + } + } else { + errs.Collect(errors.NewWorkflowReferenceNotFoundErr( + node.GetId(), + fmt.Sprintf("%v", node.GetWorkflowNode().GetLaunchplanRef()))) + } + } else if node.GetWorkflowNode().GetSubWorkflowRef() != nil { + if wf, wfOk := w.GetSubWorkflow(*node.GetWorkflowNode().GetSubWorkflowRef()); wfOk { + if wf.Template == nil { + errs.Collect(errors.NewValueRequiredErr(node.GetId(), "WorkflowNode.Template")) + } else { + iface = wf.Template.Interface + if iface == nil { + errs.Collect(errors.NewValueRequiredErr(node.GetId(), "WorkflowNode.Template.Interface")) + } + } + } else { + errs.Collect(errors.NewWorkflowReferenceNotFoundErr( + node.GetId(), + fmt.Sprintf("%v", node.GetWorkflowNode().GetSubWorkflowRef()))) + } + } else { + errs.Collect(errors.NewWorkflowReferenceNotFoundErr( + node.GetId(), + fmt.Sprintf("%v/%v", node.GetWorkflowNode().GetLaunchplanRef(), node.GetWorkflowNode().GetSubWorkflowRef()))) + } + case *core.Node_BranchNode: + iface, _ = validateBranchInterface(w, node, errs.NewScope()) + default: + errs.Collect(errors.NewValueRequiredErr(node.GetId(), "Target")) + } + + if iface != nil { + ValidateInterface(node.GetId(), iface, errs.NewScope()) + } + + if !errs.HasErrors() { + node.SetInterface(iface) + } + + return iface, !errs.HasErrors() +} diff --git a/flytepropeller/pkg/compiler/validators/interface_test.go b/flytepropeller/pkg/compiler/validators/interface_test.go new file mode 100644 index 0000000000..f85d4f7d9a --- /dev/null +++ b/flytepropeller/pkg/compiler/validators/interface_test.go @@ -0,0 +1,282 @@ +package validators + +import ( + "testing" + + "github.com/lyft/flyteidl/clients/go/coreutils" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + c "github.com/lyft/flytepropeller/pkg/compiler/common" + "github.com/lyft/flytepropeller/pkg/compiler/common/mocks" + "github.com/lyft/flytepropeller/pkg/compiler/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestValidateInterface(t *testing.T) { + t.Run("Happy path", func(t *testing.T) { + errs := errors.NewCompileErrors() + iface, ok := ValidateInterface( + c.NodeID("node1"), + &core.TypedInterface{ + Inputs: &core.VariableMap{ + Variables: map[string]*core.Variable{}, + }, + Outputs: &core.VariableMap{ + Variables: map[string]*core.Variable{}, + }, + }, + errs.NewScope(), + ) + + assertNonEmptyInterface(t, iface, ok, errs) + }) + + t.Run("Empty Inputs/Outputs", func(t *testing.T) { + errs := errors.NewCompileErrors() + iface, ok := ValidateInterface( + c.NodeID("node1"), + &core.TypedInterface{}, + errs.NewScope(), + ) + + assertNonEmptyInterface(t, iface, ok, errs) + }) + + t.Run("Empty Interface", func(t *testing.T) { + errs := errors.NewCompileErrors() + iface, ok := ValidateInterface( + c.NodeID("node1"), + nil, + errs.NewScope(), + ) + + assertNonEmptyInterface(t, iface, ok, errs) + }) +} + +func assertNonEmptyInterface(t testing.TB, iface *core.TypedInterface, ifaceOk bool, errs errors.CompileErrors) { + assert.True(t, ifaceOk) + assert.NotNil(t, iface) + assert.False(t, errs.HasErrors()) + if !ifaceOk { + t.Fatal(errs) + } + + assert.NotNil(t, iface.Inputs) + assert.NotNil(t, iface.Inputs.Variables) + assert.NotNil(t, iface.Outputs) + assert.NotNil(t, iface.Outputs.Variables) +} + +func TestValidateUnderlyingInterface(t *testing.T) { + t.Run("Invalid empty node", func(t *testing.T) { + wfBuilder := mocks.WorkflowBuilder{} + nodeBuilder := mocks.NodeBuilder{} + nodeBuilder.On("GetCoreNode").Return(&core.Node{}) + nodeBuilder.On("GetId").Return("node_1") + errs := errors.NewCompileErrors() + iface, ifaceOk := ValidateUnderlyingInterface(&wfBuilder, &nodeBuilder, errs.NewScope()) + assert.False(t, ifaceOk) + assert.Nil(t, iface) + assert.True(t, errs.HasErrors()) + }) + + t.Run("Task Node", func(t *testing.T) { + task := mocks.Task{} + task.On("GetInterface").Return(nil) + + wfBuilder := mocks.WorkflowBuilder{} + wfBuilder.On("GetTask", mock.MatchedBy(func(id core.Identifier) bool { + return id.String() == (&core.Identifier{ + Name: "Task_1", + }).String() + })).Return(&task, true) + + taskNode := &core.TaskNode{ + Reference: &core.TaskNode_ReferenceId{ + ReferenceId: &core.Identifier{ + Name: "Task_1", + }, + }, + } + + nodeBuilder := mocks.NodeBuilder{} + nodeBuilder.On("GetCoreNode").Return(&core.Node{ + Target: &core.Node_TaskNode{ + TaskNode: taskNode, + }, + }) + + nodeBuilder.On("GetTaskNode").Return(taskNode) + nodeBuilder.On("GetId").Return("node_1") + nodeBuilder.On("SetInterface", mock.Anything).Return() + + errs := errors.NewCompileErrors() + iface, ifaceOk := ValidateUnderlyingInterface(&wfBuilder, &nodeBuilder, errs.NewScope()) + assertNonEmptyInterface(t, iface, ifaceOk, errs) + }) + + t.Run("Workflow Node", func(t *testing.T) { + wfBuilder := mocks.WorkflowBuilder{} + wfBuilder.On("GetCoreWorkflow").Return(&core.CompiledWorkflow{ + Template: &core.WorkflowTemplate{ + Id: &core.Identifier{ + Name: "Ref_1", + }, + }, + }) + workflowNode := &core.WorkflowNode{ + Reference: &core.WorkflowNode_LaunchplanRef{ + LaunchplanRef: &core.Identifier{ + Name: "Ref_1", + }, + }, + } + + nodeBuilder := mocks.NodeBuilder{} + nodeBuilder.On("GetCoreNode").Return(&core.Node{ + Target: &core.Node_WorkflowNode{ + WorkflowNode: workflowNode, + }, + }) + + nodeBuilder.On("GetWorkflowNode").Return(workflowNode) + nodeBuilder.On("GetId").Return("node_1") + nodeBuilder.On("SetInterface", mock.Anything).Return() + nodeBuilder.On("GetInputs").Return([]*core.Binding{}) + + t.Run("Self", func(t *testing.T) { + errs := errors.NewCompileErrors() + _, ifaceOk := ValidateUnderlyingInterface(&wfBuilder, &nodeBuilder, errs.NewScope()) + assert.False(t, ifaceOk) + + wfBuilder := mocks.WorkflowBuilder{} + wfBuilder.On("GetCoreWorkflow").Return(&core.CompiledWorkflow{ + Template: &core.WorkflowTemplate{ + Id: &core.Identifier{ + Name: "Ref_1", + }, + Interface: &core.TypedInterface{ + Inputs: &core.VariableMap{ + Variables: map[string]*core.Variable{}, + }, + Outputs: &core.VariableMap{ + Variables: map[string]*core.Variable{}, + }, + }, + }, + }) + + errs = errors.NewCompileErrors() + iface, ifaceOk := ValidateUnderlyingInterface(&wfBuilder, &nodeBuilder, errs.NewScope()) + assertNonEmptyInterface(t, iface, ifaceOk, errs) + }) + + t.Run("LP_Ref", func(t *testing.T) { + lp := mocks.InterfaceProvider{} + lp.On("GetID").Return(&core.Identifier{Name: "Ref_1"}) + lp.On("GetExpectedInputs").Return(&core.ParameterMap{ + Parameters: map[string]*core.Parameter{ + "required": { + Var: &core.Variable{ + Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}}, + }, + Behavior: &core.Parameter_Required{ + Required: true, + }, + }, + "default_value": { + Var: &core.Variable{ + Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}}, + }, + Behavior: &core.Parameter_Default{ + Default: coreutils.MustMakeLiteral(5), + }, + }, + }, + }) + lp.On("GetExpectedOutputs").Return(&core.VariableMap{}) + + wfBuilder := mocks.WorkflowBuilder{} + wfBuilder.On("GetCoreWorkflow").Return(&core.CompiledWorkflow{ + Template: &core.WorkflowTemplate{ + Id: &core.Identifier{ + Name: "Ref_2", + }, + }, + }) + + wfBuilder.On("GetLaunchPlan", mock.Anything).Return(nil, false) + + errs := errors.NewCompileErrors() + _, ifaceOk := ValidateUnderlyingInterface(&wfBuilder, &nodeBuilder, errs.NewScope()) + assert.False(t, ifaceOk) + + wfBuilder = mocks.WorkflowBuilder{} + wfBuilder.On("GetCoreWorkflow").Return(&core.CompiledWorkflow{ + Template: &core.WorkflowTemplate{ + Id: &core.Identifier{ + Name: "Ref_2", + }, + }, + }) + + wfBuilder.On("GetLaunchPlan", matchIdentifier(core.Identifier{Name: "Ref_1"})).Return(&lp, true) + + errs = errors.NewCompileErrors() + iface, ifaceOk := ValidateUnderlyingInterface(&wfBuilder, &nodeBuilder, errs.NewScope()) + assertNonEmptyInterface(t, iface, ifaceOk, errs) + }) + + t.Run("Subwf", func(t *testing.T) { + subWf := core.CompiledWorkflow{ + Template: &core.WorkflowTemplate{ + Interface: &core.TypedInterface{ + Inputs: &core.VariableMap{}, + Outputs: &core.VariableMap{}, + }, + }, + } + + wfBuilder := mocks.WorkflowBuilder{} + wfBuilder.On("GetCoreWorkflow").Return(&core.CompiledWorkflow{ + Template: &core.WorkflowTemplate{ + Id: &core.Identifier{ + Name: "Ref_2", + }, + }, + }) + + wfBuilder.On("GetLaunchPlan", mock.Anything).Return(nil, false) + + errs := errors.NewCompileErrors() + _, ifaceOk := ValidateUnderlyingInterface(&wfBuilder, &nodeBuilder, errs.NewScope()) + assert.False(t, ifaceOk) + + wfBuilder = mocks.WorkflowBuilder{} + wfBuilder.On("GetCoreWorkflow").Return(&core.CompiledWorkflow{ + Template: &core.WorkflowTemplate{ + Id: &core.Identifier{ + Name: "Ref_2", + }, + }, + }) + + wfBuilder.On("GetSubWorkflow", matchIdentifier(core.Identifier{Name: "Ref_1"})).Return(&subWf, true) + + workflowNode.Reference = &core.WorkflowNode_SubWorkflowRef{ + SubWorkflowRef: &core.Identifier{Name: "Ref_1"}, + } + + errs = errors.NewCompileErrors() + iface, ifaceOk := ValidateUnderlyingInterface(&wfBuilder, &nodeBuilder, errs.NewScope()) + assertNonEmptyInterface(t, iface, ifaceOk, errs) + }) + }) +} + +func matchIdentifier(id core.Identifier) interface{} { + return mock.MatchedBy(func(arg core.Identifier) bool { + return arg.String() == id.String() + }) +} diff --git a/flytepropeller/pkg/compiler/validators/node.go b/flytepropeller/pkg/compiler/validators/node.go new file mode 100644 index 0000000000..cf8f9a4b42 --- /dev/null +++ b/flytepropeller/pkg/compiler/validators/node.go @@ -0,0 +1,116 @@ +// This package contains validators for all elements of the workflow spec (node, task, branch, interface, bindings... etc.) +package validators + +import ( + flyte "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + c "github.com/lyft/flytepropeller/pkg/compiler/common" + "github.com/lyft/flytepropeller/pkg/compiler/errors" +) + +// Computes output parameters after applying all aliases -if any-. +func validateEffectiveOutputParameters(n c.NodeBuilder, errs errors.CompileErrors) ( + params *flyte.VariableMap, ok bool) { + aliases := make(map[string]string, len(n.GetOutputAliases())) + for _, alias := range n.GetOutputAliases() { + if _, found := aliases[alias.Var]; found { + errs.Collect(errors.NewDuplicateAliasErr(n.GetId(), alias.Alias)) + } else { + aliases[alias.Var] = alias.Alias + } + } + + if n.GetInterface() != nil { + params = &flyte.VariableMap{ + Variables: make(map[string]*flyte.Variable, len(n.GetInterface().GetOutputs().Variables)), + } + + for paramName, param := range n.GetInterface().GetOutputs().Variables { + if alias, found := aliases[paramName]; found { + if newParam, paramOk := withVariableName(param); paramOk { + params.Variables[alias] = newParam + } else { + errs.Collect(errors.NewParameterNotBoundErr(n.GetId(), alias)) + } + + delete(aliases, paramName) + } else { + params.Variables[paramName] = param + } + } + + // If there are still more aliases at this point, they point to non-existent variables. + for _, alias := range aliases { + errs.Collect(errors.NewParameterNotBoundErr(n.GetId(), alias)) + } + } + + return params, !errs.HasErrors() +} + +func validateBranchNode(w c.WorkflowBuilder, n c.NodeBuilder, errs errors.CompileErrors) bool { + cases := make([]*flyte.IfBlock, 0, len(n.GetBranchNode().IfElse.Other)+1) + cases = append(cases, n.GetBranchNode().IfElse.Case) + cases = append(cases, n.GetBranchNode().IfElse.Other...) + for _, block := range cases { + // Validate condition + ValidateBooleanExpression(n, block.Condition, errs.NewScope()) + + if block.GetThenNode() == nil { + errs.Collect(errors.NewBranchNodeNotSpecified(n.GetId())) + } else { + wrapperNode := w.NewNodeBuilder(block.GetThenNode()) + if ValidateNode(w, wrapperNode, errs.NewScope()) { + // Add to the global nodes to be able to reference it later + w.AddNode(wrapperNode, errs.NewScope()) + w.AddExecutionEdge(n.GetId(), block.GetThenNode().Id) + } + } + } + + return !errs.HasErrors() +} + +func validateNodeID(w c.WorkflowBuilder, nodeID string, errs errors.CompileErrors) (node c.NodeBuilder, ok bool) { + if nodeID == "" { + n, _ := w.GetNode(c.StartNodeID) + return n, !errs.HasErrors() + } else if node, ok = w.GetNode(nodeID); !ok { + errs.Collect(errors.NewNodeReferenceNotFoundErr(nodeID, nodeID)) + } + + return node, !errs.HasErrors() +} + +func ValidateNode(w c.WorkflowBuilder, n c.NodeBuilder, errs errors.CompileErrors) (ok bool) { + if n.GetId() == "" { + errs.Collect(errors.NewValueRequiredErr("", "Id")) + } + + if _, ifaceOk := ValidateUnderlyingInterface(w, n, errs.NewScope()); ifaceOk { + // Validate node output aliases + validateEffectiveOutputParameters(n, errs.NewScope()) + } + + // Validate branch node conditions and inner nodes. + if n.GetBranchNode() != nil { + validateBranchNode(w, n, errs.NewScope()) + } else if workflowN := n.GetWorkflowNode(); workflowN != nil && workflowN.GetSubWorkflowRef() != nil { + if wf, wfOk := w.GetSubWorkflow(*workflowN.GetSubWorkflowRef()); wfOk { + if subWorkflow, workflowOk := w.ValidateWorkflow(wf, errs.NewScope()); workflowOk { + n.SetSubWorkflow(subWorkflow) + } + } else { + errs.Collect(errors.NewWorkflowReferenceNotFoundErr(n.GetId(), workflowN.GetSubWorkflowRef().String())) + } + } else if taskN := n.GetTaskNode(); taskN != nil && taskN.GetReferenceId() != nil { + if task, found := w.GetTask(*taskN.GetReferenceId()); found { + n.SetTask(task) + } else if taskN.GetReferenceId() == nil { + errs.Collect(errors.NewValueRequiredErr(n.GetId(), "TaskNode.ReferenceId")) + } else { + errs.Collect(errors.NewTaskReferenceNotFoundErr(n.GetId(), taskN.GetReferenceId().String())) + } + } + + return !errs.HasErrors() +} diff --git a/flytepropeller/pkg/compiler/validators/typing.go b/flytepropeller/pkg/compiler/validators/typing.go new file mode 100644 index 0000000000..f980237172 --- /dev/null +++ b/flytepropeller/pkg/compiler/validators/typing.go @@ -0,0 +1,154 @@ +package validators + +import ( + structpb "github.com/golang/protobuf/ptypes/struct" + flyte "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" +) + +type typeChecker interface { + CastsFrom(*flyte.LiteralType) bool +} + +type trivialChecker struct { + literalType *flyte.LiteralType +} + +type voidChecker struct{} + +type mapTypeChecker struct { + literalType *flyte.LiteralType +} + +type collectionTypeChecker struct { + literalType *flyte.LiteralType +} + +type schemaTypeChecker struct { + literalType *flyte.LiteralType +} + +// The trivial type checker merely checks if types match exactly. +func (t trivialChecker) CastsFrom(upstreamType *flyte.LiteralType) bool { + // Everything is nullable currently + if isVoid(upstreamType) { + return true + } + // Ignore metadata when comparing types. + upstreamTypeCopy := *upstreamType + downstreamTypeCopy := *t.literalType + upstreamTypeCopy.Metadata = &structpb.Struct{} + downstreamTypeCopy.Metadata = &structpb.Struct{} + return upstreamTypeCopy.String() == downstreamTypeCopy.String() +} + +// The void type matches everything +func (t voidChecker) CastsFrom(upstreamType *flyte.LiteralType) bool { + return true +} + +// For a map type checker, we need to ensure both the key types and value types match. +func (t mapTypeChecker) CastsFrom(upstreamType *flyte.LiteralType) bool { + // Maps are nullable + if isVoid(upstreamType) { + return true + } + + mapLiteralType := upstreamType.GetMapValueType() + if mapLiteralType != nil { + return getTypeChecker(t.literalType.GetMapValueType()).CastsFrom(mapLiteralType) + } + return false +} + +// For a collection type, we need to ensure that the nesting is correct and the final sub-types match. +func (t collectionTypeChecker) CastsFrom(upstreamType *flyte.LiteralType) bool { + // Collections are nullable + if isVoid(upstreamType) { + return true + } + + collectionType := upstreamType.GetCollectionType() + if collectionType != nil { + return getTypeChecker(t.literalType.GetCollectionType()).CastsFrom(collectionType) + } + return false +} + +// Schemas are more complex types in the Flyte ecosystem. A schema is considered castable in the following +// cases. +// +// 1. The downstream schema has no column types specified. In such a case, it accepts all schema input since it is +// generic. +// +// 2. The downstream schema has a subset of the upstream columns and they match perfectly. +// +func (t schemaTypeChecker) CastsFrom(upstreamType *flyte.LiteralType) bool { + // Schemas are nullable + if isVoid(upstreamType) { + return true + } + + schemaType := upstreamType.GetSchema() + if schemaType == nil { + return false + } + + // If no columns are specified, this is a generic schema and it can accept any schema type. + if len(t.literalType.GetSchema().Columns) == 0 { + return true + } + + nameToTypeMap := make(map[string]flyte.SchemaType_SchemaColumn_SchemaColumnType) + for _, column := range schemaType.Columns { + nameToTypeMap[column.Name] = column.Type + } + + // Check that the downstream schema is a strict sub-set of the upstream schema. + for _, column := range t.literalType.GetSchema().Columns { + upstreamType, ok := nameToTypeMap[column.Name] + if !ok { + return false + } + if upstreamType != column.Type { + return false + } + } + return true +} + +func isVoid(t *flyte.LiteralType) bool { + switch t.GetType().(type) { + case *flyte.LiteralType_Simple: + return t.GetSimple() == flyte.SimpleType_NONE + default: + return false + } +} + +func getTypeChecker(t *flyte.LiteralType) typeChecker { + switch t.GetType().(type) { + case *flyte.LiteralType_CollectionType: + return collectionTypeChecker{ + literalType: t, + } + case *flyte.LiteralType_MapValueType: + return mapTypeChecker{ + literalType: t, + } + case *flyte.LiteralType_Schema: + return schemaTypeChecker{ + literalType: t, + } + default: + if isVoid(t) { + return voidChecker{} + } + return trivialChecker{ + literalType: t, + } + } +} + +func AreTypesCastable(upstreamType, downstreamType *flyte.LiteralType) bool { + return getTypeChecker(downstreamType).CastsFrom(upstreamType) +} diff --git a/flytepropeller/pkg/compiler/validators/typing_test.go b/flytepropeller/pkg/compiler/validators/typing_test.go new file mode 100644 index 0000000000..2e4654653a --- /dev/null +++ b/flytepropeller/pkg/compiler/validators/typing_test.go @@ -0,0 +1,359 @@ +package validators + +import ( + "testing" + + structpb "github.com/golang/protobuf/ptypes/struct" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/stretchr/testify/assert" +) + +func TestSimpleLiteralCasting(t *testing.T) { + t.Run("BaseCase_Integer", func(t *testing.T) { + castable := AreTypesCastable( + &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, + }, + &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, + }, + ) + assert.True(t, castable, "Integers should be castable to other integers") + }) + + t.Run("IntegerToFloat", func(t *testing.T) { + castable := AreTypesCastable( + &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, + }, + &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_FLOAT}, + }, + ) + assert.False(t, castable, "Integers should not be castable to floats") + }) + + t.Run("FloatToInteger", func(t *testing.T) { + castable := AreTypesCastable( + &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_FLOAT}, + }, + &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, + }, + ) + assert.False(t, castable, "Floats should not be castable to integers") + }) + + t.Run("VoidToInteger", func(t *testing.T) { + castable := AreTypesCastable( + &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_NONE}, + }, + &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, + }, + ) + assert.True(t, castable, "Floats are nullable") + }) + + t.Run("IgnoreMetadata", func(t *testing.T) { + s := structpb.Struct{ + Fields: map[string]*structpb.Value{ + "a": {}, + }, + } + castable := AreTypesCastable( + &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, + Metadata: &s, + }, + &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, + }, + ) + assert.True(t, castable, "Metadata should be ignored") + }) +} + +func TestCollectionCasting(t *testing.T) { + t.Run("BaseCase_SingleIntegerCollection", func(t *testing.T) { + castable := AreTypesCastable( + &core.LiteralType{ + Type: &core.LiteralType_CollectionType{ + CollectionType: &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, + }, + }, + }, + &core.LiteralType{ + Type: &core.LiteralType_CollectionType{ + CollectionType: &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, + }, + }, + }, + ) + assert.True(t, castable, "[Integer] should be castable to [Integer].") + }) + + t.Run("SingleIntegerCollectionToSingleFloatCollection", func(t *testing.T) { + castable := AreTypesCastable( + &core.LiteralType{ + Type: &core.LiteralType_CollectionType{ + CollectionType: &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, + }, + }, + }, + &core.LiteralType{ + Type: &core.LiteralType_CollectionType{ + CollectionType: &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_FLOAT}, + }, + }, + }, + ) + assert.False(t, castable, "[Integer] should not be castable to [Float]") + }) + + t.Run("MismatchedNestLevels_Scalar", func(t *testing.T) { + castable := AreTypesCastable( + &core.LiteralType{ + Type: &core.LiteralType_CollectionType{ + CollectionType: &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, + }, + }, + }, + &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, + }, + ) + assert.False(t, castable, "[Integer] should not be castable to Integer") + }) + + t.Run("MismatchedNestLevels_Collections", func(t *testing.T) { + castable := AreTypesCastable( + &core.LiteralType{ + Type: &core.LiteralType_CollectionType{ + CollectionType: &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, + }, + }, + }, + &core.LiteralType{ + Type: &core.LiteralType_CollectionType{ + CollectionType: &core.LiteralType{ + Type: &core.LiteralType_CollectionType{ + CollectionType: &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, + }, + }, + }, + }, + }, + ) + assert.False(t, castable, "[Integer] should not be castable to [[Integer]]") + }) + + t.Run("Nullable_Collections", func(t *testing.T) { + castable := AreTypesCastable( + &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_NONE, + }, + }, + &core.LiteralType{ + Type: &core.LiteralType_CollectionType{ + CollectionType: &core.LiteralType{ + Type: &core.LiteralType_CollectionType{ + CollectionType: &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, + }, + }, + }, + }, + }, + ) + assert.True(t, castable, "Collections are nullable") + }) +} + +func TestMapCasting(t *testing.T) { + t.Run("BaseCase_SingleIntegerMap", func(t *testing.T) { + castable := AreTypesCastable( + &core.LiteralType{ + Type: &core.LiteralType_MapValueType{ + MapValueType: &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, + }, + }, + }, + &core.LiteralType{ + Type: &core.LiteralType_MapValueType{ + MapValueType: &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, + }, + }, + }, + ) + assert.True(t, castable, "{k: Integer} should be castable to {k: Integer}.") + }) + + t.Run("ScalarIntegerMapToScalarFloatMap", func(t *testing.T) { + castable := AreTypesCastable( + &core.LiteralType{ + Type: &core.LiteralType_MapValueType{ + MapValueType: &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, + }, + }, + }, + &core.LiteralType{ + Type: &core.LiteralType_MapValueType{ + MapValueType: &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_FLOAT}, + }, + }, + }, + ) + assert.False(t, castable, "{k: Integer} should not be castable to {k: Float}") + }) + + t.Run("MismatchedMapNestLevels_Scalar", func(t *testing.T) { + castable := AreTypesCastable( + &core.LiteralType{ + Type: &core.LiteralType_MapValueType{ + MapValueType: &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, + }, + }, + }, + &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, + }, + ) + assert.False(t, castable, "{k: Integer} should not be castable to Integer") + }) + + t.Run("MismatchedMapNestLevels_Maps", func(t *testing.T) { + castable := AreTypesCastable( + &core.LiteralType{ + Type: &core.LiteralType_MapValueType{ + MapValueType: &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, + }, + }, + }, + &core.LiteralType{ + Type: &core.LiteralType_MapValueType{ + MapValueType: &core.LiteralType{ + Type: &core.LiteralType_MapValueType{ + MapValueType: &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, + }, + }, + }, + }, + }, + ) + assert.False(t, castable, "{k: Integer} should not be castable to {k: {k: Integer}}") + }) +} + +func TestSchemaCasting(t *testing.T) { + genericSchema := &core.LiteralType{ + Type: &core.LiteralType_Schema{ + Schema: &core.SchemaType{ + Columns: []*core.SchemaType_SchemaColumn{}, + }, + }, + } + subsetIntegerSchema := &core.LiteralType{ + Type: &core.LiteralType_Schema{ + Schema: &core.SchemaType{ + Columns: []*core.SchemaType_SchemaColumn{ + { + Name: "a", + Type: core.SchemaType_SchemaColumn_INTEGER, + }, + }, + }, + }, + } + supersetIntegerAndFloatSchema := &core.LiteralType{ + Type: &core.LiteralType_Schema{ + Schema: &core.SchemaType{ + Columns: []*core.SchemaType_SchemaColumn{ + { + Name: "a", + Type: core.SchemaType_SchemaColumn_INTEGER, + }, + { + Name: "b", + Type: core.SchemaType_SchemaColumn_FLOAT, + }, + }, + }, + }, + } + mismatchedSubsetSchema := &core.LiteralType{ + Type: &core.LiteralType_Schema{ + Schema: &core.SchemaType{ + Columns: []*core.SchemaType_SchemaColumn{ + { + Name: "a", + Type: core.SchemaType_SchemaColumn_FLOAT, + }, + }, + }, + }, + } + + t.Run("BaseCase_GenericSchema", func(t *testing.T) { + castable := AreTypesCastable(genericSchema, genericSchema) + assert.True(t, castable, "Schema() should be castable to Schema()") + }) + + t.Run("GenericSchemaToNonGeneric", func(t *testing.T) { + castable := AreTypesCastable(genericSchema, subsetIntegerSchema) + assert.False(t, castable, "Schema() should not be castable to Schema(a=Integer)") + }) + + t.Run("NonGenericSchemaToGeneric", func(t *testing.T) { + castable := AreTypesCastable(subsetIntegerSchema, genericSchema) + assert.True(t, castable, "Schema(a=Integer) should be castable to Schema()") + }) + + t.Run("SupersetToSubsetTypedSchema", func(t *testing.T) { + castable := AreTypesCastable(supersetIntegerAndFloatSchema, subsetIntegerSchema) + assert.True(t, castable, "Schema(a=Integer, b=Float) should be castable to Schema(a=Integer)") + }) + + t.Run("SubsetToSupersetSchema", func(t *testing.T) { + castable := AreTypesCastable(subsetIntegerSchema, supersetIntegerAndFloatSchema) + assert.False(t, castable, "Schema(a=Integer) should not be castable to Schema(a=Integer, b=Float)") + }) + + t.Run("MismatchedColumns", func(t *testing.T) { + castable := AreTypesCastable(subsetIntegerSchema, mismatchedSubsetSchema) + assert.False(t, castable, "Schema(a=Integer) should not be castable to Schema(a=Float)") + }) + + t.Run("MismatchedColumnsFlipped", func(t *testing.T) { + castable := AreTypesCastable(mismatchedSubsetSchema, subsetIntegerSchema) + assert.False(t, castable, "Schema(a=Float) should not be castable to Schema(a=Integer)") + }) + + t.Run("SchemasAreNullable", func(t *testing.T) { + castable := AreTypesCastable( + &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_NONE, + }, + }, + subsetIntegerSchema) + assert.True(t, castable, "Schemas are nullable") + }) +} diff --git a/flytepropeller/pkg/compiler/validators/utils.go b/flytepropeller/pkg/compiler/validators/utils.go new file mode 100644 index 0000000000..19d6ee8357 --- /dev/null +++ b/flytepropeller/pkg/compiler/validators/utils.go @@ -0,0 +1,199 @@ +package validators + +import ( + "github.com/golang/protobuf/proto" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "k8s.io/apimachinery/pkg/util/sets" +) + +func findBindingByVariableName(bindings []*core.Binding, name string) (binding *core.Binding, found bool) { + for _, b := range bindings { + if b.Var == name { + return b, true + } + } + + return nil, false +} + +func findVariableByName(vars *core.VariableMap, name string) (variable *core.Variable, found bool) { + if vars == nil || vars.Variables == nil { + return nil, false + } + + variable, found = vars.Variables[name] + return +} + +// Gets literal type for scalar value. This can be used to compare the underlying type of two scalars for compatibility. +func literalTypeForScalar(scalar *core.Scalar) *core.LiteralType { + // TODO: Should we just pass the type information with the value? That way we don't have to guess? + var literalType *core.LiteralType + switch scalar.GetValue().(type) { + case *core.Scalar_Primitive: + literalType = literalTypeForPrimitive(scalar.GetPrimitive()) + case *core.Scalar_Blob: + if scalar.GetBlob().GetMetadata() == nil { + return nil + } + + literalType = &core.LiteralType{Type: &core.LiteralType_Blob{Blob: scalar.GetBlob().GetMetadata().GetType()}} + case *core.Scalar_Binary: + literalType = &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_BINARY}} + case *core.Scalar_Schema: + literalType = &core.LiteralType{ + Type: &core.LiteralType_Schema{ + Schema: scalar.GetSchema().Type, + }, + } + case *core.Scalar_NoneType: + literalType = &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_NONE}} + case *core.Scalar_Error: + literalType = &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_ERROR}} + default: + return nil + } + + return literalType +} + +func literalTypeForPrimitive(primitive *core.Primitive) *core.LiteralType { + simpleType := core.SimpleType_NONE + switch primitive.GetValue().(type) { + case *core.Primitive_Integer: + simpleType = core.SimpleType_INTEGER + case *core.Primitive_FloatValue: + simpleType = core.SimpleType_FLOAT + case *core.Primitive_StringValue: + simpleType = core.SimpleType_STRING + case *core.Primitive_Boolean: + simpleType = core.SimpleType_BOOLEAN + case *core.Primitive_Datetime: + simpleType = core.SimpleType_DATETIME + case *core.Primitive_Duration: + simpleType = core.SimpleType_DURATION + } + + return &core.LiteralType{Type: &core.LiteralType_Simple{Simple: simpleType}} +} + +func buildVariablesIndex(params *core.VariableMap) (map[string]*core.Variable, sets.String) { + paramMap := make(map[string]*core.Variable, len(params.Variables)) + paramSet := sets.NewString() + for paramName, param := range params.Variables { + paramMap[paramName] = param + paramSet.Insert(paramName) + } + + return paramMap, paramSet +} + +func filterVariables(vars *core.VariableMap, varNames sets.String) *core.VariableMap { + res := &core.VariableMap{ + Variables: make(map[string]*core.Variable, len(varNames)), + } + + for paramName, param := range vars.Variables { + if varNames.Has(paramName) { + res.Variables[paramName] = param + } + } + + return res +} + +func withVariableName(param *core.Variable) (newParam *core.Variable, ok bool) { + if raw, err := proto.Marshal(param); err == nil { + newParam = &core.Variable{} + if err = proto.Unmarshal(raw, newParam); err == nil { + ok = true + } + } + + return +} + +// Gets LiteralType for literal, nil if the value of literal is unknown, or type None if the literal is a non-homogeneous +// type. +func LiteralTypeForLiteral(l *core.Literal) *core.LiteralType { + switch l.GetValue().(type) { + case *core.Literal_Scalar: + return literalTypeForScalar(l.GetScalar()) + case *core.Literal_Collection: + if len(l.GetCollection().Literals) == 0 { + return &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_NONE}} + } + + // Ensure literal collection types are homogeneous. + var innerType *core.LiteralType + for _, x := range l.GetCollection().Literals { + otherType := LiteralTypeForLiteral(x) + if innerType != nil && !AreTypesCastable(otherType, innerType) { + return &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_NONE}} + } + + innerType = otherType + } + + return &core.LiteralType{Type: &core.LiteralType_CollectionType{CollectionType: innerType}} + case *core.Literal_Map: + if len(l.GetMap().Literals) == 0 { + return &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_NONE}} + } + + // Ensure literal map types are homogeneous. + var innerType *core.LiteralType + for _, x := range l.GetMap().Literals { + otherType := LiteralTypeForLiteral(x) + if innerType != nil && !AreTypesCastable(otherType, innerType) { + return &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_NONE}} + } + + innerType = otherType + } + + return &core.LiteralType{Type: &core.LiteralType_MapValueType{MapValueType: innerType}} + } + + return nil +} + +// Converts a literal to a non-promise binding data. +func LiteralToBinding(l *core.Literal) *core.BindingData { + switch l.GetValue().(type) { + case *core.Literal_Scalar: + return &core.BindingData{ + Value: &core.BindingData_Scalar{ + Scalar: l.GetScalar(), + }, + } + case *core.Literal_Collection: + x := make([]*core.BindingData, 0, len(l.GetCollection().Literals)) + for _, sub := range l.GetCollection().Literals { + x = append(x, LiteralToBinding(sub)) + } + + return &core.BindingData{ + Value: &core.BindingData_Collection{ + Collection: &core.BindingDataCollection{ + Bindings: x, + }, + }, + } + case *core.Literal_Map: + x := make(map[string]*core.BindingData, len(l.GetMap().Literals)) + for key, val := range l.GetMap().Literals { + x[key] = LiteralToBinding(val) + } + + return &core.BindingData{ + Value: &core.BindingData_Map{ + Map: &core.BindingDataMap{ + Bindings: x, + }, + }, + } + } + + return nil +} diff --git a/flytepropeller/pkg/compiler/validators/vars.go b/flytepropeller/pkg/compiler/validators/vars.go new file mode 100644 index 0000000000..04058b841f --- /dev/null +++ b/flytepropeller/pkg/compiler/validators/vars.go @@ -0,0 +1,81 @@ +package validators + +import ( + flyte "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + c "github.com/lyft/flytepropeller/pkg/compiler/common" + "github.com/lyft/flytepropeller/pkg/compiler/errors" + "k8s.io/apimachinery/pkg/util/sets" +) + +func validateOutputVar(n c.NodeBuilder, paramName string, errs errors.CompileErrors) ( + param *flyte.Variable, ok bool) { + if outputs, effectiveOk := validateEffectiveOutputParameters(n, errs.NewScope()); effectiveOk { + var paramFound bool + if param, paramFound = findVariableByName(outputs, paramName); !paramFound { + errs.Collect(errors.NewVariableNameNotFoundErr(n.GetId(), n.GetId(), paramName)) + } + } + + return param, !errs.HasErrors() +} + +func validateInputVar(n c.NodeBuilder, paramName string, errs errors.CompileErrors) (param *flyte.Variable, ok bool) { + if n.GetInterface() == nil { + return nil, false + } + + if param, ok = findVariableByName(n.GetInterface().GetInputs(), paramName); !ok { + errs.Collect(errors.NewVariableNameNotFoundErr(n.GetId(), n.GetId(), paramName)) + } + + return +} + +func validateVarType(nodeID c.NodeID, paramName string, param *flyte.Variable, + expectedType *flyte.LiteralType, errs errors.CompileErrors) (ok bool) { + if param.GetType().String() != expectedType.String() { + errs.Collect(errors.NewMismatchingTypesErr(nodeID, paramName, param.GetType().String(), expectedType.String())) + } + + return !errs.HasErrors() +} + +func validateVarsSetMatch(nodeID string, params1, params2 map[string]*flyte.Variable, + params1Set, params2Set sets.String, errs errors.CompileErrors) (match bool) { + // Validate that parameters that exist in both interfaces have compatible types. + inBoth := params1Set.Intersection(params2Set) + for paramName := range inBoth { + if validateVarType(nodeID, paramName, params1[paramName], params2[paramName].Type, errs.NewScope()) { + validateVarType(nodeID, paramName, params2[paramName], params1[paramName].Type, errs.NewScope()) + } + } + + // All remaining params on either sides indicate errors + inLeftSide := params1Set.Intersection(params2Set) + for range inLeftSide { + errs.Collect(errors.NewMismatchingInterfacesErr(nodeID, nodeID)) + } + + inRightSide := params2Set.Intersection(params1Set) + for range inRightSide { + errs.Collect(errors.NewMismatchingInterfacesErr(nodeID, nodeID)) + } + + return !errs.HasErrors() +} + +// Validate parameters have their required attributes set +func validateVariables(nodeID c.NodeID, params *flyte.VariableMap, errs errors.CompileErrors) (ok bool) { + + for paramName, param := range params.Variables { + if len(paramName) == 0 { + errs.Collect(errors.NewValueRequiredErr(nodeID, "paramName")) + } + + if param.Type == nil { + errs.Collect(errors.NewValueRequiredErr(nodeID, "param.Type")) + } + } + + return !errs.HasErrors() +} diff --git a/flytepropeller/pkg/compiler/workflow_compiler.go b/flytepropeller/pkg/compiler/workflow_compiler.go new file mode 100755 index 0000000000..2ec6e39fcd --- /dev/null +++ b/flytepropeller/pkg/compiler/workflow_compiler.go @@ -0,0 +1,330 @@ +// This package provides compiler services for flyte workflows. It performs static analysis on the Workflow and produces +// CompilerErrors for any detected issue. A flyte workflow should only be considered valid for execution if it passed through +// the compiler first. The intended usage for the compiler is as follows: +// 1) Call GetRequirements(...) and load/retrieve all tasks/workflows referenced in the response. +// 2) Call CompileWorkflow(...) and make sure it reports no errors. +// 3) Use one of the transformer packages (e.g. transformer/k8s) to build the final executable workflow. +// +// +-------------------+ +// | start(StartNode) | +// +-------------------+ +// | +// | wf_input +// v +// +--------+ +-------------------+ +// | static | --> | node_1(TaskNode) | +// +--------+ +-------------------+ +// | | +// | | x +// | v +// | +-------------------+ +// +----------> | node_2(TaskNode) | +// +-------------------+ +// | +// | n2_output +// v +// +-------------------+ +// | end(EndNode) | +// +-------------------+ +// +-------------------+ +// | Workflow Id: repo | +// +-------------------+ +package compiler + +import ( + "strings" + + "github.com/golang/protobuf/proto" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + c "github.com/lyft/flytepropeller/pkg/compiler/common" + "github.com/lyft/flytepropeller/pkg/compiler/errors" + v "github.com/lyft/flytepropeller/pkg/compiler/validators" + "k8s.io/apimachinery/pkg/util/sets" +) + +// Updates workflows and tasks references to reflect the needed ones for this workflow (ignoring subworkflows) +func (w *workflowBuilder) updateRequiredReferences() { + reqs := getRequirements(w.CoreWorkflow.Template, w.allSubWorkflows, false, errors.NewCompileErrors()) + workflows := map[c.WorkflowIDKey]c.InterfaceProvider{} + tasks := c.TaskIndex{} + for _, workflowID := range reqs.launchPlanIds { + if wf, ok := w.allLaunchPlans[workflowID.String()]; ok { + workflows[workflowID.String()] = wf + } + } + + for _, taskID := range reqs.taskIds { + if task, ok := w.allTasks[taskID.String()]; ok { + tasks[taskID.String()] = task + } + } + + w.Tasks = tasks + w.LaunchPlans = workflows +} + +// Validates the coreWorkflow contains no cycles and that all nodes are reachable. +func (w workflowBuilder) validateReachable(errs errors.CompileErrors) (ok bool) { + neighbors := func(nodeId string) sets.String { + downNodes := w.downstreamNodes[nodeId] + if downNodes == nil { + return sets.String{} + } + + return downNodes + } + + // TODO: If a branch node can exist in a cycle and not actually be a cycle since it can branch off... + if cycle, visited, detected := detectCycle(c.StartNodeID, neighbors); detected { + errs.Collect(errors.NewCycleDetectedInWorkflowErr(c.StartNodeID, strings.Join(cycle, ">"))) + } else { + // If no cycles are detected, we expect all nodes to have been visited. Otherwise there are unreachable + // node(s).. + if visited.Len() != len(w.Nodes) { + // report unreachable nodes + allNodes := toNodeIdsSet(w.Nodes) + unreachableNodes := allNodes.Difference(visited).Difference(sets.NewString(c.EndNodeID)) + if len(unreachableNodes) > 0 { + errs.Collect(errors.NewUnreachableNodesErr(c.StartNodeID, strings.Join(toSlice(unreachableNodes), ","))) + } + } + } + + return !errs.HasErrors() +} + +// Adds unique nodes to the workflow. +func (w workflowBuilder) AddNode(n c.NodeBuilder, errs errors.CompileErrors) (node c.NodeBuilder, ok bool) { + if _, ok := w.Nodes[n.GetId()]; ok { + errs.Collect(errors.NewDuplicateIDFoundErr(n.GetId())) + } + + node = n + w.Nodes[n.GetId()] = node + ok = !errs.HasErrors() + w.CoreWorkflow.Template.Nodes = append(w.CoreWorkflow.Template.Nodes, node.GetCoreNode()) + return +} + +func (w workflowBuilder) AddExecutionEdge(nodeFrom, nodeTo c.NodeID) { + if nodeFrom == "" { + nodeFrom = c.StartNodeID + } + + if _, found := w.downstreamNodes[nodeFrom]; !found { + w.downstreamNodes[nodeFrom] = sets.String{} + w.CoreWorkflow.Connections.Downstream[nodeFrom] = &core.ConnectionSet_IdList{} + } + + if _, found := w.upstreamNodes[nodeTo]; !found { + w.upstreamNodes[nodeTo] = sets.String{} + w.CoreWorkflow.Connections.Upstream[nodeTo] = &core.ConnectionSet_IdList{ + Ids: make([]string, 1), + } + } + + w.downstreamNodes[nodeFrom].Insert(nodeTo) + w.upstreamNodes[nodeTo].Insert(nodeFrom) + w.CoreWorkflow.Connections.Downstream[nodeFrom].Ids = w.downstreamNodes[nodeFrom].List() + w.CoreWorkflow.Connections.Upstream[nodeTo].Ids = w.upstreamNodes[nodeTo].List() +} + +func (w workflowBuilder) AddEdges(n c.NodeBuilder, errs errors.CompileErrors) (ok bool) { + if n.GetInterface() == nil { + // If there were errors computing node's interface, don't add any edges and just bail. + return + } + + // Add explicitly declared edges + if n.GetUpstreamNodeIds() != nil { + for _, upNode := range n.GetUpstreamNodeIds() { + w.AddExecutionEdge(upNode, n.GetId()) + } + } + + // Add implicit Edges + return v.ValidateBindings(&w, n, n.GetInputs(), n.GetInterface().GetInputs(), errs.NewScope()) +} + +// Contains the main validation logic for the coreWorkflow. If successful, it'll build an executable Workflow. +func (w workflowBuilder) ValidateWorkflow(fg *flyteWorkflow, errs errors.CompileErrors) (c.Workflow, bool) { + // Initialize workflow + wf := w.newWorkflowBuilder(fg) + wf.updateRequiredReferences() + + // Start building out the workflow + // Create global sentinel nodeBuilder with the workflow as its interface. + startNode := &core.Node{ + Id: c.StartNodeID, + } + + var ok bool + if wf.CoreWorkflow.Template.Interface, ok = v.ValidateInterface(c.StartNodeID, wf.CoreWorkflow.Template.Interface, errs.NewScope()); !ok { + return nil, !errs.HasErrors() + } + + checkpoint := make([]*core.Node, 0, len(fg.Template.Nodes)) + checkpoint = append(checkpoint, fg.Template.Nodes...) + fg.Template.Nodes = make([]*core.Node, 0, len(fg.Template.Nodes)) + wf.GetCoreWorkflow().Connections = &core.ConnectionSet{ + Downstream: make(map[string]*core.ConnectionSet_IdList), + Upstream: make(map[string]*core.ConnectionSet_IdList), + } + + globalInputNode, _ := wf.AddNode(wf.NewNodeBuilder(startNode), errs) + globalInputNode.SetInterface(&core.TypedInterface{Outputs: wf.CoreWorkflow.Template.Interface.Inputs}) + + endNode := &core.Node{Id: c.EndNodeID} + globalOutputNode, _ := wf.AddNode(wf.NewNodeBuilder(endNode), errs) + globalOutputNode.SetInterface(&core.TypedInterface{Inputs: wf.CoreWorkflow.Template.Interface.Outputs}) + globalOutputNode.SetInputs(wf.CoreWorkflow.Template.Outputs) + + // Add and validate all other nodes + for _, n := range checkpoint { + if node, addOk := wf.AddNode(wf.NewNodeBuilder(n), errs.NewScope()); addOk { + v.ValidateNode(&wf, node, errs.NewScope()) + } + } + + // Add explicitly and implicitly declared edges + for nodeID, n := range wf.Nodes { + if nodeID == c.StartNodeID { + continue + } + + wf.AddEdges(n, errs.NewScope()) + } + + // Add execution edges for orphan nodes that don't have any inward/outward edges. + for nodeID := range wf.Nodes { + if nodeID == c.StartNodeID || nodeID == c.EndNodeID { + continue + } + + if _, foundUpStream := wf.upstreamNodes[nodeID]; !foundUpStream { + wf.AddExecutionEdge(c.StartNodeID, nodeID) + } + + if _, foundDownStream := wf.downstreamNodes[nodeID]; !foundDownStream { + wf.AddExecutionEdge(nodeID, c.EndNodeID) + } + } + + // Validate workflow outputs are bound + if _, wfIfaceOk := v.ValidateInterface(globalOutputNode.GetId(), globalOutputNode.GetInterface(), errs.NewScope()); wfIfaceOk { + v.ValidateBindings(&wf, globalOutputNode, globalOutputNode.GetInputs(), + globalOutputNode.GetInterface().GetInputs(), errs.NewScope()) + } + + // Validate no cycles are detected. + wf.validateReachable(errs.NewScope()) + + return wf, !errs.HasErrors() +} + +// Validates that all requirements for the coreWorkflow and its subworkflows are present. +func (w workflowBuilder) validateAllRequirements(errs errors.CompileErrors) bool { + reqs := getRequirements(w.CoreWorkflow.Template, w.allSubWorkflows, true, errs) + + for _, lp := range reqs.launchPlanIds { + if _, ok := w.allLaunchPlans[lp.String()]; !ok { + errs.Collect(errors.NewWorkflowReferenceNotFoundErr(c.StartNodeID, lp.String())) + } + } + + for _, taskID := range reqs.taskIds { + if _, ok := w.allTasks[taskID.String()]; !ok { + errs.Collect(errors.NewTaskReferenceNotFoundErr(c.StartNodeID, taskID.String())) + } + } + + return !errs.HasErrors() +} + +// Compiles a flyte workflow a and all of its dependencies into a single executable Workflow. Refer to GetRequirements() +// to obtain a list of launchplan and Task ids to load/compile first. +// Returns an executable Workflow (if no errors are found) or a list of errors that must be addressed before the Workflow +// can be executed. Cast the error to errors.CompileErrors to inspect individual errors. +func CompileWorkflow(primaryWf *core.WorkflowTemplate, subworkflows []*core.WorkflowTemplate, tasks []*core.CompiledTask, + launchPlans []c.InterfaceProvider) (*core.CompiledWorkflowClosure, error) { + + errs := errors.NewCompileErrors() + + if primaryWf == nil { + errs.Collect(errors.NewValueRequiredErr("root", "wf")) + return nil, errs + } + + wf := proto.Clone(primaryWf).(*core.WorkflowTemplate) + + if tasks == nil { + errs.Collect(errors.NewValueRequiredErr("root", "tasks")) + return nil, errs + } + + // Validate all tasks are valid... invalid tasks won't be passed on to the workflow validator + uniqueTasks := sets.NewString() + taskBuilders := make([]c.Task, 0, len(tasks)) + for _, task := range tasks { + if task.Template == nil || task.Template.Id == nil { + errs.Collect(errors.NewValueRequiredErr("task", "Template.Id")) + return nil, errs + } + + if uniqueTasks.Has(task.Template.Id.String()) { + continue + } + + taskBuilders = append(taskBuilders, &taskBuilder{flyteTask: task.Template}) + uniqueTasks.Insert(task.Template.Id.String()) + } + + // Validate overall requirements of the coreWorkflow. + wfIndex, ok := c.NewWorkflowIndex(toCompiledWorkflows(subworkflows...), errs.NewScope()) + if !ok { + return nil, errs + } + + compiledWf := &core.CompiledWorkflow{Template: wf} + + gb := newWorfklowBuilder(compiledWf, wfIndex, c.NewTaskIndex(taskBuilders...), toInterfaceProviderMap(launchPlans)) + // Terminate early if there are some required component not present. + if !gb.validateAllRequirements(errs.NewScope()) { + return nil, errs + } + + validatedWf, ok := gb.ValidateWorkflow(compiledWf, errs.NewScope()) + if ok { + compiledTasks := make([]*core.CompiledTask, 0, len(taskBuilders)) + for _, t := range taskBuilders { + compiledTasks = append(compiledTasks, &core.CompiledTask{Template: t.GetCoreTask()}) + } + + return &core.CompiledWorkflowClosure{ + Primary: validatedWf.GetCoreWorkflow(), + Tasks: compiledTasks, + }, nil + } + + return nil, errs +} + +func (w workflowBuilder) newWorkflowBuilder(fg *flyteWorkflow) workflowBuilder { + return newWorfklowBuilder(fg, w.allSubWorkflows, w.allTasks, w.allLaunchPlans) +} + +func newWorfklowBuilder(fg *flyteWorkflow, wfIndex c.WorkflowIndex, tasks c.TaskIndex, + workflows map[string]c.InterfaceProvider) workflowBuilder { + + return workflowBuilder{ + CoreWorkflow: fg, + LaunchPlans: map[string]c.InterfaceProvider{}, + Nodes: c.NewNodeIndex(), + Tasks: c.NewTaskIndex(), + downstreamNodes: c.StringAdjacencyList{}, + upstreamNodes: c.StringAdjacencyList{}, + allSubWorkflows: wfIndex, + allLaunchPlans: workflows, + allTasks: tasks, + } +} diff --git a/flytepropeller/pkg/compiler/workflow_compiler_test.go b/flytepropeller/pkg/compiler/workflow_compiler_test.go new file mode 100755 index 0000000000..eccdbc3b5f --- /dev/null +++ b/flytepropeller/pkg/compiler/workflow_compiler_test.go @@ -0,0 +1,659 @@ +package compiler + +import ( + "fmt" + "strings" + "testing" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/compiler/common" + "github.com/lyft/flytepropeller/pkg/compiler/errors" + v "github.com/lyft/flytepropeller/pkg/compiler/validators" + "github.com/lyft/flytepropeller/pkg/visualize" + "github.com/stretchr/testify/assert" + "k8s.io/apimachinery/pkg/util/sets" +) + +func createEmptyVariableMap() *core.VariableMap { + res := &core.VariableMap{ + Variables: map[string]*core.Variable{}, + } + return res +} + +func createVariableMap(variableMap map[string]*core.Variable) *core.VariableMap { + res := &core.VariableMap{ + Variables: variableMap, + } + return res +} + +func dumpIdentifierNames(ids []common.Identifier) []string { + res := make([]string, 0, len(ids)) + + for _, id := range ids { + res = append(res, id.Name) + } + + return res +} + +func ExampleCompileWorkflow_basic() { + inputWorkflow := &core.WorkflowTemplate{ + Id: &core.Identifier{Name: "repo"}, + Interface: &core.TypedInterface{ + Inputs: createEmptyVariableMap(), + Outputs: createEmptyVariableMap(), + }, + Nodes: []*core.Node{ + { + Id: "FirstNode", + Target: &core.Node_TaskNode{ + TaskNode: &core.TaskNode{ + Reference: &core.TaskNode_ReferenceId{ + ReferenceId: &core.Identifier{Name: "task_123"}, + }, + }, + }, + }, + }, + } + + // Detect what other workflows/tasks does this coreWorkflow reference + subWorkflows := make([]*core.WorkflowTemplate, 0) + reqs, err := GetRequirements(inputWorkflow, subWorkflows) + if err != nil { + fmt.Printf("failed to get requirements. Error: %v", err) + return + } + + fmt.Printf("Needed Tasks: [%v], Needed Workflows [%v]\n", + strings.Join(dumpIdentifierNames(reqs.GetRequiredTaskIds()), ","), + strings.Join(dumpIdentifierNames(reqs.GetRequiredLaunchPlanIds()), ",")) + + // Replace with logic to satisfy the requirements + workflows := make([]common.InterfaceProvider, 0) + tasks := []*core.TaskTemplate{ + { + Id: &core.Identifier{Name: "task_123"}, + Interface: &core.TypedInterface{ + Inputs: createEmptyVariableMap(), + Outputs: createEmptyVariableMap(), + }, + Target: &core.TaskTemplate_Container{ + Container: &core.Container{ + Image: "image://", + Command: []string{"cmd"}, + Args: []string{"args"}, + }, + }, + }, + } + + compiledTasks := make([]*core.CompiledTask, 0, len(tasks)) + for _, task := range tasks { + compiledTask, err := CompileTask(task) + if err != nil { + fmt.Printf("failed to compile task [%v]. Error: %v", task.Id, err) + return + } + + compiledTasks = append(compiledTasks, compiledTask) + } + + output, errs := CompileWorkflow(inputWorkflow, subWorkflows, compiledTasks, workflows) + fmt.Printf("Compiled Workflow in GraphViz: %v\n", visualize.ToGraphViz(output.Primary)) + fmt.Printf("Compile Errors: %v\n", errs) + + // Output: + // Needed Tasks: [task_123], Needed Workflows [] + // Compiled Workflow in GraphViz: digraph G {rankdir=TB;workflow[label="Workflow Id: name:"repo" "];node[style=filled];"start-node(start)" [shape=Msquare];"start-node(start)" -> "FirstNode()" [label="execution",style="dashed"];"FirstNode()" -> "end-node(end)" [label="execution",style="dashed"];} + // Compile Errors: +} + +func ExampleCompileWorkflow_inputsOutputsBinding() { + inputWorkflow := &core.WorkflowTemplate{ + Id: &core.Identifier{Name: "repo"}, + Interface: &core.TypedInterface{ + Inputs: createVariableMap(map[string]*core.Variable{ + "wf_input": { + Type: getIntegerLiteralType(), + }, + }), + Outputs: createVariableMap(map[string]*core.Variable{ + "wf_output": { + Type: getIntegerLiteralType(), + }, + }), + }, + Nodes: []*core.Node{ + { + Id: "node_1", + Target: &core.Node_TaskNode{ + TaskNode: &core.TaskNode{Reference: &core.TaskNode_ReferenceId{ReferenceId: &core.Identifier{Name: "task_123"}}}, + }, + Inputs: []*core.Binding{ + newVarBinding("", "wf_input", "x"), newIntegerBinding(124, "y"), + }, + }, + { + Id: "node_2", + Target: &core.Node_TaskNode{ + TaskNode: &core.TaskNode{Reference: &core.TaskNode_ReferenceId{ReferenceId: &core.Identifier{Name: "task_123"}}}, + }, + Inputs: []*core.Binding{ + newIntegerBinding(124, "y"), newVarBinding("node_1", "x", "x"), + }, + OutputAliases: []*core.Alias{{Var: "x", Alias: "n2_output"}}, + }, + }, + Outputs: []*core.Binding{newVarBinding("node_2", "n2_output", "wf_output")}, + } + + // Detect what other graphs/tasks does this coreWorkflow reference + subWorkflows := make([]*core.WorkflowTemplate, 0) + reqs, err := GetRequirements(inputWorkflow, subWorkflows) + if err != nil { + fmt.Printf("Failed to get requirements. Error: %v", err) + return + } + + fmt.Printf("Needed Tasks: [%v], Needed Graphs [%v]\n", + strings.Join(dumpIdentifierNames(reqs.GetRequiredTaskIds()), ","), + strings.Join(dumpIdentifierNames(reqs.GetRequiredLaunchPlanIds()), ",")) + + // Replace with logic to satisfy the requirements + graphs := make([]common.InterfaceProvider, 0) + inputTasks := []*core.TaskTemplate{ + { + Id: &core.Identifier{Name: "task_123"}, + Metadata: &core.TaskMetadata{}, + Interface: &core.TypedInterface{ + Inputs: createVariableMap(map[string]*core.Variable{ + "x": { + Type: getIntegerLiteralType(), + }, + "y": { + Type: getIntegerLiteralType(), + }, + }), + Outputs: createVariableMap(map[string]*core.Variable{ + "x": { + Type: getIntegerLiteralType(), + }, + }), + }, + Target: &core.TaskTemplate_Container{ + Container: &core.Container{ + Image: "image://", + Command: []string{"cmd"}, + Args: []string{"args"}, + }, + }, + }, + } + + // Compile all tasks before proceeding with Workflow + compiledTasks := make([]*core.CompiledTask, 0, len(inputTasks)) + for _, task := range inputTasks { + compiledTask, err := CompileTask(task) + if err != nil { + fmt.Printf("Failed to compile task [%v]. Error: %v", task.Id, err) + return + } + + compiledTasks = append(compiledTasks, compiledTask) + } + + output, errs := CompileWorkflow(inputWorkflow, subWorkflows, compiledTasks, graphs) + if errs != nil { + fmt.Printf("Compile Errors: %v\n", errs) + } else { + fmt.Printf("Compiled Workflow in GraphViz: %v\n", visualize.ToGraphViz(output.Primary)) + } + + // Output: + // Needed Tasks: [task_123], Needed Graphs [] + // Compiled Workflow in GraphViz: digraph G {rankdir=TB;workflow[label="Workflow Id: name:"repo" "];node[style=filled];"start-node(start)" [shape=Msquare];"start-node(start)" -> "node_1()" [label="wf_input",style="solid"];"node_1()" -> "node_2()" [label="x",style="solid"];"static" -> "node_1()" [label=""];"node_2()" -> "end-node(end)" [label="n2_output",style="solid"];"static" -> "node_2()" [label=""];} +} + +func ExampleCompileWorkflow_compileErrors() { + inputWorkflow := &core.WorkflowTemplate{ + Id: &core.Identifier{Name: "repo"}, + Interface: &core.TypedInterface{ + Inputs: createEmptyVariableMap(), + Outputs: createEmptyVariableMap(), + }, + Nodes: []*core.Node{ + { + Target: &core.Node_TaskNode{ + TaskNode: &core.TaskNode{ + Reference: &core.TaskNode_ReferenceId{ + ReferenceId: &core.Identifier{Name: "task_123"}, + }, + }, + }, + }, + }, + } + + // Detect what other workflows/tasks does this coreWorkflow reference + subWorkflows := make([]*core.WorkflowTemplate, 0) + reqs, err := GetRequirements(inputWorkflow, subWorkflows) + if err != nil { + fmt.Printf("Failed to get requirements. Error: %v", err) + return + } + + fmt.Printf("Needed Tasks: [%v], Needed Workflows [%v]\n", + strings.Join(dumpIdentifierNames(reqs.GetRequiredTaskIds()), ","), + strings.Join(dumpIdentifierNames(reqs.GetRequiredLaunchPlanIds()), ",")) + + // Replace with logic to satisfy the requirements + workflows := make([]common.InterfaceProvider, 0) + _, errs := CompileWorkflow(inputWorkflow, subWorkflows, []*core.CompiledTask{}, workflows) + fmt.Printf("Compile Errors: %v\n", errs) + + // Output: + // Needed Tasks: [task_123], Needed Workflows [] + // Compile Errors: Collected Errors: 1 + // Error 0: Code: TaskReferenceNotFound, Node Id: start-node, Description: Referenced Task [name:"task_123" ] not found. +} + +func newIntegerPrimitive(value int64) *core.Primitive { + return &core.Primitive{Value: &core.Primitive_Integer{Integer: value}} +} + +func newStringPrimitive(value string) *core.Primitive { + return &core.Primitive{Value: &core.Primitive_StringValue{StringValue: value}} +} + +func newScalarInteger(value int64) *core.Scalar { + return &core.Scalar{ + Value: &core.Scalar_Primitive{ + Primitive: newIntegerPrimitive(value), + }, + } +} + +func newIntegerLiteral(value int64) *core.Literal { + return &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: newScalarInteger(value), + }, + } +} + +func getIntegerLiteralType() *core.LiteralType { + return getSimpleLiteralType(core.SimpleType_INTEGER) +} + +func getSimpleLiteralType(simpleType core.SimpleType) *core.LiteralType { + return &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: simpleType, + }, + } +} + +func newIntegerBinding(value int64, toVar string) *core.Binding { + return &core.Binding{ + Binding: &core.BindingData{ + Value: &core.BindingData_Scalar{Scalar: newIntegerLiteral(value).GetScalar()}, + }, + Var: toVar, + } +} + +func newVarBinding(fromNodeID, fromVar, toVar string) *core.Binding { + return &core.Binding{ + Binding: &core.BindingData{ + Value: &core.BindingData_Promise{ + Promise: &core.OutputReference{ + NodeId: fromNodeID, + Var: fromVar, + }, + }, + }, + Var: toVar, + } +} + +func TestComparisonExpression_MissingLeftRight(t *testing.T) { + bExpr := &core.BooleanExpression{ + Expr: &core.BooleanExpression_Comparison{ + Comparison: &core.ComparisonExpression{ + Operator: core.ComparisonExpression_GT, + }, + }, + } + + errs := errors.NewCompileErrors() + v.ValidateBooleanExpression(&nodeBuilder{flyteNode: &flyteNode{}}, bExpr, errs) + assert.Error(t, errs) + assert.Equal(t, 2, errs.ErrorCount()) +} + +func TestComparisonExpression(t *testing.T) { + bExpr := &core.BooleanExpression{ + Expr: &core.BooleanExpression_Comparison{ + Comparison: &core.ComparisonExpression{ + Operator: core.ComparisonExpression_GT, + LeftValue: &core.Operand{Val: &core.Operand_Primitive{Primitive: newIntegerPrimitive(123)}}, + RightValue: &core.Operand{Val: &core.Operand_Primitive{Primitive: newStringPrimitive("hello")}}, + }, + }, + } + + errs := errors.NewCompileErrors() + v.ValidateBooleanExpression(&nodeBuilder{flyteNode: &flyteNode{}}, bExpr, errs) + assert.True(t, errs.HasErrors()) + assert.Equal(t, 1, errs.ErrorCount()) +} + +func TestBooleanExpression_BranchNodeHasNoCondition(t *testing.T) { + bExpr := &core.BooleanExpression{ + Expr: &core.BooleanExpression_Conjunction{ + Conjunction: &core.ConjunctionExpression{ + Operator: core.ConjunctionExpression_AND, + RightExpression: &core.BooleanExpression{ + Expr: &core.BooleanExpression_Comparison{ + Comparison: &core.ComparisonExpression{ + Operator: core.ComparisonExpression_GT, + LeftValue: &core.Operand{Val: &core.Operand_Primitive{Primitive: newIntegerPrimitive(123)}}, + RightValue: &core.Operand{Val: &core.Operand_Primitive{Primitive: newIntegerPrimitive(345)}}, + }, + }, + }, + }, + }, + } + + errs := errors.NewCompileErrors() + v.ValidateBooleanExpression(&nodeBuilder{flyteNode: &flyteNode{}}, bExpr, errs) + assert.True(t, errs.HasErrors()) + assert.Equal(t, 1, errs.ErrorCount()) + for e := range *errs.Errors() { + assert.Equal(t, errors.BranchNodeHasNoCondition, e.Code()) + } +} + +func newNodeIDSet(nodeIDs ...common.NodeID) sets.String { + return sets.NewString(nodeIDs...) +} + +func TestValidateReachable(t *testing.T) { + graph := &workflowBuilder{} + graph.downstreamNodes = map[string]sets.String{ + v1alpha1.StartNodeID: newNodeIDSet("1"), + "1": newNodeIDSet("5", "2"), + "2": newNodeIDSet("3"), + "3": newNodeIDSet("4"), + "4": newNodeIDSet(v1alpha1.EndNodeID), + } + + for range graph.downstreamNodes { + graph.Nodes = common.NewNodeIndex(graph.NewNodeBuilder(nil)) + } + + errs := errors.NewCompileErrors() + assert.False(t, graph.validateReachable(errs)) + assert.True(t, errs.HasErrors()) +} + +func TestValidateUnderlyingInterface(parentT *testing.T) { + graphIface := &core.TypedInterface{ + Inputs: createVariableMap(map[string]*core.Variable{ + "x": { + Type: getIntegerLiteralType(), + }, + }), + Outputs: createVariableMap(map[string]*core.Variable{ + "x": { + Type: getIntegerLiteralType(), + }, + }), + } + + inputWorkflow := &core.WorkflowTemplate{ + Id: &core.Identifier{Name: "repo"}, + Interface: graphIface, + Nodes: []*core.Node{ + { + Id: "node_123", + Target: &core.Node_TaskNode{ + TaskNode: &core.TaskNode{Reference: &core.TaskNode_ReferenceId{ReferenceId: &core.Identifier{Name: "task_123"}}}, + }, + }, + }, + } + + taskIface := &core.TypedInterface{ + Inputs: createVariableMap(map[string]*core.Variable{ + "x": { + Type: getIntegerLiteralType(), + }, + "y": { + Type: getIntegerLiteralType(), + }, + }), + Outputs: createVariableMap(map[string]*core.Variable{ + "x": { + Type: getIntegerLiteralType(), + }, + }), + } + + inputTasks := []*core.TaskTemplate{ + { + Id: &core.Identifier{Name: "task_123"}, + Metadata: &core.TaskMetadata{}, + Interface: taskIface, + Target: &core.TaskTemplate_Container{ + Container: &core.Container{ + Image: "Image://", + Command: []string{"blah"}, + Args: []string{"bloh"}, + }, + }, + }, + } + + errs := errors.NewCompileErrors() + compiledTasks := make([]common.Task, 0, len(inputTasks)) + for _, inputTask := range inputTasks { + t, _ := compileTaskInternal(inputTask, errs) + compiledTasks = append(compiledTasks, t) + assert.False(parentT, errs.HasErrors()) + if errs.HasErrors() { + assert.FailNow(parentT, errs.Error()) + } + } + + g := newWorfklowBuilder( + &core.CompiledWorkflow{Template: inputWorkflow}, + mustBuildWorkflowIndex(inputWorkflow), + common.NewTaskIndex(compiledTasks...), + map[string]common.InterfaceProvider{}) + (&g).Tasks = common.NewTaskIndex(compiledTasks...) + + parentT.Run("TaskNode", func(t *testing.T) { + errs := errors.NewCompileErrors() + iface, ifaceOk := v.ValidateUnderlyingInterface(&g, &nodeBuilder{flyteNode: inputWorkflow.Nodes[0]}, errs) + assert.True(t, ifaceOk) + assert.False(t, errs.HasErrors()) + assert.Equal(t, taskIface, iface) + }) + + parentT.Run("GraphNode", func(t *testing.T) { + errs := errors.NewCompileErrors() + iface, ifaceOk := v.ValidateUnderlyingInterface(&g, &nodeBuilder{flyteNode: &core.Node{ + Target: &core.Node_WorkflowNode{ + WorkflowNode: &core.WorkflowNode{ + Reference: &core.WorkflowNode_SubWorkflowRef{ + SubWorkflowRef: inputWorkflow.Id, + }, + }, + }, + }}, errs) + assert.True(t, ifaceOk) + assert.False(t, errs.HasErrors()) + assert.Equal(t, graphIface, iface) + }) + + parentT.Run("BranchNode", func(branchT *testing.T) { + branchT.Run("OneCase", func(t *testing.T) { + errs := errors.NewCompileErrors() + iface, ifaceOk := v.ValidateUnderlyingInterface(&g, &nodeBuilder{flyteNode: &core.Node{ + Target: &core.Node_BranchNode{ + BranchNode: &core.BranchNode{ + IfElse: &core.IfElseBlock{ + Case: &core.IfBlock{ + ThenNode: inputWorkflow.Nodes[0], + }, + }, + }, + }, + }}, errs) + assert.True(t, ifaceOk) + assert.False(t, errs.HasErrors()) + assert.Equal(t, taskIface, iface) + }) + + branchT.Run("TwoCases", func(t *testing.T) { + errs := errors.NewCompileErrors() + _, ifaceOk := v.ValidateUnderlyingInterface(&g, &nodeBuilder{flyteNode: &core.Node{ + Target: &core.Node_BranchNode{ + BranchNode: &core.BranchNode{ + IfElse: &core.IfElseBlock{ + Case: &core.IfBlock{ + ThenNode: inputWorkflow.Nodes[0], + }, + Other: []*core.IfBlock{ + { + ThenNode: &core.Node{ + Target: &core.Node_WorkflowNode{ + WorkflowNode: &core.WorkflowNode{ + Reference: &core.WorkflowNode_SubWorkflowRef{ + SubWorkflowRef: inputWorkflow.Id, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }}, errs) + assert.False(t, ifaceOk) + assert.True(t, errs.HasErrors()) + }) + }) +} + +func TestCompileWorkflow(t *testing.T) { + inputWorkflow := &core.WorkflowTemplate{ + Id: &core.Identifier{Name: "repo"}, + Interface: &core.TypedInterface{ + Inputs: createVariableMap(map[string]*core.Variable{ + "x": { + Type: getIntegerLiteralType(), + }, + }), + Outputs: createVariableMap(map[string]*core.Variable{ + "x": { + Type: getIntegerLiteralType(), + }, + }), + }, + Nodes: []*core.Node{ + { + Id: "node_123", + Target: &core.Node_TaskNode{ + TaskNode: &core.TaskNode{Reference: &core.TaskNode_ReferenceId{ReferenceId: &core.Identifier{Name: "task_123"}}}, + }, + Inputs: []*core.Binding{ + newIntegerBinding(123, "x"), newIntegerBinding(123, "y"), + }, + }, + { + Id: "node_456", + Target: &core.Node_TaskNode{ + TaskNode: &core.TaskNode{Reference: &core.TaskNode_ReferenceId{ReferenceId: &core.Identifier{Name: "task_123"}}}, + }, + Inputs: []*core.Binding{ + newIntegerBinding(123, "y"), newVarBinding("node_123", "x", "x"), + }, + UpstreamNodeIds: []string{"node_123"}, + }, + }, + Outputs: []*core.Binding{newVarBinding("node_456", "x", "x")}, + } + + inputTasks := []*core.TaskTemplate{ + { + Id: &core.Identifier{Name: "task_123"}, Metadata: &core.TaskMetadata{}, + Interface: &core.TypedInterface{ + Inputs: createVariableMap(map[string]*core.Variable{ + "x": { + Type: getIntegerLiteralType(), + }, + "y": { + Type: getIntegerLiteralType(), + }, + }), + Outputs: createVariableMap(map[string]*core.Variable{ + "x": { + Type: getIntegerLiteralType(), + }, + }), + }, + Target: &core.TaskTemplate_Container{ + Container: &core.Container{ + Command: []string{}, + Image: "image://123", + }, + }, + }, + } + + errors.SetConfig(errors.Config{PanicOnError: true}) + defer errors.SetConfig(errors.Config{}) + output, errs := CompileWorkflow(inputWorkflow, []*core.WorkflowTemplate{}, mustCompileTasks(inputTasks), []common.InterfaceProvider{}) + assert.NoError(t, errs) + assert.NotNil(t, output) + if output != nil { + t.Logf("Graph Repr: %v", visualize.ToGraphViz(output.Primary)) + + assert.Equal(t, []string{"node_123"}, output.Primary.Connections.Upstream["node_456"].Ids) + } +} + +func mustCompileTasks(tasks []*core.TaskTemplate) []*core.CompiledTask { + res := make([]*core.CompiledTask, 0, len(tasks)) + for _, t := range tasks { + compiledT, err := CompileTask(t) + if err != nil { + panic(err) + } + + res = append(res, compiledT) + } + return res +} + +func mustBuildWorkflowIndex(wfs ...*core.WorkflowTemplate) common.WorkflowIndex { + compiledWfs := make([]*core.CompiledWorkflow, 0, len(wfs)) + for _, wf := range wfs { + compiledWfs = append(compiledWfs, &core.CompiledWorkflow{Template: wf}) + } + + err := errors.NewCompileErrors() + if index, ok := common.NewWorkflowIndex(compiledWfs, err); !ok { + panic(err) + } else { + return index + } +} diff --git a/flytepropeller/pkg/controller/catalog/catalog_client.go b/flytepropeller/pkg/controller/catalog/catalog_client.go new file mode 100644 index 0000000000..53a9b8aa5a --- /dev/null +++ b/flytepropeller/pkg/controller/catalog/catalog_client.go @@ -0,0 +1,28 @@ +package catalog + +import ( + "context" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/storage" +) + +type Client interface { + Get(ctx context.Context, task *core.TaskTemplate, inputPath storage.DataReference) (*core.LiteralMap, error) + Put(ctx context.Context, task *core.TaskTemplate, execID *core.TaskExecutionIdentifier, inputPath storage.DataReference, outputPath storage.DataReference) error +} + +func NewCatalogClient(store storage.ProtobufStore) Client { + catalogConfig := GetConfig() + + var catalogClient Client + if catalogConfig.Type == LegacyDiscoveryType { + catalogClient = NewLegacyDiscovery(catalogConfig.Endpoint, store) + } else if catalogConfig.Type == NoOpDiscoveryType { + catalogClient = NewNoOpDiscovery() + } + + logger.Infof(context.Background(), "Created Catalog client, type: %v", catalogConfig.Type) + return catalogClient +} diff --git a/flytepropeller/pkg/controller/catalog/config_flags.go b/flytepropeller/pkg/controller/catalog/config_flags.go new file mode 100755 index 0000000000..d67dd751ac --- /dev/null +++ b/flytepropeller/pkg/controller/catalog/config_flags.go @@ -0,0 +1,47 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package catalog + +import ( + "encoding/json" + "reflect" + + "fmt" + + "github.com/spf13/pflag" +) + +// If v is a pointer, it will get its element value or the zero value of the element type. +// If v is not a pointer, it will return it as is. +func (Config) elemValueOrNil(v interface{}) interface{} { + if t := reflect.TypeOf(v); t.Kind() == reflect.Ptr { + if reflect.ValueOf(v).IsNil() { + return reflect.Zero(t.Elem()).Interface() + } else { + return reflect.ValueOf(v).Interface() + } + } else if v == nil { + return reflect.Zero(t).Interface() + } + + return v +} + +func (Config) mustMarshalJSON(v json.Marshaler) string { + raw, err := v.MarshalJSON() + if err != nil { + panic(err) + } + + return string(raw) +} + +// GetPFlagSet will return strongly types pflags for all fields in Config and its nested types. The format of the +// flags is json-name.json-sub-name... etc. +func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { + cmdFlags := pflag.NewFlagSet("Config", pflag.ExitOnError) + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "type"), defaultConfig.Type, "Discovery Implementation to use") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "endpoint"), defaultConfig.Endpoint, " Endpoint for discovery service") + return cmdFlags +} diff --git a/flytepropeller/pkg/controller/catalog/config_flags_test.go b/flytepropeller/pkg/controller/catalog/config_flags_test.go new file mode 100755 index 0000000000..a2538822b8 --- /dev/null +++ b/flytepropeller/pkg/controller/catalog/config_flags_test.go @@ -0,0 +1,146 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package catalog + +import ( + "encoding/json" + "fmt" + "reflect" + "strings" + "testing" + + "github.com/mitchellh/mapstructure" + "github.com/stretchr/testify/assert" +) + +var dereferencableKindsConfig = map[reflect.Kind]struct{}{ + reflect.Array: {}, reflect.Chan: {}, reflect.Map: {}, reflect.Ptr: {}, reflect.Slice: {}, +} + +// Checks if t is a kind that can be dereferenced to get its underlying type. +func canGetElementConfig(t reflect.Kind) bool { + _, exists := dereferencableKindsConfig[t] + return exists +} + +// This decoder hook tests types for json unmarshaling capability. If implemented, it uses json unmarshal to build the +// object. Otherwise, it'll just pass on the original data. +func jsonUnmarshalerHookConfig(_, to reflect.Type, data interface{}) (interface{}, error) { + unmarshalerType := reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() + if to.Implements(unmarshalerType) || reflect.PtrTo(to).Implements(unmarshalerType) || + (canGetElementConfig(to.Kind()) && to.Elem().Implements(unmarshalerType)) { + + raw, err := json.Marshal(data) + if err != nil { + fmt.Printf("Failed to marshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + res := reflect.New(to).Interface() + err = json.Unmarshal(raw, &res) + if err != nil { + fmt.Printf("Failed to umarshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + return res, nil + } + + return data, nil +} + +func decode_Config(input, result interface{}) error { + config := &mapstructure.DecoderConfig{ + TagName: "json", + WeaklyTypedInput: true, + Result: result, + DecodeHook: mapstructure.ComposeDecodeHookFunc( + mapstructure.StringToTimeDurationHookFunc(), + mapstructure.StringToSliceHookFunc(","), + jsonUnmarshalerHookConfig, + ), + } + + decoder, err := mapstructure.NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(input) +} + +func join_Config(arr interface{}, sep string) string { + listValue := reflect.ValueOf(arr) + strs := make([]string, 0, listValue.Len()) + for i := 0; i < listValue.Len(); i++ { + strs = append(strs, fmt.Sprintf("%v", listValue.Index(i))) + } + + return strings.Join(strs, sep) +} + +func testDecodeJson_Config(t *testing.T, val, result interface{}) { + assert.NoError(t, decode_Config(val, result)) +} + +func testDecodeSlice_Config(t *testing.T, vStringSlice, result interface{}) { + assert.NoError(t, decode_Config(vStringSlice, result)) +} + +func TestConfig_GetPFlagSet(t *testing.T) { + val := Config{} + cmdFlags := val.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) +} + +func TestConfig_SetFlags(t *testing.T) { + actual := Config{} + cmdFlags := actual.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) + + t.Run("Test_type", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("type"); err == nil { + assert.Equal(t, string(defaultConfig.Type), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("type", testValue) + if vString, err := cmdFlags.GetString("type"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Type) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_endpoint", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("endpoint"); err == nil { + assert.Equal(t, string(defaultConfig.Endpoint), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("endpoint", testValue) + if vString, err := cmdFlags.GetString("endpoint"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Endpoint) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) +} diff --git a/flytepropeller/pkg/controller/catalog/discovery_config.go b/flytepropeller/pkg/controller/catalog/discovery_config.go new file mode 100644 index 0000000000..bbc1bab8ff --- /dev/null +++ b/flytepropeller/pkg/controller/catalog/discovery_config.go @@ -0,0 +1,34 @@ +package catalog + +import ( + "github.com/lyft/flytestdlib/config" +) + +//go:generate pflags Config --default-var defaultConfig + +const ConfigSectionKey = "catalog-cache" + +var ( + defaultConfig = &Config{ + Type: NoOpDiscoveryType, + } + + configSection = config.MustRegisterSection(ConfigSectionKey, defaultConfig) +) + +type DiscoveryType = string + +const ( + NoOpDiscoveryType DiscoveryType = "noop" + LegacyDiscoveryType DiscoveryType = "legacy" +) + +type Config struct { + Type DiscoveryType `json:"type" pflag:"\"noop\",Discovery Implementation to use"` + Endpoint string `json:"endpoint" pflag:"\"\", Endpoint for discovery service"` +} + +// Gets loaded config for Discovery +func GetConfig() *Config { + return configSection.GetConfig().(*Config) +} diff --git a/flytepropeller/pkg/controller/catalog/legacy_discovery.go b/flytepropeller/pkg/controller/catalog/legacy_discovery.go new file mode 100644 index 0000000000..5383c466f7 --- /dev/null +++ b/flytepropeller/pkg/controller/catalog/legacy_discovery.go @@ -0,0 +1,202 @@ +package catalog + +import ( + "context" + "encoding/base64" + "fmt" + "time" + + "github.com/golang/protobuf/proto" + grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/datacatalog" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/pbhash" + "github.com/lyft/flytestdlib/storage" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" +) + +const maxGrpcMsgSizeBytes = 41943040 + +// LegacyDiscovery encapsulates interactions with the Discovery service using a protobuf provided gRPC client. +type LegacyDiscovery struct { + client datacatalog.ArtifactsClient + store storage.ProtobufStore +} + +// Hash each value in the map and return it as the parameter value to be used to generate the Provenance. +func TransformToInputParameters(ctx context.Context, m *core.LiteralMap) ([]*datacatalog.Parameter, error) { + var params = []*datacatalog.Parameter{} + + // Note: The Discovery service will ensure that the output parameters are sorted so the hash is consistent. + // If the values of the literalmap are also a map, pbhash ensures that maps are deterministically hashed as well + for k, typedValue := range m.GetLiterals() { + inputHash, err := pbhash.ComputeHashString(ctx, typedValue) + if err != nil { + return nil, err + } + params = append(params, &datacatalog.Parameter{ + Name: k, + Value: inputHash, + }) + } + + return params, nil +} + +func TransformToOutputParameters(ctx context.Context, m *core.LiteralMap) ([]*datacatalog.Parameter, error) { + var params = []*datacatalog.Parameter{} + for k, typedValue := range m.GetLiterals() { + bytes, err := proto.Marshal(typedValue) + + if err != nil { + return nil, err + } + params = append(params, &datacatalog.Parameter{ + Name: k, + Value: base64.StdEncoding.EncodeToString(bytes), + }) + } + return params, nil +} + +func TransformFromParameters(m []*datacatalog.Parameter) (*core.LiteralMap, error) { + paramsMap := make(map[string]*core.Literal) + + for _, p := range m { + bytes, err := base64.StdEncoding.DecodeString(p.GetValue()) + if err != nil { + return nil, err + } + literal := &core.Literal{} + if err = proto.Unmarshal(bytes, literal); err != nil { + return nil, err + } + paramsMap[p.Name] = literal + } + return &core.LiteralMap{ + Literals: paramsMap, + }, nil +} + +func (d *LegacyDiscovery) Get(ctx context.Context, task *core.TaskTemplate, inputPath storage.DataReference) (*core.LiteralMap, error) { + inputs := &core.LiteralMap{} + + taskInterface := task.Interface + // only download if there are inputs to the task + if taskInterface != nil && taskInterface.Inputs != nil && len(taskInterface.Inputs.Variables) > 0 { + if err := d.store.ReadProtobuf(ctx, inputPath, inputs); err != nil { + return nil, err + } + } + + inputParams, err := TransformToInputParameters(ctx, inputs) + if err != nil { + return nil, err + } + + artifactID := &datacatalog.ArtifactId{ + Name: fmt.Sprintf("%s:%s:%s", task.Id.Project, task.Id.Domain, task.Id.Name), + Version: task.Metadata.DiscoveryVersion, + Inputs: inputParams, + } + options := []grpc.CallOption{ + grpc.MaxCallRecvMsgSize(maxGrpcMsgSizeBytes), + grpc.MaxCallSendMsgSize(maxGrpcMsgSizeBytes), + } + + request := &datacatalog.GetRequest{ + Id: &datacatalog.GetRequest_ArtifactId{ + ArtifactId: artifactID, + }, + } + resp, err := d.client.Get(ctx, request, options...) + + logger.Infof(ctx, "Discovery Get response for artifact |%v|, resp: |%v|, error: %v", artifactID, resp, err) + if err != nil { + return nil, err + } + return TransformFromParameters(resp.Artifact.Outputs) +} + +func GetDefaultGrpcOptions() []grpc_retry.CallOption { + return []grpc_retry.CallOption{ + grpc_retry.WithBackoff(grpc_retry.BackoffLinear(100 * time.Millisecond)), + grpc_retry.WithCodes(codes.DeadlineExceeded, codes.Unavailable, codes.Canceled), + grpc_retry.WithMax(5), + } +} +func (d *LegacyDiscovery) Put(ctx context.Context, task *core.TaskTemplate, execID *core.TaskExecutionIdentifier, inputPath storage.DataReference, outputPath storage.DataReference) error { + inputs := &core.LiteralMap{} + outputs := &core.LiteralMap{} + + taskInterface := task.Interface + // only download if there are inputs to the task + if taskInterface != nil && taskInterface.Inputs != nil && len(taskInterface.Inputs.Variables) > 0 { + if err := d.store.ReadProtobuf(ctx, inputPath, inputs); err != nil { + return err + } + } + + // only download if there are outputs to the task + if taskInterface != nil && taskInterface.Outputs != nil && len(taskInterface.Outputs.Variables) > 0 { + if err := d.store.ReadProtobuf(ctx, outputPath, outputs); err != nil { + return err + } + } + + outputParams, err := TransformToOutputParameters(ctx, outputs) + if err != nil { + return err + } + + inputParams, err := TransformToInputParameters(ctx, inputs) + if err != nil { + return err + } + + artifactID := &datacatalog.ArtifactId{ + Name: fmt.Sprintf("%s:%s:%s", task.Id.Project, task.Id.Domain, task.Id.Name), + Version: task.Metadata.DiscoveryVersion, + Inputs: inputParams, + } + executionID := fmt.Sprintf("%s:%s:%s", execID.GetNodeExecutionId().GetExecutionId().GetProject(), + execID.GetNodeExecutionId().GetExecutionId().GetDomain(), execID.GetNodeExecutionId().GetExecutionId().GetName()) + request := &datacatalog.CreateRequest{ + Ref: artifactID, + ReferenceId: executionID, + Revision: time.Now().Unix(), + Outputs: outputParams, + } + options := []grpc.CallOption{ + grpc.MaxCallRecvMsgSize(maxGrpcMsgSizeBytes), + grpc.MaxCallSendMsgSize(maxGrpcMsgSizeBytes), + } + + resp, err := d.client.Create(ctx, request, options...) + logger.Infof(ctx, "Discovery Put response for artifact |%v|, resp: |%v|, err: %v", artifactID, resp, err) + return err +} + +func NewLegacyDiscovery(discoveryEndpoint string, store storage.ProtobufStore) *LegacyDiscovery { + + // No discovery endpoint passed. Skip client creation. + if discoveryEndpoint == "" { + return nil + } + + opts := GetDefaultGrpcOptions() + retryInterceptor := grpc.WithUnaryInterceptor(grpc_retry.UnaryClientInterceptor(opts...)) + conn, err := grpc.Dial(discoveryEndpoint, grpc.WithInsecure(), retryInterceptor) + + if err != nil { + return nil + } + client := datacatalog.NewArtifactsClient(conn) + + return &LegacyDiscovery{ + client: client, + store: store, + } +} diff --git a/flytepropeller/pkg/controller/catalog/legacy_discovery_test.go b/flytepropeller/pkg/controller/catalog/legacy_discovery_test.go new file mode 100644 index 0000000000..4f37dabc53 --- /dev/null +++ b/flytepropeller/pkg/controller/catalog/legacy_discovery_test.go @@ -0,0 +1,286 @@ +package catalog + +import ( + "context" + "testing" + + "github.com/lyft/flyteidl/clients/go/datacatalog/mocks" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/datacatalog" + "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/promutils/labeled" + "github.com/lyft/flytestdlib/storage" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func init() { + labeled.SetMetricKeys(contextutils.TaskIDKey) +} + +func createInmemoryDataStore(t testing.TB, scope promutils.Scope) *storage.DataStore { + cfg := storage.Config{ + Type: storage.TypeMemory, + } + d, err := storage.NewDataStore(&cfg, scope) + assert.NoError(t, err) + return d +} + +func newIntegerPrimitive(value int64) *core.Primitive { + return &core.Primitive{Value: &core.Primitive_Integer{Integer: value}} +} + +func newScalarInteger(value int64) *core.Scalar { + return &core.Scalar{ + Value: &core.Scalar_Primitive{ + Primitive: newIntegerPrimitive(value), + }, + } +} + +func newIntegerLiteral(value int64) *core.Literal { + return &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: newScalarInteger(value), + }, + } +} + +func TestTransformToInputParameters(t *testing.T) { + paramsMap := make(map[string]*core.Literal) + paramsMap["out1"] = newIntegerLiteral(200) + + params, err := TransformToInputParameters(context.Background(), &core.LiteralMap{ + Literals: paramsMap, + }) + assert.Nil(t, err) + assert.Equal(t, "out1", params[0].Name) + assert.Equal(t, "c6i2T7NODjwnlxmXKRCNDk/AN4pZpRGGFX49kT6DT/c=", params[0].Value) +} + +func TestTransformToOutputParameters(t *testing.T) { + paramsMap := make(map[string]*core.Literal) + paramsMap["out1"] = newIntegerLiteral(100) + + params, err := TransformToOutputParameters(context.Background(), &core.LiteralMap{ + Literals: paramsMap, + }) + assert.Nil(t, err) + assert.Equal(t, "out1", params[0].Name) + assert.Equal(t, "CgQKAghk", params[0].Value) +} + +func TestTransformFromParameters(t *testing.T) { + params := []*datacatalog.Parameter{ + {Name: "out1", Value: "CgQKAghk"}, + } + literalMap, err := TransformFromParameters(params) + assert.Nil(t, err) + + val, exists := literalMap.Literals["out1"] + assert.True(t, exists) + assert.Equal(t, int64(100), val.GetScalar().GetPrimitive().GetInteger()) +} + +func TestLegacyDiscovery_Get(t *testing.T) { + ctx := context.Background() + + paramMap := &core.LiteralMap{Literals: map[string]*core.Literal{ + "out1": newIntegerLiteral(100), + }} + task := &core.TaskTemplate{ + Id: &core.Identifier{Project: "project", Domain: "domain", Name: "name"}, + Metadata: &core.TaskMetadata{ + DiscoveryVersion: "0.0.1", + }, + Interface: &core.TypedInterface{ + Inputs: &core.VariableMap{ + Variables: map[string]*core.Variable{ + "out1": &core.Variable{ + Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}}, + }, + }, + }, + Outputs: &core.VariableMap{ + Variables: map[string]*core.Variable{ + "out1": &core.Variable{ + Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}}, + }, + }, + }, + }, + } + + inputPath := storage.DataReference("test-data/inputs.pb") + + t.Run("notfound", func(t *testing.T) { + mockClient := &mocks.ArtifactsClient{} + store := createInmemoryDataStore(t, promutils.NewScope("get_test_notfound")) + + err := store.WriteProtobuf(ctx, inputPath, storage.Options{}, paramMap) + assert.NoError(t, err) + + discovery := LegacyDiscovery{client: mockClient, store: store} + mockClient.On("Get", + ctx, + mock.MatchedBy(func(o *datacatalog.GetRequest) bool { + assert.Equal(t, o.GetArtifactId().Name, "project:domain:name") + params, err := TransformToInputParameters(context.Background(), paramMap) + assert.NoError(t, err) + assert.Equal(t, o.GetArtifactId().Inputs, params) + return true + }), + mock.MatchedBy(func(opts grpc.CallOption) bool { return true }), + mock.MatchedBy(func(opts grpc.CallOption) bool { return true }), + ).Return(nil, status.Error(codes.NotFound, "")) + resp, err := discovery.Get(ctx, task, inputPath) + assert.Error(t, err) + assert.Nil(t, resp) + }) + + t.Run("found", func(t *testing.T) { + mockClient := &mocks.ArtifactsClient{} + store := createInmemoryDataStore(t, promutils.NewScope("get_test_found")) + + err := store.WriteProtobuf(ctx, inputPath, storage.Options{}, paramMap) + assert.NoError(t, err) + + discovery := LegacyDiscovery{client: mockClient, store: store} + outputs, err := TransformToOutputParameters(context.Background(), paramMap) + assert.NoError(t, err) + response := &datacatalog.GetResponse{ + Artifact: &datacatalog.Artifact{ + Outputs: outputs, + }, + } + mockClient.On("Get", + ctx, + mock.MatchedBy(func(o *datacatalog.GetRequest) bool { + assert.Equal(t, o.GetArtifactId().Name, "project:domain:name") + assert.Equal(t, o.GetArtifactId().Version, "0.0.1") + params, err := TransformToInputParameters(context.Background(), paramMap) + assert.NoError(t, err) + assert.Equal(t, o.GetArtifactId().Inputs, params) + return true + }), + mock.MatchedBy(func(opts grpc.CallOption) bool { return true }), + mock.MatchedBy(func(opts grpc.CallOption) bool { return true }), + ).Return(response, nil) + resp, err := discovery.Get(ctx, task, inputPath) + assert.NoError(t, err) + assert.NotNil(t, resp) + val, exists := resp.Literals["out1"] + assert.True(t, exists) + assert.Equal(t, int64(100), val.GetScalar().GetPrimitive().GetInteger()) + }) +} + +func TestLegacyDiscovery_Put(t *testing.T) { + ctx := context.Background() + + inputPath := storage.DataReference("test-data/inputs.pb") + outputPath := storage.DataReference("test-data/ouputs.pb") + + paramMap := &core.LiteralMap{Literals: map[string]*core.Literal{ + "out1": newIntegerLiteral(100), + }} + task := &core.TaskTemplate{ + Id: &core.Identifier{Project: "project", Domain: "domain", Name: "name"}, + Metadata: &core.TaskMetadata{ + DiscoveryVersion: "0.0.1", + }, + Interface: &core.TypedInterface{ + Inputs: &core.VariableMap{ + Variables: map[string]*core.Variable{ + "out1": &core.Variable{ + Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}}, + }, + }, + }, + Outputs: &core.VariableMap{ + Variables: map[string]*core.Variable{ + "out1": &core.Variable{ + Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}}, + }, + }, + }, + }, + } + + execID := &core.TaskExecutionIdentifier{ + NodeExecutionId: &core.NodeExecutionIdentifier{ + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "runID", + }, + }, + } + + t.Run("failed", func(t *testing.T) { + mockClient := &mocks.ArtifactsClient{} + store := createInmemoryDataStore(t, promutils.NewScope("put_test_failed")) + discovery := LegacyDiscovery{client: mockClient, store: store} + mockClient.On("Create", + ctx, + mock.MatchedBy(func(o *datacatalog.CreateRequest) bool { + assert.Equal(t, o.GetRef().Name, "project:domain:name") + assert.Equal(t, o.GetReferenceId(), "project:domain:runID") + inputs, err := TransformToInputParameters(context.Background(), paramMap) + assert.NoError(t, err) + outputs, err := TransformToOutputParameters(context.Background(), paramMap) + assert.NoError(t, err) + assert.Equal(t, o.GetRef().Inputs, inputs) + assert.Equal(t, o.GetOutputs(), outputs) + + return true + }), + mock.MatchedBy(func(opts grpc.CallOption) bool { return true }), + mock.MatchedBy(func(opts grpc.CallOption) bool { return true }), + ).Return(nil, status.Error(codes.AlreadyExists, "")) + + err := store.WriteProtobuf(ctx, inputPath, storage.Options{}, paramMap) + assert.NoError(t, err) + err = store.WriteProtobuf(ctx, outputPath, storage.Options{}, paramMap) + assert.NoError(t, err) + + err = discovery.Put(ctx, task, execID, inputPath, outputPath) + assert.Error(t, err) + }) + + t.Run("success", func(t *testing.T) { + store := createInmemoryDataStore(t, promutils.NewScope("put_test_success")) + mockClient := &mocks.ArtifactsClient{} + discovery := LegacyDiscovery{client: mockClient, store: store} + mockClient.On("Create", + ctx, + mock.MatchedBy(func(o *datacatalog.CreateRequest) bool { + assert.Equal(t, o.GetRef().Name, "project:domain:name") + assert.Equal(t, o.GetRef().Version, "0.0.1") + inputs, err := TransformToInputParameters(context.Background(), paramMap) + assert.NoError(t, err) + outputs, err := TransformToOutputParameters(context.Background(), paramMap) + assert.NoError(t, err) + assert.Equal(t, o.GetRef().Inputs, inputs) + assert.Equal(t, o.GetOutputs(), outputs) + + return true + }), + mock.MatchedBy(func(opts grpc.CallOption) bool { return true }), + mock.MatchedBy(func(opts grpc.CallOption) bool { return true }), + ).Return(nil, nil) + + err := store.WriteProtobuf(ctx, inputPath, storage.Options{}, paramMap) + assert.NoError(t, err) + err = store.WriteProtobuf(ctx, outputPath, storage.Options{}, paramMap) + assert.NoError(t, err) + + err = discovery.Put(ctx, task, execID, inputPath, outputPath) + assert.NoError(t, err) + }) +} diff --git a/flytepropeller/pkg/controller/catalog/mock_catalog.go b/flytepropeller/pkg/controller/catalog/mock_catalog.go new file mode 100644 index 0000000000..03885ff614 --- /dev/null +++ b/flytepropeller/pkg/controller/catalog/mock_catalog.go @@ -0,0 +1,21 @@ +package catalog + +import ( + "context" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytestdlib/storage" +) + +type MockCatalogClient struct { + GetFunc func(ctx context.Context, task *core.TaskTemplate, inputPath storage.DataReference) (*core.LiteralMap, error) + PutFunc func(ctx context.Context, task *core.TaskTemplate, execID *core.TaskExecutionIdentifier, inputPath storage.DataReference, outputPath storage.DataReference) error +} + +func (m *MockCatalogClient) Get(ctx context.Context, task *core.TaskTemplate, inputPath storage.DataReference) (*core.LiteralMap, error) { + return m.GetFunc(ctx, task, inputPath) +} + +func (m *MockCatalogClient) Put(ctx context.Context, task *core.TaskTemplate, execID *core.TaskExecutionIdentifier, inputPath storage.DataReference, outputPath storage.DataReference) error { + return m.PutFunc(ctx, task, execID, inputPath, outputPath) +} diff --git a/flytepropeller/pkg/controller/catalog/no_op_discovery.go b/flytepropeller/pkg/controller/catalog/no_op_discovery.go new file mode 100644 index 0000000000..9d3fc93326 --- /dev/null +++ b/flytepropeller/pkg/controller/catalog/no_op_discovery.go @@ -0,0 +1,29 @@ +package catalog + +import ( + "context" + + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/storage" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// NoOpDiscovery +type NoOpDiscovery struct{} + +func (d *NoOpDiscovery) Get(ctx context.Context, task *core.TaskTemplate, inputPath storage.DataReference) (*core.LiteralMap, error) { + logger.Infof(ctx, "No-op Discovery Get invoked. Returning NotFound") + return nil, status.Error(codes.NotFound, "No-op Discovery default behavior.") +} + +func (d *NoOpDiscovery) Put(ctx context.Context, task *core.TaskTemplate, execID *core.TaskExecutionIdentifier, inputPath storage.DataReference, outputPath storage.DataReference) error { + logger.Infof(ctx, "No-op Discovery Put invoked. Doing nothing") + return nil +} + +func NewNoOpDiscovery() *NoOpDiscovery { + return &NoOpDiscovery{} +} diff --git a/flytepropeller/pkg/controller/catalog/no_op_discovery_test.go b/flytepropeller/pkg/controller/catalog/no_op_discovery_test.go new file mode 100644 index 0000000000..ded727c872 --- /dev/null +++ b/flytepropeller/pkg/controller/catalog/no_op_discovery_test.go @@ -0,0 +1,26 @@ +package catalog + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +var noopDiscovery Client = &NoOpDiscovery{} + +func TestNoopDiscovery_Get(t *testing.T) { + ctx := context.Background() + resp, err := noopDiscovery.Get(ctx, nil, "") + assert.Nil(t, resp) + assert.Error(t, err) + assert.True(t, status.Code(err) == codes.NotFound) +} + +func TestNoopDiscovery_Put(t *testing.T) { + ctx := context.Background() + err := noopDiscovery.Put(ctx, nil, nil, "", "") + assert.Nil(t, err) +} diff --git a/flytepropeller/pkg/controller/completed_workflows.go b/flytepropeller/pkg/controller/completed_workflows.go new file mode 100644 index 0000000000..43197c7154 --- /dev/null +++ b/flytepropeller/pkg/controller/completed_workflows.go @@ -0,0 +1,87 @@ +package controller + +import ( + "strconv" + "time" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +const controllerAgentName = "flyteworkflow-controller" +const workflowTerminationStatusKey = "termination-status" +const workflowTerminatedValue = "terminated" +const hourOfDayCompletedKey = "hour-of-day" + +// This function creates a label selector, that will ignore all objects (in this case workflow) that DOES NOT have a +// label key=workflowTerminationStatusKey with a value=workflowTerminatedValue +func IgnoreCompletedWorkflowsLabelSelector() *v1.LabelSelector { + return &v1.LabelSelector{ + MatchExpressions: []v1.LabelSelectorRequirement{ + { + Key: workflowTerminationStatusKey, + Operator: v1.LabelSelectorOpNotIn, + Values: []string{workflowTerminatedValue}, + }, + }, + } +} + +// Creates a new LabelSelector that selects all workflows that have the completed Label +func CompletedWorkflowsLabelSelector() *v1.LabelSelector { + return &v1.LabelSelector{ + MatchLabels: map[string]string{ + workflowTerminationStatusKey: workflowTerminatedValue, + }, + } +} + +func SetCompletedLabel(w *v1alpha1.FlyteWorkflow, currentTime time.Time) { + if w.Labels == nil { + w.Labels = make(map[string]string) + } + w.Labels[workflowTerminationStatusKey] = workflowTerminatedValue + w.Labels[hourOfDayCompletedKey] = strconv.Itoa(currentTime.Hour()) +} + +func HasCompletedLabel(w *v1alpha1.FlyteWorkflow) bool { + if w.Labels != nil { + v, ok := w.Labels[workflowTerminationStatusKey] + if ok { + return v == workflowTerminatedValue + } + } + return false +} + +// Calculates a list of all the hours that should be deleted given the current hour of the day and the retentionperiod in hours +// Usually this is a list of all hours out of the 24 hours in the day - retention period - the current hour of the day +func CalculateHoursToDelete(retentionPeriodHours, currentHourOfDay int) []string { + numberOfHoursToDelete := 24 - retentionPeriodHours + hoursToDelete := make([]string, 0, numberOfHoursToDelete) + + for i := 0; i < currentHourOfDay-retentionPeriodHours; i++ { + hoursToDelete = append(hoursToDelete, strconv.Itoa(i)) + } + maxHourOfDay := 24 + if currentHourOfDay-retentionPeriodHours < 0 { + maxHourOfDay = 24 + (currentHourOfDay - retentionPeriodHours) + } + for i := currentHourOfDay + 1; i < maxHourOfDay; i++ { + hoursToDelete = append(hoursToDelete, strconv.Itoa(i)) + } + return hoursToDelete +} + +// Creates a new selector that selects all completed workflows and workflows with completed hour label outside of the +// retention window +func CompletedWorkflowsSelectorOutsideRetentionPeriod(retentionPeriodHours int, currentTime time.Time) *v1.LabelSelector { + hoursToDelete := CalculateHoursToDelete(retentionPeriodHours, currentTime.Hour()) + s := CompletedWorkflowsLabelSelector() + s.MatchExpressions = append(s.MatchExpressions, v1.LabelSelectorRequirement{ + Key: hourOfDayCompletedKey, + Operator: v1.LabelSelectorOpIn, + Values: hoursToDelete, + }) + return s +} diff --git a/flytepropeller/pkg/controller/completed_workflows_test.go b/flytepropeller/pkg/controller/completed_workflows_test.go new file mode 100644 index 0000000000..7b57d938d8 --- /dev/null +++ b/flytepropeller/pkg/controller/completed_workflows_test.go @@ -0,0 +1,160 @@ +package controller + +import ( + "testing" + "time" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/stretchr/testify/assert" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +func TestIgnoreCompletedWorkflowsLabelSelector(t *testing.T) { + s := IgnoreCompletedWorkflowsLabelSelector() + assert.NotNil(t, s) + assert.Empty(t, s.MatchLabels) + assert.NotEmpty(t, s.MatchExpressions) + r := s.MatchExpressions[0] + assert.Equal(t, workflowTerminationStatusKey, r.Key) + assert.Equal(t, v1.LabelSelectorOpNotIn, r.Operator) + assert.Equal(t, []string{workflowTerminatedValue}, r.Values) +} + +func TestCompletedWorkflowsLabelSelector(t *testing.T) { + s := CompletedWorkflowsLabelSelector() + assert.NotEmpty(t, s.MatchLabels) + v, ok := s.MatchLabels[workflowTerminationStatusKey] + assert.True(t, ok) + assert.Equal(t, workflowTerminatedValue, v) +} + +func TestHasCompletedLabel(t *testing.T) { + + n := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC) + t.Run("no-labels", func(t *testing.T) { + + w := &v1alpha1.FlyteWorkflow{} + assert.Empty(t, w.Labels) + assert.False(t, HasCompletedLabel(w)) + SetCompletedLabel(w, n) + assert.NotEmpty(t, w.Labels) + v, ok := w.Labels[workflowTerminationStatusKey] + assert.True(t, ok) + assert.Equal(t, workflowTerminatedValue, v) + assert.True(t, HasCompletedLabel(w)) + }) + + t.Run("existing-lables", func(t *testing.T) { + w := &v1alpha1.FlyteWorkflow{ + ObjectMeta: v1.ObjectMeta{ + Labels: map[string]string{ + "x": "v", + }, + }, + } + assert.NotEmpty(t, w.Labels) + assert.False(t, HasCompletedLabel(w)) + SetCompletedLabel(w, n) + assert.NotEmpty(t, w.Labels) + v, ok := w.Labels[workflowTerminationStatusKey] + assert.True(t, ok) + assert.Equal(t, workflowTerminatedValue, v) + v, ok = w.Labels["x"] + assert.True(t, ok) + assert.Equal(t, "v", v) + assert.True(t, HasCompletedLabel(w)) + }) +} + +func TestSetCompletedLabel(t *testing.T) { + n := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC) + t.Run("no-labels", func(t *testing.T) { + + w := &v1alpha1.FlyteWorkflow{} + assert.Empty(t, w.Labels) + SetCompletedLabel(w, n) + assert.NotEmpty(t, w.Labels) + v, ok := w.Labels[workflowTerminationStatusKey] + assert.True(t, ok) + assert.Equal(t, workflowTerminatedValue, v) + }) + + t.Run("existing-lables", func(t *testing.T) { + w := &v1alpha1.FlyteWorkflow{ + ObjectMeta: v1.ObjectMeta{ + Labels: map[string]string{ + "x": "v", + }, + }, + } + assert.NotEmpty(t, w.Labels) + SetCompletedLabel(w, n) + assert.NotEmpty(t, w.Labels) + v, ok := w.Labels[workflowTerminationStatusKey] + assert.True(t, ok) + assert.Equal(t, workflowTerminatedValue, v) + v, ok = w.Labels["x"] + assert.True(t, ok) + assert.Equal(t, "v", v) + }) + +} + +func TestCalculateHoursToDelete(t *testing.T) { + assert.Equal(t, []string{ + "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", "21", "22", + }, CalculateHoursToDelete(6, 5)) + + assert.Equal(t, []string{ + "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", "21", "22", "23", + }, CalculateHoursToDelete(6, 6)) + + assert.Equal(t, []string{ + "0", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", "21", "22", "23", + }, CalculateHoursToDelete(6, 7)) + + assert.Equal(t, []string{ + "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "23", + }, CalculateHoursToDelete(6, 22)) + + assert.Equal(t, []string{ + "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", + }, CalculateHoursToDelete(6, 23)) + + assert.Equal(t, []string{ + "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", "21", "22", + }, CalculateHoursToDelete(0, 23)) + + assert.Equal(t, []string{ + "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "21", "22", "23", + }, CalculateHoursToDelete(0, 20)) + + assert.Equal(t, []string{ + "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", "21", "22", "23", + }, CalculateHoursToDelete(0, 0)) + + assert.Equal(t, []string{ + "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "13", "14", "15", "16", "17", "18", "19", "20", "21", "22", "23", + }, CalculateHoursToDelete(0, 12)) + + assert.Equal(t, []string{"13"}, CalculateHoursToDelete(22, 12)) + assert.Equal(t, []string{"1"}, CalculateHoursToDelete(22, 0)) + assert.Equal(t, []string{"0"}, CalculateHoursToDelete(22, 23)) + assert.Equal(t, []string{"23"}, CalculateHoursToDelete(22, 22)) +} + +func TestCompletedWorkflowsSelectorOutsideRetentionPeriod(t *testing.T) { + n := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC) + s := CompletedWorkflowsSelectorOutsideRetentionPeriod(2, n) + v, ok := s.MatchLabels[workflowTerminationStatusKey] + assert.True(t, ok) + assert.Equal(t, workflowTerminatedValue, v) + assert.NotEmpty(t, s.MatchExpressions) + r := s.MatchExpressions[0] + assert.Equal(t, hourOfDayCompletedKey, r.Key) + assert.Equal(t, v1.LabelSelectorOpIn, r.Operator) + assert.Equal(t, 21, len(r.Values)) + assert.Equal(t, []string{ + "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", + }, r.Values) +} diff --git a/flytepropeller/pkg/controller/composite_workqueue.go b/flytepropeller/pkg/controller/composite_workqueue.go new file mode 100644 index 0000000000..a03e79e115 --- /dev/null +++ b/flytepropeller/pkg/controller/composite_workqueue.go @@ -0,0 +1,172 @@ +package controller + +import ( + "context" + "time" + + "github.com/lyft/flytepropeller/pkg/controller/config" + + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/promutils" + "github.com/pkg/errors" + "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/client-go/util/workqueue" +) + +// A CompositeWorkQueue can be used in cases where the work is enqueued by two sources. It can be enqueued by either +// 1. Informer for the Primary Object itself. In case of FlytePropeller, this is the workflow object +// 2. Informer or any other process that enqueues the top-level object for re-evaluation in response to one of the +// sub-objects being ready. In the case of FlytePropeller this is the "Node/Task" updates, will re-enqueue the workflow +// to be re-evaluated +type CompositeWorkQueue interface { + workqueue.RateLimitingInterface + // Specialized interface that should be called to start the migration of work from SubQueue to primaryQueue + Start(ctx context.Context) + // Shutsdown all the queues that are in the context + ShutdownAll() + // Adds the item explicitly to the subqueue + AddToSubQueue(item interface{}) + // Adds the item explicitly to the subqueue, using a rate limiter + AddToSubQueueRateLimited(item interface{}) + // Adds the item explicitly to the subqueue after some duration + AddToSubQueueAfter(item interface{}, duration time.Duration) +} + +// SimpleWorkQueue provides a simple RateLimitingInterface, but ensures that the compositeQueue interface works +// with a default queue. +type SimpleWorkQueue struct { + // workqueue is a rate limited work queue. This is used to queue work to be + // processed instead of performing it as soon as a change happens. This + // means we can ensure we only process a fixed amount of resources at a + // time, and makes it easy to ensure we are never processing the same item + // simultaneously in two different workers. + workqueue.RateLimitingInterface +} + +func (s *SimpleWorkQueue) Start(ctx context.Context) { +} + +func (s *SimpleWorkQueue) ShutdownAll() { + s.ShutDown() +} + +func (s *SimpleWorkQueue) AddToSubQueue(item interface{}) { + s.Add(item) +} + +func (s *SimpleWorkQueue) AddToSubQueueAfter(item interface{}, duration time.Duration) { + s.AddAfter(item, duration) +} + +func (s *SimpleWorkQueue) AddToSubQueueRateLimited(item interface{}) { + s.AddRateLimited(item) +} + +// A BatchingWorkQueue consists of 2 queues and migrates items from sub-queue to parent queue as a batch at a specified +// interval +type BatchingWorkQueue struct { + // workqueue is a rate limited work queue. This is used to queue work to be + // processed instead of performing it as soon as a change happens. This + // means we can ensure we only process a fixed amount of resources at a + // time, and makes it easy to ensure we are never processing the same item + // simultaneously in two different workers. + workqueue.RateLimitingInterface + + subQueue workqueue.RateLimitingInterface + batchingInterval time.Duration + batchSize int +} + +func (b *BatchingWorkQueue) Start(ctx context.Context) { + logger.Infof(ctx, "Batching queue started") + go wait.Until(func() { + b.runSubQueueHandler(ctx) + }, b.batchingInterval, ctx.Done()) +} + +func (b *BatchingWorkQueue) runSubQueueHandler(ctx context.Context) { + logger.Debugf(ctx, "Subqueue handler batch round") + defer logger.Debugf(ctx, "Exiting SubQueue handler batch round") + if b.subQueue.ShuttingDown() { + return + } + numToRetrieve := b.batchSize + if b.batchSize == -1 || b.batchSize > b.subQueue.Len() { + numToRetrieve = b.subQueue.Len() + } + + logger.Debugf(ctx, "Dynamically configured batch size [%d]", b.batchSize) + // Run batches forever + objectsRetrieved := make([]interface{}, numToRetrieve) + for i := 0; i < numToRetrieve; i++ { + obj, shutdown := b.subQueue.Get() + if obj != nil { + // We expect strings to come off the workqueue. These are of the + // form namespace/name. We do this as the delayed nature of the + // workqueue means the items in the informer cache may actually be + // more up to date that when the item was initially put onto the + // workqueue. + if key, ok := obj.(string); ok { + objectsRetrieved[i] = key + } + } + if shutdown { + logger.Warningf(ctx, "NodeQ shutdown invoked. Shutting down poller.") + // We cannot add after shutdown, so just quit! + return + } + + } + + for _, obj := range objectsRetrieved { + b.Add(obj) + // Finally, if no error occurs we Forget this item so it does not + // get queued again until another change happens. + b.subQueue.Forget(obj) + b.subQueue.Done(obj) + } + +} + +func (b *BatchingWorkQueue) ShutdownAll() { + b.subQueue.ShutDown() + b.ShutDown() +} + +func (b *BatchingWorkQueue) AddToSubQueue(item interface{}) { + b.subQueue.Add(item) +} + +func (b *BatchingWorkQueue) AddToSubQueueAfter(item interface{}, duration time.Duration) { + b.subQueue.AddAfter(item, duration) +} + +func (b *BatchingWorkQueue) AddToSubQueueRateLimited(item interface{}) { + b.subQueue.AddRateLimited(item) +} + +func NewCompositeWorkQueue(ctx context.Context, cfg config.CompositeQueueConfig, scope promutils.Scope) (CompositeWorkQueue, error) { + workQ, err := NewWorkQueue(ctx, cfg.Queue, scope.CurrentScope()) + if err != nil { + return nil, errors.Wrapf(err, "failed to create WorkQueue in CompositeQueue type Batch") + } + switch cfg.Type { + case config.CompositeQueueBatch: + subQ, err := NewWorkQueue(ctx, cfg.Sub, scope.NewSubScope("sub").CurrentScope()) + if err != nil { + return nil, errors.Wrapf(err, "failed to create SubQueue in CompositeQueue type Batch") + } + return &BatchingWorkQueue{ + RateLimitingInterface: workQ, + batchSize: cfg.BatchSize, + batchingInterval: cfg.BatchingInterval.Duration, + subQueue: subQ, + }, nil + case config.CompositeQueueSimple: + fallthrough + default: + } + return &SimpleWorkQueue{ + RateLimitingInterface: workQ, + }, nil +} diff --git a/flytepropeller/pkg/controller/composite_workqueue_test.go b/flytepropeller/pkg/controller/composite_workqueue_test.go new file mode 100644 index 0000000000..4239102ad6 --- /dev/null +++ b/flytepropeller/pkg/controller/composite_workqueue_test.go @@ -0,0 +1,146 @@ +package controller + +import ( + "context" + "testing" + "time" + + config2 "github.com/lyft/flytepropeller/pkg/controller/config" + + "github.com/lyft/flytestdlib/config" + "github.com/lyft/flytestdlib/promutils" + "github.com/stretchr/testify/assert" +) + +func TestNewCompositeWorkQueue(t *testing.T) { + ctx := context.TODO() + + t.Run("simple", func(t *testing.T) { + testScope := promutils.NewScope("test1") + cfg := config2.CompositeQueueConfig{} + q, err := NewCompositeWorkQueue(ctx, cfg, testScope) + assert.NoError(t, err) + assert.NotNil(t, q) + switch q.(type) { + case *SimpleWorkQueue: + return + default: + assert.FailNow(t, "SimpleWorkQueue expected") + } + }) + + t.Run("batch", func(t *testing.T) { + testScope := promutils.NewScope("test2") + cfg := config2.CompositeQueueConfig{ + Type: config2.CompositeQueueBatch, + BatchSize: -1, + BatchingInterval: config.Duration{Duration: time.Second * 1}, + } + q, err := NewCompositeWorkQueue(ctx, cfg, testScope) + assert.NoError(t, err) + assert.NotNil(t, q) + switch q.(type) { + case *BatchingWorkQueue: + assert.Equal(t, -1, q.(*BatchingWorkQueue).batchSize) + assert.Equal(t, time.Second*1, q.(*BatchingWorkQueue).batchingInterval) + return + default: + assert.FailNow(t, "BatchWorkQueue expected") + } + }) +} + +func TestSimpleWorkQueue(t *testing.T) { + ctx := context.TODO() + testScope := promutils.NewScope("test") + cfg := config2.CompositeQueueConfig{} + q, err := NewCompositeWorkQueue(ctx, cfg, testScope) + assert.NoError(t, err) + assert.NotNil(t, q) + + t.Run("AddSubQueue", func(t *testing.T) { + q.AddToSubQueue("x") + i, s := q.Get() + assert.False(t, s) + assert.Equal(t, "x", i.(string)) + q.Done(i) + }) + + t.Run("AddAfterSubQueue", func(t *testing.T) { + q.AddToSubQueueAfter("y", time.Nanosecond*0) + i, s := q.Get() + assert.False(t, s) + assert.Equal(t, "y", i.(string)) + q.Done(i) + }) + + t.Run("AddRateLimitedSubQueue", func(t *testing.T) { + q.AddToSubQueueRateLimited("z") + i, s := q.Get() + assert.False(t, s) + assert.Equal(t, "z", i.(string)) + q.Done(i) + }) + + t.Run("shutdown", func(t *testing.T) { + q.ShutdownAll() + _, s := q.Get() + assert.True(t, s) + }) +} + +func TestBatchingQueue(t *testing.T) { + ctx := context.TODO() + testScope := promutils.NewScope("test_batch") + cfg := config2.CompositeQueueConfig{ + Type: config2.CompositeQueueBatch, + BatchSize: -1, + BatchingInterval: config.Duration{Duration: time.Nanosecond * 1}, + } + q, err := NewCompositeWorkQueue(ctx, cfg, testScope) + assert.NoError(t, err) + assert.NotNil(t, q) + + batchQueue := q.(*BatchingWorkQueue) + + t.Run("AddSubQueue", func(t *testing.T) { + q.AddToSubQueue("x") + assert.Equal(t, 0, q.Len()) + batchQueue.runSubQueueHandler(ctx) + i, s := q.Get() + assert.False(t, s) + assert.Equal(t, "x", i.(string)) + q.Done(i) + }) + + t.Run("AddAfterSubQueue", func(t *testing.T) { + q.AddToSubQueueAfter("y", time.Nanosecond*0) + assert.Equal(t, 0, q.Len()) + batchQueue.runSubQueueHandler(ctx) + i, s := q.Get() + assert.False(t, s) + assert.Equal(t, "y", i.(string)) + q.Done(i) + }) + + t.Run("AddRateLimitedSubQueue", func(t *testing.T) { + q.AddToSubQueueRateLimited("z") + assert.Equal(t, 0, q.Len()) + batchQueue.Start(ctx) + i, s := q.Get() + assert.False(t, s) + assert.Equal(t, "z", i.(string)) + q.Done(i) + }) + + t.Run("shutdown", func(t *testing.T) { + q.AddToSubQueue("g") + q.ShutdownAll() + assert.Equal(t, 0, q.Len()) + batchQueue.runSubQueueHandler(ctx) + i, s := q.Get() + assert.True(t, s) + assert.Nil(t, i) + q.Done(i) + }) +} diff --git a/flytepropeller/pkg/controller/config/config.go b/flytepropeller/pkg/controller/config/config.go new file mode 100644 index 0000000000..3062906341 --- /dev/null +++ b/flytepropeller/pkg/controller/config/config.go @@ -0,0 +1,93 @@ +package config + +import ( + "github.com/lyft/flytestdlib/config" + "k8s.io/apimachinery/pkg/types" +) + +//go:generate pflags Config + +const configSectionKey = "propeller" + +var ConfigSection = config.MustRegisterSection(configSectionKey, &Config{}) + +// NOTE: when adding new fields, do not mark them as "omitempty" if it's desirable to read the value from env variables. +// Config that uses the flytestdlib Config module to generate commandline and load config files. This configuration is +// the base configuration to start propeller +type Config struct { + KubeConfigPath string `json:"kube-config" pflag:",Path to kubernetes client config file."` + MasterURL string `json:"master"` + Workers int `json:"workers" pflag:"2,Number of threads to process workflows"` + WorkflowReEval config.Duration `json:"workflow-reeval-duration" pflag:"\"30s\",Frequency of re-evaluating workflows"` + DownstreamEval config.Duration `json:"downstream-eval-duration" pflag:"\"60s\",Frequency of re-evaluating downstream tasks"` + LimitNamespace string `json:"limit-namespace" pflag:"\"all\",Namespaces to watch for this propeller"` + ProfilerPort config.Port `json:"prof-port" pflag:"\"10254\",Profiler port"` + MetadataPrefix string `json:"metadata-prefix,omitempty" pflag:",MetadataPrefix should be used if all the metadata for Flyte executions should be stored under a specific prefix in CloudStorage. If not specified, the data will be stored in the base container directly."` + Queue CompositeQueueConfig `json:"queue,omitempty" pflag:",Workflow workqueue configuration, affects the way the work is consumed from the queue."` + MetricsPrefix string `json:"metrics-prefix" pflag:"\"flyte:\",An optional prefix for all published metrics."` + EnableAdminLauncher bool `json:"enable-admin-launcher" pflag:"false, Enable remote Workflow launcher to Admin"` + MaxWorkflowRetries int `json:"max-workflow-retries" pflag:"50,Maximum number of retries per workflow"` + MaxTTLInHours int `json:"max-ttl-hours" pflag:"23,Maximum number of hours a completed workflow should be retained. Number between 1-23 hours"` + GCInterval config.Duration `json:"gc-interval" pflag:"\"30m\",Run periodic GC every 30 minutes"` + LeaderElection LeaderElectionConfig `json:"leader-election,omitempty" pflag:",Config for leader election."` + PublishK8sEvents bool `json:"publish-k8s-events" pflag:",Enable events publishing to K8s events API."` +} + +type CompositeQueueType = string + +const ( + CompositeQueueSimple CompositeQueueType = "simple" + CompositeQueueBatch CompositeQueueType = "batch" +) + +type CompositeQueueConfig struct { + Type CompositeQueueType `json:"type" pflag:"\"simple\",Type of composite queue to use for the WorkQueue"` + Queue WorkqueueConfig `json:"queue,omitempty" pflag:",Workflow workqueue configuration, affects the way the work is consumed from the queue."` + Sub WorkqueueConfig `json:"sub-queue,omitempty" pflag:",SubQueue configuration, affects the way the nodes cause the top-level Work to be re-evaluated."` + BatchingInterval config.Duration `json:"batching-interval" pflag:"\"1s\",Duration for which downstream updates are buffered"` + BatchSize int `json:"batch-size" pflag:"-1,Number of downstream triggered top-level objects to re-enqueue every duration. -1 indicates all available."` +} + +type WorkqueueType = string + +const ( + WorkqueueTypeDefault WorkqueueType = "default" + WorkqueueTypeBucketRateLimiter WorkqueueType = "bucket" + WorkqueueTypeExponentialFailureRateLimiter WorkqueueType = "expfailure" + WorkqueueTypeMaxOfRateLimiter WorkqueueType = "maxof" +) + +// prototypical configuration to configure a workqueue. We may want to generalize this in a package like k8sutils +type WorkqueueConfig struct { + // Refer to https://github.com/kubernetes/client-go/tree/master/util/workqueue + Type WorkqueueType `json:"type" pflag:"\"default\",Type of RateLimiter to use for the WorkQueue"` + BaseDelay config.Duration `json:"base-delay" pflag:"\"10s\",base backoff delay for failure"` + MaxDelay config.Duration `json:"max-delay" pflag:"\"10s\",Max backoff delay for failure"` + Rate int64 `json:"rate" pflag:"int64(10),Bucket Refill rate per second"` + Capacity int `json:"capacity" pflag:"100,Bucket capacity as number of items"` +} + +// Contains leader election configuration. +type LeaderElectionConfig struct { + // Enable or disable leader election. + Enabled bool `json:"enabled" pflag:",Enables/Disables leader election."` + + // Determines the name of the configmap that leader election will use for holding the leader lock. + LockConfigMap types.NamespacedName `json:"lock-config-map" pflag:",ConfigMap namespace/name to use for resource lock."` + + // Duration that non-leader candidates will wait to force acquire leadership. This is measured against time of last + // observed ack + LeaseDuration config.Duration `json:"lease-duration" pflag:"\"15s\",Duration that non-leader candidates will wait to force acquire leadership. This is measured against time of last observed ack."` + + // RenewDeadline is the duration that the acting master will retry refreshing leadership before giving up. + RenewDeadline config.Duration `json:"renew-deadline" pflag:"\"10s\",Duration that the acting master will retry refreshing leadership before giving up."` + + // RetryPeriod is the duration the LeaderElector clients should wait between tries of actions. + RetryPeriod config.Duration `json:"retry-period" pflag:"\"2s\",Duration the LeaderElector clients should wait between tries of actions."` +} + +// Extracts the Configuration from the global config module in flytestdlib and returns the corresponding type-casted object. +// TODO What if the type is incorrect? +func GetConfig() *Config { + return ConfigSection.GetConfig().(*Config) +} diff --git a/flytepropeller/pkg/controller/config/config_flags.go b/flytepropeller/pkg/controller/config/config_flags.go new file mode 100755 index 0000000000..f496572031 --- /dev/null +++ b/flytepropeller/pkg/controller/config/config_flags.go @@ -0,0 +1,78 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package config + +import ( + "encoding/json" + "reflect" + + "fmt" + + "github.com/spf13/pflag" +) + +// If v is a pointer, it will get its element value or the zero value of the element type. +// If v is not a pointer, it will return it as is. +func (Config) elemValueOrNil(v interface{}) interface{} { + if t := reflect.TypeOf(v); t.Kind() == reflect.Ptr { + if reflect.ValueOf(v).IsNil() { + return reflect.Zero(t.Elem()).Interface() + } else { + return reflect.ValueOf(v).Interface() + } + } else if v == nil { + return reflect.Zero(t).Interface() + } + + return v +} + +func (Config) mustMarshalJSON(v json.Marshaler) string { + raw, err := v.MarshalJSON() + if err != nil { + panic(err) + } + + return string(raw) +} + +// GetPFlagSet will return strongly types pflags for all fields in Config and its nested types. The format of the +// flags is json-name.json-sub-name... etc. +func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { + cmdFlags := pflag.NewFlagSet("Config", pflag.ExitOnError) + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "kube-config"), *new(string), "Path to kubernetes client config file.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "master"), *new(string), "") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "workers"), 2, "Number of threads to process workflows") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "workflow-reeval-duration"), "30s", "Frequency of re-evaluating workflows") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "downstream-eval-duration"), "60s", "Frequency of re-evaluating downstream tasks") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "limit-namespace"), "all", "Namespaces to watch for this propeller") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "prof-port"), "10254", "Profiler port") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "metadata-prefix"), *new(string), "MetadataPrefix should be used if all the metadata for Flyte executions should be stored under a specific prefix in CloudStorage. If not specified, the data will be stored in the base container directly.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "queue.type"), "simple", "Type of composite queue to use for the WorkQueue") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "queue.queue.type"), "default", "Type of RateLimiter to use for the WorkQueue") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "queue.queue.base-delay"), "10s", "base backoff delay for failure") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "queue.queue.max-delay"), "10s", "Max backoff delay for failure") + cmdFlags.Int64(fmt.Sprintf("%v%v", prefix, "queue.queue.rate"), int64(10), "Bucket Refill rate per second") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "queue.queue.capacity"), 100, "Bucket capacity as number of items") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "queue.sub-queue.type"), "default", "Type of RateLimiter to use for the WorkQueue") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "queue.sub-queue.base-delay"), "10s", "base backoff delay for failure") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "queue.sub-queue.max-delay"), "10s", "Max backoff delay for failure") + cmdFlags.Int64(fmt.Sprintf("%v%v", prefix, "queue.sub-queue.rate"), int64(10), "Bucket Refill rate per second") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "queue.sub-queue.capacity"), 100, "Bucket capacity as number of items") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "queue.batching-interval"), "1s", "Duration for which downstream updates are buffered") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "queue.batch-size"), -1, "Number of downstream triggered top-level objects to re-enqueue every duration. -1 indicates all available.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "metrics-prefix"), "flyte:", "An optional prefix for all published metrics.") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "enable-admin-launcher"), false, " Enable remote Workflow launcher to Admin") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "max-workflow-retries"), 50, "Maximum number of retries per workflow") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "max-ttl-hours"), 23, "Maximum number of hours a completed workflow should be retained. Number between 1-23 hours") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "gc-interval"), "30m", "Run periodic GC every 30 minutes") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "leader-election.enabled"), *new(bool), "Enables/Disables leader election.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "leader-election.lock-config-map.Namespace"), *new(string), "") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "leader-election.lock-config-map.Name"), *new(string), "") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "leader-election.lease-duration"), "15s", "Duration that non-leader candidates will wait to force acquire leadership. This is measured against time of last observed ack.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "leader-election.renew-deadline"), "10s", "Duration that the acting master will retry refreshing leadership before giving up.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "leader-election.retry-period"), "2s", "Duration the LeaderElector clients should wait between tries of actions.") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "publish-k8s-events"), *new(bool), "Enable events publishing to K8s events API.") + return cmdFlags +} diff --git a/flytepropeller/pkg/controller/config/config_flags_test.go b/flytepropeller/pkg/controller/config/config_flags_test.go new file mode 100755 index 0000000000..b7f3ff4a8c --- /dev/null +++ b/flytepropeller/pkg/controller/config/config_flags_test.go @@ -0,0 +1,828 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package config + +import ( + "encoding/json" + "fmt" + "reflect" + "strings" + "testing" + + "github.com/mitchellh/mapstructure" + "github.com/stretchr/testify/assert" +) + +var dereferencableKindsConfig = map[reflect.Kind]struct{}{ + reflect.Array: {}, reflect.Chan: {}, reflect.Map: {}, reflect.Ptr: {}, reflect.Slice: {}, +} + +// Checks if t is a kind that can be dereferenced to get its underlying type. +func canGetElementConfig(t reflect.Kind) bool { + _, exists := dereferencableKindsConfig[t] + return exists +} + +// This decoder hook tests types for json unmarshaling capability. If implemented, it uses json unmarshal to build the +// object. Otherwise, it'll just pass on the original data. +func jsonUnmarshalerHookConfig(_, to reflect.Type, data interface{}) (interface{}, error) { + unmarshalerType := reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() + if to.Implements(unmarshalerType) || reflect.PtrTo(to).Implements(unmarshalerType) || + (canGetElementConfig(to.Kind()) && to.Elem().Implements(unmarshalerType)) { + + raw, err := json.Marshal(data) + if err != nil { + fmt.Printf("Failed to marshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + res := reflect.New(to).Interface() + err = json.Unmarshal(raw, &res) + if err != nil { + fmt.Printf("Failed to umarshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + return res, nil + } + + return data, nil +} + +func decode_Config(input, result interface{}) error { + config := &mapstructure.DecoderConfig{ + TagName: "json", + WeaklyTypedInput: true, + Result: result, + DecodeHook: mapstructure.ComposeDecodeHookFunc( + mapstructure.StringToTimeDurationHookFunc(), + mapstructure.StringToSliceHookFunc(","), + jsonUnmarshalerHookConfig, + ), + } + + decoder, err := mapstructure.NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(input) +} + +func join_Config(arr interface{}, sep string) string { + listValue := reflect.ValueOf(arr) + strs := make([]string, 0, listValue.Len()) + for i := 0; i < listValue.Len(); i++ { + strs = append(strs, fmt.Sprintf("%v", listValue.Index(i))) + } + + return strings.Join(strs, sep) +} + +func testDecodeJson_Config(t *testing.T, val, result interface{}) { + assert.NoError(t, decode_Config(val, result)) +} + +func testDecodeSlice_Config(t *testing.T, vStringSlice, result interface{}) { + assert.NoError(t, decode_Config(vStringSlice, result)) +} + +func TestConfig_GetPFlagSet(t *testing.T) { + val := Config{} + cmdFlags := val.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) +} + +func TestConfig_SetFlags(t *testing.T) { + actual := Config{} + cmdFlags := actual.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) + + t.Run("Test_kube-config", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("kube-config"); err == nil { + assert.Equal(t, string(*new(string)), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("kube-config", testValue) + if vString, err := cmdFlags.GetString("kube-config"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.KubeConfigPath) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_master", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("master"); err == nil { + assert.Equal(t, string(*new(string)), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("master", testValue) + if vString, err := cmdFlags.GetString("master"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.MasterURL) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_workers", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("workers"); err == nil { + assert.Equal(t, int(2), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("workers", testValue) + if vInt, err := cmdFlags.GetInt("workers"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.Workers) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_workflow-reeval-duration", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("workflow-reeval-duration"); err == nil { + assert.Equal(t, string("30s"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "30s" + + cmdFlags.Set("workflow-reeval-duration", testValue) + if vString, err := cmdFlags.GetString("workflow-reeval-duration"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.WorkflowReEval) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_downstream-eval-duration", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("downstream-eval-duration"); err == nil { + assert.Equal(t, string("60s"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "60s" + + cmdFlags.Set("downstream-eval-duration", testValue) + if vString, err := cmdFlags.GetString("downstream-eval-duration"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.DownstreamEval) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_limit-namespace", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("limit-namespace"); err == nil { + assert.Equal(t, string("all"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("limit-namespace", testValue) + if vString, err := cmdFlags.GetString("limit-namespace"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.LimitNamespace) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_prof-port", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("prof-port"); err == nil { + assert.Equal(t, string("10254"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "10254" + + cmdFlags.Set("prof-port", testValue) + if vString, err := cmdFlags.GetString("prof-port"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.ProfilerPort) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_metadata-prefix", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("metadata-prefix"); err == nil { + assert.Equal(t, string(*new(string)), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("metadata-prefix", testValue) + if vString, err := cmdFlags.GetString("metadata-prefix"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.MetadataPrefix) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_queue.type", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("queue.type"); err == nil { + assert.Equal(t, string("simple"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("queue.type", testValue) + if vString, err := cmdFlags.GetString("queue.type"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Queue.Type) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_queue.queue.type", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("queue.queue.type"); err == nil { + assert.Equal(t, string("default"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("queue.queue.type", testValue) + if vString, err := cmdFlags.GetString("queue.queue.type"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Queue.Queue.Type) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_queue.queue.base-delay", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("queue.queue.base-delay"); err == nil { + assert.Equal(t, string("10s"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "10s" + + cmdFlags.Set("queue.queue.base-delay", testValue) + if vString, err := cmdFlags.GetString("queue.queue.base-delay"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Queue.Queue.BaseDelay) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_queue.queue.max-delay", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("queue.queue.max-delay"); err == nil { + assert.Equal(t, string("10s"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "10s" + + cmdFlags.Set("queue.queue.max-delay", testValue) + if vString, err := cmdFlags.GetString("queue.queue.max-delay"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Queue.Queue.MaxDelay) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_queue.queue.rate", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt64, err := cmdFlags.GetInt64("queue.queue.rate"); err == nil { + assert.Equal(t, int64(int64(10)), vInt64) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("queue.queue.rate", testValue) + if vInt64, err := cmdFlags.GetInt64("queue.queue.rate"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt64), &actual.Queue.Queue.Rate) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_queue.queue.capacity", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("queue.queue.capacity"); err == nil { + assert.Equal(t, int(100), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("queue.queue.capacity", testValue) + if vInt, err := cmdFlags.GetInt("queue.queue.capacity"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.Queue.Queue.Capacity) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_queue.sub-queue.type", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("queue.sub-queue.type"); err == nil { + assert.Equal(t, string("default"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("queue.sub-queue.type", testValue) + if vString, err := cmdFlags.GetString("queue.sub-queue.type"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Queue.Sub.Type) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_queue.sub-queue.base-delay", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("queue.sub-queue.base-delay"); err == nil { + assert.Equal(t, string("10s"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "10s" + + cmdFlags.Set("queue.sub-queue.base-delay", testValue) + if vString, err := cmdFlags.GetString("queue.sub-queue.base-delay"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Queue.Sub.BaseDelay) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_queue.sub-queue.max-delay", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("queue.sub-queue.max-delay"); err == nil { + assert.Equal(t, string("10s"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "10s" + + cmdFlags.Set("queue.sub-queue.max-delay", testValue) + if vString, err := cmdFlags.GetString("queue.sub-queue.max-delay"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Queue.Sub.MaxDelay) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_queue.sub-queue.rate", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt64, err := cmdFlags.GetInt64("queue.sub-queue.rate"); err == nil { + assert.Equal(t, int64(int64(10)), vInt64) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("queue.sub-queue.rate", testValue) + if vInt64, err := cmdFlags.GetInt64("queue.sub-queue.rate"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt64), &actual.Queue.Sub.Rate) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_queue.sub-queue.capacity", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("queue.sub-queue.capacity"); err == nil { + assert.Equal(t, int(100), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("queue.sub-queue.capacity", testValue) + if vInt, err := cmdFlags.GetInt("queue.sub-queue.capacity"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.Queue.Sub.Capacity) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_queue.batching-interval", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("queue.batching-interval"); err == nil { + assert.Equal(t, string("1s"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1s" + + cmdFlags.Set("queue.batching-interval", testValue) + if vString, err := cmdFlags.GetString("queue.batching-interval"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Queue.BatchingInterval) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_queue.batch-size", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("queue.batch-size"); err == nil { + assert.Equal(t, int(-1), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("queue.batch-size", testValue) + if vInt, err := cmdFlags.GetInt("queue.batch-size"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.Queue.BatchSize) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_metrics-prefix", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("metrics-prefix"); err == nil { + assert.Equal(t, string("flyte:"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("metrics-prefix", testValue) + if vString, err := cmdFlags.GetString("metrics-prefix"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.MetricsPrefix) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_enable-admin-launcher", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vBool, err := cmdFlags.GetBool("enable-admin-launcher"); err == nil { + assert.Equal(t, bool(false), vBool) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("enable-admin-launcher", testValue) + if vBool, err := cmdFlags.GetBool("enable-admin-launcher"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vBool), &actual.EnableAdminLauncher) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_max-workflow-retries", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("max-workflow-retries"); err == nil { + assert.Equal(t, int(50), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("max-workflow-retries", testValue) + if vInt, err := cmdFlags.GetInt("max-workflow-retries"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.MaxWorkflowRetries) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_max-ttl-hours", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("max-ttl-hours"); err == nil { + assert.Equal(t, int(23), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("max-ttl-hours", testValue) + if vInt, err := cmdFlags.GetInt("max-ttl-hours"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.MaxTTLInHours) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_gc-interval", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("gc-interval"); err == nil { + assert.Equal(t, string("30m"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "30m" + + cmdFlags.Set("gc-interval", testValue) + if vString, err := cmdFlags.GetString("gc-interval"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.GCInterval) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_leader-election.enabled", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vBool, err := cmdFlags.GetBool("leader-election.enabled"); err == nil { + assert.Equal(t, bool(*new(bool)), vBool) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("leader-election.enabled", testValue) + if vBool, err := cmdFlags.GetBool("leader-election.enabled"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vBool), &actual.LeaderElection.Enabled) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_leader-election.lock-config-map.Namespace", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("leader-election.lock-config-map.Namespace"); err == nil { + assert.Equal(t, string(*new(string)), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("leader-election.lock-config-map.Namespace", testValue) + if vString, err := cmdFlags.GetString("leader-election.lock-config-map.Namespace"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.LeaderElection.LockConfigMap.Namespace) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_leader-election.lock-config-map.Name", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("leader-election.lock-config-map.Name"); err == nil { + assert.Equal(t, string(*new(string)), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("leader-election.lock-config-map.Name", testValue) + if vString, err := cmdFlags.GetString("leader-election.lock-config-map.Name"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.LeaderElection.LockConfigMap.Name) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_leader-election.lease-duration", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("leader-election.lease-duration"); err == nil { + assert.Equal(t, string("15s"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "15s" + + cmdFlags.Set("leader-election.lease-duration", testValue) + if vString, err := cmdFlags.GetString("leader-election.lease-duration"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.LeaderElection.LeaseDuration) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_leader-election.renew-deadline", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("leader-election.renew-deadline"); err == nil { + assert.Equal(t, string("10s"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "10s" + + cmdFlags.Set("leader-election.renew-deadline", testValue) + if vString, err := cmdFlags.GetString("leader-election.renew-deadline"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.LeaderElection.RenewDeadline) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_leader-election.retry-period", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("leader-election.retry-period"); err == nil { + assert.Equal(t, string("2s"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "2s" + + cmdFlags.Set("leader-election.retry-period", testValue) + if vString, err := cmdFlags.GetString("leader-election.retry-period"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.LeaderElection.RetryPeriod) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_publish-k8s-events", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vBool, err := cmdFlags.GetBool("publish-k8s-events"); err == nil { + assert.Equal(t, bool(*new(bool)), vBool) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("publish-k8s-events", testValue) + if vBool, err := cmdFlags.GetBool("publish-k8s-events"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vBool), &actual.PublishK8sEvents) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) +} diff --git a/flytepropeller/pkg/controller/controller.go b/flytepropeller/pkg/controller/controller.go new file mode 100644 index 0000000000..25d8fb4316 --- /dev/null +++ b/flytepropeller/pkg/controller/controller.go @@ -0,0 +1,305 @@ +package controller + +import ( + "context" + + "github.com/lyft/flytepropeller/pkg/controller/executors" + + "github.com/lyft/flytepropeller/pkg/controller/config" + "github.com/lyft/flytepropeller/pkg/controller/workflowstore" + + "github.com/lyft/flyteidl/clients/go/admin" + "github.com/lyft/flyteidl/clients/go/events" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/storage" + "github.com/pkg/errors" + "github.com/prometheus/client_golang/prometheus" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/util/clock" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/kubernetes/scheme" + typedcorev1 "k8s.io/client-go/kubernetes/typed/core/v1" + "k8s.io/client-go/tools/cache" + "k8s.io/client-go/tools/leaderelection" + "k8s.io/client-go/tools/record" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + clientset "github.com/lyft/flytepropeller/pkg/client/clientset/versioned" + flyteScheme "github.com/lyft/flytepropeller/pkg/client/clientset/versioned/scheme" + informers "github.com/lyft/flytepropeller/pkg/client/informers/externalversions" + "github.com/lyft/flytepropeller/pkg/controller/catalog" + "github.com/lyft/flytepropeller/pkg/controller/nodes" + "github.com/lyft/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" + "github.com/lyft/flytepropeller/pkg/controller/workflow" +) + +type metrics struct { + Scope promutils.Scope + EnqueueCountWf prometheus.Counter + EnqueueCountTask prometheus.Counter +} + +// Controller is the controller implementation for FlyteWorkflow resources +type Controller struct { + workerPool *WorkerPool + flyteworkflowSynced cache.InformerSynced + workQueue CompositeWorkQueue + gc *GarbageCollector + numWorkers int + workflowStore workflowstore.FlyteWorkflow + // recorder is an event recorder for recording Event resources to the + // Kubernetes API. + recorder record.EventRecorder + metrics *metrics + leaderElector *leaderelection.LeaderElector +} + +// Runs either as a leader -if configured- or as a standalone process. +func (c *Controller) Run(ctx context.Context) error { + if c.leaderElector == nil { + logger.Infof(ctx, "Running without leader election.") + return c.run(ctx) + } + + logger.Infof(ctx, "Attempting to acquire leader lease and act as leader.") + go c.leaderElector.Run(ctx) + <-ctx.Done() + return nil +} + +// Start the actual work of controller (e.g. GC, consume and process queue items... etc.) +func (c *Controller) run(ctx context.Context) error { + // Initializing WorkerPool + logger.Info(ctx, "Initializing controller") + if err := c.workerPool.Initialize(ctx); err != nil { + return err + } + + // Start the GC + if err := c.gc.StartGC(ctx); err != nil { + logger.Errorf(ctx, "failed to start background GC") + return err + } + + // Start the informer factories to begin populating the informer caches + logger.Info(ctx, "Starting FlyteWorkflow controller") + return c.workerPool.Run(ctx, c.numWorkers, c.flyteworkflowSynced) +} + +// Called from leader elector -if configured- to start running as the leader. +func (c *Controller) onStartedLeading(ctx context.Context) { + ctx, cancelNow := context.WithCancel(context.Background()) + logger.Infof(ctx, "Acquired leader lease.") + go func() { + if err := c.run(ctx); err != nil { + logger.Panic(ctx, err) + } + }() + + <-ctx.Done() + logger.Infof(ctx, "Lost leader lease.") + cancelNow() +} + +// enqueueFlyteWorkflow takes a FlyteWorkflow resource and converts it into a namespace/name +// string which is then put onto the work queue. This method should *not* be +// passed resources of any type other than FlyteWorkflow. +func (c *Controller) enqueueFlyteWorkflow(obj interface{}) { + ctx := context.TODO() + wf, ok := obj.(*v1alpha1.FlyteWorkflow) + if !ok { + logger.Errorf(ctx, "Received a non Workflow object") + return + } + key := wf.GetK8sWorkflowID() + logger.Infof(ctx, "==> Enqueueing workflow [%v]", key) + c.workQueue.Add(key.String()) +} + +func (c *Controller) enqueueWorkflowForNodeUpdates(wID v1alpha1.WorkflowID) { + if wID == "" { + return + } + namespace, name, err := cache.SplitMetaNamespaceKey(wID) + if err != nil { + if _, err2 := c.workflowStore.Get(context.TODO(), namespace, name); err2 != nil { + if workflowstore.IsNotFound(err) { + // Workflow is not found in storage, was probably deleted, but one of the sub-objects sent an event + return + } + } + c.metrics.EnqueueCountTask.Inc() + c.workQueue.AddToSubQueue(wID) + } +} + +func (c *Controller) getWorkflowUpdatesHandler() cache.ResourceEventHandler { + return cache.ResourceEventHandlerFuncs{ + AddFunc: c.enqueueFlyteWorkflow, + UpdateFunc: func(old, new interface{}) { + // TODO we might need to handle updates to the workflow itself. + // Initially maybe we should not support it at all + c.enqueueFlyteWorkflow(new) + }, + DeleteFunc: func(obj interface{}) { + // There is a corner case where the obj is not in fact a valid resource (it sends a DeletedFinalStateUnknown + // object instead) -it has to do with missing some event that leads to not knowing the final state of the + // resource. In which case, we can't use the regular metaAccessor to read obj name/namespace but should + // instead use cache.DeletionHandling* helper functions that know how to deal with that. + + key, err := cache.DeletionHandlingMetaNamespaceKeyFunc(obj) + if err != nil { + logger.Errorf(context.TODO(), "Unable to get key for deleted obj. Error[%v]", err) + return + } + + _, name, err := cache.SplitMetaNamespaceKey(key) + if err != nil { + logger.Errorf(context.TODO(), "Unable to split enqueued key into namespace/execId. Error[%v]", err) + return + } + + logger.Infof(context.TODO(), "Deletion triggered for %v", name) + }, + } +} + +func newControllerMetrics(scope promutils.Scope) *metrics { + c := scope.MustNewCounterVec("wf_enqueue", "workflow enqueue count.", "type") + return &metrics{ + Scope: scope, + EnqueueCountWf: c.WithLabelValues("wf"), + EnqueueCountTask: c.WithLabelValues("task"), + } +} + +func newK8sEventRecorder(ctx context.Context, kubeclientset kubernetes.Interface, publishK8sEvents bool) record.EventRecorder { + // Create event broadcaster + // Add FlyteWorkflow controller types to the default Kubernetes Scheme so Events can be + // logged for FlyteWorkflow Controller types. + err := flyteScheme.AddToScheme(scheme.Scheme) + if err != nil { + logger.Panicf(ctx, "failed to add flyte workflows scheme, %s", err.Error()) + } + logger.Info(ctx, "Creating event broadcaster") + eventBroadcaster := record.NewBroadcaster() + eventBroadcaster.StartLogging(logger.InfofNoCtx) + if publishK8sEvents { + eventBroadcaster.StartRecordingToSink(&typedcorev1.EventSinkImpl{Interface: kubeclientset.CoreV1().Events("")}) + } + return eventBroadcaster.NewRecorder(scheme.Scheme, corev1.EventSource{Component: controllerAgentName}) +} + +// NewController returns a new FlyteWorkflow controller +func New(ctx context.Context, cfg *config.Config, kubeclientset kubernetes.Interface, flytepropellerClientset clientset.Interface, + flyteworkflowInformerFactory informers.SharedInformerFactory, kubeClient executors.Client, scope promutils.Scope) (*Controller, error) { + + var wfLauncher launchplan.Executor + if cfg.EnableAdminLauncher { + adminClient, err := admin.InitializeAdminClientFromConfig(ctx) + if err != nil { + logger.Errorf(ctx, "failed to initialize Admin client, err :%s", err.Error()) + return nil, err + } + wfLauncher, err = launchplan.NewAdminLaunchPlanExecutor(ctx, adminClient, cfg.DownstreamEval.Duration, + launchplan.GetAdminConfig(), scope.NewSubScope("admin_launcher")) + if err != nil { + logger.Errorf(ctx, "failed to create Admin workflow Launcher, err: %v", err.Error()) + return nil, err + } + + if err := wfLauncher.Initialize(ctx); err != nil { + logger.Errorf(ctx, "failed to initialize Admin workflow Launcher, err: %v", err.Error()) + return nil, err + } + } else { + wfLauncher = launchplan.NewFailFastLaunchPlanExecutor() + } + + logger.Info(ctx, "Setting up event sink and recorder") + eventSink, err := events.ConstructEventSink(ctx, events.GetConfig(ctx)) + if err != nil { + return nil, errors.Wrapf(err, "Failed to create EventSink [%v], error %v", events.GetConfig(ctx).Type, err) + } + gc, err := NewGarbageCollector(cfg, scope, clock.RealClock{}, kubeclientset.CoreV1().Namespaces(), flytepropellerClientset.FlyteworkflowV1alpha1(), cfg.LimitNamespace) + if err != nil { + logger.Errorf(ctx, "failed to initialize GC for workflows") + return nil, errors.Wrapf(err, "failed to initialize WF GC") + } + + eventRecorder := newK8sEventRecorder(ctx, kubeclientset, cfg.PublishK8sEvents) + controller := &Controller{ + metrics: newControllerMetrics(scope), + recorder: eventRecorder, + gc: gc, + numWorkers: cfg.Workers, + } + + lock, err := newResourceLock(kubeclientset.CoreV1(), eventRecorder, cfg.LeaderElection) + if err != nil { + logger.Errorf(ctx, "failed to initialize resource lock.") + return nil, errors.Wrapf(err, "failed to initialize resource lock.") + } + + if lock != nil { + logger.Infof(ctx, "Creating leader elector for the controller.") + controller.leaderElector, err = newLeaderElector(lock, cfg.LeaderElection, controller.onStartedLeading, func() { + logger.Fatal(ctx, "Lost leader state. Shutting down.") + }) + + if err != nil { + logger.Errorf(ctx, "failed to initialize leader elector.") + return nil, errors.Wrapf(err, "failed to initialize leader elector.") + } + } + + // WE are disabling this as the metrics have high cardinality. Metrics seem to be emitted per pod and this has problems + // when we create new pods + // Set Client Metrics Provider + // setClientMetricsProvider(scope.NewSubScope("k8s_client")) + + // obtain references to shared index informers for FlyteWorkflow. + flyteworkflowInformer := flyteworkflowInformerFactory.Flyteworkflow().V1alpha1().FlyteWorkflows() + controller.flyteworkflowSynced = flyteworkflowInformer.Informer().HasSynced + + sCfg := storage.GetConfig() + if sCfg == nil { + logger.Errorf(ctx, "Storage configuration missing.") + } + + store, err := storage.NewDataStore(sCfg, scope.NewSubScope("metastore")) + if err != nil { + return nil, errors.Wrapf(err, "Failed to create Metadata storage") + } + + logger.Info(ctx, "Setting up Catalog client.") + catalogClient := catalog.NewCatalogClient(store) + + workQ, err := NewCompositeWorkQueue(ctx, cfg.Queue, scope) + if err != nil { + return nil, errors.Wrapf(err, "Failed to create WorkQueue [%v]", scope.CurrentScope()) + } + controller.workQueue = workQ + + controller.workflowStore = workflowstore.NewPassthroughWorkflowStore(ctx, scope, flytepropellerClientset.FlyteworkflowV1alpha1(), flyteworkflowInformer.Lister()) + + nodeExecutor, err := nodes.NewExecutor(ctx, store, controller.enqueueWorkflowForNodeUpdates, + cfg.DownstreamEval.Duration, eventSink, wfLauncher, catalogClient, kubeClient, scope) + if err != nil { + return nil, errors.Wrapf(err, "Failed to create Controller.") + } + + workflowExecutor, err := workflow.NewExecutor(ctx, store, controller.enqueueWorkflowForNodeUpdates, eventSink, controller.recorder, cfg.MetadataPrefix, nodeExecutor, scope) + if err != nil { + return nil, err + } + + handler := NewPropellerHandler(ctx, cfg, controller.workflowStore, workflowExecutor, scope) + controller.workerPool = NewWorkerPool(ctx, scope, workQ, handler) + + logger.Info(ctx, "Setting up event handlers") + // Set up an event handler for when FlyteWorkflow resources change + flyteworkflowInformer.Informer().AddEventHandler(controller.getWorkflowUpdatesHandler()) + return controller, nil +} diff --git a/flytepropeller/pkg/controller/executors/contextual.go b/flytepropeller/pkg/controller/executors/contextual.go new file mode 100644 index 0000000000..7d02a6f9a7 --- /dev/null +++ b/flytepropeller/pkg/controller/executors/contextual.go @@ -0,0 +1,30 @@ +package executors + +import ( + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" +) + +type ContextualWorkflow struct { + v1alpha1.WorkflowMetaExtended + v1alpha1.ExecutableSubWorkflow + v1alpha1.NodeStatusGetter +} + +func NewBaseContextualWorkflow(baseWorkflow v1alpha1.ExecutableWorkflow) v1alpha1.ExecutableWorkflow { + return &ContextualWorkflow{ + ExecutableSubWorkflow: baseWorkflow, + WorkflowMetaExtended: baseWorkflow, + NodeStatusGetter: baseWorkflow.GetExecutionStatus(), + } +} + +// Creates a contextual workflow using the provided interface implementations. +func NewSubContextualWorkflow(baseWorkflow v1alpha1.ExecutableWorkflow, subWF v1alpha1.ExecutableSubWorkflow, + nodeStatus v1alpha1.ExecutableNodeStatus) v1alpha1.ExecutableWorkflow { + + return &ContextualWorkflow{ + ExecutableSubWorkflow: subWF, + WorkflowMetaExtended: baseWorkflow, + NodeStatusGetter: nodeStatus, + } +} diff --git a/flytepropeller/pkg/controller/executors/kube.go b/flytepropeller/pkg/controller/executors/kube.go new file mode 100644 index 0000000000..4c790c3975 --- /dev/null +++ b/flytepropeller/pkg/controller/executors/kube.go @@ -0,0 +1,56 @@ +package executors + +import ( + "context" + + "k8s.io/apimachinery/pkg/runtime" + "sigs.k8s.io/controller-runtime/pkg/cache" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +//go:generate mockery -name Client + +// A friendly controller-runtime client that gets passed to executors +type Client interface { + // GetClient returns a client configured with the Config + GetClient() client.Client + + // GetCache returns a cache.Cache + GetCache() cache.Cache +} + +type fallbackClientReader struct { + orderedClients []client.Client +} + +func (c fallbackClientReader) Get(ctx context.Context, key client.ObjectKey, out runtime.Object) (err error) { + for _, k8sClient := range c.orderedClients { + if err = k8sClient.Get(ctx, key, out); err == nil { + return nil + } + } + + return +} + +func (c fallbackClientReader) List(ctx context.Context, opts *client.ListOptions, list runtime.Object) (err error) { + for _, k8sClient := range c.orderedClients { + if err = k8sClient.List(ctx, opts, list); err == nil { + return nil + } + } + + return +} + +// Creates a new k8s client that uses the cached client for reads and falls back to making API +// calls if it failed. Write calls will always go to raw client directly. +func NewFallbackClient(cachedClient, rawClient client.Client) client.Client { + return client.DelegatingClient{ + Reader: fallbackClientReader{ + orderedClients: []client.Client{cachedClient, rawClient}, + }, + StatusClient: rawClient, + Writer: rawClient, + } +} diff --git a/flytepropeller/pkg/controller/executors/mocks/Client.go b/flytepropeller/pkg/controller/executors/mocks/Client.go new file mode 100644 index 0000000000..bc7af46707 --- /dev/null +++ b/flytepropeller/pkg/controller/executors/mocks/Client.go @@ -0,0 +1,45 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import cache "sigs.k8s.io/controller-runtime/pkg/cache" +import client "sigs.k8s.io/controller-runtime/pkg/client" + +import mock "github.com/stretchr/testify/mock" + +// Client is an autogenerated mock type for the Client type +type Client struct { + mock.Mock +} + +// GetCache provides a mock function with given fields: +func (_m *Client) GetCache() cache.Cache { + ret := _m.Called() + + var r0 cache.Cache + if rf, ok := ret.Get(0).(func() cache.Cache); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(cache.Cache) + } + } + + return r0 +} + +// GetClient provides a mock function with given fields: +func (_m *Client) GetClient() client.Client { + ret := _m.Called() + + var r0 client.Client + if rf, ok := ret.Get(0).(func() client.Client); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(client.Client) + } + } + + return r0 +} diff --git a/flytepropeller/pkg/controller/executors/mocks/fake.go b/flytepropeller/pkg/controller/executors/mocks/fake.go new file mode 100644 index 0000000000..27fb94060d --- /dev/null +++ b/flytepropeller/pkg/controller/executors/mocks/fake.go @@ -0,0 +1,13 @@ +package mocks + +import ( + "sigs.k8s.io/controller-runtime/pkg/cache/informertest" + "sigs.k8s.io/controller-runtime/pkg/client/fake" +) + +func NewFakeKubeClient() *Client { + c := Client{} + c.On("GetClient").Return(fake.NewFakeClient()) + c.On("GetCache").Return(&informertest.FakeInformers{}) + return &c +} diff --git a/flytepropeller/pkg/controller/executors/node.go b/flytepropeller/pkg/controller/executors/node.go new file mode 100644 index 0000000000..3dd1cb2707 --- /dev/null +++ b/flytepropeller/pkg/controller/executors/node.go @@ -0,0 +1,100 @@ +package executors + +import ( + "context" + "fmt" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" +) + +// Phase of the node +type NodePhase int + +const ( + // Indicates that the node is not yet ready to be executed and is pending any previous nodes completion + NodePhasePending NodePhase = iota + // Indicates that the node was queued and will start running soon + NodePhaseQueued + // Indicates that the payload associated with this node is being executed and is not yet done + NodePhaseRunning + // Indicates that the nodes payload has been successfully completed, but any downstream nodes from this node may not yet have completed + // We could make Success = running, but this enables more granular control + NodePhaseSuccess + // Complete indicates successful completion of a node. For singular nodes (nodes that have only one execution) success = complete, but, the executor + // will always signal completion + NodePhaseComplete + // Node failed in execution, either this node or anything in the downstream chain + NodePhaseFailed + // Internal error observed. This state should always be accompanied with an `error`. if not the behavior is undefined + NodePhaseUndefined +) + +func (p NodePhase) String() string { + switch p { + case NodePhaseRunning: + return "Running" + case NodePhaseQueued: + return "Queued" + case NodePhasePending: + return "Pending" + case NodePhaseFailed: + return "Failed" + case NodePhaseSuccess: + return "Success" + case NodePhaseComplete: + return "Complete" + case NodePhaseUndefined: + return "Undefined" + } + return fmt.Sprintf("Unknown - %d", p) +} + +// Core Node Executor that is used to execute a node. This is a recursive node executor and understands node dependencies +type Node interface { + // This method is used specifically to set inputs for start node. This is because start node does not retrieve inputs + // from predecessors, but the inputs are inputs to the workflow or inputs to the parent container (workflow) node. + SetInputsForStartNode(ctx context.Context, w v1alpha1.BaseWorkflowWithStatus, inputs *core.LiteralMap) (NodeStatus, error) + + // This is the main entrypoint to execute a node. It recursively depth-first goes through all ready nodes and starts their execution + // This returns either + // - 1. It finds a blocking node (not ready, or running) + // - 2. A node fails and hence the workflow will fail + // - 3. The final/end node has completed and the workflow should be stopped + RecursiveNodeHandler(ctx context.Context, w v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode) (NodeStatus, error) + + // This aborts the given node. If the given node is complete then it recursively finds the running nodes and aborts them + AbortHandler(ctx context.Context, w v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode) error + + // This method should be used to initialize Node executor + Initialize(ctx context.Context) error +} + +// Helper struct to allow passing of status between functions +type NodeStatus struct { + NodePhase NodePhase + Err error +} + +func (n *NodeStatus) IsComplete() bool { + return n.NodePhase == NodePhaseComplete +} + +func (n *NodeStatus) HasFailed() bool { + return n.NodePhase == NodePhaseFailed +} + +func (n *NodeStatus) PartiallyComplete() bool { + return n.NodePhase == NodePhaseSuccess +} + +var NodeStatusPending = NodeStatus{NodePhase: NodePhasePending} +var NodeStatusQueued = NodeStatus{NodePhase: NodePhaseQueued} +var NodeStatusRunning = NodeStatus{NodePhase: NodePhaseRunning} +var NodeStatusSuccess = NodeStatus{NodePhase: NodePhaseSuccess} +var NodeStatusComplete = NodeStatus{NodePhase: NodePhaseComplete} +var NodeStatusUndefined = NodeStatus{NodePhase: NodePhaseUndefined} + +func NodeStatusFailed(err error) NodeStatus { + return NodeStatus{NodePhase: NodePhaseFailed, Err: err} +} diff --git a/flytepropeller/pkg/controller/executors/workflow.go b/flytepropeller/pkg/controller/executors/workflow.go new file mode 100644 index 0000000000..31db1cab01 --- /dev/null +++ b/flytepropeller/pkg/controller/executors/workflow.go @@ -0,0 +1,13 @@ +package executors + +import ( + "context" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" +) + +type Workflow interface { + Initialize(ctx context.Context) error + HandleFlyteWorkflow(ctx context.Context, w *v1alpha1.FlyteWorkflow) error + HandleAbortedWorkflow(ctx context.Context, w *v1alpha1.FlyteWorkflow, maxRetries uint32) error +} diff --git a/flytepropeller/pkg/controller/finalizer.go b/flytepropeller/pkg/controller/finalizer.go new file mode 100644 index 0000000000..f1a8ba8ebd --- /dev/null +++ b/flytepropeller/pkg/controller/finalizer.go @@ -0,0 +1,36 @@ +package controller + +import v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + +const FinalizerKey = "flyte-finalizer" + +// NOTE: Some of these APIs are exclusive and do not compare the actual values of the finalizers. +// the intention of this module is to set only one opaque finalizer at a time. If you want to set multiple (not common) +// finalizers, use this module carefully and at your own risk! + +// Sets a new finalizer in case the finalizer is empty +func SetFinalizerIfEmpty(meta v1.Object, finalizer string) { + if !HasFinalizer(meta) { + meta.SetFinalizers([]string{finalizer}) + } +} + +// Check if the deletion timestamp is set, this is set automatically when an object is deleted +func IsDeleted(meta v1.Object) bool { + return meta.GetDeletionTimestamp() != nil +} + +// Reset all the finalizers on the object +func ResetFinalizers(meta v1.Object) { + meta.SetFinalizers([]string{}) +} + +// Currently we only compare the lengths of finalizers. If you add finalizers directly these API;'s will not work +func FinalizersIdentical(o1 v1.Object, o2 v1.Object) bool { + return len(o1.GetFinalizers()) == len(o2.GetFinalizers()) +} + +// Check if any finalizer is set +func HasFinalizer(meta v1.Object) bool { + return len(meta.GetFinalizers()) != 0 +} diff --git a/flytepropeller/pkg/controller/finalizer_test.go b/flytepropeller/pkg/controller/finalizer_test.go new file mode 100644 index 0000000000..05401806d6 --- /dev/null +++ b/flytepropeller/pkg/controller/finalizer_test.go @@ -0,0 +1,70 @@ +package controller + +import ( + "testing" + + "github.com/stretchr/testify/assert" + v1 "k8s.io/api/batch/v1" + v12 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +func TestFinalizersIdentical(t *testing.T) { + noFinalizer := &v1.Job{} + withFinalizer := &v1.Job{} + withFinalizer.SetFinalizers([]string{"t1"}) + + assert.True(t, FinalizersIdentical(noFinalizer, noFinalizer)) + assert.True(t, FinalizersIdentical(withFinalizer, withFinalizer)) + assert.False(t, FinalizersIdentical(noFinalizer, withFinalizer)) + withMultipleFinalizers := &v1.Job{} + withMultipleFinalizers.SetFinalizers([]string{"f1", "f2"}) + assert.False(t, FinalizersIdentical(withMultipleFinalizers, withFinalizer)) + + withDiffFinalizer := &v1.Job{} + withDiffFinalizer.SetFinalizers([]string{"f1"}) + assert.True(t, FinalizersIdentical(withFinalizer, withDiffFinalizer)) +} + +func TestIsDeleted(t *testing.T) { + noTermTS := &v1.Job{} + termedTS := &v1.Job{} + n := v12.Now() + termedTS.SetDeletionTimestamp(&n) + + assert.True(t, IsDeleted(termedTS)) + assert.False(t, IsDeleted(noTermTS)) +} + +func TestHasFinalizer(t *testing.T) { + noFinalizer := &v1.Job{} + withFinalizer := &v1.Job{} + withFinalizer.SetFinalizers([]string{"t1"}) + + assert.False(t, HasFinalizer(noFinalizer)) + assert.True(t, HasFinalizer(withFinalizer)) +} + +func TestSetFinalizerIfEmpty(t *testing.T) { + noFinalizer := &v1.Job{} + withFinalizer := &v1.Job{} + withFinalizer.SetFinalizers([]string{"t1"}) + + assert.False(t, HasFinalizer(noFinalizer)) + SetFinalizerIfEmpty(noFinalizer, "f1") + assert.True(t, HasFinalizer(noFinalizer)) + assert.Equal(t, []string{"f1"}, noFinalizer.GetFinalizers()) + + SetFinalizerIfEmpty(withFinalizer, "f1") + assert.Equal(t, []string{"t1"}, withFinalizer.GetFinalizers()) +} + +func TestResetFinalizer(t *testing.T) { + noFinalizer := &v1.Job{} + ResetFinalizers(noFinalizer) + assert.Equal(t, []string{}, noFinalizer.GetFinalizers()) + + withFinalizer := &v1.Job{} + withFinalizer.SetFinalizers([]string{"t1"}) + ResetFinalizers(withFinalizer) + assert.Equal(t, []string{}, withFinalizer.GetFinalizers()) +} diff --git a/flytepropeller/pkg/controller/garbage_collector.go b/flytepropeller/pkg/controller/garbage_collector.go new file mode 100644 index 0000000000..2361ff701b --- /dev/null +++ b/flytepropeller/pkg/controller/garbage_collector.go @@ -0,0 +1,145 @@ +package controller + +import ( + "context" + "runtime/pprof" + "time" + + "github.com/lyft/flytepropeller/pkg/controller/config" + + "strings" + + "github.com/lyft/flytepropeller/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1" + "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/promutils/labeled" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/clock" + corev1 "k8s.io/client-go/kubernetes/typed/core/v1" +) + +type gcMetrics struct { + gcRoundSuccess labeled.Counter + gcRoundFailure labeled.Counter + gcTime labeled.StopWatch +} + +// Garbage collector is an active background cleanup service, that deletes all workflows that are completed and older +// than the configured TTL +type GarbageCollector struct { + wfClient v1alpha1.FlyteworkflowV1alpha1Interface + namespaceClient corev1.NamespaceInterface + ttlHours int + interval time.Duration + clk clock.Clock + metrics *gcMetrics + namespace string +} + +// Issues a background deletion command with label selector for all completed workflows outside of the retention period +func (g *GarbageCollector) deleteWorkflows(ctx context.Context) error { + + s := CompletedWorkflowsSelectorOutsideRetentionPeriod(g.ttlHours-1, g.clk.Now()) + + // Delete doesn't support 'all' namespaces. Let's fetch namespaces and loop over each. + if g.namespace == "" || strings.ToLower(g.namespace) == "all" || strings.ToLower(g.namespace) == "all-namespaces" { + namespaceList, err := g.namespaceClient.List(v1.ListOptions{}) + if err != nil { + return err + } + for _, n := range namespaceList.Items { + namespaceCtx := contextutils.WithNamespace(ctx, n.GetName()) + logger.Infof(namespaceCtx, "Triggering Workflow delete for namespace: [%s]", n.GetName()) + + if err := g.deleteWorkflowsForNamespace(n.GetName(), s); err != nil { + g.metrics.gcRoundFailure.Inc(namespaceCtx) + logger.Errorf(namespaceCtx, "Garbage collection failed for for namespace: [%s]. Error : [%v]", n.GetName(), err) + } else { + g.metrics.gcRoundSuccess.Inc(namespaceCtx) + } + } + } else { + namespaceCtx := contextutils.WithNamespace(ctx, g.namespace) + logger.Infof(namespaceCtx, "Triggering Workflow delete for namespace: [%s]", g.namespace) + if err := g.deleteWorkflowsForNamespace(g.namespace, s); err != nil { + g.metrics.gcRoundFailure.Inc(namespaceCtx) + logger.Errorf(namespaceCtx, "Garbage collection failed for for namespace: [%s]. Error : [%v]", g.namespace, err) + } else { + g.metrics.gcRoundSuccess.Inc(namespaceCtx) + } + } + return nil +} + +func (g *GarbageCollector) deleteWorkflowsForNamespace(namespace string, labelSelector *v1.LabelSelector) error { + gracePeriodZero := int64(0) + propagation := v1.DeletePropagationBackground + + return g.wfClient.FlyteWorkflows(namespace).DeleteCollection( + &v1.DeleteOptions{ + GracePeriodSeconds: &gracePeriodZero, + PropagationPolicy: &propagation, + }, + v1.ListOptions{ + LabelSelector: v1.FormatLabelSelector(labelSelector), + }, + ) +} + +// A periodic GC running +func (g *GarbageCollector) runGC(ctx context.Context, ticker clock.Ticker) { + logger.Infof(ctx, "Background workflow garbage collection started, with duration [%s], TTL [%d] hours", g.interval.String(), g.ttlHours) + + ctx = contextutils.WithGoroutineLabel(ctx, "gc-worker") + pprof.SetGoroutineLabels(ctx) + defer ticker.Stop() + for { + select { + case <-ticker.C(): + logger.Infof(ctx, "Garbage collector running...") + t := g.metrics.gcTime.Start(ctx) + if err := g.deleteWorkflows(ctx); err != nil { + logger.Errorf(ctx, "Garbage collection failed in this round.Error : [%v]", err) + } + t.Stop() + case <-ctx.Done(): + logger.Infof(ctx, "Garbage collector stopping") + return + + } + } +} + +// Use this method to start a background garbage collection routine. Use the context to signal an exit signal +func (g *GarbageCollector) StartGC(ctx context.Context) error { + if g.ttlHours <= 0 { + logger.Warningf(ctx, "Garbage collector is disabled, as ttl [%d] is <=0", g.ttlHours) + return nil + } + ticker := g.clk.NewTicker(g.interval) + go g.runGC(ctx, ticker) + return nil +} + +func NewGarbageCollector(cfg *config.Config, scope promutils.Scope, clk clock.Clock, namespaceClient corev1.NamespaceInterface, wfClient v1alpha1.FlyteworkflowV1alpha1Interface, namespace string) (*GarbageCollector, error) { + ttl := 23 + if cfg.MaxTTLInHours < 23 { + ttl = cfg.MaxTTLInHours + } else { + logger.Warningf(context.TODO(), "defaulting max ttl for workflows to 23 hours, since configured duration is larger than 23 [%d]", cfg.MaxTTLInHours) + } + return &GarbageCollector{ + wfClient: wfClient, + ttlHours: ttl, + interval: cfg.GCInterval.Duration, + namespaceClient: namespaceClient, + metrics: &gcMetrics{ + gcTime: labeled.NewStopWatch("gc_latency", "time taken to issue a delete for TTL'ed workflows", time.Millisecond, scope), + gcRoundSuccess: labeled.NewCounter("gc_success", "successful executions of delete request", scope), + gcRoundFailure: labeled.NewCounter("gc_failure", "failure to delete workflows", scope), + }, + clk: clk, + namespace: namespace, + }, nil +} diff --git a/flytepropeller/pkg/controller/garbage_collector_test.go b/flytepropeller/pkg/controller/garbage_collector_test.go new file mode 100644 index 0000000000..8874af55cf --- /dev/null +++ b/flytepropeller/pkg/controller/garbage_collector_test.go @@ -0,0 +1,166 @@ +package controller + +import ( + "context" + "sync" + "testing" + "time" + + config2 "github.com/lyft/flytepropeller/pkg/controller/config" + + "github.com/lyft/flytepropeller/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1" + "github.com/lyft/flytestdlib/config" + "github.com/lyft/flytestdlib/promutils" + "github.com/stretchr/testify/assert" + corev1Types "k8s.io/api/core/v1" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/clock" + corev1 "k8s.io/client-go/kubernetes/typed/core/v1" +) + +func TestNewGarbageCollector(t *testing.T) { + t.Run("enabled", func(t *testing.T) { + cfg := &config2.Config{ + GCInterval: config.Duration{Duration: time.Minute * 30}, + MaxTTLInHours: 2, + } + gc, err := NewGarbageCollector(cfg, promutils.NewTestScope(), clock.NewFakeClock(time.Now()), nil, nil, "flyte") + assert.NoError(t, err) + assert.Equal(t, 2, gc.ttlHours) + }) + + t.Run("enabledBeyond23Hours", func(t *testing.T) { + cfg := &config2.Config{ + GCInterval: config.Duration{Duration: time.Minute * 30}, + MaxTTLInHours: 24, + } + gc, err := NewGarbageCollector(cfg, promutils.NewTestScope(), clock.NewFakeClock(time.Now()), nil, nil, "flyte") + assert.NoError(t, err) + assert.Equal(t, 23, gc.ttlHours) + }) + + t.Run("ttl0", func(t *testing.T) { + cfg := &config2.Config{ + GCInterval: config.Duration{Duration: time.Minute * 30}, + MaxTTLInHours: 0, + } + gc, err := NewGarbageCollector(cfg, promutils.NewTestScope(), nil, nil, nil, "flyte") + assert.NoError(t, err) + assert.Equal(t, 0, gc.ttlHours) + assert.NoError(t, gc.StartGC(context.TODO())) + + }) + + t.Run("ttl-1", func(t *testing.T) { + cfg := &config2.Config{ + GCInterval: config.Duration{Duration: time.Minute * 30}, + MaxTTLInHours: -1, + } + gc, err := NewGarbageCollector(cfg, promutils.NewTestScope(), nil, nil, nil, "flyte") + assert.NoError(t, err) + assert.Equal(t, -1, gc.ttlHours) + assert.NoError(t, gc.StartGC(context.TODO())) + }) +} + +type mockWfClient struct { + v1alpha1.FlyteWorkflowInterface + DeleteCollectionCb func(options *v1.DeleteOptions, listOptions v1.ListOptions) error +} + +func (m *mockWfClient) DeleteCollection(options *v1.DeleteOptions, listOptions v1.ListOptions) error { + return m.DeleteCollectionCb(options, listOptions) +} + +type mockClient struct { + v1alpha1.FlyteworkflowV1alpha1Client + FlyteWorkflowsCb func(namespace string) v1alpha1.FlyteWorkflowInterface +} + +func (m *mockClient) FlyteWorkflows(namespace string) v1alpha1.FlyteWorkflowInterface { + return m.FlyteWorkflowsCb(namespace) +} + +type mockNamespaceClient struct { + corev1.NamespaceInterface + ListCb func(opts v1.ListOptions) (*corev1Types.NamespaceList, error) +} + +func (m *mockNamespaceClient) List(opts v1.ListOptions) (*corev1Types.NamespaceList, error) { + return m.ListCb(opts) +} + +func TestGarbageCollector_StartGC(t *testing.T) { + wg := sync.WaitGroup{} + b := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC) + mockWfClient := &mockWfClient{ + DeleteCollectionCb: func(options *v1.DeleteOptions, listOptions v1.ListOptions) error { + assert.NotNil(t, options) + assert.NotNil(t, listOptions) + assert.Equal(t, "hour-of-day in (0,1,10,11,12,13,14,15,16,17,18,19,2,20,21,3,4,5,6,7,8,9),termination-status=terminated", listOptions.LabelSelector) + wg.Done() + return nil + }, + } + + mockClient := &mockClient{ + FlyteWorkflowsCb: func(namespace string) v1alpha1.FlyteWorkflowInterface { + return mockWfClient + }, + } + + mockNamespaceInvoked := false + mockNamespaceClient := &mockNamespaceClient{ + ListCb: func(opts v1.ListOptions) (*corev1Types.NamespaceList, error) { + mockNamespaceInvoked = true + return &corev1Types.NamespaceList{ + Items: []corev1Types.Namespace{ + { + ObjectMeta: v1.ObjectMeta{ + Name: "ns1", + }, + }, + { + ObjectMeta: v1.ObjectMeta{ + Name: "ns2", + }, + }, + }, + }, nil + }, + } + cfg := &config2.Config{ + GCInterval: config.Duration{Duration: time.Minute * 30}, + MaxTTLInHours: 2, + } + + t.Run("one-namespace", func(t *testing.T) { + fakeClock := clock.NewFakeClock(b) + mockNamespaceInvoked = false + gc, err := NewGarbageCollector(cfg, promutils.NewTestScope(), fakeClock, mockNamespaceClient, mockClient, "flyte") + assert.NoError(t, err) + wg.Add(1) + ctx := context.TODO() + ctx, cancel := context.WithCancel(ctx) + assert.NoError(t, gc.StartGC(ctx)) + fakeClock.Step(time.Minute * 30) + wg.Wait() + cancel() + assert.False(t, mockNamespaceInvoked) + }) + + t.Run("all-namespace", func(t *testing.T) { + fakeClock := clock.NewFakeClock(b) + mockNamespaceInvoked = false + gc, err := NewGarbageCollector(cfg, promutils.NewTestScope(), fakeClock, mockNamespaceClient, mockClient, "all") + assert.NoError(t, err) + wg.Add(2) + ctx := context.TODO() + ctx, cancel := context.WithCancel(ctx) + assert.NoError(t, gc.StartGC(ctx)) + fakeClock.Step(time.Minute * 30) + wg.Wait() + cancel() + assert.True(t, mockNamespaceInvoked) + }) +} diff --git a/flytepropeller/pkg/controller/handler.go b/flytepropeller/pkg/controller/handler.go new file mode 100644 index 0000000000..798d7b6c0a --- /dev/null +++ b/flytepropeller/pkg/controller/handler.go @@ -0,0 +1,188 @@ +package controller + +import ( + "context" + "fmt" + "runtime/debug" + "time" + + "github.com/lyft/flytepropeller/pkg/controller/config" + "github.com/lyft/flytepropeller/pkg/controller/workflowstore" + + "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/promutils" + "github.com/prometheus/client_golang/prometheus" + + "github.com/lyft/flytepropeller/pkg/controller/executors" +) + +// TODO Lets move everything to use controller runtime + +type propellerMetrics struct { + Scope promutils.Scope + DeepCopyTime promutils.StopWatch + RawWorkflowTraversalTime promutils.StopWatch + SystemError prometheus.Counter + AbortError prometheus.Counter + PanicObserved prometheus.Counter + RoundSkipped prometheus.Counter + WorkflowNotFound prometheus.Counter +} + +func newPropellerMetrics(scope promutils.Scope) *propellerMetrics { + roundScope := scope.NewSubScope("round") + return &propellerMetrics{ + Scope: scope, + DeepCopyTime: roundScope.MustNewStopWatch("deepcopy", "Total time to deep copy wf object", time.Millisecond), + RawWorkflowTraversalTime: roundScope.MustNewStopWatch("raw", "Total time to traverse the workflow", time.Millisecond), + SystemError: roundScope.MustNewCounter("system_error", "Failure to reconcile a workflow, system error"), + AbortError: roundScope.MustNewCounter("abort_error", "Failure to abort a workflow, system error"), + PanicObserved: roundScope.MustNewCounter("panic", "Panic during handling or aborting workflow"), + RoundSkipped: roundScope.MustNewCounter("skipped", "Round Skipped because of stale workflow"), + WorkflowNotFound: roundScope.MustNewCounter("not_found", "workflow not found in the cache"), + } +} + +type Propeller struct { + wfStore workflowstore.FlyteWorkflow + workflowExecutor executors.Workflow + metrics *propellerMetrics + cfg *config.Config +} + +func (p *Propeller) Initialize(ctx context.Context) error { + return p.workflowExecutor.Initialize(ctx) +} + +// reconciler compares the actual state with the desired, and attempts to +// converge the two. It then updates the GetExecutionStatus block of the FlyteWorkflow resource +// with the current status of the resource. +// Every FlyteWorkflow transitions through the following +// +// The Workflow to be worked on is identified for the given namespace and executionID (which is the name of the workflow) +// The return value should be an error, in the case, we wish to retry this workflow +//
+//
+//     +--------+        +--------+        +--------+     +--------+
+//     |        |        |        |        |        |     |        |
+//     | Ready  +--------> Running+--------> Succeeding---> Success|
+//     |        |        |        |        |        |     |        |
+//     +--------+        +--------+        +---------     +--------+
+//         |                  |
+//         |                  |
+//         |             +----v---+        +--------+
+//         |             |        |        |        |
+//         +-------------> Failing+--------> Failed |
+//                       |        |        |        |
+//                       +--------+        +--------+
+// 
+func (p *Propeller) Handle(ctx context.Context, namespace, name string) error { + logger.Infof(ctx, "Processing Workflow.") + defer logger.Infof(ctx, "Completed processing workflow.") + + // Get the FlyteWorkflow resource with this namespace/name + w, err := p.wfStore.Get(ctx, namespace, name) + if err != nil { + if workflowstore.IsNotFound(err) { + p.metrics.WorkflowNotFound.Inc() + logger.Warningf(ctx, "Workflow namespace[%v]/name[%v] not found, may be deleted.", namespace, name) + return nil + } + if workflowstore.IsWorkflowStale(err) { + p.metrics.RoundSkipped.Inc() + logger.Warningf(ctx, "Workflow namespace[%v]/name[%v] Stale.", namespace, name) + return nil + } + logger.Warningf(ctx, "Failed to GetWorkflow, retrying with back-off", err) + return err + } + + t := p.metrics.DeepCopyTime.Start() + wfDeepCopy := w.DeepCopy() + t.Stop() + ctx = contextutils.WithWorkflowID(ctx, wfDeepCopy.GetID()) + + maxRetries := uint32(p.cfg.MaxWorkflowRetries) + if IsDeleted(wfDeepCopy) || (wfDeepCopy.Status.FailedAttempts > maxRetries) { + var err error + func() { + defer func() { + if r := recover(); r != nil { + stack := debug.Stack() + err = fmt.Errorf("panic when aborting workflow, Stack: [%s]", string(stack)) + p.metrics.PanicObserved.Inc() + } + }() + err = p.workflowExecutor.HandleAbortedWorkflow(ctx, wfDeepCopy, maxRetries) + }() + if err != nil { + p.metrics.AbortError.Inc() + return err + } + } else { + if wfDeepCopy.GetExecutionStatus().IsTerminated() { + if HasCompletedLabel(wfDeepCopy) && !HasFinalizer(wfDeepCopy) { + logger.Debugf(ctx, "Workflow is terminated.") + return nil + } + // NOTE: This should never really happen, but in case we externally mark the workflow as terminated + // We should allow cleanup + logger.Warn(ctx, "Workflow is marked as terminated but doesn't have the completed label, marking it as completed.") + } else { + SetFinalizerIfEmpty(wfDeepCopy, FinalizerKey) + + func() { + t := p.metrics.RawWorkflowTraversalTime.Start() + defer func() { + t.Stop() + if r := recover(); r != nil { + stack := debug.Stack() + err = fmt.Errorf("panic when aborting workflow, Stack: [%s]", string(stack)) + p.metrics.PanicObserved.Inc() + } + }() + err = p.workflowExecutor.HandleFlyteWorkflow(ctx, wfDeepCopy) + }() + + if err != nil { + logger.Errorf(ctx, "Error when trying to reconcile workflow. Error [%v]", err) + // Let's mark these as system errors. + // We only want to increase failed attempts and discard any other partial changes to the CRD. + wfDeepCopy = w.DeepCopy() + wfDeepCopy.GetExecutionStatus().IncFailedAttempts() + wfDeepCopy.GetExecutionStatus().SetMessage(err.Error()) + p.metrics.SystemError.Inc() + } else { + // No updates in the status we detected, we will skip writing to KubeAPI + if wfDeepCopy.Status.Equals(&w.Status) { + logger.Info(ctx, "WF hasn't been updated in this round.") + return nil + } + } + } + } + // If the end result is a terminated workflow, we remove the labels + if wfDeepCopy.GetExecutionStatus().IsTerminated() { + // We add a completed label so that we can avoid polling for this workflow + SetCompletedLabel(wfDeepCopy, time.Now()) + ResetFinalizers(wfDeepCopy) + } + // TODO we will need to call updatestatus when it is supported. But to preserve metadata like (label/finalizer) we will need to use update + + // update the GetExecutionStatus block of the FlyteWorkflow resource. UpdateStatus will not + // allow changes to the Spec of the resource, which is ideal for ensuring + // nothing other than resource status has been updated. + return p.wfStore.Update(ctx, wfDeepCopy, workflowstore.PriorityClassCritical) +} + +func NewPropellerHandler(_ context.Context, cfg *config.Config, wfStore workflowstore.FlyteWorkflow, executor executors.Workflow, scope promutils.Scope) *Propeller { + + metrics := newPropellerMetrics(scope) + return &Propeller{ + metrics: metrics, + wfStore: wfStore, + workflowExecutor: executor, + cfg: cfg, + } +} diff --git a/flytepropeller/pkg/controller/handler_test.go b/flytepropeller/pkg/controller/handler_test.go new file mode 100644 index 0000000000..3c8dd7dcf5 --- /dev/null +++ b/flytepropeller/pkg/controller/handler_test.go @@ -0,0 +1,408 @@ +package controller + +import ( + "context" + "fmt" + "testing" + + "github.com/lyft/flytepropeller/pkg/controller/config" + "github.com/lyft/flytepropeller/pkg/controller/workflowstore" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytestdlib/promutils" + "github.com/stretchr/testify/assert" +) + +type mockExecutor struct { + HandleCb func(ctx context.Context, w *v1alpha1.FlyteWorkflow) error + HandleAbortedCb func(ctx context.Context, w *v1alpha1.FlyteWorkflow, maxRetries uint32) error +} + +func (m *mockExecutor) Initialize(ctx context.Context) error { + return nil +} + +func (m *mockExecutor) HandleAbortedWorkflow(ctx context.Context, w *v1alpha1.FlyteWorkflow, maxRetries uint32) error { + return m.HandleAbortedCb(ctx, w, maxRetries) +} + +func (m *mockExecutor) HandleFlyteWorkflow(ctx context.Context, w *v1alpha1.FlyteWorkflow) error { + return m.HandleCb(ctx, w) +} + +func TestPropeller_Handle(t *testing.T) { + scope := promutils.NewTestScope() + ctx := context.TODO() + s := workflowstore.NewInMemoryWorkflowStore() + exec := &mockExecutor{} + cfg := &config.Config{ + MaxWorkflowRetries: 0, + } + + p := NewPropellerHandler(ctx, cfg, s, exec, scope) + + const namespace = "test" + const name = "123" + t.Run("notPresent", func(t *testing.T) { + assert.NoError(t, p.Handle(ctx, namespace, name)) + }) + + t.Run("terminated", func(t *testing.T) { + assert.NoError(t, s.Create(ctx, &v1alpha1.FlyteWorkflow{ + ObjectMeta: v1.ObjectMeta{ + Name: name, + Namespace: namespace, + }, + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "w1", + }, + Status: v1alpha1.WorkflowStatus{ + Phase: v1alpha1.WorkflowPhaseFailed, + }, + })) + assert.NoError(t, p.Handle(ctx, namespace, name)) + + r, err := s.Get(ctx, namespace, name) + assert.NoError(t, err) + assert.Equal(t, v1alpha1.WorkflowPhaseFailed, r.GetExecutionStatus().GetPhase()) + assert.Equal(t, 0, len(r.Finalizers)) + assert.True(t, HasCompletedLabel(r)) + }) + + t.Run("happy", func(t *testing.T) { + assert.NoError(t, s.Create(ctx, &v1alpha1.FlyteWorkflow{ + ObjectMeta: v1.ObjectMeta{ + Name: name, + Namespace: namespace, + }, + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "w1", + }, + })) + exec.HandleCb = func(ctx context.Context, w *v1alpha1.FlyteWorkflow) error { + w.GetExecutionStatus().UpdatePhase(v1alpha1.WorkflowPhaseSucceeding, "done") + return nil + } + assert.NoError(t, p.Handle(ctx, namespace, name)) + + r, err := s.Get(ctx, namespace, name) + assert.NoError(t, err) + assert.Equal(t, v1alpha1.WorkflowPhaseSucceeding, r.GetExecutionStatus().GetPhase()) + assert.Equal(t, 1, len(r.Finalizers)) + assert.False(t, HasCompletedLabel(r)) + }) + + t.Run("error", func(t *testing.T) { + assert.NoError(t, s.Create(ctx, &v1alpha1.FlyteWorkflow{ + ObjectMeta: v1.ObjectMeta{ + Name: name, + Namespace: namespace, + }, + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "w1", + }, + })) + exec.HandleCb = func(ctx context.Context, w *v1alpha1.FlyteWorkflow) error { + return fmt.Errorf("failed") + } + assert.NoError(t, p.Handle(ctx, namespace, name)) + + r, err := s.Get(ctx, namespace, name) + assert.NoError(t, err) + assert.Equal(t, v1alpha1.WorkflowPhaseReady, r.GetExecutionStatus().GetPhase()) + assert.Equal(t, 0, len(r.Finalizers)) + assert.Equal(t, uint32(1), r.Status.FailedAttempts) + assert.False(t, HasCompletedLabel(r)) + }) + + t.Run("abort", func(t *testing.T) { + assert.NoError(t, s.Create(ctx, &v1alpha1.FlyteWorkflow{ + ObjectMeta: v1.ObjectMeta{ + Name: name, + Namespace: namespace, + }, + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "w1", + }, + Status: v1alpha1.WorkflowStatus{ + FailedAttempts: 1, + }, + })) + exec.HandleAbortedCb = func(ctx context.Context, w *v1alpha1.FlyteWorkflow, maxRetries uint32) error { + w.GetExecutionStatus().UpdatePhase(v1alpha1.WorkflowPhaseFailed, "done") + return nil + } + assert.NoError(t, p.Handle(ctx, namespace, name)) + + r, err := s.Get(ctx, namespace, name) + + assert.NoError(t, err) + assert.Equal(t, v1alpha1.WorkflowPhaseFailed, r.GetExecutionStatus().GetPhase()) + assert.Equal(t, 0, len(r.Finalizers)) + assert.True(t, HasCompletedLabel(r)) + assert.Equal(t, uint32(1), r.Status.FailedAttempts) + }) + + t.Run("abort_panics", func(t *testing.T) { + assert.NoError(t, s.Create(ctx, &v1alpha1.FlyteWorkflow{ + ObjectMeta: v1.ObjectMeta{ + Name: name, + Namespace: namespace, + Finalizers: []string{"x"}, + }, + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "w1", + }, + Status: v1alpha1.WorkflowStatus{ + FailedAttempts: 1, + Phase: v1alpha1.WorkflowPhaseRunning, + }, + })) + exec.HandleAbortedCb = func(ctx context.Context, w *v1alpha1.FlyteWorkflow, maxRetries uint32) error { + panic("error") + } + assert.Error(t, p.Handle(ctx, namespace, name)) + + r, err := s.Get(ctx, namespace, name) + + assert.NoError(t, err) + assert.Equal(t, v1alpha1.WorkflowPhaseRunning, r.GetExecutionStatus().GetPhase()) + assert.Equal(t, 1, len(r.Finalizers)) + assert.False(t, HasCompletedLabel(r)) + assert.Equal(t, uint32(1), r.Status.FailedAttempts) + }) + + t.Run("noUpdate", func(t *testing.T) { + assert.NoError(t, s.Create(ctx, &v1alpha1.FlyteWorkflow{ + ObjectMeta: v1.ObjectMeta{ + Name: name, + Namespace: namespace, + Finalizers: []string{"f1"}, + }, + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "w1", + }, + Status: v1alpha1.WorkflowStatus{ + Phase: v1alpha1.WorkflowPhaseSucceeding, + }, + })) + exec.HandleCb = func(ctx context.Context, w *v1alpha1.FlyteWorkflow) error { + w.GetExecutionStatus().UpdatePhase(v1alpha1.WorkflowPhaseSucceeding, "") + return nil + } + assert.NoError(t, p.Handle(ctx, namespace, name)) + + r, err := s.Get(ctx, namespace, name) + assert.NoError(t, err) + assert.Equal(t, v1alpha1.WorkflowPhaseSucceeding, r.GetExecutionStatus().GetPhase()) + assert.False(t, HasCompletedLabel(r)) + assert.Equal(t, 1, len(r.Finalizers)) + }) + + t.Run("handlingPanics", func(t *testing.T) { + assert.NoError(t, s.Create(ctx, &v1alpha1.FlyteWorkflow{ + ObjectMeta: v1.ObjectMeta{ + Name: name, + Namespace: namespace, + Finalizers: []string{"f1"}, + }, + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "w1", + }, + Status: v1alpha1.WorkflowStatus{ + Phase: v1alpha1.WorkflowPhaseSucceeding, + }, + })) + exec.HandleCb = func(ctx context.Context, w *v1alpha1.FlyteWorkflow) error { + panic("error") + } + assert.NoError(t, p.Handle(ctx, namespace, name)) + + r, err := s.Get(ctx, namespace, name) + assert.NoError(t, err) + assert.Equal(t, v1alpha1.WorkflowPhaseSucceeding, r.GetExecutionStatus().GetPhase()) + assert.False(t, HasCompletedLabel(r)) + assert.Equal(t, 1, len(r.Finalizers)) + assert.Equal(t, uint32(1), r.Status.FailedAttempts) + }) + + t.Run("noUpdate", func(t *testing.T) { + assert.NoError(t, s.Create(ctx, &v1alpha1.FlyteWorkflow{ + ObjectMeta: v1.ObjectMeta{ + Name: name, + Namespace: namespace, + Finalizers: []string{"f1"}, + }, + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "w1", + }, + Status: v1alpha1.WorkflowStatus{ + Phase: v1alpha1.WorkflowPhaseSucceeding, + }, + })) + exec.HandleCb = func(ctx context.Context, w *v1alpha1.FlyteWorkflow) error { + w.GetExecutionStatus().UpdatePhase(v1alpha1.WorkflowPhaseSucceeding, "") + return nil + } + assert.NoError(t, p.Handle(ctx, namespace, name)) + + r, err := s.Get(ctx, namespace, name) + assert.NoError(t, err) + assert.Equal(t, v1alpha1.WorkflowPhaseSucceeding, r.GetExecutionStatus().GetPhase()) + assert.False(t, HasCompletedLabel(r)) + assert.Equal(t, 1, len(r.Finalizers)) + }) + + t.Run("retriesExhaustedFinalize", func(t *testing.T) { + assert.NoError(t, s.Create(ctx, &v1alpha1.FlyteWorkflow{ + ObjectMeta: v1.ObjectMeta{ + Name: name, + Namespace: namespace, + Finalizers: []string{"f1"}, + }, + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "w1", + }, + Status: v1alpha1.WorkflowStatus{ + Phase: v1alpha1.WorkflowPhaseRunning, + FailedAttempts: 1, + }, + })) + abortCalled := false + exec.HandleAbortedCb = func(ctx context.Context, w *v1alpha1.FlyteWorkflow, maxRetries uint32) error { + w.Status.UpdatePhase(v1alpha1.WorkflowPhaseFailed, "Aborted") + abortCalled = true + return nil + } + assert.NoError(t, p.Handle(ctx, namespace, name)) + + r, err := s.Get(ctx, namespace, name) + assert.NoError(t, err) + assert.Equal(t, v1alpha1.WorkflowPhaseFailed, r.GetExecutionStatus().GetPhase()) + assert.Equal(t, 0, len(r.Finalizers)) + assert.True(t, HasCompletedLabel(r)) + assert.True(t, abortCalled) + }) + + t.Run("deletedShouldBeFinalized", func(t *testing.T) { + n := v1.Now() + assert.NoError(t, s.Create(ctx, &v1alpha1.FlyteWorkflow{ + ObjectMeta: v1.ObjectMeta{ + Name: name, + Namespace: namespace, + Finalizers: []string{"f1"}, + DeletionTimestamp: &n, + }, + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "w1", + }, + Status: v1alpha1.WorkflowStatus{ + Phase: v1alpha1.WorkflowPhaseSucceeding, + }, + })) + exec.HandleAbortedCb = func(ctx context.Context, w *v1alpha1.FlyteWorkflow, maxRetries uint32) error { + w.Status.UpdatePhase(v1alpha1.WorkflowPhaseAborted, "Aborted") + return nil + } + assert.NoError(t, p.Handle(ctx, namespace, name)) + + r, err := s.Get(ctx, namespace, name) + assert.NoError(t, err) + assert.Equal(t, v1alpha1.WorkflowPhaseAborted, r.GetExecutionStatus().GetPhase()) + assert.Equal(t, 0, len(r.Finalizers)) + assert.True(t, HasCompletedLabel(r)) + }) + + t.Run("deletedButAbortFailed", func(t *testing.T) { + n := v1.Now() + assert.NoError(t, s.Create(ctx, &v1alpha1.FlyteWorkflow{ + ObjectMeta: v1.ObjectMeta{ + Name: name, + Namespace: namespace, + Finalizers: []string{"f1"}, + DeletionTimestamp: &n, + }, + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "w1", + }, + Status: v1alpha1.WorkflowStatus{ + Phase: v1alpha1.WorkflowPhaseSucceeding, + }, + })) + + exec.HandleAbortedCb = func(ctx context.Context, w *v1alpha1.FlyteWorkflow, maxRetries uint32) error { + return fmt.Errorf("failed") + } + + assert.Error(t, p.Handle(ctx, namespace, name)) + + r, err := s.Get(ctx, namespace, name) + assert.NoError(t, err) + assert.Equal(t, v1alpha1.WorkflowPhaseSucceeding, r.GetExecutionStatus().GetPhase()) + assert.Equal(t, []string{"f1"}, r.Finalizers) + assert.False(t, HasCompletedLabel(r)) + }) + + t.Run("removefinalizerOnTerminateSuccess", func(t *testing.T) { + assert.NoError(t, s.Create(ctx, &v1alpha1.FlyteWorkflow{ + ObjectMeta: v1.ObjectMeta{ + Name: name, + Namespace: namespace, + Finalizers: []string{"f1"}, + }, + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "w1", + }, + })) + exec.HandleCb = func(ctx context.Context, w *v1alpha1.FlyteWorkflow) error { + w.GetExecutionStatus().UpdatePhase(v1alpha1.WorkflowPhaseSuccess, "done") + return nil + } + assert.NoError(t, p.Handle(ctx, namespace, name)) + + r, err := s.Get(ctx, namespace, name) + assert.NoError(t, err) + assert.Equal(t, v1alpha1.WorkflowPhaseSuccess, r.GetExecutionStatus().GetPhase()) + assert.Equal(t, 0, len(r.Finalizers)) + assert.True(t, HasCompletedLabel(r)) + }) + + t.Run("removefinalizerOnTerminateFailure", func(t *testing.T) { + assert.NoError(t, s.Create(ctx, &v1alpha1.FlyteWorkflow{ + ObjectMeta: v1.ObjectMeta{ + Name: name, + Namespace: namespace, + Finalizers: []string{"f1"}, + }, + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "w1", + }, + })) + exec.HandleCb = func(ctx context.Context, w *v1alpha1.FlyteWorkflow) error { + w.GetExecutionStatus().UpdatePhase(v1alpha1.WorkflowPhaseFailed, "done") + return nil + } + assert.NoError(t, p.Handle(ctx, namespace, name)) + + r, err := s.Get(ctx, namespace, name) + assert.NoError(t, err) + assert.Equal(t, v1alpha1.WorkflowPhaseFailed, r.GetExecutionStatus().GetPhase()) + assert.Equal(t, 0, len(r.Finalizers)) + assert.True(t, HasCompletedLabel(r)) + }) +} + +func TestPropellerHandler_Initialize(t *testing.T) { + scope := promutils.NewTestScope() + ctx := context.TODO() + s := workflowstore.NewInMemoryWorkflowStore() + exec := &mockExecutor{} + cfg := &config.Config{ + MaxWorkflowRetries: 0, + } + + p := NewPropellerHandler(ctx, cfg, s, exec, scope) + + assert.NoError(t, p.Initialize(ctx)) +} diff --git a/flytepropeller/pkg/controller/leaderelection.go b/flytepropeller/pkg/controller/leaderelection.go new file mode 100644 index 0000000000..8c6cd11323 --- /dev/null +++ b/flytepropeller/pkg/controller/leaderelection.go @@ -0,0 +1,78 @@ +package controller + +import ( + "context" + "fmt" + "os" + + "github.com/lyft/flytepropeller/pkg/controller/config" + + "k8s.io/apimachinery/pkg/util/rand" + + v1 "k8s.io/client-go/kubernetes/typed/core/v1" + "k8s.io/client-go/tools/leaderelection" + "k8s.io/client-go/tools/leaderelection/resourcelock" + "k8s.io/client-go/tools/record" +) + +const ( + // Env var to lookup pod name in. In pod spec, you will have to specify it like this: + // env: + // - name: POD_NAME + // valueFrom: + // fieldRef: + // fieldPath: metadata.name + podNameEnvVar = "POD_NAME" +) + +// NewResourceLock creates a new config map resource lock for use in a leader election loop +func newResourceLock(corev1 v1.CoreV1Interface, eventRecorder record.EventRecorder, options config.LeaderElectionConfig) ( + resourcelock.Interface, error) { + + if !options.Enabled { + return nil, nil + } + + // Default the LeaderElectionID + if len(options.LockConfigMap.String()) == 0 { + return nil, fmt.Errorf("to enable leader election, a config map must be provided") + } + + // Leader id, needs to be unique + return resourcelock.New(resourcelock.ConfigMapsResourceLock, + options.LockConfigMap.Namespace, + options.LockConfigMap.Name, + corev1, + resourcelock.ResourceLockConfig{ + Identity: getUniqueLeaderID(), + EventRecorder: eventRecorder, + }) +} + +func getUniqueLeaderID() string { + val, found := os.LookupEnv(podNameEnvVar) + if found { + return val + } + + id, err := os.Hostname() + if err != nil { + id = "" + } + + return fmt.Sprintf("%v_%v", id, rand.String(10)) +} + +func newLeaderElector(lock resourcelock.Interface, cfg config.LeaderElectionConfig, + leaderFn func(ctx context.Context), leaderStoppedFn func()) (*leaderelection.LeaderElector, error) { + return leaderelection.NewLeaderElector(leaderelection.LeaderElectionConfig{ + Lock: lock, + LeaseDuration: cfg.LeaseDuration.Duration, + RenewDeadline: cfg.RenewDeadline.Duration, + RetryPeriod: cfg.RetryPeriod.Duration, + Callbacks: leaderelection.LeaderCallbacks{ + OnStartedLeading: leaderFn, + OnStoppedLeading: leaderStoppedFn, + }, + }) +} diff --git a/flytepropeller/pkg/controller/nodes/branch/comparator.go b/flytepropeller/pkg/controller/nodes/branch/comparator.go new file mode 100644 index 0000000000..bf7e26ce9f --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/branch/comparator.go @@ -0,0 +1,139 @@ +package branch + +import ( + "reflect" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/pkg/errors" +) + +type comparator func(lValue *core.Primitive, rValue *core.Primitive) bool +type comparators struct { + gt comparator + eq comparator +} + +var primitiveBooleanType = reflect.TypeOf(&core.Primitive_Boolean{}).String() + +var perTypeComparators = map[string]comparators{ + reflect.TypeOf(&core.Primitive_FloatValue{}).String(): { + gt: func(lValue *core.Primitive, rValue *core.Primitive) bool { + return lValue.GetFloatValue() > rValue.GetFloatValue() + }, + eq: func(lValue *core.Primitive, rValue *core.Primitive) bool { + return lValue.GetFloatValue() == rValue.GetFloatValue() + }, + }, + reflect.TypeOf(&core.Primitive_Integer{}).String(): { + gt: func(lValue *core.Primitive, rValue *core.Primitive) bool { + return lValue.GetInteger() > rValue.GetInteger() + }, + eq: func(lValue *core.Primitive, rValue *core.Primitive) bool { + return lValue.GetInteger() == rValue.GetInteger() + }, + }, + reflect.TypeOf(&core.Primitive_Boolean{}).String(): { + eq: func(lValue *core.Primitive, rValue *core.Primitive) bool { + return lValue.GetBoolean() == rValue.GetBoolean() + }, + }, + reflect.TypeOf(&core.Primitive_StringValue{}).String(): { + gt: func(lValue *core.Primitive, rValue *core.Primitive) bool { + return lValue.GetStringValue() > rValue.GetStringValue() + }, + eq: func(lValue *core.Primitive, rValue *core.Primitive) bool { + return lValue.GetStringValue() == rValue.GetStringValue() + }, + }, + reflect.TypeOf(&core.Primitive_StringValue{}).String(): { + gt: func(lValue *core.Primitive, rValue *core.Primitive) bool { + return lValue.GetStringValue() > rValue.GetStringValue() + }, + eq: func(lValue *core.Primitive, rValue *core.Primitive) bool { + return lValue.GetStringValue() == rValue.GetStringValue() + }, + }, + reflect.TypeOf(&core.Primitive_Datetime{}).String(): { + gt: func(lValue *core.Primitive, rValue *core.Primitive) bool { + return lValue.GetDatetime().GetSeconds() > rValue.GetDatetime().GetSeconds() + }, + eq: func(lValue *core.Primitive, rValue *core.Primitive) bool { + return lValue.GetDatetime().GetSeconds() == rValue.GetDatetime().GetSeconds() + }, + }, + reflect.TypeOf(&core.Primitive_Duration{}).String(): { + gt: func(lValue *core.Primitive, rValue *core.Primitive) bool { + return lValue.GetDuration().GetSeconds() > rValue.GetDuration().GetSeconds() + }, + eq: func(lValue *core.Primitive, rValue *core.Primitive) bool { + return lValue.GetDuration().GetSeconds() == rValue.GetDuration().GetSeconds() + }, + }, +} + +func Evaluate(lValue *core.Primitive, rValue *core.Primitive, op core.ComparisonExpression_Operator) (bool, error) { + lValueType := reflect.TypeOf(lValue.Value) + rValueType := reflect.TypeOf(rValue.Value) + if lValueType != rValueType { + return false, errors.Errorf("Comparison between different primitives types. lVal[%v]:rVal[%v]", lValueType, rValueType) + } + comps, ok := perTypeComparators[lValueType.String()] + if !ok { + return false, errors.Errorf("Comparator not defined for type: [%v]", lValueType.String()) + } + isBoolean := false + if lValueType.String() == primitiveBooleanType { + isBoolean = true + } + switch op { + case core.ComparisonExpression_GT: + if isBoolean { + return false, errors.Errorf("[GT] not defined for boolean operands.") + } + return comps.gt(lValue, rValue), nil + case core.ComparisonExpression_GTE: + if isBoolean { + return false, errors.Errorf("[GTE] not defined for boolean operands.") + } + return comps.eq(lValue, rValue) || comps.gt(lValue, rValue), nil + case core.ComparisonExpression_LT: + if isBoolean { + return false, errors.Errorf("[LT] not defined for boolean operands.") + } + return !(comps.gt(lValue, rValue) || comps.eq(lValue, rValue)), nil + case core.ComparisonExpression_LTE: + if isBoolean { + return false, errors.Errorf("[LTE] not defined for boolean operands.") + } + return !comps.gt(lValue, rValue), nil + case core.ComparisonExpression_EQ: + return comps.eq(lValue, rValue), nil + case core.ComparisonExpression_NEQ: + return !comps.eq(lValue, rValue), nil + } + return false, errors.Errorf("Unsupported operator type in Propeller. System error.") +} + +func Evaluate1(lValue *core.Primitive, rValue *core.Literal, op core.ComparisonExpression_Operator) (bool, error) { + if rValue.GetScalar() == nil || rValue.GetScalar().GetPrimitive() == nil { + return false, errors.Errorf("Only primitives can be compared. RHS Variable is non primitive.") + } + return Evaluate(lValue, rValue.GetScalar().GetPrimitive(), op) +} + +func Evaluate2(lValue *core.Literal, rValue *core.Primitive, op core.ComparisonExpression_Operator) (bool, error) { + if lValue.GetScalar() == nil || lValue.GetScalar().GetPrimitive() == nil { + return false, errors.Errorf("Only primitives can be compared. LHS Variable is non primitive.") + } + return Evaluate(lValue.GetScalar().GetPrimitive(), rValue, op) +} + +func EvaluateLiterals(lValue *core.Literal, rValue *core.Literal, op core.ComparisonExpression_Operator) (bool, error) { + if lValue.GetScalar() == nil || lValue.GetScalar().GetPrimitive() == nil { + return false, errors.Errorf("Only primitives can be compared. LHS Variable is non primitive.") + } + if rValue.GetScalar() == nil || rValue.GetScalar().GetPrimitive() == nil { + return false, errors.Errorf("Only primitives can be compared. RHS Variable is non primitive") + } + return Evaluate(lValue.GetScalar().GetPrimitive(), rValue.GetScalar().GetPrimitive(), op) +} diff --git a/flytepropeller/pkg/controller/nodes/branch/comparator_test.go b/flytepropeller/pkg/controller/nodes/branch/comparator_test.go new file mode 100644 index 0000000000..c34f28b0d4 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/branch/comparator_test.go @@ -0,0 +1,403 @@ +package branch + +import ( + "fmt" + "testing" + "time" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/utils" + "github.com/stretchr/testify/assert" +) + +func TestEvaluate_int(t *testing.T) { + p1 := utils.MustMakePrimitive(1) + p2 := utils.MustMakePrimitive(2) + { + // p1 > p2 = false + b, err := Evaluate(p1, p2, core.ComparisonExpression_GT) + assert.NoError(t, err) + assert.False(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_GT) + assert.NoError(t, err) + assert.True(t, b) + } + { + // p1 >= p2 = false + b, err := Evaluate(p1, p2, core.ComparisonExpression_GTE) + assert.NoError(t, err) + assert.False(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_GTE) + assert.NoError(t, err) + assert.True(t, b) + } + { + // p1 < p2 = true + b, err := Evaluate(p1, p2, core.ComparisonExpression_LT) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_LT) + assert.NoError(t, err) + assert.False(t, b) + } + { + // p1 <= p2 = true + b, err := Evaluate(p1, p2, core.ComparisonExpression_LTE) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_LTE) + assert.NoError(t, err) + assert.False(t, b) + } + { + b, err := Evaluate(p1, p2, core.ComparisonExpression_NEQ) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_EQ) + assert.NoError(t, err) + assert.False(t, b) + b, err = Evaluate(p1, p1, core.ComparisonExpression_EQ) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p1, p1, core.ComparisonExpression_NEQ) + assert.NoError(t, err) + assert.False(t, b) + } + { + b, err := Evaluate(p1, p1, core.ComparisonExpression_LTE) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p1, p1, core.ComparisonExpression_LT) + assert.NoError(t, err) + assert.False(t, b) + } + { + b, err := Evaluate(p1, p1, core.ComparisonExpression_GTE) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p1, p1, core.ComparisonExpression_GT) + assert.NoError(t, err) + assert.False(t, b) + } +} + +func TestEvaluate_float(t *testing.T) { + p1 := utils.MustMakePrimitive(1.0) + p2 := utils.MustMakePrimitive(2.0) + { + // p1 > p2 = false + b, err := Evaluate(p1, p2, core.ComparisonExpression_GT) + assert.NoError(t, err) + assert.False(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_GT) + assert.NoError(t, err) + assert.True(t, b) + } + { + // p1 >= p2 = false + b, err := Evaluate(p1, p2, core.ComparisonExpression_GTE) + assert.NoError(t, err) + assert.False(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_GTE) + assert.NoError(t, err) + assert.True(t, b) + } + { + // p1 < p2 = true + b, err := Evaluate(p1, p2, core.ComparisonExpression_LT) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_LT) + assert.NoError(t, err) + assert.False(t, b) + } + { + // p1 <= p2 = true + b, err := Evaluate(p1, p2, core.ComparisonExpression_LTE) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_LTE) + assert.NoError(t, err) + assert.False(t, b) + } + { + b, err := Evaluate(p1, p2, core.ComparisonExpression_NEQ) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_EQ) + assert.NoError(t, err) + assert.False(t, b) + b, err = Evaluate(p1, p1, core.ComparisonExpression_EQ) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p1, p1, core.ComparisonExpression_NEQ) + assert.NoError(t, err) + assert.False(t, b) + } + { + b, err := Evaluate(p1, p1, core.ComparisonExpression_LTE) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p1, p1, core.ComparisonExpression_LT) + assert.NoError(t, err) + assert.False(t, b) + } + { + b, err := Evaluate(p1, p1, core.ComparisonExpression_GTE) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p1, p1, core.ComparisonExpression_GT) + assert.NoError(t, err) + assert.False(t, b) + } +} + +func TestEvaluate_string(t *testing.T) { + p1 := utils.MustMakePrimitive("a") + p2 := utils.MustMakePrimitive("b") + { + // p1 > p2 = false + b, err := Evaluate(p1, p2, core.ComparisonExpression_GT) + assert.NoError(t, err) + assert.False(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_GT) + assert.NoError(t, err) + assert.True(t, b) + } + { + // p1 >= p2 = false + b, err := Evaluate(p1, p2, core.ComparisonExpression_GTE) + assert.NoError(t, err) + assert.False(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_GTE) + assert.NoError(t, err) + assert.True(t, b) + } + { + // p1 < p2 = true + b, err := Evaluate(p1, p2, core.ComparisonExpression_LT) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_LT) + assert.NoError(t, err) + assert.False(t, b) + } + { + // p1 <= p2 = true + b, err := Evaluate(p1, p2, core.ComparisonExpression_LTE) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_LTE) + assert.NoError(t, err) + assert.False(t, b) + } + { + b, err := Evaluate(p1, p2, core.ComparisonExpression_NEQ) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_EQ) + assert.NoError(t, err) + assert.False(t, b) + b, err = Evaluate(p1, p1, core.ComparisonExpression_EQ) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p1, p1, core.ComparisonExpression_NEQ) + assert.NoError(t, err) + assert.False(t, b) + } + { + b, err := Evaluate(p1, p1, core.ComparisonExpression_LTE) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p1, p1, core.ComparisonExpression_LT) + assert.NoError(t, err) + assert.False(t, b) + } + { + b, err := Evaluate(p1, p1, core.ComparisonExpression_GTE) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p1, p1, core.ComparisonExpression_GT) + assert.NoError(t, err) + assert.False(t, b) + } +} + +func TestEvaluate_datetime(t *testing.T) { + p1 := utils.MustMakePrimitive(time.Date(2018, 7, 4, 12, 00, 00, 00, time.UTC)) + p2 := utils.MustMakePrimitive(time.Date(2018, 7, 4, 12, 00, 01, 00, time.UTC)) + { + // p1 > p2 = false + b, err := Evaluate(p1, p2, core.ComparisonExpression_GT) + assert.NoError(t, err) + assert.False(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_GT) + assert.NoError(t, err) + assert.True(t, b) + } + { + // p1 >= p2 = false + b, err := Evaluate(p1, p2, core.ComparisonExpression_GTE) + assert.NoError(t, err) + assert.False(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_GTE) + assert.NoError(t, err) + assert.True(t, b) + } + { + // p1 < p2 = true + b, err := Evaluate(p1, p2, core.ComparisonExpression_LT) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_LT) + assert.NoError(t, err) + assert.False(t, b) + } + { + // p1 <= p2 = true + b, err := Evaluate(p1, p2, core.ComparisonExpression_LTE) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_LTE) + assert.NoError(t, err) + assert.False(t, b) + } + { + b, err := Evaluate(p1, p2, core.ComparisonExpression_NEQ) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_EQ) + assert.NoError(t, err) + assert.False(t, b) + b, err = Evaluate(p1, p1, core.ComparisonExpression_EQ) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p1, p1, core.ComparisonExpression_NEQ) + assert.NoError(t, err) + assert.False(t, b) + } + { + b, err := Evaluate(p1, p1, core.ComparisonExpression_LTE) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p1, p1, core.ComparisonExpression_LT) + assert.NoError(t, err) + assert.False(t, b) + } + { + b, err := Evaluate(p1, p1, core.ComparisonExpression_GTE) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p1, p1, core.ComparisonExpression_GT) + assert.NoError(t, err) + assert.False(t, b) + } +} + +func TestEvaluate_duration(t *testing.T) { + p1 := utils.MustMakePrimitive(10 * time.Second) + p2 := utils.MustMakePrimitive(11 * time.Second) + { + // p1 > p2 = false + b, err := Evaluate(p1, p2, core.ComparisonExpression_GT) + assert.NoError(t, err) + assert.False(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_GT) + assert.NoError(t, err) + assert.True(t, b) + } + { + // p1 >= p2 = false + b, err := Evaluate(p1, p2, core.ComparisonExpression_GTE) + assert.NoError(t, err) + assert.False(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_GTE) + assert.NoError(t, err) + assert.True(t, b) + } + { + // p1 < p2 = true + b, err := Evaluate(p1, p2, core.ComparisonExpression_LT) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_LT) + assert.NoError(t, err) + assert.False(t, b) + } + { + // p1 <= p2 = true + b, err := Evaluate(p1, p2, core.ComparisonExpression_LTE) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_LTE) + assert.NoError(t, err) + assert.False(t, b) + } + { + b, err := Evaluate(p1, p2, core.ComparisonExpression_NEQ) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_EQ) + assert.NoError(t, err) + assert.False(t, b) + b, err = Evaluate(p1, p1, core.ComparisonExpression_EQ) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p1, p1, core.ComparisonExpression_NEQ) + assert.NoError(t, err) + assert.False(t, b) + } + { + b, err := Evaluate(p1, p1, core.ComparisonExpression_LTE) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p1, p1, core.ComparisonExpression_LT) + assert.NoError(t, err) + assert.False(t, b) + } + { + b, err := Evaluate(p1, p1, core.ComparisonExpression_GTE) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p1, p1, core.ComparisonExpression_GT) + assert.NoError(t, err) + assert.False(t, b) + } +} + +func TestEvaluate_boolean(t *testing.T) { + p1 := utils.MustMakePrimitive(true) + p2 := utils.MustMakePrimitive(false) + f := func(op core.ComparisonExpression_Operator) { + // GT/LT = false + msg := fmt.Sprintf("Evaluating: [%s]", op.String()) + b, err := Evaluate(p1, p2, op) + assert.Error(t, err, msg) + assert.False(t, b, msg) + b, err = Evaluate(p2, p1, op) + assert.Error(t, err, msg) + assert.False(t, b, msg) + b, err = Evaluate(p1, p1, op) + assert.Error(t, err, msg) + assert.False(t, b, msg) + } + f(core.ComparisonExpression_GT) + f(core.ComparisonExpression_LT) + f(core.ComparisonExpression_GTE) + f(core.ComparisonExpression_LTE) + + { + b, err := Evaluate(p1, p2, core.ComparisonExpression_NEQ) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p2, p1, core.ComparisonExpression_EQ) + assert.NoError(t, err) + assert.False(t, b) + b, err = Evaluate(p1, p1, core.ComparisonExpression_EQ) + assert.NoError(t, err) + assert.True(t, b) + b, err = Evaluate(p1, p1, core.ComparisonExpression_NEQ) + assert.NoError(t, err) + assert.False(t, b) + } +} diff --git a/flytepropeller/pkg/controller/nodes/branch/evaluator.go b/flytepropeller/pkg/controller/nodes/branch/evaluator.go new file mode 100644 index 0000000000..24377a11c4 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/branch/evaluator.go @@ -0,0 +1,139 @@ +package branch + +import ( + "context" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/controller/nodes/errors" + "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" + "github.com/lyft/flytestdlib/logger" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + regErrors "github.com/pkg/errors" +) + +func EvaluateComparison(expr *core.ComparisonExpression, nodeInputs *handler.Data) (bool, error) { + var lValue *core.Literal + var rValue *core.Literal + var lPrim *core.Primitive + var rPrim *core.Primitive + + if expr.GetLeftValue().GetPrimitive() == nil { + if nodeInputs == nil { + return false, regErrors.Errorf("Failed to find Value for Variable [%v]", expr.GetLeftValue().GetVar()) + } + lValue = nodeInputs.Literals[expr.GetLeftValue().GetVar()] + if lValue == nil { + return false, regErrors.Errorf("Failed to find Value for Variable [%v]", expr.GetLeftValue().GetVar()) + } + } else { + lPrim = expr.GetLeftValue().GetPrimitive() + } + + if expr.GetRightValue().GetPrimitive() == nil { + if nodeInputs == nil { + return false, regErrors.Errorf("Failed to find Value for Variable [%v]", expr.GetLeftValue().GetVar()) + } + rValue = nodeInputs.Literals[expr.GetRightValue().GetVar()] + if rValue == nil { + return false, regErrors.Errorf("Failed to find Value for Variable [%v]", expr.GetRightValue().GetVar()) + } + } else { + rPrim = expr.GetRightValue().GetPrimitive() + } + + if lValue != nil && rValue != nil { + return EvaluateLiterals(lValue, rValue, expr.GetOperator()) + } + if lValue != nil && rPrim != nil { + return Evaluate2(lValue, rPrim, expr.GetOperator()) + } + if lPrim != nil && rValue != nil { + return Evaluate1(lPrim, rValue, expr.GetOperator()) + } + return Evaluate(lPrim, rPrim, expr.GetOperator()) +} + +func EvaluateBooleanExpression(expr *core.BooleanExpression, nodeInputs *handler.Data) (bool, error) { + if expr.GetComparison() != nil { + return EvaluateComparison(expr.GetComparison(), nodeInputs) + } + if expr.GetConjunction() == nil { + return false, regErrors.Errorf("No Comparison or Conjunction found in Branch node expression.") + } + lvalue, err := EvaluateBooleanExpression(expr.GetConjunction().GetLeftExpression(), nodeInputs) + if err != nil { + return false, err + } + rvalue, err := EvaluateBooleanExpression(expr.GetConjunction().GetRightExpression(), nodeInputs) + if err != nil { + return false, err + } + if expr.GetConjunction().GetOperator() == core.ConjunctionExpression_OR { + return lvalue || rvalue, nil + } + return lvalue && rvalue, nil +} + +func EvaluateIfBlock(block v1alpha1.ExecutableIfBlock, nodeInputs *handler.Data, skippedNodeIds []*v1alpha1.NodeID) (*v1alpha1.NodeID, []*v1alpha1.NodeID, error) { + if ok, err := EvaluateBooleanExpression(block.GetCondition(), nodeInputs); err != nil { + return nil, skippedNodeIds, err + } else if ok { + // Set status to running + return block.GetThenNode(), skippedNodeIds, err + } + // This branch is not taken + return nil, append(skippedNodeIds, block.GetThenNode()), nil +} + +// Decides the branch to be taken, returns the nodeId of the selected node or an error +// The branchnode is marked as success. This is used by downstream node to determine if it can be executed +// All downstream nodes are marked as skipped +func DecideBranch(ctx context.Context, w v1alpha1.BaseWorkflowWithStatus, nodeID v1alpha1.NodeID, node v1alpha1.ExecutableBranchNode, nodeInputs *handler.Data) (*v1alpha1.NodeID, error) { + var selectedNodeID *v1alpha1.NodeID + var skippedNodeIds []*v1alpha1.NodeID + var err error + + selectedNodeID, skippedNodeIds, err = EvaluateIfBlock(node.GetIf(), nodeInputs, skippedNodeIds) + if err != nil { + return nil, err + } + + for _, block := range node.GetElseIf() { + if selectedNodeID != nil { + skippedNodeIds = append(skippedNodeIds, block.GetThenNode()) + } else { + selectedNodeID, skippedNodeIds, err = EvaluateIfBlock(block, nodeInputs, skippedNodeIds) + if err != nil { + return nil, err + } + } + } + if node.GetElse() != nil { + if selectedNodeID == nil { + selectedNodeID = node.GetElse() + } else { + skippedNodeIds = append(skippedNodeIds, node.GetElse()) + } + } + for _, nodeIDPtr := range skippedNodeIds { + skippedNodeID := *nodeIDPtr + n, ok := w.GetNode(skippedNodeID) + if !ok { + return nil, errors.Errorf(errors.DownstreamNodeNotFoundError, nodeID, "Downstream node [%v] not found", skippedNodeID) + } + nStatus := w.GetNodeExecutionStatus(n.GetID()) + logger.Infof(ctx, "Branch Setting Node[%v] status to Skipped!", skippedNodeID) + nStatus.UpdatePhase(v1alpha1.NodePhaseSkipped, v1.Now(), "Branch evaluated to false") + } + + if selectedNodeID == nil { + if node.GetElseFail() != nil { + return nil, errors.Errorf(errors.UserProvidedError, nodeID, node.GetElseFail().Message) + } + return nil, errors.Errorf(errors.NoBranchTakenError, nodeID, "No branch satisfied") + } + logger.Infof(ctx, "Branch Node[%v] selected!", *selectedNodeID) + return selectedNodeID, nil +} diff --git a/flytepropeller/pkg/controller/nodes/branch/evaluator_test.go b/flytepropeller/pkg/controller/nodes/branch/evaluator_test.go new file mode 100644 index 0000000000..cab3672cb6 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/branch/evaluator_test.go @@ -0,0 +1,667 @@ +package branch + +import ( + "context" + "testing" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/controller/nodes/errors" + "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" + "github.com/lyft/flytepropeller/pkg/utils" + "github.com/stretchr/testify/assert" +) + +// Creates a ComparisonExpression, comparing 2 literals +func getComparisonExpression(lV interface{}, op core.ComparisonExpression_Operator, rV interface{}) (*core.ComparisonExpression, *handler.Data) { + exp := &core.ComparisonExpression{ + LeftValue: &core.Operand{ + Val: &core.Operand_Var{ + Var: "x", + }, + }, + Operator: op, + RightValue: &core.Operand{ + Val: &core.Operand_Var{ + Var: "y", + }, + }, + } + inputs := &handler.Data{ + Literals: map[string]*core.Literal{ + "x": utils.MustMakePrimitiveLiteral(lV), + "y": utils.MustMakePrimitiveLiteral(rV), + }, + } + return exp, inputs +} + +func createUnaryConjunction(l *core.ComparisonExpression, op core.ConjunctionExpression_LogicalOperator, r *core.ComparisonExpression) *core.ConjunctionExpression { + return &core.ConjunctionExpression{ + LeftExpression: &core.BooleanExpression{ + Expr: &core.BooleanExpression_Comparison{ + Comparison: l, + }, + }, + Operator: op, + RightExpression: &core.BooleanExpression{ + Expr: &core.BooleanExpression_Comparison{ + Comparison: r, + }, + }, + } +} + +func TestEvaluateComparison(t *testing.T) { + t.Run("ComparePrimitives", func(t *testing.T) { + // Compare primitives + exp := &core.ComparisonExpression{ + LeftValue: &core.Operand{ + Val: &core.Operand_Primitive{ + Primitive: utils.MustMakePrimitive(1), + }, + }, + Operator: core.ComparisonExpression_GT, + RightValue: &core.Operand{ + Val: &core.Operand_Primitive{ + Primitive: utils.MustMakePrimitive(2), + }, + }, + } + v, err := EvaluateComparison(exp, nil) + assert.NoError(t, err) + assert.False(t, v) + }) + t.Run("ComparePrimitiveAndLiteral", func(t *testing.T) { + // Compare lVal -> primitive and rVal -> literal + exp := &core.ComparisonExpression{ + LeftValue: &core.Operand{ + Val: &core.Operand_Primitive{ + Primitive: utils.MustMakePrimitive(1), + }, + }, + Operator: core.ComparisonExpression_GT, + RightValue: &core.Operand{ + Val: &core.Operand_Var{ + Var: "y", + }, + }, + } + inputs := &handler.Data{ + Literals: map[string]*core.Literal{ + "y": utils.MustMakePrimitiveLiteral(2), + }, + } + v, err := EvaluateComparison(exp, inputs) + assert.NoError(t, err) + assert.False(t, v) + }) + t.Run("CompareLiteralAndPrimitive", func(t *testing.T) { + + // Compare lVal -> literal and rVal -> primitive + exp := &core.ComparisonExpression{ + LeftValue: &core.Operand{ + Val: &core.Operand_Var{ + Var: "x", + }, + }, + Operator: core.ComparisonExpression_GT, + RightValue: &core.Operand{ + Val: &core.Operand_Primitive{ + Primitive: utils.MustMakePrimitive(2), + }, + }, + } + inputs := &handler.Data{ + Literals: map[string]*core.Literal{ + "x": utils.MustMakePrimitiveLiteral(1), + "y": utils.MustMakePrimitiveLiteral(3), + }, + } + v, err := EvaluateComparison(exp, inputs) + assert.NoError(t, err) + assert.False(t, v) + }) + + t.Run("CompareLiterals", func(t *testing.T) { + // Compare lVal -> literal and rVal -> literal + exp, inputs := getComparisonExpression(1, core.ComparisonExpression_EQ, 1) + v, err := EvaluateComparison(exp, inputs) + assert.NoError(t, err) + assert.True(t, v) + }) + + t.Run("CompareLiterals2", func(t *testing.T) { + // Compare lVal -> literal and rVal -> literal + exp, inputs := getComparisonExpression(1, core.ComparisonExpression_NEQ, 1) + v, err := EvaluateComparison(exp, inputs) + assert.NoError(t, err) + assert.False(t, v) + }) + t.Run("ComparePrimitiveAndLiteralNotFound", func(t *testing.T) { + // Compare lVal -> primitive and rVal -> literal + exp := &core.ComparisonExpression{ + LeftValue: &core.Operand{ + Val: &core.Operand_Primitive{ + Primitive: utils.MustMakePrimitive(1), + }, + }, + Operator: core.ComparisonExpression_GT, + RightValue: &core.Operand{ + Val: &core.Operand_Var{ + Var: "y", + }, + }, + } + inputs := &handler.Data{ + Literals: map[string]*core.Literal{}, + } + _, err := EvaluateComparison(exp, inputs) + assert.Error(t, err) + + _, err = EvaluateComparison(exp, nil) + assert.Error(t, err) + }) + + t.Run("CompareLiteralNotFoundAndPrimitive", func(t *testing.T) { + // Compare lVal -> primitive and rVal -> literal + exp := &core.ComparisonExpression{ + LeftValue: &core.Operand{ + Val: &core.Operand_Var{ + Var: "y", + }, + }, + Operator: core.ComparisonExpression_GT, + RightValue: &core.Operand{ + Val: &core.Operand_Primitive{ + Primitive: utils.MustMakePrimitive(1), + }, + }, + } + inputs := &handler.Data{ + Literals: map[string]*core.Literal{}, + } + _, err := EvaluateComparison(exp, inputs) + assert.Error(t, err) + + _, err = EvaluateComparison(exp, nil) + assert.Error(t, err) + }) + +} + +func TestEvaluateBooleanExpression(t *testing.T) { + { + // Simple comparison only + ce, inputs := getComparisonExpression(1, core.ComparisonExpression_EQ, 1) + exp := &core.BooleanExpression{ + Expr: &core.BooleanExpression_Comparison{ + Comparison: ce, + }, + } + v, err := EvaluateBooleanExpression(exp, inputs) + assert.NoError(t, err) + assert.True(t, v) + } + { + // AND of 2 comparisons. Inputs are the same for both. + l, lInputs := getComparisonExpression(1, core.ComparisonExpression_EQ, 1) + r, _ := getComparisonExpression(1, core.ComparisonExpression_NEQ, 1) + + exp := &core.BooleanExpression{ + Expr: &core.BooleanExpression_Conjunction{ + Conjunction: createUnaryConjunction(l, core.ConjunctionExpression_AND, r), + }, + } + v, err := EvaluateBooleanExpression(exp, lInputs) + assert.NoError(t, err) + assert.False(t, v) + } + { + // OR of 2 comparisons + l, _ := getComparisonExpression(1, core.ComparisonExpression_EQ, 1) + r, inputs := getComparisonExpression(1, core.ComparisonExpression_NEQ, 1) + + exp := &core.BooleanExpression{ + Expr: &core.BooleanExpression_Conjunction{ + Conjunction: createUnaryConjunction(l, core.ConjunctionExpression_OR, r), + }, + } + v, err := EvaluateBooleanExpression(exp, inputs) + assert.NoError(t, err) + assert.True(t, v) + } + { + // Conjunction of comparison and a conjunction, AND + l, _ := getComparisonExpression(1, core.ComparisonExpression_EQ, 1) + r, inputs := getComparisonExpression(1, core.ComparisonExpression_NEQ, 1) + + innerExp := &core.BooleanExpression{ + Expr: &core.BooleanExpression_Conjunction{ + Conjunction: createUnaryConjunction(l, core.ConjunctionExpression_OR, r), + }, + } + + outerComparison := &core.ComparisonExpression{ + LeftValue: &core.Operand{ + Val: &core.Operand_Var{ + Var: "a", + }, + }, + Operator: core.ComparisonExpression_GT, + RightValue: &core.Operand{ + Val: &core.Operand_Var{ + Var: "b", + }, + }, + } + outerInputs := &handler.Data{ + Literals: map[string]*core.Literal{ + "a": utils.MustMakePrimitiveLiteral(5), + "b": utils.MustMakePrimitiveLiteral(4), + }, + } + + outerExp := &core.BooleanExpression{ + Expr: &core.BooleanExpression_Conjunction{ + Conjunction: &core.ConjunctionExpression{ + LeftExpression: &core.BooleanExpression{ + Expr: &core.BooleanExpression_Comparison{ + Comparison: outerComparison, + }, + }, + Operator: core.ConjunctionExpression_AND, + RightExpression: innerExp, + }, + }, + } + + for k, v := range inputs.Literals { + outerInputs.Literals[k] = v + } + + v, err := EvaluateBooleanExpression(outerExp, outerInputs) + assert.NoError(t, err) + assert.True(t, v) + } +} + +func TestEvaluateIfBlock(t *testing.T) { + { + // AND of 2 comparisons + l, _ := getComparisonExpression(1, core.ComparisonExpression_EQ, 1) + r, inputs := getComparisonExpression(1, core.ComparisonExpression_NEQ, 1) + + thenNode := "test" + block := &v1alpha1.IfBlock{ + Condition: v1alpha1.BooleanExpression{ + BooleanExpression: &core.BooleanExpression{ + Expr: &core.BooleanExpression_Conjunction{ + Conjunction: createUnaryConjunction(l, core.ConjunctionExpression_AND, r), + }, + }, + }, + ThenNode: &thenNode, + } + + skippedNodeIds := make([]*v1alpha1.NodeID, 0) + accp, skippedNodeIds, err := EvaluateIfBlock(block, inputs, skippedNodeIds) + assert.NoError(t, err) + assert.Nil(t, accp) + assert.Equal(t, 1, len(skippedNodeIds)) + assert.Equal(t, "test", *skippedNodeIds[0]) + } + { + // OR of 2 comparisons + l, _ := getComparisonExpression(1, core.ComparisonExpression_EQ, 1) + r, inputs := getComparisonExpression(1, core.ComparisonExpression_NEQ, 1) + + thenNode := "test" + block := &v1alpha1.IfBlock{ + Condition: v1alpha1.BooleanExpression{ + BooleanExpression: &core.BooleanExpression{ + Expr: &core.BooleanExpression_Conjunction{ + Conjunction: createUnaryConjunction(l, core.ConjunctionExpression_OR, r), + }, + }, + }, + ThenNode: &thenNode, + } + + skippedNodeIds := make([]*v1alpha1.NodeID, 0) + accp, skippedNodeIds, err := EvaluateIfBlock(block, inputs, skippedNodeIds) + assert.NoError(t, err) + assert.NotNil(t, accp) + assert.Equal(t, "test", *accp) + assert.Equal(t, 0, len(skippedNodeIds)) + } +} + +func TestDecideBranch(t *testing.T) { + ctx := context.Background() + + t.Run("EmptyIfBlock", func(t *testing.T) { + w := &v1alpha1.FlyteWorkflow{ + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "w1", + Nodes: map[v1alpha1.NodeID]*v1alpha1.NodeSpec{}, + }, + } + branchNode := &v1alpha1.BranchNodeSpec{} + b, err := DecideBranch(ctx, w, "n1", branchNode, nil) + assert.Error(t, err) + assert.Nil(t, b) + }) + + t.Run("MissingThenNode", func(t *testing.T) { + w := &v1alpha1.FlyteWorkflow{ + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "w1", + Nodes: map[v1alpha1.NodeID]*v1alpha1.NodeSpec{}, + }, + } + exp, inputs := getComparisonExpression(1.0, core.ComparisonExpression_EQ, 1.0) + branchNode := &v1alpha1.BranchNodeSpec{ + If: v1alpha1.IfBlock{ + Condition: v1alpha1.BooleanExpression{ + BooleanExpression: &core.BooleanExpression{ + Expr: &core.BooleanExpression_Comparison{ + Comparison: exp, + }, + }, + }, + ThenNode: nil, + }, + } + b, err := DecideBranch(ctx, w, "n1", branchNode, inputs) + assert.Error(t, err) + assert.Nil(t, b) + assert.Equal(t, errors.NoBranchTakenError, err.(*errors.NodeError).Code) + }) + + t.Run("WithThenNode", func(t *testing.T) { + n1 := "n1" + w := &v1alpha1.FlyteWorkflow{ + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "w1", + Nodes: map[v1alpha1.NodeID]*v1alpha1.NodeSpec{ + n1: { + ID: n1, + }, + }, + }, + } + exp, inputs := getComparisonExpression(1.0, core.ComparisonExpression_EQ, 1.0) + branchNode := &v1alpha1.BranchNodeSpec{ + If: v1alpha1.IfBlock{ + Condition: v1alpha1.BooleanExpression{ + BooleanExpression: &core.BooleanExpression{ + Expr: &core.BooleanExpression_Comparison{ + Comparison: exp, + }, + }, + }, + ThenNode: &n1, + }, + } + b, err := DecideBranch(ctx, w, "n1", branchNode, inputs) + assert.NoError(t, err) + assert.NotNil(t, b) + assert.Equal(t, n1, *b) + }) + + t.Run("RepeatedCondition", func(t *testing.T) { + n1 := "n1" + n2 := "n2" + w := &v1alpha1.FlyteWorkflow{ + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "w1", + Nodes: map[v1alpha1.NodeID]*v1alpha1.NodeSpec{ + n1: { + ID: n1, + }, + n2: { + ID: n2, + }, + }, + }, + } + exp, inputs := getComparisonExpression(1.0, core.ComparisonExpression_EQ, 1.0) + branchNode := &v1alpha1.BranchNodeSpec{ + If: v1alpha1.IfBlock{ + Condition: v1alpha1.BooleanExpression{ + BooleanExpression: &core.BooleanExpression{ + Expr: &core.BooleanExpression_Comparison{ + Comparison: exp, + }, + }, + }, + ThenNode: &n1, + }, + ElseIf: []*v1alpha1.IfBlock{ + { + Condition: v1alpha1.BooleanExpression{ + BooleanExpression: &core.BooleanExpression{ + Expr: &core.BooleanExpression_Comparison{ + Comparison: exp, + }, + }, + }, + ThenNode: &n2, + }, + }, + } + b, err := DecideBranch(ctx, w, "n", branchNode, inputs) + assert.NoError(t, err) + assert.NotNil(t, b) + assert.Equal(t, n1, *b) + assert.Equal(t, v1alpha1.NodePhaseSkipped, w.Status.NodeStatus[n2].GetPhase()) + assert.Nil(t, w.Status.NodeStatus[n1]) + }) + + t.Run("SecondCondition", func(t *testing.T) { + n1 := "n1" + n2 := "n2" + w := &v1alpha1.FlyteWorkflow{ + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "w1", + Nodes: map[v1alpha1.NodeID]*v1alpha1.NodeSpec{ + n1: { + ID: n1, + }, + n2: { + ID: n2, + }, + }, + }, + } + exp1, inputs := getComparisonExpression(1, core.ComparisonExpression_NEQ, 1) + exp2, _ := getComparisonExpression(1, core.ComparisonExpression_EQ, 1) + branchNode := &v1alpha1.BranchNodeSpec{ + If: v1alpha1.IfBlock{ + Condition: v1alpha1.BooleanExpression{ + BooleanExpression: &core.BooleanExpression{ + Expr: &core.BooleanExpression_Comparison{ + Comparison: exp1, + }, + }, + }, + ThenNode: &n1, + }, + ElseIf: []*v1alpha1.IfBlock{ + { + Condition: v1alpha1.BooleanExpression{ + BooleanExpression: &core.BooleanExpression{ + Expr: &core.BooleanExpression_Comparison{ + Comparison: exp2, + }, + }, + }, + ThenNode: &n2, + }, + }, + } + b, err := DecideBranch(ctx, w, "n", branchNode, inputs) + assert.NoError(t, err) + assert.NotNil(t, b) + assert.Equal(t, n2, *b) + assert.Nil(t, w.Status.NodeStatus[n2]) + assert.Equal(t, v1alpha1.NodePhaseSkipped, w.Status.NodeStatus[n1].GetPhase()) + }) + + t.Run("ElseCase", func(t *testing.T) { + n1 := "n1" + n2 := "n2" + n3 := "n3" + w := &v1alpha1.FlyteWorkflow{ + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "w1", + Nodes: map[v1alpha1.NodeID]*v1alpha1.NodeSpec{ + n1: { + ID: n1, + }, + n2: { + ID: n2, + }, + }, + }, + } + exp1, inputs := getComparisonExpression(1, core.ComparisonExpression_NEQ, 1) + exp2, _ := getComparisonExpression(1, core.ComparisonExpression_NEQ, 1) + branchNode := &v1alpha1.BranchNodeSpec{ + If: v1alpha1.IfBlock{ + Condition: v1alpha1.BooleanExpression{ + BooleanExpression: &core.BooleanExpression{ + Expr: &core.BooleanExpression_Comparison{ + Comparison: exp1, + }, + }, + }, + ThenNode: &n1, + }, + ElseIf: []*v1alpha1.IfBlock{ + { + Condition: v1alpha1.BooleanExpression{ + BooleanExpression: &core.BooleanExpression{ + Expr: &core.BooleanExpression_Comparison{ + Comparison: exp2, + }, + }, + }, + ThenNode: &n2, + }, + }, + Else: &n3, + } + b, err := DecideBranch(ctx, w, "n", branchNode, inputs) + assert.NoError(t, err) + assert.NotNil(t, b) + assert.Equal(t, n3, *b) + assert.Equal(t, v1alpha1.NodePhaseSkipped, w.Status.NodeStatus[n1].GetPhase()) + assert.Equal(t, v1alpha1.NodePhaseSkipped, w.Status.NodeStatus[n2].GetPhase()) + }) + + t.Run("MissingNode", func(t *testing.T) { + n1 := "n1" + n2 := "n2" + n3 := "n3" + w := &v1alpha1.FlyteWorkflow{ + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "w1", + Nodes: map[v1alpha1.NodeID]*v1alpha1.NodeSpec{ + n1: { + ID: n1, + }, + }, + }, + } + exp1, inputs := getComparisonExpression(1, core.ComparisonExpression_NEQ, 1) + exp2, _ := getComparisonExpression(1, core.ComparisonExpression_NEQ, 1) + branchNode := &v1alpha1.BranchNodeSpec{ + If: v1alpha1.IfBlock{ + Condition: v1alpha1.BooleanExpression{ + BooleanExpression: &core.BooleanExpression{ + Expr: &core.BooleanExpression_Comparison{ + Comparison: exp1, + }, + }, + }, + ThenNode: &n1, + }, + ElseIf: []*v1alpha1.IfBlock{ + { + Condition: v1alpha1.BooleanExpression{ + BooleanExpression: &core.BooleanExpression{ + Expr: &core.BooleanExpression_Comparison{ + Comparison: exp2, + }, + }, + }, + ThenNode: &n2, + }, + }, + Else: &n3, + } + b, err := DecideBranch(ctx, w, "n", branchNode, inputs) + assert.Error(t, err) + assert.Nil(t, b) + assert.Equal(t, errors.DownstreamNodeNotFoundError, err.(*errors.NodeError).Code) + }) + + t.Run("ElseFailCase", func(t *testing.T) { + n1 := "n1" + n2 := "n2" + userError := "User error" + w := &v1alpha1.FlyteWorkflow{ + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "w1", + Nodes: map[v1alpha1.NodeID]*v1alpha1.NodeSpec{ + n1: { + ID: n1, + }, + n2: { + ID: n2, + }, + }, + }, + } + exp1, inputs := getComparisonExpression(1, core.ComparisonExpression_NEQ, 1) + exp2, _ := getComparisonExpression(1, core.ComparisonExpression_NEQ, 1) + branchNode := &v1alpha1.BranchNodeSpec{ + If: v1alpha1.IfBlock{ + Condition: v1alpha1.BooleanExpression{ + BooleanExpression: &core.BooleanExpression{ + Expr: &core.BooleanExpression_Comparison{ + Comparison: exp1, + }, + }, + }, + ThenNode: &n1, + }, + ElseIf: []*v1alpha1.IfBlock{ + { + Condition: v1alpha1.BooleanExpression{ + BooleanExpression: &core.BooleanExpression{ + Expr: &core.BooleanExpression_Comparison{ + Comparison: exp2, + }, + }, + }, + ThenNode: &n2, + }, + }, + ElseFail: &v1alpha1.Error{ + Error: &core.Error{ + Message: userError, + }, + }, + } + b, err := DecideBranch(ctx, w, "n", branchNode, inputs) + assert.Error(t, err) + assert.Nil(t, b) + assert.Equal(t, errors.UserProvidedError, err.(*errors.NodeError).Code) + assert.Equal(t, userError, err.(*errors.NodeError).Message) + assert.Equal(t, v1alpha1.NodePhaseSkipped, w.Status.NodeStatus[n1].GetPhase()) + assert.Equal(t, v1alpha1.NodePhaseSkipped, w.Status.NodeStatus[n2].GetPhase()) + }) +} diff --git a/flytepropeller/pkg/controller/nodes/branch/handler.go b/flytepropeller/pkg/controller/nodes/branch/handler.go new file mode 100644 index 0000000000..f3531771ed --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/branch/handler.go @@ -0,0 +1,135 @@ +package branch + +import ( + "context" + + "github.com/lyft/flyteidl/clients/go/events" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/controller/executors" + "github.com/lyft/flytepropeller/pkg/controller/nodes/errors" + "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/promutils" +) + +type branchHandler struct { + nodeExecutor executors.Node + recorder events.NodeEventRecorder +} + +func (b *branchHandler) recurseDownstream(ctx context.Context, w v1alpha1.ExecutableWorkflow, nodeStatus v1alpha1.ExecutableNodeStatus, branchTakenNode v1alpha1.ExecutableNode) (handler.Status, error) { + downstreamStatus, err := b.nodeExecutor.RecursiveNodeHandler(ctx, w, branchTakenNode) + if err != nil { + return handler.StatusUndefined, err + } + + if downstreamStatus.IsComplete() { + // For branch node we set the output node to be the same as the child nodes output + childNodeStatus := w.GetNodeExecutionStatus(branchTakenNode.GetID()) + nodeStatus.SetDataDir(childNodeStatus.GetDataDir()) + return handler.StatusSuccess, nil + } + + if downstreamStatus.HasFailed() { + return handler.StatusFailed(downstreamStatus.Err), nil + } + + return handler.StatusRunning, nil +} + +func (b *branchHandler) StartNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, nodeInputs *handler.Data) (handler.Status, error) { + logger.Debugf(ctx, "Starting Branch Node") + branch := node.GetBranchNode() + if branch == nil { + return handler.StatusFailed(errors.Errorf(errors.IllegalStateError, w.GetID(), node.GetID(), "Invoked branch handler, for a non branch node.")), nil + } + nodeStatus := w.GetNodeExecutionStatus(node.GetID()) + branchStatus := nodeStatus.GetOrCreateBranchStatus() + finalNode, err := DecideBranch(ctx, w, node.GetID(), branch, nodeInputs) + if err != nil { + branchStatus.SetBranchNodeError() + logger.Debugf(ctx, "Branch evaluation failed. Error [%s]", err) + return handler.StatusFailed(err), nil + } + branchStatus.SetBranchNodeSuccess(*finalNode) + var ok bool + childNode, ok := w.GetNode(*finalNode) + if !ok { + logger.Debugf(ctx, "Branch downstream finalized node not found. FinalizedNode [%s]", *finalNode) + return handler.StatusFailed(errors.Errorf(errors.DownstreamNodeNotFoundError, w.GetID(), node.GetID(), "Downstream node [%v] not found", *finalNode)), nil + } + i := node.GetID() + childNodeStatus := w.GetNodeExecutionStatus(childNode.GetID()) + childNodeStatus.SetParentNodeID(&i) + + logger.Debugf(ctx, "Recursing down branch node") + return b.recurseDownstream(ctx, w, nodeStatus, childNode) +} + +func (b *branchHandler) CheckNodeStatus(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, nodeStatus v1alpha1.ExecutableNodeStatus) (handler.Status, error) { + branch := node.GetBranchNode() + if branch == nil { + return handler.StatusFailed(errors.Errorf(errors.IllegalStateError, w.GetID(), node.GetID(), "Invoked branch handler, for a non branch node.")), nil + } + // If the branch was already evaluated i.e, Node is in Running status + branchStatus := nodeStatus.GetOrCreateBranchStatus() + userError := branch.GetElseFail() + finalNodeID := branchStatus.GetFinalizedNode() + if finalNodeID == nil { + if userError != nil { + // We should never reach here, but for safety and completeness + return handler.StatusFailed(errors.Errorf(errors.UserProvidedError, w.GetID(), node.GetID(), userError.Message)), nil + } + return handler.StatusRunning, errors.Errorf(errors.IllegalStateError, w.GetID(), node.GetID(), "No node finalized through previous branch evaluation.") + } + var ok bool + branchTakenNode, ok := w.GetNode(*finalNodeID) + if !ok { + return handler.StatusFailed(errors.Errorf(errors.DownstreamNodeNotFoundError, w.GetID(), node.GetID(), "Downstream node [%v] not found", *finalNodeID)), nil + } + // Recurse downstream + return b.recurseDownstream(ctx, w, nodeStatus, branchTakenNode) +} + +func (b *branchHandler) HandleFailingNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) (handler.Status, error) { + return handler.StatusFailed(errors.Errorf(errors.IllegalStateError, w.GetID(), node.GetID(), "A Branch node cannot enter a failing state")), nil +} + +func (b *branchHandler) Initialize(ctx context.Context) error { + return nil +} + +func (b *branchHandler) AbortNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) error { + branch := node.GetBranchNode() + if branch == nil { + return errors.Errorf(errors.IllegalStateError, w.GetID(), node.GetID(), "Invoked branch handler, for a non branch node.") + } + // If the branch was already evaluated i.e, Node is in Running status + nodeStatus := w.GetNodeExecutionStatus(node.GetID()) + branchStatus := nodeStatus.GetOrCreateBranchStatus() + userError := branch.GetElseFail() + finalNodeID := branchStatus.GetFinalizedNode() + if finalNodeID == nil { + if userError != nil { + // We should never reach here, but for safety and completeness + return errors.Errorf(errors.UserProvidedError, w.GetID(), node.GetID(), userError.Message) + } + return errors.Errorf(errors.IllegalStateError, w.GetID(), node.GetID(), "No node finalized through previous branch evaluation.") + } + var ok bool + branchTakenNode, ok := w.GetNode(*finalNodeID) + + if !ok { + return errors.Errorf(errors.DownstreamNodeNotFoundError, w.GetID(), node.GetID(), "Downstream node [%v] not found", *finalNodeID) + } + // Recurse downstream + return b.nodeExecutor.AbortHandler(ctx, w, branchTakenNode) +} + +func New(executor executors.Node, eventSink events.EventSink, scope promutils.Scope) handler.IFace { + branchScope := scope.NewSubScope("branch") + return &branchHandler{ + nodeExecutor: executor, + recorder: events.NewNodeEventRecorder(eventSink, branchScope), + } +} diff --git a/flytepropeller/pkg/controller/nodes/branch/handler_test.go b/flytepropeller/pkg/controller/nodes/branch/handler_test.go new file mode 100644 index 0000000000..c64cf96a3e --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/branch/handler_test.go @@ -0,0 +1,236 @@ +package branch + +import ( + "context" + "fmt" + "testing" + + "github.com/lyft/flyteidl/clients/go/events" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/controller/executors" + "github.com/lyft/flytepropeller/pkg/controller/nodes/errors" + "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" + "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/promutils/labeled" + "github.com/stretchr/testify/assert" +) + +type recursiveNodeHandlerFn func(ctx context.Context, w v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode) (executors.NodeStatus, error) +type abortNodeHandlerCbFn func(ctx context.Context, w v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode) error + +type mockNodeExecutor struct { + executors.Node + RecursiveNodeHandlerCB recursiveNodeHandlerFn + AbortNodeHandlerCB abortNodeHandlerCbFn +} + +func (m *mockNodeExecutor) RecursiveNodeHandler(ctx context.Context, w v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode) (executors.NodeStatus, error) { + return m.RecursiveNodeHandlerCB(ctx, w, currentNode) +} + +func (m *mockNodeExecutor) AbortHandler(ctx context.Context, w v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode) error { + return m.AbortNodeHandlerCB(ctx, w, currentNode) +} + +func TestBranchHandler_RecurseDownstream(t *testing.T) { + ctx := context.TODO() + m := &mockNodeExecutor{} + branch := New(m, events.NewMockEventSink(), promutils.NewTestScope()).(*branchHandler) + childNodeID := "child" + childDatadir := v1alpha1.DataReference("test") + w := &v1alpha1.FlyteWorkflow{ + Status: v1alpha1.WorkflowStatus{ + NodeStatus: map[v1alpha1.NodeID]*v1alpha1.NodeStatus{ + childNodeID: { + DataDir: childDatadir, + }, + }, + }, + } + expectedError := fmt.Errorf("error") + + recursiveNodeHandlerFnArchetype := func(status executors.NodeStatus, err error) recursiveNodeHandlerFn { + return func(ctx context.Context, w v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode) (executors.NodeStatus, error) { + return status, err + } + } + + tests := []struct { + name string + recursiveNodeHandlerFn recursiveNodeHandlerFn + nodeStatus v1alpha1.ExecutableNodeStatus + branchTakenNode v1alpha1.ExecutableNode + isErr bool + expectedStatus handler.Status + }{ + {"childNodeError", recursiveNodeHandlerFnArchetype(executors.NodeStatusUndefined, expectedError), + nil, &v1alpha1.NodeSpec{}, true, handler.StatusUndefined}, + {"childPending", recursiveNodeHandlerFnArchetype(executors.NodeStatusPending, nil), + nil, &v1alpha1.NodeSpec{}, false, handler.StatusRunning}, + {"childStillRunning", recursiveNodeHandlerFnArchetype(executors.NodeStatusRunning, nil), + nil, &v1alpha1.NodeSpec{}, false, handler.StatusRunning}, + {"childFailure", recursiveNodeHandlerFnArchetype(executors.NodeStatusFailed(expectedError), nil), + nil, &v1alpha1.NodeSpec{}, false, handler.StatusFailed(expectedError)}, + {"childComplete", recursiveNodeHandlerFnArchetype(executors.NodeStatusComplete, nil), + &v1alpha1.NodeStatus{}, &v1alpha1.NodeSpec{ID: childNodeID}, false, handler.StatusSuccess}, + {"childCompleteNoStatus", recursiveNodeHandlerFnArchetype(executors.NodeStatusComplete, nil), + &v1alpha1.NodeStatus{}, &v1alpha1.NodeSpec{ID: "deadbeef"}, false, handler.StatusSuccess}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + m.RecursiveNodeHandlerCB = test.recursiveNodeHandlerFn + h, err := branch.recurseDownstream(ctx, w, test.nodeStatus, test.branchTakenNode) + if test.isErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, test.expectedStatus, h) + if test.nodeStatus != nil { + assert.Equal(t, w.GetNodeExecutionStatus(test.branchTakenNode.GetID()).GetDataDir(), test.nodeStatus.GetDataDir()) + } + }) + } +} + +func TestBranchHandler_AbortNode(t *testing.T) { + ctx := context.TODO() + m := &mockNodeExecutor{} + branch := New(m, events.NewMockEventSink(), promutils.NewTestScope()) + b1 := "b1" + n1 := "n1" + n2 := "n2" + + w := &v1alpha1.FlyteWorkflow{ + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "test", + Nodes: map[v1alpha1.NodeID]*v1alpha1.NodeSpec{ + n1: { + ID: n1, + }, + n2: { + ID: n2, + }, + }, + }, + Status: v1alpha1.WorkflowStatus{ + NodeStatus: map[v1alpha1.NodeID]*v1alpha1.NodeStatus{ + b1: { + Phase: v1alpha1.NodePhaseRunning, + BranchStatus: &v1alpha1.BranchNodeStatus{ + FinalizedNodeID: &n1, + }, + }, + }, + }, + } + exp, _ := getComparisonExpression(1.0, core.ComparisonExpression_EQ, 1.0) + + branchNode := &v1alpha1.BranchNodeSpec{ + + If: v1alpha1.IfBlock{ + Condition: v1alpha1.BooleanExpression{ + BooleanExpression: &core.BooleanExpression{ + Expr: &core.BooleanExpression_Comparison{ + Comparison: exp, + }, + }, + }, + ThenNode: &n1, + }, + ElseIf: []*v1alpha1.IfBlock{ + { + Condition: v1alpha1.BooleanExpression{ + BooleanExpression: &core.BooleanExpression{ + Expr: &core.BooleanExpression_Comparison{ + Comparison: exp, + }, + }, + }, + ThenNode: &n2, + }, + }, + } + + t.Run("NoBranchNode", func(t *testing.T) { + + err := branch.AbortNode(ctx, w, &v1alpha1.NodeSpec{}) + assert.Error(t, err) + assert.True(t, errors.Matches(err, errors.IllegalStateError)) + }) + + t.Run("BranchNodeNoEval", func(t *testing.T) { + + err := branch.AbortNode(ctx, w, &v1alpha1.NodeSpec{ + BranchNode: branchNode}) + assert.Error(t, err) + assert.True(t, errors.Matches(err, errors.IllegalStateError)) + }) + + t.Run("BranchNodeSuccess", func(t *testing.T) { + m.AbortNodeHandlerCB = func(ctx context.Context, w v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode) error { + assert.Equal(t, n1, currentNode.GetID()) + return nil + } + err := branch.AbortNode(ctx, w, &v1alpha1.NodeSpec{ + ID: b1, + BranchNode: branchNode}) + assert.NoError(t, err) + }) +} + +func TestBranchHandler_Initialize(t *testing.T) { + ctx := context.TODO() + m := &mockNodeExecutor{} + branch := New(m, events.NewMockEventSink(), promutils.NewTestScope()) + assert.NoError(t, branch.Initialize(ctx)) +} + +// TODO incomplete test suite, add more +func TestBranchHandler_StartNode(t *testing.T) { + ctx := context.TODO() + m := &mockNodeExecutor{} + branch := New(m, events.NewMockEventSink(), promutils.NewTestScope()) + childNodeID := "child" + childDatadir := v1alpha1.DataReference("test") + w := &v1alpha1.FlyteWorkflow{ + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "test", + }, + Status: v1alpha1.WorkflowStatus{ + NodeStatus: map[v1alpha1.NodeID]*v1alpha1.NodeStatus{ + childNodeID: { + DataDir: childDatadir, + }, + }, + }, + } + _, inputs := getComparisonExpression(1, core.ComparisonExpression_NEQ, 1) + + tests := []struct { + name string + node v1alpha1.ExecutableNode + isErr bool + expectedStatus handler.Status + }{ + {"NoBranchNode", &v1alpha1.NodeSpec{}, false, handler.StatusFailed(nil)}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s, err := branch.StartNode(ctx, w, test.node, inputs) + if test.isErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, test.expectedStatus.Phase, s.Phase) + + }) + } +} + +func init() { + labeled.SetMetricKeys(contextutils.ProjectKey, contextutils.DomainKey, contextutils.WorkflowIDKey, contextutils.TaskIDKey) +} diff --git a/flytepropeller/pkg/controller/nodes/common/output_resolver.go b/flytepropeller/pkg/controller/nodes/common/output_resolver.go new file mode 100644 index 0000000000..b356ab7760 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/common/output_resolver.go @@ -0,0 +1,62 @@ +package common + +import ( + "context" + + "github.com/lyft/flytepropeller/pkg/controller/nodes/errors" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/storage" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" +) + +func CreateAliasMap(aliases []v1alpha1.Alias) map[string]string { + aliasToVarMap := make(map[string]string, len(aliases)) + for _, alias := range aliases { + aliasToVarMap[alias.GetAlias()] = alias.GetVar() + } + return aliasToVarMap +} + +// A simple output resolver that expects an outputs.pb at the data directory of the node. +type SimpleOutputsResolver struct { + store storage.ProtobufStore +} + +func (r SimpleOutputsResolver) ExtractOutput(ctx context.Context, w v1alpha1.ExecutableWorkflow, n v1alpha1.ExecutableNode, + bindToVar handler.VarName) (values *core.Literal, err error) { + d := &handler.Data{} + nodeStatus := w.GetNodeExecutionStatus(n.GetID()) + outputsFileRef := v1alpha1.GetOutputsFile(nodeStatus.GetDataDir()) + if err := r.store.ReadProtobuf(ctx, outputsFileRef, d); err != nil { + return nil, errors.Wrapf(errors.CausedByError, n.GetID(), err, "Failed to GetPrevious data from dataDir [%v]", nodeStatus.GetDataDir()) + } + + if d.Literals == nil { + return nil, errors.Errorf(errors.OutputsNotFoundError, n.GetID(), + "Outputs not found at [%v]", outputsFileRef) + } + + aliasMap := CreateAliasMap(n.GetOutputAlias()) + if variable, ok := aliasMap[bindToVar]; ok { + logger.Debugf(ctx, "Mapping [%v].[%v] -> [%v].[%v]", n.GetID(), variable, n.GetID(), bindToVar) + bindToVar = variable + } + + l, ok := d.Literals[bindToVar] + if !ok { + return nil, errors.Errorf(errors.OutputsNotFoundError, n.GetID(), + "Failed to find [%v].[%v]", n.GetID(), bindToVar) + } + + return l, nil +} + +// Creates a simple output resolver that expects an outputs.pb at the data directory of the node. +func NewSimpleOutputsResolver(store storage.ProtobufStore) SimpleOutputsResolver { + return SimpleOutputsResolver{ + store: store, + } +} diff --git a/flytepropeller/pkg/controller/nodes/common/output_resolver_test.go b/flytepropeller/pkg/controller/nodes/common/output_resolver_test.go new file mode 100644 index 0000000000..42eaa76d6a --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/common/output_resolver_test.go @@ -0,0 +1,30 @@ +package common + +import ( + "testing" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/stretchr/testify/assert" +) + +func TestCreateAliasMap(t *testing.T) { + { + aliases := []v1alpha1.Alias{ + {Alias: core.Alias{Var: "x", Alias: "y"}}, + } + m := CreateAliasMap(aliases) + assert.Equal(t, map[string]string{ + "y": "x", + }, m) + } + { + var aliases []v1alpha1.Alias + m := CreateAliasMap(aliases) + assert.Equal(t, map[string]string{}, m) + } + { + m := CreateAliasMap(nil) + assert.Equal(t, map[string]string{}, m) + } +} diff --git a/flytepropeller/pkg/controller/nodes/dynamic/handler.go b/flytepropeller/pkg/controller/nodes/dynamic/handler.go new file mode 100644 index 0000000000..fe8af1cb36 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/dynamic/handler.go @@ -0,0 +1,391 @@ +package dynamic + +import ( + "context" + "time" + + "github.com/lyft/flytestdlib/promutils/labeled" + + "github.com/lyft/flytepropeller/pkg/compiler" + common2 "github.com/lyft/flytepropeller/pkg/compiler/common" + "github.com/lyft/flytepropeller/pkg/controller/executors" + + "github.com/lyft/flytepropeller/pkg/controller/nodes/common" + + "github.com/lyft/flytestdlib/promutils" + "k8s.io/apimachinery/pkg/util/rand" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/storage" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/compiler/transformers/k8s" + "github.com/lyft/flytepropeller/pkg/controller/nodes/errors" + "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" +) + +type dynamicNodeHandler struct { + handler.IFace + metrics metrics + simpleResolver common.SimpleOutputsResolver + store *storage.DataStore + nodeExecutor executors.Node + enQWorkflow v1alpha1.EnqueueWorkflow +} + +type metrics struct { + buildDynamicWorkflow labeled.StopWatch + retrieveDynamicJobSpec labeled.StopWatch +} + +func newMetrics(scope promutils.Scope) metrics { + return metrics{ + buildDynamicWorkflow: labeled.NewStopWatch("build_dynamic_workflow", "Overhead for building a dynamic workflow in memory.", time.Microsecond, scope), + retrieveDynamicJobSpec: labeled.NewStopWatch("retrieve_dynamic_spec", "Overhead of downloading and unmarshaling dynamic job spec", time.Microsecond, scope), + } +} + +func (e dynamicNodeHandler) ExtractOutput(ctx context.Context, w v1alpha1.ExecutableWorkflow, n v1alpha1.ExecutableNode, + bindToVar handler.VarName) (values *core.Literal, err error) { + outputResolver, casted := e.IFace.(handler.OutputResolver) + if !casted { + return e.simpleResolver.ExtractOutput(ctx, w, n, bindToVar) + } + + return outputResolver.ExtractOutput(ctx, w, n, bindToVar) +} + +func (e dynamicNodeHandler) getDynamicJobSpec(ctx context.Context, node v1alpha1.ExecutableNode, nodeStatus v1alpha1.ExecutableNodeStatus) (*core.DynamicJobSpec, error) { + t := e.metrics.retrieveDynamicJobSpec.Start(ctx) + defer t.Stop() + + futuresFilePath, err := e.store.ConstructReference(ctx, nodeStatus.GetDataDir(), v1alpha1.GetFutureFile()) + if err != nil { + logger.Warnf(ctx, "Failed to construct data path for futures file. Error: %v", err) + return nil, err + } + + // If no futures file produced, then declare success and return. + if metadata, err := e.store.Head(ctx, futuresFilePath); err != nil { + logger.Warnf(ctx, "Failed to read futures file. Error: %v", err) + return nil, errors.Wrapf(errors.CausedByError, node.GetID(), err, "Failed to do HEAD on futures file.") + } else if !metadata.Exists() { + return nil, nil + } + + djSpec := &core.DynamicJobSpec{} + if err := e.store.ReadProtobuf(ctx, futuresFilePath, djSpec); err != nil { + logger.Warnf(ctx, "Failed to read futures file. Error: %v", err) + return nil, errors.Wrapf(errors.CausedByError, node.GetID(), err, "Failed to read futures protobuf file.") + } + + return djSpec, nil +} + +func (e dynamicNodeHandler) buildDynamicWorkflowTemplate(ctx context.Context, djSpec *core.DynamicJobSpec, + w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, nodeStatus v1alpha1.ExecutableNodeStatus) ( + *core.WorkflowTemplate, error) { + + iface, err := underlyingInterface(w, node) + if err != nil { + return nil, err + } + + // Modify node IDs to include lineage, the entire system assumes node IDs are unique per parent WF. + // We keep track of the original node ids because that's where inputs are written to. + parentNodeID := node.GetID() + for _, n := range djSpec.Nodes { + newID, err := hierarchicalNodeID(parentNodeID, n.Id) + if err != nil { + return nil, err + } + + // Instantiate a nodeStatus using the modified name but set its data directory using the original name. + subNodeStatus := nodeStatus.GetNodeExecutionStatus(newID) + originalNodePath, err := e.store.ConstructReference(ctx, nodeStatus.GetDataDir(), n.Id) + if err != nil { + return nil, err + } + + subNodeStatus.SetDataDir(originalNodePath) + subNodeStatus.ResetDirty() + + n.Id = newID + } + + if node.GetTaskID() != nil { + // If the parent is a task, pass down data children nodes should inherit. + parentTask, err := w.GetTask(*node.GetTaskID()) + if err != nil { + return nil, errors.Wrapf(errors.CausedByError, node.GetID(), err, "Failed to find task [%v].", node.GetTaskID()) + } + + for _, t := range djSpec.Tasks { + if t.GetContainer() != nil && parentTask.CoreTask().GetContainer() != nil { + t.GetContainer().Config = append(t.GetContainer().Config, parentTask.CoreTask().GetContainer().Config...) + } + } + } + + for _, o := range djSpec.Outputs { + err = updateBindingNodeIDsWithLineage(parentNodeID, o.Binding) + if err != nil { + return nil, err + } + } + + return &core.WorkflowTemplate{ + Id: &core.Identifier{ + Project: w.GetExecutionID().Project, + Domain: w.GetExecutionID().Domain, + Version: rand.String(10), + Name: rand.String(10), + ResourceType: core.ResourceType_WORKFLOW, + }, + Nodes: djSpec.Nodes, + Outputs: djSpec.Outputs, + Interface: iface, + }, nil +} + +// For any node that is not in a NEW/READY state in the recording, CheckNodeStatus will be invoked. The implementation should handle +// idempotency and return the current observed state of the node +func (e dynamicNodeHandler) CheckNodeStatus(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, + previousNodeStatus v1alpha1.ExecutableNodeStatus) (handler.Status, error) { + + var status handler.Status + var err error + switch previousNodeStatus.GetOrCreateDynamicNodeStatus().GetDynamicNodePhase() { + case v1alpha1.DynamicNodePhaseExecuting: + // If the node succeeded, check if it generated a futures.pb file to execute. + dynamicWF, nStatus, _, err := e.buildContextualDynamicWorkflow(ctx, w, node, previousNodeStatus) + if err != nil { + return handler.StatusFailed(err), nil + } + + s, err := e.progressDynamicWorkflow(ctx, nStatus, dynamicWF) + if err == nil && s == handler.StatusSuccess { + // After the dynamic node completes we need to copy the outputs from its end nodes to the parent nodes status + endNode := dynamicWF.GetNodeExecutionStatus(v1alpha1.EndNodeID) + outputPath := v1alpha1.GetOutputsFile(endNode.GetDataDir()) + destinationPath := v1alpha1.GetOutputsFile(previousNodeStatus.GetDataDir()) + logger.Infof(ctx, "Dynamic workflow completed, copying outputs from the end-node [%s] to the parent node data dir [%s]", outputPath, destinationPath) + if err := e.store.CopyRaw(ctx, outputPath, destinationPath, storage.Options{}); err != nil { + logger.Errorf(ctx, "Failed to copy outputs from dynamic sub-wf [%s] to [%s]. Error: %s", outputPath, destinationPath, err.Error()) + return handler.StatusUndefined, errors.Wrapf(errors.StorageError, node.GetID(), err, "Failed to copy outputs from dynamic sub-wf [%s] to [%s]. Error: %s", outputPath, destinationPath, err.Error()) + } + if successHandler, ok := e.IFace.(handler.PostNodeSuccessHandler); ok { + return successHandler.HandleNodeSuccess(ctx, w, node) + } + logger.Warnf(ctx, "Bad configuration for dynamic node, no post node success handler found!") + } + return s, err + default: + // Invoke the underlying check node status. + status, err = e.IFace.CheckNodeStatus(ctx, w, node, previousNodeStatus) + + if err != nil { + return status, err + } + + if status.Phase != handler.PhaseSuccess { + return status, err + } + + // If the node succeeded, check if it generated a futures.pb file to execute. + _, _, isDynamic, err := e.buildContextualDynamicWorkflow(ctx, w, node, previousNodeStatus) + if err != nil { + return handler.StatusFailed(err), nil + } + + if !isDynamic { + if successHandler, ok := e.IFace.(handler.PostNodeSuccessHandler); ok { + return successHandler.HandleNodeSuccess(ctx, w, node) + } + logger.Warnf(ctx, "Bad configuration for dynamic node, no post node success handler found!") + return status, err + } + + // Mark the node as a dynamic node executing its child nodes. Next time check node status is called, it'll go + // directly to progress the dynamically generated workflow. + previousNodeStatus.GetOrCreateDynamicNodeStatus().SetDynamicNodePhase(v1alpha1.DynamicNodePhaseExecuting) + + return handler.StatusRunning, nil + } +} + +func (e dynamicNodeHandler) buildContextualDynamicWorkflow(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, + previousNodeStatus v1alpha1.ExecutableNodeStatus) (dynamicWf v1alpha1.ExecutableWorkflow, status v1alpha1.ExecutableNodeStatus, isDynamic bool, err error) { + + t := e.metrics.buildDynamicWorkflow.Start(ctx) + defer t.Stop() + + var nStatus v1alpha1.ExecutableNodeStatus + // We will only get here if the Phase is success. The downside is that this is an overhead for all nodes that are + // not dynamic. But given that we will only check once, it should be ok. + // TODO: Check for node.is_dynamic once the IDL changes are in and SDK migration has happened. + djSpec, err := e.getDynamicJobSpec(ctx, node, previousNodeStatus) + if err != nil { + return nil, nil, false, err + } + + if djSpec == nil { + return nil, status, false, nil + } + + rootNodeStatus := w.GetNodeExecutionStatus(node.GetID()) + if node.GetTaskID() != nil { + // TODO: This is a hack to set parent task execution id, we should move to node-node relationship. + execID, err := e.getTaskExecutionIdentifier(ctx, w, node) + if err != nil { + return nil, nil, false, err + } + + dynamicNode := &v1alpha1.NodeSpec{ + ID: "dynamic-node", + } + + nStatus = rootNodeStatus.GetNodeExecutionStatus(dynamicNode.GetID()) + nStatus.SetDataDir(rootNodeStatus.GetDataDir()) + nStatus.SetParentTaskID(execID) + } else { + nStatus = w.GetNodeExecutionStatus(node.GetID()) + } + + var closure *core.CompiledWorkflowClosure + wf, err := e.buildDynamicWorkflowTemplate(ctx, djSpec, w, node, nStatus) + if err != nil { + return nil, nil, true, err + } + + compiledTasks, err := compileTasks(ctx, djSpec.Tasks) + if err != nil { + return nil, nil, true, err + } + + // TODO: This will currently fail if the WF references any launch plans + closure, err = compiler.CompileWorkflow(wf, djSpec.Subworkflows, compiledTasks, []common2.InterfaceProvider{}) + if err != nil { + return nil, nil, true, err + } + + subwf, err := k8s.BuildFlyteWorkflow(closure, nil, nil, "") + if err != nil { + return nil, nil, true, err + } + + return newContextualWorkflow(w, subwf, nStatus, subwf.Tasks, subwf.SubWorkflows), nStatus, true, nil +} + +func (e dynamicNodeHandler) progressDynamicWorkflow(ctx context.Context, parentNodeStatus v1alpha1.ExecutableNodeStatus, + w v1alpha1.ExecutableWorkflow) (handler.Status, error) { + + state, err := e.nodeExecutor.RecursiveNodeHandler(ctx, w, w.StartNode()) + if err != nil { + return handler.StatusUndefined, err + } + + if state.HasFailed() { + if w.GetOnFailureNode() != nil { + return handler.StatusFailing(state.Err), nil + } + return handler.StatusFailed(state.Err), nil + } + + if state.IsComplete() { + nodeID := "" + if parentNodeStatus.GetParentNodeID() != nil { + nodeID = *parentNodeStatus.GetParentNodeID() + } + + // If the WF interface has outputs, validate that the outputs file was written. + if outputBindings := w.GetOutputBindings(); len(outputBindings) > 0 { + endNodeStatus := w.GetNodeExecutionStatus(v1alpha1.EndNodeID) + if endNodeStatus == nil { + return handler.StatusFailed(errors.Errorf(errors.SubWorkflowExecutionFailed, nodeID, + "No end node found in subworkflow.")), nil + } + + sourcePath := v1alpha1.GetOutputsFile(endNodeStatus.GetDataDir()) + if metadata, err := e.store.Head(ctx, sourcePath); err == nil { + if !metadata.Exists() { + return handler.StatusFailed(errors.Errorf(errors.SubWorkflowExecutionFailed, nodeID, + "Subworkflow is expected to produce outputs but no outputs file was written to %v.", + sourcePath)), nil + } + } else { + return handler.StatusUndefined, err + } + + destinationPath := v1alpha1.GetOutputsFile(parentNodeStatus.GetDataDir()) + if err := e.store.CopyRaw(ctx, sourcePath, destinationPath, storage.Options{}); err != nil { + return handler.StatusFailed(errors.Wrapf(errors.OutputsNotFoundError, nodeID, + err, "Failed to copy subworkflow outputs from [%v] to [%v]", + sourcePath, destinationPath)), nil + } + } + + return handler.StatusSuccess, nil + } + + if state.PartiallyComplete() { + // Re-enqueue the workflow + e.enQWorkflow(w.GetK8sWorkflowID().String()) + } + + return handler.StatusRunning, nil +} + +func (e dynamicNodeHandler) getTaskExecutionIdentifier(_ context.Context, w v1alpha1.ExecutableWorkflow, + node v1alpha1.ExecutableNode) (*core.TaskExecutionIdentifier, error) { + + taskID := node.GetTaskID() + task, err := w.GetTask(*taskID) + if err != nil { + return nil, errors.Wrapf(errors.BadSpecificationError, node.GetID(), err, "Unable to find task for taskId: [%v]", *taskID) + } + + nodeStatus := w.GetNodeExecutionStatus(node.GetID()) + return &core.TaskExecutionIdentifier{ + TaskId: task.CoreTask().Id, + RetryAttempt: nodeStatus.GetAttempts(), + NodeExecutionId: &core.NodeExecutionIdentifier{ + NodeId: node.GetID(), + ExecutionId: w.GetExecutionID().WorkflowExecutionIdentifier, + }, + }, nil +} + +func (e dynamicNodeHandler) AbortNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) error { + + previousNodeStatus := w.GetNodeExecutionStatus(node.GetID()) + switch previousNodeStatus.GetOrCreateDynamicNodeStatus().GetDynamicNodePhase() { + case v1alpha1.DynamicNodePhaseExecuting: + dynamicWF, _, isDynamic, err := e.buildContextualDynamicWorkflow(ctx, w, node, previousNodeStatus) + if err != nil { + return err + } + + if !isDynamic { + return nil + } + + return e.nodeExecutor.AbortHandler(ctx, dynamicWF, dynamicWF.StartNode()) + default: + // Invoke the underlying abort node. + return e.IFace.AbortNode(ctx, w, node) + } +} + +func New(underlying handler.IFace, nodeExecutor executors.Node, enQWorkflow v1alpha1.EnqueueWorkflow, store *storage.DataStore, + scope promutils.Scope) handler.IFace { + + return dynamicNodeHandler{ + IFace: underlying, + metrics: newMetrics(scope), + nodeExecutor: nodeExecutor, + enQWorkflow: enQWorkflow, + store: store, + } +} diff --git a/flytepropeller/pkg/controller/nodes/dynamic/handler_test.go b/flytepropeller/pkg/controller/nodes/dynamic/handler_test.go new file mode 100644 index 0000000000..f96e157b9d --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/dynamic/handler_test.go @@ -0,0 +1,261 @@ +package dynamic + +import ( + "context" + "fmt" + "testing" + + "github.com/lyft/flyteidl/clients/go/events" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flyteplugins/go/tasks/v1/types/mocks" + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/storage" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "k8s.io/apimachinery/pkg/api/resource" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + pluginsV1 "github.com/lyft/flyteplugins/go/tasks/v1/types" + typesV1 "k8s.io/api/core/v1" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/controller/catalog" + mocks2 "github.com/lyft/flytepropeller/pkg/controller/executors/mocks" + "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" + "github.com/lyft/flytepropeller/pkg/controller/nodes/task" +) + +const DataDir = storage.DataReference("test-data") +const NodeID = "n1" + +var ( + enqueueWfFunc = func(id string) {} + fakeKubeClient = mocks2.NewFakeKubeClient() +) + +func newIntegerPrimitive(value int64) *core.Primitive { + return &core.Primitive{Value: &core.Primitive_Integer{Integer: value}} +} + +func newScalarInteger(value int64) *core.Scalar { + return &core.Scalar{ + Value: &core.Scalar_Primitive{ + Primitive: newIntegerPrimitive(value), + }, + } +} + +func newIntegerLiteral(value int64) *core.Literal { + return &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: newScalarInteger(value), + }, + } +} + +func createTask(id string, ttype string, discoverable bool) *v1alpha1.TaskSpec { + return &v1alpha1.TaskSpec{ + TaskTemplate: &core.TaskTemplate{ + Id: &core.Identifier{Name: id}, + Type: ttype, + Metadata: &core.TaskMetadata{Discoverable: discoverable}, + Interface: &core.TypedInterface{ + Inputs: &core.VariableMap{}, + Outputs: &core.VariableMap{ + Variables: map[string]*core.Variable{ + "out1": &core.Variable{ + Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}}, + }, + }, + }, + }, + }, + } +} + +func mockCatalogClient() catalog.Client { + return &catalog.MockCatalogClient{ + GetFunc: func(ctx context.Context, task *core.TaskTemplate, inputPath storage.DataReference) (*core.LiteralMap, error) { + return nil, nil + }, + PutFunc: func(ctx context.Context, task *core.TaskTemplate, execId *core.TaskExecutionIdentifier, inputPath storage.DataReference, outputPath storage.DataReference) error { + return nil + }, + } +} + +func createWf(id string, execID string, project string, domain string, name string) *v1alpha1.FlyteWorkflow { + return &v1alpha1.FlyteWorkflow{ + ExecutionID: v1alpha1.WorkflowExecutionIdentifier{ + WorkflowExecutionIdentifier: &core.WorkflowExecutionIdentifier{ + Project: project, + Domain: domain, + Name: execID, + }, + }, + Status: v1alpha1.WorkflowStatus{ + NodeStatus: map[v1alpha1.NodeID]*v1alpha1.NodeStatus{ + NodeID: { + DataDir: DataDir, + }, + }, + }, + ObjectMeta: v1.ObjectMeta{ + Name: name, + }, + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: id, + }, + } +} + +func createStartNode() *v1alpha1.NodeSpec { + return &v1alpha1.NodeSpec{ + ID: NodeID, + Kind: v1alpha1.NodeKindStart, + Resources: &typesV1.ResourceRequirements{ + Requests: typesV1.ResourceList{ + typesV1.ResourceCPU: resource.MustParse("1"), + }, + }, + } +} + +func createInmemoryDataStore(t testing.TB, scope promutils.Scope) *storage.DataStore { + cfg := storage.Config{ + Type: storage.TypeMemory, + } + d, err := storage.NewDataStore(&cfg, scope) + assert.NoError(t, err) + return d +} + +func TestTaskHandler_CheckNodeStatusDiscovery(t *testing.T) { + ctx := context.Background() + + taskID := "t1" + tk := createTask(taskID, "container", true) + tk.Id.Project = "flytekit" + w := createWf("w1", "w2-exec", "projTest", "domainTest", "checkNodeTestName") + w.Tasks = map[v1alpha1.TaskID]*v1alpha1.TaskSpec{ + taskID: tk, + } + n := createStartNode() + n.TaskRef = &taskID + + t.Run("TaskExecDoneDiscoveryWriteFail", func(t *testing.T) { + taskExec := &mocks.Executor{} + taskExec.On("GetProperties").Return(pluginsV1.ExecutorProperties{}) + taskExec.On("CheckTaskStatus", + ctx, + mock.MatchedBy(func(o pluginsV1.TaskContext) bool { return true }), + mock.MatchedBy(func(o *core.TaskTemplate) bool { return true }), + ).Return(pluginsV1.TaskStatusSucceeded, nil) + d := &task.FactoryFuncs{ + GetTaskExecutorCb: func(taskType v1alpha1.TaskType) (pluginsV1.Executor, error) { + if taskType == tk.Type { + return taskExec, nil + } + return nil, fmt.Errorf("no match") + }, + } + mockCatalog := catalog.MockCatalogClient{ + GetFunc: func(ctx context.Context, task *core.TaskTemplate, inputPath storage.DataReference) (*core.LiteralMap, error) { + return nil, nil + }, + PutFunc: func(ctx context.Context, task *core.TaskTemplate, execId *core.TaskExecutionIdentifier, inputPath storage.DataReference, outputPath storage.DataReference) error { + return status.Errorf(codes.DeadlineExceeded, "") + }, + } + store := createInmemoryDataStore(t, promutils.NewTestScope()) + paramsMap := make(map[string]*core.Literal) + paramsMap["out1"] = newIntegerLiteral(1) + err1 := store.WriteProtobuf(ctx, "test-data/inputs.pb", storage.Options{}, &core.LiteralMap{Literals: paramsMap}) + err2 := store.WriteProtobuf(ctx, "test-data/outputs.pb", storage.Options{}, &core.LiteralMap{Literals: paramsMap}) + assert.NoError(t, err1) + assert.NoError(t, err2) + + th := New( + task.NewTaskHandlerForFactory(events.NewMockEventSink(), store, enqueueWfFunc, d, &mockCatalog, fakeKubeClient, promutils.NewTestScope()), + nil, + enqueueWfFunc, + store, + promutils.NewTestScope(), + ) + + prevNodeStatus := &v1alpha1.NodeStatus{Phase: v1alpha1.NodePhaseRunning} + s, err := th.CheckNodeStatus(ctx, w, n, prevNodeStatus) + assert.NoError(t, err) + assert.Equal(t, handler.StatusSuccess, s) + }) + + t.Run("TaskExecDoneDiscoveryMissingOutputs", func(t *testing.T) { + taskExec := &mocks.Executor{} + taskExec.On("GetProperties").Return(pluginsV1.ExecutorProperties{}) + taskExec.On("CheckTaskStatus", + ctx, + mock.MatchedBy(func(o pluginsV1.TaskContext) bool { return true }), + mock.MatchedBy(func(o *core.TaskTemplate) bool { return true }), + ).Return(pluginsV1.TaskStatusSucceeded, nil) + d := &task.FactoryFuncs{ + GetTaskExecutorCb: func(taskType v1alpha1.TaskType) (pluginsV1.Executor, error) { + if taskType == tk.Type { + return taskExec, nil + } + return nil, fmt.Errorf("no match") + }, + } + store := createInmemoryDataStore(t, promutils.NewTestScope()) + th := New( + task.NewTaskHandlerForFactory(events.NewMockEventSink(), store, enqueueWfFunc, d, mockCatalogClient(), fakeKubeClient, promutils.NewTestScope()), + nil, + enqueueWfFunc, + store, + promutils.NewTestScope(), + ) + + prevNodeStatus := &v1alpha1.NodeStatus{Phase: v1alpha1.NodePhaseRunning} + s, err := th.CheckNodeStatus(ctx, w, n, prevNodeStatus) + assert.NoError(t, err) + assert.Equal(t, handler.PhaseRetryableFailure, s.Phase, "received: %s", s.Phase.String()) + }) + + t.Run("TaskExecDoneDiscoveryWriteSuccess", func(t *testing.T) { + taskExec := &mocks.Executor{} + taskExec.On("GetProperties").Return(pluginsV1.ExecutorProperties{}) + taskExec.On("CheckTaskStatus", + ctx, + mock.MatchedBy(func(o pluginsV1.TaskContext) bool { return true }), + mock.MatchedBy(func(o *core.TaskTemplate) bool { return true }), + ).Return(pluginsV1.TaskStatusSucceeded, nil) + d := &task.FactoryFuncs{ + GetTaskExecutorCb: func(taskType v1alpha1.TaskType) (pluginsV1.Executor, error) { + if taskType == tk.Type { + return taskExec, nil + } + return nil, fmt.Errorf("no match") + }, + } + store := createInmemoryDataStore(t, promutils.NewTestScope()) + paramsMap := make(map[string]*core.Literal) + paramsMap["out1"] = newIntegerLiteral(100) + err1 := store.WriteProtobuf(ctx, "test-data/inputs.pb", storage.Options{}, &core.LiteralMap{Literals: paramsMap}) + err2 := store.WriteProtobuf(ctx, "test-data/outputs.pb", storage.Options{}, &core.LiteralMap{Literals: paramsMap}) + assert.NoError(t, err1) + assert.NoError(t, err2) + th := New( + task.NewTaskHandlerForFactory(events.NewMockEventSink(), store, enqueueWfFunc, d, mockCatalogClient(), fakeKubeClient, promutils.NewTestScope()), + nil, + enqueueWfFunc, + store, + promutils.NewTestScope(), + ) + + prevNodeStatus := &v1alpha1.NodeStatus{Phase: v1alpha1.NodePhaseRunning} + s, err := th.CheckNodeStatus(ctx, w, n, prevNodeStatus) + assert.NoError(t, err) + assert.Equal(t, handler.StatusSuccess, s) + }) +} diff --git a/flytepropeller/pkg/controller/nodes/dynamic/subworkflow.go b/flytepropeller/pkg/controller/nodes/dynamic/subworkflow.go new file mode 100644 index 0000000000..3869a5c219 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/dynamic/subworkflow.go @@ -0,0 +1,89 @@ +package dynamic + +import ( + "context" + + "github.com/lyft/flytepropeller/pkg/controller/executors" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytestdlib/storage" +) + +// Defines a sub-contextual workflow that is built in-memory to represent a dynamic job execution plan. +type contextualWorkflow struct { + v1alpha1.ExecutableWorkflow + + extraTasks map[v1alpha1.TaskID]*v1alpha1.TaskSpec + extraWorkflows map[v1alpha1.WorkflowID]*v1alpha1.WorkflowSpec + status *ContextualWorkflowStatus +} + +func newContextualWorkflow(baseWorkflow v1alpha1.ExecutableWorkflow, + subwf v1alpha1.ExecutableSubWorkflow, + status v1alpha1.ExecutableNodeStatus, + tasks map[v1alpha1.TaskID]*v1alpha1.TaskSpec, + workflows map[v1alpha1.WorkflowID]*v1alpha1.WorkflowSpec) v1alpha1.ExecutableWorkflow { + + return &contextualWorkflow{ + ExecutableWorkflow: executors.NewSubContextualWorkflow(baseWorkflow, subwf, status), + extraTasks: tasks, + extraWorkflows: workflows, + status: newContextualWorkflowStatus(baseWorkflow.GetExecutionStatus(), status), + } +} + +func (w contextualWorkflow) GetExecutionStatus() v1alpha1.ExecutableWorkflowStatus { + return w.status +} + +func (w contextualWorkflow) GetTask(id v1alpha1.TaskID) (v1alpha1.ExecutableTask, error) { + if task, found := w.extraTasks[id]; found { + return task, nil + } + + return w.ExecutableWorkflow.GetTask(id) +} + +func (w contextualWorkflow) FindSubWorkflow(id v1alpha1.WorkflowID) v1alpha1.ExecutableSubWorkflow { + if wf, found := w.extraWorkflows[id]; found { + return wf + } + + return w.ExecutableWorkflow.FindSubWorkflow(id) +} + +// A contextual workflow status to override some of the implementations. +type ContextualWorkflowStatus struct { + v1alpha1.ExecutableWorkflowStatus + baseStatus v1alpha1.ExecutableNodeStatus +} + +func (w ContextualWorkflowStatus) GetDataDir() v1alpha1.DataReference { + return w.baseStatus.GetDataDir() +} + +// Overrides default node data dir to work around the contractual assumption between Propeller and Futures to write all +// sub-node inputs into current node data directory. +// E.g. +// if current node data dir is /wf_exec/node-1/data/ +// and the task ran and yielded 2 nodes, the structure will look like this: +// /wf_exec/node-1/data/ +// |_ inputs.pb +// |_ futures.pb +// |_ sub-node1/inputs.pb +// |_ sub-node2/inputs.pb +// TODO: This is just a stop-gap until we transition the DynamicJobSpec to be a full-fledged workflow spec. +// TODO: this will allow us to have proper data bindings between nodes then we can stop making assumptions about data refs. +func (w ContextualWorkflowStatus) ConstructNodeDataDir(ctx context.Context, constructor storage.ReferenceConstructor, + name v1alpha1.NodeID) (storage.DataReference, error) { + return constructor.ConstructReference(ctx, w.GetDataDir(), name) +} + +func newContextualWorkflowStatus(baseWfStatus v1alpha1.ExecutableWorkflowStatus, + baseStatus v1alpha1.ExecutableNodeStatus) *ContextualWorkflowStatus { + + return &ContextualWorkflowStatus{ + ExecutableWorkflowStatus: baseWfStatus, + baseStatus: baseStatus, + } +} diff --git a/flytepropeller/pkg/controller/nodes/dynamic/subworkflow_test.go b/flytepropeller/pkg/controller/nodes/dynamic/subworkflow_test.go new file mode 100644 index 0000000000..9433b6ee58 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/dynamic/subworkflow_test.go @@ -0,0 +1,52 @@ +package dynamic + +import ( + "context" + "testing" + + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/storage" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" +) + +func TestNewContextualWorkflow(t *testing.T) { + wf := &mocks.ExecutableWorkflow{} + calledBase := false + wf.On("GetAnnotations").Return(map[string]string{}).Run(func(_ mock.Arguments) { + calledBase = true + }) + + wf.On("GetExecutionStatus").Return(&mocks.ExecutableWorkflowStatus{}) + + subwf := &mocks.ExecutableSubWorkflow{} + cWF := newContextualWorkflow(wf, subwf, nil, nil, nil) + cWF.GetAnnotations() + + assert.True(t, calledBase) +} + +func TestConstructNodeDataDir(t *testing.T) { + wf := &mocks.ExecutableWorkflow{} + wf.On("GetExecutionStatus").Return(&mocks.ExecutableWorkflowStatus{}) + + wfStatus := &mocks.ExecutableWorkflowStatus{} + wfStatus.On("GetDataDir").Return(storage.DataReference("fk://wrong/")).Run(func(_ mock.Arguments) { + assert.FailNow(t, "Should call the override") + }) + + nodeStatus := &mocks.ExecutableNodeStatus{} + nodeStatus.On("GetDataDir").Return(storage.DataReference("fk://right/")) + + ds, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) + assert.NoError(t, err) + cWF := newContextualWorkflowStatus(wfStatus, nodeStatus) + + dataDir, err := cWF.ConstructNodeDataDir(context.TODO(), ds, "my_node") + assert.NoError(t, err) + assert.NotNil(t, dataDir) + assert.Equal(t, "fk://right/my_node", dataDir.String()) +} diff --git a/flytepropeller/pkg/controller/nodes/dynamic/utils.go b/flytepropeller/pkg/controller/nodes/dynamic/utils.go new file mode 100644 index 0000000000..42c9e081c0 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/dynamic/utils.go @@ -0,0 +1,105 @@ +package dynamic + +import ( + "context" + + "k8s.io/apimachinery/pkg/util/sets" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/compiler" + "github.com/lyft/flytepropeller/pkg/controller/nodes/errors" + "github.com/lyft/flytepropeller/pkg/utils" +) + +// Constructs the expected interface of a given node. +func underlyingInterface(w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) (*core.TypedInterface, error) { + iface := &core.TypedInterface{} + if node.GetTaskID() != nil { + t, err := w.GetTask(*node.GetTaskID()) + if err != nil { + // Should never happen + return nil, err + } + + iface.Outputs = t.CoreTask().GetInterface().Outputs + } else if wfNode := node.GetWorkflowNode(); wfNode != nil { + if wfRef := wfNode.GetSubWorkflowRef(); wfRef != nil { + t := w.FindSubWorkflow(*wfRef) + if t == nil { + // Should never happen + return nil, errors.Errorf(errors.IllegalStateError, node.GetID(), "Couldn't find subworkflow [%v].", wfRef) + } + + iface.Outputs = t.GetOutputs().VariableMap + } else { + return nil, errors.Errorf(errors.IllegalStateError, node.GetID(), "Unknown interface") + } + } else if node.GetBranchNode() != nil { + if ifBlock := node.GetBranchNode().GetIf(); ifBlock != nil && ifBlock.GetThenNode() != nil { + bn, found := w.GetNode(*ifBlock.GetThenNode()) + if !found { + return nil, errors.Errorf(errors.IllegalStateError, node.GetID(), "Couldn't find branch node [%v]", + *ifBlock.GetThenNode()) + } + + return underlyingInterface(w, bn) + } + + return nil, errors.Errorf(errors.IllegalStateError, node.GetID(), "Empty branch detected.") + } else { + return nil, errors.Errorf(errors.IllegalStateError, node.GetID(), "Unknown interface.") + } + + return iface, nil +} + +func hierarchicalNodeID(parentNodeID, nodeID string) (string, error) { + return utils.FixedLengthUniqueIDForParts(20, parentNodeID, nodeID) +} + +func updateBindingNodeIDsWithLineage(parentNodeID string, binding *core.BindingData) (err error) { + switch b := binding.Value.(type) { + case *core.BindingData_Promise: + b.Promise.NodeId, err = hierarchicalNodeID(parentNodeID, b.Promise.NodeId) + if err != nil { + return err + } + case *core.BindingData_Collection: + for _, item := range b.Collection.Bindings { + err = updateBindingNodeIDsWithLineage(parentNodeID, item) + if err != nil { + return err + } + } + case *core.BindingData_Map: + for _, item := range b.Map.Bindings { + err = updateBindingNodeIDsWithLineage(parentNodeID, item) + if err != nil { + return err + } + } + } + + return nil +} + +func compileTasks(_ context.Context, tasks []*core.TaskTemplate) ([]*core.CompiledTask, error) { + compiledTasks := make([]*core.CompiledTask, 0, len(tasks)) + visitedTasks := sets.NewString() + for _, t := range tasks { + if visitedTasks.Has(t.Id.String()) { + continue + } + + ct, err := compiler.CompileTask(t) + if err != nil { + return nil, err + } + + compiledTasks = append(compiledTasks, ct) + visitedTasks.Insert(t.Id.String()) + } + + return compiledTasks, nil +} diff --git a/flytepropeller/pkg/controller/nodes/dynamic/utils_test.go b/flytepropeller/pkg/controller/nodes/dynamic/utils_test.go new file mode 100644 index 0000000000..e5df771a64 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/dynamic/utils_test.go @@ -0,0 +1,77 @@ +package dynamic + +import ( + "testing" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" + "github.com/stretchr/testify/mock" + + "github.com/stretchr/testify/assert" +) + +func TestHierarchicalNodeID(t *testing.T) { + t.Run("empty parent", func(t *testing.T) { + actual, err := hierarchicalNodeID("", "abc") + assert.NoError(t, err) + assert.Equal(t, "-abc", actual) + }) + + t.Run("long result", func(t *testing.T) { + actual, err := hierarchicalNodeID("abcdefghijklmnopqrstuvwxyz", "abc") + assert.NoError(t, err) + assert.Equal(t, "fpa3kc3y", actual) + }) +} + +func TestUnderlyingInterface(t *testing.T) { + expectedIface := &core.TypedInterface{ + Outputs: &core.VariableMap{ + Variables: map[string]*core.Variable{ + "in": { + Type: &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_INTEGER, + }, + }, + }, + }, + }, + } + wf := &mocks.ExecutableWorkflow{} + + subWF := &mocks.ExecutableSubWorkflow{} + wf.On("FindSubWorkflow", mock.Anything).Return(subWF) + subWF.On("GetOutputs").Return(&v1alpha1.OutputVarMap{VariableMap: expectedIface.Outputs}) + + task := &mocks.ExecutableTask{} + wf.On("GetTask", mock.Anything).Return(task, nil) + task.On("CoreTask").Return(&core.TaskTemplate{ + Interface: expectedIface, + }) + + n := &mocks.ExecutableNode{} + wf.On("GetNode", mock.Anything).Return(n) + emptyStr := "" + n.On("GetTaskID").Return(&emptyStr) + + iface, err := underlyingInterface(wf, n) + assert.NoError(t, err) + assert.NotNil(t, iface) + assert.Equal(t, expectedIface, iface) + + n = &mocks.ExecutableNode{} + n.On("GetTaskID").Return(nil) + + wfNode := &mocks.ExecutableWorkflowNode{} + n.On("GetWorkflowNode").Return(wfNode) + wfNode.On("GetSubWorkflowRef").Return(&emptyStr) + + iface, err = underlyingInterface(wf, n) + assert.NoError(t, err) + assert.NotNil(t, iface) + assert.Equal(t, expectedIface, iface) +} diff --git a/flytepropeller/pkg/controller/nodes/end/handler.go b/flytepropeller/pkg/controller/nodes/end/handler.go new file mode 100644 index 0000000000..5c53d9de13 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/end/handler.go @@ -0,0 +1,52 @@ +package end + +import ( + "context" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/controller/nodes/errors" + "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/storage" +) + +type endHandler struct { + store storage.ProtobufStore +} + +func (e *endHandler) Initialize(ctx context.Context) error { + return nil +} + +func (e *endHandler) StartNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, nodeInputs *handler.Data) (handler.Status, error) { + if nodeInputs != nil { + logger.Debugf(ctx, "Workflow has outputs. Storing them.") + nodeStatus := w.GetNodeExecutionStatus(node.GetID()) + o := v1alpha1.GetOutputsFile(nodeStatus.GetDataDir()) + so := storage.Options{} + if err := e.store.WriteProtobuf(ctx, o, so, nodeInputs); err != nil { + logger.Errorf(ctx, "Failed to store workflow outputs. Error [%s]", err) + return handler.StatusUndefined, errors.Wrapf(errors.CausedByError, node.GetID(), err, "Failed to store workflow outputs, as end-node") + } + } + logger.Debugf(ctx, "End node success") + return handler.StatusSuccess, nil +} + +func (e *endHandler) CheckNodeStatus(ctx context.Context, g v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, nodeStatus v1alpha1.ExecutableNodeStatus) (handler.Status, error) { + return handler.StatusSuccess, nil +} + +func (e *endHandler) HandleFailingNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) (handler.Status, error) { + return handler.StatusFailed(errors.Errorf(errors.IllegalStateError, node.GetID(), "End node cannot enter a failing state")), nil +} + +func (e *endHandler) AbortNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) error { + return nil +} + +func New(store storage.ProtobufStore) handler.IFace { + return &endHandler{ + store: store, + } +} diff --git a/flytepropeller/pkg/controller/nodes/end/handler_test.go b/flytepropeller/pkg/controller/nodes/end/handler_test.go new file mode 100644 index 0000000000..a5ab02c4b2 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/end/handler_test.go @@ -0,0 +1,135 @@ +package end + +import ( + "context" + "testing" + + "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flytestdlib/promutils/labeled" + + "github.com/golang/protobuf/proto" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/controller/nodes/errors" + "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" + "github.com/lyft/flytepropeller/pkg/utils" + flyteassert "github.com/lyft/flytepropeller/pkg/utils/assert" + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/storage" + regErrors "github.com/pkg/errors" + "github.com/stretchr/testify/assert" +) + +var testScope = promutils.NewScope("end_test") + +func createInmemoryDataStore(t testing.TB, scope promutils.Scope) *storage.DataStore { + cfg := storage.Config{ + Type: storage.TypeMemory, + } + d, err := storage.NewDataStore(&cfg, scope) + assert.NoError(t, err) + return d +} + +func init() { + labeled.SetMetricKeys(contextutils.NodeIDKey) +} + +type TestProtoDataStore struct { + ReadProtobufCb func(ctx context.Context, reference storage.DataReference, msg proto.Message) error + WriteProtobufCb func(ctx context.Context, reference storage.DataReference, opts storage.Options, msg proto.Message) error +} + +func (t TestProtoDataStore) ReadProtobuf(ctx context.Context, reference storage.DataReference, msg proto.Message) error { + return t.ReadProtobufCb(ctx, reference, msg) +} + +func (t TestProtoDataStore) WriteProtobuf(ctx context.Context, reference storage.DataReference, opts storage.Options, msg proto.Message) error { + return t.WriteProtobufCb(ctx, reference, opts, msg) +} + +func TestEndHandler_CheckNodeStatus(t *testing.T) { + e := endHandler{} + s, err := e.CheckNodeStatus(context.TODO(), nil, nil, nil) + assert.NoError(t, err) + assert.Equal(t, handler.StatusSuccess, s) +} + +func TestEndHandler_HandleFailingNode(t *testing.T) { + e := endHandler{} + node := &v1alpha1.NodeSpec{ + ID: v1alpha1.EndNodeID, + } + w := &v1alpha1.FlyteWorkflow{ + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: v1alpha1.WorkflowID("w1"), + }, + } + s, err := e.HandleFailingNode(context.TODO(), w, node) + assert.NoError(t, err) + assert.Equal(t, errors.IllegalStateError, s.Err.(*errors.NodeError).Code) +} + +func TestEndHandler_Initialize(t *testing.T) { + e := endHandler{} + assert.NoError(t, e.Initialize(context.TODO())) +} + +func TestEndHandler_StartNode(t *testing.T) { + inMem := createInmemoryDataStore(t, testScope.NewSubScope("x")) + e := New(inMem) + ctx := context.Background() + + inputs := &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "x": utils.MustMakePrimitiveLiteral("hello"), + "y": utils.MustMakePrimitiveLiteral("blah"), + }, + } + + outputRef := v1alpha1.DataReference("testRef") + + node := &v1alpha1.NodeSpec{ + ID: v1alpha1.EndNodeID, + } + w := &v1alpha1.FlyteWorkflow{ + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: v1alpha1.WorkflowID("w1"), + }, + } + w.Status.NodeStatus = map[v1alpha1.NodeID]*v1alpha1.NodeStatus{ + v1alpha1.EndNodeID: { + DataDir: outputRef, + }, + } + + t.Run("NoInputs", func(t *testing.T) { + s, err := e.StartNode(ctx, w, node, nil) + assert.NoError(t, err) + assert.Equal(t, handler.StatusSuccess, s) + }) + + outputLoc := v1alpha1.GetOutputsFile(outputRef) + t.Run("WithInputs", func(t *testing.T) { + s, err := e.StartNode(ctx, w, node, inputs) + assert.NoError(t, err) + assert.Equal(t, handler.StatusSuccess, s) + actual := &core.LiteralMap{} + if assert.NoError(t, inMem.ReadProtobuf(ctx, outputLoc, actual)) { + flyteassert.EqualLiteralMap(t, inputs, actual) + } + }) + + t.Run("StoreFailure", func(t *testing.T) { + store := &TestProtoDataStore{ + WriteProtobufCb: func(ctx context.Context, reference v1alpha1.DataReference, opts storage.Options, msg proto.Message) error { + return regErrors.Errorf("Fail") + }, + } + e := New(store) + s, err := e.StartNode(ctx, w, node, inputs) + assert.Error(t, err) + assert.True(t, errors.Matches(err, errors.CausedByError)) + assert.Equal(t, handler.StatusUndefined, s) + }) +} diff --git a/flytepropeller/pkg/controller/nodes/errors/codes.go b/flytepropeller/pkg/controller/nodes/errors/codes.go new file mode 100644 index 0000000000..0d0b3da099 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/errors/codes.go @@ -0,0 +1,26 @@ +package errors + +type ErrorCode string + +const ( + NotYetImplementedError ErrorCode = "NotYetImplementedError" + DownstreamNodeNotFoundError ErrorCode = "DownstreamNodeNotFound" + UserProvidedError ErrorCode = "UserProvidedError" + IllegalStateError ErrorCode = "IllegalStateError" + BadSpecificationError ErrorCode = "BadSpecificationError" + UnsupportedTaskTypeError ErrorCode = "UnsupportedTaskType" + BindingResolutionError ErrorCode = "BindingResolutionError" + CausedByError ErrorCode = "CausedByError" + RuntimeExecutionError ErrorCode = "RuntimeExecutionError" + SubWorkflowExecutionFailed ErrorCode = "SubWorkflowExecutionFailed" + RemoteChildWorkflowExecutionFailed ErrorCode = "RemoteChildWorkflowExecutionFailed" + NoBranchTakenError ErrorCode = "NoBranchTakenError" + OutputsNotFoundError ErrorCode = "OutputsNotFoundError" + StorageError ErrorCode = "StorageError" + EventRecordingFailed ErrorCode = "EventRecordingFailed" + CatalogCallFailed ErrorCode = "CatalogCallFailed" +) + +func (e ErrorCode) String() string { + return string(e) +} diff --git a/flytepropeller/pkg/controller/nodes/errors/errors.go b/flytepropeller/pkg/controller/nodes/errors/errors.go new file mode 100644 index 0000000000..05c096ddf7 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/errors/errors.go @@ -0,0 +1,80 @@ +package errors + +import ( + "fmt" + + "github.com/pkg/errors" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" +) + +type ErrorMessage = string + +type NodeError struct { + errors.StackTrace + Code ErrorCode + Message ErrorMessage + Node v1alpha1.NodeID +} + +func (n *NodeError) Error() string { + return fmt.Sprintf("failed at Node[%s]. %v: %v", n.Node, n.Code, n.Message) +} + +type NodeErrorWithCause struct { + *NodeError + cause error +} + +func (n *NodeErrorWithCause) Error() string { + return fmt.Sprintf("%v, caused by: %v", n.NodeError.Error(), errors.Cause(n)) +} + +func (n *NodeErrorWithCause) Cause() error { + return n.cause +} + +func errorf(c ErrorCode, n v1alpha1.NodeID, msgFmt string, args ...interface{}) *NodeError { + return &NodeError{ + Code: c, + Node: n, + Message: fmt.Sprintf(msgFmt, args...), + } +} + +func Errorf(c ErrorCode, n v1alpha1.NodeID, msgFmt string, args ...interface{}) error { + return errorf(c, n, msgFmt, args...) +} + +func Wrapf(c ErrorCode, n v1alpha1.NodeID, cause error, msgFmt string, args ...interface{}) error { + return &NodeErrorWithCause{ + NodeError: errorf(c, n, msgFmt, args...), + cause: cause, + } +} + +func Matches(err error, code ErrorCode) bool { + errCode, isNodeError := GetErrorCode(err) + if isNodeError { + return code == errCode + } + return false +} + +func GetErrorCode(err error) (code ErrorCode, isNodeError bool) { + isNodeError = false + e, ok := err.(*NodeError) + if ok { + code = e.Code + isNodeError = true + return + } + + e2, ok := err.(*NodeErrorWithCause) + if ok { + code = e2.Code + isNodeError = true + return + } + return +} diff --git a/flytepropeller/pkg/controller/nodes/errors/errors_test.go b/flytepropeller/pkg/controller/nodes/errors/errors_test.go new file mode 100644 index 0000000000..2386e92058 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/errors/errors_test.go @@ -0,0 +1,48 @@ +package errors + +import ( + "fmt" + "testing" + + extErrors "github.com/pkg/errors" + "github.com/stretchr/testify/assert" +) + +func TestErrorf(t *testing.T) { + msg := "msg" + err := Errorf(IllegalStateError, "n1", "Message [%v]", msg) + assert.NotNil(t, err) + e := err.(*NodeError) + assert.Equal(t, IllegalStateError, e.Code) + assert.Equal(t, "n1", e.Node) + assert.Equal(t, fmt.Sprintf("Message [%v]", msg), e.Message) + assert.Equal(t, err, extErrors.Cause(e)) + assert.Equal(t, "failed at Node[n1]. IllegalStateError: Message [msg]", err.Error()) +} + +func TestErrorfWithCause(t *testing.T) { + cause := extErrors.Errorf("Some Error") + msg := "msg" + err := Wrapf(IllegalStateError, "n1", cause, "Message [%v]", msg) + assert.NotNil(t, err) + e := err.(*NodeErrorWithCause) + assert.Equal(t, IllegalStateError, e.Code) + assert.Equal(t, "n1", e.Node) + assert.Equal(t, fmt.Sprintf("Message [%v]", msg), e.Message) + assert.Equal(t, cause, extErrors.Cause(e)) + assert.Equal(t, "failed at Node[n1]. IllegalStateError: Message [msg], caused by: Some Error", err.Error()) +} + +func TestMatches(t *testing.T) { + err := Errorf(IllegalStateError, "n1", "Message ") + assert.True(t, Matches(err, IllegalStateError)) + assert.False(t, Matches(err, BadSpecificationError)) + + cause := extErrors.Errorf("Some Error") + err = Wrapf(IllegalStateError, "n1", cause, "Message ") + assert.True(t, Matches(err, IllegalStateError)) + assert.False(t, Matches(err, BadSpecificationError)) + + assert.False(t, Matches(cause, IllegalStateError)) + assert.False(t, Matches(cause, BadSpecificationError)) +} diff --git a/flytepropeller/pkg/controller/nodes/executor.go b/flytepropeller/pkg/controller/nodes/executor.go new file mode 100644 index 0000000000..7d0ea5ee32 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/executor.go @@ -0,0 +1,540 @@ +package nodes + +import ( + "context" + "fmt" + "time" + + "github.com/golang/protobuf/ptypes" + "github.com/lyft/flyteidl/clients/go/events" + eventsErr "github.com/lyft/flyteidl/clients/go/events/errors" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/event" + "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/promutils/labeled" + "github.com/lyft/flytestdlib/storage" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/controller/catalog" + "github.com/lyft/flytepropeller/pkg/controller/executors" + "github.com/lyft/flytepropeller/pkg/controller/nodes/errors" + "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" + "github.com/lyft/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" + "github.com/lyft/flytepropeller/pkg/utils" +) + +type nodeMetrics struct { + FailureDuration labeled.StopWatch + SuccessDuration labeled.StopWatch + ResolutionFailure labeled.Counter + InputsWriteFailure labeled.Counter + + // Measures the latency between the last parent node stoppedAt time and current node's queued time. + TransitionLatency labeled.StopWatch + // Measures the latency between the time a node's been queued to the time the handler reported the executable moved + // to running state + QueuingLatency labeled.StopWatch +} + +type nodeExecutor struct { + nodeHandlerFactory HandlerFactory + enqueueWorkflow v1alpha1.EnqueueWorkflow + store *storage.DataStore + nodeRecorder events.NodeEventRecorder + metrics *nodeMetrics +} + +// In this method we check if the queue is ready to be processed and if so, we prime it in Admin as queued +// Before we start the node execution, we need to transition this Node status to Queued. +// This is because a node execution has to exist before task/wf executions can start. +func (c *nodeExecutor) queueNodeIfReady(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.BaseNode, nodeStatus v1alpha1.ExecutableNodeStatus) (handler.Status, error) { + logger.Debugf(ctx, "Node not yet started") + // Query the nodes information to figure out if it can be executed. + predicatePhase, err := CanExecute(ctx, w, node) + if err != nil { + logger.Debugf(ctx, "Node failed in CanExecute. Error [%s]", err) + return handler.StatusUndefined, err + } + if predicatePhase == PredicatePhaseSkip { + logger.Debugf(ctx, "Node upstream node was skipped. Skipping!") + return handler.StatusSkipped, nil + } else if predicatePhase == PredicatePhaseNotReady { + logger.Debugf(ctx, "Node not ready for executing.") + return handler.StatusNotStarted, nil + } + + if len(nodeStatus.GetDataDir()) == 0 { + // Predicate ready, lets Resolve the data + dataDir, err := w.GetExecutionStatus().ConstructNodeDataDir(ctx, c.store, node.GetID()) + if err != nil { + return handler.StatusUndefined, err + } + + nodeStatus.SetDataDir(dataDir) + } + + return handler.StatusQueued, nil +} + +func (c *nodeExecutor) RecordTransitionLatency(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, nodeStatus v1alpha1.ExecutableNodeStatus) { + if nodeStatus.GetPhase() == v1alpha1.NodePhaseNotYetStarted || nodeStatus.GetPhase() == v1alpha1.NodePhaseQueued { + // Log transition latency (The most recently finished parent node endAt time to this node's queuedAt time -now-) + t, err := GetParentNodeMaxEndTime(ctx, w, node) + if err != nil { + logger.Warnf(ctx, "Failed to record transition latency for node. Error: %s", err.Error()) + return + } + if !t.IsZero() { + c.metrics.TransitionLatency.Observe(ctx, t.Time, time.Now()) + } + } else if nodeStatus.GetPhase() == v1alpha1.NodePhaseRetryableFailure && nodeStatus.GetLastUpdatedAt() != nil { + c.metrics.TransitionLatency.Observe(ctx, nodeStatus.GetLastUpdatedAt().Time, time.Now()) + } +} + +// Start the node execution. This implies that the node will start processing +func (c *nodeExecutor) startNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, nodeStatus v1alpha1.ExecutableNodeStatus, h handler.IFace) (handler.Status, error) { + + // TODO: Performance problem, we may be in a retry loop and do not need to resolve the inputs again. + // For now we will do this. + dataDir := nodeStatus.GetDataDir() + var nodeInputs *handler.Data + if !node.IsStartNode() { + // Can execute + var err error + nodeInputs, err = Resolve(ctx, c.nodeHandlerFactory, w, node.GetID(), node.GetInputBindings(), c.store) + // TODO we need to handle retryable, network errors here!! + if err != nil { + c.metrics.ResolutionFailure.Inc(ctx) + logger.Warningf(ctx, "Failed to resolve inputs for Node. Error [%v]", err) + return handler.StatusFailed(err), nil + } + + if nodeInputs != nil { + inputsFile := v1alpha1.GetInputsFile(dataDir) + if err := c.store.WriteProtobuf(ctx, inputsFile, storage.Options{}, nodeInputs); err != nil { + c.metrics.InputsWriteFailure.Inc(ctx) + logger.Errorf(ctx, "Failed to store inputs for Node. Error [%v]. InputsFile [%s]", err, inputsFile) + return handler.StatusUndefined, errors.Wrapf( + errors.StorageError, node.GetID(), err, "Failed to store inputs for Node. InputsFile [%s]", inputsFile) + } + } + + logger.Debugf(ctx, "Node Data Directory [%s].", nodeStatus.GetDataDir()) + } + + // Now that we have resolved the inputs, we can record as a transition latency. This is because we have completed + // all the overhead that we have to compute. Any failures after this will incur this penalty, but it could be due + // to various external reasons - like queuing, overuse of quota, plugin overhead etc. + c.RecordTransitionLatency(ctx, w, node, nodeStatus) + + // Start node execution + return h.StartNode(ctx, w, node, nodeInputs) +} + +func (c *nodeExecutor) handleNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) (handler.Status, error) { + logger.Debugf(ctx, "Handling Node [%s]", node.GetID()) + defer logger.Debugf(ctx, "Completed node [%s]", node.GetID()) + // Now depending on the node type decide + h, err := c.nodeHandlerFactory.GetHandler(node.GetKind()) + if err != nil { + return handler.StatusUndefined, err + } + + // Important to note that we have special optimization for start node only (not end node) + // We specifically ignore queueing of start node and directly move the start node to "starting" + // This prevents an extra event to Admin and an extra write to etcD. This is also because of the fact that start node does not have tasks and do not need to send out task events. + + nodeStatus := w.GetNodeExecutionStatus(node.GetID()) + var status handler.Status + if !node.IsStartNode() && !node.IsEndNode() && nodeStatus.GetPhase() == v1alpha1.NodePhaseNotYetStarted { + // We only send the queued event to Admin in case the node was never started and when it is not either StartNode. + // We do not do this for endNode, because endNode may still be not executable. This is because StartNode + // completes as soon as started. + return c.queueNodeIfReady(ctx, w, node, nodeStatus) + } else if node.IsEndNode() { + status, err = c.queueNodeIfReady(ctx, w, node, nodeStatus) + if err == nil && status.Phase == handler.PhaseQueued { + status, err = c.startNode(ctx, w, node, nodeStatus, h) + } + } else if node.IsStartNode() || nodeStatus.GetPhase() == v1alpha1.NodePhaseQueued || + nodeStatus.GetPhase() == v1alpha1.NodePhaseRetryableFailure { + // If the node is either StartNode or was previously queued or failed in a previous attempt, we will call + // the start method on the node handler + status, err = c.startNode(ctx, w, node, nodeStatus, h) + } else if nodeStatus.GetPhase() == v1alpha1.NodePhaseFailing { + status, err = h.HandleFailingNode(ctx, w, node) + } else { + status, err = h.CheckNodeStatus(ctx, w, node, nodeStatus) + } + + return status, err +} + +func (c *nodeExecutor) IdempotentRecordEvent(ctx context.Context, nodeEvent *event.NodeExecutionEvent) error { + err := c.nodeRecorder.RecordNodeEvent(ctx, nodeEvent) + // TODO: add unit tests for this specific path + if err != nil && eventsErr.IsAlreadyExists(err) { + logger.Infof(ctx, "Node event phase: %s, nodeId %s already exist", + nodeEvent.Phase.String(), nodeEvent.GetId().NodeId) + return nil + } + return err +} + +func (c *nodeExecutor) TransitionToPhase(ctx context.Context, execID *core.WorkflowExecutionIdentifier, + node v1alpha1.ExecutableNode, nodeStatus v1alpha1.ExecutableNodeStatus, toStatus handler.Status) (executors.NodeStatus, error) { + + previousNodePhase := nodeStatus.GetPhase() + // TODO GC analysis. We will create a ton of node-events but never publish them. We could first check for the PhaseChange and if so then do this processing + + nodeEvent := &event.NodeExecutionEvent{ + Id: &core.NodeExecutionIdentifier{ + NodeId: node.GetID(), + ExecutionId: execID, + }, + InputUri: v1alpha1.GetInputsFile(nodeStatus.GetDataDir()).String(), + } + + var returnStatus executors.NodeStatus + errMsg := "" + errCode := "NodeExecutionUnknownError" + if toStatus.Err != nil { + errMsg = toStatus.Err.Error() + code, ok := errors.GetErrorCode(toStatus.Err) + if ok { + errCode = code.String() + } + } + + // If there is a child workflow, include the execution of the child workflow in the event + if nodeStatus.GetWorkflowNodeStatus() != nil { + nodeEvent.TargetMetadata = &event.NodeExecutionEvent_WorkflowNodeMetadata{ + WorkflowNodeMetadata: &event.WorkflowNodeMetadata{ + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: execID.Project, + Domain: execID.Domain, + Name: nodeStatus.GetWorkflowNodeStatus().GetWorkflowExecutionName(), + }, + }, + } + } + + switch toStatus.Phase { + case handler.PhaseNotStarted: + return executors.NodeStatusPending, nil + // TODO we should not need handler.PhaseQueued since we have added Task StateMachine. Remove it. + case handler.PhaseQueued: + nodeEvent.Phase = core.NodeExecution_QUEUED + nodeStatus.UpdatePhase(v1alpha1.NodePhaseQueued, v1.NewTime(toStatus.OccurredAt), "") + + returnStatus = executors.NodeStatusQueued + if !toStatus.OccurredAt.IsZero() { + nodeEvent.OccurredAt = utils.GetProtoTime(&v1.Time{Time: toStatus.OccurredAt}) + } else { + nodeEvent.OccurredAt = ptypes.TimestampNow() // TODO: add queueAt in nodeStatus + } + + case handler.PhaseRunning: + nodeEvent.Phase = core.NodeExecution_RUNNING + nodeStatus.UpdatePhase(v1alpha1.NodePhaseRunning, v1.NewTime(toStatus.OccurredAt), "") + nodeEvent.OccurredAt = utils.GetProtoTime(nodeStatus.GetStartedAt()) + returnStatus = executors.NodeStatusRunning + + if nodeStatus.GetQueuedAt() != nil && nodeStatus.GetStartedAt() != nil { + c.metrics.QueuingLatency.Observe(ctx, nodeStatus.GetQueuedAt().Time, nodeStatus.GetStartedAt().Time) + } + case handler.PhaseRetryableFailure: + maxAttempts := uint32(0) + if node.GetRetryStrategy() != nil && node.GetRetryStrategy().MinAttempts != nil { + maxAttempts = uint32(*node.GetRetryStrategy().MinAttempts) + } + + nodeEvent.OutputResult = &event.NodeExecutionEvent_Error{ + Error: &core.ExecutionError{ + Code: errCode, + Message: fmt.Sprintf("Retries [%d/%d], %s", nodeStatus.GetAttempts(), maxAttempts, errMsg), + ErrorUri: v1alpha1.GetOutputErrorFile(nodeStatus.GetDataDir()).String(), + }, + } + + if nodeStatus.IncrementAttempts() >= maxAttempts { + logger.Debugf(ctx, "All retries have exhausted, failing node. [%d/%d]", nodeStatus.GetAttempts(), maxAttempts) + // Failure + nodeEvent.Phase = core.NodeExecution_FAILED + nodeStatus.UpdatePhase(v1alpha1.NodePhaseFailed, v1.NewTime(toStatus.OccurredAt), errMsg) + nodeEvent.OccurredAt = utils.GetProtoTime(nodeStatus.GetStoppedAt()) + returnStatus = executors.NodeStatusFailed(toStatus.Err) + c.metrics.FailureDuration.Observe(ctx, nodeStatus.GetStartedAt().Time, nodeStatus.GetStoppedAt().Time) + } else { + // retry + // TODO add a nodeEvent of retryableFailure (it is not a terminal event). + // For now, we don't send an event for node retryable failures. + nodeEvent = nil + nodeStatus.UpdatePhase(v1alpha1.NodePhaseRetryableFailure, v1.NewTime(toStatus.OccurredAt), errMsg) + returnStatus = executors.NodeStatusRunning + + // Reset all executors' state to start a fresh attempt. + nodeStatus.ClearTaskStatus() + nodeStatus.ClearWorkflowStatus() + nodeStatus.ClearDynamicNodeStatus() + + // Required for transition (backwards compatibility) + if nodeStatus.GetLastUpdatedAt() != nil { + c.metrics.FailureDuration.Observe(ctx, nodeStatus.GetStartedAt().Time, nodeStatus.GetLastUpdatedAt().Time) + } + } + + case handler.PhaseSkipped: + nodeEvent.Phase = core.NodeExecution_SKIPPED + nodeStatus.UpdatePhase(v1alpha1.NodePhaseSkipped, v1.NewTime(toStatus.OccurredAt), "") + nodeEvent.OccurredAt = utils.GetProtoTime(nodeStatus.GetStoppedAt()) + returnStatus = executors.NodeStatusSuccess + + case handler.PhaseSucceeding: + nodeStatus.UpdatePhase(v1alpha1.NodePhaseSucceeding, v1.NewTime(toStatus.OccurredAt), "") + // Currently we do not record events for this + return executors.NodeStatusRunning, nil + + case handler.PhaseSuccess: + nodeEvent.Phase = core.NodeExecution_SUCCEEDED + reason := "" + if nodeStatus.IsCached() { + reason = "Task Skipped due to Discovery Cache Hit." + } + nodeStatus.UpdatePhase(v1alpha1.NodePhaseSucceeded, v1.NewTime(toStatus.OccurredAt), reason) + nodeEvent.OccurredAt = utils.GetProtoTime(nodeStatus.GetStoppedAt()) + if metadata, err := c.store.Head(ctx, v1alpha1.GetOutputsFile(nodeStatus.GetDataDir())); err == nil && metadata.Exists() { + nodeEvent.OutputResult = &event.NodeExecutionEvent_OutputUri{ + OutputUri: v1alpha1.GetOutputsFile(nodeStatus.GetDataDir()).String(), + } + } + + returnStatus = executors.NodeStatusSuccess + c.metrics.SuccessDuration.Observe(ctx, nodeStatus.GetStartedAt().Time, nodeStatus.GetStoppedAt().Time) + + case handler.PhaseFailing: + nodeEvent.Phase = core.NodeExecution_FAILING + nodeStatus.UpdatePhase(v1alpha1.NodePhaseFailing, v1.NewTime(toStatus.OccurredAt), "") + nodeEvent.OccurredAt = utils.GetProtoTime(nil) + returnStatus = executors.NodeStatusRunning + + case handler.PhaseFailed: + nodeEvent.Phase = core.NodeExecution_FAILED + nodeStatus.UpdatePhase(v1alpha1.NodePhaseFailed, v1.NewTime(toStatus.OccurredAt), errMsg) + nodeEvent.OccurredAt = utils.GetProtoTime(nodeStatus.GetStoppedAt()) + nodeEvent.OutputResult = &event.NodeExecutionEvent_Error{ + Error: &core.ExecutionError{ + Code: errCode, + Message: errMsg, + ErrorUri: v1alpha1.GetOutputErrorFile(nodeStatus.GetDataDir()).String(), + }, + } + returnStatus = executors.NodeStatusFailed(toStatus.Err) + c.metrics.FailureDuration.Observe(ctx, nodeStatus.GetStartedAt().Time, nodeStatus.GetStoppedAt().Time) + + case handler.PhaseUndefined: + return executors.NodeStatusUndefined, errors.Errorf(errors.IllegalStateError, node.GetID(), "unexpected undefined state received, without an error") + } + + // We observe that the phase has changed, and so we will record this event. + if nodeEvent != nil && previousNodePhase != nodeStatus.GetPhase() { + if nodeStatus.GetParentTaskID() != nil { + nodeEvent.ParentTaskMetadata = &event.ParentTaskExecutionMetadata{ + Id: nodeStatus.GetParentTaskID(), + } + } + + logger.Debugf(ctx, "Recording NodeEvent for Phase transition [%s] -> [%s]", previousNodePhase.String(), nodeStatus.GetPhase().String()) + err := c.IdempotentRecordEvent(ctx, nodeEvent) + + if err != nil && eventsErr.IsEventAlreadyInTerminalStateError(err) { + logger.Warningf(ctx, "Failed to record nodeEvent, error [%s]", err.Error()) + return executors.NodeStatusFailed(errors.Wrapf(errors.IllegalStateError, node.GetID(), err, + "phase mismatch between propeller and control plane; Propeller State: %s", returnStatus.NodePhase)), nil + } else if err != nil { + logger.Warningf(ctx, "Failed to record nodeEvent, error [%s]", err.Error()) + return executors.NodeStatusUndefined, errors.Wrapf(errors.EventRecordingFailed, node.GetID(), err, "failed to record node event") + } + } + return returnStatus, nil +} + +func (c *nodeExecutor) executeNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) (executors.NodeStatus, error) { + handlerStatus, err := c.handleNode(ctx, w, node) + if err != nil { + logger.Warningf(ctx, "Node handling failed with an error [%v]", err.Error()) + return executors.NodeStatusUndefined, err + } + nodeStatus := w.GetNodeExecutionStatus(node.GetID()) + return c.TransitionToPhase(ctx, w.GetExecutionID().WorkflowExecutionIdentifier, node, nodeStatus, handlerStatus) +} + +// The space search for the next node to execute is implemented like a DFS algorithm. handleDownstream visits all the nodes downstream from +// the currentNode. Visit a node is the RecursiveNodeHandler. A visit may be partial, complete or may result in a failure. +func (c *nodeExecutor) handleDownstream(ctx context.Context, w v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode) (executors.NodeStatus, error) { + logger.Debugf(ctx, "Handling downstream Nodes") + // This node is success. Handle all downstream nodes + downstreamNodes, err := w.FromNode(currentNode.GetID()) + if err != nil { + logger.Debugf(ctx, "Error when retrieving downstream nodes. Error [%v]", err) + return executors.NodeStatusFailed(err), nil + } + if len(downstreamNodes) == 0 { + logger.Debugf(ctx, "No downstream nodes found. Complete.") + return executors.NodeStatusComplete, nil + } + // If any downstream node is failed, fail, all + // Else if all are success then success + // Else if any one is running then Downstream is still running + allCompleted := true + partialNodeCompletion := false + for _, downstreamNodeName := range downstreamNodes { + downstreamNode, ok := w.GetNode(downstreamNodeName) + if !ok { + return executors.NodeStatusFailed(errors.Errorf(errors.BadSpecificationError, currentNode.GetID(), "Unable to find Downstream Node [%v]", downstreamNodeName)), nil + } + state, err := c.RecursiveNodeHandler(ctx, w, downstreamNode) + if err != nil { + return executors.NodeStatusUndefined, err + } + if state.HasFailed() { + logger.Debugf(ctx, "Some downstream node has failed, %s", state.Err.Error()) + return state, nil + } + if !state.IsComplete() { + allCompleted = false + } + + if state.PartiallyComplete() { + // This implies that one of the downstream nodes has completed and workflow is ready for propagation + // We do not propagate in current cycle to make it possible to store the state between transitions + partialNodeCompletion = true + } + } + if allCompleted { + logger.Debugf(ctx, "All downstream nodes completed") + return executors.NodeStatusComplete, nil + } + if partialNodeCompletion { + return executors.NodeStatusSuccess, nil + } + return executors.NodeStatusPending, nil +} + +func (c *nodeExecutor) SetInputsForStartNode(ctx context.Context, w v1alpha1.BaseWorkflowWithStatus, inputs *handler.Data) (executors.NodeStatus, error) { + startNode := w.StartNode() + if startNode == nil { + return executors.NodeStatusFailed(errors.Errorf(errors.BadSpecificationError, v1alpha1.StartNodeID, "Start node not found")), nil + } + ctx = contextutils.WithNodeID(ctx, startNode.GetID()) + if inputs == nil { + logger.Infof(ctx, "No inputs for the workflow. Skipping storing inputs") + return executors.NodeStatusComplete, nil + } + // StartNode is special. It does not have any processing step. It just takes the workflow (or subworkflow) inputs and converts to its own outputs + nodeStatus := w.GetNodeExecutionStatus(startNode.GetID()) + if nodeStatus.GetDataDir() == "" { + return executors.NodeStatusUndefined, errors.Errorf(errors.IllegalStateError, startNode.GetID(), "no data-dir set, cannot store inputs") + } + outputFile := v1alpha1.GetOutputsFile(nodeStatus.GetDataDir()) + so := storage.Options{} + if err := c.store.WriteProtobuf(ctx, outputFile, so, inputs); err != nil { + logger.Errorf(ctx, "Failed to write protobuf (metadata). Error [%v]", err) + return executors.NodeStatusUndefined, errors.Wrapf(errors.CausedByError, startNode.GetID(), err, "Failed to store workflow inputs (as start node)") + } + return executors.NodeStatusComplete, nil +} + +func (c *nodeExecutor) RecursiveNodeHandler(ctx context.Context, w v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode) (executors.NodeStatus, error) { + currentNodeCtx := contextutils.WithNodeID(ctx, currentNode.GetID()) + nodeStatus := w.GetNodeExecutionStatus(currentNode.GetID()) + switch nodeStatus.GetPhase() { + case v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseFailing, v1alpha1.NodePhaseRetryableFailure, v1alpha1.NodePhaseSucceeding: + logger.Debugf(currentNodeCtx, "Handling node Status [%v]", nodeStatus.GetPhase().String()) + return c.executeNode(currentNodeCtx, w, currentNode) + // TODO we can optimize skip state handling by iterating down the graph and marking all as skipped + // Currently we treat either Skip or Success the same way. In this approach only one node will be skipped + // at a time. As we iterate down, further nodes will be skipped + case v1alpha1.NodePhaseSucceeded, v1alpha1.NodePhaseSkipped: + return c.handleDownstream(ctx, w, currentNode) + case v1alpha1.NodePhaseFailed: + logger.Debugf(currentNodeCtx, "Node Failed") + return executors.NodeStatusFailed(errors.Errorf(errors.RuntimeExecutionError, currentNode.GetID(), "Node Failed.")), nil + } + return executors.NodeStatusUndefined, errors.Errorf(errors.IllegalStateError, currentNode.GetID(), "Should never reach here") +} + +func (c *nodeExecutor) AbortHandler(ctx context.Context, w v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode) error { + ctx = contextutils.WithNodeID(ctx, currentNode.GetID()) + nodeStatus := w.GetNodeExecutionStatus(currentNode.GetID()) + switch nodeStatus.GetPhase() { + case v1alpha1.NodePhaseRunning: + // Abort this node + h, err := c.nodeHandlerFactory.GetHandler(currentNode.GetKind()) + if err != nil { + return err + } + return h.AbortNode(ctx, w, currentNode) + case v1alpha1.NodePhaseSucceeded, v1alpha1.NodePhaseSkipped: + // Abort downstream nodes + downstreamNodes, err := w.FromNode(currentNode.GetID()) + if err != nil { + logger.Debugf(ctx, "Error when retrieving downstream nodes. Error [%v]", err) + return nil + } + for _, d := range downstreamNodes { + downstreamNode, ok := w.GetNode(d) + if !ok { + return errors.Errorf(errors.BadSpecificationError, currentNode.GetID(), "Unable to find Downstream Node [%v]", d) + } + if err := c.AbortHandler(ctx, w, downstreamNode); err != nil { + return err + } + } + return nil + } + return nil +} + +func (c *nodeExecutor) Initialize(ctx context.Context) error { + logger.Infof(ctx, "Initializing Core Node Executor") + return nil +} + +func NewExecutor(ctx context.Context, store *storage.DataStore, enQWorkflow v1alpha1.EnqueueWorkflow, + revalPeriod time.Duration, eventSink events.EventSink, workflowLauncher launchplan.Executor, + catalogClient catalog.Client, kubeClient executors.Client, scope promutils.Scope) (executors.Node, error) { + + nodeScope := scope.NewSubScope("node") + exec := &nodeExecutor{ + store: store, + enqueueWorkflow: enQWorkflow, + nodeRecorder: events.NewNodeEventRecorder(eventSink, nodeScope), + metrics: &nodeMetrics{ + FailureDuration: labeled.NewStopWatch("failure_duration", "Indicates the total execution time of a failed workflow.", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + SuccessDuration: labeled.NewStopWatch("success_duration", "Indicates the total execution time of a successful workflow.", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + InputsWriteFailure: labeled.NewCounter("inputs_write_fail", "Indicates failure in writing node inputs to metastore", nodeScope), + ResolutionFailure: labeled.NewCounter("input_resolve_fail", "Indicates failure in resolving node inputs", nodeScope), + TransitionLatency: labeled.NewStopWatch("transition_latency", "Measures the latency between the last parent node stoppedAt time and current node's queued time.", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + QueuingLatency: labeled.NewStopWatch("queueing_latency", "Measures the latency between the time a node's been queued to the time the handler reported the executable moved to running state", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + }, + } + nodeHandlerFactory, err := NewHandlerFactory( + ctx, + exec, + eventSink, + workflowLauncher, + enQWorkflow, + revalPeriod, + store, + catalogClient, + kubeClient, + nodeScope, + ) + exec.nodeHandlerFactory = nodeHandlerFactory + return exec, err +} diff --git a/flytepropeller/pkg/controller/nodes/executor_test.go b/flytepropeller/pkg/controller/nodes/executor_test.go new file mode 100644 index 0000000000..2cfcc4caff --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/executor_test.go @@ -0,0 +1,1479 @@ +package nodes + +import ( + "context" + "errors" + "fmt" + "reflect" + "testing" + "time" + + mocks4 "github.com/lyft/flytepropeller/pkg/controller/executors/mocks" + + eventsErr "github.com/lyft/flyteidl/clients/go/events/errors" + + "github.com/golang/protobuf/proto" + "github.com/lyft/flyteidl/clients/go/events" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flyteplugins/go/tasks/v1/flytek8s" + pluginV1 "github.com/lyft/flyteplugins/go/tasks/v1/types" + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/promutils/labeled" + "github.com/lyft/flytestdlib/storage" + goerrors "github.com/pkg/errors" + "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" + "github.com/lyft/flytepropeller/pkg/controller/catalog" + "github.com/lyft/flytepropeller/pkg/controller/executors" + "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" + mocks3 "github.com/lyft/flytepropeller/pkg/controller/nodes/handler/mocks" + mocks2 "github.com/lyft/flytepropeller/pkg/controller/nodes/mocks" + "github.com/lyft/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" + "github.com/lyft/flytepropeller/pkg/controller/nodes/task" + "github.com/lyft/flytepropeller/pkg/utils" + flyteassert "github.com/lyft/flytepropeller/pkg/utils/assert" +) + +var fakeKubeClient = mocks4.NewFakeKubeClient() + +func createSingletonTaskExecutorFactory() task.Factory { + return &task.FactoryFuncs{ + GetTaskExecutorCb: func(taskType v1alpha1.TaskType) (pluginV1.Executor, error) { + return nil, nil + }, + ListAllTaskExecutorsCb: func() []pluginV1.Executor { + return []pluginV1.Executor{} + }, + } +} + +func init() { + flytek8s.InitializeFake() +} + +func TestSetInputsForStartNode(t *testing.T) { + ctx := context.Background() + mockStorage := createInmemoryDataStore(t, testScope.NewSubScope("f")) + catalogClient := catalog.NewCatalogClient(mockStorage) + enQWf := func(workflowID v1alpha1.WorkflowID) {} + + factory := createSingletonTaskExecutorFactory() + task.SetTestFactory(factory) + assert.True(t, task.IsTestModeEnabled()) + + exec, err := NewExecutor(ctx, mockStorage, enQWf, time.Second, events.NewMockEventSink(), launchplan.NewFailFastLaunchPlanExecutor(), catalogClient, fakeKubeClient, promutils.NewTestScope()) + assert.NoError(t, err) + inputs := &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "x": utils.MustMakePrimitiveLiteral("hello"), + "y": utils.MustMakePrimitiveLiteral("blah"), + }, + } + + t.Run("NoInputs", func(t *testing.T) { + w := createDummyBaseWorkflow() + w.DummyStartNode = &v1alpha1.NodeSpec{ + ID: v1alpha1.StartNodeID, + } + s, err := exec.SetInputsForStartNode(ctx, w, nil) + assert.NoError(t, err) + assert.Equal(t, executors.NodeStatusComplete, s) + }) + + t.Run("WithInputs", func(t *testing.T) { + w := createDummyBaseWorkflow() + w.GetNodeExecutionStatus(v1alpha1.StartNodeID).SetDataDir("s3://test-bucket/exec/start-node/data") + w.DummyStartNode = &v1alpha1.NodeSpec{ + ID: v1alpha1.StartNodeID, + } + s, err := exec.SetInputsForStartNode(ctx, w, inputs) + assert.NoError(t, err) + assert.Equal(t, executors.NodeStatusComplete, s) + actual := &core.LiteralMap{} + if assert.NoError(t, mockStorage.ReadProtobuf(ctx, "s3://test-bucket/exec/start-node/data/outputs.pb", actual)) { + flyteassert.EqualLiteralMap(t, inputs, actual) + } + }) + + t.Run("DataDirNotSet", func(t *testing.T) { + w := createDummyBaseWorkflow() + w.DummyStartNode = &v1alpha1.NodeSpec{ + ID: v1alpha1.StartNodeID, + } + s, err := exec.SetInputsForStartNode(ctx, w, inputs) + assert.Error(t, err) + assert.Equal(t, executors.NodeStatusUndefined, s) + }) + + failStorage := createFailingDatastore(t, testScope.NewSubScope("failing")) + execFail, err := NewExecutor(ctx, failStorage, enQWf, time.Second, events.NewMockEventSink(), launchplan.NewFailFastLaunchPlanExecutor(), catalogClient, fakeKubeClient, promutils.NewTestScope()) + assert.NoError(t, err) + t.Run("StorageFailure", func(t *testing.T) { + w := createDummyBaseWorkflow() + w.GetNodeExecutionStatus(v1alpha1.StartNodeID).SetDataDir("s3://test-bucket/exec/start-node/data") + w.DummyStartNode = &v1alpha1.NodeSpec{ + ID: v1alpha1.StartNodeID, + } + s, err := execFail.SetInputsForStartNode(ctx, w, inputs) + assert.Error(t, err) + assert.Equal(t, executors.NodeStatusUndefined, s) + }) +} + +func TestNodeExecutor_TransitionToPhase(t *testing.T) { + ctx := context.Background() + enQWf := func(workflowID v1alpha1.WorkflowID) { + } + mockEventSink := events.NewMockEventSink().(*events.MockEventSink) + + factory := createSingletonTaskExecutorFactory() + task.SetTestFactory(factory) + assert.True(t, task.IsTestModeEnabled()) + + memStore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) + assert.NoError(t, err) + + catalogClient := catalog.NewCatalogClient(memStore) + execIface, err := NewExecutor(ctx, memStore, enQWf, time.Second, mockEventSink, launchplan.NewFailFastLaunchPlanExecutor(), catalogClient, fakeKubeClient, promutils.NewTestScope()) + assert.NoError(t, err) + exec := execIface.(*nodeExecutor) + execID := &core.WorkflowExecutionIdentifier{} + nodeID := "n1" + + expectedErr := fmt.Errorf("test err") + taskErr := fmt.Errorf("task failed") + + // TABLE Tests + tests := []struct { + name string + nodeStatus v1alpha1.ExecutableNodeStatus + toStatus handler.Status + expectedErr bool + expectedNodeStatus executors.NodeStatus + }{ + {"notStarted", &v1alpha1.NodeStatus{Phase: v1alpha1.NodePhaseNotYetStarted}, handler.StatusNotStarted, false, executors.NodeStatusPending}, + {"running", &v1alpha1.NodeStatus{Phase: v1alpha1.NodePhaseNotYetStarted}, handler.StatusRunning, false, executors.NodeStatusRunning}, + {"runningRepeated", &v1alpha1.NodeStatus{Phase: v1alpha1.NodePhaseRunning}, handler.StatusRunning, false, executors.NodeStatusRunning}, + {"success", &v1alpha1.NodeStatus{Phase: v1alpha1.NodePhaseRunning}, handler.StatusSuccess, false, executors.NodeStatusSuccess}, + {"succeeding", &v1alpha1.NodeStatus{Phase: v1alpha1.NodePhaseRunning}, handler.StatusSucceeding, false, executors.NodeStatusRunning}, + {"failing", &v1alpha1.NodeStatus{Phase: v1alpha1.NodePhaseRunning}, handler.StatusFailing(nil), false, executors.NodeStatusRunning}, + {"failed", &v1alpha1.NodeStatus{Phase: v1alpha1.NodePhaseFailing}, handler.StatusFailed(taskErr), false, executors.NodeStatusFailed(taskErr)}, + {"undefined", &v1alpha1.NodeStatus{Phase: v1alpha1.NodePhaseNotYetStarted}, handler.StatusUndefined, true, executors.NodeStatusUndefined}, + {"skipped", &v1alpha1.NodeStatus{Phase: v1alpha1.NodePhaseNotYetStarted}, handler.StatusSkipped, false, executors.NodeStatusSuccess}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + node := &mocks.ExecutableNode{} + node.On("GetID").Return(nodeID) + n, err := exec.TransitionToPhase(ctx, execID, node, test.nodeStatus, test.toStatus) + if test.expectedErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, test.expectedNodeStatus, n) + }) + } + + // Testing retries + t.Run("noRetryAttemptSet", func(t *testing.T) { + now := v1.Now() + status := &mocks.ExecutableNodeStatus{} + status.On("GetPhase").Return(v1alpha1.NodePhaseRunning) + status.On("GetAttempts").Return(uint32(0)) + status.On("GetDataDir").Return(storage.DataReference("x")) + status.On("IncrementAttempts").Return(uint32(1)) + status.On("UpdatePhase", v1alpha1.NodePhaseFailed, mock.Anything, mock.AnythingOfType("string")) + status.On("GetQueuedAt").Return(&now) + status.On("GetStartedAt").Return(&now) + status.On("GetStoppedAt").Return(&now) + status.On("GetWorkflowNodeStatus").Return(nil) + + node := &mocks.ExecutableNode{} + node.On("GetID").Return(nodeID) + node.On("GetRetryStrategy").Return(nil) + + n, err := exec.TransitionToPhase(ctx, execID, node, status, handler.StatusRetryableFailure(fmt.Errorf("failed"))) + assert.NoError(t, err) + assert.Equal(t, executors.NodePhaseFailed, n.NodePhase) + }) + + // Testing retries + t.Run("maxAttempt0", func(t *testing.T) { + now := v1.Now() + status := &mocks.ExecutableNodeStatus{} + status.On("GetPhase").Return(v1alpha1.NodePhaseRunning) + status.On("GetAttempts").Return(uint32(0)) + status.On("GetDataDir").Return(storage.DataReference("x")) + status.On("IncrementAttempts").Return(uint32(1)) + status.On("UpdatePhase", v1alpha1.NodePhaseFailed, mock.Anything, mock.AnythingOfType("string")) + status.On("GetQueuedAt").Return(&now) + status.On("GetStartedAt").Return(&now) + status.On("GetStoppedAt").Return(&now) + status.On("GetWorkflowNodeStatus").Return(nil) + + maxAttempts := 0 + node := &mocks.ExecutableNode{} + node.On("GetID").Return(nodeID) + node.On("GetRetryStrategy").Return(&v1alpha1.RetryStrategy{MinAttempts: &maxAttempts}) + + n, err := exec.TransitionToPhase(ctx, execID, node, status, handler.StatusRetryableFailure(fmt.Errorf("failed"))) + assert.NoError(t, err) + assert.Equal(t, executors.NodePhaseFailed, n.NodePhase) + }) + + // Testing retries + t.Run("retryAttemptsRemaining", func(t *testing.T) { + now := v1.Now() + status := &mocks.ExecutableNodeStatus{} + status.On("GetPhase").Return(v1alpha1.NodePhaseRunning) + status.On("GetAttempts").Return(uint32(0)) + status.On("GetDataDir").Return(storage.DataReference("x")) + status.On("IncrementAttempts").Return(uint32(1)) + status.On("UpdatePhase", v1alpha1.NodePhaseRetryableFailure, mock.Anything, mock.AnythingOfType("string")) + status.On("GetQueuedAt").Return(&now) + status.On("GetStartedAt").Return(&now) + status.On("GetLastUpdatedAt").Return(&now) + status.On("GetWorkflowNodeStatus").Return(nil) + var s *v1alpha1.TaskNodeStatus + status.On("UpdateTaskNodeStatus", s).Times(10) + status.On("ClearTaskStatus").Return() + status.On("ClearWorkflowStatus").Return() + status.On("ClearDynamicNodeStatus").Return() + + maxAttempts := 2 + node := &mocks.ExecutableNode{} + node.On("GetID").Return(nodeID) + node.On("GetRetryStrategy").Return(&v1alpha1.RetryStrategy{MinAttempts: &maxAttempts}) + + n, err := exec.TransitionToPhase(ctx, execID, node, status, handler.StatusRetryableFailure(fmt.Errorf("failed"))) + assert.NoError(t, err) + assert.Equal(t, executors.NodePhaseRunning, n.NodePhase, "%+v", n) + }) + + // Testing retries + t.Run("retriesExhausted", func(t *testing.T) { + now := v1.Now() + status := &mocks.ExecutableNodeStatus{} + status.On("GetPhase").Return(v1alpha1.NodePhaseRunning) + status.On("GetAttempts").Return(uint32(0)) + status.On("GetDataDir").Return(storage.DataReference("x")) + // Change to return 3 + status.On("IncrementAttempts").Return(uint32(3)) + status.On("UpdatePhase", v1alpha1.NodePhaseFailed, mock.Anything, mock.AnythingOfType("string")) + status.On("GetQueuedAt").Return(&now) + status.On("GetStartedAt").Return(&now) + status.On("GetStoppedAt").Return(&now) + status.On("GetWorkflowNodeStatus").Return(nil) + + maxAttempts := 2 + node := &mocks.ExecutableNode{} + node.On("GetID").Return(nodeID) + node.On("GetRetryStrategy").Return(&v1alpha1.RetryStrategy{MinAttempts: &maxAttempts}) + + n, err := exec.TransitionToPhase(ctx, execID, node, status, handler.StatusRetryableFailure(fmt.Errorf("failed"))) + assert.NoError(t, err) + assert.Equal(t, executors.NodePhaseFailed, n.NodePhase, "%+v", n.NodePhase) + }) + + t.Run("eventSendFailure", func(t *testing.T) { + node := &mocks.ExecutableNode{} + node.On("GetID").Return(nodeID) + // In case Report event fails + mockEventSink.SinkCb = func(ctx context.Context, message proto.Message) error { + return expectedErr + } + n, err := exec.TransitionToPhase(ctx, execID, node, &v1alpha1.NodeStatus{Phase: v1alpha1.NodePhaseRunning}, handler.StatusSuccess) + assert.Error(t, err) + assert.Equal(t, expectedErr, goerrors.Cause(err)) + assert.Equal(t, executors.NodeStatusUndefined, n) + }) + + t.Run("eventSendMismatch", func(t *testing.T) { + node := &mocks.ExecutableNode{} + node.On("GetID").Return(nodeID) + // In case Report event fails + mockEventSink.SinkCb = func(ctx context.Context, message proto.Message) error { + return &eventsErr.EventError{Code: eventsErr.EventAlreadyInTerminalStateError, + Cause: errors.New("already exists"), + } + } + n, err := exec.TransitionToPhase(ctx, execID, node, &v1alpha1.NodeStatus{Phase: v1alpha1.NodePhaseRunning}, handler.StatusSuccess) + assert.NoError(t, err) + assert.Equal(t, executors.NodePhaseFailed, n.NodePhase) + }) + + // Testing that workflow execution name is queried in running + t.Run("childWorkflows", func(t *testing.T) { + now := v1.Now() + + wfNodeStatus := &mocks.ExecutableWorkflowNodeStatus{} + wfNodeStatus.On("GetWorkflowExecutionName").Return("childWfName") + + status := &mocks.ExecutableNodeStatus{} + status.On("GetPhase").Return(v1alpha1.NodePhaseQueued) + status.On("GetAttempts").Return(uint32(0)) + status.On("GetDataDir").Return(storage.DataReference("x")) + status.On("IncrementAttempts").Return(uint32(1)) + status.On("UpdatePhase", v1alpha1.NodePhaseRunning, mock.Anything, mock.AnythingOfType("string")) + status.On("GetStartedAt").Return(&now) + status.On("GetQueuedAt").Return(&now) + status.On("GetStoppedAt").Return(&now) + status.On("GetOrCreateWorkflowStatus").Return(wfNodeStatus) + status.On("ClearTaskStatus").Return() + status.On("ClearWorkflowStatus").Return() + status.On("GetWorkflowNodeStatus").Return(wfNodeStatus) + + node := &mocks.ExecutableNode{} + node.On("GetID").Return(nodeID) + node.On("GetRetryStrategy").Return(nil) + + n, err := exec.TransitionToPhase(ctx, execID, node, status, handler.StatusRunning) + assert.NoError(t, err) + assert.Equal(t, executors.NodePhaseRunning, n.NodePhase) + wfNodeStatus.AssertCalled(t, "GetWorkflowExecutionName") + }) +} + +func TestNodeExecutor_Initialize(t *testing.T) { + ctx := context.Background() + enQWf := func(workflowID v1alpha1.WorkflowID) { + } + + mockEventSink := events.NewMockEventSink().(*events.MockEventSink) + memStore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) + assert.NoError(t, err) + + catalogClient := catalog.NewCatalogClient(memStore) + + execIface, err := NewExecutor(ctx, memStore, enQWf, time.Second, mockEventSink, launchplan.NewFailFastLaunchPlanExecutor(), catalogClient, fakeKubeClient, promutils.NewTestScope()) + assert.NoError(t, err) + exec := execIface.(*nodeExecutor) + + assert.NoError(t, exec.Initialize(ctx)) +} + +func TestNodeExecutor_RecursiveNodeHandler_RecurseStartNodes(t *testing.T) { + ctx := context.Background() + enQWf := func(workflowID v1alpha1.WorkflowID) { + } + mockEventSink := events.NewMockEventSink().(*events.MockEventSink) + factory := createSingletonTaskExecutorFactory() + task.SetTestFactory(factory) + assert.True(t, task.IsTestModeEnabled()) + + store := createInmemoryDataStore(t, promutils.NewTestScope()) + catalogClient := catalog.NewCatalogClient(store) + + execIface, err := NewExecutor(ctx, store, enQWf, time.Second, mockEventSink, + launchplan.NewFailFastLaunchPlanExecutor(), catalogClient, fakeKubeClient, promutils.NewTestScope()) + assert.NoError(t, err) + exec := execIface.(*nodeExecutor) + + defaultNodeID := "n1" + + createStartNodeWf := func(p v1alpha1.NodePhase, _ int) (v1alpha1.ExecutableWorkflow, v1alpha1.ExecutableNode, v1alpha1.ExecutableNodeStatus) { + startNode := &v1alpha1.NodeSpec{ + Kind: v1alpha1.NodeKindStart, + ID: v1alpha1.StartNodeID, + } + startNodeStatus := &v1alpha1.NodeStatus{ + Phase: p, + } + return &v1alpha1.FlyteWorkflow{ + Status: v1alpha1.WorkflowStatus{ + NodeStatus: map[v1alpha1.NodeID]*v1alpha1.NodeStatus{ + v1alpha1.StartNodeID: startNodeStatus, + }, + DataDir: "data", + }, + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "wf", + Nodes: map[v1alpha1.NodeID]*v1alpha1.NodeSpec{ + v1alpha1.StartNodeID: startNode, + }, + Connections: v1alpha1.Connections{ + UpstreamEdges: map[v1alpha1.NodeID][]v1alpha1.NodeID{ + defaultNodeID: {v1alpha1.StartNodeID}, + }, + DownstreamEdges: map[v1alpha1.NodeID][]v1alpha1.NodeID{ + v1alpha1.StartNodeID: {defaultNodeID}, + }, + }, + }, + }, startNode, startNodeStatus + + } + + // Recurse Child Node Queued previously + { + tests := []struct { + name string + currentNodePhase v1alpha1.NodePhase + expectedNodePhase v1alpha1.NodePhase + expectedPhase executors.NodePhase + handlerReturn func() (handler.Status, error) + expectedError bool + }{ + // Starting at Queued + {"nys->success", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseSucceeded, executors.NodePhaseSuccess, func() (handler.Status, error) { + return handler.StatusSuccess, nil + }, false}, + {"queued->success", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseSucceeded, executors.NodePhaseSuccess, func() (handler.Status, error) { + return handler.StatusSuccess, nil + }, false}, + {"nys->error", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseNotYetStarted, executors.NodePhaseUndefined, func() (handler.Status, error) { + return handler.StatusUndefined, fmt.Errorf("err") + }, true}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + hf := &mocks2.HandlerFactory{} + exec.nodeHandlerFactory = hf + + h := &mocks3.IFace{} + h.On("StartNode", + mock.MatchedBy(func(ctx context.Context) bool { return true }), + mock.MatchedBy(func(o v1alpha1.ExecutableWorkflow) bool { return true }), + mock.MatchedBy(func(o v1alpha1.ExecutableNode) bool { return true }), + mock.MatchedBy(func(o *handler.Data) bool { return true }), + ).Return(test.handlerReturn()) + + hf.On("GetHandler", v1alpha1.NodeKindStart).Return(h, nil) + + mockWf, startNode, startNodeStatus := createStartNodeWf(test.currentNodePhase, 0) + s, err := exec.RecursiveNodeHandler(ctx, mockWf, startNode) + if test.expectedError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, test.expectedPhase, s.NodePhase, "expected: %s, received %s", test.expectedPhase.String(), s.NodePhase.String()) + assert.Equal(t, uint32(0), startNodeStatus.GetAttempts()) + assert.Equal(t, test.expectedNodePhase, startNodeStatus.GetPhase(), "expected %s, received %s", test.expectedNodePhase.String(), startNodeStatus.GetPhase().String()) + }) + } + } +} + +func TestNodeExecutor_RecursiveNodeHandler_RecurseEndNode(t *testing.T) { + ctx := context.Background() + enQWf := func(workflowID v1alpha1.WorkflowID) { + } + mockEventSink := events.NewMockEventSink().(*events.MockEventSink) + + factory := createSingletonTaskExecutorFactory() + task.SetTestFactory(factory) + assert.True(t, task.IsTestModeEnabled()) + + store := createInmemoryDataStore(t, promutils.NewTestScope()) + catalogClient := catalog.NewCatalogClient(store) + + execIface, err := NewExecutor(ctx, store, enQWf, time.Second, mockEventSink, launchplan.NewFailFastLaunchPlanExecutor(), catalogClient, fakeKubeClient, promutils.NewTestScope()) + assert.NoError(t, err) + exec := execIface.(*nodeExecutor) + + // Node not yet started + { + createSingleNodeWf := func(parentPhase v1alpha1.NodePhase, _ int) (v1alpha1.ExecutableWorkflow, v1alpha1.ExecutableNode, v1alpha1.ExecutableNodeStatus) { + n := &v1alpha1.NodeSpec{ + ID: v1alpha1.EndNodeID, + Kind: v1alpha1.NodeKindEnd, + } + ns := &v1alpha1.NodeStatus{} + + return &v1alpha1.FlyteWorkflow{ + Status: v1alpha1.WorkflowStatus{ + NodeStatus: map[v1alpha1.NodeID]*v1alpha1.NodeStatus{ + v1alpha1.EndNodeID: ns, + v1alpha1.StartNodeID: { + Phase: parentPhase, + }, + }, + DataDir: "data", + }, + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "wf", + Nodes: map[v1alpha1.NodeID]*v1alpha1.NodeSpec{ + v1alpha1.EndNodeID: n, + }, + Connections: v1alpha1.Connections{ + UpstreamEdges: map[v1alpha1.NodeID][]v1alpha1.NodeID{ + v1alpha1.EndNodeID: {v1alpha1.StartNodeID}, + }, + }, + }, + }, n, ns + + } + tests := []struct { + name string + parentNodePhase v1alpha1.NodePhase + expectedNodePhase v1alpha1.NodePhase + expectedPhase executors.NodePhase + expectedError bool + }{ + {"notYetStarted", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseNotYetStarted, executors.NodePhasePending, false}, + {"running", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseNotYetStarted, executors.NodePhasePending, false}, + {"queued", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseNotYetStarted, executors.NodePhasePending, false}, + {"retryable", v1alpha1.NodePhaseRetryableFailure, v1alpha1.NodePhaseNotYetStarted, executors.NodePhasePending, false}, + {"failing", v1alpha1.NodePhaseFailing, v1alpha1.NodePhaseNotYetStarted, executors.NodePhasePending, false}, + {"skipped", v1alpha1.NodePhaseSkipped, v1alpha1.NodePhaseSkipped, executors.NodePhaseSuccess, false}, + {"success", v1alpha1.NodePhaseSucceeded, v1alpha1.NodePhaseQueued, executors.NodePhaseQueued, false}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + hf := &mocks2.HandlerFactory{} + exec.nodeHandlerFactory = hf + h := &mocks3.IFace{} + hf.On("GetHandler", v1alpha1.NodeKindEnd).Return(h, nil) + h.On("StartNode", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(handler.StatusQueued, nil) + + mockWf, mockNode, mockNodeStatus := createSingleNodeWf(test.parentNodePhase, 0) + s, err := exec.RecursiveNodeHandler(ctx, mockWf, mockNode) + if test.expectedError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, test.expectedPhase, s.NodePhase, "expected: %s, received %s", test.expectedPhase.String(), s.NodePhase.String()) + assert.Equal(t, uint32(0), mockNodeStatus.GetAttempts()) + assert.Equal(t, test.expectedNodePhase, mockNodeStatus.GetPhase(), "expected %s, received %s", test.expectedNodePhase.String(), mockNodeStatus.GetPhase().String()) + }) + } + } + + // Recurse End Node Queued previously + { + createSingleNodeWf := func(endNodePhase v1alpha1.NodePhase, _ int) (v1alpha1.ExecutableWorkflow, v1alpha1.ExecutableNode, v1alpha1.ExecutableNodeStatus) { + n := &v1alpha1.NodeSpec{ + ID: v1alpha1.EndNodeID, + Kind: v1alpha1.NodeKindEnd, + } + ns := &v1alpha1.NodeStatus{ + Phase: endNodePhase, + } + + return &v1alpha1.FlyteWorkflow{ + Status: v1alpha1.WorkflowStatus{ + NodeStatus: map[v1alpha1.NodeID]*v1alpha1.NodeStatus{ + v1alpha1.EndNodeID: ns, + v1alpha1.StartNodeID: { + Phase: v1alpha1.NodePhaseSucceeded, + }, + }, + DataDir: "data", + }, + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "wf", + Nodes: map[v1alpha1.NodeID]*v1alpha1.NodeSpec{ + v1alpha1.StartNodeID: { + ID: v1alpha1.StartNodeID, + Kind: v1alpha1.NodeKindStart, + }, + v1alpha1.EndNodeID: n, + }, + Connections: v1alpha1.Connections{ + UpstreamEdges: map[v1alpha1.NodeID][]v1alpha1.NodeID{ + v1alpha1.EndNodeID: {v1alpha1.StartNodeID}, + }, + DownstreamEdges: map[v1alpha1.NodeID][]v1alpha1.NodeID{ + v1alpha1.StartNodeID: {v1alpha1.EndNodeID}, + }, + }, + }, + }, n, ns + + } + tests := []struct { + name string + currentNodePhase v1alpha1.NodePhase + expectedNodePhase v1alpha1.NodePhase + expectedPhase executors.NodePhase + handlerReturn func() (handler.Status, error) + expectedError bool + }{ + // Starting at Queued + {"queued->success", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseSucceeded, executors.NodePhaseSuccess, func() (handler.Status, error) { + return handler.StatusSuccess, nil + }, false}, + + {"queued->failed", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseFailed, executors.NodePhaseFailed, func() (handler.Status, error) { + return handler.StatusFailed(fmt.Errorf("err")), nil + }, false}, + + {"queued->failing", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseFailing, executors.NodePhasePending, func() (handler.Status, error) { + return handler.StatusFailing(fmt.Errorf("err")), nil + }, false}, + + {"queued->error", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseQueued, executors.NodePhaseUndefined, func() (handler.Status, error) { + return handler.StatusUndefined, fmt.Errorf("err") + }, true}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + hf := &mocks2.HandlerFactory{} + exec.nodeHandlerFactory = hf + + h := &mocks3.IFace{} + h.On("StartNode", + mock.MatchedBy(func(ctx context.Context) bool { return true }), + mock.MatchedBy(func(o v1alpha1.ExecutableWorkflow) bool { return true }), + mock.MatchedBy(func(o v1alpha1.ExecutableNode) bool { return true }), + mock.MatchedBy(func(o *handler.Data) bool { return true }), + ).Return(test.handlerReturn()) + + hf.On("GetHandler", v1alpha1.NodeKindEnd).Return(h, nil) + + mockWf, _, mockNodeStatus := createSingleNodeWf(test.currentNodePhase, 0) + startNode := mockWf.StartNode() + startStatus := mockWf.GetNodeExecutionStatus(startNode.GetID()) + assert.Equal(t, v1alpha1.NodePhaseSucceeded, startStatus.GetPhase()) + s, err := exec.RecursiveNodeHandler(ctx, mockWf, startNode) + if test.expectedError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, test.expectedPhase, s.NodePhase, "expected: %s, received %s", test.expectedPhase.String(), s.NodePhase.String()) + assert.Equal(t, uint32(0), mockNodeStatus.GetAttempts()) + assert.Equal(t, test.expectedNodePhase, mockNodeStatus.GetPhase(), "expected %s, received %s", test.expectedNodePhase.String(), mockNodeStatus.GetPhase().String()) + }) + } + } + +} + +func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { + ctx := context.Background() + enQWf := func(workflowID v1alpha1.WorkflowID) { + } + mockEventSink := events.NewMockEventSink().(*events.MockEventSink) + + factory := createSingletonTaskExecutorFactory() + task.SetTestFactory(factory) + assert.True(t, task.IsTestModeEnabled()) + + store := createInmemoryDataStore(t, promutils.NewTestScope()) + catalogClient := catalog.NewCatalogClient(store) + + execIface, err := NewExecutor(ctx, store, enQWf, time.Second, mockEventSink, launchplan.NewFailFastLaunchPlanExecutor(), catalogClient, fakeKubeClient, promutils.NewTestScope()) + assert.NoError(t, err) + exec := execIface.(*nodeExecutor) + + defaultNodeID := "n1" + + createSingleNodeWf := func(p v1alpha1.NodePhase, maxAttempts int) (v1alpha1.ExecutableWorkflow, v1alpha1.ExecutableNode, v1alpha1.ExecutableNodeStatus) { + n := &v1alpha1.NodeSpec{ + ID: defaultNodeID, + Kind: v1alpha1.NodeKindTask, + RetryStrategy: &v1alpha1.RetryStrategy{ + MinAttempts: &maxAttempts, + }, + } + ns := &v1alpha1.NodeStatus{ + Phase: p, + } + + startNode := &v1alpha1.NodeSpec{ + Kind: v1alpha1.NodeKindStart, + ID: v1alpha1.StartNodeID, + } + return &v1alpha1.FlyteWorkflow{ + Status: v1alpha1.WorkflowStatus{ + NodeStatus: map[v1alpha1.NodeID]*v1alpha1.NodeStatus{ + defaultNodeID: ns, + v1alpha1.StartNodeID: { + Phase: v1alpha1.NodePhaseSucceeded, + }, + }, + DataDir: "data", + }, + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "wf", + Nodes: map[v1alpha1.NodeID]*v1alpha1.NodeSpec{ + defaultNodeID: n, + v1alpha1.StartNodeID: startNode, + }, + Connections: v1alpha1.Connections{ + UpstreamEdges: map[v1alpha1.NodeID][]v1alpha1.NodeID{ + defaultNodeID: {v1alpha1.StartNodeID}, + }, + DownstreamEdges: map[v1alpha1.NodeID][]v1alpha1.NodeID{ + v1alpha1.StartNodeID: {defaultNodeID}, + }, + }, + }, + }, n, ns + + } + + // Recursion test with child Node not yet started + { + nodeN0 := "n0" + nodeN2 := "n2" + ctx := context.Background() + connections := &v1alpha1.Connections{ + UpstreamEdges: map[v1alpha1.NodeID][]v1alpha1.NodeID{ + nodeN2: {nodeN0}, + }, + } + + setupNodePhase := func(n0Phase, n2Phase, expectedN2Phase v1alpha1.NodePhase) (*mocks.ExecutableWorkflow, *mocks.ExecutableNodeStatus) { + // Setup + mockN2Status := &mocks.ExecutableNodeStatus{} + // No parent node + mockN2Status.On("GetParentNodeID").Return(nil) + mockN2Status.On("GetPhase").Return(n2Phase) + mockN2Status.On("SetDataDir", mock.AnythingOfType(reflect.TypeOf(storage.DataReference("x")).String())) + mockN2Status.On("GetDataDir").Return(storage.DataReference("blah")) + mockN2Status.On("GetWorkflowNodeStatus").Return(nil) + mockN2Status.On("GetStoppedAt").Return(nil) + mockN2Status.On("UpdatePhase", expectedN2Phase, mock.Anything, mock.AnythingOfType("string")) + mockN2Status.On("IsDirty").Return(false) + + mockNode := &mocks.ExecutableNode{} + mockNode.On("GetID").Return(nodeN2) + mockNode.On("GetBranchNode").Return(nil) + mockNode.On("GetKind").Return(v1alpha1.NodeKindTask) + mockNode.On("IsStartNode").Return(false) + mockNode.On("IsEndNode").Return(false) + + mockNodeN0 := &mocks.ExecutableNode{} + mockNodeN0.On("GetID").Return(nodeN0) + mockNodeN0.On("GetBranchNode").Return(nil) + mockNodeN0.On("GetKind").Return(v1alpha1.NodeKindTask) + mockNodeN0.On("IsStartNode").Return(false) + mockNodeN0.On("IsEndNode").Return(false) + mockN0Status := &mocks.ExecutableNodeStatus{} + mockN0Status.On("GetPhase").Return(n0Phase) + mockN0Status.On("IsDirty").Return(false) + + mockWfStatus := &mocks.ExecutableWorkflowStatus{} + mockWf := &mocks.ExecutableWorkflow{} + mockWf.On("StartNode").Return(mockNodeN0) + mockWf.On("GetNode", nodeN2).Return(mockNode, true) + mockWf.On("GetNodeExecutionStatus", nodeN0).Return(mockN0Status) + mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockN2Status) + mockWf.On("GetConnections").Return(connections) + mockWf.On("GetID").Return("w1") + mockWf.On("FromNode", nodeN0).Return([]string{nodeN2}, nil) + mockWf.On("FromNode", nodeN2).Return([]string{}, fmt.Errorf("did not expect")) + mockWf.On("GetExecutionID").Return(v1alpha1.WorkflowExecutionIdentifier{}) + mockWf.On("GetExecutionStatus").Return(mockWfStatus) + mockWfStatus.On("GetDataDir").Return(storage.DataReference("x")) + return mockWf, mockN2Status + } + + tests := []struct { + name string + currentNodePhase v1alpha1.NodePhase + parentNodePhase v1alpha1.NodePhase + expectedNodePhase v1alpha1.NodePhase + expectedPhase executors.NodePhase + handlerReturn func() (handler.Status, error) + expectedError bool + updateCalled bool + }{ + {"notYetStarted->notYetStarted", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseFailed, v1alpha1.NodePhaseNotYetStarted, executors.NodePhaseFailed, func() (handler.Status, error) { + return handler.StatusNotStarted, nil + }, false, false}, + + {"notYetStarted->skipped", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseSkipped, v1alpha1.NodePhaseSkipped, executors.NodePhaseSuccess, func() (handler.Status, error) { + return handler.StatusSkipped, nil + }, false, true}, + + {"notYetStarted->queued", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseSucceeded, v1alpha1.NodePhaseQueued, executors.NodePhasePending, func() (handler.Status, error) { + return handler.StatusQueued, nil + }, false, true}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + hf := &mocks2.HandlerFactory{} + exec.nodeHandlerFactory = hf + + h := &mocks3.IFace{} + hf.On("GetHandler", v1alpha1.NodeKindTask).Return(h, nil) + + mockWf, _ := setupNodePhase(test.parentNodePhase, test.currentNodePhase, test.expectedNodePhase) + startNode := mockWf.StartNode() + s, err := exec.RecursiveNodeHandler(ctx, mockWf, startNode) + if test.expectedError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, test.expectedPhase, s.NodePhase, "expected: %s, received %s", test.expectedPhase.String(), s.NodePhase.String()) + }) + } + } + + // Recurse Child Node Queued previously + { + tests := []struct { + name string + currentNodePhase v1alpha1.NodePhase + expectedNodePhase v1alpha1.NodePhase + expectedPhase executors.NodePhase + handlerReturn func() (handler.Status, error) + expectedError bool + }{ + // Starting at Queued + {"queued->running", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseRunning, executors.NodePhasePending, func() (handler.Status, error) { + return handler.StatusRunning, nil + }, false}, + + {"queued->queued", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseQueued, executors.NodePhasePending, func() (handler.Status, error) { + return handler.StatusQueued, nil + }, false}, + + {"queued->failed", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseFailed, executors.NodePhaseFailed, func() (handler.Status, error) { + return handler.StatusFailed(fmt.Errorf("err")), nil + }, false}, + + {"queued->failing", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseFailing, executors.NodePhasePending, func() (handler.Status, error) { + return handler.StatusFailing(fmt.Errorf("err")), nil + }, false}, + + {"queued->success", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseSucceeded, executors.NodePhaseSuccess, func() (handler.Status, error) { + return handler.StatusSuccess, nil + }, false}, + + {"queued->error", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseQueued, executors.NodePhaseUndefined, func() (handler.Status, error) { + return handler.StatusUndefined, fmt.Errorf("err") + }, true}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + hf := &mocks2.HandlerFactory{} + exec.nodeHandlerFactory = hf + + h := &mocks3.IFace{} + h.On("StartNode", + mock.MatchedBy(func(ctx context.Context) bool { return true }), + mock.MatchedBy(func(o v1alpha1.ExecutableWorkflow) bool { return true }), + mock.MatchedBy(func(o v1alpha1.ExecutableNode) bool { return true }), + mock.MatchedBy(func(o *handler.Data) bool { return true }), + ).Return(test.handlerReturn()) + + hf.On("GetHandler", v1alpha1.NodeKindTask).Return(h, nil) + + mockWf, _, mockNodeStatus := createSingleNodeWf(test.currentNodePhase, 0) + startNode := mockWf.StartNode() + startStatus := mockWf.GetNodeExecutionStatus(startNode.GetID()) + assert.Equal(t, v1alpha1.NodePhaseSucceeded, startStatus.GetPhase()) + s, err := exec.RecursiveNodeHandler(ctx, mockWf, startNode) + if test.expectedError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, test.expectedPhase, s.NodePhase, "expected: %s, received %s", test.expectedPhase.String(), s.NodePhase.String()) + assert.Equal(t, uint32(0), mockNodeStatus.GetAttempts()) + assert.Equal(t, test.expectedNodePhase, mockNodeStatus.GetPhase(), "expected %s, received %s", test.expectedNodePhase.String(), mockNodeStatus.GetPhase().String()) + }) + } + } + + // Recurse Child Node started previously + { + tests := []struct { + name string + currentNodePhase v1alpha1.NodePhase + expectedNodePhase v1alpha1.NodePhase + expectedPhase executors.NodePhase + handlerReturn func() (handler.Status, error) + expectedError bool + }{ + // Starting at running + {"running->running", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseRunning, executors.NodePhasePending, func() (handler.Status, error) { + return handler.StatusRunning, nil + }, false}, + + {"running->failing", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseFailing, executors.NodePhasePending, func() (handler.Status, error) { + return handler.StatusFailing(fmt.Errorf("err")), nil + }, false}, + + {"running->failed", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseFailed, executors.NodePhaseFailed, func() (handler.Status, error) { + return handler.StatusFailed(fmt.Errorf("err")), nil + }, false}, + + {"running->success", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseSucceeded, executors.NodePhaseSuccess, func() (handler.Status, error) { + return handler.StatusSuccess, nil + }, false}, + + {"running->error", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseRunning, executors.NodePhaseUndefined, func() (handler.Status, error) { + return handler.StatusUndefined, fmt.Errorf("err") + }, true}, + + {"previously-failed", v1alpha1.NodePhaseFailed, v1alpha1.NodePhaseFailed, executors.NodePhaseFailed, func() (handler.Status, error) { + return handler.StatusQueued, nil + }, false}, + + {"previously-success", v1alpha1.NodePhaseSucceeded, v1alpha1.NodePhaseSucceeded, executors.NodePhaseComplete, func() (handler.Status, error) { + return handler.StatusQueued, nil + }, false}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + hf := &mocks2.HandlerFactory{} + exec.nodeHandlerFactory = hf + + h := &mocks3.IFace{} + h.On("CheckNodeStatus", + mock.MatchedBy(func(ctx context.Context) bool { return true }), + mock.MatchedBy(func(o v1alpha1.ExecutableWorkflow) bool { return true }), + mock.MatchedBy(func(o v1alpha1.ExecutableNode) bool { return true }), + mock.MatchedBy(func(o v1alpha1.ExecutableNodeStatus) bool { return true }), + ).Return(test.handlerReturn()) + + hf.On("GetHandler", v1alpha1.NodeKindTask).Return(h, nil) + + mockWf, _, mockNodeStatus := createSingleNodeWf(test.currentNodePhase, 0) + startNode := mockWf.StartNode() + s, err := exec.RecursiveNodeHandler(ctx, mockWf, startNode) + if test.expectedError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, test.expectedPhase, s.NodePhase, "expected: %s, received %s", test.expectedPhase.String(), s.NodePhase.String()) + assert.Equal(t, uint32(0), mockNodeStatus.GetAttempts()) + assert.Equal(t, test.expectedNodePhase, mockNodeStatus.GetPhase(), "expected %s, received %s", test.expectedNodePhase.String(), mockNodeStatus.GetPhase().String()) + }) + } + } +} + +func TestNodeExecutor_RecursiveNodeHandler_NoDownstream(t *testing.T) { + ctx := context.Background() + enQWf := func(workflowID v1alpha1.WorkflowID) { + } + mockEventSink := events.NewMockEventSink().(*events.MockEventSink) + + factory := createSingletonTaskExecutorFactory() + task.SetTestFactory(factory) + assert.True(t, task.IsTestModeEnabled()) + + store := createInmemoryDataStore(t, promutils.NewTestScope()) + catalogClient := catalog.NewCatalogClient(store) + + execIface, err := NewExecutor(ctx, store, enQWf, time.Second, mockEventSink, launchplan.NewFailFastLaunchPlanExecutor(), catalogClient, fakeKubeClient, promutils.NewTestScope()) + assert.NoError(t, err) + exec := execIface.(*nodeExecutor) + + defaultNodeID := "n1" + + createSingleNodeWf := func(p v1alpha1.NodePhase, maxAttempts int) (v1alpha1.ExecutableWorkflow, v1alpha1.ExecutableNode, v1alpha1.ExecutableNodeStatus) { + n := &v1alpha1.NodeSpec{ + ID: defaultNodeID, + Kind: v1alpha1.NodeKindTask, + RetryStrategy: &v1alpha1.RetryStrategy{ + MinAttempts: &maxAttempts, + }, + } + ns := &v1alpha1.NodeStatus{ + Phase: p, + } + + return &v1alpha1.FlyteWorkflow{ + Status: v1alpha1.WorkflowStatus{ + NodeStatus: map[v1alpha1.NodeID]*v1alpha1.NodeStatus{ + defaultNodeID: ns, + v1alpha1.StartNodeID: { + Phase: v1alpha1.NodePhaseSucceeded, + }, + }, + DataDir: "data", + }, + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "wf", + Nodes: map[v1alpha1.NodeID]*v1alpha1.NodeSpec{ + defaultNodeID: n, + }, + Connections: v1alpha1.Connections{ + UpstreamEdges: map[v1alpha1.NodeID][]v1alpha1.NodeID{ + defaultNodeID: {v1alpha1.StartNodeID}, + }, + }, + }, + }, n, ns + + } + + // Node not yet started + { + tests := []struct { + name string + currentNodePhase v1alpha1.NodePhase + expectedNodePhase v1alpha1.NodePhase + expectedPhase executors.NodePhase + handlerReturn func() (handler.Status, error) + expectedError bool + }{ + {"notYetStarted->running", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseQueued, executors.NodePhaseQueued, func() (handler.Status, error) { + return handler.StatusRunning, nil + }, false}, + + {"notYetStarted->queued", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseQueued, executors.NodePhaseQueued, func() (handler.Status, error) { + return handler.StatusQueued, nil + }, false}, + + {"notYetStarted->failed", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseQueued, executors.NodePhaseQueued, func() (handler.Status, error) { + return handler.StatusFailed(fmt.Errorf("err")), nil + }, false}, + + {"notYetStarted->success", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseQueued, executors.NodePhaseQueued, func() (handler.Status, error) { + return handler.StatusSuccess, nil + }, false}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + hf := &mocks2.HandlerFactory{} + exec.nodeHandlerFactory = hf + + h := &mocks3.IFace{} + hf.On("GetHandler", v1alpha1.NodeKindTask).Return(h, nil) + + mockWf, mockNode, mockNodeStatus := createSingleNodeWf(test.currentNodePhase, 0) + s, err := exec.RecursiveNodeHandler(ctx, mockWf, mockNode) + if test.expectedError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, test.expectedPhase, s.NodePhase, "expected: %s, received %s", test.expectedPhase.String(), s.NodePhase.String()) + assert.Equal(t, uint32(0), mockNodeStatus.GetAttempts()) + assert.Equal(t, test.expectedNodePhase, mockNodeStatus.GetPhase(), "expected %s, received %s", test.expectedNodePhase.String(), mockNodeStatus.GetPhase().String()) + }) + } + } + + // Node queued previously + { + tests := []struct { + name string + currentNodePhase v1alpha1.NodePhase + expectedNodePhase v1alpha1.NodePhase + expectedPhase executors.NodePhase + handlerReturn func() (handler.Status, error) + expectedError bool + }{ + // Starting at Queued + {"queued->running", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseRunning, executors.NodePhaseRunning, func() (handler.Status, error) { + return handler.StatusRunning, nil + }, false}, + + {"queued->queued", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseQueued, executors.NodePhaseQueued, func() (handler.Status, error) { + return handler.StatusQueued, nil + }, false}, + + {"queued->failed", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseFailed, executors.NodePhaseFailed, func() (handler.Status, error) { + return handler.StatusFailed(fmt.Errorf("err")), nil + }, false}, + + {"queued->failing", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseFailing, executors.NodePhaseRunning, func() (handler.Status, error) { + return handler.StatusFailing(fmt.Errorf("err")), nil + }, false}, + + {"queued->success", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseSucceeded, executors.NodePhaseSuccess, func() (handler.Status, error) { + return handler.StatusSuccess, nil + }, false}, + + {"queued->error", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseQueued, executors.NodePhaseUndefined, func() (handler.Status, error) { + return handler.StatusUndefined, fmt.Errorf("err") + }, true}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + hf := &mocks2.HandlerFactory{} + exec.nodeHandlerFactory = hf + + h := &mocks3.IFace{} + h.On("StartNode", + mock.MatchedBy(func(ctx context.Context) bool { return true }), + mock.MatchedBy(func(o v1alpha1.ExecutableWorkflow) bool { return true }), + mock.MatchedBy(func(o v1alpha1.ExecutableNode) bool { return true }), + mock.MatchedBy(func(o *handler.Data) bool { return true }), + ).Return(test.handlerReturn()) + + hf.On("GetHandler", v1alpha1.NodeKindTask).Return(h, nil) + + mockWf, mockNode, mockNodeStatus := createSingleNodeWf(test.currentNodePhase, 0) + s, err := exec.RecursiveNodeHandler(ctx, mockWf, mockNode) + if test.expectedError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, test.expectedPhase, s.NodePhase, "expected: %s, received %s", test.expectedPhase.String(), s.NodePhase.String()) + assert.Equal(t, uint32(0), mockNodeStatus.GetAttempts()) + assert.Equal(t, test.expectedNodePhase, mockNodeStatus.GetPhase(), "expected %s, received %s", test.expectedNodePhase.String(), mockNodeStatus.GetPhase().String()) + }) + } + } + + // Node started previously + { + tests := []struct { + name string + currentNodePhase v1alpha1.NodePhase + expectedNodePhase v1alpha1.NodePhase + expectedPhase executors.NodePhase + handlerReturn func() (handler.Status, error) + expectedError bool + }{ + // Starting at running + {"running->running", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseRunning, executors.NodePhaseRunning, func() (handler.Status, error) { + return handler.StatusRunning, nil + }, false}, + + {"running->failing", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseFailing, executors.NodePhaseRunning, func() (handler.Status, error) { + return handler.StatusFailing(fmt.Errorf("err")), nil + }, false}, + + {"running->failed", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseFailed, executors.NodePhaseFailed, func() (handler.Status, error) { + return handler.StatusFailed(fmt.Errorf("err")), nil + }, false}, + + {"running->success", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseSucceeded, executors.NodePhaseSuccess, func() (handler.Status, error) { + return handler.StatusSuccess, nil + }, false}, + + {"running->error", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseRunning, executors.NodePhaseUndefined, func() (handler.Status, error) { + return handler.StatusUndefined, fmt.Errorf("err") + }, true}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + hf := &mocks2.HandlerFactory{} + exec.nodeHandlerFactory = hf + + h := &mocks3.IFace{} + h.On("CheckNodeStatus", + mock.MatchedBy(func(ctx context.Context) bool { return true }), + mock.MatchedBy(func(o v1alpha1.ExecutableWorkflow) bool { return true }), + mock.MatchedBy(func(o v1alpha1.ExecutableNode) bool { return true }), + mock.MatchedBy(func(o v1alpha1.ExecutableNodeStatus) bool { return true }), + ).Return(test.handlerReturn()) + + hf.On("GetHandler", v1alpha1.NodeKindTask).Return(h, nil) + + mockWf, mockNode, mockNodeStatus := createSingleNodeWf(test.currentNodePhase, 0) + s, err := exec.RecursiveNodeHandler(ctx, mockWf, mockNode) + if test.expectedError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, test.expectedPhase, s.NodePhase, "expected: %s, received %s", test.expectedPhase.String(), s.NodePhase.String()) + assert.Equal(t, uint32(0), mockNodeStatus.GetAttempts()) + assert.Equal(t, test.expectedNodePhase, mockNodeStatus.GetPhase(), "expected %s, received %s", test.expectedNodePhase.String(), mockNodeStatus.GetPhase().String()) + }) + } + } + + // Node started previously and is failing + { + tests := []struct { + name string + currentNodePhase v1alpha1.NodePhase + expectedNodePhase v1alpha1.NodePhase + expectedPhase executors.NodePhase + handlerReturn func() (handler.Status, error) + expectedError bool + }{ + // Starting at Failing + // TODO this should be illegal + {"failing->running", v1alpha1.NodePhaseFailing, v1alpha1.NodePhaseRunning, executors.NodePhaseRunning, func() (handler.Status, error) { + return handler.StatusRunning, nil + }, false}, + + {"failing->failed", v1alpha1.NodePhaseFailing, v1alpha1.NodePhaseFailed, executors.NodePhaseFailed, func() (handler.Status, error) { + return handler.StatusFailed(fmt.Errorf("err")), nil + }, false}, + + // TODO this should be illegal + {"failing->success", v1alpha1.NodePhaseFailing, v1alpha1.NodePhaseSucceeded, executors.NodePhaseSuccess, func() (handler.Status, error) { + return handler.StatusSuccess, nil + }, false}, + + {"failing->error", v1alpha1.NodePhaseFailing, v1alpha1.NodePhaseFailing, executors.NodePhaseUndefined, func() (handler.Status, error) { + return handler.StatusUndefined, fmt.Errorf("err") + }, true}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + hf := &mocks2.HandlerFactory{} + exec.nodeHandlerFactory = hf + + h := &mocks3.IFace{} + h.On("HandleFailingNode", + mock.MatchedBy(func(ctx context.Context) bool { return true }), + mock.MatchedBy(func(o v1alpha1.ExecutableWorkflow) bool { return true }), + mock.MatchedBy(func(o v1alpha1.ExecutableNode) bool { return true }), + ).Return(test.handlerReturn()) + + hf.On("GetHandler", v1alpha1.NodeKindTask).Return(h, nil) + + mockWf, mockNode, mockNodeStatus := createSingleNodeWf(test.currentNodePhase, 0) + s, err := exec.RecursiveNodeHandler(ctx, mockWf, mockNode) + if test.expectedError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, test.expectedPhase, s.NodePhase, "expected: %s, received %s", test.expectedPhase.String(), s.NodePhase.String()) + assert.Equal(t, uint32(0), mockNodeStatus.GetAttempts()) + assert.Equal(t, test.expectedNodePhase, mockNodeStatus.GetPhase(), "expected %s, received %s", test.expectedNodePhase.String(), mockNodeStatus.GetPhase().String()) + }) + } + } + + // Node started previously and retryable failure + { + tests := []struct { + name string + currentNodePhase v1alpha1.NodePhase + expectedNodePhase v1alpha1.NodePhase + expectedPhase executors.NodePhase + handlerReturn func() (handler.Status, error) + expectedError bool + }{ + // Starting at Queued + {"running->retryable", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseRetryableFailure, executors.NodePhaseRunning, func() (handler.Status, error) { + return handler.StatusRetryableFailure(fmt.Errorf("err")), nil + }, false}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + hf := &mocks2.HandlerFactory{} + exec.nodeHandlerFactory = hf + + h := &mocks3.IFace{} + h.On("CheckNodeStatus", + mock.MatchedBy(func(ctx context.Context) bool { return true }), + mock.MatchedBy(func(o v1alpha1.ExecutableWorkflow) bool { return true }), + mock.MatchedBy(func(o v1alpha1.ExecutableNode) bool { return true }), + mock.MatchedBy(func(o v1alpha1.ExecutableNodeStatus) bool { return true }), + ).Return(test.handlerReturn()) + + hf.On("GetHandler", v1alpha1.NodeKindTask).Return(h, nil) + + mockWf, mockNode, mockNodeStatus := createSingleNodeWf(test.currentNodePhase, 2) + s, err := exec.RecursiveNodeHandler(ctx, mockWf, mockNode) + if test.expectedError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, test.expectedPhase, s.NodePhase, "expected: %s, received %s", test.expectedPhase.String(), s.NodePhase.String()) + assert.Equal(t, uint32(1), mockNodeStatus.GetAttempts()) + assert.Equal(t, test.expectedNodePhase, mockNodeStatus.GetPhase(), "expected %s, received %s", test.expectedNodePhase.String(), mockNodeStatus.GetPhase().String()) + }) + } + } + + // Node started previously and retryable failure - but exhausted attempts + { + tests := []struct { + name string + currentNodePhase v1alpha1.NodePhase + expectedNodePhase v1alpha1.NodePhase + expectedPhase executors.NodePhase + handlerReturn func() (handler.Status, error) + expectedError bool + }{ + {"running->retryable", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseFailed, executors.NodePhaseFailed, func() (handler.Status, error) { + return handler.StatusRetryableFailure(fmt.Errorf("err")), nil + }, false}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + hf := &mocks2.HandlerFactory{} + exec.nodeHandlerFactory = hf + + h := &mocks3.IFace{} + h.On("CheckNodeStatus", + mock.MatchedBy(func(ctx context.Context) bool { return true }), + mock.MatchedBy(func(o v1alpha1.ExecutableWorkflow) bool { return true }), + mock.MatchedBy(func(o v1alpha1.ExecutableNode) bool { return true }), + mock.MatchedBy(func(o v1alpha1.ExecutableNodeStatus) bool { return true }), + ).Return(test.handlerReturn()) + + hf.On("GetHandler", v1alpha1.NodeKindTask).Return(h, nil) + + mockWf, mockNode, mockNodeStatus := createSingleNodeWf(test.currentNodePhase, 1) + s, err := exec.RecursiveNodeHandler(ctx, mockWf, mockNode) + if test.expectedError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, test.expectedPhase, s.NodePhase, "expected: %s, received %s", test.expectedPhase.String(), s.NodePhase.String()) + assert.Equal(t, uint32(1), mockNodeStatus.GetAttempts()) + assert.Equal(t, test.expectedNodePhase, mockNodeStatus.GetPhase(), "expected %s, received %s", test.expectedNodePhase.String(), mockNodeStatus.GetPhase().String()) + }) + } + } +} + +func TestNodeExecutor_RecursiveNodeHandler_UpstreamNotReady(t *testing.T) { + ctx := context.Background() + enQWf := func(workflowID v1alpha1.WorkflowID) { + } + mockEventSink := events.NewMockEventSink().(*events.MockEventSink) + + factory := createSingletonTaskExecutorFactory() + task.SetTestFactory(factory) + assert.True(t, task.IsTestModeEnabled()) + + store := createInmemoryDataStore(t, promutils.NewTestScope()) + catalogClient := catalog.NewCatalogClient(store) + + execIface, err := NewExecutor(ctx, store, enQWf, time.Second, mockEventSink, launchplan.NewFailFastLaunchPlanExecutor(), catalogClient, fakeKubeClient, promutils.NewTestScope()) + assert.NoError(t, err) + exec := execIface.(*nodeExecutor) + + defaultNodeID := "n1" + + createSingleNodeWf := func(parentPhase v1alpha1.NodePhase, maxAttempts int) (v1alpha1.ExecutableWorkflow, v1alpha1.ExecutableNode, v1alpha1.ExecutableNodeStatus) { + n := &v1alpha1.NodeSpec{ + ID: defaultNodeID, + Kind: v1alpha1.NodeKindTask, + RetryStrategy: &v1alpha1.RetryStrategy{ + MinAttempts: &maxAttempts, + }, + } + ns := &v1alpha1.NodeStatus{} + + return &v1alpha1.FlyteWorkflow{ + Status: v1alpha1.WorkflowStatus{ + NodeStatus: map[v1alpha1.NodeID]*v1alpha1.NodeStatus{ + defaultNodeID: ns, + v1alpha1.StartNodeID: { + Phase: parentPhase, + }, + }, + DataDir: "data", + }, + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "wf", + Nodes: map[v1alpha1.NodeID]*v1alpha1.NodeSpec{ + defaultNodeID: n, + }, + Connections: v1alpha1.Connections{ + UpstreamEdges: map[v1alpha1.NodeID][]v1alpha1.NodeID{ + defaultNodeID: {v1alpha1.StartNodeID}, + }, + }, + }, + }, n, ns + + } + + // Node not yet started + { + tests := []struct { + name string + parentNodePhase v1alpha1.NodePhase + expectedNodePhase v1alpha1.NodePhase + expectedPhase executors.NodePhase + expectedError bool + }{ + {"notYetStarted", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseNotYetStarted, executors.NodePhasePending, false}, + {"running", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseNotYetStarted, executors.NodePhasePending, false}, + {"queued", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseNotYetStarted, executors.NodePhasePending, false}, + {"retryable", v1alpha1.NodePhaseRetryableFailure, v1alpha1.NodePhaseNotYetStarted, executors.NodePhasePending, false}, + {"failing", v1alpha1.NodePhaseFailing, v1alpha1.NodePhaseNotYetStarted, executors.NodePhasePending, false}, + {"skipped", v1alpha1.NodePhaseSkipped, v1alpha1.NodePhaseSkipped, executors.NodePhaseSuccess, false}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + hf := &mocks2.HandlerFactory{} + exec.nodeHandlerFactory = hf + h := &mocks3.IFace{} + hf.On("GetHandler", v1alpha1.NodeKindTask).Return(h, nil) + + mockWf, mockNode, mockNodeStatus := createSingleNodeWf(test.parentNodePhase, 0) + s, err := exec.RecursiveNodeHandler(ctx, mockWf, mockNode) + if test.expectedError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, test.expectedPhase, s.NodePhase, "expected: %s, received %s", test.expectedPhase.String(), s.NodePhase.String()) + assert.Equal(t, uint32(0), mockNodeStatus.GetAttempts()) + assert.Equal(t, test.expectedNodePhase, mockNodeStatus.GetPhase(), "expected %s, received %s", test.expectedNodePhase.String(), mockNodeStatus.GetPhase().String()) + }) + } + } +} + +func Test_nodeExecutor_RecordTransitionLatency(t *testing.T) { + testScope := promutils.NewTestScope() + type fields struct { + nodeHandlerFactory HandlerFactory + enqueueWorkflow v1alpha1.EnqueueWorkflow + store *storage.DataStore + nodeRecorder events.NodeEventRecorder + metrics *nodeMetrics + } + type args struct { + w v1alpha1.ExecutableWorkflow + node v1alpha1.ExecutableNode + nodeStatus v1alpha1.ExecutableNodeStatus + } + + nsf := func(phase v1alpha1.NodePhase, lastUpdated *time.Time) *mocks.ExecutableNodeStatus { + ns := &mocks.ExecutableNodeStatus{} + ns.On("GetPhase").Return(phase) + var t *v1.Time + if lastUpdated != nil { + t = &v1.Time{Time: *lastUpdated} + } + ns.On("GetLastUpdatedAt").Return(t) + return ns + } + testTime := time.Now() + tests := []struct { + name string + fields fields + args args + recordingExpected bool + }{ + { + "retryable-failure", + fields{metrics: &nodeMetrics{TransitionLatency: labeled.NewStopWatch("test", "xyz", time.Millisecond, testScope)}}, + args{nodeStatus: nsf(v1alpha1.NodePhaseRetryableFailure, &testTime)}, + true, + }, + { + "retryable-failure-notime", + fields{metrics: &nodeMetrics{TransitionLatency: labeled.NewStopWatch("test2", "xyz", time.Millisecond, testScope)}}, + args{nodeStatus: nsf(v1alpha1.NodePhaseRetryableFailure, nil)}, + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &nodeExecutor{ + nodeHandlerFactory: tt.fields.nodeHandlerFactory, + enqueueWorkflow: tt.fields.enqueueWorkflow, + store: tt.fields.store, + nodeRecorder: tt.fields.nodeRecorder, + metrics: tt.fields.metrics, + } + c.RecordTransitionLatency(context.TODO(), tt.args.w, tt.args.node, tt.args.nodeStatus) + + ch := make(chan prometheus.Metric, 2) + tt.fields.metrics.TransitionLatency.Collect(ch) + assert.Equal(t, len(ch) == 1, tt.recordingExpected) + }) + } +} diff --git a/flytepropeller/pkg/controller/nodes/handler/iface.go b/flytepropeller/pkg/controller/nodes/handler/iface.go new file mode 100644 index 0000000000..4101ba09df --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/handler/iface.go @@ -0,0 +1,128 @@ +package handler + +import ( + "context" + "time" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" +) + +//go:generate mockery -all + +type Data = core.LiteralMap +type VarName = string + +type Phase int + +const ( + // Indicates that the handler was unable to Start the Node due to an internal failure + PhaseNotStarted Phase = iota + // Incase of retryable failure and should be retried + PhaseRetryableFailure + // Indicates that the node is queued because the task is queued + PhaseQueued + // Indicates that the node is currently executing and no errors have been observed + PhaseRunning + // PhaseFailing is currently used by SubWorkflow Only. It indicates that the Node's primary work has failed, + // but, either some cleanup or exception handling condition is in progress + PhaseFailing + // This is a terminal Status and indicates that the node execution resulted in a Failure + PhaseFailed + // This is a pre-terminal state, currently unused and indicates that the Node execution has succeeded barring any cleanup + PhaseSucceeding + // This is a terminal state and indicates successful completion of the node execution. + PhaseSuccess + // This Phase indicates that the node execution can be skipped, because of either conditional failures or user defined cases + PhaseSkipped + // This phase indicates that an error occurred and is always accompanied by `error`. the execution for that node is + // in an indeterminate state and should be retried + PhaseUndefined +) + +var PhasesToString = map[Phase]string{ + PhaseNotStarted: "NotStarted", + PhaseQueued: "Queued", + PhaseRunning: "Running", + PhaseFailing: "Failing", + PhaseFailed: "Failed", + PhaseSucceeding: "Succeeding", + PhaseSuccess: "Success", + PhaseSkipped: "Skipped", + PhaseUndefined: "Undefined", + PhaseRetryableFailure: "RetryableFailure", +} + +func (p Phase) String() string { + str, found := PhasesToString[p] + if found { + return str + } + + return "Unknown" +} + +// This encapsulates the status of the node +type Status struct { + Phase Phase + Err error + OccurredAt time.Time +} + +var StatusNotStarted = Status{Phase: PhaseNotStarted} +var StatusQueued = Status{Phase: PhaseQueued} +var StatusRunning = Status{Phase: PhaseRunning} +var StatusSucceeding = Status{Phase: PhaseSucceeding} +var StatusSuccess = Status{Phase: PhaseSuccess} +var StatusUndefined = Status{Phase: PhaseUndefined} +var StatusSkipped = Status{Phase: PhaseSkipped} + +func (s Status) WithOccurredAt(t time.Time) Status { + s.OccurredAt = t + return s +} + +func StatusFailed(err error) Status { + return Status{Phase: PhaseFailed, Err: err} +} + +func StatusRetryableFailure(err error) Status { + return Status{Phase: PhaseRetryableFailure, Err: err} +} + +func StatusFailing(err error) Status { + return Status{Phase: PhaseFailing, Err: err} +} + +type OutputResolver interface { + // Extracts a subset of node outputs to literals. + ExtractOutput(ctx context.Context, w v1alpha1.ExecutableWorkflow, n v1alpha1.ExecutableNode, + bindToVar VarName) (values *core.Literal, err error) +} + +type PostNodeSuccessHandler interface { + HandleNodeSuccess(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) (Status, error) +} + +// Interface that should be implemented for a node type. +type IFace interface { + //OutputResolver + + // Initialize should be called, before invoking any other methods of this handler. Initialize will be called using one thread + // only + Initialize(ctx context.Context) error + + // Start node is called for a node only if the recorded state indicates that the node was never started previously. + // the implementation should handle idempotency, even if the chance of invoking it more than once for an execution is rare. + StartNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, nodeInputs *Data) (Status, error) + + // For any node that is not in a NEW/READY state in the recording, CheckNodeStatus will be invoked. The implementation should handle + // idempotency and return the current observed state of the node + CheckNodeStatus(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, previousNodeStatus v1alpha1.ExecutableNodeStatus) (Status, error) + + // This is called in the case, a node failure is observed. + HandleFailingNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) (Status, error) + + // Abort is invoked as a way to clean up failing/aborted workflows + AbortNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) error +} diff --git a/flytepropeller/pkg/controller/nodes/handler/mocks/IFace.go b/flytepropeller/pkg/controller/nodes/handler/mocks/IFace.go new file mode 100644 index 0000000000..b21e42da5d --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/handler/mocks/IFace.go @@ -0,0 +1,105 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import context "context" +import core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" +import handler "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" +import mock "github.com/stretchr/testify/mock" +import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +// IFace is an autogenerated mock type for the IFace type +type IFace struct { + mock.Mock +} + +// AbortNode provides a mock function with given fields: ctx, w, node +func (_m *IFace) AbortNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) error { + ret := _m.Called(ctx, w, node) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, v1alpha1.ExecutableWorkflow, v1alpha1.ExecutableNode) error); ok { + r0 = rf(ctx, w, node) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// CheckNodeStatus provides a mock function with given fields: ctx, w, node, previousNodeStatus +func (_m *IFace) CheckNodeStatus(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, previousNodeStatus v1alpha1.ExecutableNodeStatus) (handler.Status, error) { + ret := _m.Called(ctx, w, node, previousNodeStatus) + + var r0 handler.Status + if rf, ok := ret.Get(0).(func(context.Context, v1alpha1.ExecutableWorkflow, v1alpha1.ExecutableNode, v1alpha1.ExecutableNodeStatus) handler.Status); ok { + r0 = rf(ctx, w, node, previousNodeStatus) + } else { + r0 = ret.Get(0).(handler.Status) + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, v1alpha1.ExecutableWorkflow, v1alpha1.ExecutableNode, v1alpha1.ExecutableNodeStatus) error); ok { + r1 = rf(ctx, w, node, previousNodeStatus) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// HandleFailingNode provides a mock function with given fields: ctx, w, node +func (_m *IFace) HandleFailingNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) (handler.Status, error) { + ret := _m.Called(ctx, w, node) + + var r0 handler.Status + if rf, ok := ret.Get(0).(func(context.Context, v1alpha1.ExecutableWorkflow, v1alpha1.ExecutableNode) handler.Status); ok { + r0 = rf(ctx, w, node) + } else { + r0 = ret.Get(0).(handler.Status) + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, v1alpha1.ExecutableWorkflow, v1alpha1.ExecutableNode) error); ok { + r1 = rf(ctx, w, node) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Initialize provides a mock function with given fields: ctx +func (_m *IFace) Initialize(ctx context.Context) error { + ret := _m.Called(ctx) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = rf(ctx) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// StartNode provides a mock function with given fields: ctx, w, node, nodeInputs +func (_m *IFace) StartNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, nodeInputs *core.LiteralMap) (handler.Status, error) { + ret := _m.Called(ctx, w, node, nodeInputs) + + var r0 handler.Status + if rf, ok := ret.Get(0).(func(context.Context, v1alpha1.ExecutableWorkflow, v1alpha1.ExecutableNode, *core.LiteralMap) handler.Status); ok { + r0 = rf(ctx, w, node, nodeInputs) + } else { + r0 = ret.Get(0).(handler.Status) + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, v1alpha1.ExecutableWorkflow, v1alpha1.ExecutableNode, *core.LiteralMap) error); ok { + r1 = rf(ctx, w, node, nodeInputs) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/flytepropeller/pkg/controller/nodes/handler/mocks/OutputResolver.go b/flytepropeller/pkg/controller/nodes/handler/mocks/OutputResolver.go new file mode 100644 index 0000000000..92b9560cc7 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/handler/mocks/OutputResolver.go @@ -0,0 +1,37 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import context "context" +import core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + +import mock "github.com/stretchr/testify/mock" +import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +// OutputResolver is an autogenerated mock type for the OutputResolver type +type OutputResolver struct { + mock.Mock +} + +// ExtractOutput provides a mock function with given fields: ctx, w, n, bindToVar +func (_m *OutputResolver) ExtractOutput(ctx context.Context, w v1alpha1.ExecutableWorkflow, n v1alpha1.ExecutableNode, bindToVar string) (*core.Literal, error) { + ret := _m.Called(ctx, w, n, bindToVar) + + var r0 *core.Literal + if rf, ok := ret.Get(0).(func(context.Context, v1alpha1.ExecutableWorkflow, v1alpha1.ExecutableNode, string) *core.Literal); ok { + r0 = rf(ctx, w, n, bindToVar) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.Literal) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, v1alpha1.ExecutableWorkflow, v1alpha1.ExecutableNode, string) error); ok { + r1 = rf(ctx, w, n, bindToVar) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/flytepropeller/pkg/controller/nodes/handler_factory.go b/flytepropeller/pkg/controller/nodes/handler_factory.go new file mode 100644 index 0000000000..5869988979 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/handler_factory.go @@ -0,0 +1,76 @@ +package nodes + +import ( + "context" + "time" + + "github.com/lyft/flytepropeller/pkg/controller/nodes/dynamic" + + "github.com/lyft/flytestdlib/promutils" + + "github.com/lyft/flyteidl/clients/go/events" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/controller/catalog" + "github.com/lyft/flytepropeller/pkg/controller/executors" + "github.com/lyft/flytepropeller/pkg/controller/nodes/branch" + "github.com/lyft/flytepropeller/pkg/controller/nodes/end" + "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" + "github.com/lyft/flytepropeller/pkg/controller/nodes/start" + "github.com/lyft/flytepropeller/pkg/controller/nodes/subworkflow" + "github.com/lyft/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" + "github.com/lyft/flytepropeller/pkg/controller/nodes/task" + "github.com/lyft/flytestdlib/storage" + "github.com/pkg/errors" +) + +//go:generate mockery -name HandlerFactory + +type HandlerFactory interface { + GetHandler(kind v1alpha1.NodeKind) (handler.IFace, error) +} + +type handlerFactory struct { + handlers map[v1alpha1.NodeKind]handler.IFace +} + +func (f handlerFactory) GetHandler(kind v1alpha1.NodeKind) (handler.IFace, error) { + h, ok := f.handlers[kind] + if !ok { + return nil, errors.Errorf("Handler not registered for NodeKind [%v]", kind) + } + return h, nil +} + +func NewHandlerFactory(ctx context.Context, + executor executors.Node, + eventSink events.EventSink, + workflowLauncher launchplan.Executor, + enQWorkflow v1alpha1.EnqueueWorkflow, + revalPeriod time.Duration, + store *storage.DataStore, + catalogClient catalog.Client, + kubeClient executors.Client, + scope promutils.Scope, +) (HandlerFactory, error) { + + f := &handlerFactory{ + handlers: map[v1alpha1.NodeKind]handler.IFace{ + v1alpha1.NodeKindBranch: branch.New(executor, eventSink, scope), + v1alpha1.NodeKindTask: dynamic.New( + task.New(eventSink, store, enQWorkflow, revalPeriod, catalogClient, kubeClient, scope), + executor, + enQWorkflow, + store, + scope), + v1alpha1.NodeKindWorkflow: subworkflow.New(executor, eventSink, workflowLauncher, enQWorkflow, store, scope), + v1alpha1.NodeKindStart: start.New(store), + v1alpha1.NodeKindEnd: end.New(store), + }, + } + for _, v := range f.handlers { + if err := v.Initialize(ctx); err != nil { + return nil, err + } + } + return f, nil +} diff --git a/flytepropeller/pkg/controller/nodes/mocks/HandlerFactory.go b/flytepropeller/pkg/controller/nodes/mocks/HandlerFactory.go new file mode 100644 index 0000000000..024d0c3524 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/mocks/HandlerFactory.go @@ -0,0 +1,36 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import handler "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" +import mock "github.com/stretchr/testify/mock" + +import v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +// HandlerFactory is an autogenerated mock type for the HandlerFactory type +type HandlerFactory struct { + mock.Mock +} + +// GetHandler provides a mock function with given fields: kind +func (_m *HandlerFactory) GetHandler(kind v1alpha1.NodeKind) (handler.IFace, error) { + ret := _m.Called(kind) + + var r0 handler.IFace + if rf, ok := ret.Get(0).(func(v1alpha1.NodeKind) handler.IFace); ok { + r0 = rf(kind) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(handler.IFace) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(v1alpha1.NodeKind) error); ok { + r1 = rf(kind) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/flytepropeller/pkg/controller/nodes/predicate.go b/flytepropeller/pkg/controller/nodes/predicate.go new file mode 100644 index 0000000000..cd718d8bc3 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/predicate.go @@ -0,0 +1,111 @@ +package nodes + +import ( + "context" + "time" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/controller/nodes/errors" + "github.com/lyft/flytestdlib/logger" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +// Special enum to indicate if the node under consideration is ready to be executed or should be skipped +type PredicatePhase int + +const ( + // Indicates node is not yet ready to be executed + PredicatePhaseNotReady PredicatePhase = iota + // Indicates node is ready to be executed - execution should proceed + PredicatePhaseReady + // Indicates that the node execution should be skipped as one of its parents was skipped or the branch was not taken + PredicatePhaseSkip + // Indicates failure during Predicate check + PredicatePhaseUndefined +) + +func CanExecute(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.BaseNode) (PredicatePhase, error) { + nodeID := node.GetID() + if nodeID == v1alpha1.StartNodeID { + logger.Debugf(ctx, "Start Node id is assumed to be ready.") + return PredicatePhaseReady, nil + } + nodeStatus := w.GetNodeExecutionStatus(nodeID) + parentNodeID := nodeStatus.GetParentNodeID() + upstreamNodes, ok := w.GetConnections().UpstreamEdges[nodeID] + if !ok { + return PredicatePhaseUndefined, errors.Errorf(errors.BadSpecificationError, nodeID, "Unable to find upstream nodes for Node") + } + skipped := false + for _, upstreamNodeID := range upstreamNodes { + upstreamNodeStatus := w.GetNodeExecutionStatus(upstreamNodeID) + + if upstreamNodeStatus.IsDirty() { + return PredicatePhaseNotReady, nil + } + + if parentNodeID != nil && *parentNodeID == upstreamNodeID { + upstreamNode, ok := w.GetNode(upstreamNodeID) + if !ok { + return PredicatePhaseUndefined, errors.Errorf(errors.BadSpecificationError, nodeID, "Upstream node [%v] of node [%v] not defined", upstreamNodeID, nodeID) + } + // This only happens if current node is the child node of a branch node + if upstreamNode.GetBranchNode() == nil || upstreamNodeStatus.GetOrCreateBranchStatus().GetPhase() != v1alpha1.BranchNodeSuccess { + logger.Debugf(ctx, "Branch sub node is expected to have parent branch node in succeeded state") + return PredicatePhaseUndefined, errors.Errorf(errors.IllegalStateError, nodeID, "Upstream node [%v] is set as parent, but is not a branch node of [%v] or in illegal state.", upstreamNodeID, nodeID) + } + continue + } + + if upstreamNodeStatus.GetPhase() == v1alpha1.NodePhaseSkipped { + skipped = true + } else if upstreamNodeStatus.GetPhase() != v1alpha1.NodePhaseSucceeded { + return PredicatePhaseNotReady, nil + } + } + if skipped { + return PredicatePhaseSkip, nil + } + return PredicatePhaseReady, nil +} + +func GetParentNodeMaxEndTime(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.BaseNode) (t v1.Time, err error) { + zeroTime := v1.NewTime(time.Time{}) + nodeID := node.GetID() + if nodeID == v1alpha1.StartNodeID { + logger.Debugf(ctx, "Start Node id is assumed to be ready.") + return zeroTime, nil + } + + nodeStatus := w.GetNodeExecutionStatus(node.GetID()) + parentNodeID := nodeStatus.GetParentNodeID() + upstreamNodes, ok := w.GetConnections().UpstreamEdges[nodeID] + if !ok { + return zeroTime, errors.Errorf(errors.BadSpecificationError, nodeID, "Unable to find upstream nodes for Node") + } + + var latest v1.Time + for _, upstreamNodeID := range upstreamNodes { + upstreamNodeStatus := w.GetNodeExecutionStatus(upstreamNodeID) + if parentNodeID != nil && *parentNodeID == upstreamNodeID { + upstreamNode, ok := w.GetNode(upstreamNodeID) + if !ok { + return zeroTime, errors.Errorf(errors.BadSpecificationError, nodeID, "Upstream node [%v] of node [%v] not defined", upstreamNodeID, nodeID) + } + + // This only happens if current node is the child node of a branch node + if upstreamNode.GetBranchNode() == nil || upstreamNodeStatus.GetOrCreateBranchStatus().GetPhase() != v1alpha1.BranchNodeSuccess { + logger.Debugf(ctx, "Branch sub node is expected to have parent branch node in succeeded state") + return zeroTime, errors.Errorf(errors.IllegalStateError, nodeID, "Upstream node [%v] is set as parent, but is not a branch node of [%v] or in illegal state.", upstreamNodeID, nodeID) + } + + continue + } + + if stoppedAt := upstreamNodeStatus.GetStoppedAt(); stoppedAt != nil && stoppedAt.Unix() > latest.Unix() { + latest = *upstreamNodeStatus.GetStoppedAt() + } + } + + return latest, nil +} diff --git a/flytepropeller/pkg/controller/nodes/predicate_test.go b/flytepropeller/pkg/controller/nodes/predicate_test.go new file mode 100644 index 0000000000..ae151010b1 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/predicate_test.go @@ -0,0 +1,550 @@ +package nodes + +import ( + "context" + "testing" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" + "github.com/stretchr/testify/assert" +) + +func TestCanExecute(t *testing.T) { + nodeN0 := "n0" + nodeN1 := "n1" + nodeN2 := "n2" + ctx := context.Background() + connections := &v1alpha1.Connections{ + UpstreamEdges: map[v1alpha1.NodeID][]v1alpha1.NodeID{ + nodeN2: {nodeN0, nodeN1}, + }, + } + + // Table tests are not really helpful here, so we decided against it + + t.Run("startNode", func(t *testing.T) { + mockNode := &mocks.BaseNode{} + mockNode.On("GetID").Return(v1alpha1.StartNodeID) + p, err := CanExecute(ctx, nil, mockNode) + assert.NoError(t, err) + assert.Equal(t, PredicatePhaseReady, p) + }) + + t.Run("noUpstreamConnection", func(t *testing.T) { + // Setup + mockNodeStatus := &mocks.ExecutableNodeStatus{} + // No parent node + mockNodeStatus.On("GetParentNodeID").Return(nil) + mockNode := &mocks.BaseNode{} + mockNode.On("GetID").Return(nodeN2) + mockWf := &mocks.ExecutableWorkflow{} + mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockNodeStatus) + mockWf.On("GetConnections").Return(&v1alpha1.Connections{}) + mockWf.On("GetID").Return("w1") + + p, err := CanExecute(ctx, mockWf, mockNode) + assert.Error(t, err) + assert.Equal(t, PredicatePhaseUndefined, p) + }) + + t.Run("upstreamConnectionsNotReady", func(t *testing.T) { + // Setup + mockN2Status := &mocks.ExecutableNodeStatus{} + // No parent node + mockN2Status.On("GetParentNodeID").Return(nil) + mockN2Status.On("IsDirty").Return(false) + mockNode := &mocks.BaseNode{} + mockNode.On("GetID").Return(nodeN2) + + mockN0Status := &mocks.ExecutableNodeStatus{} + mockN0Status.On("GetPhase").Return(v1alpha1.NodePhaseRunning) + mockN0Status.On("IsDirty").Return(false) + + mockN1Status := &mocks.ExecutableNodeStatus{} + mockN1Status.On("GetPhase").Return(v1alpha1.NodePhaseRunning) + mockN1Status.On("IsDirty").Return(false) + + mockWf := &mocks.ExecutableWorkflow{} + mockWf.On("GetNodeExecutionStatus", nodeN0).Return(mockN0Status) + mockWf.On("GetNodeExecutionStatus", nodeN1).Return(mockN1Status) + mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockN2Status) + mockWf.On("GetConnections").Return(connections) + mockWf.On("GetID").Return("w1") + + p, err := CanExecute(ctx, mockWf, mockNode) + assert.NoError(t, err) + assert.Equal(t, PredicatePhaseNotReady, p) + }) + + t.Run("upstreamConnectionsPartialReady", func(t *testing.T) { + // Setup + mockN2Status := &mocks.ExecutableNodeStatus{} + // No parent node + mockN2Status.On("GetParentNodeID").Return(nil) + mockN2Status.On("IsDirty").Return(false) + + mockNode := &mocks.BaseNode{} + mockNode.On("GetID").Return(nodeN2) + + mockN0Status := &mocks.ExecutableNodeStatus{} + mockN0Status.On("GetPhase").Return(v1alpha1.NodePhaseRunning) + mockN0Status.On("IsDirty").Return(false) + + mockN1Status := &mocks.ExecutableNodeStatus{} + mockN1Status.On("GetPhase").Return(v1alpha1.NodePhaseSucceeded) + mockN1Status.On("IsDirty").Return(false) + + mockWf := &mocks.ExecutableWorkflow{} + mockWf.On("GetNodeExecutionStatus", nodeN0).Return(mockN0Status) + mockWf.On("GetNodeExecutionStatus", nodeN1).Return(mockN1Status) + mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockN2Status) + mockWf.On("GetConnections").Return(connections) + mockWf.On("GetID").Return("w1") + + p, err := CanExecute(ctx, mockWf, mockNode) + assert.NoError(t, err) + assert.Equal(t, PredicatePhaseNotReady, p) + }) + + t.Run("upstreamConnectionsCompletelyReady", func(t *testing.T) { + // Setup + mockN2Status := &mocks.ExecutableNodeStatus{} + // No parent node + mockN2Status.On("GetParentNodeID").Return(nil) + mockN2Status.On("IsDirty").Return(false) + + mockNode := &mocks.BaseNode{} + mockNode.On("GetID").Return(nodeN2) + + mockN0Status := &mocks.ExecutableNodeStatus{} + mockN0Status.On("GetPhase").Return(v1alpha1.NodePhaseSucceeded) + mockN0Status.On("IsDirty").Return(false) + + mockN1Status := &mocks.ExecutableNodeStatus{} + mockN1Status.On("GetPhase").Return(v1alpha1.NodePhaseSucceeded) + mockN1Status.On("IsDirty").Return(false) + + mockWf := &mocks.ExecutableWorkflow{} + mockWf.On("GetNodeExecutionStatus", nodeN0).Return(mockN0Status) + mockWf.On("GetNodeExecutionStatus", nodeN1).Return(mockN1Status) + mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockN2Status) + mockWf.On("GetConnections").Return(connections) + mockWf.On("GetID").Return("w1") + + p, err := CanExecute(ctx, mockWf, mockNode) + assert.NoError(t, err) + assert.Equal(t, PredicatePhaseReady, p) + }) + + t.Run("upstreamConnectionsDirty", func(t *testing.T) { + // Setup + mockN2Status := &mocks.ExecutableNodeStatus{} + // No parent node + mockN2Status.On("GetParentNodeID").Return(nil) + mockN2Status.On("IsDirty").Return(false) + + mockNode := &mocks.BaseNode{} + mockNode.On("GetID").Return(nodeN2) + + mockN0Status := &mocks.ExecutableNodeStatus{} + mockN0Status.On("GetPhase").Return(v1alpha1.NodePhaseSucceeded) + mockN0Status.On("IsDirty").Return(false) + + mockN1Status := &mocks.ExecutableNodeStatus{} + mockN1Status.On("GetPhase").Return(v1alpha1.NodePhaseSucceeded) + mockN1Status.On("IsDirty").Return(true) + + mockWf := &mocks.ExecutableWorkflow{} + mockWf.On("GetNodeExecutionStatus", nodeN0).Return(mockN0Status) + mockWf.On("GetNodeExecutionStatus", nodeN1).Return(mockN1Status) + mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockN2Status) + mockWf.On("GetConnections").Return(connections) + mockWf.On("GetID").Return("w1") + + p, err := CanExecute(ctx, mockWf, mockNode) + assert.NoError(t, err) + assert.Equal(t, PredicatePhaseNotReady, p) + }) + + t.Run("upstreamConnectionsPartialSkipped", func(t *testing.T) { + // Setup + mockN2Status := &mocks.ExecutableNodeStatus{} + // No parent node + mockN2Status.On("GetParentNodeID").Return(nil) + mockN2Status.On("IsDirty").Return(false) + + mockNode := &mocks.BaseNode{} + mockNode.On("GetID").Return(nodeN2) + + mockN0Status := &mocks.ExecutableNodeStatus{} + mockN0Status.On("GetPhase").Return(v1alpha1.NodePhaseRunning) + mockN0Status.On("IsDirty").Return(false) + + mockN1Status := &mocks.ExecutableNodeStatus{} + mockN1Status.On("GetPhase").Return(v1alpha1.NodePhaseSkipped) + mockN1Status.On("IsDirty").Return(false) + + mockWf := &mocks.ExecutableWorkflow{} + mockWf.On("GetNodeExecutionStatus", nodeN0).Return(mockN0Status) + mockWf.On("GetNodeExecutionStatus", nodeN1).Return(mockN1Status) + mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockN2Status) + mockWf.On("GetConnections").Return(connections) + mockWf.On("GetID").Return("w1") + + p, err := CanExecute(ctx, mockWf, mockNode) + assert.NoError(t, err) + assert.Equal(t, PredicatePhaseNotReady, p) + }) + + t.Run("upstreamConnectionsOneSkipped", func(t *testing.T) { + // Setup + mockN2Status := &mocks.ExecutableNodeStatus{} + // No parent node + mockN2Status.On("GetParentNodeID").Return(nil) + mockN2Status.On("IsDirty").Return(false) + + mockNode := &mocks.BaseNode{} + mockNode.On("GetID").Return(nodeN2) + + mockN0Status := &mocks.ExecutableNodeStatus{} + mockN0Status.On("GetPhase").Return(v1alpha1.NodePhaseSucceeded) + mockN0Status.On("IsDirty").Return(false) + + mockN1Status := &mocks.ExecutableNodeStatus{} + mockN1Status.On("GetPhase").Return(v1alpha1.NodePhaseSkipped) + mockN1Status.On("IsDirty").Return(false) + + mockWf := &mocks.ExecutableWorkflow{} + mockWf.On("GetNodeExecutionStatus", nodeN0).Return(mockN0Status) + mockWf.On("GetNodeExecutionStatus", nodeN1).Return(mockN1Status) + mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockN2Status) + mockWf.On("GetConnections").Return(connections) + mockWf.On("GetID").Return("w1") + + p, err := CanExecute(ctx, mockWf, mockNode) + assert.NoError(t, err) + assert.Equal(t, PredicatePhaseSkip, p) + }) + + t.Run("upstreamConnectionsAllSkipped", func(t *testing.T) { + // Setup + mockN2Status := &mocks.ExecutableNodeStatus{} + // No parent node + mockN2Status.On("GetParentNodeID").Return(nil) + mockN2Status.On("IsDirty").Return(false) + + mockNode := &mocks.BaseNode{} + mockNode.On("GetID").Return(nodeN2) + + mockN0Status := &mocks.ExecutableNodeStatus{} + mockN0Status.On("GetPhase").Return(v1alpha1.NodePhaseSkipped) + mockN0Status.On("IsDirty").Return(false) + + mockN1Status := &mocks.ExecutableNodeStatus{} + mockN1Status.On("GetPhase").Return(v1alpha1.NodePhaseSkipped) + mockN1Status.On("IsDirty").Return(false) + + mockWf := &mocks.ExecutableWorkflow{} + mockWf.On("GetNodeExecutionStatus", nodeN0).Return(mockN0Status) + mockWf.On("GetNodeExecutionStatus", nodeN1).Return(mockN1Status) + mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockN2Status) + mockWf.On("GetConnections").Return(connections) + mockWf.On("GetID").Return("w1") + + p, err := CanExecute(ctx, mockWf, mockNode) + assert.NoError(t, err) + assert.Equal(t, PredicatePhaseSkip, p) + }) + + // Failed should never happen for predicate check. Hence we return not ready + t.Run("upstreamConnectionsFailed", func(t *testing.T) { + // Setup + mockN2Status := &mocks.ExecutableNodeStatus{} + // No parent node + mockN2Status.On("GetParentNodeID").Return(nil) + mockN2Status.On("IsDirty").Return(false) + + mockNode := &mocks.BaseNode{} + mockNode.On("GetID").Return(nodeN2) + + mockN0Status := &mocks.ExecutableNodeStatus{} + mockN0Status.On("GetPhase").Return(v1alpha1.NodePhaseFailed) + mockN0Status.On("IsDirty").Return(false) + + mockN1Status := &mocks.ExecutableNodeStatus{} + mockN1Status.On("GetPhase").Return(v1alpha1.NodePhaseFailed) + mockN1Status.On("IsDirty").Return(false) + + mockWf := &mocks.ExecutableWorkflow{} + mockWf.On("GetNodeExecutionStatus", nodeN0).Return(mockN0Status) + mockWf.On("GetNodeExecutionStatus", nodeN1).Return(mockN1Status) + mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockN2Status) + mockWf.On("GetConnections").Return(connections) + mockWf.On("GetID").Return("w1") + + p, err := CanExecute(ctx, mockWf, mockNode) + assert.NoError(t, err) + assert.Equal(t, PredicatePhaseNotReady, p) + }) + + // Branch node tests + + // ParentNode not found? + t.Run("upstreamConnectionsParentNodeNotFound", func(t *testing.T) { + // Setup + mockN2Status := &mocks.ExecutableNodeStatus{} + // No parent node + mockN2Status.On("GetParentNodeID").Return(&nodeN0) + mockN2Status.On("IsDirty").Return(false) + + mockNode := &mocks.BaseNode{} + mockNode.On("GetID").Return(nodeN2) + + mockN0Status := &mocks.ExecutableNodeStatus{} + mockN0Status.On("GetPhase").Return(v1alpha1.NodePhaseSucceeded) + mockN0Status.On("IsDirty").Return(false) + + mockN1Status := &mocks.ExecutableNodeStatus{} + mockN1Status.On("GetPhase").Return(v1alpha1.NodePhaseSucceeded) + mockN1Status.On("IsDirty").Return(false) + + mockWf := &mocks.ExecutableWorkflow{} + mockWf.On("GetNodeExecutionStatus", nodeN0).Return(mockN0Status) + mockWf.On("GetNodeExecutionStatus", nodeN1).Return(mockN1Status) + mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockN2Status) + mockWf.On("GetConnections").Return(connections) + mockWf.On("GetNode", nodeN0).Return(nil, false) + mockWf.On("GetID").Return("w1") + + p, err := CanExecute(ctx, mockWf, mockNode) + assert.Error(t, err) + assert.Equal(t, PredicatePhaseUndefined, p) + }) + + // ParentNode has no branch node + t.Run("upstreamConnectionsParentHasNoBranch", func(t *testing.T) { + // Setup + mockN2Status := &mocks.ExecutableNodeStatus{} + // No parent node + mockN2Status.On("GetParentNodeID").Return(&nodeN0) + mockN2Status.On("IsDirty").Return(false) + + mockNode := &mocks.BaseNode{} + mockNode.On("GetID").Return(nodeN2) + + mockN0Node := &mocks.ExecutableNode{} + mockN0Node.On("GetBranchNode").Return(nil) + mockN0Status := &mocks.ExecutableNodeStatus{} + mockN0Status.On("GetPhase").Return(v1alpha1.NodePhaseSucceeded) + mockN0Status.On("IsDirty").Return(false) + + mockN1Status := &mocks.ExecutableNodeStatus{} + mockN1Status.On("GetPhase").Return(v1alpha1.NodePhaseSucceeded) + mockN1Status.On("IsDirty").Return(false) + + mockWf := &mocks.ExecutableWorkflow{} + mockWf.On("GetNodeExecutionStatus", nodeN0).Return(mockN0Status) + mockWf.On("GetNodeExecutionStatus", nodeN1).Return(mockN1Status) + mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockN2Status) + mockWf.On("GetConnections").Return(connections) + mockWf.On("GetNode", nodeN0).Return(mockN0Node, true) + mockWf.On("GetID").Return("w1") + + p, err := CanExecute(ctx, mockWf, mockNode) + assert.Error(t, err) + assert.Equal(t, PredicatePhaseUndefined, p) + }) + + // ParentNode branch not ready + t.Run("upstreamConnectionsBranchNodeNotReady", func(t *testing.T) { + // Setup + mockN2Status := &mocks.ExecutableNodeStatus{} + // No parent node + mockN2Status.On("GetParentNodeID").Return(&nodeN0) + mockN2Status.On("IsDirty").Return(false) + + mockNode := &mocks.BaseNode{} + mockNode.On("GetID").Return(nodeN2) + + mockN0BranchStatus := &mocks.MutableBranchNodeStatus{} + mockN0BranchStatus.On("GetPhase").Return(v1alpha1.BranchNodeNotYetEvaluated) + mockN0BranchNode := &mocks.ExecutableBranchNode{} + mockN0Node := &mocks.ExecutableNode{} + mockN0Node.On("GetBranchNode").Return(mockN0BranchNode) + mockN0Status := &mocks.ExecutableNodeStatus{} + mockN0Status.On("GetPhase").Return(v1alpha1.NodePhaseSucceeded) + mockN0Status.On("GetOrCreateBranchStatus").Return(mockN0BranchStatus) + mockN0Status.On("IsDirty").Return(false) + + mockN1Status := &mocks.ExecutableNodeStatus{} + mockN1Status.On("GetPhase").Return(v1alpha1.NodePhaseSucceeded) + mockN1Status.On("IsDirty").Return(false) + + mockWf := &mocks.ExecutableWorkflow{} + mockWf.On("GetNodeExecutionStatus", nodeN0).Return(mockN0Status) + mockWf.On("GetNodeExecutionStatus", nodeN1).Return(mockN1Status) + mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockN2Status) + mockWf.On("GetConnections").Return(connections) + mockWf.On("GetNode", nodeN0).Return(mockN0Node, true) + mockWf.On("GetID").Return("w1") + + p, err := CanExecute(ctx, mockWf, mockNode) + assert.Error(t, err) + assert.Equal(t, PredicatePhaseUndefined, p) + }) + + // ParentNode branch is errored + t.Run("upstreamConnectionsBranchNodeError", func(t *testing.T) { + // Setup + mockN2Status := &mocks.ExecutableNodeStatus{} + // No parent node + mockN2Status.On("GetParentNodeID").Return(&nodeN0) + mockN2Status.On("IsDirty").Return(false) + + mockNode := &mocks.BaseNode{} + mockNode.On("GetID").Return(nodeN2) + + mockN0BranchStatus := &mocks.MutableBranchNodeStatus{} + mockN0BranchStatus.On("GetPhase").Return(v1alpha1.BranchNodeError) + mockN0BranchNode := &mocks.ExecutableBranchNode{} + mockN0Node := &mocks.ExecutableNode{} + mockN0Node.On("GetBranchNode").Return(mockN0BranchNode) + mockN0Status := &mocks.ExecutableNodeStatus{} + mockN0Status.On("GetPhase").Return(v1alpha1.NodePhaseSucceeded) + mockN0Status.On("GetOrCreateBranchStatus").Return(mockN0BranchStatus) + mockN0Status.On("IsDirty").Return(false) + + mockN1Status := &mocks.ExecutableNodeStatus{} + mockN1Status.On("GetPhase").Return(v1alpha1.NodePhaseSucceeded) + mockN1Status.On("IsDirty").Return(false) + + mockWf := &mocks.ExecutableWorkflow{} + mockWf.On("GetNodeExecutionStatus", nodeN0).Return(mockN0Status) + mockWf.On("GetNodeExecutionStatus", nodeN1).Return(mockN1Status) + mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockN2Status) + mockWf.On("GetConnections").Return(connections) + mockWf.On("GetNode", nodeN0).Return(mockN0Node, true) + mockWf.On("GetID").Return("w1") + + p, err := CanExecute(ctx, mockWf, mockNode) + assert.Error(t, err) + assert.Equal(t, PredicatePhaseUndefined, p) + }) + + // ParentNode branch ready + t.Run("upstreamConnectionsBranchSuccessOtherSuccess", func(t *testing.T) { + // Setup + mockN2Status := &mocks.ExecutableNodeStatus{} + // No parent node + mockN2Status.On("GetParentNodeID").Return(&nodeN0) + mockN2Status.On("IsDirty").Return(false) + + mockNode := &mocks.BaseNode{} + mockNode.On("GetID").Return(nodeN2) + + mockN0BranchStatus := &mocks.MutableBranchNodeStatus{} + mockN0BranchStatus.On("GetPhase").Return(v1alpha1.BranchNodeSuccess) + mockN0BranchNode := &mocks.ExecutableBranchNode{} + + mockN0Node := &mocks.ExecutableNode{} + mockN0Node.On("GetBranchNode").Return(mockN0BranchNode) + mockN0Status := &mocks.ExecutableNodeStatus{} + mockN0Status.On("GetPhase").Return(v1alpha1.NodePhaseSucceeded) + mockN0Status.On("GetOrCreateBranchStatus").Return(mockN0BranchStatus) + mockN0Status.On("IsDirty").Return(false) + + mockN1Status := &mocks.ExecutableNodeStatus{} + mockN1Status.On("GetPhase").Return(v1alpha1.NodePhaseSucceeded) + mockN1Status.On("IsDirty").Return(false) + + mockWf := &mocks.ExecutableWorkflow{} + mockWf.On("GetNodeExecutionStatus", nodeN0).Return(mockN0Status) + mockWf.On("GetNodeExecutionStatus", nodeN1).Return(mockN1Status) + mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockN2Status) + mockWf.On("GetConnections").Return(connections) + mockWf.On("GetNode", nodeN0).Return(mockN0Node, true) + mockWf.On("GetID").Return("w1") + + p, err := CanExecute(ctx, mockWf, mockNode) + assert.NoError(t, err) + assert.Equal(t, PredicatePhaseReady, p) + }) + + // ParentNode branch ready + t.Run("upstreamConnectionsBranchSuccessOtherSkipped", func(t *testing.T) { + // Setup + mockN2Status := &mocks.ExecutableNodeStatus{} + // No parent node + mockN2Status.On("GetParentNodeID").Return(&nodeN0) + mockN2Status.On("IsDirty").Return(false) + + mockNode := &mocks.BaseNode{} + mockNode.On("GetID").Return(nodeN2) + + mockN0BranchStatus := &mocks.MutableBranchNodeStatus{} + mockN0BranchStatus.On("GetPhase").Return(v1alpha1.BranchNodeSuccess) + + mockN0BranchNode := &mocks.ExecutableBranchNode{} + mockN0Node := &mocks.ExecutableNode{} + mockN0Node.On("GetBranchNode").Return(mockN0BranchNode) + mockN0Status := &mocks.ExecutableNodeStatus{} + mockN0Status.On("GetPhase").Return(v1alpha1.NodePhaseSucceeded) + mockN0Status.On("GetOrCreateBranchStatus").Return(mockN0BranchStatus) + mockN0Status.On("IsDirty").Return(false) + + mockN1Status := &mocks.ExecutableNodeStatus{} + mockN1Status.On("GetPhase").Return(v1alpha1.NodePhaseSkipped) + mockN1Status.On("IsDirty").Return(false) + + mockWf := &mocks.ExecutableWorkflow{} + mockWf.On("GetNodeExecutionStatus", nodeN0).Return(mockN0Status) + mockWf.On("GetNodeExecutionStatus", nodeN1).Return(mockN1Status) + mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockN2Status) + mockWf.On("GetConnections").Return(connections) + mockWf.On("GetNode", nodeN0).Return(mockN0Node, true) + mockWf.On("GetID").Return("w1") + + p, err := CanExecute(ctx, mockWf, mockNode) + assert.NoError(t, err) + assert.Equal(t, PredicatePhaseSkip, p) + }) + + // ParentNode branch ready + t.Run("upstreamConnectionsBranchSuccessOtherRunning", func(t *testing.T) { + // Setup + mockN2Status := &mocks.ExecutableNodeStatus{} + // No parent node + mockN2Status.On("GetParentNodeID").Return(&nodeN0) + mockN2Status.On("IsDirty").Return(false) + + mockNode := &mocks.BaseNode{} + mockNode.On("GetID").Return(nodeN2) + + mockN0BranchStatus := &mocks.MutableBranchNodeStatus{} + mockN0BranchStatus.On("GetPhase").Return(v1alpha1.BranchNodeSuccess) + + mockN0BranchNode := &mocks.ExecutableBranchNode{} + mockN0Node := &mocks.ExecutableNode{} + mockN0Node.On("GetBranchNode").Return(mockN0BranchNode) + mockN0Status := &mocks.ExecutableNodeStatus{} + mockN0Status.On("GetPhase").Return(v1alpha1.NodePhaseSucceeded) + mockN0Status.On("GetOrCreateBranchStatus").Return(mockN0BranchStatus) + mockN0Status.On("IsDirty").Return(false) + + mockN1Status := &mocks.ExecutableNodeStatus{} + mockN1Status.On("GetPhase").Return(v1alpha1.NodePhaseRunning) + mockN1Status.On("IsDirty").Return(false) + + mockWf := &mocks.ExecutableWorkflow{} + mockWf.On("GetNodeExecutionStatus", nodeN0).Return(mockN0Status) + mockWf.On("GetNodeExecutionStatus", nodeN1).Return(mockN1Status) + mockWf.On("GetNodeExecutionStatus", nodeN2).Return(mockN2Status) + mockWf.On("GetConnections").Return(connections) + mockWf.On("GetNode", nodeN0).Return(mockN0Node, true) + mockWf.On("GetID").Return("w1") + + p, err := CanExecute(ctx, mockWf, mockNode) + assert.NoError(t, err) + assert.Equal(t, PredicatePhaseNotReady, p) + }) +} diff --git a/flytepropeller/pkg/controller/nodes/resolve.go b/flytepropeller/pkg/controller/nodes/resolve.go new file mode 100644 index 0000000000..c0a2472676 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/resolve.go @@ -0,0 +1,104 @@ +package nodes + +import ( + "context" + + "github.com/lyft/flytepropeller/pkg/controller/nodes/common" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/controller/nodes/errors" + "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" + "github.com/lyft/flytestdlib/storage" +) + +func ResolveBindingData(ctx context.Context, h HandlerFactory, w v1alpha1.ExecutableWorkflow, bindingData *core.BindingData, store storage.ProtobufStore) (*core.Literal, error) { + literal := &core.Literal{} + if bindingData == nil { + return nil, nil + } + switch bindingData.GetValue().(type) { + case *core.BindingData_Collection: + literalCollection := make([]*core.Literal, 0, len(bindingData.GetCollection().GetBindings())) + for _, b := range bindingData.GetCollection().GetBindings() { + l, err := ResolveBindingData(ctx, h, w, b, store) + if err != nil { + return nil, err + } + + literalCollection = append(literalCollection, l) + } + literal.Value = &core.Literal_Collection{ + Collection: &core.LiteralCollection{ + Literals: literalCollection, + }, + } + case *core.BindingData_Map: + literalMap := make(map[string]*core.Literal, len(bindingData.GetMap().GetBindings())) + for k, v := range bindingData.GetMap().GetBindings() { + l, err := ResolveBindingData(ctx, h, w, v, store) + if err != nil { + return nil, err + } + + literalMap[k] = l + } + literal.Value = &core.Literal_Map{ + Map: &core.LiteralMap{ + Literals: literalMap, + }, + } + case *core.BindingData_Promise: + upstreamNodeID := bindingData.GetPromise().GetNodeId() + bindToVar := bindingData.GetPromise().GetVar() + if w == nil { + return nil, errors.Errorf(errors.IllegalStateError, upstreamNodeID, + "Trying to resolve output from previous node, without providing the workflow for variable [%s]", + bindToVar) + } + if upstreamNodeID == "" { + return nil, errors.Errorf(errors.BadSpecificationError, "missing", + "No nodeId (missing) specified for binding in Workflow.") + } + n, ok := w.GetNode(upstreamNodeID) + if !ok { + return nil, errors.Errorf(errors.IllegalStateError, w.GetID(), upstreamNodeID, + "Undefined node in Workflow") + } + + nodeHandler, err := h.GetHandler(n.GetKind()) + if err != nil { + return nil, errors.Wrapf(errors.CausedByError, n.GetID(), err, "Failed to find handler for node kind [%v]", n.GetKind()) + } + + resolver, casted := nodeHandler.(handler.OutputResolver) + if !casted { + // If the handler doesn't implement output resolver, use simple resolver which expects an outputs.pb at the + // output location of the task. + if store == nil { + return nil, errors.Errorf(errors.IllegalStateError, w.GetID(), n.GetID(), "System error. Promise lookup without store.") + } + + resolver = common.NewSimpleOutputsResolver(store) + } + + return resolver.ExtractOutput(ctx, w, n, bindToVar) + case *core.BindingData_Scalar: + literal.Value = &core.Literal_Scalar{Scalar: bindingData.GetScalar()} + } + return literal, nil +} + +func Resolve(ctx context.Context, h HandlerFactory, w v1alpha1.ExecutableWorkflow, nodeID v1alpha1.NodeID, bindings []*v1alpha1.Binding, store storage.ProtobufStore) (*handler.Data, error) { + literalMap := make(map[string]*core.Literal, len(bindings)) + for _, binding := range bindings { + l, err := ResolveBindingData(ctx, h, w, binding.GetBinding(), store) + if err != nil { + return nil, errors.Wrapf(errors.BindingResolutionError, nodeID, err, "Error binding Var [%v].[%v]", w.GetID(), binding.GetVar()) + } + literalMap[binding.GetVar()] = l + } + return &core.LiteralMap{ + Literals: literalMap, + }, nil +} diff --git a/flytepropeller/pkg/controller/nodes/resolve_test.go b/flytepropeller/pkg/controller/nodes/resolve_test.go new file mode 100644 index 0000000000..3e5c98032b --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/resolve_test.go @@ -0,0 +1,434 @@ +package nodes + +import ( + "context" + "fmt" + "testing" + + mocks2 "github.com/lyft/flytepropeller/pkg/controller/nodes/handler/mocks" + "github.com/lyft/flytepropeller/pkg/controller/nodes/mocks" + "github.com/stretchr/testify/mock" + + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/utils" + flyteassert "github.com/lyft/flytepropeller/pkg/utils/assert" + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/storage" + "github.com/stretchr/testify/assert" +) + +var testScope = promutils.NewScope("test") + +type dummyBaseWorkflow struct { + DummyStartNode v1alpha1.ExecutableNode + ID v1alpha1.WorkflowID + FromNodeCb func(name v1alpha1.NodeID) ([]v1alpha1.NodeID, error) + GetNodeCb func(nodeId v1alpha1.NodeID) (v1alpha1.ExecutableNode, bool) + Status map[v1alpha1.NodeID]*v1alpha1.NodeStatus +} + +func (d *dummyBaseWorkflow) GetOutputBindings() []*v1alpha1.Binding { + return []*v1alpha1.Binding{} +} + +func (d *dummyBaseWorkflow) GetOnFailureNode() v1alpha1.ExecutableNode { + return nil +} + +func (d *dummyBaseWorkflow) GetNodes() []v1alpha1.NodeID { + return []v1alpha1.NodeID{d.DummyStartNode.GetID()} +} + +func (d *dummyBaseWorkflow) GetConnections() *v1alpha1.Connections { + return &v1alpha1.Connections{} +} + +func (d *dummyBaseWorkflow) GetOutputs() *v1alpha1.OutputVarMap { + return &v1alpha1.OutputVarMap{} +} + +func (d *dummyBaseWorkflow) GetExecutionID() v1alpha1.ExecutionID { + return v1alpha1.ExecutionID{ + WorkflowExecutionIdentifier: &core.WorkflowExecutionIdentifier{ + Name: "test", + }, + } +} + +func (d *dummyBaseWorkflow) GetK8sWorkflowID() types.NamespacedName { + return types.NamespacedName{ + Name: "WF_Name", + } +} + +func (d *dummyBaseWorkflow) NewControllerRef() v1.OwnerReference { + return v1.OwnerReference{} +} + +func (d *dummyBaseWorkflow) GetNamespace() string { + return d.GetK8sWorkflowID().Namespace +} + +func (d *dummyBaseWorkflow) GetCreationTimestamp() v1.Time { + return v1.Now() +} + +func (d *dummyBaseWorkflow) GetAnnotations() map[string]string { + return map[string]string{} +} + +func (d *dummyBaseWorkflow) GetLabels() map[string]string { + return map[string]string{} +} + +func (d *dummyBaseWorkflow) GetName() string { + return d.ID +} + +func (d *dummyBaseWorkflow) GetServiceAccountName() string { + return "" +} + +func (d *dummyBaseWorkflow) GetTask(id v1alpha1.TaskID) (v1alpha1.ExecutableTask, error) { + return nil, nil +} + +func (d *dummyBaseWorkflow) FindSubWorkflow(subID v1alpha1.WorkflowID) v1alpha1.ExecutableSubWorkflow { + return nil +} + +func (d *dummyBaseWorkflow) GetExecutionStatus() v1alpha1.ExecutableWorkflowStatus { + return nil +} + +func (d *dummyBaseWorkflow) GetNodeExecutionStatus(id v1alpha1.NodeID) v1alpha1.ExecutableNodeStatus { + n, ok := d.Status[id] + if ok { + return n + } + n = &v1alpha1.NodeStatus{} + d.Status[id] = n + return n +} + +func (d *dummyBaseWorkflow) StartNode() v1alpha1.ExecutableNode { + return d.DummyStartNode +} + +func (d *dummyBaseWorkflow) GetID() v1alpha1.WorkflowID { + return d.ID +} + +func (d *dummyBaseWorkflow) FromNode(name v1alpha1.NodeID) ([]v1alpha1.NodeID, error) { + return d.FromNodeCb(name) +} + +func (d *dummyBaseWorkflow) GetNode(nodeID v1alpha1.NodeID) (v1alpha1.ExecutableNode, bool) { + return d.GetNodeCb(nodeID) +} + +func createDummyBaseWorkflow() *dummyBaseWorkflow { + return &dummyBaseWorkflow{ + ID: "w1", + Status: map[v1alpha1.NodeID]*v1alpha1.NodeStatus{ + v1alpha1.StartNodeID: {}, + }, + } +} + +func createInmemoryDataStore(t testing.TB, scope promutils.Scope) *storage.DataStore { + cfg := storage.Config{ + Type: storage.TypeMemory, + } + d, err := storage.NewDataStore(&cfg, scope) + assert.NoError(t, err) + return d +} + +func createFailingDatastore(_ testing.TB, scope promutils.Scope) *storage.DataStore { + return storage.NewCompositeDataStore(storage.URLPathConstructor{}, storage.NewDefaultProtobufStore(utils.FailingRawStore{}, scope)) +} + +func TestResolveBindingData(t *testing.T) { + ctx := context.Background() + outputRef := v1alpha1.DataReference("output-ref") + n1 := &v1alpha1.NodeSpec{ + ID: "n1", + OutputAliases: []v1alpha1.Alias{ + {Alias: core.Alias{ + Var: "x", + Alias: "m", + }}, + }, + } + + n2 := &v1alpha1.NodeSpec{ + ID: "n2", + OutputAliases: []v1alpha1.Alias{ + {Alias: core.Alias{ + Var: "x", + Alias: "m", + }}, + }, + } + + outputPath := v1alpha1.GetOutputsFile(outputRef) + + w := &dummyBaseWorkflow{ + Status: map[v1alpha1.NodeID]*v1alpha1.NodeStatus{ + "n2": { + DataDir: outputRef, + }, + }, + GetNodeCb: func(nodeId v1alpha1.NodeID) (v1alpha1.ExecutableNode, bool) { + switch nodeId { + case "n1": + return n1, true + case "n2": + return n2, true + } + return nil, false + }, + } + + hf := &mocks.HandlerFactory{} + h := &mocks2.IFace{} + hf.On("GetHandler", mock.Anything).Return(h, nil) + h.On("ExtractOutput", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, nil) + + t.Run("StaticBinding", func(t *testing.T) { + w := &dummyBaseWorkflow{} + b := utils.MustMakePrimitiveBindingData(1) + l, err := ResolveBindingData(ctx, hf, w, b, nil) + assert.NoError(t, err) + flyteassert.EqualLiterals(t, utils.MustMakeLiteral(1), l) + }) + + t.Run("PromiseMissingNode", func(t *testing.T) { + w := &dummyBaseWorkflow{ + GetNodeCb: func(nodeId v1alpha1.NodeID) (v1alpha1.ExecutableNode, bool) { + return nil, false + }, + } + b := utils.MakeBindingDataPromise("n1", "x") + _, err := ResolveBindingData(ctx, nil, w, b, nil) + assert.Error(t, err) + }) + + t.Run("PromiseMissingStore", func(t *testing.T) { + b := utils.MakeBindingDataPromise("n1", "x") + _, err := ResolveBindingData(ctx, hf, w, b, nil) + assert.Error(t, err) + }) + + t.Run("PromiseMissing", func(t *testing.T) { + store := createInmemoryDataStore(t, testScope.NewSubScope("1")) + b := utils.MakeBindingDataPromise("n1", "x") + _, err := ResolveBindingData(ctx, hf, w, b, store) + assert.Error(t, err) + }) + + t.Run("PromiseMissingWithData", func(t *testing.T) { + store := createInmemoryDataStore(t, testScope.NewSubScope("2")) + m, err := utils.MakeLiteralMap(map[string]interface{}{"z": 1}) + assert.NoError(t, err) + assert.NoError(t, store.WriteProtobuf(ctx, outputPath, storage.Options{}, m)) + b := utils.MakeBindingDataPromise("n1", "x") + _, err = ResolveBindingData(ctx, hf, w, b, store) + assert.Error(t, err) + }) + + t.Run("PromiseFound", func(t *testing.T) { + store := createInmemoryDataStore(t, testScope.NewSubScope("3")) + m, err := utils.MakeLiteralMap(map[string]interface{}{"x": 1}) + assert.NoError(t, err) + assert.NoError(t, store.WriteProtobuf(ctx, outputPath, storage.Options{}, m)) + + b := utils.MakeBindingDataPromise("n2", "x") + l, err := ResolveBindingData(ctx, hf, w, b, store) + if assert.NoError(t, err) { + flyteassert.EqualLiterals(t, utils.MustMakeLiteral(1), l) + } + }) + + t.Run("NullBinding", func(t *testing.T) { + l, err := ResolveBindingData(ctx, hf, w, nil, nil) + assert.NoError(t, err) + assert.Nil(t, l) + }) + + t.Run("NullWorkflowPromise", func(t *testing.T) { + store := createInmemoryDataStore(t, testScope.NewSubScope("4")) + m, err := utils.MakeLiteralMap(map[string]interface{}{"x": 1}) + assert.NoError(t, err) + assert.NoError(t, store.WriteProtobuf(ctx, outputPath, storage.Options{}, m)) + b := utils.MakeBindingDataPromise("n1", "x") + _, err = ResolveBindingData(ctx, nil, nil, b, store) + assert.Error(t, err) + }) + + t.Run("PromiseFoundAlias", func(t *testing.T) { + store := createInmemoryDataStore(t, testScope.NewSubScope("5")) + m, err := utils.MakeLiteralMap(map[string]interface{}{"x": 1}) + assert.NoError(t, err) + assert.NoError(t, store.WriteProtobuf(ctx, outputPath, storage.Options{}, m)) + b := utils.MakeBindingDataPromise("n2", "m") + l, err := ResolveBindingData(ctx, hf, w, b, store) + if assert.NoError(t, err) { + flyteassert.EqualLiterals(t, utils.MustMakeLiteral(1), l) + } + }) + + t.Run("BindingDataMap", func(t *testing.T) { + store := createInmemoryDataStore(t, testScope.NewSubScope("6")) + // Store output of previous + m, err := utils.MakeLiteralMap(map[string]interface{}{"x": 1}) + assert.NoError(t, err) + assert.NoError(t, store.WriteProtobuf(ctx, outputPath, storage.Options{}, m)) + m2 := &core.LiteralMap{} + assert.NoError(t, store.ReadProtobuf(ctx, outputPath, m2)) + // Output of current + b := utils.MakeBindingDataMap( + utils.NewPair("x", utils.MakeBindingDataPromise("n2", "x")), + utils.NewPair("z", utils.MustMakePrimitiveBindingData(5)), + ) + l, err := ResolveBindingData(ctx, hf, w, b, store) + if assert.NoError(t, err) { + expected, err := utils.MakeLiteralMap(map[string]interface{}{"x": 1, "z": 5}) + assert.NoError(t, err) + flyteassert.EqualLiteralMap(t, expected, l.GetMap()) + } + + }) + + t.Run("BindingDataMapFailedPromise", func(t *testing.T) { + store := createInmemoryDataStore(t, testScope.NewSubScope("7")) + // do not store anything + + // Output of current + b := utils.MakeBindingDataMap( + utils.NewPair("x", utils.MakeBindingDataPromise("n1", "x")), + utils.NewPair("z", utils.MustMakePrimitiveBindingData(5)), + ) + _, err := ResolveBindingData(ctx, hf, w, b, store) + assert.Error(t, err) + }) + + t.Run("BindingDataCollection", func(t *testing.T) { + store := createInmemoryDataStore(t, testScope.NewSubScope("8")) + // Store random value + m, err := utils.MakeLiteralMap(map[string]interface{}{"jj": 1}) + assert.NoError(t, err) + assert.NoError(t, store.WriteProtobuf(ctx, outputPath, storage.Options{}, m)) + + // binding of current npde + b := utils.MakeBindingDataCollection( + utils.MakeBindingDataPromise("n1", "x"), + utils.MustMakePrimitiveBindingData(5), + ) + _, err = ResolveBindingData(ctx, hf, w, b, store) + assert.Error(t, err) + + }) +} + +func TestResolve(t *testing.T) { + ctx := context.Background() + outputRef := v1alpha1.DataReference("output-ref") + n1 := &v1alpha1.NodeSpec{ + ID: "n1", + OutputAliases: []v1alpha1.Alias{ + {Alias: core.Alias{ + Var: "x", + Alias: "m", + }}, + }, + } + + outputPath := v1alpha1.GetOutputsFile(outputRef) + + w := &dummyBaseWorkflow{ + Status: map[v1alpha1.NodeID]*v1alpha1.NodeStatus{ + "n1": { + DataDir: outputRef, + }, + }, + GetNodeCb: func(nodeId v1alpha1.NodeID) (v1alpha1.ExecutableNode, bool) { + if nodeId == "n1" { + return n1, true + } + return nil, false + }, + } + + t.Run("SimpleResolve", func(t *testing.T) { + store := createInmemoryDataStore(t, testScope.NewSubScope("9")) + // Store output of previous + m, err := utils.MakeLiteralMap(map[string]interface{}{"x": 1}) + assert.NoError(t, err) + assert.NoError(t, store.WriteProtobuf(ctx, outputPath, storage.Options{}, m)) + + //bindings + b := []*v1alpha1.Binding{ + { + Binding: utils.MakeBinding("map", utils.MakeBindingDataMap( + utils.NewPair("x", utils.MakeBindingDataPromise("n1", "x")), + utils.NewPair("z", utils.MustMakePrimitiveBindingData(5)), + )), + }, + { + Binding: utils.MakeBinding("simple", utils.MustMakePrimitiveBindingData(1)), + }, + } + + hf := &mocks.HandlerFactory{} + h := &mocks2.IFace{} + hf.On("GetHandler", mock.Anything).Return(h, nil) + expected, err := utils.MakeLiteralMap(map[string]interface{}{ + "map": map[string]interface{}{"x": 1, "z": 5}, + "simple": utils.MustMakePrimitiveLiteral(1), + }) + assert.NoError(t, err) + + h.On("ExtractOutput", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(expected, nil) + + l, err := Resolve(ctx, hf, w, "n2", b, store) + if assert.NoError(t, err) { + assert.NotNil(t, l) + if assert.NoError(t, err) { + flyteassert.EqualLiteralMap(t, expected, l) + } + } + }) + + t.Run("SimpleResolveFail", func(t *testing.T) { + store := createInmemoryDataStore(t, testScope.NewSubScope("10")) + // Store has no previous output + + //bindings + b := []*v1alpha1.Binding{ + { + Binding: utils.MakeBinding("map", utils.MakeBindingDataMap( + utils.NewPair("x", utils.MakeBindingDataPromise("n1", "x")), + utils.NewPair("z", utils.MustMakePrimitiveBindingData(5)), + )), + }, + { + Binding: utils.MakeBinding("simple", utils.MustMakePrimitiveBindingData(1)), + }, + } + + hf := &mocks.HandlerFactory{} + h := &mocks2.IFace{} + hf.On("GetHandler", mock.Anything).Return(h, nil) + h.On("ExtractOutput", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, fmt.Errorf("No outputs")) + + _, err := Resolve(ctx, hf, w, "n2", b, store) + assert.Error(t, err) + }) + +} diff --git a/flytepropeller/pkg/controller/nodes/start/handler.go b/flytepropeller/pkg/controller/nodes/start/handler.go new file mode 100644 index 0000000000..61034cf4ce --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/start/handler.go @@ -0,0 +1,40 @@ +package start + +import ( + "context" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/controller/nodes/errors" + "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" + "github.com/lyft/flytestdlib/storage" +) + +type startHandler struct { + store *storage.DataStore +} + +func (s startHandler) Initialize(ctx context.Context) error { + return nil +} + +func (s *startHandler) StartNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, nodeInputs *handler.Data) (handler.Status, error) { + return handler.StatusSuccess, nil +} + +func (s *startHandler) CheckNodeStatus(ctx context.Context, g v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, nodeStatus v1alpha1.ExecutableNodeStatus) (handler.Status, error) { + return handler.StatusSuccess, nil +} + +func (s *startHandler) HandleFailingNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) (handler.Status, error) { + return handler.StatusFailed(errors.Errorf(errors.IllegalStateError, node.GetID(), "start node cannot enter a failing state")), nil +} + +func (s *startHandler) AbortNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) error { + return nil +} + +func New(store *storage.DataStore) handler.IFace { + return &startHandler{ + store: store, + } +} diff --git a/flytepropeller/pkg/controller/nodes/start/handler_test.go b/flytepropeller/pkg/controller/nodes/start/handler_test.go new file mode 100644 index 0000000000..18004308c0 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/start/handler_test.go @@ -0,0 +1,77 @@ +package start + +import ( + "context" + "testing" + + "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flytestdlib/promutils/labeled" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" + "github.com/lyft/flytepropeller/pkg/utils" + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/storage" + "github.com/stretchr/testify/assert" +) + +var testScope = promutils.NewScope("start_test") + +func createInmemoryDataStore(t testing.TB, scope promutils.Scope) *storage.DataStore { + cfg := storage.Config{ + Type: storage.TypeMemory, + } + d, err := storage.NewDataStore(&cfg, scope) + assert.NoError(t, err) + return d +} + +func init() { + labeled.SetMetricKeys(contextutils.NodeIDKey) +} + +func TestStartNodeHandler_Initialize(t *testing.T) { + h := startHandler{} + // Do nothing + assert.NoError(t, h.Initialize(context.TODO())) +} + +func TestStartNodeHandler_StartNode(t *testing.T) { + ctx := context.Background() + mockStorage := createInmemoryDataStore(t, testScope.NewSubScope("z")) + h := New(mockStorage) + node := &v1alpha1.NodeSpec{ + ID: v1alpha1.EndNodeID, + } + w := &v1alpha1.FlyteWorkflow{ + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: v1alpha1.WorkflowID("w1"), + }, + } + t.Run("NoInputs", func(t *testing.T) { + s, err := h.StartNode(ctx, w, node, nil) + assert.NoError(t, err) + assert.Equal(t, handler.StatusSuccess, s) + }) + t.Run("WithInputs", func(t *testing.T) { + node := &v1alpha1.NodeSpec{ + ID: v1alpha1.NodeID("n1"), + InputBindings: []*v1alpha1.Binding{ + { + Binding: utils.MakeBinding("x", utils.MustMakePrimitiveBindingData("hello")), + }, + }, + } + s, err := h.StartNode(ctx, w, node, nil) + assert.NoError(t, err) + assert.Equal(t, handler.StatusSuccess, s) + }) +} + +func TestStartNodeHandler_HandleNode(t *testing.T) { + ctx := context.Background() + h := startHandler{} + s, err := h.CheckNodeStatus(ctx, nil, nil, nil) + assert.NoError(t, err) + assert.Equal(t, handler.StatusSuccess, s) +} diff --git a/flytepropeller/pkg/controller/nodes/subworkflow/handler.go b/flytepropeller/pkg/controller/nodes/subworkflow/handler.go new file mode 100644 index 0000000000..8be3879413 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/subworkflow/handler.go @@ -0,0 +1,78 @@ +package subworkflow + +import ( + "context" + + "github.com/lyft/flyteidl/clients/go/events" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/controller/executors" + "github.com/lyft/flytepropeller/pkg/controller/nodes/errors" + "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" + "github.com/lyft/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/storage" +) + +type workflowNodeHandler struct { + recorder events.WorkflowEventRecorder + lpHandler launchPlanHandler + subWfHandler subworkflowHandler +} + +func (w *workflowNodeHandler) Initialize(ctx context.Context) error { + return nil +} + +func (w *workflowNodeHandler) StartNode(ctx context.Context, wf v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, nodeInputs *handler.Data) (handler.Status, error) { + if node.GetWorkflowNode().GetSubWorkflowRef() != nil { + return w.subWfHandler.StartSubWorkflow(ctx, wf, node, nodeInputs) + } + + if node.GetWorkflowNode().GetLaunchPlanRefID() != nil { + return w.lpHandler.StartLaunchPlan(ctx, wf, node, nodeInputs) + } + + return handler.StatusFailed(errors.Errorf(errors.BadSpecificationError, node.GetID(), "SubWorkflow is incorrectly specified.")), nil +} + +func (w *workflowNodeHandler) CheckNodeStatus(ctx context.Context, wf v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, status v1alpha1.ExecutableNodeStatus) (handler.Status, error) { + if node.GetWorkflowNode().GetSubWorkflowRef() != nil { + return w.subWfHandler.CheckSubWorkflowStatus(ctx, wf, node, status) + } + + if node.GetWorkflowNode().GetLaunchPlanRefID() != nil { + return w.lpHandler.CheckLaunchPlanStatus(ctx, wf, node, status) + } + + return handler.StatusFailed(errors.Errorf(errors.BadSpecificationError, node.GetID(), "workflow node does not have a subworkflow or child workflow reference")), nil +} + +func (w *workflowNodeHandler) HandleFailingNode(ctx context.Context, wf v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) (handler.Status, error) { + if node.GetWorkflowNode() != nil && node.GetWorkflowNode().GetSubWorkflowRef() != nil { + return w.subWfHandler.HandleSubWorkflowFailingNode(ctx, wf, node) + } + return handler.StatusFailed(nil), nil +} + +func (w *workflowNodeHandler) AbortNode(ctx context.Context, wf v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) error { + if node.GetWorkflowNode().GetSubWorkflowRef() != nil { + return w.subWfHandler.HandleAbort(ctx, wf, node) + } + + if node.GetWorkflowNode().GetLaunchPlanRefID() != nil { + return w.lpHandler.HandleAbort(ctx, wf, node) + } + return nil +} + +func New(executor executors.Node, eventSink events.EventSink, workflowLauncher launchplan.Executor, enQWorkflow v1alpha1.EnqueueWorkflow, store *storage.DataStore, scope promutils.Scope) handler.IFace { + subworkflowScope := scope.NewSubScope("workflow") + return &workflowNodeHandler{ + subWfHandler: newSubworkflowHandler(executor, enQWorkflow, store), + lpHandler: launchPlanHandler{ + store: store, + launchPlan: workflowLauncher, + }, + recorder: events.NewWorkflowEventRecorder(eventSink, subworkflowScope), + } +} diff --git a/flytepropeller/pkg/controller/nodes/subworkflow/handler_test.go b/flytepropeller/pkg/controller/nodes/subworkflow/handler_test.go new file mode 100644 index 0000000000..124079e5d7 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/subworkflow/handler_test.go @@ -0,0 +1,230 @@ +package subworkflow + +import ( + "context" + "fmt" + "reflect" + "testing" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + mocks2 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" + "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" + "github.com/lyft/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" + "github.com/lyft/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/mocks" + "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/promutils/labeled" + "github.com/lyft/flytestdlib/storage" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestWorkflowNodeHandler_StartNode_Launchplan(t *testing.T) { + ctx := context.TODO() + + nodeID := "n1" + attempts := uint32(1) + + lpID := &core.Identifier{ + Project: "p", + Domain: "d", + Name: "n", + Version: "v", + ResourceType: core.ResourceType_LAUNCH_PLAN, + } + mockWfNode := &mocks2.ExecutableWorkflowNode{} + mockWfNode.On("GetLaunchPlanRefID").Return(&v1alpha1.Identifier{ + Identifier: lpID, + }) + mockWfNode.On("GetSubWorkflowRef").Return(nil) + + mockNode := &mocks2.ExecutableNode{} + mockNode.On("GetID").Return("n1") + mockNode.On("GetWorkflowNode").Return(mockWfNode) + + mockNodeStatus := &mocks2.ExecutableNodeStatus{} + mockNodeStatus.On("GetAttempts").Return(attempts) + wfStatus := &mocks2.MutableWorkflowNodeStatus{} + mockNodeStatus.On("GetOrCreateWorkflowStatus").Return(wfStatus) + wfStatus.On("SetWorkflowExecutionName", + mock.MatchedBy(func(name string) bool { + return name == "x-n1-1" + }), + ).Return() + parentID := &core.WorkflowExecutionIdentifier{ + Name: "x", + Domain: "y", + Project: "z", + } + mockWf := &mocks2.ExecutableWorkflow{} + mockWf.On("GetNodeExecutionStatus", nodeID).Return(mockNodeStatus) + mockWf.On("GetExecutionID").Return(v1alpha1.WorkflowExecutionIdentifier{ + WorkflowExecutionIdentifier: parentID, + }) + + ni := &core.LiteralMap{} + + t.Run("happy", func(t *testing.T) { + + mockLPExec := &mocks.Executor{} + + h := New(nil, nil, mockLPExec, nil, nil, promutils.NewTestScope()) + mockLPExec.On("Launch", + ctx, + mock.MatchedBy(func(o launchplan.LaunchContext) bool { + return o.ParentNodeExecution.NodeId == mockNode.GetID() && + o.ParentNodeExecution.ExecutionId == parentID + }), + mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { + return o.Project == parentID.Project && o.Domain == parentID.Domain + }), + mock.MatchedBy(func(o *core.Identifier) bool { return lpID == o }), + mock.MatchedBy(func(o *core.LiteralMap) bool { return ni == o }), + ).Return(nil) + + s, err := h.StartNode(ctx, mockWf, mockNode, ni) + assert.NoError(t, err) + assert.Equal(t, s.Phase, handler.PhaseRunning) + }) +} + +func TestWorkflowNodeHandler_CheckNodeStatus(t *testing.T) { + ctx := context.TODO() + + nodeID := "n1" + attempts := uint32(1) + dataDir := storage.DataReference("data") + + lpID := &core.Identifier{ + Project: "p", + Domain: "d", + Name: "n", + Version: "v", + ResourceType: core.ResourceType_LAUNCH_PLAN, + } + mockWfNode := &mocks2.ExecutableWorkflowNode{} + mockWfNode.On("GetLaunchPlanRefID").Return(&v1alpha1.Identifier{ + Identifier: lpID, + }) + mockWfNode.On("GetSubWorkflowRef").Return(nil) + + mockNode := &mocks2.ExecutableNode{} + mockNode.On("GetID").Return("n1") + mockNode.On("GetWorkflowNode").Return(mockWfNode) + + mockNodeStatus := &mocks2.ExecutableNodeStatus{} + mockNodeStatus.On("GetAttempts").Return(attempts) + mockNodeStatus.On("GetDataDir").Return(dataDir) + + parentID := &core.WorkflowExecutionIdentifier{ + Name: "x", + Domain: "y", + Project: "z", + } + mockWf := &mocks2.ExecutableWorkflow{} + mockWf.On("GetNodeExecutionStatus", nodeID).Return(mockNodeStatus) + mockWf.On("GetExecutionID").Return(v1alpha1.WorkflowExecutionIdentifier{ + WorkflowExecutionIdentifier: parentID, + }) + + t.Run("stillRunning", func(t *testing.T) { + + mockLPExec := &mocks.Executor{} + + h := New(nil, nil, mockLPExec, nil, nil, promutils.NewTestScope()) + mockLPExec.On("GetStatus", + ctx, + mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { + return o.Project == parentID.Project && o.Domain == parentID.Domain + }), + ).Return(&admin.ExecutionClosure{ + Phase: core.WorkflowExecution_RUNNING, + }, nil) + + s, err := h.CheckNodeStatus(ctx, mockWf, mockNode, nil) + assert.NoError(t, err) + assert.Equal(t, s.Phase, handler.PhaseRunning) + }) +} + +func TestWorkflowNodeHandler_AbortNode(t *testing.T) { + ctx := context.TODO() + + nodeID := "n1" + attempts := uint32(1) + dataDir := storage.DataReference("data") + + lpID := &core.Identifier{ + Project: "p", + Domain: "d", + Name: "n", + Version: "v", + ResourceType: core.ResourceType_LAUNCH_PLAN, + } + mockWfNode := &mocks2.ExecutableWorkflowNode{} + mockWfNode.On("GetLaunchPlanRefID").Return(&v1alpha1.Identifier{ + Identifier: lpID, + }) + mockWfNode.On("GetSubWorkflowRef").Return(nil) + + mockNode := &mocks2.ExecutableNode{} + mockNode.On("GetID").Return("n1") + mockNode.On("GetWorkflowNode").Return(mockWfNode) + + mockNodeStatus := &mocks2.ExecutableNodeStatus{} + mockNodeStatus.On("GetAttempts").Return(attempts) + mockNodeStatus.On("GetDataDir").Return(dataDir) + + parentID := &core.WorkflowExecutionIdentifier{ + Name: "x", + Domain: "y", + Project: "z", + } + mockWf := &mocks2.ExecutableWorkflow{} + mockWf.On("GetNodeExecutionStatus", nodeID).Return(mockNodeStatus) + mockWf.On("GetName").Return("test") + mockWf.On("GetExecutionID").Return(v1alpha1.WorkflowExecutionIdentifier{ + WorkflowExecutionIdentifier: parentID, + }) + + t.Run("abort", func(t *testing.T) { + + mockLPExec := &mocks.Executor{} + + h := New(nil, nil, mockLPExec, nil, nil, promutils.NewTestScope()) + mockLPExec.On("Kill", + ctx, + mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { + return o.Project == parentID.Project && o.Domain == parentID.Domain + }), + mock.AnythingOfType(reflect.String.String()), + ).Return(nil) + + err := h.AbortNode(ctx, mockWf, mockNode) + assert.NoError(t, err) + }) + + t.Run("abort-fail", func(t *testing.T) { + + mockLPExec := &mocks.Executor{} + expectedErr := fmt.Errorf("fail") + h := New(nil, nil, mockLPExec, nil, nil, promutils.NewTestScope()) + mockLPExec.On("Kill", + ctx, + mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { + return o.Project == parentID.Project && o.Domain == parentID.Domain + }), + mock.AnythingOfType(reflect.String.String()), + ).Return(expectedErr) + + err := h.AbortNode(ctx, mockWf, mockNode) + assert.Error(t, err) + assert.Equal(t, err, expectedErr) + }) +} + +func init() { + labeled.SetMetricKeys(contextutils.ProjectKey, contextutils.DomainKey, contextutils.WorkflowIDKey, contextutils.TaskIDKey) +} diff --git a/flytepropeller/pkg/controller/nodes/subworkflow/launchplan.go b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan.go new file mode 100644 index 0000000000..03c0b71d95 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan.go @@ -0,0 +1,139 @@ +package subworkflow + +import ( + "context" + "fmt" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/controller/nodes/errors" + "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" + "github.com/lyft/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/storage" +) + +type launchPlanHandler struct { + launchPlan launchplan.Executor + store *storage.DataStore +} + +func (l *launchPlanHandler) StartLaunchPlan(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, nodeInputs *handler.Data) (handler.Status, error) { + nodeStatus := w.GetNodeExecutionStatus(node.GetID()) + childID, err := GetChildWorkflowExecutionID( + w.GetExecutionID().WorkflowExecutionIdentifier, + node.GetID(), + nodeStatus.GetAttempts(), + ) + + if err != nil { + return handler.StatusFailed(errors.Wrapf(errors.RuntimeExecutionError, node.GetID(), err, "failed to create unique ID")), nil + } + + launchCtx := launchplan.LaunchContext{ + // TODO we need to add principal and nestinglevel as annotations or labels? + Principal: "unknown", + NestingLevel: 0, + ParentNodeExecution: &core.NodeExecutionIdentifier{ + NodeId: node.GetID(), + ExecutionId: w.GetExecutionID().WorkflowExecutionIdentifier, + }, + } + err = l.launchPlan.Launch(ctx, launchCtx, childID, node.GetWorkflowNode().GetLaunchPlanRefID().Identifier, nodeInputs) + if err != nil { + if launchplan.IsAlreadyExists(err) { + logger.Info(ctx, "Execution already exists [%s].", childID.Name) + } else if launchplan.IsUserError(err) { + return handler.StatusFailed(err), nil + } else { + return handler.StatusUndefined, err + } + } else { + logger.Infof(ctx, "Launched launchplan with ID [%s]", childID.Name) + } + + nodeStatus.GetOrCreateWorkflowStatus().SetWorkflowExecutionName(childID.Name) + return handler.StatusRunning, nil +} + +func (l *launchPlanHandler) CheckLaunchPlanStatus(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, status v1alpha1.ExecutableNodeStatus) (handler.Status, error) { + // Handle launch plan + nodeStatus := w.GetNodeExecutionStatus(node.GetID()) + childID, err := GetChildWorkflowExecutionID( + w.GetExecutionID().WorkflowExecutionIdentifier, + node.GetID(), + nodeStatus.GetAttempts(), + ) + + if err != nil { + // THIS SHOULD NEVER HAPPEN + return handler.StatusFailed(errors.Wrapf(errors.RuntimeExecutionError, node.GetID(), err, "failed to create unique ID")), nil + } + + wfStatusClosure, err := l.launchPlan.GetStatus(ctx, childID) + if err != nil { + if launchplan.IsNotFound(err) { //NotFound + return handler.StatusFailed(err), nil + } + + return handler.StatusUndefined, err + } + + if wfStatusClosure == nil { + logger.Info(ctx, "Retrieved Launch Plan status is nil. This might indicate pressure on the admin cache."+ + " Consider tweaking its size to allow for more concurrent executions to be cached.") + return handler.StatusRunning, nil + } + + var wErr error + switch wfStatusClosure.GetPhase() { + case core.WorkflowExecution_ABORTED: + wErr = fmt.Errorf("launchplan execution aborted") + return handler.StatusFailed(errors.Wrapf(errors.RemoteChildWorkflowExecutionFailed, node.GetID(), wErr, "launchplan [%s] failed", childID.Name)), nil + case core.WorkflowExecution_FAILED: + wErr = fmt.Errorf("launchplan execution failed without explicit error") + if wfStatusClosure.GetError() != nil { + wErr = fmt.Errorf(" errorCode[%s]: %s", wfStatusClosure.GetError().Code, wfStatusClosure.GetError().Message) + } + return handler.StatusFailed(errors.Wrapf(errors.RemoteChildWorkflowExecutionFailed, node.GetID(), wErr, "launchplan [%s] failed", childID.Name)), nil + case core.WorkflowExecution_SUCCEEDED: + if wfStatusClosure.GetOutputs() != nil { + outputFile := v1alpha1.GetOutputsFile(nodeStatus.GetDataDir()) + childOutput := &core.LiteralMap{} + uri := wfStatusClosure.GetOutputs().GetUri() + if uri != "" { + // Copy remote data to local S3 path + if err := l.store.ReadProtobuf(ctx, storage.DataReference(uri), childOutput); err != nil { + if storage.IsNotFound(err) { + return handler.StatusFailed(errors.Wrapf(errors.RemoteChildWorkflowExecutionFailed, node.GetID(), err, "remote output for launchplan execution was not found, uri [%s]", uri)), nil + } + return handler.StatusUndefined, errors.Wrapf(errors.RuntimeExecutionError, node.GetID(), err, "failed to read outputs from child workflow @ [%s]", uri) + } + + } else if wfStatusClosure.GetOutputs().GetValues() != nil { + // Store data to S3Path + childOutput = wfStatusClosure.GetOutputs().GetValues() + } + if err := l.store.WriteProtobuf(ctx, outputFile, storage.Options{}, childOutput); err != nil { + logger.Debugf(ctx, "failed to write data to Storage, err: %v", err.Error()) + return handler.StatusUndefined, errors.Wrapf(errors.CausedByError, node.GetID(), err, "failed to copy outputs for child workflow") + } + } + return handler.StatusSuccess, nil + } + return handler.StatusRunning, nil +} + +func (l *launchPlanHandler) HandleAbort(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) error { + nodeStatus := w.GetNodeExecutionStatus(node.GetID()) + childID, err := GetChildWorkflowExecutionID( + w.GetExecutionID().WorkflowExecutionIdentifier, + node.GetID(), + nodeStatus.GetAttempts(), + ) + if err != nil { + // THIS SHOULD NEVER HAPPEN + return err + } + return l.launchPlan.Kill(ctx, childID, fmt.Sprintf("parent execution id [%s] aborted", w.GetName())) +} diff --git a/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/admin.go b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/admin.go new file mode 100644 index 0000000000..00b52712fc --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/admin.go @@ -0,0 +1,167 @@ +package launchplan + +import ( + "context" + "fmt" + "runtime/pprof" + "time" + + "github.com/lyft/flytestdlib/logger" + + "github.com/lyft/flytestdlib/contextutils" + + "github.com/lyft/flytestdlib/promutils" + + "github.com/lyft/flytestdlib/utils" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/service" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// Executor for Launchplans that executes on a remote FlyteAdmin service (if configured) +type adminLaunchPlanExecutor struct { + adminClient service.AdminServiceClient + cache utils.AutoRefreshCache +} + +type executionCacheItem struct { + core.WorkflowExecutionIdentifier + ExecutionClosure *admin.ExecutionClosure + SyncError error +} + +func (e executionCacheItem) ID() string { + return e.String() +} + +func (a *adminLaunchPlanExecutor) Launch(ctx context.Context, launchCtx LaunchContext, executionID *core.WorkflowExecutionIdentifier, launchPlanRef *core.Identifier, inputs *core.LiteralMap) error { + req := &admin.ExecutionCreateRequest{ + Project: executionID.Project, + Domain: executionID.Domain, + Name: executionID.Name, + Spec: &admin.ExecutionSpec{ + LaunchPlan: launchPlanRef, + Metadata: &admin.ExecutionMetadata{ + Mode: admin.ExecutionMetadata_SYSTEM, + Nesting: launchCtx.NestingLevel + 1, + Principal: launchCtx.Principal, + ParentNodeExecution: launchCtx.ParentNodeExecution, + }, + Inputs: inputs, + }, + } + _, err := a.adminClient.CreateExecution(ctx, req) + if err != nil { + statusCode := status.Code(err) + switch statusCode { + case codes.AlreadyExists: + _, err := a.cache.GetOrCreate(executionCacheItem{WorkflowExecutionIdentifier: *executionID}) + if err != nil { + logger.Errorf(ctx, "Failed to add ExecID [%v] to auto refresh cache", executionID) + } + + return Wrapf(RemoteErrorAlreadyExists, err, "ExecID %s already exists", executionID.Name) + case codes.DataLoss, codes.DeadlineExceeded, codes.Internal, codes.Unknown, codes.Canceled: + return Wrapf(RemoteErrorSystem, err, "failed to launch workflow [%s], system error", launchPlanRef.Name) + default: + return Wrapf(RemoteErrorUser, err, "failed to launch workflow") + } + } + + _, err = a.cache.GetOrCreate(executionCacheItem{WorkflowExecutionIdentifier: *executionID}) + if err != nil { + logger.Info(ctx, "Failed to add ExecID [%v] to auto refresh cache", executionID) + } + + return nil +} + +func (a *adminLaunchPlanExecutor) GetStatus(ctx context.Context, executionID *core.WorkflowExecutionIdentifier) (*admin.ExecutionClosure, error) { + if executionID == nil { + return nil, fmt.Errorf("nil executionID") + } + + obj, err := a.cache.GetOrCreate(executionCacheItem{WorkflowExecutionIdentifier: *executionID}) + if err != nil { + return nil, err + } + + item := obj.(executionCacheItem) + + return item.ExecutionClosure, item.SyncError +} + +func (a *adminLaunchPlanExecutor) Kill(ctx context.Context, executionID *core.WorkflowExecutionIdentifier, reason string) error { + req := &admin.ExecutionTerminateRequest{ + Id: executionID, + Cause: reason, + } + _, err := a.adminClient.TerminateExecution(ctx, req) + if err != nil { + if status.Code(err) == codes.NotFound { + return nil + } + return Wrapf(RemoteErrorSystem, err, "system error") + } + return nil +} + +func (a *adminLaunchPlanExecutor) Initialize(ctx context.Context) error { + go func() { + // Set goroutine-label... + ctx = contextutils.WithGoroutineLabel(ctx, "admin-launcher") + pprof.SetGoroutineLabels(ctx) + a.cache.Start(ctx) + }() + + return nil +} + +func (a *adminLaunchPlanExecutor) syncItem(ctx context.Context, obj utils.CacheItem) ( + newItem utils.CacheItem, result utils.CacheSyncAction, err error) { + exec := obj.(executionCacheItem) + req := &admin.WorkflowExecutionGetRequest{ + Id: &exec.WorkflowExecutionIdentifier, + } + + res, err := a.adminClient.GetExecution(ctx, req) + if err != nil { + // TODO: Define which error codes are system errors (and return the error) vs user errors. + + if status.Code(err) == codes.NotFound { + err = Wrapf(RemoteErrorNotFound, err, "execID [%s] not found on remote", exec.WorkflowExecutionIdentifier.Name) + } else { + err = Wrapf(RemoteErrorSystem, err, "system error") + } + + return executionCacheItem{ + WorkflowExecutionIdentifier: exec.WorkflowExecutionIdentifier, + SyncError: err, + }, utils.Update, nil + } + + return executionCacheItem{ + WorkflowExecutionIdentifier: exec.WorkflowExecutionIdentifier, + ExecutionClosure: res.Closure, + }, utils.Update, nil +} + +func NewAdminLaunchPlanExecutor(_ context.Context, client service.AdminServiceClient, + syncPeriod time.Duration, cfg *AdminConfig, scope promutils.Scope) (Executor, error) { + exec := &adminLaunchPlanExecutor{ + adminClient: client, + } + + // TODO: make tps/burst/size configurable + cache, err := utils.NewAutoRefreshCache(exec.syncItem, utils.NewRateLimiter("adminSync", + float64(cfg.TPS), cfg.Burst), syncPeriod, cfg.MaxCacheSize, scope) + if err != nil { + return nil, err + } + + exec.cache = cache + return exec, nil +} diff --git a/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/admin_test.go b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/admin_test.go new file mode 100644 index 0000000000..3384cedcbe --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/admin_test.go @@ -0,0 +1,277 @@ +package launchplan + +import ( + "context" + "testing" + "time" + + "github.com/lyft/flytestdlib/promutils" + + "github.com/lyft/flyteidl/clients/go/admin/mocks" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func TestAdminLaunchPlanExecutor_GetStatus(t *testing.T) { + ctx := context.TODO() + id := &core.WorkflowExecutionIdentifier{ + Name: "n", + Domain: "d", + Project: "p", + } + var result *admin.ExecutionClosure + + t.Run("happy", func(t *testing.T) { + mockClient := &mocks.AdminServiceClient{} + exec, err := NewAdminLaunchPlanExecutor(ctx, mockClient, time.Millisecond, defaultAdminConfig, promutils.NewTestScope()) + assert.NoError(t, err) + mockClient.On("GetExecution", + ctx, + mock.MatchedBy(func(o *admin.WorkflowExecutionGetRequest) bool { return true }), + ).Return(result, nil) + assert.NoError(t, err) + s, err := exec.GetStatus(ctx, id) + assert.NoError(t, err) + assert.Equal(t, result, s) + }) + + t.Run("notFound", func(t *testing.T) { + mockClient := &mocks.AdminServiceClient{} + + mockClient.On("CreateExecution", + ctx, + mock.MatchedBy(func(o *admin.ExecutionCreateRequest) bool { + return o.Project == "p" && o.Domain == "d" && o.Name == "n" && o.Spec.Inputs == nil + }), + ).Return(nil, nil) + + mockClient.On("GetExecution", + mock.Anything, + mock.MatchedBy(func(o *admin.WorkflowExecutionGetRequest) bool { return true }), + ).Return(nil, status.Error(codes.NotFound, "")) + + exec, err := NewAdminLaunchPlanExecutor(ctx, mockClient, time.Millisecond, defaultAdminConfig, promutils.NewTestScope()) + assert.NoError(t, err) + + assert.NoError(t, exec.Initialize(ctx)) + + err = exec.Launch(ctx, + LaunchContext{ + ParentNodeExecution: &core.NodeExecutionIdentifier{ + NodeId: "node-id", + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "p", + Domain: "d", + Name: "w", + }, + }, + }, + id, + &core.Identifier{}, + nil, + ) + assert.NoError(t, err) + + // Allow for sync to be called + time.Sleep(time.Second) + + s, err := exec.GetStatus(ctx, id) + assert.Error(t, err) + assert.Nil(t, s) + assert.True(t, IsNotFound(err)) + }) + + t.Run("other", func(t *testing.T) { + mockClient := &mocks.AdminServiceClient{} + + mockClient.On("CreateExecution", + ctx, + mock.MatchedBy(func(o *admin.ExecutionCreateRequest) bool { + return o.Project == "p" && o.Domain == "d" && o.Name == "n" && o.Spec.Inputs == nil + }), + ).Return(nil, nil) + + mockClient.On("GetExecution", + mock.Anything, + mock.MatchedBy(func(o *admin.WorkflowExecutionGetRequest) bool { return true }), + ).Return(nil, status.Error(codes.Canceled, "")) + + exec, err := NewAdminLaunchPlanExecutor(ctx, mockClient, time.Millisecond, defaultAdminConfig, promutils.NewTestScope()) + assert.NoError(t, err) + + assert.NoError(t, exec.Initialize(ctx)) + + err = exec.Launch(ctx, + LaunchContext{ + ParentNodeExecution: &core.NodeExecutionIdentifier{ + NodeId: "node-id", + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "p", + Domain: "d", + Name: "w", + }, + }, + }, + id, + &core.Identifier{}, + nil, + ) + assert.NoError(t, err) + + // Allow for sync to be called + time.Sleep(time.Second) + + s, err := exec.GetStatus(ctx, id) + assert.Error(t, err) + assert.Nil(t, s) + assert.False(t, IsNotFound(err)) + }) +} + +func TestAdminLaunchPlanExecutor_Launch(t *testing.T) { + ctx := context.TODO() + id := &core.WorkflowExecutionIdentifier{ + Name: "n", + Domain: "d", + Project: "p", + } + + t.Run("happy", func(t *testing.T) { + + mockClient := &mocks.AdminServiceClient{} + exec, err := NewAdminLaunchPlanExecutor(ctx, mockClient, time.Second, defaultAdminConfig, promutils.NewTestScope()) + mockClient.On("CreateExecution", + ctx, + mock.MatchedBy(func(o *admin.ExecutionCreateRequest) bool { + return o.Project == "p" && o.Domain == "d" && o.Name == "n" && o.Spec.Inputs == nil + }), + ).Return(nil, nil) + assert.NoError(t, err) + err = exec.Launch(ctx, + LaunchContext{ + ParentNodeExecution: &core.NodeExecutionIdentifier{ + NodeId: "node-id", + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "p", + Domain: "d", + Name: "w", + }, + }, + }, + id, + &core.Identifier{}, + nil, + ) + assert.NoError(t, err) + }) + + t.Run("notFound", func(t *testing.T) { + + mockClient := &mocks.AdminServiceClient{} + exec, err := NewAdminLaunchPlanExecutor(ctx, mockClient, time.Second, defaultAdminConfig, promutils.NewTestScope()) + mockClient.On("CreateExecution", + ctx, + mock.MatchedBy(func(o *admin.ExecutionCreateRequest) bool { return true }), + ).Return(nil, status.Error(codes.AlreadyExists, "")) + assert.NoError(t, err) + err = exec.Launch(ctx, + LaunchContext{ + ParentNodeExecution: &core.NodeExecutionIdentifier{ + NodeId: "node-id", + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "p", + Domain: "d", + Name: "w", + }, + }, + }, + id, + &core.Identifier{}, + nil, + ) + assert.Error(t, err) + assert.True(t, IsAlreadyExists(err)) + }) + + t.Run("other", func(t *testing.T) { + + mockClient := &mocks.AdminServiceClient{} + exec, err := NewAdminLaunchPlanExecutor(ctx, mockClient, time.Second, defaultAdminConfig, promutils.NewTestScope()) + mockClient.On("CreateExecution", + ctx, + mock.MatchedBy(func(o *admin.ExecutionCreateRequest) bool { return true }), + ).Return(nil, status.Error(codes.Canceled, "")) + assert.NoError(t, err) + err = exec.Launch(ctx, + LaunchContext{ + ParentNodeExecution: &core.NodeExecutionIdentifier{ + NodeId: "node-id", + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "p", + Domain: "d", + Name: "w", + }, + }, + }, + id, + &core.Identifier{}, + nil, + ) + assert.Error(t, err) + assert.False(t, IsAlreadyExists(err)) + }) +} + +func TestAdminLaunchPlanExecutor_Kill(t *testing.T) { + ctx := context.TODO() + id := &core.WorkflowExecutionIdentifier{ + Name: "n", + Domain: "d", + Project: "p", + } + + const reason = "reason" + t.Run("happy", func(t *testing.T) { + + mockClient := &mocks.AdminServiceClient{} + exec, err := NewAdminLaunchPlanExecutor(ctx, mockClient, time.Second, defaultAdminConfig, promutils.NewTestScope()) + mockClient.On("TerminateExecution", + ctx, + mock.MatchedBy(func(o *admin.ExecutionTerminateRequest) bool { return o.Id == id && o.Cause == reason }), + ).Return(&admin.ExecutionTerminateResponse{}, nil) + assert.NoError(t, err) + err = exec.Kill(ctx, id, reason) + assert.NoError(t, err) + }) + + t.Run("notFound", func(t *testing.T) { + + mockClient := &mocks.AdminServiceClient{} + exec, err := NewAdminLaunchPlanExecutor(ctx, mockClient, time.Second, defaultAdminConfig, promutils.NewTestScope()) + mockClient.On("TerminateExecution", + ctx, + mock.MatchedBy(func(o *admin.ExecutionTerminateRequest) bool { return o.Id == id && o.Cause == reason }), + ).Return(nil, status.Error(codes.NotFound, "")) + assert.NoError(t, err) + err = exec.Kill(ctx, id, reason) + assert.NoError(t, err) + }) + + t.Run("other", func(t *testing.T) { + + mockClient := &mocks.AdminServiceClient{} + exec, err := NewAdminLaunchPlanExecutor(ctx, mockClient, time.Second, defaultAdminConfig, promutils.NewTestScope()) + mockClient.On("TerminateExecution", + ctx, + mock.MatchedBy(func(o *admin.ExecutionTerminateRequest) bool { return o.Id == id && o.Cause == reason }), + ).Return(nil, status.Error(codes.Canceled, "")) + assert.NoError(t, err) + err = exec.Kill(ctx, id, reason) + assert.Error(t, err) + assert.False(t, IsNotFound(err)) + }) +} diff --git a/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/adminconfig.go b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/adminconfig.go new file mode 100644 index 0000000000..ed7b8d5171 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/adminconfig.go @@ -0,0 +1,33 @@ +package launchplan + +import ( + ctrlConfig "github.com/lyft/flytepropeller/pkg/controller/config" +) + +//go:generate pflags AdminConfig --default-var defaultAdminConfig + +var ( + defaultAdminConfig = &AdminConfig{ + TPS: 5, + Burst: 10, + MaxCacheSize: 10000, + } + + adminConfigSection = ctrlConfig.ConfigSection.MustRegisterSection("admin-launcher", defaultAdminConfig) +) + +type AdminConfig struct { + // TPS indicates the maximum transactions per second to flyte admin from this client. + // If it's zero, the created client will use DefaultTPS: 5 + TPS int64 `json:"tps" pflag:",The maximum number of transactions per second to flyte admin from this client."` + + // Maximum burst for throttle. + // If it's zero, the created client will use DefaultBurst: 10. + Burst int `json:"burst" pflag:",Maximum burst for throttle"` + + MaxCacheSize int `json:"cacheSize" pflag:",Maximum cache in terms of number of items stored."` +} + +func GetAdminConfig() *AdminConfig { + return adminConfigSection.GetConfig().(*AdminConfig) +} diff --git a/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/adminconfig_flags.go b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/adminconfig_flags.go new file mode 100755 index 0000000000..9b85965e7e --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/adminconfig_flags.go @@ -0,0 +1,48 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package launchplan + +import ( + "encoding/json" + "reflect" + + "fmt" + + "github.com/spf13/pflag" +) + +// If v is a pointer, it will get its element value or the zero value of the element type. +// If v is not a pointer, it will return it as is. +func (AdminConfig) elemValueOrNil(v interface{}) interface{} { + if t := reflect.TypeOf(v); t.Kind() == reflect.Ptr { + if reflect.ValueOf(v).IsNil() { + return reflect.Zero(t.Elem()).Interface() + } else { + return reflect.ValueOf(v).Interface() + } + } else if v == nil { + return reflect.Zero(t).Interface() + } + + return v +} + +func (AdminConfig) mustMarshalJSON(v json.Marshaler) string { + raw, err := v.MarshalJSON() + if err != nil { + panic(err) + } + + return string(raw) +} + +// GetPFlagSet will return strongly types pflags for all fields in AdminConfig and its nested types. The format of the +// flags is json-name.json-sub-name... etc. +func (cfg AdminConfig) GetPFlagSet(prefix string) *pflag.FlagSet { + cmdFlags := pflag.NewFlagSet("AdminConfig", pflag.ExitOnError) + cmdFlags.Int64(fmt.Sprintf("%v%v", prefix, "tps"), defaultAdminConfig.TPS, "The maximum number of transactions per second to flyte admin from this client.") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "burst"), defaultAdminConfig.Burst, "Maximum burst for throttle") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "cacheSize"), defaultAdminConfig.MaxCacheSize, "Maximum cache in terms of number of items stored.") + return cmdFlags +} diff --git a/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/adminconfig_flags_test.go b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/adminconfig_flags_test.go new file mode 100755 index 0000000000..79de094636 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/adminconfig_flags_test.go @@ -0,0 +1,168 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package launchplan + +import ( + "encoding/json" + "fmt" + "reflect" + "strings" + "testing" + + "github.com/mitchellh/mapstructure" + "github.com/stretchr/testify/assert" +) + +var dereferencableKindsAdminConfig = map[reflect.Kind]struct{}{ + reflect.Array: {}, reflect.Chan: {}, reflect.Map: {}, reflect.Ptr: {}, reflect.Slice: {}, +} + +// Checks if t is a kind that can be dereferenced to get its underlying type. +func canGetElementAdminConfig(t reflect.Kind) bool { + _, exists := dereferencableKindsAdminConfig[t] + return exists +} + +// This decoder hook tests types for json unmarshaling capability. If implemented, it uses json unmarshal to build the +// object. Otherwise, it'll just pass on the original data. +func jsonUnmarshalerHookAdminConfig(_, to reflect.Type, data interface{}) (interface{}, error) { + unmarshalerType := reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() + if to.Implements(unmarshalerType) || reflect.PtrTo(to).Implements(unmarshalerType) || + (canGetElementAdminConfig(to.Kind()) && to.Elem().Implements(unmarshalerType)) { + + raw, err := json.Marshal(data) + if err != nil { + fmt.Printf("Failed to marshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + res := reflect.New(to).Interface() + err = json.Unmarshal(raw, &res) + if err != nil { + fmt.Printf("Failed to umarshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + return res, nil + } + + return data, nil +} + +func decode_AdminConfig(input, result interface{}) error { + config := &mapstructure.DecoderConfig{ + TagName: "json", + WeaklyTypedInput: true, + Result: result, + DecodeHook: mapstructure.ComposeDecodeHookFunc( + mapstructure.StringToTimeDurationHookFunc(), + mapstructure.StringToSliceHookFunc(","), + jsonUnmarshalerHookAdminConfig, + ), + } + + decoder, err := mapstructure.NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(input) +} + +func join_AdminConfig(arr interface{}, sep string) string { + listValue := reflect.ValueOf(arr) + strs := make([]string, 0, listValue.Len()) + for i := 0; i < listValue.Len(); i++ { + strs = append(strs, fmt.Sprintf("%v", listValue.Index(i))) + } + + return strings.Join(strs, sep) +} + +func testDecodeJson_AdminConfig(t *testing.T, val, result interface{}) { + assert.NoError(t, decode_AdminConfig(val, result)) +} + +func testDecodeSlice_AdminConfig(t *testing.T, vStringSlice, result interface{}) { + assert.NoError(t, decode_AdminConfig(vStringSlice, result)) +} + +func TestAdminConfig_GetPFlagSet(t *testing.T) { + val := AdminConfig{} + cmdFlags := val.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) +} + +func TestAdminConfig_SetFlags(t *testing.T) { + actual := AdminConfig{} + cmdFlags := actual.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) + + t.Run("Test_tps", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt64, err := cmdFlags.GetInt64("tps"); err == nil { + assert.Equal(t, int64(defaultAdminConfig.TPS), vInt64) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("tps", testValue) + if vInt64, err := cmdFlags.GetInt64("tps"); err == nil { + testDecodeJson_AdminConfig(t, fmt.Sprintf("%v", vInt64), &actual.TPS) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_burst", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("burst"); err == nil { + assert.Equal(t, int(defaultAdminConfig.Burst), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("burst", testValue) + if vInt, err := cmdFlags.GetInt("burst"); err == nil { + testDecodeJson_AdminConfig(t, fmt.Sprintf("%v", vInt), &actual.Burst) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_cacheSize", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("cacheSize"); err == nil { + assert.Equal(t, int(defaultAdminConfig.MaxCacheSize), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("cacheSize", testValue) + if vInt, err := cmdFlags.GetInt("cacheSize"); err == nil { + testDecodeJson_AdminConfig(t, fmt.Sprintf("%v", vInt), &actual.MaxCacheSize) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) +} diff --git a/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/errors.go b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/errors.go new file mode 100644 index 0000000000..02b9203fe2 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/errors.go @@ -0,0 +1,57 @@ +package launchplan + +import "fmt" + +type ErrorCode string + +const ( + RemoteErrorAlreadyExists ErrorCode = "AlreadyExists" + RemoteErrorNotFound ErrorCode = "NotFound" + RemoteErrorSystem = "SystemError" // timeouts, network error etc + RemoteErrorUser = "UserError" // Incase of bad specification, invalid arguments, etc +) + +type RemoteError struct { + Code ErrorCode + Cause error + Message string +} + +func (r RemoteError) Error() string { + return fmt.Sprintf("%s: %s, caused by [%s]", r.Code, r.Message, r.Cause.Error()) +} + +func Wrapf(code ErrorCode, cause error, msg string, args ...interface{}) error { + return &RemoteError{ + Code: code, + Cause: cause, + Message: fmt.Sprintf(msg, args...), + } +} + +// Checks if the error is of type RemoteError and the ErrorCode is of type RemoteErrorAlreadyExists +func IsAlreadyExists(err error) bool { + e, ok := err.(*RemoteError) + if ok { + return e.Code == RemoteErrorAlreadyExists + } + return false +} + +// Checks if the error is of type RemoteError and the ErrorCode is of type RemoteErrorUser +func IsUserError(err error) bool { + e, ok := err.(*RemoteError) + if ok { + return e.Code == RemoteErrorUser + } + return false +} + +// Checks if the error is of type RemoteError and the ErrorCode is of type RemoteErrorNotFound +func IsNotFound(err error) bool { + e, ok := err.(*RemoteError) + if ok { + return e.Code == RemoteErrorNotFound + } + return false +} diff --git a/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/errors_test.go b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/errors_test.go new file mode 100644 index 0000000000..4519218fcf --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/errors_test.go @@ -0,0 +1,36 @@ +package launchplan + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRemoteError(t *testing.T) { + t.Run("alreadyExists", func(t *testing.T) { + e := Wrapf(RemoteErrorAlreadyExists, fmt.Errorf("blah"), "error") + assert.Error(t, e) + assert.True(t, IsAlreadyExists(e)) + }) + + t.Run("notfound", func(t *testing.T) { + e := Wrapf(RemoteErrorNotFound, fmt.Errorf("blah"), "error") + assert.Error(t, e) + assert.True(t, IsNotFound(e)) + }) + + t.Run("alreadyExists", func(t *testing.T) { + e := Wrapf(RemoteErrorUser, fmt.Errorf("blah"), "error") + assert.Error(t, e) + assert.True(t, IsUserError(e)) + }) + + t.Run("system", func(t *testing.T) { + e := Wrapf(RemoteErrorSystem, fmt.Errorf("blah"), "error") + assert.Error(t, e) + assert.False(t, IsAlreadyExists(e)) + assert.False(t, IsNotFound(e)) + assert.False(t, IsUserError(e)) + }) +} diff --git a/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/launchplan.go b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/launchplan.go new file mode 100644 index 0000000000..413f4a2b86 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/launchplan.go @@ -0,0 +1,36 @@ +package launchplan + +import ( + "context" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" +) + +//go:generate mockery -name Executor + +// A simple context that is used to start an execution of a LaunchPlan. It encapsulates enough parent information +// to tie the executions +type LaunchContext struct { + // Nesting level of the current workflow (parent) + NestingLevel uint32 + // Principal of the current workflow, so that billing can be tied correctly + Principal string + // If a node launched the execution, this specifies which node execution + ParentNodeExecution *core.NodeExecutionIdentifier +} + +// Interface to be implemented by the remote system that can allow workflow launching capabilities +type Executor interface { + // Start an execution of a launchplan + Launch(ctx context.Context, launchCtx LaunchContext, executionID *core.WorkflowExecutionIdentifier, launchPlanRef *core.Identifier, inputs *core.LiteralMap) error + + // Retrieve status of a LaunchPlan execution + GetStatus(ctx context.Context, executionID *core.WorkflowExecutionIdentifier) (*admin.ExecutionClosure, error) + + // Kill a remote execution + Kill(ctx context.Context, executionID *core.WorkflowExecutionIdentifier, reason string) error + + // Initializes Executor. + Initialize(ctx context.Context) error +} diff --git a/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/mocks/Executor.go b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/mocks/Executor.go new file mode 100644 index 0000000000..83d28be3b0 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/mocks/Executor.go @@ -0,0 +1,79 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import admin "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" +import context "context" +import core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" +import launchplan "github.com/lyft/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" +import mock "github.com/stretchr/testify/mock" + +// Executor is an autogenerated mock type for the Executor type +type Executor struct { + mock.Mock +} + +// GetStatus provides a mock function with given fields: ctx, executionID +func (_m *Executor) GetStatus(ctx context.Context, executionID *core.WorkflowExecutionIdentifier) (*admin.ExecutionClosure, error) { + ret := _m.Called(ctx, executionID) + + var r0 *admin.ExecutionClosure + if rf, ok := ret.Get(0).(func(context.Context, *core.WorkflowExecutionIdentifier) *admin.ExecutionClosure); ok { + r0 = rf(ctx, executionID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*admin.ExecutionClosure) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *core.WorkflowExecutionIdentifier) error); ok { + r1 = rf(ctx, executionID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Initialize provides a mock function with given fields: ctx +func (_m *Executor) Initialize(ctx context.Context) error { + ret := _m.Called(ctx) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = rf(ctx) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Kill provides a mock function with given fields: ctx, executionID, reason +func (_m *Executor) Kill(ctx context.Context, executionID *core.WorkflowExecutionIdentifier, reason string) error { + ret := _m.Called(ctx, executionID, reason) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *core.WorkflowExecutionIdentifier, string) error); ok { + r0 = rf(ctx, executionID, reason) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Launch provides a mock function with given fields: ctx, launchCtx, executionID, launchPlanRef, inputs +func (_m *Executor) Launch(ctx context.Context, launchCtx launchplan.LaunchContext, executionID *core.WorkflowExecutionIdentifier, launchPlanRef *core.Identifier, inputs *core.LiteralMap) error { + ret := _m.Called(ctx, launchCtx, executionID, launchPlanRef, inputs) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, launchplan.LaunchContext, *core.WorkflowExecutionIdentifier, *core.Identifier, *core.LiteralMap) error); ok { + r0 = rf(ctx, launchCtx, executionID, launchPlanRef, inputs) + } else { + r0 = ret.Error(0) + } + + return r0 +} diff --git a/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/noop.go b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/noop.go new file mode 100644 index 0000000000..f913a8f34d --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/noop.go @@ -0,0 +1,37 @@ +package launchplan + +import ( + "context" + "fmt" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytestdlib/logger" +) + +type failFastWorkflowLauncher struct { +} + +func (failFastWorkflowLauncher) Launch(ctx context.Context, launchCtx LaunchContext, executionID *core.WorkflowExecutionIdentifier, launchPlanRef *core.Identifier, inputs *core.LiteralMap) error { + logger.Infof(ctx, "Fail: Launch Workflow requested with ExecID [%s], LaunchPlan [%s]", executionID.Name, fmt.Sprintf("%s:%s:%s", launchPlanRef.Project, launchPlanRef.Domain, launchPlanRef.Name)) + return Wrapf(RemoteErrorUser, fmt.Errorf("badly configured system"), "please enable admin workflow launch to use launchplans") +} + +func (failFastWorkflowLauncher) GetStatus(ctx context.Context, executionID *core.WorkflowExecutionIdentifier) (*admin.ExecutionClosure, error) { + logger.Infof(ctx, "NOOP: Workflow Status ExecID [%s]", executionID.Name) + return nil, Wrapf(RemoteErrorUser, fmt.Errorf("badly configured system"), "please enable admin workflow launch to use launchplans") +} + +func (failFastWorkflowLauncher) Kill(ctx context.Context, executionID *core.WorkflowExecutionIdentifier, reason string) error { + return nil +} + +// Initializes Executor. +func (failFastWorkflowLauncher) Initialize(ctx context.Context) error { + return nil +} + +func NewFailFastLaunchPlanExecutor() Executor { + logger.Infof(context.TODO(), "created failFast workflow launcher, will not launch subworkflows.") + return &failFastWorkflowLauncher{} +} diff --git a/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/noop_test.go b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/noop_test.go new file mode 100644 index 0000000000..f75d542506 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/noop_test.go @@ -0,0 +1,51 @@ +package launchplan + +import ( + "context" + "testing" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/stretchr/testify/assert" +) + +func TestFailFastWorkflowLauncher(t *testing.T) { + ctx := context.TODO() + f := NewFailFastLaunchPlanExecutor() + t.Run("getStatus", func(t *testing.T) { + a, err := f.GetStatus(ctx, &core.WorkflowExecutionIdentifier{ + Project: "p", + Domain: "d", + Name: "n", + }) + assert.Nil(t, a) + assert.Error(t, err) + }) + + t.Run("launch", func(t *testing.T) { + err := f.Launch(ctx, LaunchContext{ + ParentNodeExecution: &core.NodeExecutionIdentifier{ + NodeId: "node-id", + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "p", + Domain: "d", + Name: "n", + }, + }, + }, &core.WorkflowExecutionIdentifier{ + Project: "p", + Domain: "d", + Name: "n", + }, &core.Identifier{}, + nil) + assert.Error(t, err) + }) + + t.Run("kill", func(t *testing.T) { + err := f.Kill(ctx, &core.WorkflowExecutionIdentifier{ + Project: "p", + Domain: "d", + Name: "n", + }, "reason") + assert.NoError(t, err) + }) +} diff --git a/flytepropeller/pkg/controller/nodes/subworkflow/launchplan_test.go b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan_test.go new file mode 100644 index 0000000000..400203e688 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan_test.go @@ -0,0 +1,640 @@ +package subworkflow + +import ( + "context" + "fmt" + "reflect" + "testing" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + mocks2 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" + "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" + "github.com/lyft/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" + "github.com/lyft/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/mocks" + "github.com/lyft/flytepropeller/pkg/utils" + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/storage" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func createInmemoryStore(t testing.TB) *storage.DataStore { + cfg := storage.Config{ + Type: storage.TypeMemory, + } + + d, err := storage.NewDataStore(&cfg, promutils.NewTestScope()) + assert.NoError(t, err) + + return d +} + +func TestSubWorkflowHandler_StartLaunchPlan(t *testing.T) { + ctx := context.TODO() + + nodeID := "n1" + attempts := uint32(1) + + lpID := &core.Identifier{ + Project: "p", + Domain: "d", + Name: "n", + Version: "v", + ResourceType: core.ResourceType_LAUNCH_PLAN, + } + mockWfNode := &mocks2.ExecutableWorkflowNode{} + mockWfNode.On("GetLaunchPlanRefID").Return(&v1alpha1.Identifier{ + Identifier: lpID, + }) + + mockNode := &mocks2.ExecutableNode{} + mockNode.On("GetID").Return("n1") + mockNode.On("GetWorkflowNode").Return(mockWfNode) + + mockNodeStatus := &mocks2.ExecutableNodeStatus{} + mockNodeStatus.On("GetAttempts").Return(attempts) + + parentID := &core.WorkflowExecutionIdentifier{ + Name: "x", + Domain: "y", + Project: "z", + } + mockWf := &mocks2.ExecutableWorkflow{} + mockWf.On("GetNodeExecutionStatus", nodeID).Return(mockNodeStatus) + mockWf.On("GetExecutionID").Return(v1alpha1.WorkflowExecutionIdentifier{ + WorkflowExecutionIdentifier: parentID, + }) + + ni := &core.LiteralMap{} + + t.Run("happy", func(t *testing.T) { + + mockLPExec := &mocks.Executor{} + + h := launchPlanHandler{ + launchPlan: mockLPExec, + } + mockLPExec.On("Launch", + ctx, + mock.MatchedBy(func(o launchplan.LaunchContext) bool { + return o.ParentNodeExecution.NodeId == mockNode.GetID() && + o.ParentNodeExecution.ExecutionId == parentID + }), + mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { + return o.Project == parentID.Project && o.Domain == parentID.Domain + }), + mock.MatchedBy(func(o *core.Identifier) bool { return lpID == o }), + mock.MatchedBy(func(o *core.LiteralMap) bool { return ni == o }), + ).Return(nil) + + wfStatus := &mocks2.MutableWorkflowNodeStatus{} + mockNodeStatus.On("GetOrCreateWorkflowStatus").Return(wfStatus) + wfStatus.On("SetWorkflowExecutionName", + mock.MatchedBy(func(name string) bool { + return name == "x-n1-1" + }), + ).Return() + + s, err := h.StartLaunchPlan(ctx, mockWf, mockNode, ni) + assert.NoError(t, err) + assert.Equal(t, s.Phase, handler.PhaseRunning) + }) + + t.Run("alreadyExists", func(t *testing.T) { + + mockLPExec := &mocks.Executor{} + + h := launchPlanHandler{ + launchPlan: mockLPExec, + } + mockLPExec.On("Launch", + ctx, + mock.MatchedBy(func(o launchplan.LaunchContext) bool { + return o.ParentNodeExecution.NodeId == mockNode.GetID() && + o.ParentNodeExecution.ExecutionId == parentID + }), + mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { + return o.Project == parentID.Project && o.Domain == parentID.Domain + }), + mock.MatchedBy(func(o *core.Identifier) bool { return lpID == o }), + mock.MatchedBy(func(o *core.LiteralMap) bool { return ni == o }), + ).Return(launchplan.Wrapf(launchplan.RemoteErrorAlreadyExists, fmt.Errorf("blah"), "failed")) + + s, err := h.StartLaunchPlan(ctx, mockWf, mockNode, ni) + assert.NoError(t, err) + assert.Equal(t, s.Phase, handler.PhaseRunning) + }) + + t.Run("systemError", func(t *testing.T) { + + mockLPExec := &mocks.Executor{} + + h := launchPlanHandler{ + launchPlan: mockLPExec, + } + mockLPExec.On("Launch", + ctx, + mock.MatchedBy(func(o launchplan.LaunchContext) bool { + return o.ParentNodeExecution.NodeId == mockNode.GetID() && + o.ParentNodeExecution.ExecutionId == parentID + }), + mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { + return o.Project == parentID.Project && o.Domain == parentID.Domain + }), + mock.MatchedBy(func(o *core.Identifier) bool { return lpID == o }), + mock.MatchedBy(func(o *core.LiteralMap) bool { return ni == o }), + ).Return(launchplan.Wrapf(launchplan.RemoteErrorSystem, fmt.Errorf("blah"), "failed")) + + s, err := h.StartLaunchPlan(ctx, mockWf, mockNode, ni) + assert.Error(t, err) + assert.Equal(t, s.Phase, handler.PhaseUndefined) + }) + + t.Run("userError", func(t *testing.T) { + + mockLPExec := &mocks.Executor{} + + h := launchPlanHandler{ + launchPlan: mockLPExec, + } + mockLPExec.On("Launch", + ctx, + mock.MatchedBy(func(o launchplan.LaunchContext) bool { + return o.ParentNodeExecution.NodeId == mockNode.GetID() && + o.ParentNodeExecution.ExecutionId == parentID + }), + mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { + return o.Project == parentID.Project && o.Domain == parentID.Domain + }), + mock.MatchedBy(func(o *core.Identifier) bool { return lpID == o }), + mock.MatchedBy(func(o *core.LiteralMap) bool { return ni == o }), + ).Return(launchplan.Wrapf(launchplan.RemoteErrorUser, fmt.Errorf("blah"), "failed")) + + s, err := h.StartLaunchPlan(ctx, mockWf, mockNode, ni) + assert.NoError(t, err) + assert.Equal(t, s.Phase, handler.PhaseFailed) + }) +} + +func TestSubWorkflowHandler_CheckLaunchPlanStatus(t *testing.T) { + + ctx := context.TODO() + + nodeID := "n1" + attempts := uint32(1) + dataDir := storage.DataReference("data") + + lpID := &core.Identifier{ + Project: "p", + Domain: "d", + Name: "n", + Version: "v", + ResourceType: core.ResourceType_LAUNCH_PLAN, + } + mockWfNode := &mocks2.ExecutableWorkflowNode{} + mockWfNode.On("GetLaunchPlanRefID").Return(&v1alpha1.Identifier{ + Identifier: lpID, + }) + + mockNode := &mocks2.ExecutableNode{} + mockNode.On("GetID").Return("n1") + mockNode.On("GetWorkflowNode").Return(mockWfNode) + + mockNodeStatus := &mocks2.ExecutableNodeStatus{} + mockNodeStatus.On("GetAttempts").Return(attempts) + mockNodeStatus.On("GetDataDir").Return(dataDir) + + parentID := &core.WorkflowExecutionIdentifier{ + Name: "x", + Domain: "y", + Project: "z", + } + mockWf := &mocks2.ExecutableWorkflow{} + mockWf.On("GetNodeExecutionStatus", nodeID).Return(mockNodeStatus) + mockWf.On("GetExecutionID").Return(v1alpha1.WorkflowExecutionIdentifier{ + WorkflowExecutionIdentifier: parentID, + }) + + t.Run("stillRunning", func(t *testing.T) { + + mockLPExec := &mocks.Executor{} + + h := launchPlanHandler{ + launchPlan: mockLPExec, + } + mockLPExec.On("GetStatus", + ctx, + mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { + return o.Project == parentID.Project && o.Domain == parentID.Domain + }), + ).Return(&admin.ExecutionClosure{ + Phase: core.WorkflowExecution_RUNNING, + }, nil) + + s, err := h.CheckLaunchPlanStatus(ctx, mockWf, mockNode, nil) + assert.NoError(t, err) + assert.Equal(t, s.Phase, handler.PhaseRunning) + }) + + t.Run("successNoOutputs", func(t *testing.T) { + + mockLPExec := &mocks.Executor{} + + h := launchPlanHandler{ + launchPlan: mockLPExec, + } + mockLPExec.On("GetStatus", + ctx, + mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { + return o.Project == parentID.Project && o.Domain == parentID.Domain + }), + ).Return(&admin.ExecutionClosure{ + Phase: core.WorkflowExecution_SUCCEEDED, + }, nil) + + s, err := h.CheckLaunchPlanStatus(ctx, mockWf, mockNode, nil) + assert.NoError(t, err) + assert.Equal(t, s.Phase, handler.PhaseSuccess) + }) + + t.Run("successOutputURI", func(t *testing.T) { + + mockStore := createInmemoryStore(t) + mockLPExec := &mocks.Executor{} + uri := storage.DataReference("uri") + + h := launchPlanHandler{ + launchPlan: mockLPExec, + store: mockStore, + } + + op := &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "x": utils.MustMakePrimitiveLiteral(1), + }, + } + err := mockStore.WriteProtobuf(ctx, uri, storage.Options{}, op) + assert.NoError(t, err) + + mockLPExec.On("GetStatus", + ctx, + mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { + return o.Project == parentID.Project && o.Domain == parentID.Domain + }), + ).Return(&admin.ExecutionClosure{ + Phase: core.WorkflowExecution_SUCCEEDED, + OutputResult: &admin.ExecutionClosure_Outputs{ + Outputs: &admin.LiteralMapBlob{ + Data: &admin.LiteralMapBlob_Uri{ + Uri: uri.String(), + }, + }, + }, + }, nil) + + s, err := h.CheckLaunchPlanStatus(ctx, mockWf, mockNode, nil) + assert.NoError(t, err) + assert.Equal(t, s.Phase, handler.PhaseSuccess) + final := &core.LiteralMap{} + assert.NoError(t, mockStore.ReadProtobuf(ctx, v1alpha1.GetOutputsFile(dataDir), final)) + v, ok := final.GetLiterals()["x"] + assert.True(t, ok) + assert.Equal(t, int64(1), v.GetScalar().GetPrimitive().GetInteger()) + }) + + t.Run("successOutputs", func(t *testing.T) { + + mockStore := createInmemoryStore(t) + mockLPExec := &mocks.Executor{} + + h := launchPlanHandler{ + launchPlan: mockLPExec, + store: mockStore, + } + + op := &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "x": utils.MustMakePrimitiveLiteral(1), + }, + } + mockLPExec.On("GetStatus", + ctx, + mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { + return o.Project == parentID.Project && o.Domain == parentID.Domain + }), + ).Return(&admin.ExecutionClosure{ + Phase: core.WorkflowExecution_SUCCEEDED, + OutputResult: &admin.ExecutionClosure_Outputs{ + Outputs: &admin.LiteralMapBlob{ + Data: &admin.LiteralMapBlob_Values{ + Values: op, + }, + }, + }, + }, nil) + + s, err := h.CheckLaunchPlanStatus(ctx, mockWf, mockNode, nil) + assert.NoError(t, err) + assert.Equal(t, s.Phase, handler.PhaseSuccess) + final := &core.LiteralMap{} + assert.NoError(t, mockStore.ReadProtobuf(ctx, v1alpha1.GetOutputsFile(dataDir), final)) + v, ok := final.GetLiterals()["x"] + assert.True(t, ok) + assert.Equal(t, int64(1), v.GetScalar().GetPrimitive().GetInteger()) + }) + + t.Run("failureError", func(t *testing.T) { + + mockLPExec := &mocks.Executor{} + + h := launchPlanHandler{ + launchPlan: mockLPExec, + } + + mockLPExec.On("GetStatus", + ctx, + mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { + return o.Project == parentID.Project && o.Domain == parentID.Domain + }), + ).Return(&admin.ExecutionClosure{ + Phase: core.WorkflowExecution_FAILED, + OutputResult: &admin.ExecutionClosure_Error{ + Error: &core.ExecutionError{ + Message: "msg", + Code: "code", + }, + }, + }, nil) + + s, err := h.CheckLaunchPlanStatus(ctx, mockWf, mockNode, nil) + assert.NoError(t, err) + assert.Equal(t, s.Phase, handler.PhaseFailed) + }) + + t.Run("failureNoError", func(t *testing.T) { + + mockLPExec := &mocks.Executor{} + + h := launchPlanHandler{ + launchPlan: mockLPExec, + } + + mockLPExec.On("GetStatus", + ctx, + mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { + return o.Project == parentID.Project && o.Domain == parentID.Domain + }), + ).Return(&admin.ExecutionClosure{ + Phase: core.WorkflowExecution_FAILED, + }, nil) + + s, err := h.CheckLaunchPlanStatus(ctx, mockWf, mockNode, nil) + assert.NoError(t, err) + assert.Equal(t, s.Phase, handler.PhaseFailed) + }) + + t.Run("aborted", func(t *testing.T) { + + mockLPExec := &mocks.Executor{} + + h := launchPlanHandler{ + launchPlan: mockLPExec, + } + + mockLPExec.On("GetStatus", + ctx, + mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { + return o.Project == parentID.Project && o.Domain == parentID.Domain + }), + ).Return(&admin.ExecutionClosure{ + Phase: core.WorkflowExecution_ABORTED, + }, nil) + + s, err := h.CheckLaunchPlanStatus(ctx, mockWf, mockNode, nil) + assert.NoError(t, err) + assert.Equal(t, s.Phase, handler.PhaseFailed) + }) + + t.Run("notFound", func(t *testing.T) { + + mockLPExec := &mocks.Executor{} + + h := launchPlanHandler{ + launchPlan: mockLPExec, + } + + mockLPExec.On("GetStatus", + ctx, + mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { + return o.Project == parentID.Project && o.Domain == parentID.Domain + }), + ).Return(nil, launchplan.Wrapf(launchplan.RemoteErrorNotFound, fmt.Errorf("some error"), "not found")) + + s, err := h.CheckLaunchPlanStatus(ctx, mockWf, mockNode, nil) + assert.NoError(t, err) + assert.Equal(t, s.Phase, handler.PhaseFailed) + }) + + t.Run("systemError", func(t *testing.T) { + + mockLPExec := &mocks.Executor{} + + h := launchPlanHandler{ + launchPlan: mockLPExec, + } + + mockLPExec.On("GetStatus", + ctx, + mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { + return o.Project == parentID.Project && o.Domain == parentID.Domain + }), + ).Return(nil, launchplan.Wrapf(launchplan.RemoteErrorSystem, fmt.Errorf("some error"), "not found")) + + s, err := h.CheckLaunchPlanStatus(ctx, mockWf, mockNode, nil) + assert.Error(t, err) + assert.Equal(t, s.Phase, handler.PhaseUndefined) + }) + + t.Run("dataStoreFailure", func(t *testing.T) { + + mockStore := storage.NewCompositeDataStore(storage.URLPathConstructor{}, storage.NewDefaultProtobufStore(utils.FailingRawStore{}, promutils.NewTestScope())) + mockLPExec := &mocks.Executor{} + + h := launchPlanHandler{ + launchPlan: mockLPExec, + store: mockStore, + } + + op := &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "x": utils.MustMakePrimitiveLiteral(1), + }, + } + mockLPExec.On("GetStatus", + ctx, + mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { + return o.Project == parentID.Project && o.Domain == parentID.Domain + }), + ).Return(&admin.ExecutionClosure{ + Phase: core.WorkflowExecution_SUCCEEDED, + OutputResult: &admin.ExecutionClosure_Outputs{ + Outputs: &admin.LiteralMapBlob{ + Data: &admin.LiteralMapBlob_Values{ + Values: op, + }, + }, + }, + }, nil) + + s, err := h.CheckLaunchPlanStatus(ctx, mockWf, mockNode, nil) + assert.Error(t, err) + assert.Equal(t, s.Phase, handler.PhaseUndefined) + }) + + t.Run("outputURINotFound", func(t *testing.T) { + + mockStore := createInmemoryStore(t) + mockLPExec := &mocks.Executor{} + uri := storage.DataReference("uri") + + h := launchPlanHandler{ + launchPlan: mockLPExec, + store: mockStore, + } + + mockLPExec.On("GetStatus", + ctx, + mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { + return o.Project == parentID.Project && o.Domain == parentID.Domain + }), + ).Return(&admin.ExecutionClosure{ + Phase: core.WorkflowExecution_SUCCEEDED, + OutputResult: &admin.ExecutionClosure_Outputs{ + Outputs: &admin.LiteralMapBlob{ + Data: &admin.LiteralMapBlob_Uri{ + Uri: uri.String(), + }, + }, + }, + }, nil) + + s, err := h.CheckLaunchPlanStatus(ctx, mockWf, mockNode, nil) + assert.NoError(t, err) + assert.Equal(t, s.Phase, handler.PhaseFailed) + }) + + t.Run("outputURISystemError", func(t *testing.T) { + + mockStore := storage.NewCompositeDataStore(storage.URLPathConstructor{}, storage.NewDefaultProtobufStore(utils.FailingRawStore{}, promutils.NewTestScope())) + mockLPExec := &mocks.Executor{} + uri := storage.DataReference("uri") + + h := launchPlanHandler{ + launchPlan: mockLPExec, + store: mockStore, + } + + mockLPExec.On("GetStatus", + ctx, + mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { + return o.Project == parentID.Project && o.Domain == parentID.Domain + }), + ).Return(&admin.ExecutionClosure{ + Phase: core.WorkflowExecution_SUCCEEDED, + OutputResult: &admin.ExecutionClosure_Outputs{ + Outputs: &admin.LiteralMapBlob{ + Data: &admin.LiteralMapBlob_Uri{ + Uri: uri.String(), + }, + }, + }, + }, nil) + + s, err := h.CheckLaunchPlanStatus(ctx, mockWf, mockNode, nil) + assert.Error(t, err) + assert.Equal(t, s.Phase, handler.PhaseUndefined) + }) +} + +func TestLaunchPlanHandler_HandleAbort(t *testing.T) { + + ctx := context.TODO() + + nodeID := "n1" + attempts := uint32(1) + dataDir := storage.DataReference("data") + + lpID := &core.Identifier{ + Project: "p", + Domain: "d", + Name: "n", + Version: "v", + ResourceType: core.ResourceType_LAUNCH_PLAN, + } + mockWfNode := &mocks2.ExecutableWorkflowNode{} + mockWfNode.On("GetLaunchPlanRefID").Return(&v1alpha1.Identifier{ + Identifier: lpID, + }) + + mockNode := &mocks2.ExecutableNode{} + mockNode.On("GetID").Return(nodeID) + mockNode.On("GetWorkflowNode").Return(mockWfNode) + + mockNodeStatus := &mocks2.ExecutableNodeStatus{} + mockNodeStatus.On("GetAttempts").Return(attempts) + mockNodeStatus.On("GetDataDir").Return(dataDir) + + parentID := &core.WorkflowExecutionIdentifier{ + Name: "x", + Domain: "y", + Project: "z", + } + mockWf := &mocks2.ExecutableWorkflow{} + mockWf.On("GetName").Return("test") + mockWf.On("GetNodeExecutionStatus", nodeID).Return(mockNodeStatus) + mockWf.On("GetExecutionID").Return(v1alpha1.WorkflowExecutionIdentifier{ + WorkflowExecutionIdentifier: parentID, + }) + + t.Run("abort-success", func(t *testing.T) { + mockLPExec := &mocks.Executor{} + mockStore := storage.NewCompositeDataStore(storage.URLPathConstructor{}, storage.NewDefaultProtobufStore(utils.FailingRawStore{}, promutils.NewTestScope())) + mockLPExec.On("Kill", + ctx, + mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { + return o.Project == parentID.Project && o.Domain == parentID.Domain + }), + mock.AnythingOfType(reflect.String.String()), + ).Return(nil) + + h := launchPlanHandler{ + launchPlan: mockLPExec, + store: mockStore, + } + err := h.HandleAbort(ctx, mockWf, mockNode) + assert.NoError(t, err) + }) + + t.Run("abort-fail", func(t *testing.T) { + expectedErr := fmt.Errorf("fail") + mockLPExec := &mocks.Executor{} + mockStore := storage.NewCompositeDataStore(storage.URLPathConstructor{}, storage.NewDefaultProtobufStore(utils.FailingRawStore{}, promutils.NewTestScope())) + mockLPExec.On("Kill", + ctx, + mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { + return o.Project == parentID.Project && o.Domain == parentID.Domain + }), + mock.AnythingOfType(reflect.String.String()), + ).Return(expectedErr) + + h := launchPlanHandler{ + launchPlan: mockLPExec, + store: mockStore, + } + err := h.HandleAbort(ctx, mockWf, mockNode) + assert.Error(t, err) + assert.Equal(t, err, expectedErr) + }) +} diff --git a/flytepropeller/pkg/controller/nodes/subworkflow/sub_workflow.go b/flytepropeller/pkg/controller/nodes/subworkflow/sub_workflow.go new file mode 100644 index 0000000000..a712d991f4 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/subworkflow/sub_workflow.go @@ -0,0 +1,195 @@ +package subworkflow + +import ( + "context" + "fmt" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/controller/executors" + "github.com/lyft/flytepropeller/pkg/controller/nodes/errors" + "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/storage" +) + +//TODO Add unit tests for subworkflow handler + +// Subworkflow handler handles inline subworkflows +type subworkflowHandler struct { + nodeExecutor executors.Node + enqueueWorkflow v1alpha1.EnqueueWorkflow + store *storage.DataStore +} + +func (s *subworkflowHandler) DoInlineSubWorkflow(ctx context.Context, w v1alpha1.ExecutableWorkflow, + parentNodeStatus v1alpha1.ExecutableNodeStatus, startNode v1alpha1.ExecutableNode) (handler.Status, error) { + + //TODO we need to handle failing and success nodes + state, err := s.nodeExecutor.RecursiveNodeHandler(ctx, w, startNode) + if err != nil { + return handler.StatusUndefined, err + } + + if state.HasFailed() { + if w.GetOnFailureNode() != nil { + return handler.StatusFailing(state.Err), nil + } + return handler.StatusFailed(state.Err), nil + } + + if state.IsComplete() { + nodeID := "" + if parentNodeStatus.GetParentNodeID() != nil { + nodeID = *parentNodeStatus.GetParentNodeID() + } + + // If the WF interface has outputs, validate that the outputs file was written. + if outputBindings := w.GetOutputBindings(); len(outputBindings) > 0 { + endNodeStatus := w.GetNodeExecutionStatus(v1alpha1.EndNodeID) + if endNodeStatus == nil { + return handler.StatusFailed(errors.Errorf(errors.SubWorkflowExecutionFailed, nodeID, + "No end node found in subworkflow.")), nil + } + + sourcePath := v1alpha1.GetOutputsFile(endNodeStatus.GetDataDir()) + if metadata, err := s.store.Head(ctx, sourcePath); err == nil { + if !metadata.Exists() { + return handler.StatusFailed(errors.Errorf(errors.SubWorkflowExecutionFailed, nodeID, + "Subworkflow is expected to produce outputs but no outputs file was written to %v.", + sourcePath)), nil + } + } else { + return handler.StatusUndefined, err + } + + destinationPath := v1alpha1.GetOutputsFile(parentNodeStatus.GetDataDir()) + if err := s.store.CopyRaw(ctx, sourcePath, destinationPath, storage.Options{}); err != nil { + return handler.StatusFailed(errors.Wrapf(errors.OutputsNotFoundError, nodeID, + err, "Failed to copy subworkflow outputs from [%v] to [%v]", + sourcePath, destinationPath)), nil + } + } + + return handler.StatusSuccess, nil + } + + if state.PartiallyComplete() { + // Re-enqueue the workflow + s.enqueueWorkflow(w.GetK8sWorkflowID().String()) + } + + return handler.StatusRunning, nil +} + +func (s *subworkflowHandler) DoInFailureHandling(ctx context.Context, w v1alpha1.ExecutableWorkflow) (handler.Status, error) { + if w.GetOnFailureNode() != nil { + state, err := s.nodeExecutor.RecursiveNodeHandler(ctx, w, w.GetOnFailureNode()) + if err != nil { + return handler.StatusUndefined, err + } + if state.HasFailed() { + return handler.StatusFailed(state.Err), nil + } + if state.IsComplete() { + // Re-enqueue the workflow + s.enqueueWorkflow(w.GetK8sWorkflowID().String()) + return handler.StatusFailed(nil), nil + } + return handler.StatusFailing(nil), nil + } + return handler.StatusFailed(nil), nil +} + +func (s *subworkflowHandler) StartSubWorkflow(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, nodeInputs *handler.Data) (handler.Status, error) { + subID := *node.GetWorkflowNode().GetSubWorkflowRef() + subWorkflow := w.FindSubWorkflow(subID) + if subWorkflow == nil { + return handler.StatusFailed(errors.Errorf(errors.SubWorkflowExecutionFailed, node.GetID(), "No subWorkflow [%s], workflow.", subID)), nil + } + + status := w.GetNodeExecutionStatus(node.GetID()) + contextualSubWorkflow := executors.NewSubContextualWorkflow(w, subWorkflow, status) + startNode := contextualSubWorkflow.StartNode() + if startNode == nil { + return handler.StatusFailed(errors.Errorf(errors.SubWorkflowExecutionFailed, "", "No start node found in subworkflow.")), nil + } + + // Before starting the subworkflow, lets set the inputs for the Workflow. The inputs for a SubWorkflow are essentially + // Copy of the inputs to the Node + nodeStatus := contextualSubWorkflow.GetNodeExecutionStatus(startNode.GetID()) + if len(nodeStatus.GetDataDir()) == 0 { + dataDir, err := contextualSubWorkflow.GetExecutionStatus().ConstructNodeDataDir(ctx, s.store, startNode.GetID()) + if err != nil { + logger.Errorf(ctx, "Failed to create metadata store key. Error [%v]", err) + return handler.StatusUndefined, errors.Wrapf(errors.CausedByError, startNode.GetID(), err, "Failed to create metadata store key.") + } + + nodeStatus.SetDataDir(dataDir) + startStatus, err := s.nodeExecutor.SetInputsForStartNode(ctx, contextualSubWorkflow, nodeInputs) + if err != nil { + // TODO we are considering an error when setting inputs are retryable + return handler.StatusUndefined, err + } + + if startStatus.HasFailed() { + return handler.StatusFailed(startStatus.Err), nil + } + } + + return s.DoInlineSubWorkflow(ctx, contextualSubWorkflow, status, startNode) +} + +func (s *subworkflowHandler) CheckSubWorkflowStatus(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, status v1alpha1.ExecutableNodeStatus) (handler.Status, error) { + // Handle subworkflow + subID := *node.GetWorkflowNode().GetSubWorkflowRef() + subWorkflow := w.FindSubWorkflow(subID) + if subWorkflow == nil { + return handler.StatusFailed(errors.Errorf(errors.SubWorkflowExecutionFailed, node.GetID(), "No subWorkflow [%s], workflow.", subID)), nil + } + + contextualSubWorkflow := executors.NewSubContextualWorkflow(w, subWorkflow, status) + startNode := w.StartNode() + if startNode == nil { + return handler.StatusFailed(errors.Errorf(errors.SubWorkflowExecutionFailed, node.GetID(), "No start node found in subworkflow")), nil + } + + parentNodeStatus := w.GetNodeExecutionStatus(node.GetID()) + return s.DoInlineSubWorkflow(ctx, contextualSubWorkflow, parentNodeStatus, startNode) +} + +func (s *subworkflowHandler) HandleSubWorkflowFailingNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) (handler.Status, error) { + status := w.GetNodeExecutionStatus(node.GetID()) + subID := *node.GetWorkflowNode().GetSubWorkflowRef() + subWorkflow := w.FindSubWorkflow(subID) + if subWorkflow == nil { + return handler.StatusFailed(errors.Errorf(errors.SubWorkflowExecutionFailed, node.GetID(), "No subWorkflow [%s], workflow.", subID)), nil + } + contextualSubWorkflow := executors.NewSubContextualWorkflow(w, subWorkflow, status) + return s.DoInFailureHandling(ctx, contextualSubWorkflow) +} + +func (s *subworkflowHandler) HandleAbort(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) error { + subID := *node.GetWorkflowNode().GetSubWorkflowRef() + subWorkflow := w.FindSubWorkflow(subID) + if subWorkflow == nil { + return fmt.Errorf("no sub workflow [%s] found in node [%s]", subID, node.GetID()) + } + + nodeStatus := w.GetNodeExecutionStatus(node.GetID()) + contextualSubWorkflow := executors.NewSubContextualWorkflow(w, subWorkflow, nodeStatus) + + startNode := w.StartNode() + if startNode == nil { + return fmt.Errorf("no sub workflow [%s] found in node [%s]", subID, node.GetID()) + } + + return s.nodeExecutor.AbortHandler(ctx, contextualSubWorkflow, startNode) +} + +func newSubworkflowHandler(nodeExecutor executors.Node, enqueueWorkflow v1alpha1.EnqueueWorkflow, store *storage.DataStore) subworkflowHandler { + return subworkflowHandler{ + nodeExecutor: nodeExecutor, + enqueueWorkflow: enqueueWorkflow, + store: store, + } +} diff --git a/flytepropeller/pkg/controller/nodes/subworkflow/util.go b/flytepropeller/pkg/controller/nodes/subworkflow/util.go new file mode 100644 index 0000000000..973a2e0b93 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/subworkflow/util.go @@ -0,0 +1,24 @@ +package subworkflow + +import ( + "strconv" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/utils" +) + +const maxLengthForSubWorkflow = 20 + +func GetChildWorkflowExecutionID(parentID *core.WorkflowExecutionIdentifier, id v1alpha1.NodeID, attempt uint32) (*core.WorkflowExecutionIdentifier, error) { + name, err := utils.FixedLengthUniqueIDForParts(maxLengthForSubWorkflow, parentID.Name, id, strconv.Itoa(int(attempt))) + if err != nil { + return nil, err + } + // Restriction on name is 20 chars + return &core.WorkflowExecutionIdentifier{ + Project: parentID.Project, + Domain: parentID.Domain, + Name: name, + }, nil +} diff --git a/flytepropeller/pkg/controller/nodes/subworkflow/util_test.go b/flytepropeller/pkg/controller/nodes/subworkflow/util_test.go new file mode 100644 index 0000000000..a3e126f94b --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/subworkflow/util_test.go @@ -0,0 +1,19 @@ +package subworkflow + +import ( + "testing" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/stretchr/testify/assert" +) + +func TestGetChildWorkflowExecutionID(t *testing.T) { + id, err := GetChildWorkflowExecutionID( + &core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "first-name-is-pretty-large", + }, "hello-world", 1) + assert.Equal(t, id.Name, "fav2uxxi") + assert.NoError(t, err) +} diff --git a/flytepropeller/pkg/controller/nodes/task/factory.go b/flytepropeller/pkg/controller/nodes/task/factory.go new file mode 100644 index 0000000000..456590b69e --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/task/factory.go @@ -0,0 +1,72 @@ +package task + +import ( + "time" + + v1 "github.com/lyft/flyteplugins/go/tasks/v1" + "github.com/lyft/flyteplugins/go/tasks/v1/types" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/pkg/errors" +) + +var testModeEnabled = false +var testTaskFactory Factory + +type Factory interface { + GetTaskExecutor(taskType v1alpha1.TaskType) (types.Executor, error) + ListAllTaskExecutors() []types.Executor +} + +// We create a simple facade so that if required we could make a Readonly cache of the Factory without any mutexes +// TODO decide if we want to make this a cache +type sealedTaskFactory struct { +} + +func (sealedTaskFactory) GetTaskExecutor(taskType v1alpha1.TaskType) (types.Executor, error) { + return v1.GetTaskExecutor(taskType) +} + +func (sealedTaskFactory) ListAllTaskExecutors() []types.Executor { + return v1.ListAllTaskExecutors() +} + +func NewFactory(revalPeriod time.Duration) Factory { + if testModeEnabled { + return testTaskFactory + } + + return sealedTaskFactory{} +} + +func SetTestFactory(tf Factory) { + testModeEnabled = true + testTaskFactory = tf +} + +func IsTestModeEnabled() bool { + return testModeEnabled +} + +func DisableTestMode() { + testTaskFactory = nil + testModeEnabled = false +} + +type FactoryFuncs struct { + GetTaskExecutorCb func(taskType v1alpha1.TaskType) (types.Executor, error) + ListAllTaskExecutorsCb func() []types.Executor +} + +func (t *FactoryFuncs) GetTaskExecutor(taskType v1alpha1.TaskType) (types.Executor, error) { + if t.GetTaskExecutorCb != nil { + return t.GetTaskExecutorCb(taskType) + } + return nil, errors.Errorf("No implementation provided") +} + +func (t *FactoryFuncs) ListAllTaskExecutors() []types.Executor { + if t.ListAllTaskExecutorsCb != nil { + return t.ListAllTaskExecutorsCb() + } + return nil +} diff --git a/flytepropeller/pkg/controller/nodes/task/handler.go b/flytepropeller/pkg/controller/nodes/task/handler.go new file mode 100644 index 0000000000..b118107edd --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/task/handler.go @@ -0,0 +1,439 @@ +package task + +import ( + "context" + "fmt" + "reflect" + "runtime/debug" + "strconv" + "time" + + "github.com/lyft/flytepropeller/pkg/controller/executors" + + "sigs.k8s.io/controller-runtime/pkg/runtime/inject" + + "github.com/lyft/flytestdlib/promutils" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/lyft/flyteidl/clients/go/events" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + pluginsV1 "github.com/lyft/flyteplugins/go/tasks/v1/types" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/promutils/labeled" + "github.com/lyft/flytestdlib/storage" + errors2 "github.com/pkg/errors" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/controller/catalog" + "github.com/lyft/flytepropeller/pkg/controller/nodes/errors" + "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" + "github.com/lyft/flytepropeller/pkg/utils" + + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" +) + +const IDMaxLength = 50 + +// TODO handle retries +type taskContext struct { + taskExecutionID taskExecutionID + dataDir storage.DataReference + workflow v1alpha1.WorkflowMeta + node v1alpha1.ExecutableNode + status v1alpha1.ExecutableTaskNodeStatus + serviceAccountName string +} + +func (t *taskContext) GetCustomState() pluginsV1.CustomState { + return t.status.GetCustomState() +} + +func (t *taskContext) GetPhase() pluginsV1.TaskPhase { + return t.status.GetPhase() +} + +func (t *taskContext) GetPhaseVersion() uint32 { + return t.status.GetPhaseVersion() +} + +type taskExecutionID struct { + execName string + id core.TaskExecutionIdentifier +} + +func (te taskExecutionID) GetID() core.TaskExecutionIdentifier { + return te.id +} + +func (te taskExecutionID) GetGeneratedName() string { + return te.execName +} + +func (t *taskContext) GetOwnerID() types.NamespacedName { + return t.workflow.GetK8sWorkflowID() +} + +func (t *taskContext) GetTaskExecutionID() pluginsV1.TaskExecutionID { + return t.taskExecutionID +} + +func (t *taskContext) GetDataDir() storage.DataReference { + return t.dataDir +} + +func (t *taskContext) GetInputsFile() storage.DataReference { + return v1alpha1.GetInputsFile(t.dataDir) +} + +func (t *taskContext) GetOutputsFile() storage.DataReference { + return v1alpha1.GetOutputsFile(t.dataDir) +} + +func (t *taskContext) GetErrorFile() storage.DataReference { + return v1alpha1.GetOutputErrorFile(t.dataDir) +} + +func (t *taskContext) GetNamespace() string { + return t.workflow.GetNamespace() +} + +func (t *taskContext) GetOwnerReference() v1.OwnerReference { + return t.workflow.NewControllerRef() +} + +func (t *taskContext) GetOverrides() pluginsV1.TaskOverrides { + return t.node +} + +func (t *taskContext) GetLabels() map[string]string { + return t.workflow.GetLabels() +} + +func (t *taskContext) GetAnnotations() map[string]string { + return t.workflow.GetAnnotations() +} + +func (t *taskContext) GetK8sServiceAccount() string { + return t.serviceAccountName +} + +type metrics struct { + pluginPanics labeled.Counter + unsupportedTaskType labeled.Counter + discoveryPutFailureCount labeled.Counter + discoveryGetFailureCount labeled.Counter + discoveryMissCount labeled.Counter + discoveryHitCount labeled.Counter + + // TODO We should have a metric to capture custom state size +} + +type taskHandler struct { + taskFactory Factory + recorder events.TaskEventRecorder + enqueueWf v1alpha1.EnqueueWorkflow + store *storage.DataStore + scope promutils.Scope + catalogClient catalog.Client + kubeClient executors.Client + metrics *metrics +} + +func (h *taskHandler) GetTaskExecutorContext(ctx context.Context, w v1alpha1.ExecutableWorkflow, + node v1alpha1.ExecutableNode) (pluginsV1.Executor, v1alpha1.ExecutableTask, pluginsV1.TaskContext, error) { + + taskID := node.GetTaskID() + if taskID == nil { + return nil, nil, nil, errors.Errorf(errors.BadSpecificationError, node.GetID(), "Task Id not set for NodeKind `Task`") + } + task, err := w.GetTask(*taskID) + if err != nil { + return nil, nil, nil, errors.Wrapf(errors.BadSpecificationError, node.GetID(), err, "Unable to find task for taskId: [%v]", *taskID) + } + + exec, err := h.taskFactory.GetTaskExecutor(task.TaskType()) + if err != nil { + h.metrics.unsupportedTaskType.Inc(ctx) + return nil, nil, nil, errors.Wrapf(errors.UnsupportedTaskTypeError, node.GetID(), err, + "Unable to find taskExecutor for taskId: [%v]. TaskType: [%v]", *taskID, task.TaskType()) + } + + nodeStatus := w.GetNodeExecutionStatus(node.GetID()) + id := core.TaskExecutionIdentifier{ + TaskId: task.CoreTask().Id, + RetryAttempt: nodeStatus.GetAttempts(), + NodeExecutionId: &core.NodeExecutionIdentifier{ + NodeId: node.GetID(), + ExecutionId: w.GetExecutionID().WorkflowExecutionIdentifier, + }, + } + + uniqueID, err := utils.FixedLengthUniqueIDForParts(IDMaxLength, w.GetName(), node.GetID(), strconv.Itoa(int(id.RetryAttempt))) + if err != nil { + // SHOULD never really happen + return nil, nil, nil, err + } + + taskNodeStatus := nodeStatus.GetTaskNodeStatus() + if taskNodeStatus == nil { + mutableTaskNodeStatus := nodeStatus.GetOrCreateTaskStatus() + taskNodeStatus = mutableTaskNodeStatus + } + + return exec, task, &taskContext{ + taskExecutionID: taskExecutionID{execName: uniqueID, id: id}, + dataDir: nodeStatus.GetDataDir(), + workflow: w, + node: node, + status: taskNodeStatus, + serviceAccountName: w.GetServiceAccountName(), + }, nil +} + +func (h *taskHandler) ExtractOutput(ctx context.Context, w v1alpha1.ExecutableWorkflow, n v1alpha1.ExecutableNode, + bindToVar handler.VarName) (values *core.Literal, err error) { + t, task, taskCtx, err := h.GetTaskExecutorContext(ctx, w, n) + if err != nil { + return nil, errors.Wrapf(errors.CausedByError, n.GetID(), err, "failed to create TaskCtx") + } + + l, err := t.ResolveOutputs(ctx, taskCtx, bindToVar) + if err != nil { + return nil, errors.Wrapf(errors.CausedByError, n.GetID(), err, + "failed to resolve output [%v] from task of type [%v]", bindToVar, task.TaskType()) + } + + return l[bindToVar], nil +} + +func (h *taskHandler) StartNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, nodeInputs *handler.Data) (handler.Status, error) { + t, task, taskCtx, err := h.GetTaskExecutorContext(ctx, w, node) + if err != nil { + return handler.StatusFailed(errors.Wrapf(errors.CausedByError, node.GetID(), err, "failed to create TaskCtx")), nil + } + + logger.Infof(ctx, "Executor type: [%v]. Properties: finalizer[%v]. disable[%v].", reflect.TypeOf(t).String(), t.GetProperties().RequiresFinalizer, t.GetProperties().DisableNodeLevelCaching) + if task.CoreTask().Metadata.Discoverable { + if t.GetProperties().DisableNodeLevelCaching { + logger.Infof(ctx, "Executor has Node-Level caching disabled. Skipping.") + } else if resp, err := h.catalogClient.Get(ctx, task.CoreTask(), taskCtx.GetInputsFile()); err != nil { + if taskStatus, ok := status.FromError(err); ok && taskStatus.Code() == codes.NotFound { + h.metrics.discoveryMissCount.Inc(ctx) + logger.Infof(ctx, "Artifact not found in Discovery. Executing Task.") + } else { + h.metrics.discoveryGetFailureCount.Inc(ctx) + logger.Errorf(ctx, "Discovery check failed. Executing Task. Err: %v", err.Error()) + } + } else if resp != nil { + h.metrics.discoveryHitCount.Inc(ctx) + if iface := task.CoreTask().Interface; iface != nil && iface.Outputs != nil && len(iface.Outputs.Variables) > 0 { + if err := h.store.WriteProtobuf(ctx, taskCtx.GetOutputsFile(), storage.Options{}, resp); err != nil { + logger.Errorf(ctx, "failed to write data to Storage, err: %v", err.Error()) + return handler.StatusUndefined, errors.Wrapf(errors.CausedByError, node.GetID(), err, "failed to copy cached results for task.") + } + } + // SetCached. + w.GetNodeExecutionStatus(node.GetID()).SetCached() + return handler.StatusSuccess, nil + } else { + // Nil response and Nil error + h.metrics.discoveryGetFailureCount.Inc(ctx) + return handler.StatusUndefined, errors.Wrapf(errors.CatalogCallFailed, node.GetID(), err, "Nil catalog response. Failed to check Catalog for previous results") + } + } + + var taskStatus pluginsV1.TaskStatus + func() { + defer func() { + if r := recover(); r != nil { + h.metrics.pluginPanics.Inc(ctx) + stack := debug.Stack() + err = fmt.Errorf("panic when executing a plugin for TaskType [%s]. Stack: [%s]", task.TaskType(), string(stack)) + logger.Errorf(ctx, "Panic in plugin for TaskType [%s]", task.TaskType()) + } + }() + taskStatus, err = t.StartTask(ctx, taskCtx, task.CoreTask(), nodeInputs) + }() + + if err != nil { + return handler.StatusUndefined, errors.Wrapf(errors.CausedByError, node.GetID(), err, "failed to start task [retry attempt: %d]", taskCtx.GetTaskExecutionID().GetID().RetryAttempt) + } + + nodeStatus := w.GetNodeExecutionStatus(node.GetID()) + taskNodeStatus := nodeStatus.GetOrCreateTaskStatus() + taskNodeStatus.SetPhase(taskStatus.Phase) + taskNodeStatus.SetPhaseVersion(taskStatus.PhaseVersion) + taskNodeStatus.SetCustomState(taskStatus.State) + + logger.Debugf(ctx, "Started Task Node") + return ConvertTaskPhaseToHandlerStatus(taskStatus) +} + +func ConvertTaskPhaseToHandlerStatus(taskStatus pluginsV1.TaskStatus) (handler.Status, error) { + // TODO handle retryable failure + switch taskStatus.Phase { + case pluginsV1.TaskPhaseNotReady: + return handler.StatusQueued.WithOccurredAt(taskStatus.OccurredAt), nil + case pluginsV1.TaskPhaseQueued, pluginsV1.TaskPhaseRunning: + return handler.StatusRunning.WithOccurredAt(taskStatus.OccurredAt), nil + case pluginsV1.TaskPhasePermanentFailure: + return handler.StatusFailed(taskStatus.Err).WithOccurredAt(taskStatus.OccurredAt), nil + case pluginsV1.TaskPhaseRetryableFailure: + return handler.StatusRetryableFailure(taskStatus.Err).WithOccurredAt(taskStatus.OccurredAt), nil + case pluginsV1.TaskPhaseSucceeded: + return handler.StatusSuccess.WithOccurredAt(taskStatus.OccurredAt), nil + default: + return handler.StatusUndefined, errors.Errorf(errors.IllegalStateError, "received unknown task phase. [%s]", taskStatus.Phase.String()) + } +} + +func (h *taskHandler) CheckNodeStatus(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, prevNodeStatus v1alpha1.ExecutableNodeStatus) (handler.Status, error) { + t, task, taskCtx, err := h.GetTaskExecutorContext(ctx, w, node) + if err != nil { + return handler.StatusFailed(errors.Wrapf(errors.CausedByError, node.GetID(), err, "Failed to create TaskCtx")), nil + } + + var taskStatus pluginsV1.TaskStatus + func() { + defer func() { + if r := recover(); r != nil { + h.metrics.pluginPanics.Inc(ctx) + stack := debug.Stack() + err = fmt.Errorf("panic when executing a plugin for TaskType [%s]. Stack: [%s]", task.TaskType(), string(stack)) + logger.Errorf(ctx, "Panic in plugin for TaskType [%s]", task.TaskType()) + } + }() + taskStatus, err = t.CheckTaskStatus(ctx, taskCtx, task.CoreTask()) + }() + + if err != nil { + logger.Warnf(ctx, "Failed to check status") + return handler.StatusUndefined, errors.Wrapf(errors.CausedByError, node.GetID(), err, "failed to check status") + } + + nodeStatus := w.GetNodeExecutionStatus(node.GetID()) + taskNodeStatus := nodeStatus.GetOrCreateTaskStatus() + taskNodeStatus.SetPhase(taskStatus.Phase) + taskNodeStatus.SetPhaseVersion(taskStatus.PhaseVersion) + taskNodeStatus.SetCustomState(taskStatus.State) + + return ConvertTaskPhaseToHandlerStatus(taskStatus) +} + +func (h *taskHandler) HandleNodeSuccess(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) (handler.Status, error) { + t, task, taskCtx, err := h.GetTaskExecutorContext(ctx, w, node) + if err != nil { + return handler.StatusFailed(errors.Wrapf(errors.CausedByError, node.GetID(), err, "Failed to create TaskCtx")), nil + } + + // If the task interface has outputs, validate that the outputs file was written. + if iface := task.CoreTask().Interface; task.TaskType() != "container_array" && iface != nil && iface.Outputs != nil && len(iface.Outputs.Variables) > 0 { + if metadata, err := h.store.Head(ctx, taskCtx.GetOutputsFile()); err != nil { + return handler.StatusUndefined, errors.Wrapf(errors.CausedByError, node.GetID(), err, "failed to HEAD task outputs file.") + } else if !metadata.Exists() { + return handler.StatusRetryableFailure(errors.Errorf(errors.OutputsNotFoundError, node.GetID(), + "Outputs not found for task type %s, looking for output file %s", task.TaskType(), taskCtx.GetOutputsFile())), nil + } + + // ignores discovery write failures + if task.CoreTask().Metadata.Discoverable && !t.GetProperties().DisableNodeLevelCaching { + taskExecutionID := taskCtx.GetTaskExecutionID().GetID() + if err2 := h.catalogClient.Put(ctx, task.CoreTask(), &taskExecutionID, taskCtx.GetInputsFile(), taskCtx.GetOutputsFile()); err2 != nil { + h.metrics.discoveryPutFailureCount.Inc(ctx) + logger.Errorf(ctx, "Failed to write results to catalog. Err: %v", err2) + } else { + logger.Debugf(ctx, "Successfully cached results to discovery - Task [%s]", task.CoreTask().GetId()) + } + } + } + return handler.StatusSuccess, nil +} + +func (h *taskHandler) HandleFailingNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) (handler.Status, error) { + return handler.StatusFailed(errors.Errorf(errors.IllegalStateError, node.GetID(), "A regular Task node cannot enter a failing state")), nil +} + +func (h *taskHandler) Initialize(ctx context.Context) error { + logger.Infof(ctx, "Initializing taskHandler") + enqueueFn := func(ownerId types.NamespacedName) error { + h.enqueueWf(ownerId.String()) + return nil + } + + initParams := pluginsV1.ExecutorInitializationParameters{ + CatalogClient: h.catalogClient, + EventRecorder: h.recorder, + DataStore: h.store, + EnqueueOwner: enqueueFn, + OwnerKind: v1alpha1.FlyteWorkflowKind, + MetricsScope: h.scope, + } + + for _, r := range h.taskFactory.ListAllTaskExecutors() { + logger.Infof(ctx, "Initializing Executor [%v]", r.GetID()) + // Inject a RuntimeClient if the executor needs one. + if _, err := inject.ClientInto(h.kubeClient.GetClient(), r); err != nil { + return errors2.Wrapf(err, "Failed to initialize [%v]", r.GetID()) + } + + if _, err := inject.CacheInto(h.kubeClient.GetCache(), r); err != nil { + return errors2.Wrapf(err, "Failed to initialize [%v]", r.GetID()) + } + + err := r.Initialize(ctx, initParams) + if err != nil { + return errors2.Wrapf(err, "Failed to Initialize TaskExecutor [%v]", r.GetID()) + } + } + + logger.Infof(ctx, "taskHandler Initialization complete") + return nil +} + +func (h *taskHandler) AbortNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) error { + t, _, taskCtx, err := h.GetTaskExecutorContext(ctx, w, node) + if err != nil { + return errors.Wrapf(errors.CausedByError, node.GetID(), err, "failed to create TaskCtx") + } + + err = t.KillTask(ctx, taskCtx, "Node aborted") + if err != nil { + return errors.Wrapf(errors.CausedByError, node.GetID(), err, "failed to abort task") + } + // TODO: Do we need to update the Node status to Failed here as well ? + logger.Infof(ctx, "Invoked KillTask on Task Node.") + return nil +} + +func NewTaskHandlerForFactory(eventSink events.EventSink, store *storage.DataStore, enqueueWf v1alpha1.EnqueueWorkflow, + tf Factory, catalogClient catalog.Client, kubeClient executors.Client, scope promutils.Scope) handler.IFace { + + // create a recorder for the plugins + eventsRecorder := utils.NewPluginTaskEventRecorder(events.NewTaskEventRecorder(eventSink, scope)) + return &taskHandler{ + taskFactory: tf, + recorder: eventsRecorder, + enqueueWf: enqueueWf, + store: store, + scope: scope, + catalogClient: catalogClient, + kubeClient: kubeClient, + metrics: &metrics{ + pluginPanics: labeled.NewCounter("plugin_panic", "Task plugin paniced when trying to execute a task.", scope), + unsupportedTaskType: labeled.NewCounter("unsupported_tasktype", "No task plugin configured for task type", scope), + discoveryHitCount: labeled.NewCounter("discovery_hit_count", "Task cached in Discovery", scope), + discoveryMissCount: labeled.NewCounter("discovery_miss_count", "Task not cached in Discovery", scope), + discoveryPutFailureCount: labeled.NewCounter("discovery_put_failure_count", "Discovery Put failure count", scope), + discoveryGetFailureCount: labeled.NewCounter("discovery_get_failure_count", "Discovery Get faillure count", scope), + }, + } +} + +func New(eventSink events.EventSink, store *storage.DataStore, enqueueWf v1alpha1.EnqueueWorkflow, revalPeriod time.Duration, + catalogClient catalog.Client, kubeClient executors.Client, scope promutils.Scope) handler.IFace { + + return NewTaskHandlerForFactory(eventSink, store, enqueueWf, NewFactory(revalPeriod), + catalogClient, kubeClient, scope.NewSubScope("task")) +} diff --git a/flytepropeller/pkg/controller/nodes/task/handler_test.go b/flytepropeller/pkg/controller/nodes/task/handler_test.go new file mode 100644 index 0000000000..1832e82f74 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/task/handler_test.go @@ -0,0 +1,769 @@ +package task + +import ( + "context" + "fmt" + "reflect" + "testing" + + mocks2 "github.com/lyft/flytepropeller/pkg/controller/executors/mocks" + + "github.com/lyft/flytestdlib/promutils" + + "github.com/lyft/flyteidl/clients/go/events" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + pluginsV1 "github.com/lyft/flyteplugins/go/tasks/v1/types" + "github.com/lyft/flyteplugins/go/tasks/v1/types/mocks" + "github.com/lyft/flytestdlib/storage" + regErrors "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + typesV1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/controller/catalog" + "github.com/lyft/flytepropeller/pkg/controller/nodes/errors" + "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" +) + +const DataDir = storage.DataReference("test-data") +const NodeID = "n1" + +var ( + enqueueWfFunc = func(id string) {} + fakeKubeClient = mocks2.NewFakeKubeClient() +) + +func mockCatalogClient() catalog.Client { + return &catalog.MockCatalogClient{ + GetFunc: func(ctx context.Context, task *core.TaskTemplate, inputPath storage.DataReference) (*core.LiteralMap, error) { + return nil, nil + }, + PutFunc: func(ctx context.Context, task *core.TaskTemplate, execId *core.TaskExecutionIdentifier, inputPath storage.DataReference, outputPath storage.DataReference) error { + return nil + }, + } +} + +func createWf(id string, execID string, project string, domain string, name string) *v1alpha1.FlyteWorkflow { + return &v1alpha1.FlyteWorkflow{ + ExecutionID: v1alpha1.WorkflowExecutionIdentifier{ + WorkflowExecutionIdentifier: &core.WorkflowExecutionIdentifier{ + Project: project, + Domain: domain, + Name: execID, + }, + }, + Status: v1alpha1.WorkflowStatus{ + NodeStatus: map[v1alpha1.NodeID]*v1alpha1.NodeStatus{ + NodeID: { + DataDir: DataDir, + }, + }, + }, + ObjectMeta: v1.ObjectMeta{ + Name: name, + }, + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: id, + }, + } +} + +func createStartNode() *v1alpha1.NodeSpec { + return &v1alpha1.NodeSpec{ + ID: NodeID, + Kind: v1alpha1.NodeKindStart, + Resources: &typesV1.ResourceRequirements{ + Requests: typesV1.ResourceList{ + typesV1.ResourceCPU: resource.MustParse("1"), + }, + }, + } +} + +func createTask(id string, ttype string, discoverable bool) *v1alpha1.TaskSpec { + return &v1alpha1.TaskSpec{ + TaskTemplate: &core.TaskTemplate{ + Id: &core.Identifier{Name: id}, + Type: ttype, + Metadata: &core.TaskMetadata{Discoverable: discoverable}, + Interface: &core.TypedInterface{ + Inputs: &core.VariableMap{}, + Outputs: &core.VariableMap{ + Variables: map[string]*core.Variable{ + "out1": &core.Variable{ + Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}}, + }, + }, + }, + }, + }, + } +} + +func createDummyExec() *mocks.Executor { + dummyExec := &mocks.Executor{} + dummyExec.On("Initialize", + mock.AnythingOfType(reflect.TypeOf(context.TODO()).String()), + mock.AnythingOfType(reflect.TypeOf(pluginsV1.ExecutorInitializationParameters{}).String()), + ).Return(nil) + dummyExec.On("GetID").Return("test") + + return dummyExec +} + +func TestTaskHandler_Initialize(t *testing.T) { + ctx := context.TODO() + t.Run("NoHandlers", func(t *testing.T) { + d := &FactoryFuncs{} + th := NewTaskHandlerForFactory(events.NewMockEventSink(), nil, enqueueWfFunc, d, mockCatalogClient(), fakeKubeClient, promutils.NewTestScope()) + assert.NoError(t, th.Initialize(context.TODO())) + }) + + t.Run("SomeHandler", func(t *testing.T) { + d := &FactoryFuncs{ + ListAllTaskExecutorsCb: func() []pluginsV1.Executor { + return []pluginsV1.Executor{ + createDummyExec(), + } + }, + } + th := NewTaskHandlerForFactory(events.NewMockEventSink(), nil, enqueueWfFunc, d, mockCatalogClient(), fakeKubeClient, promutils.NewTestScope()) + assert.NoError(t, th.Initialize(ctx)) + }) +} + +func TestTaskHandler_HandleFailingNode(t *testing.T) { + ctx := context.Background() + d := &FactoryFuncs{} + th := NewTaskHandlerForFactory(events.NewMockEventSink(), nil, enqueueWfFunc, d, mockCatalogClient(), fakeKubeClient, promutils.NewTestScope()) + + w := createWf("w1", "w2-exec", "project", "domain", "execName1") + n := createStartNode() + s, err := th.HandleFailingNode(ctx, w, n) + assert.NoError(t, err) + assert.Equal(t, handler.PhaseFailed, s.Phase) + assert.Error(t, s.Err) +} + +func TestTaskHandler_GetTaskExecutorContext(t *testing.T) { + ctx := context.Background() + const execName = "w1-exec" + t.Run("NoTaskId", func(t *testing.T) { + w := createWf("w1", execName, "project", "domain", "execName1") + n := createStartNode() + d := &FactoryFuncs{} + th := NewTaskHandlerForFactory(events.NewMockEventSink(), nil, enqueueWfFunc, d, mockCatalogClient(), fakeKubeClient, promutils.NewTestScope()).(*taskHandler) + + _, _, _, err := th.GetTaskExecutorContext(ctx, w, n) + assert.Error(t, err) + assert.True(t, errors.Matches(err, errors.BadSpecificationError)) + }) + + t.Run("NoTaskMatch", func(t *testing.T) { + taskID := "t1" + w := createWf("w1", execName, "project", "domain", "execName1") + n := createStartNode() + n.TaskRef = &taskID + + d := &FactoryFuncs{} + th := NewTaskHandlerForFactory(events.NewMockEventSink(), nil, enqueueWfFunc, d, mockCatalogClient(), fakeKubeClient, promutils.NewTestScope()).(*taskHandler) + _, _, _, err := th.GetTaskExecutorContext(ctx, w, n) + assert.Error(t, err) + assert.True(t, errors.Matches(err, errors.BadSpecificationError)) + }) + + t.Run("TaskMatchNoExecutor", func(t *testing.T) { + taskID := "t1" + task := createTask(taskID, "dynamic", false) + + w := createWf("w1", execName, "project", "domain", "execName1") + w.Tasks = map[v1alpha1.TaskID]*v1alpha1.TaskSpec{ + taskID: task, + } + + n := createStartNode() + n.TaskRef = &taskID + + d := &FactoryFuncs{} + th := NewTaskHandlerForFactory(events.NewMockEventSink(), nil, enqueueWfFunc, d, mockCatalogClient(), fakeKubeClient, promutils.NewTestScope()).(*taskHandler) + _, _, _, err := th.GetTaskExecutorContext(ctx, w, n) + assert.Error(t, err) + assert.True(t, errors.Matches(err, errors.UnsupportedTaskTypeError)) + }) + + t.Run("TaskMatch", func(t *testing.T) { + taskID := "t1" + task := createTask(taskID, "container", false) + w := createWf("w1", execName, "project", "domain", "execName1") + w.Tasks = map[v1alpha1.TaskID]*v1alpha1.TaskSpec{ + taskID: task, + } + w.ServiceAccountName = "service-account" + n := createStartNode() + n.TaskRef = &taskID + + taskExec := &mocks.Executor{} + d := &FactoryFuncs{ + GetTaskExecutorCb: func(taskType v1alpha1.TaskType) (pluginsV1.Executor, error) { + if taskType == task.Type { + return taskExec, nil + } + + return nil, regErrors.New("No match") + }, + } + th := NewTaskHandlerForFactory(events.NewMockEventSink(), nil, enqueueWfFunc, d, mockCatalogClient(), fakeKubeClient, promutils.NewTestScope()).(*taskHandler) + te, receivedTask, tc, err := th.GetTaskExecutorContext(ctx, w, n) + if assert.NoError(t, err) { + assert.Equal(t, taskExec, te) + if assert.NotNil(t, tc) { + assert.Equal(t, "execName1-n1-0", tc.GetTaskExecutionID().GetGeneratedName()) + assert.Equal(t, DataDir, tc.GetDataDir()) + assert.NotNil(t, tc.GetOverrides()) + assert.NotNil(t, tc.GetOverrides().GetResources()) + assert.NotEmpty(t, tc.GetOverrides().GetResources().Requests) + assert.Equal(t, "service-account", tc.GetK8sServiceAccount()) + } + assert.Equal(t, task, receivedTask) + } + }) + + t.Run("TaskMatchAttempt>0", func(t *testing.T) { + taskID := "t1" + task := createTask(taskID, "container", false) + w := createWf("w1", execName, "project", "domain", "execName1") + w.Tasks = map[v1alpha1.TaskID]*v1alpha1.TaskSpec{ + taskID: task, + } + n := createStartNode() + n.TaskRef = &taskID + + status := w.Status.GetNodeExecutionStatus(n.ID).(*v1alpha1.NodeStatus) + status.Attempts = 2 + + taskExec := &mocks.Executor{} + d := &FactoryFuncs{ + GetTaskExecutorCb: func(taskType v1alpha1.TaskType) (pluginsV1.Executor, error) { + if taskType == task.Type { + return taskExec, nil + } + return nil, regErrors.New("No match") + }, + } + th := NewTaskHandlerForFactory(events.NewMockEventSink(), nil, enqueueWfFunc, d, mockCatalogClient(), fakeKubeClient, promutils.NewTestScope()).(*taskHandler) + te, receivedTask, tc, err := th.GetTaskExecutorContext(ctx, w, n) + if assert.NoError(t, err) { + assert.Equal(t, taskExec, te) + if assert.NotNil(t, tc) { + assert.Equal(t, "execName1-n1-2", tc.GetTaskExecutionID().GetGeneratedName()) + assert.Equal(t, DataDir, tc.GetDataDir()) + assert.NotNil(t, tc.GetOverrides()) + assert.NotNil(t, tc.GetOverrides().GetResources()) + assert.NotEmpty(t, tc.GetOverrides().GetResources().Requests) + } + assert.Equal(t, task, receivedTask) + } + }) + +} + +func TestTaskHandler_StartNode(t *testing.T) { + ctx := context.Background() + taskID := "t1" + task := createTask(taskID, "container", false) + w := createWf("w2", "w2-exec", "project", "domain", "execName") + w.Tasks = map[v1alpha1.TaskID]*v1alpha1.TaskSpec{ + taskID: task, + } + n := createStartNode() + n.TaskRef = &taskID + + t.Run("NoTaskExec", func(t *testing.T) { + taskExec := &mocks.Executor{} + d := &FactoryFuncs{ + GetTaskExecutorCb: func(taskType v1alpha1.TaskType) (pluginsV1.Executor, error) { + if taskType == task.Type { + return nil, regErrors.New("No match") + } + return taskExec, nil + }, + } + th := NewTaskHandlerForFactory(events.NewMockEventSink(), nil, enqueueWfFunc, d, mockCatalogClient(), fakeKubeClient, promutils.NewTestScope()) + + s, err := th.StartNode(ctx, w, n, nil) + assert.NoError(t, err) + assert.Error(t, s.Err) + assert.True(t, errors.Matches(s.Err, errors.CausedByError)) + }) + + t.Run("TaskExecStartFail", func(t *testing.T) { + taskExec := &mocks.Executor{} + taskExec.On("GetProperties").Return(pluginsV1.ExecutorProperties{}) + taskExec.On("StartTask", + ctx, + mock.MatchedBy(func(o pluginsV1.TaskContext) bool { return true }), + mock.MatchedBy(func(o *core.TaskTemplate) bool { return true }), + mock.MatchedBy(func(o *core.LiteralMap) bool { return true }), + ).Return(pluginsV1.TaskStatusPermanentFailure(regErrors.New("Failed")), nil) + d := &FactoryFuncs{ + GetTaskExecutorCb: func(taskType v1alpha1.TaskType) (pluginsV1.Executor, error) { + if taskType == task.Type { + return taskExec, nil + } + return nil, regErrors.New("No match") + }, + } + th := NewTaskHandlerForFactory(events.NewMockEventSink(), nil, enqueueWfFunc, d, mockCatalogClient(), fakeKubeClient, promutils.NewTestScope()) + + s, err := th.StartNode(ctx, w, n, nil) + assert.NoError(t, err) + assert.Equal(t, handler.PhaseFailed, s.Phase) + }) + + t.Run("TaskExecStartPanic", func(t *testing.T) { + taskExec := &mocks.Executor{} + taskExec.On("GetProperties").Return(pluginsV1.ExecutorProperties{}) + taskExec.On("StartTask", + ctx, + mock.MatchedBy(func(o pluginsV1.TaskContext) bool { return true }), + mock.MatchedBy(func(o *core.TaskTemplate) bool { return true }), + mock.MatchedBy(func(o *core.LiteralMap) bool { return true }), + ).Return( + func(ctx context.Context, taskCtx pluginsV1.TaskContext, task *core.TaskTemplate, inputs *core.LiteralMap) (pluginsV1.TaskStatus, error) { + panic("failed in execution") + }, + ) + d := &FactoryFuncs{ + GetTaskExecutorCb: func(taskType v1alpha1.TaskType) (pluginsV1.Executor, error) { + if taskType == task.Type { + return taskExec, nil + } + return nil, regErrors.New("No match") + }, + } + th := NewTaskHandlerForFactory(events.NewMockEventSink(), nil, enqueueWfFunc, d, mockCatalogClient(), fakeKubeClient, promutils.NewTestScope()) + s, err := th.StartNode(ctx, w, n, nil) + assert.Error(t, err) + assert.Equal(t, handler.PhaseUndefined, s.Phase) + }) + + t.Run("TaskExecStarted", func(t *testing.T) { + taskExec := &mocks.Executor{} + taskExec.On("GetProperties").Return(pluginsV1.ExecutorProperties{}) + taskExec.On("StartTask", + ctx, + mock.MatchedBy(func(o pluginsV1.TaskContext) bool { return true }), + mock.MatchedBy(func(o *core.TaskTemplate) bool { return true }), + mock.MatchedBy(func(o *core.LiteralMap) bool { return true }), + ).Return(pluginsV1.TaskStatusRunning, nil) + d := &FactoryFuncs{ + GetTaskExecutorCb: func(taskType v1alpha1.TaskType) (pluginsV1.Executor, error) { + if taskType == task.Type { + return taskExec, nil + } + return nil, regErrors.New("No match") + }, + } + th := NewTaskHandlerForFactory(events.NewMockEventSink(), nil, enqueueWfFunc, d, mockCatalogClient(), fakeKubeClient, promutils.NewTestScope()) + + s, err := th.StartNode(ctx, w, n, nil) + assert.NoError(t, err) + assert.Equal(t, handler.StatusRunning, s) + }) +} + +func TestTaskHandler_StartNodeDiscoverable(t *testing.T) { + ctx := context.Background() + taskID := "t1" + task := createTask(taskID, "container", true) + task.Id.Project = "flytekit" + w := createWf("w2", "w2-exec", "flytekit", "domain", "execName") + w.Tasks = map[v1alpha1.TaskID]*v1alpha1.TaskSpec{ + taskID: task, + } + n := createStartNode() + n.TaskRef = &taskID + + t.Run("TaskExecStartNodeDiscoveryFail", func(t *testing.T) { + taskExec := &mocks.Executor{} + taskExec.On("GetProperties").Return(pluginsV1.ExecutorProperties{}) + taskExec.On("StartTask", + ctx, + mock.MatchedBy(func(o pluginsV1.TaskContext) bool { return true }), + mock.MatchedBy(func(o *core.TaskTemplate) bool { return true }), + mock.MatchedBy(func(o *core.LiteralMap) bool { return true }), + ).Return(pluginsV1.TaskStatusRunning, nil) + d := &FactoryFuncs{ + GetTaskExecutorCb: func(taskType v1alpha1.TaskType) (pluginsV1.Executor, error) { + if taskType == task.Type { + return taskExec, nil + } + return nil, regErrors.New("No match") + }, + } + mockCatalog := catalog.MockCatalogClient{ + GetFunc: func(ctx context.Context, task *core.TaskTemplate, inputPath storage.DataReference) (*core.LiteralMap, error) { + return nil, regErrors.Errorf("error") + }, + PutFunc: func(ctx context.Context, task *core.TaskTemplate, execId *core.TaskExecutionIdentifier, inputPath storage.DataReference, outputPath storage.DataReference) error { + return nil + }, + } + th := NewTaskHandlerForFactory(events.NewMockEventSink(), nil, enqueueWfFunc, d, &mockCatalog, fakeKubeClient, promutils.NewTestScope()) + + s, err := th.StartNode(ctx, w, n, nil) + assert.NoError(t, err) + assert.Equal(t, handler.StatusRunning, s) + }) + + t.Run("TaskExecStartNodeDiscoveryMiss", func(t *testing.T) { + taskExec := &mocks.Executor{} + taskExec.On("GetProperties").Return(pluginsV1.ExecutorProperties{}) + taskExec.On("StartTask", + ctx, + mock.MatchedBy(func(o pluginsV1.TaskContext) bool { return true }), + mock.MatchedBy(func(o *core.TaskTemplate) bool { return true }), + mock.MatchedBy(func(o *core.LiteralMap) bool { return true }), + ).Return(pluginsV1.TaskStatusRunning, nil) + d := &FactoryFuncs{ + GetTaskExecutorCb: func(taskType v1alpha1.TaskType) (pluginsV1.Executor, error) { + if taskType == task.Type { + return taskExec, nil + } + return nil, regErrors.New("No match") + }, + } + mockCatalog := catalog.MockCatalogClient{ + GetFunc: func(ctx context.Context, task *core.TaskTemplate, inputPath storage.DataReference) (*core.LiteralMap, error) { + return nil, status.Errorf(codes.NotFound, "not found") + }, + PutFunc: func(ctx context.Context, task *core.TaskTemplate, execId *core.TaskExecutionIdentifier, inputPath storage.DataReference, outputPath storage.DataReference) error { + return nil + }, + } + th := NewTaskHandlerForFactory(events.NewMockEventSink(), nil, enqueueWfFunc, d, &mockCatalog, fakeKubeClient, promutils.NewTestScope()) + + s, err := th.StartNode(ctx, w, n, nil) + assert.NoError(t, err) + assert.Equal(t, handler.StatusRunning, s) + }) + + t.Run("TaskExecStartNodeDiscoveryHit", func(t *testing.T) { + taskExec := &mocks.Executor{} + taskExec.On("GetProperties").Return(pluginsV1.ExecutorProperties{}) + taskExec.On("StartTask", + ctx, + mock.MatchedBy(func(o pluginsV1.TaskContext) bool { return true }), + mock.MatchedBy(func(o *core.TaskTemplate) bool { return true }), + mock.MatchedBy(func(o *core.LiteralMap) bool { return true }), + ).Return(pluginsV1.TaskStatusRunning, nil) + d := &FactoryFuncs{ + GetTaskExecutorCb: func(taskType v1alpha1.TaskType) (pluginsV1.Executor, error) { + if taskType == task.Type { + return taskExec, nil + } + return nil, regErrors.New("No match") + }, + } + mockCatalog := catalog.MockCatalogClient{ + GetFunc: func(ctx context.Context, task *core.TaskTemplate, inputPath storage.DataReference) (*core.LiteralMap, error) { + paramsMap := make(map[string]*core.Literal) + paramsMap["out1"] = newIntegerLiteral(100) + + return &core.LiteralMap{ + Literals: paramsMap, + }, nil + }, + PutFunc: func(ctx context.Context, task *core.TaskTemplate, execId *core.TaskExecutionIdentifier, inputPath storage.DataReference, outputPath storage.DataReference) error { + return nil + }, + } + store := createInmemoryDataStore(t, testScope.NewSubScope("12")) + th := NewTaskHandlerForFactory(events.NewMockEventSink(), store, enqueueWfFunc, d, &mockCatalog, fakeKubeClient, promutils.NewTestScope()) + + s, err := th.StartNode(ctx, w, n, nil) + assert.NoError(t, err) + assert.Equal(t, handler.StatusSuccess, s) + }) +} + +func TestTaskHandler_AbortNode(t *testing.T) { + ctx := context.Background() + taskID := "t1" + task := createTask(taskID, "container", false) + w := createWf("w2", "w2-exec", "project", "domain", "execName") + w.Tasks = map[v1alpha1.TaskID]*v1alpha1.TaskSpec{ + taskID: task, + } + n := createStartNode() + n.TaskRef = &taskID + + t.Run("NoTaskExec", func(t *testing.T) { + taskExec := &mocks.Executor{} + taskExec.On("GetProperties").Return(pluginsV1.ExecutorProperties{}) + d := &FactoryFuncs{ + GetTaskExecutorCb: func(taskType v1alpha1.TaskType) (pluginsV1.Executor, error) { + if taskType == task.Type { + return nil, regErrors.New("No match") + } + return taskExec, nil + }, + } + th := NewTaskHandlerForFactory(events.NewMockEventSink(), nil, enqueueWfFunc, d, mockCatalogClient(), fakeKubeClient, promutils.NewTestScope()) + + err := th.AbortNode(ctx, w, n) + assert.Error(t, err) + assert.True(t, errors.Matches(err, errors.CausedByError)) + }) + + t.Run("TaskExecKillFail", func(t *testing.T) { + taskExec := &mocks.Executor{} + taskExec.On("GetProperties").Return(pluginsV1.ExecutorProperties{}) + taskExec.On("KillTask", + ctx, + mock.MatchedBy(func(o pluginsV1.TaskContext) bool { return true }), + mock.Anything, + ).Return(regErrors.New("Failed")) + d := &FactoryFuncs{ + GetTaskExecutorCb: func(taskType v1alpha1.TaskType) (pluginsV1.Executor, error) { + if taskType == task.Type { + return taskExec, nil + } + return nil, regErrors.New("No match") + }, + } + th := NewTaskHandlerForFactory(events.NewMockEventSink(), nil, enqueueWfFunc, d, mockCatalogClient(), fakeKubeClient, promutils.NewTestScope()) + + err := th.AbortNode(ctx, w, n) + assert.Error(t, err) + assert.True(t, errors.Matches(err, errors.CausedByError)) + }) + + t.Run("TaskExecKilled", func(t *testing.T) { + taskExec := &mocks.Executor{} + taskExec.On("GetProperties").Return(pluginsV1.ExecutorProperties{}) + taskExec.On("KillTask", + ctx, + mock.MatchedBy(func(o pluginsV1.TaskContext) bool { return true }), + mock.Anything, + ).Return(nil) + d := &FactoryFuncs{ + GetTaskExecutorCb: func(taskType v1alpha1.TaskType) (pluginsV1.Executor, error) { + if taskType == task.Type { + return taskExec, nil + } + return nil, regErrors.New("No match") + }, + } + th := NewTaskHandlerForFactory(events.NewMockEventSink(), nil, enqueueWfFunc, d, mockCatalogClient(), fakeKubeClient, promutils.NewTestScope()) + + err := th.AbortNode(ctx, w, n) + assert.NoError(t, err) + }) +} + +func createInmemoryDataStore(t testing.TB, scope promutils.Scope) *storage.DataStore { + cfg := storage.Config{ + Type: storage.TypeMemory, + } + d, err := storage.NewDataStore(&cfg, scope) + assert.NoError(t, err) + return d +} + +func newIntegerPrimitive(value int64) *core.Primitive { + return &core.Primitive{Value: &core.Primitive_Integer{Integer: value}} +} + +func newScalarInteger(value int64) *core.Scalar { + return &core.Scalar{ + Value: &core.Scalar_Primitive{ + Primitive: newIntegerPrimitive(value), + }, + } +} + +func newIntegerLiteral(value int64) *core.Literal { + return &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: newScalarInteger(value), + }, + } +} + +var testScope = promutils.NewScope("test_wfexec") + +func TestTaskHandler_CheckNodeStatus(t *testing.T) { + ctx := context.Background() + + taskID := "t1" + task := createTask(taskID, "container", false) + w := createWf("w1", "w2-exec", "projTest", "domainTest", "checkNodeTestName") + w.Tasks = map[v1alpha1.TaskID]*v1alpha1.TaskSpec{ + taskID: task, + } + n := createStartNode() + n.TaskRef = &taskID + + t.Run("NoTaskExec", func(t *testing.T) { + taskExec := &mocks.Executor{} + taskExec.On("GetProperties").Return(pluginsV1.ExecutorProperties{}) + d := &FactoryFuncs{ + GetTaskExecutorCb: func(taskType v1alpha1.TaskType) (pluginsV1.Executor, error) { + if taskType == task.Type { + return nil, regErrors.New("No match") + } + return taskExec, nil + }, + } + th := NewTaskHandlerForFactory(events.NewMockEventSink(), nil, enqueueWfFunc, d, mockCatalogClient(), fakeKubeClient, promutils.NewTestScope()) + + prevNodeStatus := &v1alpha1.NodeStatus{Phase: v1alpha1.NodePhaseNotYetStarted} + s, err := th.CheckNodeStatus(ctx, w, n, prevNodeStatus) + assert.NoError(t, err) + assert.True(t, errors.Matches(s.Err, errors.CausedByError)) + }) + + t.Run("TaskExecStartFail", func(t *testing.T) { + taskExec := &mocks.Executor{} + taskExec.On("GetProperties").Return(pluginsV1.ExecutorProperties{}) + taskExec.On("CheckTaskStatus", + ctx, + mock.MatchedBy(func(o pluginsV1.TaskContext) bool { return true }), + mock.MatchedBy(func(o *core.TaskTemplate) bool { return true }), + ).Return(pluginsV1.TaskStatusPermanentFailure(regErrors.New("Failed")), nil) + d := &FactoryFuncs{ + GetTaskExecutorCb: func(taskType v1alpha1.TaskType) (pluginsV1.Executor, error) { + if taskType == task.Type { + return taskExec, nil + } + return nil, regErrors.New("No match") + }, + } + th := NewTaskHandlerForFactory(events.NewMockEventSink(), nil, enqueueWfFunc, d, mockCatalogClient(), fakeKubeClient, promutils.NewTestScope()) + + prevNodeStatus := &v1alpha1.NodeStatus{Phase: v1alpha1.NodePhaseRunning} + s, err := th.CheckNodeStatus(ctx, w, n, prevNodeStatus) + assert.NoError(t, err) + assert.Equal(t, handler.PhaseFailed, s.Phase) + }) + + t.Run("TaskExecCheckPanic", func(t *testing.T) { + taskExec := &mocks.Executor{} + taskExec.On("GetProperties").Return(pluginsV1.ExecutorProperties{}) + taskExec.On("CheckTaskStatus", + ctx, + mock.MatchedBy(func(o pluginsV1.TaskContext) bool { return true }), + mock.MatchedBy(func(o *core.TaskTemplate) bool { return true }), + ).Return(func(ctx context.Context, taskCtx pluginsV1.TaskContext, task *core.TaskTemplate) (status pluginsV1.TaskStatus, err error) { + panic("failed in execution") + }) + d := &FactoryFuncs{ + GetTaskExecutorCb: func(taskType v1alpha1.TaskType) (pluginsV1.Executor, error) { + if taskType == task.Type { + return taskExec, nil + } + return nil, regErrors.New("No match") + }, + } + th := NewTaskHandlerForFactory(events.NewMockEventSink(), nil, enqueueWfFunc, d, mockCatalogClient(), fakeKubeClient, promutils.NewTestScope()) + prevNodeStatus := &v1alpha1.NodeStatus{Phase: v1alpha1.NodePhaseRunning} + s, err := th.CheckNodeStatus(ctx, w, n, prevNodeStatus) + assert.Error(t, err) + assert.Equal(t, handler.PhaseUndefined, s.Phase) + }) + + t.Run("TaskExecRunning", func(t *testing.T) { + taskExec := &mocks.Executor{} + taskExec.On("GetProperties").Return(pluginsV1.ExecutorProperties{}) + taskExec.On("CheckTaskStatus", + ctx, + mock.MatchedBy(func(o pluginsV1.TaskContext) bool { return true }), + mock.MatchedBy(func(o *core.TaskTemplate) bool { return true }), + ).Return(pluginsV1.TaskStatusRunning, nil) + d := &FactoryFuncs{ + GetTaskExecutorCb: func(taskType v1alpha1.TaskType) (pluginsV1.Executor, error) { + if taskType == task.Type { + return taskExec, nil + } + return nil, regErrors.New("No match") + }, + } + th := NewTaskHandlerForFactory(events.NewMockEventSink(), nil, enqueueWfFunc, d, mockCatalogClient(), fakeKubeClient, promutils.NewTestScope()) + + prevNodeStatus := &v1alpha1.NodeStatus{Phase: v1alpha1.NodePhaseRunning} + s, err := th.CheckNodeStatus(ctx, w, n, prevNodeStatus) + assert.NoError(t, err) + assert.Equal(t, handler.StatusRunning, s) + }) + + t.Run("TaskExecDone", func(t *testing.T) { + taskExec := &mocks.Executor{} + taskExec.On("GetProperties").Return(pluginsV1.ExecutorProperties{}) + taskExec.On("CheckTaskStatus", + ctx, + mock.MatchedBy(func(o pluginsV1.TaskContext) bool { return true }), + mock.MatchedBy(func(o *core.TaskTemplate) bool { return true }), + ).Return(pluginsV1.TaskStatusSucceeded, nil) + d := &FactoryFuncs{ + GetTaskExecutorCb: func(taskType v1alpha1.TaskType) (pluginsV1.Executor, error) { + if taskType == task.Type { + return taskExec, nil + } + return nil, regErrors.New("No match") + }, + } + + store := createInmemoryDataStore(t, testScope.NewSubScope("4")) + paramsMap := make(map[string]*core.Literal) + paramsMap["out1"] = newIntegerLiteral(100) + err1 := store.WriteProtobuf(ctx, "test-data/inputs.pb", storage.Options{}, &core.LiteralMap{Literals: paramsMap}) + err2 := store.WriteProtobuf(ctx, "test-data/outputs.pb", storage.Options{}, &core.LiteralMap{Literals: paramsMap}) + assert.NoError(t, err1) + assert.NoError(t, err2) + + th := NewTaskHandlerForFactory(events.NewMockEventSink(), store, enqueueWfFunc, d, mockCatalogClient(), fakeKubeClient, promutils.NewTestScope()) + + prevNodeStatus := &v1alpha1.NodeStatus{Phase: v1alpha1.NodePhaseRunning} + + s, err := th.CheckNodeStatus(ctx, w, n, prevNodeStatus) + assert.NoError(t, err) + assert.Equal(t, handler.StatusSuccess, s) + }) +} + +func TestConvertTaskPhaseToHandlerStatus(t *testing.T) { + expectedErr := fmt.Errorf("failed") + tests := []struct { + name string + status pluginsV1.TaskStatus + hs handler.Status + isError bool + }{ + {"undefined", pluginsV1.TaskStatusUndefined, handler.StatusUndefined, true}, + {"running", pluginsV1.TaskStatusRunning, handler.StatusRunning, false}, + {"queued", pluginsV1.TaskStatusQueued, handler.StatusRunning, false}, + {"succeeded", pluginsV1.TaskStatusSucceeded, handler.StatusSuccess, false}, + {"unknown", pluginsV1.TaskStatusUnknown, handler.StatusUndefined, true}, + {"retryable", pluginsV1.TaskStatusRetryableFailure(expectedErr), handler.StatusRetryableFailure(expectedErr), false}, + {"failed", pluginsV1.TaskStatusPermanentFailure(expectedErr), handler.StatusFailed(expectedErr), false}, + {"undefined", pluginsV1.TaskStatusUndefined, handler.StatusUndefined, true}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + hs, err := ConvertTaskPhaseToHandlerStatus(test.status) + assert.Equal(t, hs, test.hs) + if test.isError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/flytepropeller/pkg/controller/workers.go b/flytepropeller/pkg/controller/workers.go new file mode 100644 index 0000000000..7b6303c818 --- /dev/null +++ b/flytepropeller/pkg/controller/workers.go @@ -0,0 +1,177 @@ +package controller + +import ( + "context" + "fmt" + "runtime/pprof" + "time" + + "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/promutils" + "github.com/prometheus/client_golang/prometheus" + "k8s.io/apimachinery/pkg/util/runtime" + "k8s.io/client-go/tools/cache" +) + +type Handler interface { + // Initialize the Handler + Initialize(ctx context.Context) error + // Handle method that should handle the object and try to converge the desired and the actual state + Handle(ctx context.Context, namespace, key string) error +} + +type workerPoolMetrics struct { + Scope promutils.Scope + FreeWorkers prometheus.Gauge + PerRoundTimer promutils.StopWatch + RoundError prometheus.Counter + RoundSuccess prometheus.Counter + WorkersRestarted prometheus.Counter +} + +type WorkerPool struct { + workQueue CompositeWorkQueue + metrics workerPoolMetrics + handler Handler +} + +// processNextWorkItem will read a single work item off the workqueue and +// attempt to process it, by calling the handler. +func (w *WorkerPool) processNextWorkItem(ctx context.Context) bool { + obj, shutdown := w.workQueue.Get() + + w.metrics.FreeWorkers.Dec() + defer w.metrics.FreeWorkers.Inc() + + if shutdown { + return false + } + + // We wrap this block in a func so we can defer c.workqueue.Done. + err := func(obj interface{}) error { + // We call Done here so the workqueue knows we have finished + // processing this item. We also must remember to call Forget if we + // do not want this work item being re-queued. For example, we do + // not call Forget if a transient error occurs, instead the item is + // put back on the workqueue and attempted again after a back-off + // period. + defer w.workQueue.Done(obj) + var key string + var ok bool + // We expect strings to come off the workqueue. These are of the + // form namespace/name. We do this as the delayed nature of the + // workqueue means the items in the informer cache may actually be + // more up to date that when the item was initially put onto the + // workqueue. + if key, ok = obj.(string); !ok { + // As the item in the workqueue is actually invalid, we call + // Forget here else we'd go into a loop of attempting to + // process a work item that is invalid. + w.workQueue.Forget(obj) + runtime.HandleError(fmt.Errorf("expected string in workqueue but got %#v", obj)) + return nil + } + + t := w.metrics.PerRoundTimer.Start() + defer t.Stop() + + // Convert the namespace/name string into a distinct namespace and name + namespace, name, err := cache.SplitMetaNamespaceKey(key) + if err != nil { + logger.Errorf(ctx, "Unable to split enqueued key into namespace/execId. Error[%v]", err) + return nil + } + ctx = contextutils.WithNamespace(ctx, namespace) + ctx = contextutils.WithExecutionID(ctx, name) + // Reconcile the Workflow + if err := w.handler.Handle(ctx, namespace, name); err != nil { + w.metrics.RoundError.Inc() + return fmt.Errorf("error syncing '%s': %s", key, err.Error()) + } + w.metrics.RoundSuccess.Inc() + + // Finally, if no error occurs we Forget this item so it does not + // get queued again until another change happens. + w.workQueue.Forget(obj) + logger.Infof(ctx, "Successfully synced '%s'", key) + return nil + }(obj) + + if err != nil { + runtime.HandleError(err) + return true + } + + return true +} + +// runWorker is a long-running function that will continually call the +// processNextWorkItem function in order to read and process a message on the +// workqueue. +func (w *WorkerPool) runWorker(ctx context.Context) { + logger.Infof(ctx, "Started Worker") + defer logger.Infof(ctx, "Exiting Worker") + for w.processNextWorkItem(ctx) { + } +} + +func (w *WorkerPool) Initialize(ctx context.Context) error { + return w.handler.Initialize(ctx) +} + +// Run will set up the event handlers for types we are interested in, as well +// as syncing informer caches and starting workers. It will block until stopCh +// is closed, at which point it will shutdown the workqueue and wait for +// workers to finish processing their current work items. +func (w *WorkerPool) Run(ctx context.Context, threadiness int, synced ...cache.InformerSynced) error { + defer runtime.HandleCrash() + defer w.workQueue.ShutdownAll() + + // Start the informer factories to begin populating the informer caches + logger.Info(ctx, "Starting FlyteWorkflow controller") + w.metrics.WorkersRestarted.Inc() + + // Wait for the caches to be synced before starting workers + logger.Info(ctx, "Waiting for informer caches to sync") + if ok := cache.WaitForCacheSync(ctx.Done(), synced...); !ok { + return fmt.Errorf("failed to wait for caches to sync") + } + + logger.Infof(ctx, "Starting workers [%d]", threadiness) + // Launch workers to process FlyteWorkflow resources + for i := 0; i < threadiness; i++ { + w.metrics.FreeWorkers.Inc() + logger.Infof(ctx, "Starting worker [%d]", i) + workerLabel := fmt.Sprintf("worker-%v", i) + go func() { + workerCtx := contextutils.WithGoroutineLabel(ctx, workerLabel) + pprof.SetGoroutineLabels(workerCtx) + w.runWorker(workerCtx) + }() + } + + w.workQueue.Start(ctx) + logger.Info(ctx, "Started workers") + <-ctx.Done() + logger.Info(ctx, "Shutting down workers") + + return nil +} + +func NewWorkerPool(ctx context.Context, scope promutils.Scope, workQueue CompositeWorkQueue, handler Handler) *WorkerPool { + roundScope := scope.NewSubScope("round") + metrics := workerPoolMetrics{ + Scope: scope, + FreeWorkers: scope.MustNewGauge("free_workers_count", "Number of workers free"), + PerRoundTimer: roundScope.MustNewStopWatch("round_total", "Latency per round", time.Millisecond), + RoundSuccess: roundScope.MustNewCounter("success_count", "Round succeeded"), + RoundError: roundScope.MustNewCounter("error_count", "Round failed"), + WorkersRestarted: scope.MustNewCounter("workers_restarted", "Propeller worker-pool was restarted"), + } + return &WorkerPool{ + workQueue: workQueue, + metrics: metrics, + handler: handler, + } +} diff --git a/flytepropeller/pkg/controller/workers_test.go b/flytepropeller/pkg/controller/workers_test.go new file mode 100644 index 0000000000..aae7e54ebe --- /dev/null +++ b/flytepropeller/pkg/controller/workers_test.go @@ -0,0 +1,93 @@ +package controller + +import ( + "context" + "sync" + "testing" + + "github.com/lyft/flytepropeller/pkg/controller/config" + + "github.com/lyft/flytestdlib/promutils" + "github.com/stretchr/testify/assert" +) + +var testLocalScope2 = promutils.NewScope("worker_pool") + +type testHandler struct { + InitCb func(ctx context.Context) error + HandleCb func(ctx context.Context, namespace, key string) error +} + +func (t *testHandler) Initialize(ctx context.Context) error { + return t.InitCb(ctx) +} + +func (t *testHandler) Handle(ctx context.Context, namespace, key string) error { + return t.HandleCb(ctx, namespace, key) +} + +func simpleWorkQ(ctx context.Context, t *testing.T, testScope promutils.Scope) CompositeWorkQueue { + cfg := config.CompositeQueueConfig{} + q, err := NewCompositeWorkQueue(ctx, cfg, testScope) + assert.NoError(t, err) + assert.NotNil(t, q) + return q +} + +func TestWorkerPool_Run(t *testing.T) { + ctx := context.TODO() + l := testLocalScope2.NewSubScope("new") + h := &testHandler{} + q := simpleWorkQ(ctx, t, l) + w := NewWorkerPool(ctx, l, q, h) + assert.NotNil(t, w) + + t.Run("initcalled", func(t *testing.T) { + + initCalled := false + h.InitCb = func(ctx context.Context) error { + initCalled = true + return nil + } + + assert.NoError(t, w.Initialize(ctx)) + assert.True(t, initCalled) + }) + + // Bad TEST :(. We create 2 waitgroups, one will wait for the Run function to exit (called wg) + // Other is called handleReceived, waits for receiving a handle + // The flow is, + // - start the poll loop + // - add a key `x` + // - wait for `x` to be handled + // - cancel the loop + // - wait for loop to exit + t.Run("run", func(t *testing.T) { + childCtx, cancel := context.WithCancel(ctx) + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + assert.NoError(t, w.Run(childCtx, 1, func() bool { + return true + })) + wg.Done() + }() + + handleReceived := sync.WaitGroup{} + handleReceived.Add(1) + + h.HandleCb = func(ctx context.Context, namespace, key string) error { + if key == "x" { + handleReceived.Done() + } else { + assert.FailNow(t, "x expected") + } + return nil + } + q.Add("x") + handleReceived.Wait() + + cancel() + wg.Wait() + }) +} diff --git a/flytepropeller/pkg/controller/workflow/errors/codes.go b/flytepropeller/pkg/controller/workflow/errors/codes.go new file mode 100644 index 0000000000..f4a84685aa --- /dev/null +++ b/flytepropeller/pkg/controller/workflow/errors/codes.go @@ -0,0 +1,15 @@ +package errors + +type ErrorCode string + +const ( + IllegalStateError ErrorCode = "IllegalStateError" + BadSpecificationError ErrorCode = "BadSpecificationError" + CausedByError ErrorCode = "CausedByError" + RuntimeExecutionError ErrorCode = "RuntimeExecutionError" + EventRecordingError ErrorCode = "ErrorRecordingError" +) + +func (e ErrorCode) String() string { + return string(e) +} diff --git a/flytepropeller/pkg/controller/workflow/errors/errors.go b/flytepropeller/pkg/controller/workflow/errors/errors.go new file mode 100644 index 0000000000..edf56224b3 --- /dev/null +++ b/flytepropeller/pkg/controller/workflow/errors/errors.go @@ -0,0 +1,80 @@ +package errors + +import ( + "fmt" + + "github.com/pkg/errors" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" +) + +type ErrorMessage = string + +type WorkflowError struct { + errors.StackTrace + Code ErrorCode + Message ErrorMessage + Workflow v1alpha1.WorkflowID +} + +func (w *WorkflowError) Error() string { + return fmt.Sprintf("Workflow[%s] failed. %v: %v", w.Workflow, w.Code, w.Message) +} + +type WorkflowErrorWithCause struct { + *WorkflowError + cause error +} + +func (w *WorkflowErrorWithCause) Error() string { + return fmt.Sprintf("%v, caused by: %v", w.WorkflowError.Error(), errors.Cause(w)) +} + +func (w *WorkflowErrorWithCause) Cause() error { + return w.cause +} + +func errorf(c ErrorCode, w v1alpha1.WorkflowID, msgFmt string, args ...interface{}) *WorkflowError { + return &WorkflowError{ + Code: c, + Workflow: w, + Message: fmt.Sprintf(msgFmt, args...), + } +} + +func Errorf(c ErrorCode, w v1alpha1.WorkflowID, msgFmt string, args ...interface{}) error { + return errorf(c, w, msgFmt, args...) +} + +func Wrapf(c ErrorCode, w v1alpha1.WorkflowID, cause error, msgFmt string, args ...interface{}) error { + return &WorkflowErrorWithCause{ + WorkflowError: errorf(c, w, msgFmt, args...), + cause: cause, + } +} + +func Matches(err error, code ErrorCode) bool { + errCode, isWorkflowError := GetErrorCode(err) + if isWorkflowError { + return code == errCode + } + return false +} + +func GetErrorCode(err error) (code ErrorCode, isWorkflowError bool) { + isWorkflowError = false + e, ok := err.(*WorkflowError) + if ok { + code = e.Code + isWorkflowError = true + return + } + + e2, ok := err.(*WorkflowErrorWithCause) + if ok { + code = e2.Code + isWorkflowError = true + return + } + return +} diff --git a/flytepropeller/pkg/controller/workflow/errors/errors_test.go b/flytepropeller/pkg/controller/workflow/errors/errors_test.go new file mode 100644 index 0000000000..c773fa776f --- /dev/null +++ b/flytepropeller/pkg/controller/workflow/errors/errors_test.go @@ -0,0 +1,48 @@ +package errors + +import ( + "fmt" + "testing" + + extErrors "github.com/pkg/errors" + "github.com/stretchr/testify/assert" +) + +func TestErrorf(t *testing.T) { + msg := "msg" + err := Errorf(IllegalStateError, "w1", "Message [%v]", msg) + assert.NotNil(t, err) + e := err.(*WorkflowError) + assert.Equal(t, IllegalStateError, e.Code) + assert.Equal(t, "w1", e.Workflow) + assert.Equal(t, fmt.Sprintf("Message [%v]", msg), e.Message) + assert.Equal(t, err, extErrors.Cause(e)) + assert.Equal(t, "Workflow[w1] failed. IllegalStateError: Message [msg]", err.Error()) +} + +func TestErrorfWithCause(t *testing.T) { + cause := extErrors.Errorf("Some Error") + msg := "msg" + err := Wrapf(IllegalStateError, "w1", cause, "Message [%v]", msg) + assert.NotNil(t, err) + e := err.(*WorkflowErrorWithCause) + assert.Equal(t, IllegalStateError, e.Code) + assert.Equal(t, "w1", e.Workflow) + assert.Equal(t, fmt.Sprintf("Message [%v]", msg), e.Message) + assert.Equal(t, cause, extErrors.Cause(e)) + assert.Equal(t, "Workflow[w1] failed. IllegalStateError: Message [msg], caused by: Some Error", err.Error()) +} + +func TestMatches(t *testing.T) { + err := Errorf(IllegalStateError, "w1", "Message ") + assert.True(t, Matches(err, IllegalStateError)) + assert.False(t, Matches(err, BadSpecificationError)) + + cause := extErrors.Errorf("Some Error") + err = Wrapf(IllegalStateError, "w1", cause, "Message ") + assert.True(t, Matches(err, IllegalStateError)) + assert.False(t, Matches(err, BadSpecificationError)) + + assert.False(t, Matches(cause, IllegalStateError)) + assert.False(t, Matches(cause, BadSpecificationError)) +} diff --git a/flytepropeller/pkg/controller/workflow/executor.go b/flytepropeller/pkg/controller/workflow/executor.go new file mode 100644 index 0000000000..ddbcc03eca --- /dev/null +++ b/flytepropeller/pkg/controller/workflow/executor.go @@ -0,0 +1,420 @@ +package workflow + +import ( + "context" + "fmt" + "time" + + "github.com/lyft/flyteidl/clients/go/events" + eventsErr "github.com/lyft/flyteidl/clients/go/events/errors" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/event" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/promutils/labeled" + "github.com/lyft/flytestdlib/storage" + corev1 "k8s.io/api/core/v1" + "k8s.io/client-go/tools/record" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/controller/executors" + "github.com/lyft/flytepropeller/pkg/controller/workflow/errors" + "github.com/lyft/flytepropeller/pkg/utils" +) + +type workflowMetrics struct { + AcceptedWorkflows labeled.Counter + FailureDuration labeled.StopWatch + SuccessDuration labeled.StopWatch + IncompleteWorkflowAborted labeled.Counter + + // Measures the time between when we receive service call to create an execution and when it has moved to running state. + AcceptanceLatency labeled.StopWatch + // Measures the time between when the WF moved to succeeding/failing state and when it finally moved to a terminal state. + CompletionLatency labeled.StopWatch +} + +type Status struct { + TransitionToPhase v1alpha1.WorkflowPhase + Err error +} + +var StatusReady = Status{TransitionToPhase: v1alpha1.WorkflowPhaseReady} +var StatusRunning = Status{TransitionToPhase: v1alpha1.WorkflowPhaseRunning} +var StatusSucceeding = Status{TransitionToPhase: v1alpha1.WorkflowPhaseSucceeding} +var StatusSuccess = Status{TransitionToPhase: v1alpha1.WorkflowPhaseSuccess} + +func StatusFailing(err error) Status { + return Status{TransitionToPhase: v1alpha1.WorkflowPhaseFailing, Err: err} +} + +func StatusFailed(err error) Status { + return Status{TransitionToPhase: v1alpha1.WorkflowPhaseFailed, Err: err} +} + +type workflowExecutor struct { + enqueueWorkflow v1alpha1.EnqueueWorkflow + store *storage.DataStore + wfRecorder events.WorkflowEventRecorder + k8sRecorder record.EventRecorder + metadataPrefix storage.DataReference + nodeExecutor executors.Node + metrics *workflowMetrics +} + +func (c *workflowExecutor) constructWorkflowMetadataPrefix(ctx context.Context, w *v1alpha1.FlyteWorkflow) (storage.DataReference, error) { + if w.GetExecutionID().WorkflowExecutionIdentifier != nil { + execID := fmt.Sprintf("%v-%v-%v", w.GetExecutionID().GetProject(), w.GetExecutionID().GetDomain(), w.GetExecutionID().GetName()) + return c.store.ConstructReference(ctx, c.metadataPrefix, execID) + } + // TODO should we use a random guid as the prefix? Otherwise we may get collisions + logger.Warningf(ctx, "Workflow has no ExecutionID. Using the name as the storage-prefix. This maybe unsafe!") + return c.store.ConstructReference(ctx, c.metadataPrefix, w.Name) +} + +func (c *workflowExecutor) handleReadyWorkflow(ctx context.Context, w *v1alpha1.FlyteWorkflow) (Status, error) { + + startNode := w.StartNode() + if startNode == nil { + return StatusFailing(errors.Errorf(errors.BadSpecificationError, w.GetID(), "StartNode not found.")), nil + } + + ref, err := c.constructWorkflowMetadataPrefix(ctx, w) + if err != nil { + return StatusFailing(errors.Wrapf(errors.CausedByError, w.GetID(), err, "failed to create metadata prefix.")), nil + } + w.GetExecutionStatus().SetDataDir(ref) + var inputs *core.LiteralMap + if w.Inputs != nil { + inputs = w.Inputs.LiteralMap + } + // Before starting the subworkflow, lets set the inputs for the Workflow. The inputs for a SubWorkflow are essentially + // Copy of the inputs to the Node + nodeStatus := w.GetNodeExecutionStatus(startNode.GetID()) + dataDir, err := c.store.ConstructReference(ctx, ref, startNode.GetID(), "data") + if err != nil { + return StatusFailing(errors.Wrapf(errors.CausedByError, w.GetID(), err, "failed to create metadata prefix for start node.")), nil + } + + logger.Infof(ctx, "Setting the MetadataDir for StartNode [%v]", dataDir) + nodeStatus.SetDataDir(dataDir) + s, err := c.nodeExecutor.SetInputsForStartNode(ctx, w, inputs) + if err != nil { + return StatusReady, err + } + + if s.HasFailed() { + return StatusFailing(errors.Wrapf(errors.CausedByError, w.GetID(), err, "failed to set inputs for Start node.")), nil + } + return StatusRunning, nil +} + +func (c *workflowExecutor) handleRunningWorkflow(ctx context.Context, w *v1alpha1.FlyteWorkflow) (Status, error) { + contextualWf := executors.NewBaseContextualWorkflow(w) + startNode := contextualWf.StartNode() + if startNode == nil { + return StatusFailed(errors.Errorf(errors.IllegalStateError, w.GetID(), "StartNode not found in running workflow?")), nil + } + state, err := c.nodeExecutor.RecursiveNodeHandler(ctx, contextualWf, startNode) + if err != nil { + return StatusRunning, err + } + if state.HasFailed() { + logger.Infof(ctx, "Workflow has failed. Error [%s]", state.Err.Error()) + return StatusFailing(state.Err), nil + } + if state.IsComplete() { + return StatusSucceeding, nil + } + if state.PartiallyComplete() { + c.enqueueWorkflow(contextualWf.GetK8sWorkflowID().String()) + } + return StatusRunning, nil +} + +func (c *workflowExecutor) handleFailingWorkflow(ctx context.Context, w *v1alpha1.FlyteWorkflow) (Status, error) { + contextualWf := executors.NewBaseContextualWorkflow(w) + // Best effort clean-up. + if err := c.cleanupRunningNodes(ctx, contextualWf); err != nil { + logger.Errorf(ctx, "Failed to propagate Abort for workflow:%v. Error: %v", w.ExecutionID.WorkflowExecutionIdentifier, err) + } + + errorNode := contextualWf.GetOnFailureNode() + if errorNode != nil { + state, err := c.nodeExecutor.RecursiveNodeHandler(ctx, contextualWf, errorNode) + if err != nil { + return StatusFailing(nil), err + } + if state.HasFailed() { + return StatusFailed(state.Err), nil + } + if state.PartiallyComplete() { + // Re-enqueue the workflow + c.enqueueWorkflow(contextualWf.GetK8sWorkflowID().String()) + return StatusFailing(nil), nil + } + // Fallthrough to handle state is complete + } + return StatusFailed(errors.Errorf(errors.CausedByError, w.ID, contextualWf.GetExecutionStatus().GetMessage())), nil +} + +func (c *workflowExecutor) handleSucceedingWorkflow(ctx context.Context, w *v1alpha1.FlyteWorkflow) Status { + logger.Infof(ctx, "Workflow completed successfully") + endNodeStatus := w.GetNodeExecutionStatus(v1alpha1.EndNodeID) + if endNodeStatus.GetPhase() == v1alpha1.NodePhaseSucceeded { + if endNodeStatus.GetDataDir() != "" { + w.Status.SetOutputReference(v1alpha1.GetOutputsFile(endNodeStatus.GetDataDir())) + } + } + return StatusSuccess +} + +func convertToExecutionError(err error, alternateErr string) *event.WorkflowExecutionEvent_Error { + if err != nil { + if code, isWorkflowErr := errors.GetErrorCode(err); isWorkflowErr { + return &event.WorkflowExecutionEvent_Error{ + Error: &core.ExecutionError{ + Code: code.String(), + Message: err.Error(), + }, + } + } + } else { + err = fmt.Errorf(alternateErr) + } + return &event.WorkflowExecutionEvent_Error{ + Error: &core.ExecutionError{ + Code: errors.RuntimeExecutionError.String(), + Message: err.Error(), + }, + } +} + +func (c *workflowExecutor) IdempotentReportEvent(ctx context.Context, e *event.WorkflowExecutionEvent) error { + err := c.wfRecorder.RecordWorkflowEvent(ctx, e) + if err != nil && eventsErr.IsAlreadyExists(err) { + logger.Infof(ctx, "Workflow event phase: %s, executionId %s already exist", + e.Phase.String(), e.ExecutionId) + return nil + } + return err +} + +func (c *workflowExecutor) TransitionToPhase(ctx context.Context, execID *core.WorkflowExecutionIdentifier, wStatus v1alpha1.ExecutableWorkflowStatus, toStatus Status) error { + if wStatus.GetPhase() != toStatus.TransitionToPhase { + logger.Debugf(ctx, "Transitioning/Recording event for workflow state transition [%s] -> [%s]", wStatus.GetPhase().String(), toStatus.TransitionToPhase.String()) + + wfEvent := &event.WorkflowExecutionEvent{ + ExecutionId: execID, + } + previousMsg := wStatus.GetMessage() + switch toStatus.TransitionToPhase { + case v1alpha1.WorkflowPhaseReady: + // Do nothing + return nil + case v1alpha1.WorkflowPhaseRunning: + wfEvent.Phase = core.WorkflowExecution_RUNNING + wStatus.UpdatePhase(v1alpha1.WorkflowPhaseRunning, fmt.Sprintf("Workflow Started")) + wfEvent.OccurredAt = utils.GetProtoTime(wStatus.GetStartedAt()) + case v1alpha1.WorkflowPhaseFailing: + wfEvent.Phase = core.WorkflowExecution_FAILING + e := convertToExecutionError(toStatus.Err, previousMsg) + wfEvent.OutputResult = e + // Completion latency is only observed when a workflow completes successfully + wStatus.UpdatePhase(v1alpha1.WorkflowPhaseFailing, e.Error.Message) + wfEvent.OccurredAt = utils.GetProtoTime(nil) + case v1alpha1.WorkflowPhaseFailed: + wfEvent.Phase = core.WorkflowExecution_FAILED + e := convertToExecutionError(toStatus.Err, previousMsg) + wfEvent.OutputResult = e + wStatus.UpdatePhase(v1alpha1.WorkflowPhaseFailed, e.Error.Message) + wfEvent.OccurredAt = utils.GetProtoTime(wStatus.GetStoppedAt()) + c.metrics.FailureDuration.Observe(ctx, wStatus.GetStartedAt().Time, wStatus.GetStoppedAt().Time) + case v1alpha1.WorkflowPhaseSucceeding: + wfEvent.Phase = core.WorkflowExecution_SUCCEEDING + endNodeStatus := wStatus.GetNodeExecutionStatus(v1alpha1.EndNodeID) + // Workflow completion latency is recorded as the time it takes for the workflow to transition from end + // node started time to workflow success being sent to the control plane. + if endNodeStatus != nil && endNodeStatus.GetStartedAt() != nil { + c.metrics.CompletionLatency.Observe(ctx, endNodeStatus.GetStartedAt().Time, time.Now()) + } + + wStatus.UpdatePhase(v1alpha1.WorkflowPhaseSucceeding, "") + wfEvent.OccurredAt = utils.GetProtoTime(nil) + case v1alpha1.WorkflowPhaseSuccess: + wfEvent.Phase = core.WorkflowExecution_SUCCEEDED + wStatus.UpdatePhase(v1alpha1.WorkflowPhaseSuccess, "") + // Not all workflows have outputs + if wStatus.GetOutputReference() != "" { + wfEvent.OutputResult = &event.WorkflowExecutionEvent_OutputUri{ + OutputUri: wStatus.GetOutputReference().String(), + } + } + wfEvent.OccurredAt = utils.GetProtoTime(wStatus.GetStoppedAt()) + c.metrics.SuccessDuration.Observe(ctx, wStatus.GetStartedAt().Time, wStatus.GetStoppedAt().Time) + case v1alpha1.WorkflowPhaseAborted: + wfEvent.Phase = core.WorkflowExecution_ABORTED + if wStatus.GetLastUpdatedAt() != nil { + c.metrics.CompletionLatency.Observe(ctx, wStatus.GetLastUpdatedAt().Time, time.Now()) + } + wStatus.UpdatePhase(v1alpha1.WorkflowPhaseAborted, "") + wfEvent.OccurredAt = utils.GetProtoTime(wStatus.GetStoppedAt()) + default: + return errors.Errorf(errors.IllegalStateError, "", "Illegal transition from [%v] -> [%v]", wStatus.GetPhase().String(), toStatus.TransitionToPhase.String()) + } + + if recordingErr := c.IdempotentReportEvent(ctx, wfEvent); recordingErr != nil { + if eventsErr.IsEventAlreadyInTerminalStateError(recordingErr) { + // Move to WorkflowPhaseFailed for state mis-match + msg := fmt.Sprintf("workflow state mismatch between propeller and control plane; Propeller State: %s, ExecutionId %s", wfEvent.Phase.String(), wfEvent.ExecutionId) + logger.Warningf(ctx, msg) + wStatus.UpdatePhase(v1alpha1.WorkflowPhaseFailed, msg) + return nil + } + logger.Warningf(ctx, "Event recording failed. Error [%s]", recordingErr.Error()) + return errors.Wrapf(errors.EventRecordingError, "", recordingErr, "failed to publish event") + } + } + return nil +} + +func (c *workflowExecutor) Initialize(ctx context.Context) error { + logger.Infof(ctx, "Initializing Core Workflow Executor") + return c.nodeExecutor.Initialize(ctx) +} + +func (c *workflowExecutor) HandleFlyteWorkflow(ctx context.Context, w *v1alpha1.FlyteWorkflow) error { + logger.Infof(ctx, "Handling Workflow [%s], id: [%s], Phase [%s]", w.GetName(), w.GetExecutionID(), w.GetExecutionStatus().GetPhase().String()) + defer logger.Infof(ctx, "Handling Workflow [%s] Done", w.GetName()) + + wStatus := w.GetExecutionStatus() + // Initialize the Status if not already initialized + switch wStatus.GetPhase() { + case v1alpha1.WorkflowPhaseReady: + newStatus, err := c.handleReadyWorkflow(ctx, w) + if err != nil { + return err + } + c.metrics.AcceptedWorkflows.Inc(ctx) + if err := c.TransitionToPhase(ctx, w.ExecutionID.WorkflowExecutionIdentifier, wStatus, newStatus); err != nil { + return err + } + c.k8sRecorder.Event(w, corev1.EventTypeNormal, v1alpha1.WorkflowPhaseRunning.String(), "Workflow began execution") + + // TODO: Consider annotating with the newStatus. + acceptedAt := w.GetCreationTimestamp().Time + if w.AcceptedAt != nil && !w.AcceptedAt.IsZero() { + acceptedAt = w.AcceptedAt.Time + } + + c.metrics.AcceptanceLatency.Observe(ctx, acceptedAt, time.Now()) + return nil + + case v1alpha1.WorkflowPhaseRunning: + newStatus, err := c.handleRunningWorkflow(ctx, w) + if err != nil { + logger.Warningf(ctx, "Error in handling running workflow [%v]", err.Error()) + return err + } + if err := c.TransitionToPhase(ctx, w.ExecutionID.WorkflowExecutionIdentifier, wStatus, newStatus); err != nil { + return err + } + return nil + case v1alpha1.WorkflowPhaseSucceeding: + newStatus := c.handleSucceedingWorkflow(ctx, w) + + if err := c.TransitionToPhase(ctx, w.ExecutionID.WorkflowExecutionIdentifier, wStatus, newStatus); err != nil { + return err + } + c.k8sRecorder.Event(w, corev1.EventTypeNormal, v1alpha1.WorkflowPhaseSuccess.String(), "Workflow completed.") + return nil + case v1alpha1.WorkflowPhaseFailing: + newStatus, err := c.handleFailingWorkflow(ctx, w) + if err != nil { + return err + } + if err := c.TransitionToPhase(ctx, w.ExecutionID.WorkflowExecutionIdentifier, wStatus, newStatus); err != nil { + return err + } + c.k8sRecorder.Event(w, corev1.EventTypeWarning, v1alpha1.WorkflowPhaseFailed.String(), "Workflow failed.") + return nil + default: + return errors.Errorf(errors.IllegalStateError, w.ID, "Unsupported state [%s] for workflow", w.GetExecutionStatus().GetPhase().String()) + } +} + +func (c *workflowExecutor) HandleAbortedWorkflow(ctx context.Context, w *v1alpha1.FlyteWorkflow, maxRetries uint32) error { + if !w.Status.IsTerminated() { + c.metrics.IncompleteWorkflowAborted.Inc(ctx) + var err error + if w.Status.FailedAttempts > maxRetries { + err = errors.Errorf(errors.RuntimeExecutionError, w.GetID(), "max number of system retry attempts [%d/%d] exhausted. Last known status message: %v", w.Status.FailedAttempts, maxRetries, w.Status.Message) + } + + // Best effort clean-up. + contextualWf := executors.NewBaseContextualWorkflow(w) + if err2 := c.cleanupRunningNodes(ctx, contextualWf); err2 != nil { + logger.Errorf(ctx, "Failed to propagate Abort for workflow:%v. Error: %v", w.ExecutionID.WorkflowExecutionIdentifier, err2) + } + + var status Status + if err != nil { + // This workflow failed, record that phase and corresponding error message. + status = StatusFailed(err) + } else { + // Otherwise, this workflow is aborted. + status = Status{ + TransitionToPhase: v1alpha1.WorkflowPhaseAborted, + } + } + + if err := c.TransitionToPhase(ctx, w.ExecutionID.WorkflowExecutionIdentifier, w.GetExecutionStatus(), status); err != nil { + return err + } + } + return nil +} + +func (c *workflowExecutor) cleanupRunningNodes(ctx context.Context, w v1alpha1.ExecutableWorkflow) error { + startNode := w.StartNode() + if startNode == nil { + return errors.Errorf(errors.IllegalStateError, w.GetID(), "StartNode not found in running workflow?") + } + + if err := c.nodeExecutor.AbortHandler(ctx, w, startNode); err != nil { + return errors.Errorf(errors.CausedByError, w.GetID(), "Failed to propagate Abort for workflow. Error: %v", err) + } + + return nil +} + +func NewExecutor(ctx context.Context, store *storage.DataStore, enQWorkflow v1alpha1.EnqueueWorkflow, eventSink events.EventSink, k8sEventRecorder record.EventRecorder, metadataPrefix string, nodeExecutor executors.Node, scope promutils.Scope) (executors.Workflow, error) { + basePrefix := store.GetBaseContainerFQN(ctx) + if metadataPrefix != "" { + var err error + basePrefix, err = store.ConstructReference(ctx, basePrefix, metadataPrefix) + if err != nil { + return nil, err + } + } + logger.Infof(ctx, "Metadata will be stored in container path: [%s]", basePrefix) + + workflowScope := scope.NewSubScope("workflow") + + return &workflowExecutor{ + nodeExecutor: nodeExecutor, + store: store, + enqueueWorkflow: enQWorkflow, + wfRecorder: events.NewWorkflowEventRecorder(eventSink, workflowScope), + k8sRecorder: k8sEventRecorder, + metadataPrefix: basePrefix, + metrics: &workflowMetrics{ + AcceptedWorkflows: labeled.NewCounter("accepted", "Number of workflows accepted by propeller", workflowScope), + FailureDuration: labeled.NewStopWatch("failure_duration", "Indicates the total execution time of a failed workflow.", time.Millisecond, workflowScope), + SuccessDuration: labeled.NewStopWatch("success_duration", "Indicates the total execution time of a successful workflow.", time.Millisecond, workflowScope), + IncompleteWorkflowAborted: labeled.NewCounter("workflow_aborted", "Indicates an inprogress execution was aborted", workflowScope), + AcceptanceLatency: labeled.NewStopWatch("acceptance_latency", "Delay between workflow creation and moving it to running state.", time.Millisecond, workflowScope, labeled.EmitUnlabeledMetric), + CompletionLatency: labeled.NewStopWatch("completion_latency", "Measures the time between when the WF moved to succeeding/failing state and when it finally moved to a terminal state.", time.Millisecond, workflowScope, labeled.EmitUnlabeledMetric), + }, + }, nil +} diff --git a/flytepropeller/pkg/controller/workflow/executor_test.go b/flytepropeller/pkg/controller/workflow/executor_test.go new file mode 100644 index 0000000000..492e4067dd --- /dev/null +++ b/flytepropeller/pkg/controller/workflow/executor_test.go @@ -0,0 +1,615 @@ +package workflow + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "reflect" + "strconv" + "testing" + + mocks2 "github.com/lyft/flytepropeller/pkg/controller/executors/mocks" + + eventsErr "github.com/lyft/flyteidl/clients/go/events/errors" + "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" + wfErrors "github.com/lyft/flytepropeller/pkg/controller/workflow/errors" + + "time" + + "github.com/lyft/flyteplugins/go/tasks/v1/flytek8s" + + "github.com/golang/protobuf/proto" + "github.com/golang/protobuf/ptypes" + "github.com/lyft/flyteidl/clients/go/events" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/event" + pluginV1 "github.com/lyft/flyteplugins/go/tasks/v1/types" + "github.com/lyft/flyteplugins/go/tasks/v1/types/mocks" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/controller/catalog" + "github.com/lyft/flytepropeller/pkg/controller/nodes" + "github.com/lyft/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" + "github.com/lyft/flytepropeller/pkg/controller/nodes/task" + "github.com/lyft/flytepropeller/pkg/utils" + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/storage" + "github.com/lyft/flytestdlib/yamlutils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "k8s.io/client-go/tools/record" +) + +var ( + testScope = promutils.NewScope("test_wfexec") + fakeKubeClient = mocks2.NewFakeKubeClient() +) + +func createInmemoryDataStore(t testing.TB, scope promutils.Scope) *storage.DataStore { + cfg := storage.Config{ + Type: storage.TypeMemory, + } + d, err := storage.NewDataStore(&cfg, scope) + assert.NoError(t, err) + return d +} + +func StdOutEventRecorder() record.EventRecorder { + eventChan := make(chan string) + recorder := &record.FakeRecorder{ + Events: eventChan, + } + + go func() { + defer close(eventChan) + for { + s := <-eventChan + if s == "" { + return + } + fmt.Printf("Event: [%v]\n", s) + } + }() + return recorder +} + +func createHappyPathTaskExecutor(t assert.TestingT, store *storage.DataStore, enableAsserts bool) pluginV1.Executor { + exec := &mocks.Executor{} + exec.On("GetID").Return("task") + exec.On("GetProperties").Return(pluginV1.ExecutorProperties{}) + exec.On("Initialize", + mock.AnythingOfType(reflect.TypeOf(context.TODO()).String()), + mock.AnythingOfType(reflect.TypeOf(pluginV1.ExecutorInitializationParameters{}).String()), + ).Return(nil) + + exec.On("ResolveOutputs", + mock.Anything, + mock.Anything, + mock.Anything, + ). + Return(func(ctx context.Context, taskCtx pluginV1.TaskContext, varNames ...string) (values map[string]*core.Literal) { + d := &handler.Data{} + outputsFileRef := v1alpha1.GetOutputsFile(taskCtx.GetDataDir()) + assert.NoError(t, store.ReadProtobuf(ctx, outputsFileRef, d)) + assert.NotNil(t, d.Literals) + + values = make(map[string]*core.Literal, len(varNames)) + for _, varName := range varNames { + l, ok := d.Literals[varName] + assert.True(t, ok, "Expect var %v in task outputs.", varName) + + values[varName] = l + } + + return values + }, func(ctx context.Context, taskCtx pluginV1.TaskContext, varNames ...string) error { + return nil + }) + + startFn := func(ctx context.Context, taskCtx pluginV1.TaskContext, task *core.TaskTemplate, _ *core.LiteralMap) pluginV1.TaskStatus { + outputVars := task.GetInterface().Outputs.Variables + o := &core.LiteralMap{ + Literals: make(map[string]*core.Literal, len(outputVars)), + } + for k, v := range outputVars { + l, err := utils.MakeDefaultLiteralForType(v.Type) + if enableAsserts && !assert.NoError(t, err) { + assert.FailNow(t, "Failed to create default output for node [%v] Type [%v]", taskCtx.GetTaskExecutionID(), v.Type) + } + o.Literals[k] = l + } + assert.NoError(t, store.WriteProtobuf(ctx, v1alpha1.GetOutputsFile(taskCtx.GetDataDir()), storage.Options{}, o)) + + return pluginV1.TaskStatusRunning + } + + exec.On("StartTask", + mock.MatchedBy(func(ctx context.Context) bool { return true }), + mock.MatchedBy(func(o pluginV1.TaskContext) bool { return true }), + mock.MatchedBy(func(o *core.TaskTemplate) bool { return true }), + mock.MatchedBy(func(o *core.LiteralMap) bool { return true }), + ).Return(startFn, nil) + + checkStatusFn := func(_ context.Context, taskCtx pluginV1.TaskContext, _ *core.TaskTemplate) pluginV1.TaskStatus { + if enableAsserts { + assert.NotEmpty(t, taskCtx.GetDataDir()) + } + return pluginV1.TaskStatusSucceeded + } + exec.On("CheckTaskStatus", + mock.MatchedBy(func(ctx context.Context) bool { return true }), + mock.MatchedBy(func(o pluginV1.TaskContext) bool { return true }), + mock.MatchedBy(func(o *core.TaskTemplate) bool { return true }), + ).Return(checkStatusFn, nil) + + return exec +} + +func createFailingTaskExecutor() pluginV1.Executor { + exec := &mocks.Executor{} + exec.On("GetID").Return("task") + exec.On("GetProperties").Return(pluginV1.ExecutorProperties{}) + exec.On("Initialize", + mock.AnythingOfType(reflect.TypeOf(context.TODO()).String()), + mock.AnythingOfType(reflect.TypeOf(pluginV1.ExecutorInitializationParameters{}).String()), + ).Return(nil) + + exec.On("StartTask", + mock.MatchedBy(func(ctx context.Context) bool { return true }), + mock.MatchedBy(func(o pluginV1.TaskContext) bool { return true }), + mock.MatchedBy(func(o *core.TaskTemplate) bool { return true }), + mock.MatchedBy(func(o *core.LiteralMap) bool { return true }), + ).Return(pluginV1.TaskStatusRunning, nil) + + exec.On("CheckTaskStatus", + mock.MatchedBy(func(ctx context.Context) bool { return true }), + mock.MatchedBy(func(o pluginV1.TaskContext) bool { return true }), + mock.MatchedBy(func(o *core.TaskTemplate) bool { return true }), + ).Return(pluginV1.TaskStatusPermanentFailure(errors.New("failed")), nil) + + exec.On("KillTask", + mock.MatchedBy(func(ctx context.Context) bool { return true }), + mock.MatchedBy(func(o pluginV1.TaskContext) bool { return true }), + mock.Anything, + ).Return(nil) + + return exec +} + +func createTaskExecutorErrorInCheck() pluginV1.Executor { + exec := &mocks.Executor{} + exec.On("GetID").Return("task") + exec.On("GetProperties").Return(pluginV1.ExecutorProperties{}) + exec.On("Initialize", + mock.AnythingOfType(reflect.TypeOf(context.TODO()).String()), + mock.AnythingOfType(reflect.TypeOf(pluginV1.ExecutorInitializationParameters{}).String()), + ).Return(nil) + + exec.On("StartTask", + mock.MatchedBy(func(ctx context.Context) bool { return true }), + mock.MatchedBy(func(o pluginV1.TaskContext) bool { return true }), + mock.MatchedBy(func(o *core.TaskTemplate) bool { return true }), + mock.MatchedBy(func(o *core.LiteralMap) bool { return true }), + ).Return(pluginV1.TaskStatusRunning, nil) + + exec.On("CheckTaskStatus", + mock.MatchedBy(func(ctx context.Context) bool { return true }), + mock.MatchedBy(func(o pluginV1.TaskContext) bool { return true }), + mock.MatchedBy(func(o *core.TaskTemplate) bool { return true }), + ).Return(pluginV1.TaskStatusUndefined, errors.New("check failed")) + + return exec +} + +func createSingletonTaskExecutorFactory(te pluginV1.Executor) task.Factory { + return &task.FactoryFuncs{ + GetTaskExecutorCb: func(taskType v1alpha1.TaskType) (pluginV1.Executor, error) { + return te, nil + }, + ListAllTaskExecutorsCb: func() []pluginV1.Executor { + return []pluginV1.Executor{te} + }, + } +} + +func init() { + flytek8s.InitializeFake() +} + +func TestWorkflowExecutor_HandleFlyteWorkflow_Error(t *testing.T) { + ctx := context.Background() + store := createInmemoryDataStore(t, testScope.NewSubScope("12")) + recorder := StdOutEventRecorder() + _, err := events.ConstructEventSink(ctx, events.GetConfig(ctx)) + assert.NoError(t, err) + + te := createTaskExecutorErrorInCheck() + tf := createSingletonTaskExecutorFactory(te) + task.SetTestFactory(tf) + assert.True(t, task.IsTestModeEnabled()) + + enqueueWorkflow := func(workflowId v1alpha1.WorkflowID) {} + + eventSink := events.NewMockEventSink() + nodeExec, err := nodes.NewExecutor(ctx, store, enqueueWorkflow, time.Second, eventSink, launchplan.NewFailFastLaunchPlanExecutor(), catalog.NewCatalogClient(store), fakeKubeClient, promutils.NewTestScope()) + assert.NoError(t, err) + executor, err := NewExecutor(ctx, store, enqueueWorkflow, eventSink, recorder, "", nodeExec, promutils.NewTestScope()) + assert.NoError(t, err) + + assert.NoError(t, executor.Initialize(ctx)) + + wJSON, err := yamlutils.ReadYamlFileAsJSON("testdata/benchmark_wf.yaml") + if assert.NoError(t, err) { + w := &v1alpha1.FlyteWorkflow{} + if assert.NoError(t, json.Unmarshal(wJSON, w)) { + // For benchmark workflow, we know how many rounds it needs + // Number of rounds = 7 + 1 + + for i := 0; i < 11; i++ { + err := executor.HandleFlyteWorkflow(ctx, w) + for k, v := range w.Status.NodeStatus { + fmt.Printf("Node[%v=%v],", k, v.Phase.String()) + // Reset dirty manually for tests. + v.ResetDirty() + } + fmt.Printf("\n") + + if i < 4 { + assert.NoError(t, err, "Round %d", i) + } else { + assert.Error(t, err, "Round %d", i) + } + } + assert.Equal(t, v1alpha1.WorkflowPhaseRunning.String(), w.Status.Phase.String(), "Message: [%v]", w.Status.Message) + } + } +} + +func TestWorkflowExecutor_HandleFlyteWorkflow(t *testing.T) { + ctx := context.Background() + store := createInmemoryDataStore(t, testScope.NewSubScope("13")) + recorder := StdOutEventRecorder() + _, err := events.ConstructEventSink(ctx, events.GetConfig(ctx)) + assert.NoError(t, err) + + te := createHappyPathTaskExecutor(t, store, true) + tf := createSingletonTaskExecutorFactory(te) + task.SetTestFactory(tf) + assert.True(t, task.IsTestModeEnabled()) + + enqueueWorkflow := func(workflowId v1alpha1.WorkflowID) {} + + eventSink := events.NewMockEventSink() + nodeExec, err := nodes.NewExecutor(ctx, store, enqueueWorkflow, time.Second, eventSink, launchplan.NewFailFastLaunchPlanExecutor(), catalog.NewCatalogClient(store), fakeKubeClient, promutils.NewTestScope()) + assert.NoError(t, err) + + executor, err := NewExecutor(ctx, store, enqueueWorkflow, eventSink, recorder, "", nodeExec, promutils.NewTestScope()) + assert.NoError(t, err) + + assert.NoError(t, executor.Initialize(ctx)) + + wJSON, err := yamlutils.ReadYamlFileAsJSON("testdata/benchmark_wf.yaml") + if assert.NoError(t, err) { + w := &v1alpha1.FlyteWorkflow{} + if assert.NoError(t, json.Unmarshal(wJSON, w)) { + // For benchmark workflow, we know how many rounds it needs + // Number of rounds = 12 ? + for i := 0; i < 12; i++ { + err := executor.HandleFlyteWorkflow(ctx, w) + if err != nil { + t.Log(err) + } + + assert.NoError(t, err) + fmt.Printf("Round[%d] Workflow[%v]\n", i, w.Status.Phase.String()) + for k, v := range w.Status.NodeStatus { + fmt.Printf("Node[%v=%v],", k, v.Phase.String()) + // Reset dirty manually for tests. + v.ResetDirty() + } + fmt.Printf("\n") + } + + assert.Equal(t, v1alpha1.WorkflowPhaseSuccess.String(), w.Status.Phase.String(), "Message: [%v]", w.Status.Message) + } + } +} + +func BenchmarkWorkflowExecutor(b *testing.B) { + scope := promutils.NewScope("test3") + ctx := context.Background() + store := createInmemoryDataStore(b, scope.NewSubScope(strconv.Itoa(b.N))) + recorder := StdOutEventRecorder() + _, err := events.ConstructEventSink(ctx, events.GetConfig(ctx)) + assert.NoError(b, err) + + te := createHappyPathTaskExecutor(b, store, false) + tf := createSingletonTaskExecutorFactory(te) + task.SetTestFactory(tf) + assert.True(b, task.IsTestModeEnabled()) + enqueueWorkflow := func(workflowId v1alpha1.WorkflowID) {} + + eventSink := events.NewMockEventSink() + nodeExec, err := nodes.NewExecutor(ctx, store, enqueueWorkflow, time.Second, eventSink, launchplan.NewFailFastLaunchPlanExecutor(), catalog.NewCatalogClient(store), fakeKubeClient, scope) + assert.NoError(b, err) + + executor, err := NewExecutor(ctx, store, enqueueWorkflow, eventSink, recorder, "", nodeExec, promutils.NewTestScope()) + assert.NoError(b, err) + + assert.NoError(b, executor.Initialize(ctx)) + b.ReportAllocs() + + wJSON, err := yamlutils.ReadYamlFileAsJSON("testdata/benchmark_wf.yaml") + if err != nil { + assert.FailNow(b, "Got error reading the testdata") + } + w := &v1alpha1.FlyteWorkflow{} + err = json.Unmarshal(wJSON, w) + if err != nil { + assert.FailNow(b, "Got error unmarshalling the testdata") + } + + // Current benchmark 2ms/op + for i := 0; i < b.N; i++ { + deepW := w.DeepCopy() + // For benchmark workflow, we know how many rounds it needs + // Number of rounds = 7 + 1 + for i := 0; i < 8; i++ { + err := executor.HandleFlyteWorkflow(ctx, deepW) + if err != nil { + assert.FailNow(b, "Run the unit test first. Benchmark should not fail") + } + } + if deepW.Status.Phase != v1alpha1.WorkflowPhaseSuccess { + assert.FailNow(b, "Workflow did not end in the expected state") + } + } +} + +func TestWorkflowExecutor_HandleFlyteWorkflow_Failing(t *testing.T) { + ctx := context.Background() + store := createInmemoryDataStore(t, promutils.NewTestScope()) + recorder := StdOutEventRecorder() + _, err := events.ConstructEventSink(ctx, events.GetConfig(ctx)) + assert.NoError(t, err) + + te := createFailingTaskExecutor() + tf := createSingletonTaskExecutorFactory(te) + task.SetTestFactory(tf) + assert.True(t, task.IsTestModeEnabled()) + + enqueueWorkflow := func(workflowId v1alpha1.WorkflowID) {} + + recordedRunning := false + recordedFailed := false + recordedFailing := true + eventSink := events.NewMockEventSink() + mockSink := eventSink.(*events.MockEventSink) + mockSink.SinkCb = func(ctx context.Context, message proto.Message) error { + e, ok := message.(*event.WorkflowExecutionEvent) + + if ok { + assert.True(t, ok) + switch e.Phase { + case core.WorkflowExecution_RUNNING: + occuredAt, err := ptypes.Timestamp(e.OccurredAt) + assert.NoError(t, err) + + assert.WithinDuration(t, occuredAt, time.Now(), time.Millisecond*5) + recordedRunning = true + case core.WorkflowExecution_FAILING: + occuredAt, err := ptypes.Timestamp(e.OccurredAt) + assert.NoError(t, err) + + assert.WithinDuration(t, occuredAt, time.Now(), time.Millisecond*5) + recordedFailing = true + case core.WorkflowExecution_FAILED: + occuredAt, err := ptypes.Timestamp(e.OccurredAt) + assert.NoError(t, err) + + assert.WithinDuration(t, occuredAt, time.Now(), time.Millisecond*5) + recordedFailed = true + default: + return fmt.Errorf("MockWorkflowRecorder should not have entered into any other states [%v]", e.Phase) + } + } + return nil + } + nodeExec, err := nodes.NewExecutor(ctx, store, enqueueWorkflow, time.Second, eventSink, launchplan.NewFailFastLaunchPlanExecutor(), catalog.NewCatalogClient(store), fakeKubeClient, promutils.NewTestScope()) + assert.NoError(t, err) + executor, err := NewExecutor(ctx, store, enqueueWorkflow, eventSink, recorder, "", nodeExec, promutils.NewTestScope()) + assert.NoError(t, err) + + assert.NoError(t, executor.Initialize(ctx)) + + wJSON, err := yamlutils.ReadYamlFileAsJSON("testdata/benchmark_wf.yaml") + if assert.NoError(t, err) { + w := &v1alpha1.FlyteWorkflow{} + if assert.NoError(t, json.Unmarshal(wJSON, w)) { + // For benchmark workflow, we will run into the first failure on round 6 + + for i := 0; i < 6; i++ { + err := executor.HandleFlyteWorkflow(ctx, w) + assert.Nil(t, err, "Round [%v]", i) + fmt.Printf("Round[%d] Workflow[%v]\n", i, w.Status.Phase.String()) + for k, v := range w.Status.NodeStatus { + fmt.Printf("Node[%v=%v],", k, v.Phase.String()) + // Reset dirty manually for tests. + v.ResetDirty() + } + fmt.Printf("\n") + + if i == 5 { + assert.Equal(t, v1alpha1.WorkflowPhaseFailed, w.Status.Phase) + } else { + assert.NotEqual(t, v1alpha1.WorkflowPhaseFailed, w.Status.Phase, "For Round [%v] got phase [%v]", i, w.Status.Phase.String()) + } + + } + + assert.Equal(t, v1alpha1.WorkflowPhaseFailed.String(), w.Status.Phase.String(), "Message: [%v]", w.Status.Message) + } + } + assert.True(t, recordedRunning) + assert.True(t, recordedFailing) + assert.True(t, recordedFailed) +} + +func TestWorkflowExecutor_HandleFlyteWorkflow_Events(t *testing.T) { + ctx := context.Background() + store := createInmemoryDataStore(t, promutils.NewTestScope()) + recorder := StdOutEventRecorder() + _, err := events.ConstructEventSink(ctx, events.GetConfig(ctx)) + assert.NoError(t, err) + + te := createHappyPathTaskExecutor(t, store, true) + tf := createSingletonTaskExecutorFactory(te) + task.SetTestFactory(tf) + assert.True(t, task.IsTestModeEnabled()) + + enqueueWorkflow := func(workflowId v1alpha1.WorkflowID) {} + + recordedRunning := false + recordedSuccess := false + recordedFailing := true + eventSink := events.NewMockEventSink() + mockSink := eventSink.(*events.MockEventSink) + mockSink.SinkCb = func(ctx context.Context, message proto.Message) error { + e, ok := message.(*event.WorkflowExecutionEvent) + if ok { + switch e.Phase { + case core.WorkflowExecution_RUNNING: + occuredAt, err := ptypes.Timestamp(e.OccurredAt) + assert.NoError(t, err) + + assert.WithinDuration(t, occuredAt, time.Now(), time.Millisecond*5) + recordedRunning = true + case core.WorkflowExecution_SUCCEEDING: + occuredAt, err := ptypes.Timestamp(e.OccurredAt) + assert.NoError(t, err) + + assert.WithinDuration(t, occuredAt, time.Now(), time.Millisecond*5) + recordedFailing = true + case core.WorkflowExecution_SUCCEEDED: + occuredAt, err := ptypes.Timestamp(e.OccurredAt) + assert.NoError(t, err) + + assert.WithinDuration(t, occuredAt, time.Now(), time.Millisecond*5) + recordedSuccess = true + default: + return fmt.Errorf("MockWorkflowRecorder should not have entered into any other states, received [%v]", e.Phase.String()) + } + } + return nil + } + nodeExec, err := nodes.NewExecutor(ctx, store, enqueueWorkflow, time.Second, eventSink, launchplan.NewFailFastLaunchPlanExecutor(), catalog.NewCatalogClient(store), fakeKubeClient, promutils.NewTestScope()) + assert.NoError(t, err) + executor, err := NewExecutor(ctx, store, enqueueWorkflow, eventSink, recorder, "metadata", nodeExec, promutils.NewTestScope()) + assert.NoError(t, err) + + assert.NoError(t, executor.Initialize(ctx)) + + wJSON, err := yamlutils.ReadYamlFileAsJSON("testdata/benchmark_wf.yaml") + if assert.NoError(t, err) { + w := &v1alpha1.FlyteWorkflow{} + if assert.NoError(t, json.Unmarshal(wJSON, w)) { + // For benchmark workflow, we know how many rounds it needs + // Number of rounds = 12 ? + for i := 0; i < 12; i++ { + err := executor.HandleFlyteWorkflow(ctx, w) + assert.NoError(t, err) + fmt.Printf("Round[%d] Workflow[%v]\n", i, w.Status.Phase.String()) + for k, v := range w.Status.NodeStatus { + fmt.Printf("Node[%v=%v],", k, v.Phase.String()) + // Reset dirty manually for tests. + v.ResetDirty() + } + fmt.Printf("\n") + } + + assert.Equal(t, v1alpha1.WorkflowPhaseSuccess.String(), w.Status.Phase.String(), "Message: [%v]", w.Status.Message) + } + } + assert.True(t, recordedRunning) + assert.True(t, recordedFailing) + assert.True(t, recordedSuccess) +} + +func TestWorkflowExecutor_HandleFlyteWorkflow_EventFailure(t *testing.T) { + ctx := context.Background() + store := createInmemoryDataStore(t, promutils.NewTestScope()) + recorder := StdOutEventRecorder() + _, err := events.ConstructEventSink(ctx, events.GetConfig(ctx)) + assert.NoError(t, err) + + te := createHappyPathTaskExecutor(t, store, true) + tf := createSingletonTaskExecutorFactory(te) + task.SetTestFactory(tf) + assert.True(t, task.IsTestModeEnabled()) + + enqueueWorkflow := func(workflowId v1alpha1.WorkflowID) {} + + wJSON, err := yamlutils.ReadYamlFileAsJSON("testdata/benchmark_wf.yaml") + assert.NoError(t, err) + + nodeEventSink := events.NewMockEventSink() + nodeExec, err := nodes.NewExecutor(ctx, store, enqueueWorkflow, time.Second, nodeEventSink, launchplan.NewFailFastLaunchPlanExecutor(), catalog.NewCatalogClient(store), fakeKubeClient, promutils.NewTestScope()) + assert.NoError(t, err) + + t.Run("EventAlreadyInTerminalStateError", func(t *testing.T) { + + eventSink := events.NewMockEventSink() + mockSink := eventSink.(*events.MockEventSink) + mockSink.SinkCb = func(ctx context.Context, message proto.Message) error { + return &eventsErr.EventError{Code: eventsErr.EventAlreadyInTerminalStateError, + Cause: errors.New("already exists"), + } + } + executor, err := NewExecutor(ctx, store, enqueueWorkflow, mockSink, recorder, "metadata", nodeExec, promutils.NewTestScope()) + assert.NoError(t, err) + w := &v1alpha1.FlyteWorkflow{} + assert.NoError(t, json.Unmarshal(wJSON, w)) + + assert.NoError(t, executor.Initialize(ctx)) + err = executor.HandleFlyteWorkflow(ctx, w) + assert.Equal(t, v1alpha1.WorkflowPhaseFailed.String(), w.Status.Phase.String()) + + assert.NoError(t, err) + }) + + t.Run("EventSinkAlreadyExistsError", func(t *testing.T) { + eventSink := events.NewMockEventSink() + mockSink := eventSink.(*events.MockEventSink) + mockSink.SinkCb = func(ctx context.Context, message proto.Message) error { + return &eventsErr.EventError{Code: eventsErr.AlreadyExists, + Cause: errors.New("already exists"), + } + } + executor, err := NewExecutor(ctx, store, enqueueWorkflow, eventSink, recorder, "metadata", nodeExec, promutils.NewTestScope()) + assert.NoError(t, err) + w := &v1alpha1.FlyteWorkflow{} + assert.NoError(t, json.Unmarshal(wJSON, w)) + + err = executor.HandleFlyteWorkflow(ctx, w) + assert.NoError(t, err) + }) + + t.Run("EventSinkGenericError", func(t *testing.T) { + eventSink := events.NewMockEventSink() + mockSink := eventSink.(*events.MockEventSink) + mockSink.SinkCb = func(ctx context.Context, message proto.Message) error { + return &eventsErr.EventError{Code: eventsErr.EventSinkError, + Cause: errors.New("generic exists"), + } + } + executor, err := NewExecutor(ctx, store, enqueueWorkflow, eventSink, recorder, "metadata", nodeExec, promutils.NewTestScope()) + assert.NoError(t, err) + w := &v1alpha1.FlyteWorkflow{} + assert.NoError(t, json.Unmarshal(wJSON, w)) + + err = executor.HandleFlyteWorkflow(ctx, w) + assert.Error(t, err) + assert.True(t, wfErrors.Matches(err, wfErrors.EventRecordingError)) + }) + +} diff --git a/flytepropeller/pkg/controller/workflow/testdata/benchmark_wf.yaml b/flytepropeller/pkg/controller/workflow/testdata/benchmark_wf.yaml new file mode 100644 index 0000000000..d53059b330 --- /dev/null +++ b/flytepropeller/pkg/controller/workflow/testdata/benchmark_wf.yaml @@ -0,0 +1,378 @@ +kind: flyteworkflow +metadata: + creationTimestamp: null + generateName: dummy-workflow-1-0- + labels: + execution-id: "exec-id" + workflow-id: dummy-workflow-1-0 + namespace: myflytenamespace + name: "test-wf" +inputs: + literals: + triggered_date: + scalar: + primitive: + datetime: 2018-08-08T22:16:36.860016587Z +spec: + connections: + add-one-and-print-0: + - sum-non-none-0 + add-one-and-print-1: + - add-one-and-print-2 + - add-one-and-print-2 + - sum-and-print-0 + - sum-and-print-0 + add-one-and-print-2: + - sum-and-print-0 + - sum-and-print-0 + add-one-and-print-3: + - sum-non-none-0 + - sum-non-none-0 + start-node: + - print-every-time-0 + - add-one-and-print-0 + - add-one-and-print-3 + sum-and-print-0: + - print-every-time-0 + - print-every-time-0 + - print-every-time-0 + - print-every-time-0 + sum-non-none-0: + - add-one-and-print-1 + - add-one-and-print-1 + - sum-and-print-0 + id: dummy-workflow-1-0 + nodes: + add-one-and-print-0: + activeDeadlineSeconds: 0 + id: add-one-and-print-0 + inputBindings: + - binding: + scalar: + primitive: + integer: "3" + var: value_to_print + kind: task + resources: + requests: + cpu: "2" + memory: 2Gi + status: + phase: 0 + task: add-one-and-print + add-one-and-print-1: + activeDeadlineSeconds: 0 + id: add-one-and-print-1 + inputBindings: + - binding: + promise: + nodeId: sum-non-none-0 + var: out + var: value_to_print + kind: task + resources: + requests: + cpu: "2" + memory: 2Gi + status: + phase: 0 + task: add-one-and-print + add-one-and-print-2: + activeDeadlineSeconds: 0 + id: add-one-and-print-2 + inputBindings: + - binding: + promise: + nodeId: add-one-and-print-1 + var: out + var: value_to_print + kind: task + resources: + requests: + cpu: "2" + memory: 2Gi + status: + phase: 0 + task: add-one-and-print + add-one-and-print-3: + activeDeadlineSeconds: 0 + id: add-one-and-print-3 + inputBindings: + - binding: + scalar: + primitive: + integer: "101" + var: value_to_print + kind: task + resources: + requests: + cpu: "2" + memory: 2Gi + status: + phase: 0 + task: add-one-and-print + end-node: + id: end-node + kind: end + resources: {} + status: + phase: 0 + print-every-time-0: + activeDeadlineSeconds: 0 + id: print-every-time-0 + inputBindings: + - binding: + promise: + nodeId: start-node + var: triggered_date + var: date_triggered + - binding: + promise: + nodeId: sum-and-print-0 + var: out_blob + var: in_blob + - binding: + promise: + nodeId: sum-and-print-0 + var: multi_blob + var: multi_blob + - binding: + promise: + nodeId: sum-and-print-0 + var: out + var: value_to_print + kind: task + resources: + requests: + cpu: "2" + memory: 2Gi + status: + phase: 0 + task: print-every-time + start-node: + id: start-node + kind: start + resources: {} + status: + phase: 0 + sum-and-print-0: + activeDeadlineSeconds: 0 + id: sum-and-print-0 + inputBindings: + - binding: + collection: + bindings: + - promise: + nodeId: sum-non-none-0 + var: out + - promise: + nodeId: add-one-and-print-1 + var: out + - promise: + nodeId: add-one-and-print-2 + var: out + - scalar: + primitive: + integer: "100" + var: values_to_add + kind: task + resources: + requests: + cpu: "2" + memory: 2Gi + status: + phase: 0 + task: sum-and-print + sum-non-none-0: + activeDeadlineSeconds: 0 + id: sum-non-none-0 + inputBindings: + - binding: + collection: + bindings: + - promise: + nodeId: add-one-and-print-0 + var: out + - promise: + nodeId: add-one-and-print-3 + var: out + var: values_to_print + kind: task + resources: + requests: + cpu: "2" + memory: 2Gi + status: + phase: 0 + task: sum-non-none +status: + phase: 0 +tasks: + add-one-and-print: + container: + args: + - --task-module=flytekit.examples.tasks + - --task-name=add_one_and_print + - --inputs={{$input}} + - --output-prefix={{$output}} + command: + - flyte-python-entrypoint + image: myflytecontainer:abc123 + resources: + requests: + - name: 1 + value: "2.000" + - name: 3 + value: 2048Mi + - name: 2 + value: "0.000" + id: + name: add-one-and-print + interface: + inputs: + variables: + value_to_print: + type: + simple: INTEGER + outputs: + variables: + out: + type: + simple: INTEGER + metadata: + runtime: + type: 1 + version: 1.19.0b10 + timeout: 0s + type: "7" + print-every-time: + container: + args: + - --task-module=flytekit.examples.tasks + - --task-name=print_every_time + - --inputs={{$input}} + - --output-prefix={{$output}} + command: + - flyte-python-entrypoint + image: myflytecontainer:abc123 + resources: + requests: + - name: 1 + value: "2.000" + - name: 3 + value: 2048Mi + - name: 2 + value: "0.000" + id: + name: print-every-time + interface: + inputs: + variables: + date_triggered: + type: + simple: DATETIME + in_blob: + type: + blob: + dimensionality: SINGLE + multi_blob: + type: + blob: + dimensionality: 1 + value_to_print: + type: + simple: INTEGER + outputs: + variables: {} + metadata: + runtime: + type: 1 + version: 1.19.0b10 + timeout: 0s + type: "7" + sum-and-print: + container: + args: + - --task-module=flytekit.examples.tasks + - --task-name=sum_and_print + - --inputs={{$input}} + - --output-prefix={{$output}} + command: + - flyte-python-entrypoint + image: myflytecontainer:abc123 + resources: + requests: + - name: 1 + value: "2.000" + - name: 3 + value: 2048Mi + - name: 2 + value: "0.000" + id: + name: sum-and-print + interface: + inputs: + variables: + values_to_add: + type: + collectionType: + simple: INTEGER + outputs: + variables: + multi_blob: + type: + blob: + dimensionality: 1 + out: + type: + blob: + dimensionality: 0 + out_blob: + type: + blob: + dimensionality: 0 + metadata: + runtime: + type: 1 + version: 1.19.0b10 + timeout: 0s + type: "7" + sum-non-none: + container: + args: + - --task-module=flytekit.examples.tasks + - --task-name=sum_non_none + - --inputs={{$input}} + - --output-prefix={{$output}} + command: + - flyte-python-entrypoint + image: myflytecontainer:abc123 + resources: + requests: + - name: 1 + value: "2.000" + - name: 3 + value: 2048Mi + - name: 2 + value: "0.000" + id: + name: sum-non-none + interface: + inputs: + variables: + values_to_print: + type: + collectionType: + simple: INTEGER + outputs: + variables: + out: + type: + simple: INTEGER + metadata: + runtime: + type: 1 + version: 1.19.0b10 + timeout: 0s + type: "7" + diff --git a/flytepropeller/pkg/controller/workflowstore/errors.go b/flytepropeller/pkg/controller/workflowstore/errors.go new file mode 100644 index 0000000000..572508850a --- /dev/null +++ b/flytepropeller/pkg/controller/workflowstore/errors.go @@ -0,0 +1,18 @@ +package workflowstore + +import ( + "fmt" + + "github.com/pkg/errors" +) + +var errStaleWorkflowError = fmt.Errorf("stale Workflow Found error") +var errWorkflowNotFound = fmt.Errorf("workflow not-found error") + +func IsNotFound(err error) bool { + return errors.Cause(err) == errWorkflowNotFound +} + +func IsWorkflowStale(err error) bool { + return errors.Cause(err) == errStaleWorkflowError +} diff --git a/flytepropeller/pkg/controller/workflowstore/iface.go b/flytepropeller/pkg/controller/workflowstore/iface.go new file mode 100644 index 0000000000..7e40f26b5c --- /dev/null +++ b/flytepropeller/pkg/controller/workflowstore/iface.go @@ -0,0 +1,20 @@ +package workflowstore + +import ( + "context" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" +) + +type PriorityClass int + +const ( + PriorityClassCritical PriorityClass = iota + PriorityClassRegular +) + +type FlyteWorkflow interface { + Get(ctx context.Context, namespace, name string) (*v1alpha1.FlyteWorkflow, error) + UpdateStatus(ctx context.Context, workflow *v1alpha1.FlyteWorkflow, priorityClass PriorityClass) error + Update(ctx context.Context, workflow *v1alpha1.FlyteWorkflow, priorityClass PriorityClass) error +} diff --git a/flytepropeller/pkg/controller/workflowstore/inmemory.go b/flytepropeller/pkg/controller/workflowstore/inmemory.go new file mode 100644 index 0000000000..906e1cc756 --- /dev/null +++ b/flytepropeller/pkg/controller/workflowstore/inmemory.go @@ -0,0 +1,70 @@ +package workflowstore + +import ( + "context" + "fmt" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + kubeerrors "k8s.io/apimachinery/pkg/api/errors" +) + +type InmemoryWorkflowStore struct { + store map[string]map[string]*v1alpha1.FlyteWorkflow +} + +func (i *InmemoryWorkflowStore) Create(ctx context.Context, w *v1alpha1.FlyteWorkflow) error { + if w != nil { + if w.Name != "" && w.Namespace != "" { + if _, ok := i.store[w.Namespace]; !ok { + i.store[w.Namespace] = map[string]*v1alpha1.FlyteWorkflow{} + } + i.store[w.Namespace][w.Name] = w + return nil + } + } + return kubeerrors.NewBadRequest(fmt.Sprintf("Workflow object with Namespace [%v] & Name [%v] is required", w.Namespace, w.Name)) +} + +func (i *InmemoryWorkflowStore) Delete(ctx context.Context, namespace, name string) error { + if m, ok := i.store[namespace]; ok { + if _, ok := m[name]; ok { + delete(m, name) + return nil + } + } + return nil +} + +func (i *InmemoryWorkflowStore) Get(ctx context.Context, namespace, name string) (*v1alpha1.FlyteWorkflow, error) { + if m, ok := i.store[namespace]; ok { + if v, ok := m[name]; ok { + return v, nil + } + } + return nil, errWorkflowNotFound +} + +func (i *InmemoryWorkflowStore) UpdateStatus(ctx context.Context, w *v1alpha1.FlyteWorkflow, priorityClass PriorityClass) error { + if w != nil { + if w.Name != "" && w.Namespace != "" { + if m, ok := i.store[w.Namespace]; ok { + if _, ok := m[w.Name]; ok { + m[w.Name] = w + return nil + } + } + return nil + } + } + return kubeerrors.NewBadRequest("Workflow object with Namespace & Name is required") +} + +func (i *InmemoryWorkflowStore) Update(ctx context.Context, w *v1alpha1.FlyteWorkflow, priorityClass PriorityClass) error { + return i.UpdateStatus(ctx, w, priorityClass) +} + +func NewInMemoryWorkflowStore() *InmemoryWorkflowStore { + return &InmemoryWorkflowStore{ + store: map[string]map[string]*v1alpha1.FlyteWorkflow{}, + } +} diff --git a/flytepropeller/pkg/controller/workflowstore/passthrough.go b/flytepropeller/pkg/controller/workflowstore/passthrough.go new file mode 100644 index 0000000000..882242b347 --- /dev/null +++ b/flytepropeller/pkg/controller/workflowstore/passthrough.go @@ -0,0 +1,106 @@ +package workflowstore + +import ( + "context" + "time" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + v1alpha12 "github.com/lyft/flytepropeller/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1" + listers "github.com/lyft/flytepropeller/pkg/client/listers/flyteworkflow/v1alpha1" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/promutils" + "github.com/prometheus/client_golang/prometheus" + kubeerrors "k8s.io/apimachinery/pkg/api/errors" +) + +type workflowstoreMetrics struct { + workflowUpdateCount prometheus.Counter + workflowUpdateFailedCount prometheus.Counter + workflowUpdateSuccessCount prometheus.Counter + workflowUpdateConflictCount prometheus.Counter + workflowUpdateLatency promutils.StopWatch +} + +type passthroughWorkflowStore struct { + wfLister listers.FlyteWorkflowLister + wfClientSet v1alpha12.FlyteworkflowV1alpha1Interface + metrics *workflowstoreMetrics +} + +func (p *passthroughWorkflowStore) Get(ctx context.Context, namespace, name string) (*v1alpha1.FlyteWorkflow, error) { + w, err := p.wfLister.FlyteWorkflows(namespace).Get(name) + if err != nil { + // The FlyteWorkflow resource may no longer exist, in which case we stop + // processing. + if kubeerrors.IsNotFound(err) { + logger.Warningf(ctx, "Workflow not found in cache.") + return nil, errWorkflowNotFound + } + return nil, err + } + return w, nil +} + +func (p *passthroughWorkflowStore) UpdateStatus(ctx context.Context, workflow *v1alpha1.FlyteWorkflow, priorityClass PriorityClass) error { + p.metrics.workflowUpdateCount.Inc() + // Something has changed. Lets save + logger.Debugf(ctx, "Observed FlyteWorkflow State change. [%v] -> [%v]", workflow.Status.Phase.String(), workflow.Status.Phase.String()) + t := p.metrics.workflowUpdateLatency.Start() + _, err := p.wfClientSet.FlyteWorkflows(workflow.Namespace).Update(workflow) + if err != nil { + if kubeerrors.IsNotFound(err) { + return nil + } + if kubeerrors.IsConflict(err) { + p.metrics.workflowUpdateConflictCount.Inc() + } + p.metrics.workflowUpdateFailedCount.Inc() + logger.Errorf(ctx, "Failed to update workflow status. Error [%v]", err) + return err + } + t.Stop() + p.metrics.workflowUpdateSuccessCount.Inc() + logger.Debugf(ctx, "Updated workflow status.") + return nil +} + +func (p *passthroughWorkflowStore) Update(ctx context.Context, workflow *v1alpha1.FlyteWorkflow, priorityClass PriorityClass) error { + p.metrics.workflowUpdateCount.Inc() + // Something has changed. Lets save + logger.Debugf(ctx, "Observed FlyteWorkflow Update (maybe finalizer)") + t := p.metrics.workflowUpdateLatency.Start() + _, err := p.wfClientSet.FlyteWorkflows(workflow.Namespace).Update(workflow) + if err != nil { + if kubeerrors.IsNotFound(err) { + return nil + } + if kubeerrors.IsConflict(err) { + p.metrics.workflowUpdateConflictCount.Inc() + } + p.metrics.workflowUpdateFailedCount.Inc() + logger.Errorf(ctx, "Failed to update workflow. Error [%v]", err) + return err + } + t.Stop() + p.metrics.workflowUpdateSuccessCount.Inc() + logger.Debugf(ctx, "Updated workflow.") + return nil +} + +func NewPassthroughWorkflowStore(_ context.Context, scope promutils.Scope, wfClient v1alpha12.FlyteworkflowV1alpha1Interface, + flyteworkflowLister listers.FlyteWorkflowLister) FlyteWorkflow { + + metrics := &workflowstoreMetrics{ + workflowUpdateCount: scope.MustNewCounter("wf_updated", "Total number of status updates"), + workflowUpdateFailedCount: scope.MustNewCounter("wf_update_failed", "Failure to update ETCd"), + workflowUpdateConflictCount: scope.MustNewCounter("wf_update_conflict", "Failure to update ETCd because of conflict"), + workflowUpdateSuccessCount: scope.MustNewCounter("wf_update_success", "Success in updating ETCd"), + workflowUpdateLatency: scope.MustNewStopWatch("wf_update_latency", "Time taken to complete update/updatestatus", time.Millisecond), + } + + return &passthroughWorkflowStore{ + wfLister: flyteworkflowLister, + wfClientSet: wfClient, + metrics: metrics, + } +} diff --git a/flytepropeller/pkg/controller/workflowstore/passthrough_test.go b/flytepropeller/pkg/controller/workflowstore/passthrough_test.go new file mode 100644 index 0000000000..9d77993e08 --- /dev/null +++ b/flytepropeller/pkg/controller/workflowstore/passthrough_test.go @@ -0,0 +1,130 @@ +package workflowstore + +import ( + "context" + "fmt" + "testing" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + listers "github.com/lyft/flytepropeller/pkg/client/listers/flyteworkflow/v1alpha1" + "github.com/lyft/flytestdlib/promutils" + "github.com/stretchr/testify/assert" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/lyft/flytepropeller/pkg/client/clientset/versioned/fake" + + kubeerrors "k8s.io/apimachinery/pkg/api/errors" +) + +type mockWFNamespaceLister struct { + listers.FlyteWorkflowNamespaceLister + GetCb func(name string) (*v1alpha1.FlyteWorkflow, error) +} + +func (m *mockWFNamespaceLister) Get(name string) (*v1alpha1.FlyteWorkflow, error) { + return m.GetCb(name) +} + +type mockWFLister struct { + listers.FlyteWorkflowLister + V listers.FlyteWorkflowNamespaceLister +} + +func (m *mockWFLister) FlyteWorkflows(namespace string) listers.FlyteWorkflowNamespaceLister { + return m.V +} + +func TestPassthroughWorkflowStore_Get(t *testing.T) { + ctx := context.TODO() + + mockClient := fake.NewSimpleClientset().FlyteworkflowV1alpha1() + + l := &mockWFNamespaceLister{} + wfStore := NewPassthroughWorkflowStore(ctx, promutils.NewTestScope(), mockClient, &mockWFLister{V: l}) + + t.Run("notFound", func(t *testing.T) { + l.GetCb = func(name string) (*v1alpha1.FlyteWorkflow, error) { + return nil, kubeerrors.NewNotFound(v1alpha1.Resource(v1alpha1.FlyteWorkflowKind), "name") + } + w, err := wfStore.Get(ctx, "ns", "name") + assert.Error(t, err) + assert.True(t, IsNotFound(err)) + assert.Nil(t, w) + }) + + t.Run("alreadyExists?", func(t *testing.T) { + l.GetCb = func(name string) (*v1alpha1.FlyteWorkflow, error) { + return nil, kubeerrors.NewAlreadyExists(v1alpha1.Resource(v1alpha1.FlyteWorkflowKind), "name") + } + w, err := wfStore.Get(ctx, "ns", "name") + assert.Error(t, err) + assert.Nil(t, w) + }) + + t.Run("unknownError", func(t *testing.T) { + l.GetCb = func(name string) (*v1alpha1.FlyteWorkflow, error) { + return nil, fmt.Errorf("error") + } + w, err := wfStore.Get(ctx, "ns", "name") + assert.Error(t, err) + assert.Nil(t, w) + }) + + t.Run("success", func(t *testing.T) { + expW := &v1alpha1.FlyteWorkflow{} + l.GetCb = func(name string) (*v1alpha1.FlyteWorkflow, error) { + return expW, nil + } + w, err := wfStore.Get(ctx, "ns", "name") + assert.NoError(t, err) + assert.Equal(t, expW, w) + }) +} + +func dummyWf(namespace, name string) *v1alpha1.FlyteWorkflow { + return &v1alpha1.FlyteWorkflow{ + ObjectMeta: v1.ObjectMeta{ + Name: name, + Namespace: namespace, + }, + } +} + +func TestPassthroughWorkflowStore_UpdateStatus(t *testing.T) { + + ctx := context.TODO() + + mockClient := fake.NewSimpleClientset().FlyteworkflowV1alpha1() + l := &mockWFNamespaceLister{} + wfStore := NewPassthroughWorkflowStore(ctx, promutils.NewTestScope(), mockClient, &mockWFLister{V: l}) + + const namespace = "test-ns" + t.Run("notFound", func(t *testing.T) { + wf := dummyWf(namespace, "x") + err := wfStore.UpdateStatus(ctx, wf, PriorityClassCritical) + assert.NoError(t, err) + updated, err := mockClient.FlyteWorkflows(namespace).Get("x", v1.GetOptions{}) + assert.Error(t, err) + assert.Nil(t, updated) + }) + + t.Run("Found-Updated", func(t *testing.T) { + n := mockClient.FlyteWorkflows(namespace) + wf := dummyWf(namespace, "x") + wf.GetExecutionStatus().UpdatePhase(v1alpha1.WorkflowPhaseSucceeding, "") + wf.ResourceVersion = "r1" + _, err := n.Create(wf) + assert.NoError(t, err) + updated, err := n.Get("x", v1.GetOptions{}) + if assert.NoError(t, err) { + assert.Equal(t, v1alpha1.WorkflowPhaseSucceeding, updated.GetExecutionStatus().GetPhase()) + wf.GetExecutionStatus().UpdatePhase(v1alpha1.WorkflowPhaseFailed, "") + err := wfStore.UpdateStatus(ctx, wf, PriorityClassCritical) + assert.NoError(t, err) + newVal, err := n.Get("x", v1.GetOptions{}) + assert.NoError(t, err) + assert.Equal(t, v1alpha1.WorkflowPhaseFailed, newVal.GetExecutionStatus().GetPhase()) + } + }) + +} diff --git a/flytepropeller/pkg/controller/workflowstore/resource_version_caching.go b/flytepropeller/pkg/controller/workflowstore/resource_version_caching.go new file mode 100644 index 0000000000..2bb84854cf --- /dev/null +++ b/flytepropeller/pkg/controller/workflowstore/resource_version_caching.go @@ -0,0 +1,96 @@ +package workflowstore + +import ( + "context" + "fmt" + "sync" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytestdlib/promutils" + "github.com/prometheus/client_golang/prometheus" +) + +// TODO - optimization maybe? we can move this to predicate check, before we add it to the queue? +type resourceVersionMetrics struct { + workflowStaleCount prometheus.Counter + workflowEvictedCount prometheus.Counter +} + +// Simple function that covnerts the namespace and name to a string +func resourceVersionKey(namespace, name string) string { + return fmt.Sprintf("%s/%s", namespace, name) +} + +// A specialized store that stores a inmemory cache of all the workflows that are currently executing and their last observed version numbers +// If the version numbers between the last update and the next Get have not been updated then the Get returns a nil (ignores the workflow) +// Propeller round will then just ignore the workflow +type resourceVersionCaching struct { + w FlyteWorkflow + metrics *resourceVersionMetrics + lastUpdatedResourceVersionCache sync.Map +} + +func (r *resourceVersionCaching) updateRevisionCache(_ context.Context, namespace, name, resourceVersion string, isTerminated bool) { + if isTerminated { + r.metrics.workflowEvictedCount.Inc() + r.lastUpdatedResourceVersionCache.Delete(resourceVersionKey(namespace, name)) + } else { + r.lastUpdatedResourceVersionCache.Store(resourceVersionKey(namespace, name), resourceVersion) + } +} + +func (r *resourceVersionCaching) isResourceVersionSameAsPrevious(namespace, name, resourceVersion string) bool { + if v, ok := r.lastUpdatedResourceVersionCache.Load(resourceVersionKey(namespace, name)); ok { + strV := v.(string) + if strV == resourceVersion { + r.metrics.workflowStaleCount.Inc() + return true + } + } + return false +} + +func (r *resourceVersionCaching) Get(ctx context.Context, namespace, name string) (*v1alpha1.FlyteWorkflow, error) { + w, err := r.w.Get(ctx, namespace, name) + if err != nil { + return nil, err + } + if w != nil { + if r.isResourceVersionSameAsPrevious(namespace, name, w.ResourceVersion) { + return nil, errStaleWorkflowError + } + } + return w, nil +} + +func (r *resourceVersionCaching) UpdateStatus(ctx context.Context, workflow *v1alpha1.FlyteWorkflow, priorityClass PriorityClass) error { + err := r.w.UpdateStatus(ctx, workflow, priorityClass) + if err != nil { + return err + } + + r.updateRevisionCache(ctx, workflow.Namespace, workflow.Name, workflow.ResourceVersion, workflow.Status.IsTerminated()) + return nil +} + +func (r *resourceVersionCaching) Update(ctx context.Context, workflow *v1alpha1.FlyteWorkflow, priorityClass PriorityClass) error { + err := r.w.Update(ctx, workflow, priorityClass) + if err != nil { + return err + } + + r.updateRevisionCache(ctx, workflow.Namespace, workflow.Name, workflow.ResourceVersion, workflow.Status.IsTerminated()) + return nil +} + +func NewResourceVersionCachingStore(ctx context.Context, scope promutils.Scope, workflowStore FlyteWorkflow) FlyteWorkflow { + + return &resourceVersionCaching{ + w: workflowStore, + metrics: &resourceVersionMetrics{ + workflowStaleCount: scope.MustNewCounter("wf_stale", "Found stale workflow in cache"), + workflowEvictedCount: scope.MustNewCounter("wf_evict", "removed workflow from resource version cache"), + }, + lastUpdatedResourceVersionCache: sync.Map{}, + } +} diff --git a/flytepropeller/pkg/controller/workflowstore/resource_version_caching_test.go b/flytepropeller/pkg/controller/workflowstore/resource_version_caching_test.go new file mode 100644 index 0000000000..b09847e0f9 --- /dev/null +++ b/flytepropeller/pkg/controller/workflowstore/resource_version_caching_test.go @@ -0,0 +1,153 @@ +package workflowstore + +import ( + "context" + "fmt" + "testing" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/client/clientset/versioned/fake" + "github.com/lyft/flytestdlib/promutils" + "github.com/stretchr/testify/assert" + kubeerrors "k8s.io/apimachinery/pkg/api/errors" +) + +func TestResourceVersionCaching_Get_NotInCache(t *testing.T) { + ctx := context.TODO() + mockClient := fake.NewSimpleClientset().FlyteworkflowV1alpha1() + + scope := promutils.NewTestScope() + l := &mockWFNamespaceLister{} + wfStore := NewResourceVersionCachingStore(ctx, scope, NewPassthroughWorkflowStore(ctx, scope, mockClient, &mockWFLister{V: l})) + + t.Run("notFound", func(t *testing.T) { + l.GetCb = func(name string) (*v1alpha1.FlyteWorkflow, error) { + return nil, kubeerrors.NewNotFound(v1alpha1.Resource(v1alpha1.FlyteWorkflowKind), "name") + } + w, err := wfStore.Get(ctx, "ns", "name") + assert.Error(t, err) + assert.True(t, IsNotFound(err)) + assert.Nil(t, w) + }) + + t.Run("alreadyExists?", func(t *testing.T) { + l.GetCb = func(name string) (*v1alpha1.FlyteWorkflow, error) { + return nil, kubeerrors.NewAlreadyExists(v1alpha1.Resource(v1alpha1.FlyteWorkflowKind), "name") + } + w, err := wfStore.Get(ctx, "ns", "name") + assert.Error(t, err) + assert.Nil(t, w) + }) + + t.Run("unknownError", func(t *testing.T) { + l.GetCb = func(name string) (*v1alpha1.FlyteWorkflow, error) { + return nil, fmt.Errorf("error") + } + w, err := wfStore.Get(ctx, "ns", "name") + assert.Error(t, err) + assert.Nil(t, w) + }) + + t.Run("success", func(t *testing.T) { + expW := &v1alpha1.FlyteWorkflow{} + l.GetCb = func(name string) (*v1alpha1.FlyteWorkflow, error) { + return expW, nil + } + w, err := wfStore.Get(ctx, "ns", "name") + assert.NoError(t, err) + assert.Equal(t, expW, w) + }) +} + +func TestResourceVersionCaching_Get_UpdateAndRead(t *testing.T) { + ctx := context.TODO() + + mockClient := fake.NewSimpleClientset().FlyteworkflowV1alpha1() + + namespace := "ns" + name := "name" + resourceVersion := "r1" + + wf := dummyWf(namespace, name) + wf.ResourceVersion = resourceVersion + + t.Run("Stale", func(t *testing.T) { + + scope := promutils.NewTestScope() + l := &mockWFNamespaceLister{} + wfStore := NewResourceVersionCachingStore(ctx, scope, NewPassthroughWorkflowStore(ctx, scope, mockClient, &mockWFLister{V: l})) + // Insert a new workflow with R1 + err := wfStore.Update(ctx, wf, PriorityClassCritical) + assert.NoError(t, err) + + // Return the same workflow + l.GetCb = func(name string) (*v1alpha1.FlyteWorkflow, error) { + + return wf, nil + } + + w, err := wfStore.Get(ctx, namespace, name) + assert.Error(t, err) + assert.False(t, IsNotFound(err)) + assert.True(t, IsWorkflowStale(err)) + assert.Nil(t, w) + }) + + t.Run("Updated", func(t *testing.T) { + scope := promutils.NewTestScope() + l := &mockWFNamespaceLister{} + wfStore := NewResourceVersionCachingStore(ctx, scope, NewPassthroughWorkflowStore(ctx, scope, mockClient, &mockWFLister{V: l})) + // Insert a new workflow with R1 + err := wfStore.Update(ctx, wf, PriorityClassCritical) + assert.NoError(t, err) + + // Update the workflow version + wf2 := wf.DeepCopy() + wf2.ResourceVersion = "r2" + + // Return updated workflow + l.GetCb = func(name string) (*v1alpha1.FlyteWorkflow, error) { + return wf2, nil + } + + w, err := wfStore.Get(ctx, namespace, name) + assert.NoError(t, err) + assert.NotNil(t, w) + assert.Equal(t, "r2", w.ResourceVersion) + }) +} + +func TestResourceVersionCaching_UpdateTerminated(t *testing.T) { + ctx := context.TODO() + + mockClient := fake.NewSimpleClientset().FlyteworkflowV1alpha1() + + namespace := "ns" + name := "name" + resourceVersion := "r1" + + wf := dummyWf(namespace, name) + wf.ResourceVersion = resourceVersion + + scope := promutils.NewTestScope() + l := &mockWFNamespaceLister{} + wfStore := NewResourceVersionCachingStore(ctx, scope, NewPassthroughWorkflowStore(ctx, scope, mockClient, &mockWFLister{V: l})) + // Insert a new workflow with R1 + err := wfStore.Update(ctx, wf, PriorityClassCritical) + assert.NoError(t, err) + + rvStore := wfStore.(*resourceVersionCaching) + v, ok := rvStore.lastUpdatedResourceVersionCache.Load(resourceVersionKey(namespace, name)) + assert.True(t, ok) + assert.Equal(t, resourceVersion, v.(string)) + + wf2 := wf.DeepCopy() + wf2.Status.Phase = v1alpha1.WorkflowPhaseAborted + err = wfStore.Update(ctx, wf2, PriorityClassCritical) + assert.NoError(t, err) + + v, ok = rvStore.lastUpdatedResourceVersionCache.Load(resourceVersionKey(namespace, name)) + assert.False(t, ok) + assert.Nil(t, v) + +} diff --git a/flytepropeller/pkg/controller/workqueue.go b/flytepropeller/pkg/controller/workqueue.go new file mode 100644 index 0000000000..051333a8ce --- /dev/null +++ b/flytepropeller/pkg/controller/workqueue.go @@ -0,0 +1,47 @@ +package controller + +import ( + "context" + + "github.com/lyft/flytepropeller/pkg/controller/config" + + "golang.org/x/time/rate" + + "github.com/lyft/flytestdlib/logger" + "k8s.io/client-go/util/workqueue" +) + +func NewWorkQueue(ctx context.Context, cfg config.WorkqueueConfig, name string) (workqueue.RateLimitingInterface, error) { + // TODO introduce bounds checks + logger.Infof(ctx, "WorkQueue type [%v] configured", cfg.Type) + switch cfg.Type { + case config.WorkqueueTypeBucketRateLimiter: + logger.Infof(ctx, "Using Bucket Ratelimited Workqueue, Rate [%v] Capacity [%v]", cfg.Rate, cfg.Capacity) + return workqueue.NewNamedRateLimitingQueue( + // 10 qps, 100 bucket size. This is only for retry speed and its only the overall factor (not per item) + &workqueue.BucketRateLimiter{ + Limiter: rate.NewLimiter(rate.Limit(cfg.Rate), cfg.Capacity), + }, name), nil + case config.WorkqueueTypeExponentialFailureRateLimiter: + logger.Infof(ctx, "Using Exponential failure backoff Ratelimited Workqueue, Base Delay [%v], max Delay [%v]", cfg.BaseDelay, cfg.MaxDelay) + return workqueue.NewNamedRateLimitingQueue( + workqueue.NewItemExponentialFailureRateLimiter(cfg.BaseDelay.Duration, cfg.MaxDelay.Duration), + name), nil + case config.WorkqueueTypeMaxOfRateLimiter: + logger.Infof(ctx, "Using Max-of Ratelimited Workqueue, Bucket {Rate [%v] Capacity [%v]} | FailureBackoff {Base Delay [%v], max Delay [%v]}", cfg.Rate, cfg.Capacity, cfg.BaseDelay, cfg.MaxDelay) + return workqueue.NewNamedRateLimitingQueue( + workqueue.NewMaxOfRateLimiter( + &workqueue.BucketRateLimiter{ + Limiter: rate.NewLimiter(rate.Limit(cfg.Rate), cfg.Capacity), + }, + workqueue.NewItemExponentialFailureRateLimiter(cfg.BaseDelay.Duration, + cfg.MaxDelay.Duration), + ), name), nil + + case config.WorkqueueTypeDefault: + fallthrough + default: + logger.Infof(ctx, "Using Default Workqueue") + return workqueue.NewNamedRateLimitingQueue(workqueue.DefaultControllerRateLimiter(), name), nil + } +} diff --git a/flytepropeller/pkg/controller/workqueue_test.go b/flytepropeller/pkg/controller/workqueue_test.go new file mode 100644 index 0000000000..888c93b1f2 --- /dev/null +++ b/flytepropeller/pkg/controller/workqueue_test.go @@ -0,0 +1,54 @@ +package controller + +import ( + "context" + "testing" + "time" + + config2 "github.com/lyft/flytepropeller/pkg/controller/config" + + "github.com/lyft/flytestdlib/config" + "github.com/stretchr/testify/assert" +) + +func TestNewWorkQueue(t *testing.T) { + ctx := context.TODO() + + t.Run("emptyConfig", func(t *testing.T) { + cfg := config2.WorkqueueConfig{} + w, err := NewWorkQueue(ctx, cfg, "q_test1") + assert.NoError(t, err) + assert.NotNil(t, w) + }) + + t.Run("simpleConfig", func(t *testing.T) { + cfg := config2.WorkqueueConfig{ + Type: config2.WorkqueueTypeDefault, + } + w, err := NewWorkQueue(ctx, cfg, "q_test2") + assert.NoError(t, err) + assert.NotNil(t, w) + }) + + t.Run("bucket", func(t *testing.T) { + cfg := config2.WorkqueueConfig{ + Type: config2.WorkqueueTypeBucketRateLimiter, + Capacity: 5, + Rate: 1, + } + w, err := NewWorkQueue(ctx, cfg, "q_test3") + assert.NoError(t, err) + assert.NotNil(t, w) + }) + + t.Run("expfailure", func(t *testing.T) { + cfg := config2.WorkqueueConfig{ + Type: config2.WorkqueueTypeExponentialFailureRateLimiter, + MaxDelay: config.Duration{Duration: time.Second * 10}, + BaseDelay: config.Duration{Duration: time.Second * 1}, + } + w, err := NewWorkQueue(ctx, cfg, "q_test4") + assert.NoError(t, err) + assert.NotNil(t, w) + }) +} diff --git a/flytepropeller/pkg/signals/signal.go b/flytepropeller/pkg/signals/signal.go new file mode 100644 index 0000000000..2fe649c135 --- /dev/null +++ b/flytepropeller/pkg/signals/signal.go @@ -0,0 +1,46 @@ +/* +Copyright 2017 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// This file was derived from https://raw.githubusercontent.com/kubernetes/sample-controller/00e875e461860a3b584a2e891485e5b331473ec5/pkg/signals/signal.go + +package signals + +import ( + "context" + "os" + "os/signal" +) + +var onlyOneSignalHandler = make(chan struct{}) + +// SetupSignalHandler registered for SIGTERM and SIGINT. A stop channel is returned +// which is closed on one of these signals. If a second signal is caught, the program +// is terminated with exit code 1. +func SetupSignalHandler(ctx context.Context) context.Context { + close(onlyOneSignalHandler) // panics when called twice + + childCtx, cancel := context.WithCancel(ctx) + c := make(chan os.Signal, 2) + signal.Notify(c, shutdownSignals...) + go func() { + <-c + cancel() + <-c + os.Exit(1) // second signal. Exit directly. + }() + + return childCtx +} diff --git a/flytepropeller/pkg/signals/signal_posix.go b/flytepropeller/pkg/signals/signal_posix.go new file mode 100644 index 0000000000..0c4cd6007d --- /dev/null +++ b/flytepropeller/pkg/signals/signal_posix.go @@ -0,0 +1,28 @@ +// +build !windows + +/* +Copyright 2017 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// This file was derived from https://raw.githubusercontent.com/kubernetes/sample-controller/00e875e461860a3b584a2e891485e5b331473ec5/pkg/signals/signal_posix.go + +package signals + +import ( + "os" + "syscall" +) + +var shutdownSignals = []os.Signal{os.Interrupt, syscall.SIGTERM} diff --git a/flytepropeller/pkg/signals/signal_windows.go b/flytepropeller/pkg/signals/signal_windows.go new file mode 100644 index 0000000000..3440028810 --- /dev/null +++ b/flytepropeller/pkg/signals/signal_windows.go @@ -0,0 +1,25 @@ +/* +Copyright 2017 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// This file was derived from https://raw.githubusercontent.com/kubernetes/sample-controller/00e875e461860a3b584a2e891485e5b331473ec5/pkg/signals/signal_windows.go + +package signals + +import ( + "os" +) + +var shutdownSignals = []os.Signal{os.Interrupt} diff --git a/flytepropeller/pkg/utils/assert/literals.go b/flytepropeller/pkg/utils/assert/literals.go new file mode 100644 index 0000000000..caf915c13c --- /dev/null +++ b/flytepropeller/pkg/utils/assert/literals.go @@ -0,0 +1,74 @@ +package assert + +import ( + "reflect" + "testing" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/stretchr/testify/assert" +) + +func EqualPrimitive(t *testing.T, p1 *core.Primitive, p2 *core.Primitive) { + if p1 != nil { + assert.NotNil(t, p2) + } + assert.Equal(t, reflect.TypeOf(p1.Value), reflect.TypeOf(p2.Value)) + switch p1.Value.(type) { + case *core.Primitive_Integer: + assert.Equal(t, p1.GetInteger(), p2.GetInteger()) + case *core.Primitive_StringValue: + assert.Equal(t, p1.GetStringValue(), p2.GetStringValue()) + default: + assert.FailNow(t, "Not yet implemented for types %v", reflect.TypeOf(p1.Value)) + } +} + +func EqualScalar(t *testing.T, p1 *core.Scalar, p2 *core.Scalar) { + if p1 != nil { + assert.NotNil(t, p2) + } + assert.Equal(t, reflect.TypeOf(p1.Value), reflect.TypeOf(p2.Value)) + switch p1.Value.(type) { + case *core.Scalar_Primitive: + EqualPrimitive(t, p1.GetPrimitive(), p2.GetPrimitive()) + default: + assert.FailNow(t, "Not yet implemented for types %v", reflect.TypeOf(p1.Value)) + } +} + +func EqualLiterals(t *testing.T, l1 *core.Literal, l2 *core.Literal) { + if l1 != nil { + assert.NotNil(t, l2) + } else { + assert.FailNow(t, "expected value is nil") + } + assert.Equal(t, reflect.TypeOf(l1.Value), reflect.TypeOf(l2.Value)) + switch l1.Value.(type) { + case *core.Literal_Scalar: + EqualScalar(t, l1.GetScalar(), l2.GetScalar()) + case *core.Literal_Map: + EqualLiteralMap(t, l1.GetMap(), l2.GetMap()) + default: + assert.FailNow(t, "Not supported test type") + } +} + +func EqualLiteralMap(t *testing.T, l1 *core.LiteralMap, l2 *core.LiteralMap) { + if assert.NotNil(t, l1, "l1 is nil") && assert.NotNil(t, l2, "l2 is nil") { + assert.Equal(t, len(l1.Literals), len(l2.Literals)) + for k, v := range l1.Literals { + actual, ok := l2.Literals[k] + assert.True(t, ok) + EqualLiterals(t, v, actual) + } + } +} + +func EqualLiteralCollection(t *testing.T, l1 *core.LiteralCollection, l2 *core.LiteralCollection) { + if assert.NotNil(t, l2) { + assert.Equal(t, len(l1.Literals), len(l2.Literals)) + for i, v := range l1.Literals { + EqualLiterals(t, v, l2.Literals[i]) + } + } +} diff --git a/flytepropeller/pkg/utils/bindings.go b/flytepropeller/pkg/utils/bindings.go new file mode 100644 index 0000000000..4a1bdeef22 --- /dev/null +++ b/flytepropeller/pkg/utils/bindings.go @@ -0,0 +1,85 @@ +package utils + +import "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + +func MakeBindingDataPromise(fromNode, fromVar string) *core.BindingData { + return &core.BindingData{ + Value: &core.BindingData_Promise{ + Promise: &core.OutputReference{ + Var: fromVar, + NodeId: fromNode, + }, + }, + } +} + +func MakeBindingPromise(fromNode, fromVar, toVar string) *core.Binding { + return &core.Binding{ + Var: toVar, + Binding: MakeBindingDataPromise(fromNode, fromVar), + } +} + +func MakeBindingDataCollection(bindings ...*core.BindingData) *core.BindingData { + return &core.BindingData{ + Value: &core.BindingData_Collection{ + Collection: &core.BindingDataCollection{ + Bindings: bindings, + }, + }, + } +} + +type Pair struct { + K string + V *core.BindingData +} + +func NewPair(k string, v *core.BindingData) Pair { + return Pair{K: k, V: v} +} + +func MakeBindingDataMap(pairs ...Pair) *core.BindingData { + bindingsMap := map[string]*core.BindingData{} + for _, p := range pairs { + bindingsMap[p.K] = p.V + } + return &core.BindingData{ + Value: &core.BindingData_Map{ + Map: &core.BindingDataMap{ + Bindings: bindingsMap, + }, + }, + } +} + +func MakePrimitiveBindingData(v interface{}) (*core.BindingData, error) { + p, err := MakePrimitive(v) + if err != nil { + return nil, err + } + return &core.BindingData{ + Value: &core.BindingData_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Primitive{ + Primitive: p, + }, + }, + }, + }, nil +} + +func MustMakePrimitiveBindingData(v interface{}) *core.BindingData { + p, err := MakePrimitiveBindingData(v) + if err != nil { + panic(err) + } + return p +} + +func MakeBinding(variable string, b *core.BindingData) *core.Binding { + return &core.Binding{ + Var: variable, + Binding: b, + } +} diff --git a/flytepropeller/pkg/utils/bindings_test.go b/flytepropeller/pkg/utils/bindings_test.go new file mode 100644 index 0000000000..c6cb5fcc12 --- /dev/null +++ b/flytepropeller/pkg/utils/bindings_test.go @@ -0,0 +1,156 @@ +package utils + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +const primitiveString = "hello" + +func TestMakePrimitiveBinding(t *testing.T) { + { + v := 1.0 + xb, err := MakePrimitiveBindingData(v) + x := MakeBinding("x", xb) + assert.NoError(t, err) + assert.Equal(t, "x", x.GetVar()) + p := x.GetBinding() + assert.NotNil(t, p.GetScalar()) + assert.Equal(t, "*core.Primitive_FloatValue", reflect.TypeOf(p.GetScalar().GetPrimitive().Value).String()) + assert.Equal(t, v, p.GetScalar().GetPrimitive().GetFloatValue()) + } + { + v := struct { + }{} + _, err := MakePrimitiveBindingData(v) + assert.Error(t, err) + } +} + +func TestMustMakePrimitiveBinding(t *testing.T) { + { + v := 1.0 + x := MakeBinding("x", MustMakePrimitiveBindingData(v)) + assert.Equal(t, "x", x.GetVar()) + p := x.GetBinding() + assert.NotNil(t, p.GetScalar()) + assert.Equal(t, "*core.Primitive_FloatValue", reflect.TypeOf(p.GetScalar().GetPrimitive().Value).String()) + assert.Equal(t, v, p.GetScalar().GetPrimitive().GetFloatValue()) + } + { + v := struct { + }{} + assert.Panics(t, func() { + MustMakePrimitiveBindingData(v) + }) + } +} + +func TestMakeBindingDataCollection(t *testing.T) { + v1 := int64(1) + v2 := primitiveString + c := MakeBindingDataCollection( + MustMakePrimitiveBindingData(v1), + MustMakePrimitiveBindingData(v2), + ) + + c2 := MakeBindingDataCollection( + MustMakePrimitiveBindingData(v1), + c, + ) + + assert.NotNil(t, c.GetCollection()) + assert.Equal(t, 2, len(c.GetCollection().Bindings)) + { + p := c.GetCollection().GetBindings()[0] + assert.NotNil(t, p.GetScalar()) + assert.Equal(t, "*core.Primitive_Integer", reflect.TypeOf(p.GetScalar().GetPrimitive().Value).String()) + assert.Equal(t, v1, p.GetScalar().GetPrimitive().GetInteger()) + } + { + p := c.GetCollection().GetBindings()[1] + assert.NotNil(t, p.GetScalar()) + assert.Equal(t, "*core.Primitive_StringValue", reflect.TypeOf(p.GetScalar().GetPrimitive().Value).String()) + assert.Equal(t, v2, p.GetScalar().GetPrimitive().GetStringValue()) + } + + assert.NotNil(t, c2.GetCollection()) + assert.Equal(t, 2, len(c2.GetCollection().Bindings)) + { + p := c2.GetCollection().GetBindings()[0] + assert.NotNil(t, p.GetScalar()) + assert.Equal(t, "*core.Primitive_Integer", reflect.TypeOf(p.GetScalar().GetPrimitive().Value).String()) + assert.Equal(t, v1, p.GetScalar().GetPrimitive().GetInteger()) + } + { + p := c2.GetCollection().GetBindings()[1] + assert.NotNil(t, p.GetCollection()) + assert.Equal(t, c.GetCollection(), p.GetCollection()) + } +} + +func TestMakeBindingDataMap(t *testing.T) { + v1 := int64(1) + v2 := primitiveString + c := MakeBindingDataCollection( + MustMakePrimitiveBindingData(v1), + MustMakePrimitiveBindingData(v2), + ) + + m := MakeBindingDataMap( + NewPair("x", MustMakePrimitiveBindingData(v1)), + NewPair("y", c), + ) + + m2 := MakeBindingDataMap( + NewPair("x", MustMakePrimitiveBindingData(v1)), + NewPair("y", m), + ) + assert.NotNil(t, m.GetMap()) + assert.Equal(t, 2, len(m.GetMap().GetBindings())) + { + p := m.GetMap().GetBindings()["x"] + assert.NotNil(t, p.GetScalar()) + assert.Equal(t, "*core.Primitive_Integer", reflect.TypeOf(p.GetScalar().GetPrimitive().Value).String()) + assert.Equal(t, v1, p.GetScalar().GetPrimitive().GetInteger()) + } + { + p := m.GetMap().GetBindings()["y"] + assert.NotNil(t, p.GetCollection()) + assert.Equal(t, c.GetCollection(), p.GetCollection()) + } + + assert.NotNil(t, m2.GetMap()) + assert.Equal(t, 2, len(m2.GetMap().GetBindings())) + { + p := m2.GetMap().GetBindings()["x"] + assert.NotNil(t, p.GetScalar()) + assert.Equal(t, "*core.Primitive_Integer", reflect.TypeOf(p.GetScalar().GetPrimitive().Value).String()) + assert.Equal(t, v1, p.GetScalar().GetPrimitive().GetInteger()) + } + { + p := m2.GetMap().GetBindings()["y"] + assert.NotNil(t, p.GetMap()) + assert.Equal(t, m.GetMap(), p.GetMap()) + } + +} + +func TestMakeBindingPromise(t *testing.T) { + p := MakeBindingPromise("n1", "x", "y") + assert.NotNil(t, p) + assert.Equal(t, "y", p.GetVar()) + assert.NotNil(t, p.GetBinding().GetPromise()) + assert.Equal(t, "n1", p.GetBinding().GetPromise().GetNodeId()) + assert.Equal(t, "x", p.GetBinding().GetPromise().GetVar()) +} + +func TestMakeBindingDataPromise(t *testing.T) { + p := MakeBindingDataPromise("n1", "x") + assert.NotNil(t, p) + assert.NotNil(t, p.GetPromise()) + assert.Equal(t, "n1", p.GetPromise().GetNodeId()) + assert.Equal(t, "x", p.GetPromise().GetVar()) +} diff --git a/flytepropeller/pkg/utils/encoder.go b/flytepropeller/pkg/utils/encoder.go new file mode 100644 index 0000000000..8d92d92f76 --- /dev/null +++ b/flytepropeller/pkg/utils/encoder.go @@ -0,0 +1,55 @@ +package utils + +import ( + "encoding/base32" + "fmt" + "hash/fnv" + "strings" +) + +const specialEncoderKey = "abcdefghijklmnopqrstuvwxyz123456" + +var base32Encoder = base32.NewEncoding(specialEncoderKey).WithPadding(base32.NoPadding) + +// Creates a new UniqueID that is based on the inputID and of a specified length, if the given id is longer than the +// maxLength. +func FixedLengthUniqueID(inputID string, maxLength int) (string, error) { + if len(inputID) <= maxLength { + return inputID, nil + } + + hasher := fnv.New32a() + _, err := hasher.Write([]byte(inputID)) + if err != nil { + return "", err + } + b := hasher.Sum(nil) + // expected length after this step is 8 chars (1 + 7 chars from base32Encoder.EncodeToString(b)) + finalStr := "f" + base32Encoder.EncodeToString(b) + if len(finalStr) > maxLength { + return finalStr, fmt.Errorf("max Length is too small, cannot create an encoded string that is so small") + } + return finalStr, nil +} + +// Creates a new uniqueID using the parts concatenated using `-` and ensures that the uniqueID is not longer than the +// maxLength. In case a simple concatenation yields a longer string, a new hashed ID is created which is always +// around 8 characters in length +func FixedLengthUniqueIDForParts(maxLength int, parts ...string) (string, error) { + b := strings.Builder{} + for i, p := range parts { + if i > 0 { + _, err := b.WriteRune('-') + if err != nil { + return "", err + } + } + + _, err := b.WriteString(p) + if err != nil { + return "", err + } + } + + return FixedLengthUniqueID(b.String(), maxLength) +} diff --git a/flytepropeller/pkg/utils/encoder_test.go b/flytepropeller/pkg/utils/encoder_test.go new file mode 100644 index 0000000000..d4ae8aae78 --- /dev/null +++ b/flytepropeller/pkg/utils/encoder_test.go @@ -0,0 +1,61 @@ +package utils + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestFixedLengthUniqueID(t *testing.T) { + tests := []struct { + name string + input string + maxLength int + output string + expectError bool + }{ + {"smallerThanMax", "x", 5, "x", false}, + {"veryLowLimit", "xx", 1, "flfryc2i", true}, + {"highLimit", "xxxxxx", 5, "fufiti6i", true}, + {"higherLimit", "xxxxx", 10, "xxxxx", false}, + {"largeID", "xxxxxxxxxxx", 10, "fggddjly", false}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + i, err := FixedLengthUniqueID(test.input, test.maxLength) + if test.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, i, test.output) + }) + } +} + +func TestFixedLengthUniqueIDForParts(t *testing.T) { + tests := []struct { + name string + parts []string + maxLength int + output string + expectError bool + }{ + {"smallerThanMax", []string{"x", "y", "z"}, 10, "x-y-z", false}, + {"veryLowLimit", []string{"x", "y"}, 1, "fz2jizji", true}, + {"fittingID", []string{"x"}, 2, "x", false}, + {"highLimit", []string{"x", "y", "z"}, 4, "fxzsoqrq", true}, + {"largeID", []string{"x", "y", "z", "m", "n"}, 8, "fsigbmty", false}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + i, err := FixedLengthUniqueIDForParts(test.maxLength, test.parts...) + if test.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, i, test.output) + }) + } +} diff --git a/flytepropeller/pkg/utils/event_helpers.go b/flytepropeller/pkg/utils/event_helpers.go new file mode 100644 index 0000000000..61af9f01d9 --- /dev/null +++ b/flytepropeller/pkg/utils/event_helpers.go @@ -0,0 +1,32 @@ +package utils + +import ( + "context" + + "github.com/lyft/flyteidl/clients/go/events" + eventsErr "github.com/lyft/flyteidl/clients/go/events/errors" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/event" + "github.com/lyft/flytestdlib/logger" +) + +// Construct task event recorder to pass down to plugin. This is a just a wrapper around the normal +// taskEventRecorder that can encapsulate logic to validate and handle errors. +func NewPluginTaskEventRecorder(taskEventRecorder events.TaskEventRecorder) events.TaskEventRecorder { + return &pluginTaskEventRecorder{ + taskEventRecorder: taskEventRecorder, + } +} + +type pluginTaskEventRecorder struct { + taskEventRecorder events.TaskEventRecorder +} + +func (r pluginTaskEventRecorder) RecordTaskEvent(ctx context.Context, event *event.TaskExecutionEvent) error { + err := r.taskEventRecorder.RecordTaskEvent(ctx, event) + if err != nil && eventsErr.IsAlreadyExists(err) { + logger.Infof(ctx, "Task event phase: %s, taskId %s, retry attempt %d - already exists", + event.Phase.String(), event.GetTaskId(), event.RetryAttempt) + return nil + } + return err +} diff --git a/flytepropeller/pkg/utils/failing_datastore.go b/flytepropeller/pkg/utils/failing_datastore.go new file mode 100644 index 0000000000..5dbff06656 --- /dev/null +++ b/flytepropeller/pkg/utils/failing_datastore.go @@ -0,0 +1,32 @@ +package utils + +import ( + "context" + "fmt" + "io" + + "github.com/lyft/flytestdlib/storage" +) + +type FailingRawStore struct { +} + +func (FailingRawStore) CopyRaw(ctx context.Context, source, destination storage.DataReference, opts storage.Options) error { + return fmt.Errorf("failed to copy raw") +} + +func (FailingRawStore) GetBaseContainerFQN(ctx context.Context) storage.DataReference { + return "" +} + +func (FailingRawStore) Head(ctx context.Context, reference storage.DataReference) (storage.Metadata, error) { + return nil, fmt.Errorf("failed metadata fetch") +} + +func (FailingRawStore) ReadRaw(ctx context.Context, reference storage.DataReference) (io.ReadCloser, error) { + return nil, fmt.Errorf("failed read raw") +} + +func (FailingRawStore) WriteRaw(ctx context.Context, reference storage.DataReference, size int64, opts storage.Options, raw io.Reader) error { + return fmt.Errorf("failed write raw") +} diff --git a/flytepropeller/pkg/utils/failing_datastore_test.go b/flytepropeller/pkg/utils/failing_datastore_test.go new file mode 100644 index 0000000000..9adb3e8854 --- /dev/null +++ b/flytepropeller/pkg/utils/failing_datastore_test.go @@ -0,0 +1,26 @@ +package utils + +import ( + "bytes" + "context" + "testing" + + "github.com/lyft/flytestdlib/storage" + "github.com/stretchr/testify/assert" +) + +func TestFailingRawStore(t *testing.T) { + ctx := context.TODO() + f := FailingRawStore{} + _, err := f.Head(ctx, "") + assert.Error(t, err) + + c := f.GetBaseContainerFQN(ctx) + assert.Equal(t, storage.DataReference(""), c) + + _, err = f.ReadRaw(ctx, "") + assert.Error(t, err) + + err = f.WriteRaw(ctx, "", 0, storage.Options{}, bytes.NewReader(nil)) + assert.Error(t, err) +} diff --git a/flytepropeller/pkg/utils/helpers.go b/flytepropeller/pkg/utils/helpers.go new file mode 100644 index 0000000000..2889164fe1 --- /dev/null +++ b/flytepropeller/pkg/utils/helpers.go @@ -0,0 +1,12 @@ +package utils + +func CopyMap(o map[string]string) (r map[string]string) { + if o == nil { + return nil + } + r = make(map[string]string, len(o)) + for k, v := range o { + r[k] = v + } + return +} diff --git a/flytepropeller/pkg/utils/helpers_test.go b/flytepropeller/pkg/utils/helpers_test.go new file mode 100644 index 0000000000..a9f693ae29 --- /dev/null +++ b/flytepropeller/pkg/utils/helpers_test.go @@ -0,0 +1,19 @@ +package utils + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCopyMap(t *testing.T) { + m := map[string]string{ + "k1": "v1", + "k2": "v2", + } + co := CopyMap(m) + assert.NotNil(t, co) + assert.Equal(t, m, co) + + assert.Nil(t, CopyMap(nil)) +} diff --git a/flytepropeller/pkg/utils/k8s.go b/flytepropeller/pkg/utils/k8s.go new file mode 100644 index 0000000000..5d45ac8689 --- /dev/null +++ b/flytepropeller/pkg/utils/k8s.go @@ -0,0 +1,105 @@ +package utils + +import ( + "github.com/golang/protobuf/ptypes" + "github.com/golang/protobuf/ptypes/timestamp" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/pkg/errors" + v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +var NotTheOwnerError = errors.Errorf("FlytePropeller is not the owner") + +// ResourceNvidiaGPU is the name of the Nvidia GPU resource. +const ResourceNvidiaGPU = "nvidia.com/gpu" + +func ToK8sEnvVar(env []*core.KeyValuePair) []v1.EnvVar { + envVars := make([]v1.EnvVar, 0, len(env)) + for _, kv := range env { + envVars = append(envVars, v1.EnvVar{Name: kv.Key, Value: kv.Value}) + } + return envVars +} + +// TODO we should modify the container resources to contain a map of enum values? +// Also we should probably create tolerations / taints, but we could do that as a post process +func ToK8sResourceList(resources []*core.Resources_ResourceEntry) (v1.ResourceList, error) { + k8sResources := make(v1.ResourceList, len(resources)) + for _, r := range resources { + rVal := r.Value + v, err := resource.ParseQuantity(rVal) + if err != nil { + return nil, errors.Wrap(err, "Failed to parse resource as a valid quantity.") + } + switch r.Name { + case core.Resources_CPU: + if !v.IsZero() { + k8sResources[v1.ResourceCPU] = v + } + case core.Resources_MEMORY: + if !v.IsZero() { + k8sResources[v1.ResourceMemory] = v + } + case core.Resources_STORAGE: + if !v.IsZero() { + k8sResources[v1.ResourceStorage] = v + } + case core.Resources_GPU: + if !v.IsZero() { + k8sResources[ResourceNvidiaGPU] = v + } + } + } + return k8sResources, nil +} + +func ToK8sResourceRequirements(resources *core.Resources) (*v1.ResourceRequirements, error) { + res := &v1.ResourceRequirements{} + if resources == nil { + return res, nil + } + req, err := ToK8sResourceList(resources.Requests) + if err != nil { + return res, err + } + lim, err := ToK8sResourceList(resources.Limits) + if err != nil { + return res, err + } + res.Limits = lim + res.Requests = req + return res, nil +} + +func GetWorkflowIDFromObject(obj metav1.Object) (v1alpha1.WorkflowID, error) { + controller := metav1.GetControllerOf(obj) + if controller == nil { + return "", NotTheOwnerError + } + if controller.Kind == v1alpha1.FlyteWorkflowKind { + return obj.GetNamespace() + "/" + controller.Name, nil + } + return "", NotTheOwnerError +} + +func GetWorkflowIDFromOwner(reference *metav1.OwnerReference, namespace string) (v1alpha1.WorkflowID, error) { + if reference == nil { + return "", NotTheOwnerError + } + if reference.Kind == v1alpha1.FlyteWorkflowKind { + return namespace + "/" + reference.Name, nil + } + return "", NotTheOwnerError +} +func GetProtoTime(t *metav1.Time) *timestamp.Timestamp { + if t != nil { + pTime, err := ptypes.TimestampProto(t.Time) + if err == nil { + return pTime + } + } + return ptypes.TimestampNow() +} diff --git a/flytepropeller/pkg/utils/k8s_test.go b/flytepropeller/pkg/utils/k8s_test.go new file mode 100644 index 0000000000..806d4dd493 --- /dev/null +++ b/flytepropeller/pkg/utils/k8s_test.go @@ -0,0 +1,194 @@ +package utils + +import ( + "testing" + "time" + + "github.com/golang/protobuf/ptypes" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/stretchr/testify/assert" + v12 "k8s.io/api/batch/v1" + v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + v13 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +func TestToK8sEnvVar(t *testing.T) { + e := ToK8sEnvVar([]*core.KeyValuePair{ + {Key: "k1", Value: "v1"}, + {Key: "k2", Value: "v2"}, + }) + + assert.NotEmpty(t, e) + assert.Equal(t, []v1.EnvVar{ + {Name: "k1", Value: "v1"}, + {Name: "k2", Value: "v2"}, + }, e) + + e = ToK8sEnvVar(nil) + assert.Empty(t, e) +} + +func TestToK8sResourceList(t *testing.T) { + { + r, err := ToK8sResourceList([]*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "250m"}, + {Name: core.Resources_GPU, Value: "1"}, + {Name: core.Resources_MEMORY, Value: "1024Mi"}, + {Name: core.Resources_STORAGE, Value: "1024Mi"}, + }) + + assert.NoError(t, err) + assert.NotEmpty(t, r) + assert.NotNil(t, r[v1.ResourceCPU]) + assert.Equal(t, resource.MustParse("250m"), r[v1.ResourceCPU]) + assert.Equal(t, resource.MustParse("1"), r[ResourceNvidiaGPU]) + assert.Equal(t, resource.MustParse("1024Mi"), r[v1.ResourceMemory]) + assert.Equal(t, resource.MustParse("1024Mi"), r[v1.ResourceStorage]) + } + { + r, err := ToK8sResourceList([]*core.Resources_ResourceEntry{}) + assert.NoError(t, err) + assert.Empty(t, r) + } + { + _, err := ToK8sResourceList([]*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "250x"}, + }) + assert.Error(t, err) + } + +} + +func TestToK8sResourceRequirements(t *testing.T) { + + { + r, err := ToK8sResourceRequirements(nil) + assert.NoError(t, err) + assert.NotNil(t, r) + assert.Empty(t, r.Limits) + assert.Empty(t, r.Requests) + } + { + r, err := ToK8sResourceRequirements(&core.Resources{ + Requests: nil, + Limits: nil, + }) + assert.NoError(t, err) + assert.NotNil(t, r) + assert.Empty(t, r.Limits) + assert.Empty(t, r.Requests) + } + { + r, err := ToK8sResourceRequirements(&core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "250m"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "1024m"}, + }, + }) + assert.NoError(t, err) + assert.NotNil(t, r) + assert.Equal(t, resource.MustParse("250m"), r.Requests[v1.ResourceCPU]) + assert.Equal(t, resource.MustParse("1024m"), r.Limits[v1.ResourceCPU]) + } + { + _, err := ToK8sResourceRequirements(&core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "blah"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "1024m"}, + }, + }) + assert.Error(t, err) + } + { + _, err := ToK8sResourceRequirements(&core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "250m"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "blah"}, + }, + }) + assert.Error(t, err) + } +} + +func TestGetWorkflowIDFromObject(t *testing.T) { + { + b := true + j := &v12.Job{ + ObjectMeta: v13.ObjectMeta{ + Namespace: "ns", + OwnerReferences: []v13.OwnerReference{ + { + APIVersion: "test", + Kind: v1alpha1.FlyteWorkflowKind, + Name: "my-id", + UID: "blah", + BlockOwnerDeletion: &b, + Controller: &b, + }, + }, + }, + } + w, err := GetWorkflowIDFromObject(j) + assert.NoError(t, err) + assert.Equal(t, "ns/my-id", w) + } + { + b := true + j := &v12.Job{ + ObjectMeta: v13.ObjectMeta{ + Namespace: "ns", + OwnerReferences: []v13.OwnerReference{ + { + APIVersion: "test", + Kind: "some-other", + Name: "my-id", + UID: "blah", + BlockOwnerDeletion: &b, + Controller: &b, + }, + }, + }, + } + _, err := GetWorkflowIDFromObject(j) + assert.Error(t, err) + } + +} + +func TestGetProtoTime(t *testing.T) { + assert.NotNil(t, GetProtoTime(nil)) + n := time.Now() + nproto, err := ptypes.TimestampProto(n) + assert.NoError(t, err) + assert.Equal(t, nproto, GetProtoTime(&metav1.Time{Time: n})) +} + +func TestGetWorkflowIDFromOwner(t *testing.T) { + tests := []struct { + name string + reference *metav1.OwnerReference + namespace string + expectedOwner string + expectedErr error + }{ + {"nilReference", nil, "", "", NotTheOwnerError}, + {"badReference", &metav1.OwnerReference{Kind: "x"}, "", "", NotTheOwnerError}, + {"wfReference", &metav1.OwnerReference{Kind: v1alpha1.FlyteWorkflowKind, Name: "x"}, "ns", "ns/x", nil}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + o, e := GetWorkflowIDFromOwner(test.reference, test.namespace) + assert.Equal(t, test.expectedOwner, o) + assert.Equal(t, test.expectedErr, e) + }) + } +} diff --git a/flytepropeller/pkg/utils/literals.go b/flytepropeller/pkg/utils/literals.go new file mode 100644 index 0000000000..d882a3f145 --- /dev/null +++ b/flytepropeller/pkg/utils/literals.go @@ -0,0 +1,276 @@ +package utils + +import ( + "reflect" + "time" + + "github.com/golang/protobuf/ptypes" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/pkg/errors" +) + +func MakePrimitive(v interface{}) (*core.Primitive, error) { + switch p := v.(type) { + case int: + return &core.Primitive{ + Value: &core.Primitive_Integer{ + Integer: int64(p), + }, + }, nil + case int64: + return &core.Primitive{ + Value: &core.Primitive_Integer{ + Integer: p, + }, + }, nil + case float64: + return &core.Primitive{ + Value: &core.Primitive_FloatValue{ + FloatValue: p, + }, + }, nil + case time.Time: + t, err := ptypes.TimestampProto(p) + if err != nil { + return nil, err + } + return &core.Primitive{ + Value: &core.Primitive_Datetime{ + Datetime: t, + }, + }, nil + case time.Duration: + d := ptypes.DurationProto(p) + return &core.Primitive{ + Value: &core.Primitive_Duration{ + Duration: d, + }, + }, nil + case string: + return &core.Primitive{ + Value: &core.Primitive_StringValue{ + StringValue: p, + }, + }, nil + case bool: + return &core.Primitive{ + Value: &core.Primitive_Boolean{ + Boolean: p, + }, + }, nil + } + return nil, errors.Errorf("Failed to convert to a known primitive type. Input Type [%v] not supported", reflect.TypeOf(v).String()) +} + +func MustMakePrimitive(v interface{}) *core.Primitive { + f, err := MakePrimitive(v) + if err != nil { + panic(err) + } + return f +} + +func MakePrimitiveLiteral(v interface{}) (*core.Literal, error) { + p, err := MakePrimitive(v) + if err != nil { + return nil, err + } + return &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Primitive{ + Primitive: p, + }, + }, + }, + }, nil +} + +func MustMakePrimitiveLiteral(v interface{}) *core.Literal { + p, err := MakePrimitiveLiteral(v) + if err != nil { + panic(err) + } + return p +} + +func MakeLiteralForMap(v map[string]interface{}) (*core.Literal, error) { + m, err := MakeLiteralMap(v) + if err != nil { + return nil, err + } + + return &core.Literal{ + Value: &core.Literal_Map{ + Map: m, + }, + }, nil +} + +func MakeLiteralForCollection(v []interface{}) (*core.Literal, error) { + literals := make([]*core.Literal, 0, len(v)) + for _, val := range v { + l, err := MakeLiteral(val) + if err != nil { + return nil, err + } + + literals = append(literals, l) + } + + return &core.Literal{ + Value: &core.Literal_Collection{ + Collection: &core.LiteralCollection{ + Literals: literals, + }, + }, + }, nil +} + +func MakeBinaryLiteral(v []byte) *core.Literal { + return &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: v, + }, + }, + }, + }, + } +} + +func MakeLiteral(v interface{}) (*core.Literal, error) { + if v == nil { + return &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_NoneType{ + NoneType: nil, + }, + }, + }, + }, nil + } + switch o := v.(type) { + case *core.Literal: + return o, nil + case []interface{}: + return MakeLiteralForCollection(o) + case map[string]interface{}: + return MakeLiteralForMap(o) + case []byte: + return MakeBinaryLiteral(v.([]byte)), nil + default: + return MakePrimitiveLiteral(o) + } +} + +func MustMakeDefaultLiteralForType(typ *core.LiteralType) *core.Literal { + if res, err := MakeDefaultLiteralForType(typ); err != nil { + panic(err) + } else { + return res + } +} + +func MakeDefaultLiteralForType(typ *core.LiteralType) (*core.Literal, error) { + switch t := typ.GetType().(type) { + case *core.LiteralType_Simple: + switch t.Simple { + case core.SimpleType_NONE: + return MakeLiteral(nil) + case core.SimpleType_INTEGER: + return MakeLiteral(int(0)) + case core.SimpleType_FLOAT: + return MakeLiteral(float64(0)) + case core.SimpleType_STRING: + return MakeLiteral("") + case core.SimpleType_BOOLEAN: + return MakeLiteral(false) + case core.SimpleType_DATETIME: + return MakeLiteral(time.Now()) + case core.SimpleType_DURATION: + return MakeLiteral(time.Second) + case core.SimpleType_BINARY: + return MakeLiteral([]byte{}) + //case core.SimpleType_WAITABLE: + //case core.SimpleType_ERROR: + } + return nil, errors.Errorf("Not yet implemented. Default creation is not yet implemented. ") + + case *core.LiteralType_Blob: + return &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Blob{ + Blob: &core.Blob{ + Metadata: &core.BlobMetadata{ + Type: t.Blob, + }, + Uri: "/tmp/somepath", + }, + }, + }, + }, + }, nil + case *core.LiteralType_CollectionType: + single, err := MakeDefaultLiteralForType(t.CollectionType) + if err != nil { + return nil, err + } + + return &core.Literal{ + Value: &core.Literal_Collection{ + Collection: &core.LiteralCollection{ + Literals: []*core.Literal{single}, + }, + }, + }, nil + case *core.LiteralType_MapValueType: + single, err := MakeDefaultLiteralForType(t.MapValueType) + if err != nil { + return nil, err + } + + return &core.Literal{ + Value: &core.Literal_Map{ + Map: &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "itemKey": single, + }, + }, + }, + }, nil + //case *core.LiteralType_Schema: + } + + return nil, errors.Errorf("Failed to convert to a known Literal. Input Type [%v] not supported", typ.String()) +} + +func MustMakeLiteral(v interface{}) *core.Literal { + p, err := MakeLiteral(v) + if err != nil { + panic(err) + } + + return p +} + +func MakeLiteralMap(v map[string]interface{}) (*core.LiteralMap, error) { + + literals := make(map[string]*core.Literal, len(v)) + for key, val := range v { + l, err := MakeLiteral(val) + if err != nil { + return nil, err + } + + literals[key] = l + } + + return &core.LiteralMap{ + Literals: literals, + }, nil +} diff --git a/flytepropeller/pkg/utils/literals_test.go b/flytepropeller/pkg/utils/literals_test.go new file mode 100644 index 0000000000..05dd11ff77 --- /dev/null +++ b/flytepropeller/pkg/utils/literals_test.go @@ -0,0 +1,200 @@ +package utils + +import ( + "reflect" + "testing" + "time" + + "github.com/golang/protobuf/ptypes" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/stretchr/testify/assert" +) + +func TestMakePrimitive(t *testing.T) { + { + v := 1 + p, err := MakePrimitive(v) + assert.NoError(t, err) + assert.Equal(t, "*core.Primitive_Integer", reflect.TypeOf(p.Value).String()) + assert.Equal(t, int64(v), p.GetInteger()) + } + { + v := int64(1) + p, err := MakePrimitive(v) + assert.NoError(t, err) + assert.Equal(t, "*core.Primitive_Integer", reflect.TypeOf(p.Value).String()) + assert.Equal(t, v, p.GetInteger()) + } + { + v := 1.0 + p, err := MakePrimitive(v) + assert.NoError(t, err) + assert.Equal(t, "*core.Primitive_FloatValue", reflect.TypeOf(p.Value).String()) + assert.Equal(t, v, p.GetFloatValue()) + } + { + v := "blah" + p, err := MakePrimitive(v) + assert.NoError(t, err) + assert.Equal(t, "*core.Primitive_StringValue", reflect.TypeOf(p.Value).String()) + assert.Equal(t, v, p.GetStringValue()) + } + { + v := true + p, err := MakePrimitive(v) + assert.NoError(t, err) + assert.Equal(t, "*core.Primitive_Boolean", reflect.TypeOf(p.Value).String()) + assert.Equal(t, v, p.GetBoolean()) + } + { + v := time.Now() + p, err := MakePrimitive(v) + assert.NoError(t, err) + assert.Equal(t, "*core.Primitive_Datetime", reflect.TypeOf(p.Value).String()) + j, err := ptypes.TimestampProto(v) + assert.NoError(t, err) + assert.Equal(t, j, p.GetDatetime()) + } + { + v := time.Second * 10 + p, err := MakePrimitive(v) + assert.NoError(t, err) + assert.Equal(t, "*core.Primitive_Duration", reflect.TypeOf(p.Value).String()) + assert.Equal(t, ptypes.DurationProto(v), p.GetDuration()) + } + { + v := struct { + }{} + _, err := MakePrimitive(v) + assert.Error(t, err) + } +} + +func TestMustMakePrimitive(t *testing.T) { + { + v := struct { + }{} + assert.Panics(t, func() { + MustMakePrimitive(v) + }) + } + { + v := time.Second * 10 + p := MustMakePrimitive(v) + assert.Equal(t, "*core.Primitive_Duration", reflect.TypeOf(p.Value).String()) + assert.Equal(t, ptypes.DurationProto(v), p.GetDuration()) + } +} + +func TestMakePrimitiveLiteral(t *testing.T) { + { + v := 1.0 + p, err := MakePrimitiveLiteral(v) + assert.NoError(t, err) + assert.NotNil(t, p.GetScalar()) + assert.Equal(t, "*core.Primitive_FloatValue", reflect.TypeOf(p.GetScalar().GetPrimitive().Value).String()) + assert.Equal(t, v, p.GetScalar().GetPrimitive().GetFloatValue()) + } + { + v := struct { + }{} + _, err := MakePrimitiveLiteral(v) + assert.Error(t, err) + } +} + +func TestMustMakePrimitiveLiteral(t *testing.T) { + t.Run("Panic", func(t *testing.T) { + v := struct { + }{} + assert.Panics(t, func() { + MustMakePrimitiveLiteral(v) + }) + }) + t.Run("FloatValue", func(t *testing.T) { + v := 1.0 + p := MustMakePrimitiveLiteral(v) + assert.NotNil(t, p.GetScalar()) + assert.Equal(t, "*core.Primitive_FloatValue", reflect.TypeOf(p.GetScalar().GetPrimitive().Value).String()) + assert.Equal(t, v, p.GetScalar().GetPrimitive().GetFloatValue()) + }) +} + +func TestMakeLiteral(t *testing.T) { + t.Run("Primitive", func(t *testing.T) { + lit, err := MakeLiteral("test_string") + assert.NoError(t, err) + assert.Equal(t, "*core.Primitive_StringValue", reflect.TypeOf(lit.GetScalar().GetPrimitive().Value).String()) + }) + + t.Run("Array", func(t *testing.T) { + lit, err := MakeLiteral([]interface{}{1, 2, 3}) + assert.NoError(t, err) + assert.Equal(t, "*core.Literal_Collection", reflect.TypeOf(lit.GetValue()).String()) + assert.Equal(t, "*core.Primitive_Integer", reflect.TypeOf(lit.GetCollection().Literals[0].GetScalar().GetPrimitive().Value).String()) + }) + + t.Run("Map", func(t *testing.T) { + lit, err := MakeLiteral(map[string]interface{}{ + "key1": []interface{}{1, 2, 3}, + "key2": []interface{}{5}, + }) + assert.NoError(t, err) + assert.Equal(t, "*core.Literal_Map", reflect.TypeOf(lit.GetValue()).String()) + assert.Equal(t, "*core.Literal_Collection", reflect.TypeOf(lit.GetMap().Literals["key1"].GetValue()).String()) + }) + + t.Run("Binary", func(t *testing.T) { + s := MakeBinaryLiteral([]byte{'h'}) + assert.Equal(t, []byte{'h'}, s.GetScalar().GetBinary().GetValue()) + }) + + t.Run("NoneType", func(t *testing.T) { + p, err := MakeLiteral(nil) + assert.NoError(t, err) + assert.NotNil(t, p.GetScalar()) + assert.Equal(t, "*core.Scalar_NoneType", reflect.TypeOf(p.GetScalar().Value).String()) + }) +} + +func TestMustMakeLiteral(t *testing.T) { + v := "hello" + l := MustMakeLiteral(v) + assert.NotNil(t, l.GetScalar()) + assert.Equal(t, v, l.GetScalar().GetPrimitive().GetStringValue()) +} + +func TestMakeBinaryLiteral(t *testing.T) { + s := MakeBinaryLiteral([]byte{'h'}) + assert.Equal(t, []byte{'h'}, s.GetScalar().GetBinary().GetValue()) +} + +func TestMakeDefaultLiteralForType(t *testing.T) { + + tests := [][]interface{}{ + {"Integer", core.SimpleType_INTEGER, "*core.Primitive_Integer"}, + {"Float", core.SimpleType_FLOAT, "*core.Primitive_FloatValue"}, + {"String", core.SimpleType_STRING, "*core.Primitive_StringValue"}, + {"Boolean", core.SimpleType_BOOLEAN, "*core.Primitive_Boolean"}, + {"Duration", core.SimpleType_DURATION, "*core.Primitive_Duration"}, + {"Datetime", core.SimpleType_DATETIME, "*core.Primitive_Datetime"}, + } + + for i := range tests { + name := tests[i][0].(string) + ty := tests[i][1].(core.SimpleType) + tyName := tests[i][2].(string) + + t.Run(name, func(t *testing.T) { + l, err := MakeDefaultLiteralForType(&core.LiteralType{Type: &core.LiteralType_Simple{Simple: ty}}) + assert.NoError(t, err) + assert.Equal(t, tyName, reflect.TypeOf(l.GetScalar().GetPrimitive().Value).String()) + }) + } + + t.Run("Binary", func(t *testing.T) { + s, err := MakeLiteral([]byte{'h'}) + assert.NoError(t, err) + assert.Equal(t, []byte{'h'}, s.GetScalar().GetBinary().GetValue()) + }) +} diff --git a/flytepropeller/pkg/visualize/nodeq.go b/flytepropeller/pkg/visualize/nodeq.go new file mode 100644 index 0000000000..5b479a0790 --- /dev/null +++ b/flytepropeller/pkg/visualize/nodeq.go @@ -0,0 +1,35 @@ +package visualize + +import "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + +type NodeQ []v1alpha1.NodeID + +func (s *NodeQ) Enqueue(items ...v1alpha1.NodeID) { + *s = append(*s, items...) +} + +func (s NodeQ) HasNext() bool { + return len(s) > 0 +} + +func (s NodeQ) Remaining() int { + return len(s) +} + +func (s *NodeQ) Peek() v1alpha1.NodeID { + if s.HasNext() { + return (*s)[0] + } + + return "" +} + +func (s *NodeQ) Deque() v1alpha1.NodeID { + item := s.Peek() + *s = (*s)[1:] + return item +} + +func NewNodeNameQ(items ...v1alpha1.NodeID) NodeQ { + return NodeQ(items) +} diff --git a/flytepropeller/pkg/visualize/sort.go b/flytepropeller/pkg/visualize/sort.go new file mode 100644 index 0000000000..7b4f62559b --- /dev/null +++ b/flytepropeller/pkg/visualize/sort.go @@ -0,0 +1,72 @@ +package visualize + +import ( + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/pkg/errors" +) + +type VisitStatus int8 + +const ( + NotVisited VisitStatus = iota + Visited + Completed +) + +type NodeVisitor map[v1alpha1.NodeID]VisitStatus + +func NewNodeVisitor(nodes []v1alpha1.NodeID) NodeVisitor { + v := make(NodeVisitor, len(nodes)) + for _, n := range nodes { + v[n] = NotVisited + } + return v +} + +func tsortHelper(g v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode, visited NodeVisitor, reverseSortedNodes *[]v1alpha1.ExecutableNode) error { + if visited[currentNode.GetID()] == NotVisited { + visited[currentNode.GetID()] = Visited + defer func() { + visited[currentNode.GetID()] = Completed + }() + nodes, err := g.FromNode(currentNode.GetID()) + if err != nil { + return err + } + for _, childID := range nodes { + child, ok := g.GetNode(childID) + if !ok { + return errors.Errorf("Unable to find Node [%s] in Workflow [%s]", childID, g.GetID()) + } + if err := tsortHelper(g, child, visited, reverseSortedNodes); err != nil { + return err + } + } + + *reverseSortedNodes = append(*reverseSortedNodes, currentNode) + return nil + } + // Node was successfully visited previously + if visited[currentNode.GetID()] == Completed { + return nil + } + // Node was partially visited and we are in the subgraph and that reached back to the parent + return errors.Errorf("Cycle detected. Node [%v]", currentNode.GetID()) +} + +func reverseSlice(sl []v1alpha1.ExecutableNode) []v1alpha1.ExecutableNode { + for i := len(sl)/2 - 1; i >= 0; i-- { + opp := len(sl) - 1 - i + sl[i], sl[opp] = sl[opp], sl[i] + } + return sl +} + +func TopologicalSort(g v1alpha1.ExecutableWorkflow) ([]v1alpha1.ExecutableNode, error) { + reverseSortedNodes := make([]v1alpha1.ExecutableNode, 0, 25) + visited := NewNodeVisitor(g.GetNodes()) + if err := tsortHelper(g, g.StartNode(), visited, &reverseSortedNodes); err != nil { + return nil, err + } + return reverseSlice(reverseSortedNodes), nil +} diff --git a/flytepropeller/pkg/visualize/visualize.go b/flytepropeller/pkg/visualize/visualize.go new file mode 100644 index 0000000000..358f6e5cbc --- /dev/null +++ b/flytepropeller/pkg/visualize/visualize.go @@ -0,0 +1,244 @@ +package visualize + +import ( + "fmt" + "reflect" + "strings" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/compiler/common" + "k8s.io/apimachinery/pkg/util/sets" +) + +const executionEdgeLabel = "execution" + +type edgeStyle = string + +const ( + styleSolid edgeStyle = "solid" + styleDashed edgeStyle = "dashed" +) + +const staticNodeID = "static" + +func flatten(binding *core.BindingData, flatMap map[common.NodeID]sets.String) { + switch binding.GetValue().(type) { + case *core.BindingData_Collection: + for _, v := range binding.GetCollection().GetBindings() { + flatten(v, flatMap) + } + case *core.BindingData_Map: + for _, v := range binding.GetMap().GetBindings() { + flatten(v, flatMap) + } + case *core.BindingData_Promise: + if _, ok := flatMap[binding.GetPromise().NodeId]; !ok { + flatMap[binding.GetPromise().NodeId] = sets.String{} + } + + flatMap[binding.GetPromise().NodeId].Insert(binding.GetPromise().GetVar()) + case *core.BindingData_Scalar: + if _, ok := flatMap[staticNodeID]; !ok { + flatMap[staticNodeID] = sets.NewString() + } + } +} + +// Returns GraphViz https://www.graphviz.org/ representation of the current state of the state machine. +func WorkflowToGraphViz(g *v1alpha1.FlyteWorkflow) string { + res := fmt.Sprintf("digraph G {rankdir=TB;workflow[label=\"Workflow Id: %v\"];node[style=filled];", + g.ID) + + nodeFinder := func(nodeId common.NodeID) *v1alpha1.NodeSpec { + for _, n := range g.Nodes { + if n.ID == nodeId { + return n + } + } + + return nil + } + + nodeLabel := func(nodeId common.NodeID) string { + node := nodeFinder(nodeId) + return fmt.Sprintf("%v(%v)", node.ID, node.Kind) + } + + edgeLabel := func(nodeFromId, nodeToId common.NodeID) string { + flatMap := make(map[common.NodeID]sets.String) + nodeFrom := nodeFinder(nodeFromId) + nodeTo := nodeFinder(nodeToId) + for _, binding := range nodeTo.GetInputBindings() { + flatten(binding.GetBinding(), flatMap) + } + + if vars, found := flatMap[nodeFrom.ID]; found { + return strings.Join(vars.List(), ",") + } else if vars, found := flatMap[""]; found && nodeFromId == common.StartNodeID { + return strings.Join(vars.List(), ",") + } else { + return executionEdgeLabel + } + } + + style := func(edgeLabel string) string { + if edgeLabel == executionEdgeLabel { + return styleDashed + } + + return styleSolid + } + + start := nodeFinder(common.StartNodeID) + res += fmt.Sprintf("\"%v\" [shape=Msquare];", nodeLabel(start.ID)) + visitedNodes := sets.NewString(start.ID) + createdEdges := sets.NewString() + + for nodesToVisit := NewNodeNameQ(start.ID); nodesToVisit.HasNext(); { + node := nodesToVisit.Deque() + nodes, found := g.GetConnections().DownstreamEdges[node] + if found { + nodesToVisit.Enqueue(nodes...) + + for _, child := range nodes { + label := edgeLabel(node, child) + edge := fmt.Sprintf("\"%v\" -> \"%v\" [label=\"%v\",style=\"%v\"];", + nodeLabel(node), + nodeLabel(child), + label, + style(label), + ) + + if !createdEdges.Has(edge) { + res += edge + createdEdges.Insert(edge) + } + } + } + + // add static bindings' links + flatMap := make(common.StringAdjacencyList) + n := nodeFinder(node) + for _, binding := range n.GetInputBindings() { + flatten(binding.GetBinding(), flatMap) + } + + if vars, found := flatMap[staticNodeID]; found { + res += fmt.Sprintf("\"static\" -> \"%v\" [label=\"%v\"];", + nodeLabel(node), + strings.Join(vars.List(), ","), + ) + } + + visitedNodes.Insert(node) + } + + res += "}" + + return res +} + +func ToGraphViz(g *core.CompiledWorkflow) string { + res := fmt.Sprintf("digraph G {rankdir=TB;workflow[label=\"Workflow Id: %v\"];node[style=filled];", + g.Template.GetId()) + + nodeFinder := func(nodeId common.NodeID) *core.Node { + for _, n := range g.Template.Nodes { + if n.Id == nodeId { + return n + } + } + + return nil + } + + nodeKind := func(nodeId common.NodeID) string { + node := nodeFinder(nodeId) + if nodeId == common.StartNodeID { + return "start" + } else if nodeId == common.EndNodeID { + return "end" + } else { + return reflect.TypeOf(node.GetTarget()).Name() + } + } + + nodeLabel := func(nodeId common.NodeID) string { + node := nodeFinder(nodeId) + return fmt.Sprintf("%v(%v)", node.GetId(), nodeKind(nodeId)) + } + + edgeLabel := func(nodeFromId, nodeToId common.NodeID) string { + flatMap := make(map[common.NodeID]sets.String) + nodeFrom := nodeFinder(nodeFromId) + nodeTo := nodeFinder(nodeToId) + for _, binding := range nodeTo.GetInputs() { + flatten(binding.GetBinding(), flatMap) + } + + if vars, found := flatMap[nodeFrom.GetId()]; found { + return strings.Join(vars.List(), ",") + } else if vars, found := flatMap[""]; found && nodeFromId == common.StartNodeID { + return strings.Join(vars.List(), ",") + } else { + return executionEdgeLabel + } + } + + style := func(edgeLabel string) string { + if edgeLabel == executionEdgeLabel { + return styleDashed + } + + return styleSolid + } + + start := nodeFinder(common.StartNodeID) + res += fmt.Sprintf("\"%v\" [shape=Msquare];", nodeLabel(start.GetId())) + visitedNodes := sets.NewString(start.GetId()) + createdEdges := sets.NewString() + + for nodesToVisit := NewNodeNameQ(start.GetId()); nodesToVisit.HasNext(); { + node := nodesToVisit.Deque() + nodes, found := g.GetConnections().GetDownstream()[node] + if found { + nodesToVisit.Enqueue(nodes.Ids...) + + for _, child := range nodes.Ids { + label := edgeLabel(node, child) + edge := fmt.Sprintf("\"%v\" -> \"%v\" [label=\"%v\",style=\"%v\"];", + nodeLabel(node), + nodeLabel(child), + label, + style(label), + ) + + if !createdEdges.Has(edge) { + res += edge + createdEdges.Insert(edge) + } + } + } + + // add static bindings' links + flatMap := make(common.StringAdjacencyList) + n := nodeFinder(node) + for _, binding := range n.GetInputs() { + flatten(binding.GetBinding(), flatMap) + } + + if vars, found := flatMap[staticNodeID]; found { + res += fmt.Sprintf("\"static\" -> \"%v\" [label=\"%v\"];", + nodeLabel(node), + strings.Join(vars.List(), ","), + ) + } + + visitedNodes.Insert(node) + } + + res += "}" + + return res +} diff --git a/flytepropeller/raw_examples/README.md b/flytepropeller/raw_examples/README.md new file mode 100644 index 0000000000..dad5d2fbcf --- /dev/null +++ b/flytepropeller/raw_examples/README.md @@ -0,0 +1,3 @@ +The intention of these examples is to test basic functionality of propeller. +Usually users should be using the flyteidl interface to interact with propeller +through flytectl diff --git a/flytepropeller/raw_examples/example-condition.yaml b/flytepropeller/raw_examples/example-condition.yaml new file mode 100644 index 0000000000..3af1f14b1a --- /dev/null +++ b/flytepropeller/raw_examples/example-condition.yaml @@ -0,0 +1,104 @@ +apiVersion: flyte.lyft.com/v1alpha1 +kind: FlyteWorkflow +metadata: + name: test-branch + namespace: default +tasks: + foo: + id: foo + category: 0 + type: container + metadata: + runtime: + type: 0 + version: "1.18.0" + flavor: python + discoverable: true + interface: + inputs: + - name: x + type: + simple: INTEGER + outputs: + - name: "y" + type: + simple: INTEGER + container: + image: alpine + command: ["echo", "Hello", "{{$input}}", "{{$output}}"] +spec: + id: test-branch + nodes: + start: + id: start + kind: start + input_bindings: + - var: x + binding: + scalar: + primitive: + integer: 5 + - var: "y" + binding: + scalar: + primitive: + integer: 10 + foo1: + id: foo1 + kind: task + input_bindings: + - var: x + binding: + promise: + node_id: start + var: "y" + resources: + requests: + cpu: 250m + limits: + cpu: 250m + task_ref: foo + foo2: + id: foo2 + kind: task + input_bindings: + - var: x + binding: + promise: + node_id: start + var: x + resources: + requests: + cpu: 250m + limits: + cpu: 250m + task_ref: foo + foobranch: + id: foobranch + kind: branch + input_bindings: + - var: x + binding: + promise: + node_id: start + var: x + branch_node: + if: + condition: + comparison: + left_value: + var: x + operator: GT + right_value: + primitive: + integer: 5 + then: foo1 + else: foo2 + connections: + start: + - foobranch + - foo1 + - foo2 + foobranch: + - foo1 + - foo2 diff --git a/flytepropeller/raw_examples/example-inputs.yaml b/flytepropeller/raw_examples/example-inputs.yaml new file mode 100644 index 0000000000..b3a641cc70 --- /dev/null +++ b/flytepropeller/raw_examples/example-inputs.yaml @@ -0,0 +1,61 @@ +apiVersion: flyte.lyft.com/v1alpha1 +kind: FlyteWorkflow +metadata: + name: test-wf-inputs + namespace: default +tasks: + foo: + id: foo + category: 0 + type: container + metadata: + name: foo + runtime: + # Enums are ints + type: 0 + version: "1.18.0" + flavor: python + discoverable: true + description: "Test Task" + interface: + inputs: + - name: x + type: + simple: INTEGER + outputs: + - name: "y" + type: + simple: INTEGER + container: + image: alpine + command: ["echo", "Hello", "{{$input}}", "{{$output}}"] +spec: + id: test-wf + nodes: + start: + id: start + kind: start + input_bindings: + - var: x + binding: + scalar: + primitive: + integer: 5 + foo1: + id: foo1 + kind: task + resources: + requests: + cpu: 250m + limits: + cpu: 250m + task_ref: foo + input_bindings: + - var: x + binding: + promise: + node_id: start + var: x + connections: + start: + - foo1 diff --git a/flytepropeller/raw_examples/example-noinputs.yaml b/flytepropeller/raw_examples/example-noinputs.yaml new file mode 100644 index 0000000000..719b582dcc --- /dev/null +++ b/flytepropeller/raw_examples/example-noinputs.yaml @@ -0,0 +1,41 @@ +apiVersion: flyte.lyft.com/v1alpha1 +kind: FlyteWorkflow +metadata: + name: test-fg + namespace: default +tasks: + foo: + id: foo + # Enums are ints + category: 0 + type: container + metadata: + name: foo + runtime: + # Enums are ints + type: 0 + version: "1.18.0" + flavor: python + discoverable: true + description: "Test Task" + container: + image: alpine + command: ["ls", "${{inputs}}", "${{outputs}}"] +spec: + id: test-wf + nodes: + start: + id: start + kind: start + foo1: + id: foo1 + kind: task + resources: + requests: + cpu: 250m + limits: + cpu: 250m + task_ref: foo + connections: + start: + - foo1 From e1e7b5c0591d2e695e8085fa14841940fb5070d8 Mon Sep 17 00:00:00 2001 From: Chetan Raj Date: Mon, 9 Sep 2019 10:36:20 -0700 Subject: [PATCH 0084/1918] Generailize HiveExecutor construcor method to accept custom executor client --- .../go/tasks/v1/qubole/client/qubole_client.go | 4 ++-- .../go/tasks/v1/qubole/client/qubole_status.go | 2 +- flyteplugins/go/tasks/v1/qubole/hive_executor.go | 10 ++++++++++ 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/flyteplugins/go/tasks/v1/qubole/client/qubole_client.go b/flyteplugins/go/tasks/v1/qubole/client/qubole_client.go index 3bb8470739..4936de0864 100755 --- a/flyteplugins/go/tasks/v1/qubole/client/qubole_client.go +++ b/flyteplugins/go/tasks/v1/qubole/client/qubole_client.go @@ -194,7 +194,7 @@ func (q *quboleClient) ExecuteHiveCommand( return nil, err } - status := newQuboleStatus(ctx, cmd.Status) + status := NewQuboleStatus(ctx, cmd.Status) return &QuboleCommandDetails{ID: cmd.ID, Status: status}, nil } @@ -242,7 +242,7 @@ func (q *quboleClient) GetCommandStatus(ctx context.Context, commandID string, a return QuboleStatusUnknown, err } - cmdStatus := newQuboleStatus(ctx, cmd.Status) + cmdStatus := NewQuboleStatus(ctx, cmd.Status) return cmdStatus, nil } diff --git a/flyteplugins/go/tasks/v1/qubole/client/qubole_status.go b/flyteplugins/go/tasks/v1/qubole/client/qubole_status.go index bcafa6b007..2cf80e5a44 100755 --- a/flyteplugins/go/tasks/v1/qubole/client/qubole_status.go +++ b/flyteplugins/go/tasks/v1/qubole/client/qubole_status.go @@ -28,7 +28,7 @@ var QuboleStatuses = map[QuboleStatus]struct{}{ QuboleStatusCancelled: {}, } -func newQuboleStatus(ctx context.Context, status string) QuboleStatus { +func NewQuboleStatus(ctx context.Context, status string) QuboleStatus { upperCased := strings.ToUpper(status) if _, ok := QuboleStatuses[QuboleStatus(upperCased)]; ok { return QuboleStatus(upperCased) diff --git a/flyteplugins/go/tasks/v1/qubole/hive_executor.go b/flyteplugins/go/tasks/v1/qubole/hive_executor.go index 814ae6769b..c386f9176f 100755 --- a/flyteplugins/go/tasks/v1/qubole/hive_executor.go +++ b/flyteplugins/go/tasks/v1/qubole/hive_executor.go @@ -562,6 +562,16 @@ func NewHiveTaskExecutorWithCache(ctx context.Context) (*HiveExecutor, error) { return &hiveExecutor, nil } +func NewHiveTaskExecutor(ctx context.Context, executorId string, executorClient client.QuboleClient) (*HiveExecutor, error) { + hiveExecutor := HiveExecutor{ + id: executorId, + secretsManager: NewSecretsManager(), + quboleClient: executorClient, + } + + return &hiveExecutor, nil +} + func init() { tasksV1.RegisterLoader(func(ctx context.Context) error { hiveExecutor, err := NewHiveTaskExecutorWithCache(ctx) From 9e35745e730745dd860d3277e481437e840c88ad Mon Sep 17 00:00:00 2001 From: Chetan Raj Date: Mon, 9 Sep 2019 11:41:37 -0700 Subject: [PATCH 0085/1918] Customize Qubole Job Uri --- flyteplugins/go/tasks/v1/qubole/client/qubole_client.go | 5 +++-- flyteplugins/go/tasks/v1/qubole/hive_executor.go | 1 + flyteplugins/go/tasks/v1/qubole/qubole_work.go | 4 +++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/flyteplugins/go/tasks/v1/qubole/client/qubole_client.go b/flyteplugins/go/tasks/v1/qubole/client/qubole_client.go index 4936de0864..10c603a758 100755 --- a/flyteplugins/go/tasks/v1/qubole/client/qubole_client.go +++ b/flyteplugins/go/tasks/v1/qubole/client/qubole_client.go @@ -17,7 +17,7 @@ import ( const url = "https://api.qubole.com/api" const apiPath = "/v1.2/commands" -const QuboleLogLinkFormat = "https://api.qubole.com/v2/analyze?command_id=%s" +const QuboleLogLinkFormat = "https://api.qubole.com/v2/analyze?command_id=%d" const tokenKeyForAth = "X-AUTH-TOKEN" const acceptHeaderKey = "Accept" @@ -38,6 +38,7 @@ type quboleCmdDetailsInternal struct { type QuboleCommandDetails struct { ID int64 Status QuboleStatus + JobUri string } // QuboleClient API Request Body, meant to be passed into JSON.marshal @@ -195,7 +196,7 @@ func (q *quboleClient) ExecuteHiveCommand( } status := NewQuboleStatus(ctx, cmd.Status) - return &QuboleCommandDetails{ID: cmd.ID, Status: status}, nil + return &QuboleCommandDetails{ID: cmd.ID, Status: status, JobUri: fmt.Sprintf(QuboleLogLinkFormat, cmd.ID)}, nil } /* diff --git a/flyteplugins/go/tasks/v1/qubole/hive_executor.go b/flyteplugins/go/tasks/v1/qubole/hive_executor.go index c386f9176f..0f6545b3c9 100755 --- a/flyteplugins/go/tasks/v1/qubole/hive_executor.go +++ b/flyteplugins/go/tasks/v1/qubole/hive_executor.go @@ -265,6 +265,7 @@ func (h HiveExecutor) CheckTaskStatus(ctx context.Context, taskCtx types.TaskCon commandId := strconv.FormatInt(cmdDetails.ID, 10) logger.Infof(ctx, "Created Qubole ID %s for %s", commandId, workCacheKey) item.CommandId = commandId + item.JobUri = cmdDetails.JobUri item.Status = QuboleWorkRunning item.Query = "" // Clear the query to save space in etcd once we've successfully launched err := h.executionBuffer.ConfirmExecution(ctx, workCacheKey, commandId) diff --git a/flyteplugins/go/tasks/v1/qubole/qubole_work.go b/flyteplugins/go/tasks/v1/qubole/qubole_work.go index bcfebf6995..84bde9dc5a 100755 --- a/flyteplugins/go/tasks/v1/qubole/qubole_work.go +++ b/flyteplugins/go/tasks/v1/qubole/qubole_work.go @@ -43,6 +43,8 @@ type QuboleWorkItem struct { Query string `json:"query,omitempty"` TimeoutSec uint32 `json:"timeout,omitempty"` + + JobUri string `json:"job_uri,omitempty"` } // This ID will be used in a process-wide cache, so it needs to be unique across all concurrent work being done by @@ -151,7 +153,7 @@ func constructEventInfoFromQuboleWorkItems(taskCtx types.TaskContext, quboleWork Name: fmt.Sprintf("Retry: %d Status: %s [%s]", taskCtx.GetTaskExecutionID().GetID().RetryAttempt, workItem.Status, workItem.CommandId), MessageFormat: core.TaskLog_UNKNOWN, - Uri: fmt.Sprintf(client.QuboleLogLinkFormat, workItem.CommandId), + Uri: workItem.JobUri, }) } } From b82c51924d40628b3730508146d20357ea629b5a Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Mon, 9 Sep 2019 13:02:03 -0700 Subject: [PATCH 0086/1918] bump versions --- flyteplugins/.gitignore | 5 ++ flyteplugins/Gopkg.lock | 162 ++++++++++++++++++++-------------------- flyteplugins/Gopkg.toml | 2 +- 3 files changed, 89 insertions(+), 80 deletions(-) create mode 100644 flyteplugins/.gitignore diff --git a/flyteplugins/.gitignore b/flyteplugins/.gitignore new file mode 100644 index 0000000000..b74da52e1e --- /dev/null +++ b/flyteplugins/.gitignore @@ -0,0 +1,5 @@ +.idea/ +vendor/ +.DS_Store +*.pyc +*.swp diff --git a/flyteplugins/Gopkg.lock b/flyteplugins/Gopkg.lock index a0b6d14518..019954ae25 100755 --- a/flyteplugins/Gopkg.lock +++ b/flyteplugins/Gopkg.lock @@ -2,16 +2,16 @@ [[projects]] - digest = "1:f45299a845f297e482104076e5ae4b1b0885cafb227098d2d5675b2cc65084a5" + digest = "1:e1549ae10031ac55dd7d26ac4d480130ddbdf97f9a26ebbedff089aa0335798f" name = "github.com/GoogleCloudPlatform/spark-on-k8s-operator" packages = [ "pkg/apis/sparkoperator.k8s.io", "pkg/apis/sparkoperator.k8s.io/v1beta1", ] pruneopts = "" - revision = "21894ac2fe2a4e64632ef620c8a4da776a7b6b87" + revision = "5306d013b4dbd6a9c75879c1643c7fcb237560ec" source = "https://github.com/lyft/spark-on-k8s-operator" - version = "v0.1.1" + version = "v0.1.3" [[projects]] digest = "1:60942d250d0e06d3722ddc8e22bc52f8cef7961ba6d8d3e95327a32b6b024a7b" @@ -22,7 +22,7 @@ version = "1.0.0" [[projects]] - digest = "1:e54184af8a1457b632aae19f35b241b4fe48f18765f7c80d55d7ef2c0d19d774" + digest = "1:3b037e5e14b77258878f05904573aeb0d4a1839c6815580ad6b91cb42333f735" name = "github.com/aws/aws-sdk-go" packages = [ "aws", @@ -46,6 +46,7 @@ "internal/ini", "internal/s3err", "internal/sdkio", + "internal/sdkmath", "internal/sdkrand", "internal/sdkuri", "internal/shareddefaults", @@ -60,26 +61,27 @@ "private/protocol/xml/xmlutil", "service/s3", "service/sts", + "service/sts/stsiface", ] pruneopts = "" - revision = "eb8216aeaa74d4010569c51ae6238919c172ed82" - version = "v1.19.44" + revision = "db133b72163d045aea9419c8e95b3d312d590135" + version = "v1.23.18" [[projects]] - digest = "1:0d3deb8a6da8ffba5635d6fb1d2144662200def6c9d82a35a6d05d6c2d4a48f9" + digest = "1:ac2a05be7167c495fe8aaf8aaf62ecf81e78d2180ecb04e16778dc6c185c96a5" name = "github.com/beorn7/perks" packages = ["quantile"] pruneopts = "" - revision = "4b2b341e8d7715fae06375aa633dbb6e91b3fb46" - version = "v1.0.0" + revision = "37c8de3658fcb183f997c4e13e8337516ab753e6" + version = "v1.0.1" [[projects]] - digest = "1:f6485831252319cd6ca29fc170adecf1eb81bf1e805f62f44eb48564ce2485fe" + digest = "1:545ae40d6dde46043a71bdfd7f9a17f2353ce16277c83ac685af231b4b7c4beb" name = "github.com/cespare/xxhash" packages = ["."] pruneopts = "" - revision = "3b82fb7d186719faeedd0c2864f868c74fbf79a1" - version = "v2.0.0" + revision = "de209a9ffae3256185a6bb135d1a0ada7b2b5f09" + version = "v2.1.0" [[projects]] digest = "1:193f6d32d751f26540aa8eeedc114ce0a51f9e77b6c22dda3a4db4e5f65aec66" @@ -138,7 +140,7 @@ version = "v0.1.1" [[projects]] - digest = "1:c2db84082861ca42d0b00580d28f4b31aceec477a00a38e1a057fb3da75c8adc" + digest = "1:b994001ce7517c69ce8173ba12463f329c3fe8aecae84be6a334eae89a99f2fd" name = "github.com/go-redis/redis" packages = [ ".", @@ -150,30 +152,30 @@ "internal/util", ] pruneopts = "" - revision = "75795aa4236dc7341eefac3bbe945e68c99ef9df" - version = "v6.15.3" + revision = "17c058513b3e03c5e23136c18582fdd8ac8a3645" + version = "v6.15.5" [[projects]] - digest = "1:fd53b471edb4c28c7d297f617f4da0d33402755f58d6301e7ca1197ef0a90937" + digest = "1:8a7fe65e9ac2612c4df602cc9f014a92406776d993ff0f28335e5a8831d87c53" name = "github.com/gogo/protobuf" packages = [ "proto", "sortkeys", ] pruneopts = "" - revision = "ba06b47c162d49f2af050fb4c75bcbc86a159d5c" - version = "v1.2.1" + revision = "0ca988a254f991240804bf9821f3450d87ccbb1b" + version = "v1.3.0" [[projects]] branch = "master" - digest = "1:f9714c0c017f2b821bccceeec2c7a93d29638346bb546c36ca5f90e751f91b9e" + digest = "1:e1822d37be8e11e101357a27170527b1056c99182407f270e080f76409adbd9a" name = "github.com/golang/groupcache" packages = ["lru"] pruneopts = "" - revision = "5b532d6fd5efaf7fa130d4e859a2fde0fc3a9e1b" + revision = "869f871628b6baa9cfbc11732cdf6546b17c1298" [[projects]] - digest = "1:529d738b7976c3848cae5cf3a8036440166835e389c1f617af701eeb12a0518d" + digest = "1:b852d2b62be24e445fcdbad9ce3015b44c207815d631230dfce3f14e7803f5bf" name = "github.com/golang/protobuf" packages = [ "jsonpb", @@ -190,8 +192,8 @@ "ptypes/wrappers", ] pruneopts = "" - revision = "b5d812f8a3706043e23a9cd5babf2e5423744d30" - version = "v1.3.1" + revision = "6c65a5562fc06764971b7c5d05c76c75e84bdbf7" + version = "v1.3.2" [[projects]] digest = "1:1e5b1e14524ed08301977b7b8e10c719ed853cbf3f24ecb66fae783a46f207a6" @@ -210,7 +212,7 @@ version = "v1.0.0" [[projects]] - digest = "1:16b2837c8b3cf045fa2cdc82af0cf78b19582701394484ae76b2c3bc3c99ad73" + digest = "1:728f28282e0edc47e2d8f41c9ec1956ad645ad6b15e6376ab31e2c3b094fc38f" name = "github.com/googleapis/gnostic" packages = [ "OpenAPIv2", @@ -218,8 +220,8 @@ "extensions", ] pruneopts = "" - revision = "7c663266750e7d82587642f65e60bc4083f1f84e" - version = "v0.2.0" + revision = "ab0dd09aa10e2952b28e12ecd35681b20463ebab" + version = "v0.3.1" [[projects]] digest = "1:94697ef521414e9814e038c512699e3ef984519a301b7a499b00cf851c928b29" @@ -234,14 +236,14 @@ [[projects]] branch = "master" - digest = "1:326d7083af3723768cd8150db99b8ac730837b05ef290d5a042562905cc26210" + digest = "1:e1fd67b5695fb12f54f979606c5d650a5aa72ef242f8e71072bfd4f7b5a141a0" name = "github.com/gregjones/httpcache" packages = [ ".", "diskcache", ] pruneopts = "" - revision = "3befbb6ad0cc97d4c25d851e9528915809e1a22f" + revision = "901d90724c7919163f472a9812253fb26761123d" [[projects]] digest = "1:9a0b2dd1f882668a3d7fbcd424eed269c383a16f1faa3a03d14e0dd5fba571b1" @@ -265,7 +267,7 @@ version = "v1.2.0" [[projects]] - digest = "1:dee8ec16fa714522c6cad579dfeeba3caf9644d93b8b452cd7138584402c81f7" + digest = "1:0ebfd2f00a84ee4fb31913b49011b7fa2fb6b12040991d8b948db821a15f7f77" name = "github.com/grpc-ecosystem/grpc-gateway" packages = [ "internal", @@ -274,19 +276,19 @@ "utilities", ] pruneopts = "" - revision = "8fd5fd9d19ce68183a6b0934519dfe7fe6269612" - version = "v1.9.0" + revision = "471f45a5a99a578de7a8638dc7ed29e245bde097" + version = "v1.11.1" [[projects]] - digest = "1:85f8f8d390a03287a563e215ea6bd0610c858042731a8b42062435a0dcbc485f" + digest = "1:7f6f07500a0b7d3766b00fa466040b97f2f5b5f3eef2ecabfe516e703b05119a" name = "github.com/hashicorp/golang-lru" packages = [ ".", "simplelru", ] pruneopts = "" - revision = "7087cb70de9f7a8bc0a10c375cb0d2280a8edf9c" - version = "v0.5.1" + revision = "7f827b33c0f158ec5dfbba01bb0b14a4541fd81d" + version = "v0.5.3" [[projects]] digest = "1:d14365c51dd1d34d5c79833ec91413bfbb166be978724f15701e17080dc06dec" @@ -331,12 +333,12 @@ revision = "c2b33e84" [[projects]] - digest = "1:12d3de2c11e54ea37d7f00daf85088ad5e61ec4e8a1f828d6c8b657976856be7" + digest = "1:e716a02584d94519e2ccf7ac461c4028da736d41a58c1ed95e641c1603bdb056" name = "github.com/json-iterator/go" packages = ["."] pruneopts = "" - revision = "0ff49de124c6f76f8494e194af75bde0f1a49a29" - version = "v1.1.6" + revision = "27518f6661eba504be5a7a9a9f6d9460d892ade3" + version = "v1.1.7" [[projects]] digest = "1:0f51cee70b0d254dbc93c22666ea2abf211af81c1701a96d04e2284b408621db" @@ -347,7 +349,7 @@ version = "v1.0.2" [[projects]] - digest = "1:2b5f0e6bc8fb862fed5bccf9fbb1ab819c8b3f8a21e813fe442c06aec3bb3e86" + digest = "1:ef7b24655c09b19a0b397e8a58f8f15fc402b349484afad6ce1de0a8f21bb292" name = "github.com/lyft/flyteidl" packages = [ "clients/go/admin", @@ -362,9 +364,9 @@ "gen/pb-go/flyteidl/service", ] pruneopts = "" - revision = "793b09d190148236f41ad8160b5cec9a3325c16f" + revision = "211b8fe8c2c1d9ab168afd078b62d4f7834171d3" source = "https://github.com/lyft/flyteidl" - version = "v0.1.0" + version = "v0.14.0" [[projects]] digest = "1:c368fe9a00a38c8702e24475dd3a8348d2a191892ef9030aceb821f8c035b737" @@ -409,12 +411,12 @@ version = "v0.0.9" [[projects]] - digest = "1:dbfae9da5a674236b914e486086671145b37b5e3880a38da906665aede3c9eab" + digest = "1:afc2714dedf683e571932f94f8a8ec444679eb84e076e021f63de871c5bc6cb1" name = "github.com/mattn/go-isatty" packages = ["."] pruneopts = "" - revision = "1311e847b0cb909da63b5fecfb5370aa66236465" - version = "v0.0.8" + revision = "e1f7b56ace729e4a73a29a6b4fac6cd5fcda7ab3" + version = "v0.0.9" [[projects]] digest = "1:63722a4b1e1717be7b98fc686e0b30d5e7f734b9e93d7dee86293b6deab7ea28" @@ -489,26 +491,26 @@ version = "v1.0.0" [[projects]] - digest = "1:6894aa393989b5e59d9936b8b1197dc261c2c200057b92dec34007b06e9856ae" + digest = "1:c826496cad27bd9a7644a01230a79d472b4093dd33587236e8f8369bb1d8534e" name = "github.com/prometheus/client_golang" packages = [ "prometheus", "prometheus/internal", ] pruneopts = "" - revision = "50c4339db732beb2165735d2cde0bff78eb3c5a5" - version = "v0.9.3" + revision = "2641b987480bca71fb39738eb8c8b0d577cb1d76" + version = "v0.9.4" [[projects]] branch = "master" - digest = "1:cd67319ee7536399990c4b00fae07c3413035a53193c644549a676091507cadc" + digest = "1:0a565f69553dd41b3de790fde3532e9237142f2637899e20cd3e7396f0c4f2f7" name = "github.com/prometheus/client_model" packages = ["go"] pruneopts = "" - revision = "fd36f4220a901265f90734c3183c5f0c91daa0b8" + revision = "14fe0d1b01d4d5fc031dd4bec1823bd3ebbe8016" [[projects]] - digest = "1:e6315869762add748defb9e0fcc537738f78cabeaf70b2788aba9db13097b6e9" + digest = "1:0f2cee44695a3208fe5d6926076641499c72304e6f015348c9ab2df90a202cdf" name = "github.com/prometheus/common" packages = [ "expfmt", @@ -516,19 +518,20 @@ "model", ] pruneopts = "" - revision = "17f5ca1748182ddf24fc33a5a7caaaf790a52fcc" - version = "v0.4.1" + revision = "31bed53e4047fd6c510e43a941f90cb31be0972a" + version = "v0.6.0" [[projects]] - digest = "1:fea688256dfff79e9a0e24be47c4acf51347fcff52a5dfca7b251932a52c67e0" + digest = "1:e010d89927008cac947ad9650f643a1b2b668dde47adfe56664da8694c1541d1" name = "github.com/prometheus/procfs" packages = [ ".", "internal/fs", + "internal/util", ] pruneopts = "" - revision = "833678b5bb319f2d20a475cb165c6cc59c2cc77c" - version = "v0.0.2" + revision = "00ec24a6a2d86e7074629c8384715dbb05adccd8" + version = "v0.0.4" [[projects]] digest = "1:1a405cddcf3368445051fb70ab465ae99da56ad7be8d8ca7fc52159d1c2d873c" @@ -558,12 +561,12 @@ version = "v1.3.0" [[projects]] - digest = "1:78715f4ed019d19795e67eed1dc63f525461d925616b1ed02b72582c01362440" + digest = "1:0c63b3c7ad6d825a898f28cb854252a3b29d37700c68a117a977263f5ec94efe" name = "github.com/spf13/cobra" packages = ["."] pruneopts = "" - revision = "67fc4837d267bc9bfd6e47f77783fcc3dffc68de" - version = "v0.0.4" + revision = "f2b07da1e2c38d5f12845a4f607e2e1018cbb1f5" + version = "v0.0.5" [[projects]] digest = "1:cc15ae4fbdb02ce31f3392361a70ac041f4f02e0485de8ffac92bd8033e3d26e" @@ -597,15 +600,15 @@ version = "v0.1.1" [[projects]] - digest = "1:381bcbeb112a51493d9d998bbba207a529c73dbb49b3fd789e48c63fac1f192c" + digest = "1:f7b541897bcde05a04a044c342ddc7425aab7e331f37b47fbb486cd16324b48e" name = "github.com/stretchr/testify" packages = [ "assert", "mock", ] pruneopts = "" - revision = "ffdc059bfe9ce6a4e144ba849dbedead332c6053" - version = "v1.3.0" + revision = "221dbe5ed46703ee255b1da0dec05086f5035f62" + version = "v1.4.0" [[projects]] digest = "1:e6ff7840319b6fda979a918a8801005ec2049abca62af19211d96971d8ec3327" @@ -640,15 +643,15 @@ [[projects]] branch = "master" - digest = "1:9d150270ca2c3356f2224a0878daa1652e4d0b25b345f18b4f6e156cc4b8ec5e" + digest = "1:c3b737126c08fa82da30169dff76cee24db554013cd051e9cc888dcc079292e6" name = "golang.org/x/crypto" packages = ["ssh/terminal"] pruneopts = "" - revision = "f99c8df09eb5bff426315721bfa5f16a99cad32c" + revision = "094676da4a83be5288d281081bba63a173ce6772" [[projects]] branch = "master" - digest = "1:d168befeef1eb51a25ab229b1bb411ae07c7bef22ebee588c290faf3bdf4ae27" + digest = "1:3feebd8c7f8c56efb8dd591ccb3227ba5e05863ead67a7e64cbba4c3957f61b4" name = "golang.org/x/net" packages = [ "context", @@ -661,7 +664,7 @@ "trace", ] pruneopts = "" - revision = "1492cefac77f61bc789c00f41ead8f8d7307cd21" + revision = "a7b16738d86b947dd0fadb08ca2c2342b51958b6" [[projects]] branch = "master" @@ -676,14 +679,14 @@ [[projects]] branch = "master" - digest = "1:4b923bc8024a3154f2c1e072d37133d17326e8a6a61bb03102e2f14b8af7a067" + digest = "1:185dabf552bab58db3ec09d4b734902a7650dd0d380defefdcccd6d3ae6a7ee3" name = "golang.org/x/sys" packages = [ "unix", "windows", ] pruneopts = "" - revision = "5da285871e9c6a1c3acade75bea3282d33f55ebd" + revision = "f460065e899abee61eb86816fd1fd684ea5a0f26" [[projects]] digest = "1:740b51a55815493a8d0f2b1e0d0ae48fe48953bf7eaf3fcc4198823bf67768c0" @@ -719,7 +722,7 @@ revision = "9d24e82272b4f38b78bc8cff74fa936d31ccd8ef" [[projects]] - digest = "1:47f391ee443f578f01168347818cb234ed819521e49e4d2c8dd2fb80d48ee41a" + digest = "1:0568e577f790e9bd0420521cff50580f9b38165a38f217ce68f55c4bbaa97066" name = "google.golang.org/appengine" packages = [ "internal", @@ -731,12 +734,12 @@ "urlfetch", ] pruneopts = "" - revision = "b2f4a3cf3c67576a2ee09e1fe62656a5086ce880" - version = "v1.6.1" + revision = "5f2a59506353b8d5ba8cbbcd9f3c1f41f1eaf079" + version = "v1.6.2" [[projects]] branch = "master" - digest = "1:52c6a188a1f16480607d7753e02cbd5ad43089d5550c725f32f46d376c946f37" + digest = "1:d57fc4649c80e946db4ae2064153e5896dc45936752db1ade3d94c52734116c2" name = "google.golang.org/genproto" packages = [ "googleapis/api/annotations", @@ -745,10 +748,10 @@ "protobuf/field_mask", ] pruneopts = "" - revision = "eb0b1bdb6ae60fcfc41b8d907b50dfb346112301" + revision = "92dd089d5514cecde7a465f807541f83dfd486d5" [[projects]] - digest = "1:6881653b963cd12dc1a9824aed5e122d0ff38e53e3ee07862f969a56ad2f2e9c" + digest = "1:e8a4007e58ea9431f6460d1bc5c7f9dd29fdc0211ab780b6ad3d0581478f2076" name = "google.golang.org/grpc" packages = [ ".", @@ -780,13 +783,14 @@ "resolver", "resolver/dns", "resolver/passthrough", + "serviceconfig", "stats", "status", "tap", ] pruneopts = "" - revision = "501c41df7f472c740d0674ff27122f3f48c80ce7" - version = "v1.21.1" + revision = "6eaf6f47437a6b4e2153a190160ef39a92c7eceb" + version = "v1.23.0" [[projects]] digest = "1:75fb3fcfc73a8c723efde7777b40e8e8ff9babf30d8c56160d01beffea8a95a6" @@ -939,20 +943,20 @@ version = "kubernetes-1.13.1" [[projects]] - digest = "1:9eaf86f4f6fb4a8f177220d488ef1e3255d06a691cca95f14ef085d4cd1cef3c" + digest = "1:3063061b6514ad2666c4fa292451685884cacf77c803e1b10b4a4fa23f7787fb" name = "k8s.io/klog" packages = ["."] pruneopts = "" - revision = "d98d8acdac006fb39831f1b25640813fef9c314f" - version = "v0.3.3" + revision = "3ca30a56d8a775276f9cdae009ba326fdc05af7f" + version = "v0.4.0" [[projects]] branch = "master" - digest = "1:d2aae07aa745223592ae668f6eb6c2ca0242d66a6dcf16b1e8e2711a79aad0f1" + digest = "1:71e59e355758d825c891c77bfe3ec2c0b2523b05076e96b2a2bfa804e6ac576a" name = "k8s.io/kube-openapi" packages = ["pkg/util/proto"] pruneopts = "" - revision = "db7b694dc208eead64d38030265f702db593fcf2" + revision = "743ec37842bffe49dd4221d9026f30fb1d5adbc4" [[projects]] digest = "1:5c1664b5783da5772e29bc7c2fbe369dc0b1d2f11b7935c6adc283d9aa839355" diff --git a/flyteplugins/Gopkg.toml b/flyteplugins/Gopkg.toml index 99b9a880f5..4f99d8086a 100755 --- a/flyteplugins/Gopkg.toml +++ b/flyteplugins/Gopkg.toml @@ -10,7 +10,7 @@ ignored = ["k8s.io/spark-on-k8s-operator", [[constraint]] name = "github.com/lyft/flyteidl" source = "https://github.com/lyft/flyteidl" - version = "^0.1.x" + version = "^0.14.x" [[constraint]] name = "github.com/lyft/flytestdlib" From 06b59c41f39c710b93a8a7b93f01d2e463ce578c Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Mon, 9 Sep 2019 13:03:40 -0700 Subject: [PATCH 0087/1918] probably don't care about python files --- flyteplugins/.gitignore | 1 - 1 file changed, 1 deletion(-) diff --git a/flyteplugins/.gitignore b/flyteplugins/.gitignore index b74da52e1e..b8d9753633 100644 --- a/flyteplugins/.gitignore +++ b/flyteplugins/.gitignore @@ -1,5 +1,4 @@ .idea/ vendor/ .DS_Store -*.pyc *.swp From 9d10397480a0ac1897e6ac8e63c232b98f44d0d2 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Mon, 9 Sep 2019 15:47:26 -0700 Subject: [PATCH 0088/1918] use direct options --- .../go/tasks/v1/utils/marshal_utils.go | 6 +++- .../go/tasks/v1/utils/marshal_utils_test.go | 28 +++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) create mode 100644 flyteplugins/go/tasks/v1/utils/marshal_utils_test.go diff --git a/flyteplugins/go/tasks/v1/utils/marshal_utils.go b/flyteplugins/go/tasks/v1/utils/marshal_utils.go index e5d5d9c8b6..ceb563b5d0 100755 --- a/flyteplugins/go/tasks/v1/utils/marshal_utils.go +++ b/flyteplugins/go/tasks/v1/utils/marshal_utils.go @@ -3,6 +3,7 @@ package utils import ( "encoding/json" "fmt" + "strings" "github.com/golang/protobuf/jsonpb" "github.com/golang/protobuf/proto" @@ -10,6 +11,9 @@ import ( ) var jsonPbMarshaler = jsonpb.Marshaler{} +var jsonPbUnmarshaler = &jsonpb.Unmarshaler{ + AllowUnknownFields: true, +} func UnmarshalStruct(structObj *structpb.Struct, msg proto.Message) error { if structObj == nil { @@ -21,7 +25,7 @@ func UnmarshalStruct(structObj *structpb.Struct, msg proto.Message) error { return err } - if err = jsonpb.UnmarshalString(jsonObj, msg); err != nil { + if err = jsonPbUnmarshaler.Unmarshal(strings.NewReader(jsonObj), msg); err != nil { return err } diff --git a/flyteplugins/go/tasks/v1/utils/marshal_utils_test.go b/flyteplugins/go/tasks/v1/utils/marshal_utils_test.go new file mode 100644 index 0000000000..406cfefa73 --- /dev/null +++ b/flyteplugins/go/tasks/v1/utils/marshal_utils_test.go @@ -0,0 +1,28 @@ +package utils + +import ( + structpb "github.com/golang/protobuf/ptypes/struct" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestBackwardsCompatibility(t *testing.T) { + sidecarProtoMessage := plugins.SidecarJob{ + PrimaryContainerName: "primary", + } + + sidecarStruct, err := MarshalObjToStruct(sidecarProtoMessage) + assert.NoError(t, err) + + // Set a new field in the struct to mimic what happens when we add new fields to protobuf messages + sidecarStruct.Fields["hello"] = &structpb.Value{ + Kind: &structpb.Value_StringValue{ + StringValue: "world", + }, + } + + newSidecarJob := plugins.SidecarJob{} + err = UnmarshalStruct(sidecarStruct, &newSidecarJob) + assert.NoError(t, err) +} From 4d3cccc5578b54480b1a636b2d4f8bd1a686c447 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Mon, 9 Sep 2019 16:11:37 -0700 Subject: [PATCH 0089/1918] Bump IDL/Plugins versions (#4) Bumping versions to pick up Flyte IDL change deprecating Hive query collections. https://github.com/lyft/flyteplugins/releases/tag/v0.1.2 https://github.com/lyft/flyteplugins/releases/tag/v0.1.3 https://github.com/lyft/flyteidl/releases/tag/v0.14.0 --- flytepropeller/Gopkg.lock | 12 ++++++------ flytepropeller/Gopkg.toml | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/flytepropeller/Gopkg.lock b/flytepropeller/Gopkg.lock index 08a0aa8c73..8ac4731381 100644 --- a/flytepropeller/Gopkg.lock +++ b/flytepropeller/Gopkg.lock @@ -448,7 +448,7 @@ version = "v1.0.2" [[projects]] - digest = "1:2b5f0e6bc8fb862fed5bccf9fbb1ab819c8b3f8a21e813fe442c06aec3bb3e86" + digest = "1:ef7b24655c09b19a0b397e8a58f8f15fc402b349484afad6ce1de0a8f21bb292" name = "github.com/lyft/flyteidl" packages = [ "clients/go/admin", @@ -466,12 +466,12 @@ "gen/pb-go/flyteidl/service", ] pruneopts = "" - revision = "793b09d190148236f41ad8160b5cec9a3325c16f" + revision = "211b8fe8c2c1d9ab168afd078b62d4f7834171d3" source = "https://github.com/lyft/flyteidl" - version = "v0.1.0" + version = "v0.14.0" [[projects]] - digest = "1:500471ee50c4141d3523c79615cc90529b3152f8aa5924b63122df6bf201a7a0" + digest = "1:1f4d377eba88e89d78761ad01caa205b214aeed0db4ead48f424538c9a8f7bcf" name = "github.com/lyft/flyteplugins" packages = [ "go/tasks", @@ -493,9 +493,9 @@ "go/tasks/v1/utils", ] pruneopts = "" - revision = "8c85a7c9f19de4df4767de329c56a7f09d0a7bbc" + revision = "99d622cf0f0ca5041e46aad172101536079ea22a" source = "https://github.com/lyft/flyteplugins" - version = "v0.1.0" + version = "v0.1.3" [[projects]] digest = "1:77615fd7bcfd377c5e898c41d33be827f108166691cb91257d27113ee5d08650" diff --git a/flytepropeller/Gopkg.toml b/flytepropeller/Gopkg.toml index 1a0e8974e5..7d7e504096 100644 --- a/flytepropeller/Gopkg.toml +++ b/flytepropeller/Gopkg.toml @@ -49,12 +49,12 @@ required = [ [[constraint]] name = "github.com/lyft/flyteidl" source = "https://github.com/lyft/flyteidl" - version = "^0.1.x" + version = "^0.14.x" [[constraint]] name = "github.com/lyft/flyteplugins" source = "https://github.com/lyft/flyteplugins" - version = "^0.1.0" + version = "^0.1.3" [[override]] name = "github.com/lyft/flytestdlib" From 94c1c97a9fef7b1fdf87aadb311f623974358fa3 Mon Sep 17 00:00:00 2001 From: Chetan Raj Date: Mon, 9 Sep 2019 16:59:04 -0700 Subject: [PATCH 0090/1918] DRY: Make use of NewHiveTaskExecutor function in NewHiveTaskExecutorWithCache --- flyteplugins/go/tasks/v1/qubole/hive_executor.go | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/flyteplugins/go/tasks/v1/qubole/hive_executor.go b/flyteplugins/go/tasks/v1/qubole/hive_executor.go index 0f6545b3c9..8fb1301e0d 100755 --- a/flyteplugins/go/tasks/v1/qubole/hive_executor.go +++ b/flyteplugins/go/tasks/v1/qubole/hive_executor.go @@ -554,13 +554,7 @@ func (h *HiveExecutor) SyncQuboleQuery(ctx context.Context, obj utils2.CacheItem } func NewHiveTaskExecutorWithCache(ctx context.Context) (*HiveExecutor, error) { - hiveExecutor := HiveExecutor{ - id: hiveExecutorId, - secretsManager: NewSecretsManager(), - quboleClient: client.NewQuboleClient(), - } - - return &hiveExecutor, nil + return NewHiveTaskExecutor(ctx, hiveExecutorId, client.NewQuboleClient()) } func NewHiveTaskExecutor(ctx context.Context, executorId string, executorClient client.QuboleClient) (*HiveExecutor, error) { From 7bbf545e5bbfef358a6ce92c062f93b8b1bf98e6 Mon Sep 17 00:00:00 2001 From: Andrew Chan Date: Mon, 9 Sep 2019 12:30:37 -0700 Subject: [PATCH 0091/1918] Cache task executions to DataCatalog --- flytepropeller/Gopkg.lock | 14 + flytepropeller/Gopkg.toml | 5 + flytepropeller/config.yaml | 3 + .../pkg/controller/catalog/catalog_client.go | 20 +- .../{discovery_config.go => config.go} | 6 +- .../pkg/controller/catalog/config_flags.go | 38 +- .../controller/catalog/config_flags_test.go | 46 +- .../catalog/datacatalog/datacatalog.go | 291 ++++++++++++ .../catalog/datacatalog/datacatalog_test.go | 427 ++++++++++++++++++ .../datacatalog/mocks/DataCatalogClient.go | 163 +++++++ .../transformer/datacatalog_transformer.go | 144 ++++++ .../datacatalog_transformer_test.go | 132 ++++++ flytepropeller/pkg/controller/controller.go | 5 +- .../pkg/controller/nodes/executor_test.go | 16 +- .../pkg/controller/nodes/task/handler.go | 19 +- .../pkg/controller/workflow/executor_test.go | 18 +- 16 files changed, 1265 insertions(+), 82 deletions(-) rename flytepropeller/pkg/controller/catalog/{discovery_config.go => config.go} (73%) create mode 100644 flytepropeller/pkg/controller/catalog/datacatalog/datacatalog.go create mode 100644 flytepropeller/pkg/controller/catalog/datacatalog/datacatalog_test.go create mode 100644 flytepropeller/pkg/controller/catalog/datacatalog/mocks/DataCatalogClient.go create mode 100644 flytepropeller/pkg/controller/catalog/datacatalog/transformer/datacatalog_transformer.go create mode 100644 flytepropeller/pkg/controller/catalog/datacatalog/transformer/datacatalog_transformer_test.go diff --git a/flytepropeller/Gopkg.lock b/flytepropeller/Gopkg.lock index 8ac4731381..513d0366e3 100644 --- a/flytepropeller/Gopkg.lock +++ b/flytepropeller/Gopkg.lock @@ -447,6 +447,15 @@ revision = "f55edac94c9bbba5d6182a4be46d86a2c9b5b50e" version = "v1.0.2" +[[projects]] + digest = "1:eaef68aaa87572012b236975f558582a043044c21be3fda97921c4871fb4298f" + name = "github.com/lyft/datacatalog" + packages = ["protos/gen"] + pruneopts = "" + revision = "0da0ffbb4705efd5d5ecd04ea560b35d968beb86" + source = "https://github.com/lyft/datacatalog" + version = "v0.1.0" + [[projects]] digest = "1:ef7b24655c09b19a0b397e8a58f8f15fc402b349484afad6ce1de0a8f21bb292" name = "github.com/lyft/flyteidl" @@ -1338,12 +1347,15 @@ "github.com/DiSiqueira/GoTree", "github.com/fatih/color", "github.com/ghodss/yaml", + "github.com/gogo/protobuf/proto", "github.com/golang/protobuf/jsonpb", "github.com/golang/protobuf/proto", "github.com/golang/protobuf/ptypes", "github.com/golang/protobuf/ptypes/struct", "github.com/golang/protobuf/ptypes/timestamp", + "github.com/google/uuid", "github.com/grpc-ecosystem/go-grpc-middleware/retry", + "github.com/lyft/datacatalog/protos/gen", "github.com/lyft/flyteidl/clients/go/admin", "github.com/lyft/flyteidl/clients/go/admin/mocks", "github.com/lyft/flyteidl/clients/go/coreutils", @@ -1384,6 +1396,7 @@ "golang.org/x/time/rate", "google.golang.org/grpc", "google.golang.org/grpc/codes", + "google.golang.org/grpc/credentials", "google.golang.org/grpc/status", "k8s.io/api/batch/v1", "k8s.io/api/core/v1", @@ -1399,6 +1412,7 @@ "k8s.io/apimachinery/pkg/util/rand", "k8s.io/apimachinery/pkg/util/runtime", "k8s.io/apimachinery/pkg/util/sets", + "k8s.io/apimachinery/pkg/util/uuid", "k8s.io/apimachinery/pkg/util/wait", "k8s.io/apimachinery/pkg/watch", "k8s.io/client-go/discovery", diff --git a/flytepropeller/Gopkg.toml b/flytepropeller/Gopkg.toml index 7d7e504096..c908a00657 100644 --- a/flytepropeller/Gopkg.toml +++ b/flytepropeller/Gopkg.toml @@ -51,6 +51,11 @@ required = [ source = "https://github.com/lyft/flyteidl" version = "^0.14.x" +[[constraint]] + name = "github.com/lyft/datacatalog" + source = "https://github.com/lyft/datacatalog" + version = "^0.1.x" + [[constraint]] name = "github.com/lyft/flyteplugins" source = "https://github.com/lyft/flyteplugins" diff --git a/flytepropeller/config.yaml b/flytepropeller/config.yaml index c50dbc4b31..6bb16eea30 100644 --- a/flytepropeller/config.yaml +++ b/flytepropeller/config.yaml @@ -88,6 +88,9 @@ event: admin: endpoint: localhost:8089 insecure: true +catalog-cache: + type: catalog + endpoint: datacatalog:8089 errors: show-source: true logger: diff --git a/flytepropeller/pkg/controller/catalog/catalog_client.go b/flytepropeller/pkg/controller/catalog/catalog_client.go index 53a9b8aa5a..af24c2b4d0 100644 --- a/flytepropeller/pkg/controller/catalog/catalog_client.go +++ b/flytepropeller/pkg/controller/catalog/catalog_client.go @@ -3,7 +3,10 @@ package catalog import ( "context" + "fmt" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/controller/catalog/datacatalog" "github.com/lyft/flytestdlib/logger" "github.com/lyft/flytestdlib/storage" ) @@ -13,16 +16,25 @@ type Client interface { Put(ctx context.Context, task *core.TaskTemplate, execID *core.TaskExecutionIdentifier, inputPath storage.DataReference, outputPath storage.DataReference) error } -func NewCatalogClient(store storage.ProtobufStore) Client { +func NewCatalogClient(ctx context.Context, store storage.ProtobufStore) (Client, error) { catalogConfig := GetConfig() var catalogClient Client - if catalogConfig.Type == LegacyDiscoveryType { + var err error + switch catalogConfig.Type { + case LegacyDiscoveryType: catalogClient = NewLegacyDiscovery(catalogConfig.Endpoint, store) - } else if catalogConfig.Type == NoOpDiscoveryType { + case DataCatalogType: + catalogClient, err = datacatalog.NewDataCatalog(ctx, catalogConfig.Endpoint, catalogConfig.Secure, store) + if err != nil { + return nil, err + } + case NoOpDiscoveryType, "": catalogClient = NewNoOpDiscovery() + default: + return nil, fmt.Errorf("No such catalog type available: %v", catalogConfig.Type) } logger.Infof(context.Background(), "Created Catalog client, type: %v", catalogConfig.Type) - return catalogClient + return catalogClient, nil } diff --git a/flytepropeller/pkg/controller/catalog/discovery_config.go b/flytepropeller/pkg/controller/catalog/config.go similarity index 73% rename from flytepropeller/pkg/controller/catalog/discovery_config.go rename to flytepropeller/pkg/controller/catalog/config.go index bbc1bab8ff..6d4f2c3376 100644 --- a/flytepropeller/pkg/controller/catalog/discovery_config.go +++ b/flytepropeller/pkg/controller/catalog/config.go @@ -21,11 +21,13 @@ type DiscoveryType = string const ( NoOpDiscoveryType DiscoveryType = "noop" LegacyDiscoveryType DiscoveryType = "legacy" + DataCatalogType DiscoveryType = "datacatalog" ) type Config struct { - Type DiscoveryType `json:"type" pflag:"\"noop\",Discovery Implementation to use"` - Endpoint string `json:"endpoint" pflag:"\"\", Endpoint for discovery service"` + Type DiscoveryType `json:"type" pflag:"\"noop\", Catalog Implementation to use"` + Endpoint string `json:"endpoint" pflag:"\"\", Endpoint for catalog service"` + Secure bool `json:"secure" pflag:"true, Connect with TSL/SSL"` } // Gets loaded config for Discovery diff --git a/flytepropeller/pkg/controller/catalog/config_flags.go b/flytepropeller/pkg/controller/catalog/config_flags.go index d67dd751ac..b2359a462d 100755 --- a/flytepropeller/pkg/controller/catalog/config_flags.go +++ b/flytepropeller/pkg/controller/catalog/config_flags.go @@ -1,47 +1,21 @@ // Code generated by go generate; DO NOT EDIT. -// This file was generated by robots. +// This file was generated by robots at +// 2019-09-05 05:37:07.301294018 -0700 PDT m=+14.696460456 package catalog import ( - "encoding/json" - "reflect" - "fmt" "github.com/spf13/pflag" ) -// If v is a pointer, it will get its element value or the zero value of the element type. -// If v is not a pointer, it will return it as is. -func (Config) elemValueOrNil(v interface{}) interface{} { - if t := reflect.TypeOf(v); t.Kind() == reflect.Ptr { - if reflect.ValueOf(v).IsNil() { - return reflect.Zero(t.Elem()).Interface() - } else { - return reflect.ValueOf(v).Interface() - } - } else if v == nil { - return reflect.Zero(t).Interface() - } - - return v -} - -func (Config) mustMarshalJSON(v json.Marshaler) string { - raw, err := v.MarshalJSON() - if err != nil { - panic(err) - } - - return string(raw) -} - // GetPFlagSet will return strongly types pflags for all fields in Config and its nested types. The format of the // flags is json-name.json-sub-name... etc. -func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { +func (Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags := pflag.NewFlagSet("Config", pflag.ExitOnError) - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "type"), defaultConfig.Type, "Discovery Implementation to use") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "endpoint"), defaultConfig.Endpoint, " Endpoint for discovery service") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "type"), "noop", " Catalog Implementation to use") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "endpoint"), "", " Endpoint for catalog service") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "secure"), true, " Connect with TSL/SSL") return cmdFlags } diff --git a/flytepropeller/pkg/controller/catalog/config_flags_test.go b/flytepropeller/pkg/controller/catalog/config_flags_test.go index a2538822b8..e5ce16b044 100755 --- a/flytepropeller/pkg/controller/catalog/config_flags_test.go +++ b/flytepropeller/pkg/controller/catalog/config_flags_test.go @@ -1,5 +1,6 @@ // Code generated by go generate; DO NOT EDIT. -// This file was generated by robots. +// This file was generated by robots at +// 2019-09-05 05:37:07.301294018 -0700 PDT m=+14.696460456 package catalog @@ -7,7 +8,6 @@ import ( "encoding/json" "fmt" "reflect" - "strings" "testing" "github.com/mitchellh/mapstructure" @@ -70,16 +70,6 @@ func decode_Config(input, result interface{}) error { return decoder.Decode(input) } -func join_Config(arr interface{}, sep string) string { - listValue := reflect.ValueOf(arr) - strs := make([]string, 0, listValue.Len()) - for i := 0; i < listValue.Len(); i++ { - strs = append(strs, fmt.Sprintf("%v", listValue.Index(i))) - } - - return strings.Join(strs, sep) -} - func testDecodeJson_Config(t *testing.T, val, result interface{}) { assert.NoError(t, decode_Config(val, result)) } @@ -103,16 +93,14 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vString, err := cmdFlags.GetString("type"); err == nil { - assert.Equal(t, string(defaultConfig.Type), vString) + assert.Equal(t, "noop", vString) } else { assert.FailNow(t, err.Error()) } }) t.Run("Override", func(t *testing.T) { - testValue := "1" - - cmdFlags.Set("type", testValue) + cmdFlags.Set("type", "1") if vString, err := cmdFlags.GetString("type"); err == nil { testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Type) @@ -125,16 +113,14 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vString, err := cmdFlags.GetString("endpoint"); err == nil { - assert.Equal(t, string(defaultConfig.Endpoint), vString) + assert.Equal(t, "", vString) } else { assert.FailNow(t, err.Error()) } }) t.Run("Override", func(t *testing.T) { - testValue := "1" - - cmdFlags.Set("endpoint", testValue) + cmdFlags.Set("endpoint", "1") if vString, err := cmdFlags.GetString("endpoint"); err == nil { testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Endpoint) @@ -143,4 +129,24 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) + t.Run("Test_secure", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vBool, err := cmdFlags.GetBool("secure"); err == nil { + assert.Equal(t, true, vBool) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + cmdFlags.Set("secure", "1") + if vBool, err := cmdFlags.GetBool("secure"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vBool), &actual.Secure) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) } diff --git a/flytepropeller/pkg/controller/catalog/datacatalog/datacatalog.go b/flytepropeller/pkg/controller/catalog/datacatalog/datacatalog.go new file mode 100644 index 0000000000..672da3f3f5 --- /dev/null +++ b/flytepropeller/pkg/controller/catalog/datacatalog/datacatalog.go @@ -0,0 +1,291 @@ +package datacatalog + +import ( + "context" + "crypto/x509" + "time" + + "fmt" + + grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry" + datacatalog "github.com/lyft/datacatalog/protos/gen" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/controller/catalog/datacatalog/transformer" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/storage" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/status" + "k8s.io/apimachinery/pkg/util/uuid" +) + +const ( + taskVersionKey = "task-version" + taskExecKey = "execution-name" + taskExecVersion = "execution-version" +) + +type CatalogClient struct { + client datacatalog.DataCatalogClient + store storage.ProtobufStore +} + +func (m *CatalogClient) getArtifactByTag(ctx context.Context, tagName string, dataset *datacatalog.Dataset) (*datacatalog.Artifact, error) { + logger.Debugf(ctx, "Get Artifact by tag %v", tagName) + artifactQuery := &datacatalog.GetArtifactRequest{ + Dataset: dataset.Id, + QueryHandle: &datacatalog.GetArtifactRequest_TagName{ + TagName: tagName, + }, + } + response, err := m.client.GetArtifact(ctx, artifactQuery) + if err != nil { + return nil, err + } + + return response.Artifact, nil +} + +func (m *CatalogClient) getDataset(ctx context.Context, task *core.TaskTemplate) (*datacatalog.Dataset, error) { + datasetID, err := transformer.GenerateDatasetIDForTask(ctx, task) + if err != nil { + return nil, err + } + logger.Debugf(ctx, "Get Dataset %v", datasetID) + + dsQuery := &datacatalog.GetDatasetRequest{ + Dataset: datasetID, + } + + datasetResponse, err := m.client.GetDataset(ctx, dsQuery) + if err != nil { + return nil, err + } + + return datasetResponse.Dataset, nil +} + +func (m *CatalogClient) validateTask(task *core.TaskTemplate) error { + taskInterface := task.Interface + if taskInterface == nil { + return fmt.Errorf("Task interface cannot be nil, task: [%+v]", task) + } + + if task.Id == nil { + return fmt.Errorf("Task ID cannot be nil, task: [%+v]", task) + } + + if task.Metadata == nil { + return fmt.Errorf("Task metadata cannot be nil, task: [%+v]", task) + } + + return nil +} + +// Get the cached task execution from Catalog. +// These are the steps taken: +// - Verify there is a Dataset created for the Task +// - Lookup the Artifact that is tagged with the hash of the input values +// - The artifactData contains the literal values that serve as the task outputs +func (m *CatalogClient) Get(ctx context.Context, task *core.TaskTemplate, inputPath storage.DataReference) (*core.LiteralMap, error) { + inputs := &core.LiteralMap{} + + if err := m.validateTask(task); err != nil { + logger.Errorf(ctx, "DataCatalog task validation failed %+v, err: %+v", task, err) + return nil, err + } + + if task.Interface.Inputs != nil && len(task.Interface.Inputs.Variables) != 0 { + if err := m.store.ReadProtobuf(ctx, inputPath, inputs); err != nil { + logger.Errorf(ctx, "DataCatalog failed to read inputs %+v, err: %+v", inputPath, err) + return nil, err + } + logger.Debugf(ctx, "DataCatalog read inputs from %v", inputPath) + } + + dataset, err := m.getDataset(ctx, task) + if err != nil { + logger.Errorf(ctx, "DataCatalog failed to get dataset for task %+v, err: %+v", task, err) + return nil, err + } + + tag, err := transformer.GenerateArtifactTagName(ctx, inputs) + if err != nil { + logger.Errorf(ctx, "DataCatalog failed to generate tag for inputs %+v, err: %+v", inputs, err) + return nil, err + } + + artifact, err := m.getArtifactByTag(ctx, tag, dataset) + if err != nil { + logger.Errorf(ctx, "DataCatalog failed to get artifact by tag %+v, err: %+v", tag, err) + return nil, err + } + logger.Debugf(ctx, "Artifact found %v from tag %v", artifact, tag) + + outputs, err := transformer.GenerateTaskOutputsFromArtifact(task, artifact) + if err != nil { + logger.Errorf(ctx, "DataCatalog failed to get outputs from artifact %+v, err: %+v", artifact.Id, err) + return nil, err + } + + logger.Debugf(ctx, "Cached %v artifact outputs from artifact %v", len(outputs.Literals), artifact.Id) + return outputs, nil +} + +// Catalog the task execution as a cached Artifact. We associate an Artifact as the cached data by tagging the Artifact +// with the hash of the input values. +// +// The steps taken to cache an execution: +// - Ensure a Dataset exists for the Artifact. The Dataset represents the proj/domain/name/version of the task +// - Create an Artifact with the execution data that belongs to the dataset +// - Tag the Artifact with a hash generated by the input values +func (m *CatalogClient) Put(ctx context.Context, task *core.TaskTemplate, execID *core.TaskExecutionIdentifier, inputPath storage.DataReference, outputPath storage.DataReference) error { + inputs := &core.LiteralMap{} + outputs := &core.LiteralMap{} + + if err := m.validateTask(task); err != nil { + logger.Errorf(ctx, "DataCatalog task validation failed %+v, err: %+v", task, err) + return err + } + + if task.Interface.Inputs != nil && len(task.Interface.Inputs.Variables) != 0 { + if err := m.store.ReadProtobuf(ctx, inputPath, inputs); err != nil { + logger.Errorf(ctx, "DataCatalog failed to read inputs %+v, err: %+v", inputPath, err) + return err + } + logger.Debugf(ctx, "DataCatalog read inputs from %v", inputPath) + } + + if task.Interface.Outputs != nil && len(task.Interface.Outputs.Variables) != 0 { + if err := m.store.ReadProtobuf(ctx, outputPath, outputs); err != nil { + logger.Errorf(ctx, "DataCatalog failed to read outputs %+v, err: %+v", outputPath, err) + return err + } + logger.Debugf(ctx, "DataCatalog read outputs from %v", outputPath) + } + + datasetID, err := transformer.GenerateDatasetIDForTask(ctx, task) + if err != nil { + logger.Errorf(ctx, "DataCatalog failed to generate dataset for task %+v, err: %+v", task, err) + return err + } + + logger.Debugf(ctx, "DataCatalog put into Catalog for DataSet %v", datasetID) + + // Try creating the dataset in case it doesn't exist + newDataset := &datacatalog.Dataset{ + Id: datasetID, + Metadata: &datacatalog.Metadata{ + KeyMap: map[string]string{ + taskVersionKey: task.Id.Version, + taskExecKey: execID.TaskId.Name, + }, + }, + } + + _, err = m.client.CreateDataset(ctx, &datacatalog.CreateDatasetRequest{Dataset: newDataset}) + if err != nil { + logger.Debugf(ctx, "Create dataset %v return err %v", datasetID, err) + + if status.Code(err) == codes.AlreadyExists { + logger.Debugf(ctx, "Create Dataset for task %v already exists", task.Id) + } else { + logger.Errorf(ctx, "Unable to create dataset %+v, err: %+v", datasetID, err) + return err + } + } + + // Create the artifact for the execution that belongs in the task + artifactDataList := make([]*datacatalog.ArtifactData, 0, len(outputs.Literals)) + for name, value := range outputs.Literals { + artifactData := &datacatalog.ArtifactData{ + Name: name, + Value: value, + } + artifactDataList = append(artifactDataList, artifactData) + } + + artifactMetadata := &datacatalog.Metadata{ + KeyMap: map[string]string{ + taskExecVersion: execID.TaskId.Version, + taskExecKey: execID.TaskId.Name, + }, + } + + cachedArtifact := &datacatalog.Artifact{ + Id: string(uuid.NewUUID()), + Dataset: datasetID, + Data: artifactDataList, + Metadata: artifactMetadata, + } + + createArtifactRequest := &datacatalog.CreateArtifactRequest{Artifact: cachedArtifact} + _, err = m.client.CreateArtifact(ctx, createArtifactRequest) + if err != nil { + logger.Errorf(ctx, "Failed to create Artifact %+v, err: %v", cachedArtifact, err) + return err + } + logger.Debugf(ctx, "Created artifact: %v, with %v outputs from execution %v", cachedArtifact.Id, len(artifactDataList), execID.TaskId.Name) + + // Tag the artifact since it is the cached artifact + tagName, err := transformer.GenerateArtifactTagName(ctx, inputs) + if err != nil { + logger.Errorf(ctx, "Failed to create tag for artifact %+v, err: %+v", cachedArtifact.Id, err) + return err + } + logger.Debugf(ctx, "Created tag: %v, for task: %v", tagName, task.Id) + + // TODO: We should create the artifact + tag in a transaction when the service supports that + tag := &datacatalog.Tag{ + Name: tagName, + Dataset: datasetID, + ArtifactId: cachedArtifact.Id, + } + _, err = m.client.AddTag(ctx, &datacatalog.AddTagRequest{Tag: tag}) + if err != nil { + if status.Code(err) == codes.AlreadyExists { + logger.Errorf(ctx, "Tag %v already exists for Artifact %v (idempotent)", tagName, cachedArtifact.Id) + } + + logger.Errorf(ctx, "Failed to add tag %+v for artifact %+v, err: %+v", tagName, cachedArtifact.Id, err) + return err + } + + return nil +} + +func NewDataCatalog(ctx context.Context, endpoint string, secureConnection bool, datastore storage.ProtobufStore) (*CatalogClient, error) { + var opts []grpc.DialOption + + grpcOptions := []grpc_retry.CallOption{ + grpc_retry.WithBackoff(grpc_retry.BackoffLinear(100 * time.Millisecond)), + grpc_retry.WithCodes(codes.DeadlineExceeded, codes.Unavailable, codes.Canceled), + grpc_retry.WithMax(5), + } + + if secureConnection { + pool, err := x509.SystemCertPool() + if err != nil { + return nil, err + } + + creds := credentials.NewClientTLSFromCert(pool, "") + opts = append(opts, grpc.WithTransportCredentials(creds)) + } + + retryInterceptor := grpc.WithUnaryInterceptor(grpc_retry.UnaryClientInterceptor(grpcOptions...)) + + opts = append(opts, retryInterceptor) + clientConn, err := grpc.Dial(endpoint, opts...) + if err != nil { + return nil, err + } + + client := datacatalog.NewDataCatalogClient(clientConn) + + return &CatalogClient{ + client: client, + store: datastore, + }, nil +} diff --git a/flytepropeller/pkg/controller/catalog/datacatalog/datacatalog_test.go b/flytepropeller/pkg/controller/catalog/datacatalog/datacatalog_test.go new file mode 100644 index 0000000000..d25b94165a --- /dev/null +++ b/flytepropeller/pkg/controller/catalog/datacatalog/datacatalog_test.go @@ -0,0 +1,427 @@ +package datacatalog + +import ( + "context" + "testing" + + "github.com/golang/protobuf/proto" + "github.com/google/uuid" + datacatalog "github.com/lyft/datacatalog/protos/gen" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/controller/catalog/datacatalog/mocks" + "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/promutils/labeled" + "github.com/lyft/flytestdlib/storage" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func init() { + labeled.SetMetricKeys(contextutils.ProjectKey, contextutils.DomainKey, contextutils.WorkflowIDKey, contextutils.TaskIDKey) +} + +func createInmemoryStore(t testing.TB) *storage.DataStore { + cfg := storage.Config{ + Type: storage.TypeMemory, + } + + d, err := storage.NewDataStore(&cfg, promutils.NewTestScope()) + assert.NoError(t, err) + + return d +} + +func newStringLiteral(value string) *core.Literal { + return &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Primitive{ + Primitive: &core.Primitive{ + Value: &core.Primitive_StringValue{ + StringValue: value, + }, + }, + }, + }, + }, + } +} + +var sampleParameters = &core.LiteralMap{Literals: map[string]*core.Literal{ + "out1": newStringLiteral("output1-stringval"), +}} + +var variableMap = &core.VariableMap{ + Variables: map[string]*core.Variable{ + "test": &core.Variable{ + Type: &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_STRING, + }, + }, + }, + }, +} + +var typedInterface = &core.TypedInterface{ + Inputs: variableMap, + Outputs: variableMap, +} + +var sampleTask = &core.TaskTemplate{ + Id: &core.Identifier{Project: "project", Domain: "domain", Name: "name"}, + Interface: typedInterface, + Metadata: &core.TaskMetadata{ + DiscoveryVersion: "1.0.0", + }, +} + +var noInputOutputTask = &core.TaskTemplate{ + Id: &core.Identifier{Project: "project", Domain: "domain", Name: "name"}, + Metadata: &core.TaskMetadata{ + DiscoveryVersion: "1.0.0", + }, + Interface: &core.TypedInterface{}, +} + +var datasetID = &datacatalog.DatasetID{ + Project: "project", + Domain: "domain", + Name: "flyte_task-name", + Version: "1.0.0-ue5g6uuI-ue5g6uuI", +} + +func assertGrpcErr(t *testing.T, err error, code codes.Code) { + assert.Equal(t, code, status.Code(err)) +} + +func TestCatalog_Get(t *testing.T) { + + ctx := context.Background() + testFile := storage.DataReference("test-data.pb") + + t.Run("Empty interface returns err", func(t *testing.T) { + dataStore := createInmemoryStore(t) + + mockClient := &mocks.DataCatalogClient{} + catalogClient := &CatalogClient{ + client: mockClient, + store: dataStore, + } + taskWithoutInterface := &core.TaskTemplate{ + Id: &core.Identifier{Project: "project", Domain: "domain", Name: "name"}, + Metadata: &core.TaskMetadata{ + DiscoveryVersion: "1.0.0", + }, + } + _, err := catalogClient.Get(ctx, taskWithoutInterface, testFile) + assert.Error(t, err) + }) + + t.Run("No results, no Dataset", func(t *testing.T) { + dataStore := createInmemoryStore(t) + err := dataStore.WriteProtobuf(ctx, testFile, storage.Options{}, newStringLiteral("output")) + assert.NoError(t, err) + + mockClient := &mocks.DataCatalogClient{} + catalogClient := &CatalogClient{ + client: mockClient, + store: dataStore, + } + mockClient.On("GetDataset", + ctx, + mock.MatchedBy(func(o *datacatalog.GetDatasetRequest) bool { + assert.EqualValues(t, datasetID, o.Dataset) + return true + }), + ).Return(nil, status.Error(codes.NotFound, "test not found")) + resp, err := catalogClient.Get(ctx, sampleTask, testFile) + assert.Error(t, err) + + assertGrpcErr(t, err, codes.NotFound) + assert.Nil(t, resp) + }) + + t.Run("No results, no Artifact", func(t *testing.T) { + dataStore := createInmemoryStore(t) + + err := dataStore.WriteProtobuf(ctx, testFile, storage.Options{}, sampleParameters) + assert.NoError(t, err) + + mockClient := &mocks.DataCatalogClient{} + discovery := &CatalogClient{ + client: mockClient, + store: dataStore, + } + + sampleDataSet := &datacatalog.Dataset{ + Id: datasetID, + } + mockClient.On("GetDataset", + ctx, + mock.MatchedBy(func(o *datacatalog.GetDatasetRequest) bool { + assert.EqualValues(t, datasetID, o.Dataset) + return true + }), + ).Return(&datacatalog.GetDatasetResponse{Dataset: sampleDataSet}, nil, "") + + mockClient.On("GetArtifact", + ctx, + mock.MatchedBy(func(o *datacatalog.GetArtifactRequest) bool { + return true + }), + ).Return(nil, status.Error(codes.NotFound, "")) + + outputs, err := discovery.Get(ctx, sampleTask, testFile) + assert.Nil(t, outputs) + assert.Error(t, err) + }) + + t.Run("Found w/ tag", func(t *testing.T) { + dataStore := createInmemoryStore(t) + + err := dataStore.WriteProtobuf(ctx, testFile, storage.Options{}, sampleParameters) + assert.NoError(t, err) + + mockClient := &mocks.DataCatalogClient{} + discovery := &CatalogClient{ + client: mockClient, + store: dataStore, + } + + sampleDataSet := &datacatalog.Dataset{ + Id: datasetID, + } + + mockClient.On("GetDataset", + ctx, + mock.MatchedBy(func(o *datacatalog.GetDatasetRequest) bool { + assert.EqualValues(t, datasetID, o.Dataset) + return true + }), + ).Return(&datacatalog.GetDatasetResponse{Dataset: sampleDataSet}, nil) + + sampleArtifactData := &datacatalog.ArtifactData{ + Name: "test", + Value: newStringLiteral("output1-stringval"), + } + sampleArtifact := &datacatalog.Artifact{ + Id: "test-artifact", + Dataset: sampleDataSet.Id, + Data: []*datacatalog.ArtifactData{sampleArtifactData}, + } + mockClient.On("GetArtifact", + ctx, + mock.MatchedBy(func(o *datacatalog.GetArtifactRequest) bool { + assert.EqualValues(t, datasetID, o.Dataset) + assert.Equal(t, "flyte_cached-BE6CZsMk6N3ExR_4X9EuwBgj2Jh2UwasXK3a_pM9xlY", o.GetTagName()) + return true + }), + ).Return(&datacatalog.GetArtifactResponse{Artifact: sampleArtifact}, nil) + + resp, err := discovery.Get(ctx, sampleTask, testFile) + assert.NoError(t, err) + assert.NotNil(t, resp) + }) + + t.Run("Found w/ tag no inputs or outputs", func(t *testing.T) { + dataStore := createInmemoryStore(t) + + mockClient := &mocks.DataCatalogClient{} + discovery := &CatalogClient{ + client: mockClient, + store: dataStore, + } + + sampleDataSet := &datacatalog.Dataset{ + Id: &datacatalog.DatasetID{ + Project: "project", + Domain: "domain", + Name: "name", + Version: "1.0.0-V-K42BDF-V-K42BDF", + }, + } + + mockClient.On("GetDataset", + ctx, + mock.MatchedBy(func(o *datacatalog.GetDatasetRequest) bool { + assert.EqualValues(t, "1.0.0-V-K42BDF-V-K42BDF", o.Dataset.Version) + return true + }), + ).Return(&datacatalog.GetDatasetResponse{Dataset: sampleDataSet}, nil) + + sampleArtifact := &datacatalog.Artifact{ + Id: "test-artifact", + Dataset: sampleDataSet.Id, + Data: []*datacatalog.ArtifactData{}, + } + mockClient.On("GetArtifact", + ctx, + mock.MatchedBy(func(o *datacatalog.GetArtifactRequest) bool { + assert.EqualValues(t, "1.0.0-V-K42BDF-V-K42BDF", o.Dataset.Version) + assert.Equal(t, "flyte_cached-m4vFNUOHOFEFIiZSyOyid92TkWFFBDha4UOkkBb47XU", o.GetTagName()) + return true + }), + ).Return(&datacatalog.GetArtifactResponse{Artifact: sampleArtifact}, nil) + + resp, err := discovery.Get(ctx, noInputOutputTask, testFile) + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.Len(t, resp.Literals, 0) + }) +} + +func TestCatalog_Put(t *testing.T) { + ctx := context.Background() + + execID := &core.TaskExecutionIdentifier{ + NodeExecutionId: &core.NodeExecutionIdentifier{ + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "runID", + }, + }, + TaskId: &core.Identifier{ + Project: "project", + Domain: "domain", + Name: "taskRunName", + Version: "taskRunVersion", + }, + } + + testFile := storage.DataReference("test-data.pb") + + t.Run("Create new cached execution", func(t *testing.T) { + dataStore := createInmemoryStore(t) + + err := dataStore.WriteProtobuf(ctx, testFile, storage.Options{}, sampleParameters) + assert.NoError(t, err) + + mockClient := &mocks.DataCatalogClient{} + discovery := &CatalogClient{ + client: mockClient, + store: dataStore, + } + + mockClient.On("CreateDataset", + ctx, + mock.MatchedBy(func(o *datacatalog.CreateDatasetRequest) bool { + assert.True(t, proto.Equal(o.Dataset.Id, datasetID)) + return true + }), + ).Return(&datacatalog.CreateDatasetResponse{}, nil) + + mockClient.On("CreateArtifact", + ctx, + mock.MatchedBy(func(o *datacatalog.CreateArtifactRequest) bool { + _, parseErr := uuid.Parse(o.Artifact.Id) + assert.NoError(t, parseErr) + assert.EqualValues(t, 1, len(o.Artifact.Data)) + assert.EqualValues(t, "out1", o.Artifact.Data[0].Name) + assert.EqualValues(t, newStringLiteral("output1-stringval"), o.Artifact.Data[0].Value) + return true + }), + ).Return(&datacatalog.CreateArtifactResponse{}, nil) + + mockClient.On("AddTag", + ctx, + mock.MatchedBy(func(o *datacatalog.AddTagRequest) bool { + assert.EqualValues(t, "flyte_cached-BE6CZsMk6N3ExR_4X9EuwBgj2Jh2UwasXK3a_pM9xlY", o.Tag.Name) + return true + }), + ).Return(&datacatalog.AddTagResponse{}, nil) + err = discovery.Put(ctx, sampleTask, execID, testFile, testFile) + assert.NoError(t, err) + }) + + t.Run("Create new cached execution with no inputs/outputs", func(t *testing.T) { + dataStore := createInmemoryStore(t) + mockClient := &mocks.DataCatalogClient{} + catalogClient := &CatalogClient{ + client: mockClient, + store: dataStore, + } + + mockClient.On("CreateDataset", + ctx, + mock.MatchedBy(func(o *datacatalog.CreateDatasetRequest) bool { + assert.Equal(t, "1.0.0-V-K42BDF-V-K42BDF", o.Dataset.Id.Version) + return true + }), + ).Return(&datacatalog.CreateDatasetResponse{}, nil) + + mockClient.On("CreateArtifact", + ctx, + mock.MatchedBy(func(o *datacatalog.CreateArtifactRequest) bool { + assert.EqualValues(t, 0, len(o.Artifact.Data)) + return true + }), + ).Return(&datacatalog.CreateArtifactResponse{}, nil) + + mockClient.On("AddTag", + ctx, + mock.MatchedBy(func(o *datacatalog.AddTagRequest) bool { + assert.EqualValues(t, "flyte_cached-m4vFNUOHOFEFIiZSyOyid92TkWFFBDha4UOkkBb47XU", o.Tag.Name) + return true + }), + ).Return(&datacatalog.AddTagResponse{}, nil) + err := catalogClient.Put(ctx, noInputOutputTask, execID, "", "") + assert.NoError(t, err) + }) + + t.Run("Create new cached execution with existing dataset", func(t *testing.T) { + dataStore := createInmemoryStore(t) + + err := dataStore.WriteProtobuf(ctx, testFile, storage.Options{}, sampleParameters) + assert.NoError(t, err) + + mockClient := &mocks.DataCatalogClient{} + discovery := &CatalogClient{ + client: mockClient, + store: dataStore, + } + + createDatasetCalled := false + mockClient.On("CreateDataset", + ctx, + mock.MatchedBy(func(o *datacatalog.CreateDatasetRequest) bool { + createDatasetCalled = true + return true + }), + ).Return(nil, status.Error(codes.AlreadyExists, "test dataset already exists")) + + createArtifactCalled := false + mockClient.On("CreateArtifact", + ctx, + mock.MatchedBy(func(o *datacatalog.CreateArtifactRequest) bool { + _, parseErr := uuid.Parse(o.Artifact.Id) + assert.NoError(t, parseErr) + assert.EqualValues(t, 1, len(o.Artifact.Data)) + assert.EqualValues(t, "out1", o.Artifact.Data[0].Name) + assert.EqualValues(t, newStringLiteral("output1-stringval"), o.Artifact.Data[0].Value) + createArtifactCalled = true + return true + }), + ).Return(&datacatalog.CreateArtifactResponse{}, nil) + + addTagCalled := false + mockClient.On("AddTag", + ctx, + mock.MatchedBy(func(o *datacatalog.AddTagRequest) bool { + assert.EqualValues(t, "flyte_cached-BE6CZsMk6N3ExR_4X9EuwBgj2Jh2UwasXK3a_pM9xlY", o.Tag.Name) + addTagCalled = true + return true + }), + ).Return(&datacatalog.AddTagResponse{}, nil) + err = discovery.Put(ctx, sampleTask, execID, testFile, testFile) + assert.NoError(t, err) + assert.True(t, createDatasetCalled) + assert.True(t, createArtifactCalled) + assert.True(t, addTagCalled) + }) + +} diff --git a/flytepropeller/pkg/controller/catalog/datacatalog/mocks/DataCatalogClient.go b/flytepropeller/pkg/controller/catalog/datacatalog/mocks/DataCatalogClient.go new file mode 100644 index 0000000000..ae5a7f8e4b --- /dev/null +++ b/flytepropeller/pkg/controller/catalog/datacatalog/mocks/DataCatalogClient.go @@ -0,0 +1,163 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import context "context" +import datacatalog "github.com/lyft/datacatalog/protos/gen" +import grpc "google.golang.org/grpc" +import mock "github.com/stretchr/testify/mock" + +// DataCatalogClient is an autogenerated mock type for the DataCatalogClient type +type DataCatalogClient struct { + mock.Mock +} + +// AddTag provides a mock function with given fields: ctx, in, opts +func (_m *DataCatalogClient) AddTag(ctx context.Context, in *datacatalog.AddTagRequest, opts ...grpc.CallOption) (*datacatalog.AddTagResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *datacatalog.AddTagResponse + if rf, ok := ret.Get(0).(func(context.Context, *datacatalog.AddTagRequest, ...grpc.CallOption) *datacatalog.AddTagResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*datacatalog.AddTagResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *datacatalog.AddTagRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// CreateArtifact provides a mock function with given fields: ctx, in, opts +func (_m *DataCatalogClient) CreateArtifact(ctx context.Context, in *datacatalog.CreateArtifactRequest, opts ...grpc.CallOption) (*datacatalog.CreateArtifactResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *datacatalog.CreateArtifactResponse + if rf, ok := ret.Get(0).(func(context.Context, *datacatalog.CreateArtifactRequest, ...grpc.CallOption) *datacatalog.CreateArtifactResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*datacatalog.CreateArtifactResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *datacatalog.CreateArtifactRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// CreateDataset provides a mock function with given fields: ctx, in, opts +func (_m *DataCatalogClient) CreateDataset(ctx context.Context, in *datacatalog.CreateDatasetRequest, opts ...grpc.CallOption) (*datacatalog.CreateDatasetResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *datacatalog.CreateDatasetResponse + if rf, ok := ret.Get(0).(func(context.Context, *datacatalog.CreateDatasetRequest, ...grpc.CallOption) *datacatalog.CreateDatasetResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*datacatalog.CreateDatasetResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *datacatalog.CreateDatasetRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetArtifact provides a mock function with given fields: ctx, in, opts +func (_m *DataCatalogClient) GetArtifact(ctx context.Context, in *datacatalog.GetArtifactRequest, opts ...grpc.CallOption) (*datacatalog.GetArtifactResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *datacatalog.GetArtifactResponse + if rf, ok := ret.Get(0).(func(context.Context, *datacatalog.GetArtifactRequest, ...grpc.CallOption) *datacatalog.GetArtifactResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*datacatalog.GetArtifactResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *datacatalog.GetArtifactRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetDataset provides a mock function with given fields: ctx, in, opts +func (_m *DataCatalogClient) GetDataset(ctx context.Context, in *datacatalog.GetDatasetRequest, opts ...grpc.CallOption) (*datacatalog.GetDatasetResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *datacatalog.GetDatasetResponse + if rf, ok := ret.Get(0).(func(context.Context, *datacatalog.GetDatasetRequest, ...grpc.CallOption) *datacatalog.GetDatasetResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*datacatalog.GetDatasetResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *datacatalog.GetDatasetRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/flytepropeller/pkg/controller/catalog/datacatalog/transformer/datacatalog_transformer.go b/flytepropeller/pkg/controller/catalog/datacatalog/transformer/datacatalog_transformer.go new file mode 100644 index 0000000000..6f66b24f40 --- /dev/null +++ b/flytepropeller/pkg/controller/catalog/datacatalog/transformer/datacatalog_transformer.go @@ -0,0 +1,144 @@ +package transformer + +import ( + "context" + "fmt" + "reflect" + + "encoding/base64" + + datacatalog "github.com/lyft/datacatalog/protos/gen" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/compiler/validators" + "github.com/lyft/flytestdlib/pbhash" +) + +const cachedTaskTag = "flyte_cached" +const taskNamespace = "flyte_task" +const maxParamHashLength = 8 + +// Declare the definition of empty literal and variable maps. This is important because we hash against +// the literal and variable maps. So Nil and empty literals and variable maps should translate to these defintions +// in order to have a consistent hash. +var emptyLiteralMap = core.LiteralMap{Literals: map[string]*core.Literal{}} +var emptyVariableMap = core.VariableMap{Variables: map[string]*core.Variable{}} + +func getDatasetNameFromTask(task *core.TaskTemplate) string { + return fmt.Sprintf("%s-%s", taskNamespace, task.Id.Name) +} + +func GenerateTaskOutputsFromArtifact(task *core.TaskTemplate, artifact *datacatalog.Artifact) (*core.LiteralMap, error) { + // if there are no outputs in the task, return empty map + if task.Interface.Outputs == nil || len(task.Interface.Outputs.Variables) == 0 { + return &emptyLiteralMap, nil + } + + outputVariables := task.Interface.Outputs.Variables + artifactDataList := artifact.Data + + // verify the task outputs matches what is stored in ArtifactData + if len(outputVariables) != len(artifactDataList) { + return nil, fmt.Errorf("The task %v with %d outputs, should have %d artifactData for artifact %v", + task.Id, len(outputVariables), len(artifactDataList), artifact.Id) + } + + outputs := make(map[string]*core.Literal, len(artifactDataList)) + for _, artifactData := range artifactDataList { + // verify that the name and type of artifactData matches what is expected from the interface + if _, ok := outputVariables[artifactData.Name]; !ok { + return nil, fmt.Errorf("Unexpected artifactData with name [%v] does not match any task output variables %v", artifactData.Name, reflect.ValueOf(outputVariables).MapKeys()) + } + + expectedVarType := outputVariables[artifactData.Name].GetType() + inputType := validators.LiteralTypeForLiteral(artifactData.Value) + if !validators.AreTypesCastable(inputType, expectedVarType) { + return nil, fmt.Errorf("Unexpected artifactData: [%v] type: [%v] does not match any task output type: [%v]", artifactData.Name, inputType, expectedVarType) + } + + outputs[artifactData.Name] = artifactData.Value + } + + return &core.LiteralMap{Literals: outputs}, nil +} + +func generateDataSetVersionFromTask(ctx context.Context, task *core.TaskTemplate) (string, error) { + signatureHash, err := generateTaskSignatureHash(ctx, task) + if err != nil { + return "", err + } + + cacheVersion := task.Metadata.DiscoveryVersion + if len(cacheVersion) == 0 { + return "", fmt.Errorf("Task cannot have an empty discoveryVersion %v", cacheVersion) + } + return fmt.Sprintf("%s-%s", cacheVersion, signatureHash), nil +} + +func generateTaskSignatureHash(ctx context.Context, task *core.TaskTemplate) (string, error) { + taskInputs := &emptyVariableMap + taskOutputs := &emptyVariableMap + + if task.Interface.Inputs != nil && len(task.Interface.Inputs.Variables) != 0 { + taskInputs = task.Interface.Inputs + } + + if task.Interface.Outputs != nil && len(task.Interface.Outputs.Variables) != 0 { + taskOutputs = task.Interface.Outputs + } + + inputHash, err := pbhash.ComputeHash(ctx, taskInputs) + if err != nil { + return "", err + } + + outputHash, err := pbhash.ComputeHash(ctx, taskOutputs) + if err != nil { + return "", err + } + + inputHashString := base64.RawURLEncoding.EncodeToString(inputHash) + + if len(inputHashString) > maxParamHashLength { + inputHashString = inputHashString[0:maxParamHashLength] + } + + outputHashString := base64.RawURLEncoding.EncodeToString(outputHash) + if len(outputHashString) > maxParamHashLength { + outputHashString = outputHashString[0:maxParamHashLength] + } + + return fmt.Sprintf("%v-%v", inputHashString, outputHashString), nil +} + +func GenerateArtifactTagName(ctx context.Context, inputs *core.LiteralMap) (string, error) { + if inputs == nil || len(inputs.Literals) == 0 { + inputs = &emptyLiteralMap + } + + inputsHash, err := pbhash.ComputeHash(ctx, inputs) + if err != nil { + return "", err + } + + hashString := base64.RawURLEncoding.EncodeToString(inputsHash) + tag := fmt.Sprintf("%s-%s", cachedTaskTag, hashString) + return tag, nil +} + +// Get the DataSetID for a task. +// NOTE: the version of the task is a combination of both the discoverable_version and the task signature. +// This is because the interfact may of changed even if the discoverable_version hadn't. +func GenerateDatasetIDForTask(ctx context.Context, task *core.TaskTemplate) (*datacatalog.DatasetID, error) { + datasetVersion, err := generateDataSetVersionFromTask(ctx, task) + if err != nil { + return nil, err + } + + datasetID := &datacatalog.DatasetID{ + Project: task.Id.Project, + Domain: task.Id.Domain, + Name: getDatasetNameFromTask(task), + Version: datasetVersion, + } + return datasetID, nil +} diff --git a/flytepropeller/pkg/controller/catalog/datacatalog/transformer/datacatalog_transformer_test.go b/flytepropeller/pkg/controller/catalog/datacatalog/transformer/datacatalog_transformer_test.go new file mode 100644 index 0000000000..39cd764618 --- /dev/null +++ b/flytepropeller/pkg/controller/catalog/datacatalog/transformer/datacatalog_transformer_test.go @@ -0,0 +1,132 @@ +package transformer + +import ( + "context" + "testing" + + "github.com/gogo/protobuf/proto" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/utils" + "github.com/stretchr/testify/assert" +) + +// add test for raarranged Literal maps for input values + +func TestNilParamTask(t *testing.T) { + task := &core.TaskTemplate{ + Id: &core.Identifier{ + Project: "project", + Domain: "domain", + Name: "name", + Version: "1.0.0", + }, + Metadata: &core.TaskMetadata{ + DiscoveryVersion: "1.0.0", + }, + Interface: &core.TypedInterface{ + Inputs: nil, + Outputs: nil, + }, + } + datasetID, err := GenerateDatasetIDForTask(context.TODO(), task) + assert.NoError(t, err) + assert.NotEmpty(t, datasetID.Version) + assert.Equal(t, "1.0.0-V-K42BDF-V-K42BDF", datasetID.Version) +} + +// Ensure that empty parameters generate the same dataset as nil parameters +func TestEmptyParamTask(t *testing.T) { + task := &core.TaskTemplate{ + Id: &core.Identifier{ + Project: "project", + Domain: "domain", + Name: "name", + Version: "1.0.0", + }, + Metadata: &core.TaskMetadata{ + DiscoveryVersion: "1.0.0", + }, + Interface: &core.TypedInterface{ + Inputs: &core.VariableMap{}, + Outputs: &core.VariableMap{}, + }, + } + datasetID, err := GenerateDatasetIDForTask(context.TODO(), task) + assert.NoError(t, err) + assert.NotEmpty(t, datasetID.Version) + assert.Equal(t, "1.0.0-V-K42BDF-V-K42BDF", datasetID.Version) + + task.Interface.Inputs = nil + task.Interface.Outputs = nil + datasetIDDupe, err := GenerateDatasetIDForTask(context.TODO(), task) + assert.NoError(t, err) + assert.True(t, proto.Equal(datasetIDDupe, datasetID)) +} + +// Ensure the key order on the map generates the same dataset +func TestVariableMapOrder(t *testing.T) { + task := &core.TaskTemplate{ + Id: &core.Identifier{ + Project: "project", + Domain: "domain", + Name: "name", + Version: "1.0.0", + }, + Metadata: &core.TaskMetadata{ + DiscoveryVersion: "1.0.0", + }, + Interface: &core.TypedInterface{ + Inputs: &core.VariableMap{ + Variables: map[string]*core.Variable{ + "1": {Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}}}, + "2": {Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}}}, + }, + }, + }, + } + datasetID, err := GenerateDatasetIDForTask(context.TODO(), task) + assert.NoError(t, err) + assert.NotEmpty(t, datasetID.Version) + assert.Equal(t, "1.0.0-UxVtPm0k-V-K42BDF", datasetID.Version) + + task.Interface.Inputs = &core.VariableMap{ + Variables: map[string]*core.Variable{ + "2": {Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}}}, + "1": {Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}}}, + }, + } + datasetIDDupe, err := GenerateDatasetIDForTask(context.TODO(), task) + assert.NoError(t, err) + + assert.Equal(t, "1.0.0-UxVtPm0k-V-K42BDF", datasetIDDupe.Version) + assert.True(t, proto.Equal(datasetID, datasetIDDupe)) +} + +// Ensure the key order on the inputs generates the same tag +func TestInputValueSorted(t *testing.T) { + literalMap, err := utils.MakeLiteralMap(map[string]interface{}{"1": 1, "2": 2}) + assert.NoError(t, err) + + tag, err := GenerateArtifactTagName(context.TODO(), literalMap) + assert.NoError(t, err) + assert.Equal(t, "flyte_cached-GQid5LjHbakcW68DS3P2jp80QLbiF0olFHF2hTh5bg8", tag) + + literalMap, err = utils.MakeLiteralMap(map[string]interface{}{"2": 2, "1": 1}) + assert.NoError(t, err) + + tagDupe, err := GenerateArtifactTagName(context.TODO(), literalMap) + assert.NoError(t, err) + assert.Equal(t, tagDupe, tag) +} + +// Ensure that empty inputs are hashed the same way +func TestNoInputValues(t *testing.T) { + tag, err := GenerateArtifactTagName(context.TODO(), nil) + assert.NoError(t, err) + assert.Equal(t, "flyte_cached-m4vFNUOHOFEFIiZSyOyid92TkWFFBDha4UOkkBb47XU", tag) + + tagDupe, err := GenerateArtifactTagName(context.TODO(), &core.LiteralMap{Literals: nil}) + assert.NoError(t, err) + assert.Equal(t, "flyte_cached-m4vFNUOHOFEFIiZSyOyid92TkWFFBDha4UOkkBb47XU", tagDupe) + assert.Equal(t, tagDupe, tag) +} diff --git a/flytepropeller/pkg/controller/controller.go b/flytepropeller/pkg/controller/controller.go index 25d8fb4316..d805efee49 100644 --- a/flytepropeller/pkg/controller/controller.go +++ b/flytepropeller/pkg/controller/controller.go @@ -274,7 +274,10 @@ func New(ctx context.Context, cfg *config.Config, kubeclientset kubernetes.Inter } logger.Info(ctx, "Setting up Catalog client.") - catalogClient := catalog.NewCatalogClient(store) + catalogClient, err := catalog.NewCatalogClient(ctx, store) + if err != nil { + return nil, errors.Wrapf(err, "Failed to create datacatalog client") + } workQ, err := NewCompositeWorkQueue(ctx, cfg.Queue, scope) if err != nil { diff --git a/flytepropeller/pkg/controller/nodes/executor_test.go b/flytepropeller/pkg/controller/nodes/executor_test.go index 2cfcc4caff..56d7d8c53f 100644 --- a/flytepropeller/pkg/controller/nodes/executor_test.go +++ b/flytepropeller/pkg/controller/nodes/executor_test.go @@ -59,7 +59,7 @@ func init() { func TestSetInputsForStartNode(t *testing.T) { ctx := context.Background() mockStorage := createInmemoryDataStore(t, testScope.NewSubScope("f")) - catalogClient := catalog.NewCatalogClient(mockStorage) + catalogClient, _ := catalog.NewCatalogClient(ctx, mockStorage) enQWf := func(workflowID v1alpha1.WorkflowID) {} factory := createSingletonTaskExecutorFactory() @@ -138,7 +138,7 @@ func TestNodeExecutor_TransitionToPhase(t *testing.T) { memStore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) assert.NoError(t, err) - catalogClient := catalog.NewCatalogClient(memStore) + catalogClient, _ := catalog.NewCatalogClient(ctx, memStore) execIface, err := NewExecutor(ctx, memStore, enQWf, time.Second, mockEventSink, launchplan.NewFailFastLaunchPlanExecutor(), catalogClient, fakeKubeClient, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*nodeExecutor) @@ -349,7 +349,7 @@ func TestNodeExecutor_Initialize(t *testing.T) { memStore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) assert.NoError(t, err) - catalogClient := catalog.NewCatalogClient(memStore) + catalogClient, _ := catalog.NewCatalogClient(ctx, memStore) execIface, err := NewExecutor(ctx, memStore, enQWf, time.Second, mockEventSink, launchplan.NewFailFastLaunchPlanExecutor(), catalogClient, fakeKubeClient, promutils.NewTestScope()) assert.NoError(t, err) @@ -368,7 +368,7 @@ func TestNodeExecutor_RecursiveNodeHandler_RecurseStartNodes(t *testing.T) { assert.True(t, task.IsTestModeEnabled()) store := createInmemoryDataStore(t, promutils.NewTestScope()) - catalogClient := catalog.NewCatalogClient(store) + catalogClient, _ := catalog.NewCatalogClient(ctx, store) execIface, err := NewExecutor(ctx, store, enQWf, time.Second, mockEventSink, launchplan.NewFailFastLaunchPlanExecutor(), catalogClient, fakeKubeClient, promutils.NewTestScope()) @@ -472,7 +472,7 @@ func TestNodeExecutor_RecursiveNodeHandler_RecurseEndNode(t *testing.T) { assert.True(t, task.IsTestModeEnabled()) store := createInmemoryDataStore(t, promutils.NewTestScope()) - catalogClient := catalog.NewCatalogClient(store) + catalogClient, _ := catalog.NewCatalogClient(ctx, store) execIface, err := NewExecutor(ctx, store, enQWf, time.Second, mockEventSink, launchplan.NewFailFastLaunchPlanExecutor(), catalogClient, fakeKubeClient, promutils.NewTestScope()) assert.NoError(t, err) @@ -660,7 +660,7 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { assert.True(t, task.IsTestModeEnabled()) store := createInmemoryDataStore(t, promutils.NewTestScope()) - catalogClient := catalog.NewCatalogClient(store) + catalogClient, _ := catalog.NewCatalogClient(ctx, store) execIface, err := NewExecutor(ctx, store, enQWf, time.Second, mockEventSink, launchplan.NewFailFastLaunchPlanExecutor(), catalogClient, fakeKubeClient, promutils.NewTestScope()) assert.NoError(t, err) @@ -961,7 +961,7 @@ func TestNodeExecutor_RecursiveNodeHandler_NoDownstream(t *testing.T) { assert.True(t, task.IsTestModeEnabled()) store := createInmemoryDataStore(t, promutils.NewTestScope()) - catalogClient := catalog.NewCatalogClient(store) + catalogClient, _ := catalog.NewCatalogClient(ctx, store) execIface, err := NewExecutor(ctx, store, enQWf, time.Second, mockEventSink, launchplan.NewFailFastLaunchPlanExecutor(), catalogClient, fakeKubeClient, promutils.NewTestScope()) assert.NoError(t, err) @@ -1334,7 +1334,7 @@ func TestNodeExecutor_RecursiveNodeHandler_UpstreamNotReady(t *testing.T) { assert.True(t, task.IsTestModeEnabled()) store := createInmemoryDataStore(t, promutils.NewTestScope()) - catalogClient := catalog.NewCatalogClient(store) + catalogClient, _ := catalog.NewCatalogClient(ctx, store) execIface, err := NewExecutor(ctx, store, enQWf, time.Second, mockEventSink, launchplan.NewFailFastLaunchPlanExecutor(), catalogClient, fakeKubeClient, promutils.NewTestScope()) assert.NoError(t, err) diff --git a/flytepropeller/pkg/controller/nodes/task/handler.go b/flytepropeller/pkg/controller/nodes/task/handler.go index b118107edd..dc0776a17d 100644 --- a/flytepropeller/pkg/controller/nodes/task/handler.go +++ b/flytepropeller/pkg/controller/nodes/task/handler.go @@ -215,25 +215,26 @@ func (h *taskHandler) StartNode(ctx context.Context, w v1alpha1.ExecutableWorkfl } logger.Infof(ctx, "Executor type: [%v]. Properties: finalizer[%v]. disable[%v].", reflect.TypeOf(t).String(), t.GetProperties().RequiresFinalizer, t.GetProperties().DisableNodeLevelCaching) - if task.CoreTask().Metadata.Discoverable { + if iface := task.CoreTask().Interface; task.CoreTask().Metadata.Discoverable && iface != nil && iface.Outputs != nil && len(iface.Outputs.Variables) > 0 { if t.GetProperties().DisableNodeLevelCaching { logger.Infof(ctx, "Executor has Node-Level caching disabled. Skipping.") } else if resp, err := h.catalogClient.Get(ctx, task.CoreTask(), taskCtx.GetInputsFile()); err != nil { if taskStatus, ok := status.FromError(err); ok && taskStatus.Code() == codes.NotFound { h.metrics.discoveryMissCount.Inc(ctx) - logger.Infof(ctx, "Artifact not found in Discovery. Executing Task.") + logger.Infof(ctx, "Artifact not found in cache. Executing Task.") } else { h.metrics.discoveryGetFailureCount.Inc(ctx) - logger.Errorf(ctx, "Discovery check failed. Executing Task. Err: %v", err.Error()) + logger.Errorf(ctx, "Catalog cache check failed. Executing Task. Err: %v", err.Error()) } } else if resp != nil { h.metrics.discoveryHitCount.Inc(ctx) - if iface := task.CoreTask().Interface; iface != nil && iface.Outputs != nil && len(iface.Outputs.Variables) > 0 { - if err := h.store.WriteProtobuf(ctx, taskCtx.GetOutputsFile(), storage.Options{}, resp); err != nil { - logger.Errorf(ctx, "failed to write data to Storage, err: %v", err.Error()) - return handler.StatusUndefined, errors.Wrapf(errors.CausedByError, node.GetID(), err, "failed to copy cached results for task.") - } + + logger.Debugf(ctx, "Outputs found in Catalog cache %+v", resp) + if err := h.store.WriteProtobuf(ctx, taskCtx.GetOutputsFile(), storage.Options{}, resp); err != nil { + logger.Errorf(ctx, "failed to write data to Storage, err: %v", err.Error()) + return handler.StatusUndefined, errors.Wrapf(errors.CausedByError, node.GetID(), err, "failed to copy cached results for task.") } + // SetCached. w.GetNodeExecutionStatus(node.GetID()).SetCached() return handler.StatusSuccess, nil @@ -344,7 +345,7 @@ func (h *taskHandler) HandleNodeSuccess(ctx context.Context, w v1alpha1.Executab h.metrics.discoveryPutFailureCount.Inc(ctx) logger.Errorf(ctx, "Failed to write results to catalog. Err: %v", err2) } else { - logger.Debugf(ctx, "Successfully cached results to discovery - Task [%s]", task.CoreTask().GetId()) + logger.Debugf(ctx, "Successfully cached results - Task [%s]", task.CoreTask().GetId()) } } } diff --git a/flytepropeller/pkg/controller/workflow/executor_test.go b/flytepropeller/pkg/controller/workflow/executor_test.go index 492e4067dd..39bd812917 100644 --- a/flytepropeller/pkg/controller/workflow/executor_test.go +++ b/flytepropeller/pkg/controller/workflow/executor_test.go @@ -231,7 +231,8 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_Error(t *testing.T) { enqueueWorkflow := func(workflowId v1alpha1.WorkflowID) {} eventSink := events.NewMockEventSink() - nodeExec, err := nodes.NewExecutor(ctx, store, enqueueWorkflow, time.Second, eventSink, launchplan.NewFailFastLaunchPlanExecutor(), catalog.NewCatalogClient(store), fakeKubeClient, promutils.NewTestScope()) + catalogClient, _ := catalog.NewCatalogClient(ctx, store) + nodeExec, err := nodes.NewExecutor(ctx, store, enqueueWorkflow, time.Second, eventSink, launchplan.NewFailFastLaunchPlanExecutor(), catalogClient, fakeKubeClient, promutils.NewTestScope()) assert.NoError(t, err) executor, err := NewExecutor(ctx, store, enqueueWorkflow, eventSink, recorder, "", nodeExec, promutils.NewTestScope()) assert.NoError(t, err) @@ -280,7 +281,8 @@ func TestWorkflowExecutor_HandleFlyteWorkflow(t *testing.T) { enqueueWorkflow := func(workflowId v1alpha1.WorkflowID) {} eventSink := events.NewMockEventSink() - nodeExec, err := nodes.NewExecutor(ctx, store, enqueueWorkflow, time.Second, eventSink, launchplan.NewFailFastLaunchPlanExecutor(), catalog.NewCatalogClient(store), fakeKubeClient, promutils.NewTestScope()) + catalogClient, _ := catalog.NewCatalogClient(ctx, store) + nodeExec, err := nodes.NewExecutor(ctx, store, enqueueWorkflow, time.Second, eventSink, launchplan.NewFailFastLaunchPlanExecutor(), catalogClient, fakeKubeClient, promutils.NewTestScope()) assert.NoError(t, err) executor, err := NewExecutor(ctx, store, enqueueWorkflow, eventSink, recorder, "", nodeExec, promutils.NewTestScope()) @@ -330,7 +332,8 @@ func BenchmarkWorkflowExecutor(b *testing.B) { enqueueWorkflow := func(workflowId v1alpha1.WorkflowID) {} eventSink := events.NewMockEventSink() - nodeExec, err := nodes.NewExecutor(ctx, store, enqueueWorkflow, time.Second, eventSink, launchplan.NewFailFastLaunchPlanExecutor(), catalog.NewCatalogClient(store), fakeKubeClient, scope) + catalogClient, _ := catalog.NewCatalogClient(ctx, store) + nodeExec, err := nodes.NewExecutor(ctx, store, enqueueWorkflow, time.Second, eventSink, launchplan.NewFailFastLaunchPlanExecutor(), catalogClient, fakeKubeClient, scope) assert.NoError(b, err) executor, err := NewExecutor(ctx, store, enqueueWorkflow, eventSink, recorder, "", nodeExec, promutils.NewTestScope()) @@ -415,7 +418,8 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_Failing(t *testing.T) { } return nil } - nodeExec, err := nodes.NewExecutor(ctx, store, enqueueWorkflow, time.Second, eventSink, launchplan.NewFailFastLaunchPlanExecutor(), catalog.NewCatalogClient(store), fakeKubeClient, promutils.NewTestScope()) + catalogClient, _ := catalog.NewCatalogClient(ctx, store) + nodeExec, err := nodes.NewExecutor(ctx, store, enqueueWorkflow, time.Second, eventSink, launchplan.NewFailFastLaunchPlanExecutor(), catalogClient, fakeKubeClient, promutils.NewTestScope()) assert.NoError(t, err) executor, err := NewExecutor(ctx, store, enqueueWorkflow, eventSink, recorder, "", nodeExec, promutils.NewTestScope()) assert.NoError(t, err) @@ -502,7 +506,8 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_Events(t *testing.T) { } return nil } - nodeExec, err := nodes.NewExecutor(ctx, store, enqueueWorkflow, time.Second, eventSink, launchplan.NewFailFastLaunchPlanExecutor(), catalog.NewCatalogClient(store), fakeKubeClient, promutils.NewTestScope()) + catalogClient, _ := catalog.NewCatalogClient(ctx, store) + nodeExec, err := nodes.NewExecutor(ctx, store, enqueueWorkflow, time.Second, eventSink, launchplan.NewFailFastLaunchPlanExecutor(), catalogClient, fakeKubeClient, promutils.NewTestScope()) assert.NoError(t, err) executor, err := NewExecutor(ctx, store, enqueueWorkflow, eventSink, recorder, "metadata", nodeExec, promutils.NewTestScope()) assert.NoError(t, err) @@ -553,7 +558,8 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_EventFailure(t *testing.T) { assert.NoError(t, err) nodeEventSink := events.NewMockEventSink() - nodeExec, err := nodes.NewExecutor(ctx, store, enqueueWorkflow, time.Second, nodeEventSink, launchplan.NewFailFastLaunchPlanExecutor(), catalog.NewCatalogClient(store), fakeKubeClient, promutils.NewTestScope()) + catalogClient, _ := catalog.NewCatalogClient(ctx, store) + nodeExec, err := nodes.NewExecutor(ctx, store, enqueueWorkflow, time.Second, nodeEventSink, launchplan.NewFailFastLaunchPlanExecutor(), catalogClient, fakeKubeClient, promutils.NewTestScope()) assert.NoError(t, err) t.Run("EventAlreadyInTerminalStateError", func(t *testing.T) { From 382026f15b328069c490ab65b5a378f0c3261763 Mon Sep 17 00:00:00 2001 From: Andrew Chan Date: Mon, 9 Sep 2019 17:07:36 -0700 Subject: [PATCH 0092/1918] Correct metadata for execution name --- .../catalog/datacatalog/datacatalog.go | 31 ++++++++----------- 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/flytepropeller/pkg/controller/catalog/datacatalog/datacatalog.go b/flytepropeller/pkg/controller/catalog/datacatalog/datacatalog.go index 672da3f3f5..1cdea9b816 100644 --- a/flytepropeller/pkg/controller/catalog/datacatalog/datacatalog.go +++ b/flytepropeller/pkg/controller/catalog/datacatalog/datacatalog.go @@ -21,11 +21,11 @@ import ( ) const ( - taskVersionKey = "task-version" - taskExecKey = "execution-name" - taskExecVersion = "execution-version" + taskVersionKey = "task-version" + taskExecKey = "execution-name" ) +// This is the client that caches task executions to DataCatalog service. type CatalogClient struct { client datacatalog.DataCatalogClient store storage.ProtobufStore @@ -174,15 +174,17 @@ func (m *CatalogClient) Put(ctx context.Context, task *core.TaskTemplate, execID logger.Debugf(ctx, "DataCatalog put into Catalog for DataSet %v", datasetID) // Try creating the dataset in case it doesn't exist - newDataset := &datacatalog.Dataset{ - Id: datasetID, - Metadata: &datacatalog.Metadata{ - KeyMap: map[string]string{ - taskVersionKey: task.Id.Version, - taskExecKey: execID.TaskId.Name, - }, + + metadata := &datacatalog.Metadata{ + KeyMap: map[string]string{ + taskVersionKey: task.Id.Version, + taskExecKey: execID.NodeExecutionId.NodeId, }, } + newDataset := &datacatalog.Dataset{ + Id: datasetID, + Metadata: metadata, + } _, err = m.client.CreateDataset(ctx, &datacatalog.CreateDatasetRequest{Dataset: newDataset}) if err != nil { @@ -206,18 +208,11 @@ func (m *CatalogClient) Put(ctx context.Context, task *core.TaskTemplate, execID artifactDataList = append(artifactDataList, artifactData) } - artifactMetadata := &datacatalog.Metadata{ - KeyMap: map[string]string{ - taskExecVersion: execID.TaskId.Version, - taskExecKey: execID.TaskId.Name, - }, - } - cachedArtifact := &datacatalog.Artifact{ Id: string(uuid.NewUUID()), Dataset: datasetID, Data: artifactDataList, - Metadata: artifactMetadata, + Metadata: metadata, } createArtifactRequest := &datacatalog.CreateArtifactRequest{Artifact: cachedArtifact} From c2655ec3b428c77f6b3b717f517014eb573dfc6d Mon Sep 17 00:00:00 2001 From: Chetan Raj Date: Tue, 10 Sep 2019 09:55:33 -0700 Subject: [PATCH 0093/1918] Rename JobUri to CommandUri --- .../go/tasks/v1/qubole/client/qubole_client.go | 6 ++++-- flyteplugins/go/tasks/v1/qubole/hive_executor.go | 2 +- flyteplugins/go/tasks/v1/qubole/qubole_work.go | 11 ++++++----- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/flyteplugins/go/tasks/v1/qubole/client/qubole_client.go b/flyteplugins/go/tasks/v1/qubole/client/qubole_client.go index 10c603a758..6baa94b16e 100755 --- a/flyteplugins/go/tasks/v1/qubole/client/qubole_client.go +++ b/flyteplugins/go/tasks/v1/qubole/client/qubole_client.go @@ -35,10 +35,12 @@ type quboleCmdDetailsInternal struct { Status string } +type QuboleUri = string + type QuboleCommandDetails struct { ID int64 Status QuboleStatus - JobUri string + Uri QuboleUri } // QuboleClient API Request Body, meant to be passed into JSON.marshal @@ -196,7 +198,7 @@ func (q *quboleClient) ExecuteHiveCommand( } status := NewQuboleStatus(ctx, cmd.Status) - return &QuboleCommandDetails{ID: cmd.ID, Status: status, JobUri: fmt.Sprintf(QuboleLogLinkFormat, cmd.ID)}, nil + return &QuboleCommandDetails{ID: cmd.ID, Status: status, Uri: fmt.Sprintf(QuboleLogLinkFormat, cmd.ID)}, nil } /* diff --git a/flyteplugins/go/tasks/v1/qubole/hive_executor.go b/flyteplugins/go/tasks/v1/qubole/hive_executor.go index 8fb1301e0d..8e85e30e02 100755 --- a/flyteplugins/go/tasks/v1/qubole/hive_executor.go +++ b/flyteplugins/go/tasks/v1/qubole/hive_executor.go @@ -265,7 +265,7 @@ func (h HiveExecutor) CheckTaskStatus(ctx context.Context, taskCtx types.TaskCon commandId := strconv.FormatInt(cmdDetails.ID, 10) logger.Infof(ctx, "Created Qubole ID %s for %s", commandId, workCacheKey) item.CommandId = commandId - item.JobUri = cmdDetails.JobUri + item.CommandUri = cmdDetails.Uri item.Status = QuboleWorkRunning item.Query = "" // Clear the query to save space in etcd once we've successfully launched err := h.executionBuffer.ConfirmExecution(ctx, workCacheKey, commandId) diff --git a/flyteplugins/go/tasks/v1/qubole/qubole_work.go b/flyteplugins/go/tasks/v1/qubole/qubole_work.go index 84bde9dc5a..8bb8ada446 100755 --- a/flyteplugins/go/tasks/v1/qubole/qubole_work.go +++ b/flyteplugins/go/tasks/v1/qubole/qubole_work.go @@ -1,8 +1,9 @@ package qubole import ( - "fmt" "encoding/json" + "fmt" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" "github.com/lyft/flyteplugins/go/tasks/v1/events" "github.com/lyft/flyteplugins/go/tasks/v1/qubole/client" @@ -44,7 +45,7 @@ type QuboleWorkItem struct { TimeoutSec uint32 `json:"timeout,omitempty"` - JobUri string `json:"job_uri,omitempty"` + CommandUri string `json:"command_uri,omitempty"` } // This ID will be used in a process-wide cache, so it needs to be unique across all concurrent work being done by @@ -150,10 +151,10 @@ func constructEventInfoFromQuboleWorkItems(taskCtx types.TaskContext, quboleWork workItem := v.(QuboleWorkItem) if workItem.CommandId != "" { logs = append(logs, &core.TaskLog{ - Name: fmt.Sprintf("Retry: %d Status: %s [%s]", + Name: fmt.Sprintf("Retry: %d Status: %s [%s]", taskCtx.GetTaskExecutionID().GetID().RetryAttempt, workItem.Status, workItem.CommandId), MessageFormat: core.TaskLog_UNKNOWN, - Uri: workItem.JobUri, + Uri: workItem.CommandUri, }) } } @@ -201,4 +202,4 @@ func InterfaceConverter(cachedInterface interface{}) (QuboleWorkItem, error) { } return *item, nil -} \ No newline at end of file +} From 529136b633ae134f2c569e3ea03fa0054d12a76d Mon Sep 17 00:00:00 2001 From: Andrew Chan Date: Tue, 10 Sep 2019 10:19:21 -0700 Subject: [PATCH 0094/1918] Specify insecure connection with config --- .../pkg/controller/catalog/catalog_client.go | 2 +- flytepropeller/pkg/controller/catalog/config.go | 2 +- .../pkg/controller/catalog/config_flags.go | 4 ++-- .../pkg/controller/catalog/config_flags_test.go | 14 +++++++------- .../controller/catalog/datacatalog/datacatalog.go | 8 ++++++-- 5 files changed, 17 insertions(+), 13 deletions(-) diff --git a/flytepropeller/pkg/controller/catalog/catalog_client.go b/flytepropeller/pkg/controller/catalog/catalog_client.go index af24c2b4d0..67a196898b 100644 --- a/flytepropeller/pkg/controller/catalog/catalog_client.go +++ b/flytepropeller/pkg/controller/catalog/catalog_client.go @@ -25,7 +25,7 @@ func NewCatalogClient(ctx context.Context, store storage.ProtobufStore) (Client, case LegacyDiscoveryType: catalogClient = NewLegacyDiscovery(catalogConfig.Endpoint, store) case DataCatalogType: - catalogClient, err = datacatalog.NewDataCatalog(ctx, catalogConfig.Endpoint, catalogConfig.Secure, store) + catalogClient, err = datacatalog.NewDataCatalog(ctx, catalogConfig.Endpoint, catalogConfig.Insecure, store) if err != nil { return nil, err } diff --git a/flytepropeller/pkg/controller/catalog/config.go b/flytepropeller/pkg/controller/catalog/config.go index 6d4f2c3376..d067876f03 100644 --- a/flytepropeller/pkg/controller/catalog/config.go +++ b/flytepropeller/pkg/controller/catalog/config.go @@ -27,7 +27,7 @@ const ( type Config struct { Type DiscoveryType `json:"type" pflag:"\"noop\", Catalog Implementation to use"` Endpoint string `json:"endpoint" pflag:"\"\", Endpoint for catalog service"` - Secure bool `json:"secure" pflag:"true, Connect with TSL/SSL"` + Insecure bool `json:"insecure" pflag:"false, Use insecure grpc connection"` } // Gets loaded config for Discovery diff --git a/flytepropeller/pkg/controller/catalog/config_flags.go b/flytepropeller/pkg/controller/catalog/config_flags.go index b2359a462d..731c2347f1 100755 --- a/flytepropeller/pkg/controller/catalog/config_flags.go +++ b/flytepropeller/pkg/controller/catalog/config_flags.go @@ -1,6 +1,6 @@ // Code generated by go generate; DO NOT EDIT. // This file was generated by robots at -// 2019-09-05 05:37:07.301294018 -0700 PDT m=+14.696460456 +// 2019-09-10 10:40:36.580780957 -0700 PDT m=+11.375426414 package catalog @@ -16,6 +16,6 @@ func (Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags := pflag.NewFlagSet("Config", pflag.ExitOnError) cmdFlags.String(fmt.Sprintf("%v%v", prefix, "type"), "noop", " Catalog Implementation to use") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "endpoint"), "", " Endpoint for catalog service") - cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "secure"), true, " Connect with TSL/SSL") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "insecure"), false, " Use insecure grpc connection") return cmdFlags } diff --git a/flytepropeller/pkg/controller/catalog/config_flags_test.go b/flytepropeller/pkg/controller/catalog/config_flags_test.go index e5ce16b044..5d01a62377 100755 --- a/flytepropeller/pkg/controller/catalog/config_flags_test.go +++ b/flytepropeller/pkg/controller/catalog/config_flags_test.go @@ -1,6 +1,6 @@ // Code generated by go generate; DO NOT EDIT. // This file was generated by robots at -// 2019-09-05 05:37:07.301294018 -0700 PDT m=+14.696460456 +// 2019-09-10 10:40:36.580780957 -0700 PDT m=+11.375426414 package catalog @@ -129,20 +129,20 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) - t.Run("Test_secure", func(t *testing.T) { + t.Run("Test_insecure", func(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly - if vBool, err := cmdFlags.GetBool("secure"); err == nil { - assert.Equal(t, true, vBool) + if vBool, err := cmdFlags.GetBool("insecure"); err == nil { + assert.Equal(t, false, vBool) } else { assert.FailNow(t, err.Error()) } }) t.Run("Override", func(t *testing.T) { - cmdFlags.Set("secure", "1") - if vBool, err := cmdFlags.GetBool("secure"); err == nil { - testDecodeJson_Config(t, fmt.Sprintf("%v", vBool), &actual.Secure) + cmdFlags.Set("insecure", "1") + if vBool, err := cmdFlags.GetBool("insecure"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vBool), &actual.Insecure) } else { assert.FailNow(t, err.Error()) diff --git a/flytepropeller/pkg/controller/catalog/datacatalog/datacatalog.go b/flytepropeller/pkg/controller/catalog/datacatalog/datacatalog.go index 1cdea9b816..73db43b221 100644 --- a/flytepropeller/pkg/controller/catalog/datacatalog/datacatalog.go +++ b/flytepropeller/pkg/controller/catalog/datacatalog/datacatalog.go @@ -250,7 +250,7 @@ func (m *CatalogClient) Put(ctx context.Context, task *core.TaskTemplate, execID return nil } -func NewDataCatalog(ctx context.Context, endpoint string, secureConnection bool, datastore storage.ProtobufStore) (*CatalogClient, error) { +func NewDataCatalog(ctx context.Context, endpoint string, insecureConnection bool, datastore storage.ProtobufStore) (*CatalogClient, error) { var opts []grpc.DialOption grpcOptions := []grpc_retry.CallOption{ @@ -259,7 +259,11 @@ func NewDataCatalog(ctx context.Context, endpoint string, secureConnection bool, grpc_retry.WithMax(5), } - if secureConnection { + if insecureConnection { + logger.Debug(ctx, "Establishing insecure connection to DataCatalog") + opts = append(opts, grpc.WithInsecure()) + } else { + logger.Debug(ctx, "Establishing secure connection to DataCatalog") pool, err := x509.SystemCertPool() if err != nil { return nil, err From 4668e7e8464e4c8692923d4bafb17aee685e4495 Mon Sep 17 00:00:00 2001 From: Andrew Chan Date: Tue, 10 Sep 2019 11:28:38 -0700 Subject: [PATCH 0095/1918] Add more comments --- flytepropeller/config.yaml | 1 + .../pkg/controller/catalog/datacatalog/datacatalog.go | 4 +++- .../datacatalog/transformer/datacatalog_transformer.go | 3 +++ 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/flytepropeller/config.yaml b/flytepropeller/config.yaml index 6bb16eea30..cbfa17fe95 100644 --- a/flytepropeller/config.yaml +++ b/flytepropeller/config.yaml @@ -91,6 +91,7 @@ admin: catalog-cache: type: catalog endpoint: datacatalog:8089 + insecure: true errors: show-source: true logger: diff --git a/flytepropeller/pkg/controller/catalog/datacatalog/datacatalog.go b/flytepropeller/pkg/controller/catalog/datacatalog/datacatalog.go index 73db43b221..f8aa16883f 100644 --- a/flytepropeller/pkg/controller/catalog/datacatalog/datacatalog.go +++ b/flytepropeller/pkg/controller/catalog/datacatalog/datacatalog.go @@ -31,6 +31,7 @@ type CatalogClient struct { store storage.ProtobufStore } +// Helper method to retrieve an artifact by the tag func (m *CatalogClient) getArtifactByTag(ctx context.Context, tagName string, dataset *datacatalog.Dataset) (*datacatalog.Artifact, error) { logger.Debugf(ctx, "Get Artifact by tag %v", tagName) artifactQuery := &datacatalog.GetArtifactRequest{ @@ -47,6 +48,7 @@ func (m *CatalogClient) getArtifactByTag(ctx context.Context, tagName string, da return response.Artifact, nil } +// Helper method to retrieve a dataset that is associated with the task func (m *CatalogClient) getDataset(ctx context.Context, task *core.TaskTemplate) (*datacatalog.Dataset, error) { datasetID, err := transformer.GenerateDatasetIDForTask(ctx, task) if err != nil { @@ -174,7 +176,6 @@ func (m *CatalogClient) Put(ctx context.Context, task *core.TaskTemplate, execID logger.Debugf(ctx, "DataCatalog put into Catalog for DataSet %v", datasetID) // Try creating the dataset in case it doesn't exist - metadata := &datacatalog.Metadata{ KeyMap: map[string]string{ taskVersionKey: task.Id.Version, @@ -250,6 +251,7 @@ func (m *CatalogClient) Put(ctx context.Context, task *core.TaskTemplate, execID return nil } +// Create a new Datacatalog client for task execution caching func NewDataCatalog(ctx context.Context, endpoint string, insecureConnection bool, datastore storage.ProtobufStore) (*CatalogClient, error) { var opts []grpc.DialOption diff --git a/flytepropeller/pkg/controller/catalog/datacatalog/transformer/datacatalog_transformer.go b/flytepropeller/pkg/controller/catalog/datacatalog/transformer/datacatalog_transformer.go index 6f66b24f40..c5f643cfdd 100644 --- a/flytepropeller/pkg/controller/catalog/datacatalog/transformer/datacatalog_transformer.go +++ b/flytepropeller/pkg/controller/catalog/datacatalog/transformer/datacatalog_transformer.go @@ -27,7 +27,9 @@ func getDatasetNameFromTask(task *core.TaskTemplate) string { return fmt.Sprintf("%s-%s", taskNamespace, task.Id.Name) } +// Transform the artifact Data into task execution outputs as a literal map func GenerateTaskOutputsFromArtifact(task *core.TaskTemplate, artifact *datacatalog.Artifact) (*core.LiteralMap, error) { + // if there are no outputs in the task, return empty map if task.Interface.Outputs == nil || len(task.Interface.Outputs.Variables) == 0 { return &emptyLiteralMap, nil @@ -110,6 +112,7 @@ func generateTaskSignatureHash(ctx context.Context, task *core.TaskTemplate) (st return fmt.Sprintf("%v-%v", inputHashString, outputHashString), nil } +// Generate a tag by hashing the input values func GenerateArtifactTagName(ctx context.Context, inputs *core.LiteralMap) (string, error) { if inputs == nil || len(inputs.Literals) == 0 { inputs = &emptyLiteralMap From 9fc6b2e5d1b1c9e181dd9d995d4a482c9398e31f Mon Sep 17 00:00:00 2001 From: Andrew Chan Date: Tue, 10 Sep 2019 11:33:48 -0700 Subject: [PATCH 0096/1918] go-lint cleanup --- .../pkg/controller/catalog/datacatalog/datacatalog_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flytepropeller/pkg/controller/catalog/datacatalog/datacatalog_test.go b/flytepropeller/pkg/controller/catalog/datacatalog/datacatalog_test.go index d25b94165a..ea3bef477b 100644 --- a/flytepropeller/pkg/controller/catalog/datacatalog/datacatalog_test.go +++ b/flytepropeller/pkg/controller/catalog/datacatalog/datacatalog_test.go @@ -388,7 +388,7 @@ func TestCatalog_Put(t *testing.T) { createDatasetCalled := false mockClient.On("CreateDataset", ctx, - mock.MatchedBy(func(o *datacatalog.CreateDatasetRequest) bool { + mock.MatchedBy(func(_ *datacatalog.CreateDatasetRequest) bool { createDatasetCalled = true return true }), From f1f7feeae935134299b04becb3e58815f711aabf Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Tue, 10 Sep 2019 14:27:38 -0700 Subject: [PATCH 0097/1918] Support deleting resources on KillTask --- .../v1/flytek8s/mocks/K8sResourceHandler.go | 14 ++++++++++++ .../go/tasks/v1/flytek8s/plugin_executor.go | 14 +++++++++++- .../tasks/v1/flytek8s/plugin_executor_test.go | 10 ++++++--- .../go/tasks/v1/flytek8s/plugin_iface.go | 2 ++ .../go/tasks/v1/k8splugins/container.go | 6 ++++- .../go/tasks/v1/k8splugins/sidecar.go | 8 ++++++- flyteplugins/go/tasks/v1/k8splugins/spark.go | 6 ++++- .../go/tasks/v1/k8splugins/waitable_task.go | 4 ++++ .../go/tasks/v1/qubole/config/config_flags.go | 1 + .../v1/qubole/config/config_flags_test.go | 22 +++++++++++++++++++ flyteplugins/go/tasks/v1/types/task.go | 4 ++++ 11 files changed, 84 insertions(+), 7 deletions(-) diff --git a/flyteplugins/go/tasks/v1/flytek8s/mocks/K8sResourceHandler.go b/flyteplugins/go/tasks/v1/flytek8s/mocks/K8sResourceHandler.go index 292d056245..195d6a335b 100755 --- a/flyteplugins/go/tasks/v1/flytek8s/mocks/K8sResourceHandler.go +++ b/flyteplugins/go/tasks/v1/flytek8s/mocks/K8sResourceHandler.go @@ -60,6 +60,20 @@ func (_m *K8sResourceHandler) BuildResource(ctx context.Context, taskCtx types.T return r0, r1 } +// GetProperties provides a mock function with given fields: +func (_m *K8sResourceHandler) GetProperties() types.ExecutorProperties { + ret := _m.Called() + + var r0 types.ExecutorProperties + if rf, ok := ret.Get(0).(func() types.ExecutorProperties); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(types.ExecutorProperties) + } + + return r0 +} + // GetTaskStatus provides a mock function with given fields: ctx, taskCtx, resource func (_m *K8sResourceHandler) GetTaskStatus(ctx context.Context, taskCtx types.TaskContext, resource flytek8s.K8sResource) (types.TaskStatus, *events.TaskEventInfo, error) { ret := _m.Called(ctx, taskCtx, resource) diff --git a/flyteplugins/go/tasks/v1/flytek8s/plugin_executor.go b/flyteplugins/go/tasks/v1/flytek8s/plugin_executor.go index 652e9775f2..2389d24362 100755 --- a/flyteplugins/go/tasks/v1/flytek8s/plugin_executor.go +++ b/flyteplugins/go/tasks/v1/flytek8s/plugin_executor.go @@ -312,14 +312,26 @@ func (e *K8sTaskExecutor) KillTask(ctx context.Context, taskCtx types.TaskContex logger.Warningf(ctx, "Failed to build the Resource with name: %v. Error: %v", taskCtx.GetTaskExecutionID().GetGeneratedName(), err) return err } + AddObjectMetadata(taskCtx, o) + // Retrieve the object from cache/etcd to get the last known version. _, _, err = e.getResource(ctx, taskCtx, o) if err != nil { return err } - return e.ClearFinalizers(ctx, o) + // Clear finalizers + err = e.ClearFinalizers(ctx, o) + if err != nil { + return err + } + + if e.handler.GetProperties().DeleteResourceOnAbort { + return instance.kubeClient.Delete(ctx, o) + } + + return nil } func (e *K8sTaskExecutor) ClearFinalizers(ctx context.Context, o K8sResource) error { diff --git a/flyteplugins/go/tasks/v1/flytek8s/plugin_executor_test.go b/flyteplugins/go/tasks/v1/flytek8s/plugin_executor_test.go index 95015beb50..d2d9814bc9 100755 --- a/flyteplugins/go/tasks/v1/flytek8s/plugin_executor_test.go +++ b/flyteplugins/go/tasks/v1/flytek8s/plugin_executor_test.go @@ -4,10 +4,11 @@ import ( "bytes" "errors" "fmt" - eventErrors "github.com/lyft/flyteidl/clients/go/events/errors" "testing" "time" + eventErrors "github.com/lyft/flyteidl/clients/go/events/errors" + k8serrs "k8s.io/apimachinery/pkg/api/errors" taskerrs "github.com/lyft/flyteplugins/go/tasks/v1/errors" @@ -21,7 +22,7 @@ import ( "github.com/lyft/flyteplugins/go/tasks/v1/flytek8s/mocks" "github.com/lyft/flytestdlib/storage" "github.com/stretchr/testify/mock" - "k8s.io/api/core/v1" + v1 "k8s.io/api/core/v1" v12 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" @@ -52,6 +53,10 @@ func (k8sSampleHandler) GetTaskStatus(ctx context.Context, taskCtx types.TaskCon panic("implement me") } +func (k8sSampleHandler) GetProperties() types.ExecutorProperties { + panic("implement me") +} + func ExampleNewK8sTaskExecutorForResource() { exec := flytek8s.NewK8sTaskExecutorForResource("SampleHandler", &k8sBatch.Job{}, k8sSampleHandler{}, time.Second*1) fmt.Printf("Created executor: %v\n", exec.GetID()) @@ -291,7 +296,6 @@ func TestK8sTaskExecutor_CheckTaskStatus(t *testing.T) { assert.Equal(t, expectedNewStatus, s) }) - t.Run("PhaseMismatch", func(t *testing.T) { evRecorder := &mocks2.EventRecorder{} diff --git a/flyteplugins/go/tasks/v1/flytek8s/plugin_iface.go b/flyteplugins/go/tasks/v1/flytek8s/plugin_iface.go index 844970fd31..9a0bc6e552 100755 --- a/flyteplugins/go/tasks/v1/flytek8s/plugin_iface.go +++ b/flyteplugins/go/tasks/v1/flytek8s/plugin_iface.go @@ -29,6 +29,8 @@ type K8sResourceHandler interface { // Analyses the k8s resource and reports the status as TaskPhase. GetTaskStatus(ctx context.Context, taskCtx types.TaskContext, resource K8sResource) (types.TaskStatus, *events.TaskEventInfo, error) + + GetProperties() types.ExecutorProperties } type K8sResource interface { diff --git a/flyteplugins/go/tasks/v1/k8splugins/container.go b/flyteplugins/go/tasks/v1/k8splugins/container.go index 3223b1f209..c9548cddfb 100755 --- a/flyteplugins/go/tasks/v1/k8splugins/container.go +++ b/flyteplugins/go/tasks/v1/k8splugins/container.go @@ -5,7 +5,7 @@ import ( "time" "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" - "k8s.io/api/core/v1" + v1 "k8s.io/api/core/v1" metaV1 "k8s.io/apimachinery/pkg/apis/meta/v1" tasksV1 "github.com/lyft/flyteplugins/go/tasks/v1" @@ -105,6 +105,10 @@ func (containerTaskExecutor) BuildIdentityResource(_ context.Context, taskCtx ty return flytek8s.BuildIdentityPod(), nil } +func (containerTaskExecutor) GetProperties() types.ExecutorProperties { + return types.ExecutorProperties{} +} + func init() { tasksV1.RegisterLoader(func(ctx context.Context) error { return tasksV1.K8sRegisterAsDefault(containerTaskType, &v1.Pod{}, flytek8s.DefaultInformerResyncDuration, diff --git a/flyteplugins/go/tasks/v1/k8splugins/sidecar.go b/flyteplugins/go/tasks/v1/k8splugins/sidecar.go index 29c8adc627..e111ef9589 100755 --- a/flyteplugins/go/tasks/v1/k8splugins/sidecar.go +++ b/flyteplugins/go/tasks/v1/k8splugins/sidecar.go @@ -88,7 +88,7 @@ func (sidecarResourceHandler) BuildResource( err := utils.UnmarshalStruct(task.GetCustom(), &sidecarJob) if err != nil { return nil, errors.Errorf(errors.BadTaskSpecification, - "invalid TaskSpecification [%v], Err: [%v]", task.GetCustom(), err.Error()) + "invalid TaskSpecification [%v], Err: [%v]", task.GetCustom(), err.Error()) } pod := flytek8s.BuildPodWithSpec(sidecarJob.PodSpec) @@ -188,6 +188,12 @@ func (sidecarResourceHandler) GetTaskStatus( return status, info, err } +func (sidecarResourceHandler) GetProperties() types.ExecutorProperties { + return types.ExecutorProperties{ + DeleteResourceOnAbort: true, + } +} + func init() { v1.RegisterLoader(func(ctx context.Context) error { return v1.K8sRegisterForTaskTypes(sidecarTaskType, &k8sv1.Pod{}, flytek8s.DefaultInformerResyncDuration, diff --git a/flyteplugins/go/tasks/v1/k8splugins/spark.go b/flyteplugins/go/tasks/v1/k8splugins/spark.go index de19a37828..d601dd18de 100755 --- a/flyteplugins/go/tasks/v1/k8splugins/spark.go +++ b/flyteplugins/go/tasks/v1/k8splugins/spark.go @@ -54,6 +54,10 @@ func setSparkConfig(cfg *SparkConfig) error { type sparkResourceHandler struct { } +func (sparkResourceHandler) GetProperties() types.ExecutorProperties { + return types.ExecutorProperties{} +} + // Creates a new Job that will execute the main container as well as any generated types the result from the execution. func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx types.TaskContext, task *core.TaskTemplate, inputs *core.LiteralMap) (flytek8s.K8sResource, error) { @@ -147,7 +151,7 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx types.Tas HadoopConf: sparkJob.GetHadoopConf(), // SubmissionFailures handled here. Task Failures handled at Propeller/Job level. RestartPolicy: sparkOp.RestartPolicy{ - Type: sparkOp.OnFailure, + Type: sparkOp.OnFailure, OnSubmissionFailureRetries: &submissionFailureRetries, }, }, diff --git a/flyteplugins/go/tasks/v1/k8splugins/waitable_task.go b/flyteplugins/go/tasks/v1/k8splugins/waitable_task.go index 2a5f82eeef..632c58fa30 100755 --- a/flyteplugins/go/tasks/v1/k8splugins/waitable_task.go +++ b/flyteplugins/go/tasks/v1/k8splugins/waitable_task.go @@ -524,6 +524,10 @@ func (w waitableTaskExecutor) syncItem(ctx context.Context, obj utils2.CacheItem return waitable, utils2.Unchanged, nil } +func (waitableTaskExecutor) GetProperties() types.ExecutorProperties { + return types.ExecutorProperties{} +} + func newWaitableTaskExecutor(ctx context.Context) (executor *waitableTaskExecutor, err error) { waitableExec := &waitableTaskExecutor{ containerTaskExecutor: containerTaskExecutor{}, diff --git a/flyteplugins/go/tasks/v1/qubole/config/config_flags.go b/flyteplugins/go/tasks/v1/qubole/config/config_flags.go index 8fb6b6ee11..dfe19903b7 100755 --- a/flyteplugins/go/tasks/v1/qubole/config/config_flags.go +++ b/flyteplugins/go/tasks/v1/qubole/config/config_flags.go @@ -45,6 +45,7 @@ func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.String(fmt.Sprintf("%v%v", prefix, "resourceManagerType"), defaultConfig.ResourceManagerType, "Which resource manager to use") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "redisHostPath"), defaultConfig.RedisHostPath, "Redis host location") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "redisHostKey"), defaultConfig.RedisHostKey, "Key for local Redis access") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "redisMaxRetries"), defaultConfig.RedisMaxRetries, "See Redis client options for more info") cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "quboleLimit"), defaultConfig.QuboleLimit, "Global limit for concurrent Qubole queries") cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "lruCacheSize"), defaultConfig.LruCacheSize, "Size of the AutoRefreshCache") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "lookasideBufferPrefix"), defaultConfig.LookasideBufferPrefix, "Prefix used for lookaside buffer") diff --git a/flyteplugins/go/tasks/v1/qubole/config/config_flags_test.go b/flyteplugins/go/tasks/v1/qubole/config/config_flags_test.go index 336b875795..24c3f58b83 100755 --- a/flyteplugins/go/tasks/v1/qubole/config/config_flags_test.go +++ b/flyteplugins/go/tasks/v1/qubole/config/config_flags_test.go @@ -187,6 +187,28 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) + t.Run("Test_redisMaxRetries", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("redisMaxRetries"); err == nil { + assert.Equal(t, int(defaultConfig.RedisMaxRetries), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("redisMaxRetries", testValue) + if vInt, err := cmdFlags.GetInt("redisMaxRetries"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.RedisMaxRetries) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) t.Run("Test_quboleLimit", func(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly diff --git a/flyteplugins/go/tasks/v1/types/task.go b/flyteplugins/go/tasks/v1/types/task.go index 346630f68f..2a9dc6e661 100755 --- a/flyteplugins/go/tasks/v1/types/task.go +++ b/flyteplugins/go/tasks/v1/types/task.go @@ -31,6 +31,10 @@ type ExecutorProperties struct { // If set, the execution engine will not perform node-level task caching and retrieval. This can be useful for more // fine-grained executors that implement their own logic for caching. DisableNodeLevelCaching bool + + // Determines if resources should be actively deleted when abort is attempted. The default behavior is to clear + // finalizers only. If a plugin's resource will automatically be freed by K8s, it should NOT set this field. + DeleteResourceOnAbort bool } // Defines the exposed interface for plugins to record task events. From 275a84d019b252d424a4531d78e82e1dbcd989b5 Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Tue, 10 Sep 2019 14:50:11 -0700 Subject: [PATCH 0098/1918] Fix unit tests --- .../go/tasks/v1/k8splugins/sidecar.go | 25 +++++++++++++++++-- flyteplugins/go/tasks/v1/k8splugins/spark.go | 20 +++++++++++++-- .../go/tasks/v1/k8splugins/waitable_task.go | 25 +++++++++++++++++++ .../go/tasks/v1/qubole/hive_executor.go | 24 +++++++++++++++--- .../go/tasks/v1/qubole/qubole_work_test.go | 14 ++++++----- 5 files changed, 95 insertions(+), 13 deletions(-) diff --git a/flyteplugins/go/tasks/v1/k8splugins/sidecar.go b/flyteplugins/go/tasks/v1/k8splugins/sidecar.go index 29c8adc627..6ab8202011 100755 --- a/flyteplugins/go/tasks/v1/k8splugins/sidecar.go +++ b/flyteplugins/go/tasks/v1/k8splugins/sidecar.go @@ -81,14 +81,35 @@ func validateAndFinalizeContainers( return &pod, nil } +func validateSidecarJob(sidecarJob *plugins.SidecarJob) error { + if sidecarJob == nil { + return fmt.Errorf("empty sidecarjob") + } + + if sidecarJob.PodSpec == nil { + return fmt.Errorf("empty podspec") + } + + if len(sidecarJob.PodSpec.Containers) == 0 { + return fmt.Errorf("empty containers") + } + + return nil +} + func (sidecarResourceHandler) BuildResource( ctx context.Context, taskCtx types.TaskContext, task *core.TaskTemplate, inputs *core.LiteralMap) ( flytek8s.K8sResource, error) { sidecarJob := plugins.SidecarJob{} err := utils.UnmarshalStruct(task.GetCustom(), &sidecarJob) if err != nil { - return nil, errors.Errorf(errors.BadTaskSpecification, - "invalid TaskSpecification [%v], Err: [%v]", task.GetCustom(), err.Error()) + return nil, errors.Wrapf(errors.BadTaskSpecification, err, + "invalid TaskSpecification [%v], failed to unmarshal", task.GetCustom()) + } + + if err = validateSidecarJob(&sidecarJob); err != nil { + return nil, errors.Wrapf(errors.BadTaskSpecification, err, + "invalid TaskSpecification [%v]", task.GetCustom()) } pod := flytek8s.BuildPodWithSpec(sidecarJob.PodSpec) diff --git a/flyteplugins/go/tasks/v1/k8splugins/spark.go b/flyteplugins/go/tasks/v1/k8splugins/spark.go index de19a37828..e46728e1bb 100755 --- a/flyteplugins/go/tasks/v1/k8splugins/spark.go +++ b/flyteplugins/go/tasks/v1/k8splugins/spark.go @@ -54,13 +54,29 @@ func setSparkConfig(cfg *SparkConfig) error { type sparkResourceHandler struct { } +func validateSparkJob(sparkJob *plugins.SparkJob) error { + if sparkJob == nil { + return fmt.Errorf("empty sparkJob") + } + + if len(sparkJob.MainApplicationFile) == 0 && len(sparkJob.MainClass) == 0 { + return fmt.Errorf("either MainApplicationFile or MainClass must be set") + } + + return nil +} + // Creates a new Job that will execute the main container as well as any generated types the result from the execution. func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx types.TaskContext, task *core.TaskTemplate, inputs *core.LiteralMap) (flytek8s.K8sResource, error) { sparkJob := plugins.SparkJob{} err := utils.UnmarshalStruct(task.GetCustom(), &sparkJob) if err != nil { - return nil, errors.Errorf(errors.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", task.GetCustom(), err.Error()) + return nil, errors.Wrapf(errors.BadTaskSpecification, err, "invalid TaskSpecification [%v], failed to unmarshal", task.GetCustom()) + } + + if err = validateSparkJob(&sparkJob); err != nil { + return nil, errors.Wrapf(errors.BadTaskSpecification, err, "invalid TaskSpecification [%v].", task.GetCustom()) } annotations := flytek8s.UnionMaps(config.GetK8sPluginConfig().DefaultAnnotations, utils.CopyMap(taskCtx.GetAnnotations())) @@ -147,7 +163,7 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx types.Tas HadoopConf: sparkJob.GetHadoopConf(), // SubmissionFailures handled here. Task Failures handled at Propeller/Job level. RestartPolicy: sparkOp.RestartPolicy{ - Type: sparkOp.OnFailure, + Type: sparkOp.OnFailure, OnSubmissionFailureRetries: &submissionFailureRetries, }, }, diff --git a/flyteplugins/go/tasks/v1/k8splugins/waitable_task.go b/flyteplugins/go/tasks/v1/k8splugins/waitable_task.go index 2a5f82eeef..bd5e1a1430 100755 --- a/flyteplugins/go/tasks/v1/k8splugins/waitable_task.go +++ b/flyteplugins/go/tasks/v1/k8splugins/waitable_task.go @@ -137,6 +137,11 @@ func discoverWaitableInputs(l *core.Literal) (literals []*core.Literal, waitable return []*core.Literal{}, []*waitableWrapper{} } + if err = validateWaitable(waitable); err != nil { + // skip, it's just a different type? + return []*core.Literal{}, []*waitableWrapper{} + } + return []*core.Literal{l}, []*waitableWrapper{{Waitable: waitable}} } } @@ -309,6 +314,22 @@ func (w waitableTaskExecutor) getUpdatedWaitables(ctx context.Context, taskCtx t return updatedWaitables, allDone, hasChanged, nil } +func validateWaitable(waitable *plugins.Waitable) error { + if waitable == nil { + return fmt.Errorf("empty waitable") + } + + if waitable.WfExecId == nil { + return fmt.Errorf("empty executionID") + } + + if len(waitable.WfExecId.Name) == 0 { + return fmt.Errorf("empty executionID Name") + } + + return nil +} + func updateWaitableLiterals(literals []*core.Literal, waitables []*waitableWrapper) error { index := make(map[string]*plugins.Waitable, len(waitables)) for _, w := range waitables { @@ -321,6 +342,10 @@ func updateWaitableLiterals(literals []*core.Literal, waitables []*waitableWrapp return err } + if err := validateWaitable(orig); err != nil { + return err + } + newW, found := index[orig.WfExecId.String()] if !found { return fmt.Errorf("couldn't find a waitable corresponding to literal WfID: %v", orig.WfExecId.String()) diff --git a/flyteplugins/go/tasks/v1/qubole/hive_executor.go b/flyteplugins/go/tasks/v1/qubole/hive_executor.go index 8e85e30e02..53d213773d 100755 --- a/flyteplugins/go/tasks/v1/qubole/hive_executor.go +++ b/flyteplugins/go/tasks/v1/qubole/hive_executor.go @@ -3,10 +3,11 @@ package qubole import ( "context" "fmt" - "github.com/go-redis/redis" "strconv" "time" + "github.com/go-redis/redis" + eventErrors "github.com/lyft/flyteidl/clients/go/events/errors" "github.com/lyft/flyteplugins/go/tasks/v1/events" @@ -126,6 +127,18 @@ func (h HiveExecutor) getUniqueCacheKey(taskCtx types.TaskContext, idx int) stri return fmt.Sprintf("%s_%d", taskCtx.GetTaskExecutionID().GetGeneratedName(), idx) } +func validateQuboleHiveJob(job *plugins.QuboleHiveJob) error { + if job == nil { + return fmt.Errorf("empty job") + } + + if job.Query == nil && job.QueryCollection == nil { + return fmt.Errorf("either query or queryCollection must be set") + } + + return nil +} + // This function is only ever called once, assuming it doesn't return in error. // Essentially, what this function does is translate the task's custom field into the TaskContext's CustomState // that's stored back into etcd @@ -135,8 +148,13 @@ func (h HiveExecutor) StartTask(ctx context.Context, taskCtx types.TaskContext, hiveJob := plugins.QuboleHiveJob{} err := utils.UnmarshalStruct(task.GetCustom(), &hiveJob) if err != nil { - return types.TaskStatusPermanentFailure(errors.Errorf(errors.BadTaskSpecification, - "Invalid Job Specification in task: [%v]. Err: [%v]", task.GetCustom(), err)), nil + return types.TaskStatusPermanentFailure(errors.Wrapf(errors.BadTaskSpecification, err, + "Invalid Job Specification in task: [%v], failed to unmarshal", task.GetCustom())), nil + } + + if err = validateQuboleHiveJob(&hiveJob); err != nil { + return types.TaskStatusPermanentFailure(errors.Wrapf(errors.BadTaskSpecification, err, + "Invalid Job Specification in task: [%v]", task.GetCustom())), nil } // TODO: Asserts around queries, like len > 0 or something. diff --git a/flyteplugins/go/tasks/v1/qubole/qubole_work_test.go b/flyteplugins/go/tasks/v1/qubole/qubole_work_test.go index af4daa313f..f466dc9a51 100755 --- a/flyteplugins/go/tasks/v1/qubole/qubole_work_test.go +++ b/flyteplugins/go/tasks/v1/qubole/qubole_work_test.go @@ -2,11 +2,12 @@ package qubole import ( "encoding/json" + "strings" + "testing" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" tasksMocks "github.com/lyft/flyteplugins/go/tasks/v1/types/mocks" "github.com/stretchr/testify/assert" - "strings" - "testing" ) func getMockTaskContext() *tasksMocks.TaskContext { @@ -31,6 +32,7 @@ func TestConstructEventInfoFromQuboleWorkItems(t *testing.T) { Status: QuboleWorkSucceeded, ClusterLabel: "default", Tags: []string{}, + CommandUri: "https://api.qubole.com/command/", }, } @@ -156,12 +158,12 @@ func TestInterfaceConverter(t *testing.T) { // This is a complicated step to reproduce what will ultimately be given to the function at runtime, the values // inside the CustomState item := QuboleWorkItem{ - Status: QuboleWorkRunning, - CommandId: "123456", - Query: "", + Status: QuboleWorkRunning, + CommandId: "123456", + Query: "", UniqueWorkCacheKey: "fjdsakfjd", } - raw, err := json.Marshal(map[string]interface{}{"":item}) + raw, err := json.Marshal(map[string]interface{}{"": item}) assert.NoError(t, err) // We can't unmarshal into a interface{} but we can unmarhsal into a interface{} if it's the value of a map. From 42de9d36c8e3c140fe9bdca7afad44e666a0dc69 Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Wed, 11 Sep 2019 11:19:06 -0700 Subject: [PATCH 0099/1918] Update flyteplugins to 0.1.5 --- flytepropeller/Gopkg.lock | 183 ++++++++++++++++---------------------- 1 file changed, 76 insertions(+), 107 deletions(-) diff --git a/flytepropeller/Gopkg.lock b/flytepropeller/Gopkg.lock index 8ac4731381..3f3b4bfa3b 100644 --- a/flytepropeller/Gopkg.lock +++ b/flytepropeller/Gopkg.lock @@ -2,20 +2,12 @@ [[projects]] - digest = "1:80e5d0810f1448259385b08f381852a83f87b6c958d8500e621db821e15c3771" + digest = "1:62010c37b093a63520b29cefd17edb3aa43735b265d2f2ede22e85fa86bf4cf0" name = "cloud.google.com/go" packages = ["compute/metadata"] pruneopts = "" - revision = "cdaaf98f9226c39dc162b8e55083b2fbc67b4674" - version = "v0.43.0" - -[[projects]] - digest = "1:6158256042564abf0da300ea7cb016f79ddaf24fdda2cc06c9712b0c2e06dd2a" - name = "contrib.go.opencensus.io/exporter/ocagent" - packages = ["."] - pruneopts = "" - revision = "dcb33c7f3b7cfe67e8a2cea10207ede1b7c40764" - version = "v0.4.12" + revision = "d1af076dc3f6a314e9dd6610fe83b0f7afaff6d6" + version = "v0.45.1" [[projects]] digest = "1:9a11be778d5fcb8e4873e64a097dfd2862d8665d9e2d969b90810d5272e51acb" @@ -26,19 +18,18 @@ version = "v10.2.1-beta" [[projects]] - digest = "1:5cb9540799639936e705a6ac54cfb6744b598519485fb357acb6e3285f43fbfb" + digest = "1:49d1526a17b46732f253a580a8c978bf8519cfa76e8075d9c8093080360aa104" name = "github.com/Azure/go-autorest" packages = [ "autorest", "autorest/adal", "autorest/azure", - "autorest/date", "logger", "tracing", ] pruneopts = "" - revision = "7166fb346dbf8978ad28211a1937b20fdabc08c8" - version = "v12.4.2" + revision = "69b4126ece6b5257e2f9b0017007d2334153655f" + version = "v13.0.1" [[projects]] digest = "1:558b53577dc0c9fde49b08405d706b202bcac3064320e9be53a75fc866280ee3" @@ -69,7 +60,7 @@ version = "1.0.0" [[projects]] - digest = "1:cfe39a015adcf9cc2bce0e8bd38ecf041cb516b8ab7a2ecb11b1c84a4b8acabf" + digest = "1:95314c6de2df70e7564c960f7edd6c128c9bb2e46800bc40ca48462057648b13" name = "github.com/aws/aws-sdk-go" packages = [ "aws", @@ -93,6 +84,7 @@ "internal/ini", "internal/s3err", "internal/sdkio", + "internal/sdkmath", "internal/sdkrand", "internal/sdkuri", "internal/shareddefaults", @@ -112,8 +104,8 @@ "service/sts/stsiface", ] pruneopts = "" - revision = "14379de571db1ac1b08f2f723a1acc1810c4dd0d" - version = "v1.22.2" + revision = "e51c7493711a6595b86496dc3f6229bae1354c1c" + version = "v1.23.19" [[projects]] branch = "master" @@ -132,27 +124,12 @@ version = "v1.0.1" [[projects]] - digest = "1:ad70cf78ff17abf96d92a6082f4d3241fef8f149118f87c3a267ed47a08be603" - name = "github.com/census-instrumentation/opencensus-proto" - packages = [ - "gen-go/agent/common/v1", - "gen-go/agent/metrics/v1", - "gen-go/agent/trace/v1", - "gen-go/metrics/v1", - "gen-go/resource/v1", - "gen-go/trace/v1", - ] - pruneopts = "" - revision = "d89fa54de508111353cb0b06403c00569be780d8" - version = "v0.2.1" - -[[projects]] - digest = "1:f6485831252319cd6ca29fc170adecf1eb81bf1e805f62f44eb48564ce2485fe" + digest = "1:545ae40d6dde46043a71bdfd7f9a17f2353ce16277c83ac685af231b4b7c4beb" name = "github.com/cespare/xxhash" packages = ["."] pruneopts = "" - revision = "3b82fb7d186719faeedd0c2864f868c74fbf79a1" - version = "v2.0.0" + revision = "de209a9ffae3256185a6bb135d1a0ada7b2b5f09" + version = "v2.1.0" [[projects]] digest = "1:193f6d32d751f26540aa8eeedc114ce0a51f9e77b6c22dda3a4db4e5f65aec66" @@ -170,14 +147,6 @@ revision = "8991bc29aa16c548c550c7ff78260e27b9ab7c73" version = "v1.1.1" -[[projects]] - digest = "1:6098222470fe0172157ce9bbef5d2200df4edde17ee649c5d6e48330e4afa4c6" - name = "github.com/dgrijalva/jwt-go" - packages = ["."] - pruneopts = "" - revision = "06ea1031745cb8b3dab3f6a236daf2b0aa468b7e" - version = "v3.2.0" - [[projects]] digest = "1:46ddeb9dd35d875ac7568c4dc1fc96ce424e034bdbb984239d8ffc151398ec01" name = "github.com/evanphx/json-patch" @@ -227,7 +196,7 @@ version = "v0.1.1" [[projects]] - digest = "1:c2db84082861ca42d0b00580d28f4b31aceec477a00a38e1a057fb3da75c8adc" + digest = "1:b994001ce7517c69ce8173ba12463f329c3fe8aecae84be6a334eae89a99f2fd" name = "github.com/go-redis/redis" packages = [ ".", @@ -239,19 +208,19 @@ "internal/util", ] pruneopts = "" - revision = "75795aa4236dc7341eefac3bbe945e68c99ef9df" - version = "v6.15.3" + revision = "17c058513b3e03c5e23136c18582fdd8ac8a3645" + version = "v6.15.5" [[projects]] - digest = "1:fd53b471edb4c28c7d297f617f4da0d33402755f58d6301e7ca1197ef0a90937" + digest = "1:8a7fe65e9ac2612c4df602cc9f014a92406776d993ff0f28335e5a8831d87c53" name = "github.com/gogo/protobuf" packages = [ "proto", "sortkeys", ] pruneopts = "" - revision = "ba06b47c162d49f2af050fb4c75bcbc86a159d5c" - version = "v1.2.1" + revision = "0ca988a254f991240804bf9821f3450d87ccbb1b" + version = "v1.3.0" [[projects]] branch = "master" @@ -307,7 +276,15 @@ version = "v1.1.1" [[projects]] - digest = "1:5facc3828b6a56f9aec988433ea33fb4407a89460952ed75be5347cec07318c0" + digest = "1:6120f027b9d68ef460b8731e27b0dcf2017f80605c17eb0b4cc151866ba38f6d" + name = "github.com/googleapis/gax-go" + packages = ["v2"] + pruneopts = "" + revision = "bd5b16380fd03dc758d11cef74ba2e3bc8b0e8c2" + version = "v2.0.5" + +[[projects]] + digest = "1:728f28282e0edc47e2d8f41c9ec1956ad645ad6b15e6376ab31e2c3b094fc38f" name = "github.com/googleapis/gnostic" packages = [ "OpenAPIv2", @@ -315,8 +292,8 @@ "extensions", ] pruneopts = "" - revision = "e73c7ec21d36ddb0711cb36d1502d18363b5c2c9" - version = "v0.3.0" + revision = "ab0dd09aa10e2952b28e12ecd35681b20463ebab" + version = "v0.3.1" [[projects]] digest = "1:1ea91d049b6a609f628ecdfda32e85f445a0d3671980dcbf7cbe1bbd7ee6aabc" @@ -366,7 +343,7 @@ version = "v1.2.0" [[projects]] - digest = "1:4ab82898193e99be9d4f1f1eb4ca3b1113ab6b7b2ff4605198ae305de864f05e" + digest = "1:0ebfd2f00a84ee4fb31913b49011b7fa2fb6b12040991d8b948db821a15f7f77" name = "github.com/grpc-ecosystem/grpc-gateway" packages = [ "internal", @@ -375,8 +352,8 @@ "utilities", ] pruneopts = "" - revision = "ad529a448ba494a88058f9e5be0988713174ac86" - version = "v1.9.5" + revision = "471f45a5a99a578de7a8638dc7ed29e245bde097" + version = "v1.11.1" [[projects]] digest = "1:7f6f07500a0b7d3766b00fa466040b97f2f5b5f3eef2ecabfe516e703b05119a" @@ -471,7 +448,7 @@ version = "v0.14.0" [[projects]] - digest = "1:1f4d377eba88e89d78761ad01caa205b214aeed0db4ead48f424538c9a8f7bcf" + digest = "1:77ccfc9618a05ffc14788ee562beb07c86f413fef7fb5421352437d18957eaef" name = "github.com/lyft/flyteplugins" packages = [ "go/tasks", @@ -493,9 +470,9 @@ "go/tasks/v1/utils", ] pruneopts = "" - revision = "99d622cf0f0ca5041e46aad172101536079ea22a" + revision = "01416c973482766e1fca181bbe61e42a6100e9df" source = "https://github.com/lyft/flyteplugins" - version = "v0.1.3" + version = "v0.1.5" [[projects]] digest = "1:77615fd7bcfd377c5e898c41d33be827f108166691cb91257d27113ee5d08650" @@ -544,12 +521,12 @@ version = "v0.0.9" [[projects]] - digest = "1:dbfae9da5a674236b914e486086671145b37b5e3880a38da906665aede3c9eab" + digest = "1:afc2714dedf683e571932f94f8a8ec444679eb84e076e021f63de871c5bc6cb1" name = "github.com/mattn/go-isatty" packages = ["."] pruneopts = "" - revision = "1311e847b0cb909da63b5fecfb5370aa66236465" - version = "v0.0.8" + revision = "e1f7b56ace729e4a73a29a6b4fac6cd5fcda7ab3" + version = "v0.0.9" [[projects]] digest = "1:63722a4b1e1717be7b98fc686e0b30d5e7f734b9e93d7dee86293b6deab7ea28" @@ -661,11 +638,11 @@ [[projects]] branch = "master" - digest = "1:cd67319ee7536399990c4b00fae07c3413035a53193c644549a676091507cadc" + digest = "1:0a565f69553dd41b3de790fde3532e9237142f2637899e20cd3e7396f0c4f2f7" name = "github.com/prometheus/client_model" packages = ["go"] pruneopts = "" - revision = "fd36f4220a901265f90734c3183c5f0c91daa0b8" + revision = "14fe0d1b01d4d5fc031dd4bec1823bd3ebbe8016" [[projects]] digest = "1:0f2cee44695a3208fe5d6926076641499c72304e6f015348c9ab2df90a202cdf" @@ -680,15 +657,16 @@ version = "v0.6.0" [[projects]] - digest = "1:9b33e539d6bf6e4453668a847392d1e9e6345225ea1426f9341212c652bcbee4" + digest = "1:e010d89927008cac947ad9650f643a1b2b668dde47adfe56664da8694c1541d1" name = "github.com/prometheus/procfs" packages = [ ".", "internal/fs", + "internal/util", ] pruneopts = "" - revision = "3f98efb27840a48a7a2898ec80be07674d19f9c8" - version = "v0.0.3" + revision = "00ec24a6a2d86e7074629c8384715dbb05adccd8" + version = "v0.0.4" [[projects]] digest = "1:7f569d906bdd20d906b606415b7d794f798f91a62fcfb6a4daa6d50690fb7a3f" @@ -765,28 +743,27 @@ version = "v0.1.1" [[projects]] - digest = "1:381bcbeb112a51493d9d998bbba207a529c73dbb49b3fd789e48c63fac1f192c" + digest = "1:f7b541897bcde05a04a044c342ddc7425aab7e331f37b47fbb486cd16324b48e" name = "github.com/stretchr/testify" packages = [ "assert", "mock", ] pruneopts = "" - revision = "ffdc059bfe9ce6a4e144ba849dbedead332c6053" - version = "v1.3.0" + revision = "221dbe5ed46703ee255b1da0dec05086f5035f62" + version = "v1.4.0" [[projects]] - digest = "1:98f63c8942146f9bf4b3925db1d96637b86c1d83693a894a244eae54aa53bb40" + digest = "1:1967fb934ef747bf690fcc56487a06c46bf674bd91cb3381a78a7e4d5c2e1a82" name = "go.opencensus.io" packages = [ ".", - "exemplar", "internal", "internal/tagencoding", - "plugin/ocgrpc", + "metric/metricdata", + "metric/metricproducer", "plugin/ochttp", "plugin/ochttp/propagation/b3", - "plugin/ochttp/propagation/tracecontext", "resource", "stats", "stats/internal", @@ -798,7 +775,8 @@ "trace/tracestate", ] pruneopts = "" - revision = "aab39bd6a98b853ab66c8a564f5d6cfcad59ce8a" + revision = "59d1ce35d30f3c25ba762169da2a37eab6ffa041" + version = "v0.22.1" [[projects]] digest = "1:e6ff7840319b6fda979a918a8801005ec2049abca62af19211d96971d8ec3327" @@ -833,15 +811,15 @@ [[projects]] branch = "master" - digest = "1:086760278d762dbb0e9a26e09b57f04c89178c86467d8d94fae47d64c222f328" + digest = "1:4a4c4edf69b61bf98e98d22696aa80c4059384895920de4d5e2fe696068d5f13" name = "golang.org/x/crypto" packages = ["ssh/terminal"] pruneopts = "" - revision = "4def268fd1a49955bfb3dda92fe3db4f924f2285" + revision = "227b76d455e791cb042b03e633e2f7fbcfdf74a5" [[projects]] branch = "master" - digest = "1:955694a7c42527d7fb188505a22f10b3e158c6c2cf31fe64b1e62c9ab7b18401" + digest = "1:3feebd8c7f8c56efb8dd591ccb3227ba5e05863ead67a7e64cbba4c3957f61b4" name = "golang.org/x/net" packages = [ "context", @@ -854,7 +832,7 @@ "trace", ] pruneopts = "" - revision = "ca1201d0de80cfde86cb01aea620983605dfe99b" + revision = "a7b16738d86b947dd0fadb08ca2c2342b51958b6" [[projects]] branch = "master" @@ -872,22 +850,14 @@ [[projects]] branch = "master" - digest = "1:9f6efefb4e401a4f699a295d14518871368eb89403f2dd23ec11dfcd2c0836ba" - name = "golang.org/x/sync" - packages = ["semaphore"] - pruneopts = "" - revision = "112230192c580c3556b8cee6403af37a4fc5f28c" - -[[projects]] - branch = "master" - digest = "1:0b5c2207c72f2d13995040f176feb6e3f453d6b01af2b9d57df76b05ded2e926" + digest = "1:7340151bcf028f4deae54d53a3ab97840214b4b5dabaf1be7c9e16fae9c714ff" name = "golang.org/x/sys" packages = [ "unix", "windows", ] pruneopts = "" - revision = "51ab0e2deafac1f46c46ad59cf0921be2f180c3d" + revision = "bbd175535a8b9969bb0b33f1502f681ee0b122bd" [[projects]] digest = "1:740b51a55815493a8d0f2b1e0d0ae48fe48953bf7eaf3fcc4198823bf67768c0" @@ -924,7 +894,7 @@ [[projects]] branch = "master" - digest = "1:3f52587092bc722a3c3843989e6b88ec26924dc4b7b9c971095b7e93a11e0eff" + digest = "1:012cf0ba83736a2f73f6d35e228a74834cfa66c894922c02b266cb1f4c98b348" name = "golang.org/x/tools" packages = [ "go/ast/astutil", @@ -941,29 +911,28 @@ "internal/semver", ] pruneopts = "" - revision = "e713427fea3f98cb070e72a058c557a1a560cf22" + revision = "4f2ddba30aff720d7cb90a510a695e622273ce77" [[projects]] branch = "master" - digest = "1:f77558501305be5977ac30110f9820d21c5f1a89328667dc82db0bd9ebaab4c4" + digest = "1:b00308a37526131ce24245e25c247112c86344d8cf9608c2f35d4812a09e99c6" name = "google.golang.org/api" packages = [ - "gensupport", "googleapi", "googleapi/internal/uritemplates", "googleapi/transport", "internal", + "internal/gensupport", "option", "storage/v1", - "support/bundler", "transport/http", "transport/http/internal/propagation", ] pruneopts = "" - revision = "6f3912904777a209e099b9dbda3ed7bcb4e25ad7" + revision = "a93ff9fe2e76564e1c26f4691c5a75450d0955f0" [[projects]] - digest = "1:47f391ee443f578f01168347818cb234ed819521e49e4d2c8dd2fb80d48ee41a" + digest = "1:0568e577f790e9bd0420521cff50580f9b38165a38f217ce68f55c4bbaa97066" name = "google.golang.org/appengine" packages = [ ".", @@ -978,12 +947,12 @@ "urlfetch", ] pruneopts = "" - revision = "b2f4a3cf3c67576a2ee09e1fe62656a5086ce880" - version = "v1.6.1" + revision = "5f2a59506353b8d5ba8cbbcd9f3c1f41f1eaf079" + version = "v1.6.2" [[projects]] branch = "master" - digest = "1:95b0a53d4d31736b2483a8c41667b2bd83f303706106f81bd2f54e3f9c24eaf4" + digest = "1:f9e92b6d2b267abfae825d2a674c5d18a8a5c05354c428bff7b9e8536a23a2b6" name = "google.golang.org/genproto" packages = [ "googleapis/api/annotations", @@ -992,10 +961,10 @@ "protobuf/field_mask", ] pruneopts = "" - revision = "fa694d86fc64c7654a660f8908de4e879866748d" + revision = "1774047e7e5133fa3573a4e51b37a586b6b0360c" [[projects]] - digest = "1:425ee670b3e8b6562e31754021a82d78aa46b9281247827376616c8aa78f4687" + digest = "1:e8a4007e58ea9431f6460d1bc5c7f9dd29fdc0211ab780b6ad3d0581478f2076" name = "google.golang.org/grpc" packages = [ ".", @@ -1033,8 +1002,8 @@ "tap", ] pruneopts = "" - revision = "045159ad57f3781d409358e3ade910a018c16b30" - version = "v1.22.1" + revision = "6eaf6f47437a6b4e2153a190160ef39a92c7eceb" + version = "v1.23.0" [[projects]] digest = "1:75fb3fcfc73a8c723efde7777b40e8e8ff9babf30d8c56160d01beffea8a95a6" @@ -1256,7 +1225,7 @@ [[projects]] branch = "master" - digest = "1:6a2a63e09a59caff3fd2d36d69b7b92c2fe7cf783390f0b7349fb330820f9a8e" + digest = "1:a8c60fdc825be548a5cd9f599565d2c0dceff8d77a2b5ea00ff891e4aca33de5" name = "k8s.io/gengo" packages = [ "args", @@ -1269,7 +1238,7 @@ "types", ] pruneopts = "" - revision = "e17681d19d3ac4837a019ece36c2a0ec31ffe985" + revision = "ebc107f98eab922ef99d645781b87caca01f4f48" [[projects]] digest = "1:3063061b6514ad2666c4fa292451685884cacf77c803e1b10b4a4fa23f7787fb" @@ -1281,11 +1250,11 @@ [[projects]] branch = "master" - digest = "1:3176cac3365c8442ab92d465e69e05071b0dbc0d715e66b76059b04611811dff" + digest = "1:71e59e355758d825c891c77bfe3ec2c0b2523b05076e96b2a2bfa804e6ac576a" name = "k8s.io/kube-openapi" packages = ["pkg/util/proto"] pruneopts = "" - revision = "5e22f3d471e6f24ca20becfdffdc6206c7cecac8" + revision = "743ec37842bffe49dd4221d9026f30fb1d5adbc4" [[projects]] digest = "1:77629c3c036454b4623e99e20f5591b9551dd81d92db616384af92435b52e9b6" From 2aa1fa538b004b3f07b2bb223bb8e7988f9fe640 Mon Sep 17 00:00:00 2001 From: Viktor Barinov Date: Wed, 11 Sep 2019 16:16:55 -0700 Subject: [PATCH 0100/1918] flyte-idl-v0.14.1 --- flyteplugins/Gopkg.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flyteplugins/Gopkg.lock b/flyteplugins/Gopkg.lock index 019954ae25..065621edf0 100755 --- a/flyteplugins/Gopkg.lock +++ b/flyteplugins/Gopkg.lock @@ -366,7 +366,7 @@ pruneopts = "" revision = "211b8fe8c2c1d9ab168afd078b62d4f7834171d3" source = "https://github.com/lyft/flyteidl" - version = "v0.14.0" + version = "v0.14.1" [[projects]] digest = "1:c368fe9a00a38c8702e24475dd3a8348d2a191892ef9030aceb821f8c035b737" From fc4129b4378ca45c05df3e770b0924f355f4fe77 Mon Sep 17 00:00:00 2001 From: Viktor Barinov Date: Wed, 11 Sep 2019 16:24:35 -0700 Subject: [PATCH 0101/1918] bump-versions flyteidl and flyteplugins --- flytepropeller/Gopkg.lock | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flytepropeller/Gopkg.lock b/flytepropeller/Gopkg.lock index 6a3182f971..0df859eb1c 100644 --- a/flytepropeller/Gopkg.lock +++ b/flytepropeller/Gopkg.lock @@ -454,7 +454,7 @@ pruneopts = "" revision = "211b8fe8c2c1d9ab168afd078b62d4f7834171d3" source = "https://github.com/lyft/flyteidl" - version = "v0.14.0" + version = "v0.14.1" [[projects]] digest = "1:77ccfc9618a05ffc14788ee562beb07c86f413fef7fb5421352437d18957eaef" @@ -481,7 +481,7 @@ pruneopts = "" revision = "01416c973482766e1fca181bbe61e42a6100e9df" source = "https://github.com/lyft/flyteplugins" - version = "v0.1.5" + version = "v0.1.6" [[projects]] digest = "1:77615fd7bcfd377c5e898c41d33be827f108166691cb91257d27113ee5d08650" From c353571279787853d4e446d39aaa2c89fb94073f Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Wed, 11 Sep 2019 16:47:36 -0700 Subject: [PATCH 0102/1918] github.com/dgrijalva/jwt-go was missing (#8) Could not compile. Apparently dep ensure doesn't update transitive dependencies. This just runs dep ensure -update. --- flytepropeller/Gopkg.lock | 41 ++++++++++++++++++++++++--------------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/flytepropeller/Gopkg.lock b/flytepropeller/Gopkg.lock index 0df859eb1c..a34e7a2852 100644 --- a/flytepropeller/Gopkg.lock +++ b/flytepropeller/Gopkg.lock @@ -24,6 +24,7 @@ "autorest", "autorest/adal", "autorest/azure", + "autorest/date", "logger", "tracing", ] @@ -60,7 +61,7 @@ version = "1.0.0" [[projects]] - digest = "1:95314c6de2df70e7564c960f7edd6c128c9bb2e46800bc40ca48462057648b13" + digest = "1:9f8615d7907ba737b89895391626f81e9dd7776f48ea062d5bea77b46bdcfc89" name = "github.com/aws/aws-sdk-go" packages = [ "aws", @@ -104,8 +105,8 @@ "service/sts/stsiface", ] pruneopts = "" - revision = "e51c7493711a6595b86496dc3f6229bae1354c1c" - version = "v1.23.19" + revision = "8c6586204ba7e9a887fb4cb6ffca2428e6c6dc7c" + version = "v1.23.20" [[projects]] branch = "master" @@ -147,6 +148,14 @@ revision = "8991bc29aa16c548c550c7ff78260e27b9ab7c73" version = "v1.1.1" +[[projects]] + digest = "1:6098222470fe0172157ce9bbef5d2200df4edde17ee649c5d6e48330e4afa4c6" + name = "github.com/dgrijalva/jwt-go" + packages = ["."] + pruneopts = "" + revision = "06ea1031745cb8b3dab3f6a236daf2b0aa468b7e" + version = "v3.2.0" + [[projects]] digest = "1:46ddeb9dd35d875ac7568c4dc1fc96ce424e034bdbb984239d8ffc151398ec01" name = "github.com/evanphx/json-patch" @@ -434,7 +443,7 @@ version = "v0.1.0" [[projects]] - digest = "1:ef7b24655c09b19a0b397e8a58f8f15fc402b349484afad6ce1de0a8f21bb292" + digest = "1:dec8d616f023c717c476cad2b30d5afebb4b6ed2f23d215bf309b9889c2a377b" name = "github.com/lyft/flyteidl" packages = [ "clients/go/admin", @@ -452,12 +461,12 @@ "gen/pb-go/flyteidl/service", ] pruneopts = "" - revision = "211b8fe8c2c1d9ab168afd078b62d4f7834171d3" + revision = "fac461ff1d6926e98b48e2f3374419afdb69de18" source = "https://github.com/lyft/flyteidl" version = "v0.14.1" [[projects]] - digest = "1:77ccfc9618a05ffc14788ee562beb07c86f413fef7fb5421352437d18957eaef" + digest = "1:c6181f0bd353d2558f8b3b0128b0625f0982f6a428664c5fc4f3a69118172799" name = "github.com/lyft/flyteplugins" packages = [ "go/tasks", @@ -479,7 +488,7 @@ "go/tasks/v1/utils", ] pruneopts = "" - revision = "01416c973482766e1fca181bbe61e42a6100e9df" + revision = "fcad09176702bc061d6e0422593834f266a88c13" source = "https://github.com/lyft/flyteplugins" version = "v0.1.6" @@ -859,14 +868,14 @@ [[projects]] branch = "master" - digest = "1:7340151bcf028f4deae54d53a3ab97840214b4b5dabaf1be7c9e16fae9c714ff" + digest = "1:0c8be8c385496c91dd86d6cc727041eafdc2f42537c9f58ec992e8efad0fa923" name = "golang.org/x/sys" packages = [ "unix", "windows", ] pruneopts = "" - revision = "bbd175535a8b9969bb0b33f1502f681ee0b122bd" + revision = "7ad0cfa0b7b5a50bdf0fb49923febdf3742a975c" [[projects]] digest = "1:740b51a55815493a8d0f2b1e0d0ae48fe48953bf7eaf3fcc4198823bf67768c0" @@ -903,7 +912,7 @@ [[projects]] branch = "master" - digest = "1:012cf0ba83736a2f73f6d35e228a74834cfa66c894922c02b266cb1f4c98b348" + digest = "1:a779e74cb881ebe8626d4fee58ad262c4b0d41b5d97a8454e4dda97b638f6177" name = "golang.org/x/tools" packages = [ "go/ast/astutil", @@ -920,11 +929,11 @@ "internal/semver", ] pruneopts = "" - revision = "4f2ddba30aff720d7cb90a510a695e622273ce77" + revision = "6bfd74cf029c99138aa1bb5b7e0d6b57c9d4eb49" [[projects]] branch = "master" - digest = "1:b00308a37526131ce24245e25c247112c86344d8cf9608c2f35d4812a09e99c6" + digest = "1:d8b3f0a0769b9c2694d4ca4fb4b2b6d971d1af008ce9ea9c3b188df76d9d77d6" name = "google.golang.org/api" packages = [ "googleapi", @@ -938,7 +947,7 @@ "transport/http/internal/propagation", ] pruneopts = "" - revision = "a93ff9fe2e76564e1c26f4691c5a75450d0955f0" + revision = "28eb3b1f27f6a4e7185acb5b5074a90b4cc04cc4" [[projects]] digest = "1:0568e577f790e9bd0420521cff50580f9b38165a38f217ce68f55c4bbaa97066" @@ -973,7 +982,7 @@ revision = "1774047e7e5133fa3573a4e51b37a586b6b0360c" [[projects]] - digest = "1:e8a4007e58ea9431f6460d1bc5c7f9dd29fdc0211ab780b6ad3d0581478f2076" + digest = "1:7ed022f305690d5843ba3f0bb94f9890a1f9c9459f0158a28304798213326d88" name = "google.golang.org/grpc" packages = [ ".", @@ -1011,8 +1020,8 @@ "tap", ] pruneopts = "" - revision = "6eaf6f47437a6b4e2153a190160ef39a92c7eceb" - version = "v1.23.0" + revision = "39e8a7b072a67ca2a75f57fa2e0d50995f5b22f6" + version = "v1.23.1" [[projects]] digest = "1:75fb3fcfc73a8c723efde7777b40e8e8ff9babf30d8c56160d01beffea8a95a6" From 7eca17c6673915f08966c38d16ee18c8efeab6ef Mon Sep 17 00:00:00 2001 From: Andrew Chan Date: Thu, 12 Sep 2019 20:53:00 -0700 Subject: [PATCH 0103/1918] Adjust logging levels --- .../pkg/controller/catalog/datacatalog/datacatalog.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/flytepropeller/pkg/controller/catalog/datacatalog/datacatalog.go b/flytepropeller/pkg/controller/catalog/datacatalog/datacatalog.go index f8aa16883f..f133b57664 100644 --- a/flytepropeller/pkg/controller/catalog/datacatalog/datacatalog.go +++ b/flytepropeller/pkg/controller/catalog/datacatalog/datacatalog.go @@ -108,7 +108,7 @@ func (m *CatalogClient) Get(ctx context.Context, task *core.TaskTemplate, inputP dataset, err := m.getDataset(ctx, task) if err != nil { - logger.Errorf(ctx, "DataCatalog failed to get dataset for task %+v, err: %+v", task, err) + logger.Debugf(ctx, "DataCatalog failed to get dataset for task %+v, err: %+v", task.Id, err) return nil, err } @@ -120,7 +120,7 @@ func (m *CatalogClient) Get(ctx context.Context, task *core.TaskTemplate, inputP artifact, err := m.getArtifactByTag(ctx, tag, dataset) if err != nil { - logger.Errorf(ctx, "DataCatalog failed to get artifact by tag %+v, err: %+v", tag, err) + logger.Debugf(ctx, "DataCatalog failed to get artifact by tag %+v, err: %+v", tag, err) return nil, err } logger.Debugf(ctx, "Artifact found %v from tag %v", artifact, tag) @@ -131,7 +131,7 @@ func (m *CatalogClient) Get(ctx context.Context, task *core.TaskTemplate, inputP return nil, err } - logger.Debugf(ctx, "Cached %v artifact outputs from artifact %v", len(outputs.Literals), artifact.Id) + logger.Debugf(ctx, "Retrieved %v artifact outputs from artifact %v", len(outputs.Literals), artifact.Id) return outputs, nil } @@ -227,7 +227,7 @@ func (m *CatalogClient) Put(ctx context.Context, task *core.TaskTemplate, execID // Tag the artifact since it is the cached artifact tagName, err := transformer.GenerateArtifactTagName(ctx, inputs) if err != nil { - logger.Errorf(ctx, "Failed to create tag for artifact %+v, err: %+v", cachedArtifact.Id, err) + logger.Errorf(ctx, "Failed to generate tag for artifact %+v, err: %+v", cachedArtifact.Id, err) return err } logger.Debugf(ctx, "Created tag: %v, for task: %v", tagName, task.Id) @@ -241,7 +241,7 @@ func (m *CatalogClient) Put(ctx context.Context, task *core.TaskTemplate, execID _, err = m.client.AddTag(ctx, &datacatalog.AddTagRequest{Tag: tag}) if err != nil { if status.Code(err) == codes.AlreadyExists { - logger.Errorf(ctx, "Tag %v already exists for Artifact %v (idempotent)", tagName, cachedArtifact.Id) + logger.Warnf(ctx, "Tag %v already exists for Artifact %v (idempotent)", tagName, cachedArtifact.Id) } logger.Errorf(ctx, "Failed to add tag %+v for artifact %+v, err: %+v", tagName, cachedArtifact.Id, err) From 5acf2863e779f24aae052dd84a62ba8816dcdd77 Mon Sep 17 00:00:00 2001 From: Ketan Umare Date: Thu, 12 Sep 2019 21:32:24 -0700 Subject: [PATCH 0104/1918] Mocks for Autorefresh cache and its parts (#36) --- flytestdlib/utils/auto_refresh_cache.go | 2 + flytestdlib/utils/mocks/auto_refresh_cache.go | 56 +++++++++++++++++++ flytestdlib/utils/mocks/cache_item.go | 24 ++++++++ flytestdlib/utils/mocks/rate_limiter.go | 25 +++++++++ flytestdlib/utils/mocks/sequencer.go | 38 +++++++++++++ 5 files changed, 145 insertions(+) create mode 100644 flytestdlib/utils/mocks/auto_refresh_cache.go create mode 100644 flytestdlib/utils/mocks/cache_item.go create mode 100644 flytestdlib/utils/mocks/rate_limiter.go create mode 100644 flytestdlib/utils/mocks/sequencer.go diff --git a/flytestdlib/utils/auto_refresh_cache.go b/flytestdlib/utils/auto_refresh_cache.go index 3453483914..d7d07a18bd 100644 --- a/flytestdlib/utils/auto_refresh_cache.go +++ b/flytestdlib/utils/auto_refresh_cache.go @@ -11,6 +11,8 @@ import ( "k8s.io/apimachinery/pkg/util/wait" ) +//go:generate mockery -all -case=underscore + // AutoRefreshCache with regular GetOrCreate and Delete along with background asynchronous refresh. Caller provides // callbacks for create, refresh and delete item. // The cache doesn't provide apis to update items. diff --git a/flytestdlib/utils/mocks/auto_refresh_cache.go b/flytestdlib/utils/mocks/auto_refresh_cache.go new file mode 100644 index 0000000000..9c6f13efc2 --- /dev/null +++ b/flytestdlib/utils/mocks/auto_refresh_cache.go @@ -0,0 +1,56 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import context "context" +import mock "github.com/stretchr/testify/mock" +import utils "github.com/lyft/flytestdlib/utils" + +// AutoRefreshCache is an autogenerated mock type for the AutoRefreshCache type +type AutoRefreshCache struct { + mock.Mock +} + +// Get provides a mock function with given fields: id +func (_m *AutoRefreshCache) Get(id string) utils.CacheItem { + ret := _m.Called(id) + + var r0 utils.CacheItem + if rf, ok := ret.Get(0).(func(string) utils.CacheItem); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(utils.CacheItem) + } + } + + return r0 +} + +// GetOrCreate provides a mock function with given fields: item +func (_m *AutoRefreshCache) GetOrCreate(item utils.CacheItem) (utils.CacheItem, error) { + ret := _m.Called(item) + + var r0 utils.CacheItem + if rf, ok := ret.Get(0).(func(utils.CacheItem) utils.CacheItem); ok { + r0 = rf(item) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(utils.CacheItem) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(utils.CacheItem) error); ok { + r1 = rf(item) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Start provides a mock function with given fields: ctx +func (_m *AutoRefreshCache) Start(ctx context.Context) { + _m.Called(ctx) +} diff --git a/flytestdlib/utils/mocks/cache_item.go b/flytestdlib/utils/mocks/cache_item.go new file mode 100644 index 0000000000..b802e5bbab --- /dev/null +++ b/flytestdlib/utils/mocks/cache_item.go @@ -0,0 +1,24 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" + +// CacheItem is an autogenerated mock type for the CacheItem type +type CacheItem struct { + mock.Mock +} + +// ID provides a mock function with given fields: +func (_m *CacheItem) ID() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} diff --git a/flytestdlib/utils/mocks/rate_limiter.go b/flytestdlib/utils/mocks/rate_limiter.go new file mode 100644 index 0000000000..7a10626e92 --- /dev/null +++ b/flytestdlib/utils/mocks/rate_limiter.go @@ -0,0 +1,25 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import context "context" +import mock "github.com/stretchr/testify/mock" + +// RateLimiter is an autogenerated mock type for the RateLimiter type +type RateLimiter struct { + mock.Mock +} + +// Wait provides a mock function with given fields: ctx +func (_m *RateLimiter) Wait(ctx context.Context) error { + ret := _m.Called(ctx) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = rf(ctx) + } else { + r0 = ret.Error(0) + } + + return r0 +} diff --git a/flytestdlib/utils/mocks/sequencer.go b/flytestdlib/utils/mocks/sequencer.go new file mode 100644 index 0000000000..0fc724b412 --- /dev/null +++ b/flytestdlib/utils/mocks/sequencer.go @@ -0,0 +1,38 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" + +// Sequencer is an autogenerated mock type for the Sequencer type +type Sequencer struct { + mock.Mock +} + +// GetCur provides a mock function with given fields: +func (_m *Sequencer) GetCur() uint64 { + ret := _m.Called() + + var r0 uint64 + if rf, ok := ret.Get(0).(func() uint64); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(uint64) + } + + return r0 +} + +// GetNext provides a mock function with given fields: +func (_m *Sequencer) GetNext() uint64 { + ret := _m.Called() + + var r0 uint64 + if rf, ok := ret.Get(0).(func() uint64); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(uint64) + } + + return r0 +} From 4057acb9dc30a92b40d673d7b2555aa6af6786e9 Mon Sep 17 00:00:00 2001 From: Andrew Chan Date: Thu, 12 Sep 2019 21:58:55 -0700 Subject: [PATCH 0105/1918] correct metadata to pass workflow exec name --- .../pkg/controller/catalog/datacatalog/datacatalog.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flytepropeller/pkg/controller/catalog/datacatalog/datacatalog.go b/flytepropeller/pkg/controller/catalog/datacatalog/datacatalog.go index f133b57664..bc39a1e87b 100644 --- a/flytepropeller/pkg/controller/catalog/datacatalog/datacatalog.go +++ b/flytepropeller/pkg/controller/catalog/datacatalog/datacatalog.go @@ -179,7 +179,7 @@ func (m *CatalogClient) Put(ctx context.Context, task *core.TaskTemplate, execID metadata := &datacatalog.Metadata{ KeyMap: map[string]string{ taskVersionKey: task.Id.Version, - taskExecKey: execID.NodeExecutionId.NodeId, + taskExecKey: execID.NodeExecutionId.ExecutionId.Name, }, } newDataset := &datacatalog.Dataset{ From fffbc2c5522b90f67cbf19b8c7b484bb1cf35045 Mon Sep 17 00:00:00 2001 From: Andrew Chan Date: Fri, 13 Sep 2019 09:22:09 -0700 Subject: [PATCH 0106/1918] Clarify key name for execution name --- .../pkg/controller/catalog/datacatalog/datacatalog.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flytepropeller/pkg/controller/catalog/datacatalog/datacatalog.go b/flytepropeller/pkg/controller/catalog/datacatalog/datacatalog.go index bc39a1e87b..5728536d9b 100644 --- a/flytepropeller/pkg/controller/catalog/datacatalog/datacatalog.go +++ b/flytepropeller/pkg/controller/catalog/datacatalog/datacatalog.go @@ -22,7 +22,7 @@ import ( const ( taskVersionKey = "task-version" - taskExecKey = "execution-name" + wfExecNameKey = "execution-name" ) // This is the client that caches task executions to DataCatalog service. @@ -179,7 +179,7 @@ func (m *CatalogClient) Put(ctx context.Context, task *core.TaskTemplate, execID metadata := &datacatalog.Metadata{ KeyMap: map[string]string{ taskVersionKey: task.Id.Version, - taskExecKey: execID.NodeExecutionId.ExecutionId.Name, + wfExecNameKey: execID.NodeExecutionId.ExecutionId.Name, }, } newDataset := &datacatalog.Dataset{ From ba33672a566805518cfd3671f85c4f68d57fb56f Mon Sep 17 00:00:00 2001 From: Andrew Chan Date: Mon, 16 Sep 2019 13:43:32 -0700 Subject: [PATCH 0107/1918] Let the first metric scope be namespaced something other than datacatalog --- datacatalog/pkg/rpc/datacatalogservice/service.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datacatalog/pkg/rpc/datacatalogservice/service.go b/datacatalog/pkg/rpc/datacatalogservice/service.go index ce3cfc79c3..0c6b12bb26 100644 --- a/datacatalog/pkg/rpc/datacatalogservice/service.go +++ b/datacatalog/pkg/rpc/datacatalogservice/service.go @@ -68,7 +68,7 @@ func (s *DataCatalogService) AddTag(ctx context.Context, request *catalog.AddTag func NewDataCatalogService() *DataCatalogService { configProvider := runtime.NewConfigurationProvider() dataCatalogConfig := configProvider.ApplicationConfiguration().GetDataCatalogConfig() - catalogScope := promutils.NewScope(dataCatalogConfig.MetricsScope).NewSubScope("service") + catalogScope := promutils.NewScope(dataCatalogConfig.MetricsScope).NewSubScope("datacatalog") ctx := contextutils.WithAppName(context.Background(), "datacatalog") // Set Keys From 4e828d621312d3fd6d6e3fc1e4bcbd0af738a6d3 Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Tue, 17 Sep 2019 14:35:41 -0700 Subject: [PATCH 0108/1918] Filter informer events (#13) * Filter informer events * PR comments --- .../go/tasks/v1/flytek8s/mux_handler.go | 39 +++++++++++-------- .../go/tasks/v1/flytek8s/plugin_executor.go | 3 +- 2 files changed, 25 insertions(+), 17 deletions(-) diff --git a/flyteplugins/go/tasks/v1/flytek8s/mux_handler.go b/flyteplugins/go/tasks/v1/flytek8s/mux_handler.go index 0e7cd948d9..885a4ef5dd 100755 --- a/flyteplugins/go/tasks/v1/flytek8s/mux_handler.go +++ b/flyteplugins/go/tasks/v1/flytek8s/mux_handler.go @@ -6,6 +6,11 @@ import ( "sync" "time" + "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flytestdlib/promutils/labeled" + + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/logger" "sigs.k8s.io/controller-runtime/pkg/cache/informertest" @@ -111,7 +116,7 @@ func Initialize(ctx context.Context, watchNamespace string, resyncPeriod time.Du return nil } -func RegisterResource(ctx context.Context, resourceToWatch runtime.Object, handler Handler) error { +func RegisterResource(_ context.Context, resourceToWatch runtime.Object, handler Handler, metricsScope promutils.Scope) error { if instance == nil { return fmt.Errorf("instance not initialized") } @@ -132,30 +137,32 @@ func RegisterResource(ctx context.Context, resourceToWatch runtime.Object, handl q := workqueue.NewNamedRateLimitingQueue(workqueue.DefaultControllerRateLimiter(), resourceToWatch.GetObjectKind().GroupVersionKind().Kind) + updateCount := labeled.NewCounter("informer_update", "Update events from informer", metricsScope) + droppedUpdateCount := labeled.NewCounter("informer_update_dropped", "Update events from informer that have the same resource version", metricsScope) + err := src.Start(ctrlHandler.Funcs{ CreateFunc: func(evt event.CreateEvent, q2 workqueue.RateLimitingInterface) { - err := handler.Handle(ctx, evt.Object) - if err != nil { - logger.Warnf(ctx, "Failed to handle Create event for object [%v]", evt.Object) - } }, UpdateFunc: func(evt event.UpdateEvent, q2 workqueue.RateLimitingInterface) { - err := handler.Handle(ctx, evt.ObjectNew) - if err != nil { - logger.Warnf(ctx, "Failed to handle Update event for object [%v]", evt.ObjectNew) + if evt.MetaNew == nil { + logger.Warn(context.Background(), "Received an Update event with nil MetaNew.") + } else if evt.MetaOld == nil || evt.MetaOld.GetResourceVersion() != evt.MetaNew.GetResourceVersion() { + newCtx := contextutils.WithNamespace(context.Background(), evt.MetaNew.GetNamespace()) + updateCount.Inc(newCtx) + + logger.Debugf(newCtx, "Enqueueing owner for updated object [%v/%v]", evt.MetaNew.GetNamespace(), evt.MetaNew.GetName()) + err := handler.Handle(newCtx, evt.ObjectNew) + if err != nil { + logger.Warnf(newCtx, "Failed to handle Update event for object [%v]", evt.ObjectNew) + } + } else { + newCtx := contextutils.WithNamespace(context.Background(), evt.MetaNew.GetNamespace()) + droppedUpdateCount.Inc(newCtx) } }, DeleteFunc: func(evt event.DeleteEvent, q2 workqueue.RateLimitingInterface) { - err := handler.Handle(ctx, evt.Object) - if err != nil { - logger.Warnf(ctx, "Failed to handle Delete event for object [%v]", evt.Object) - } }, GenericFunc: func(evt event.GenericEvent, q2 workqueue.RateLimitingInterface) { - err := handler.Handle(ctx, evt.Object) - if err != nil { - logger.Warnf(ctx, "Failed to handle Generic event for object [%v]", evt.Object) - } }, }, q) diff --git a/flyteplugins/go/tasks/v1/flytek8s/plugin_executor.go b/flyteplugins/go/tasks/v1/flytek8s/plugin_executor.go index 2389d24362..e964137de7 100755 --- a/flyteplugins/go/tasks/v1/flytek8s/plugin_executor.go +++ b/flyteplugins/go/tasks/v1/flytek8s/plugin_executor.go @@ -51,6 +51,7 @@ type K8sTaskExecutorMetrics struct { type ownerRegisteringHandler struct { ownerKind string enqueueOwner types.EnqueueOwner + metricsScope promutils.Scope } // A common handle for all k8s-resource reliant task executors that push workflow id on the work queue. @@ -113,7 +114,7 @@ func (e *K8sTaskExecutor) Initialize(ctx context.Context, params types.ExecutorI return RegisterResource(ctx, e.resourceToWatch, ownerRegisteringHandler{ enqueueOwner: params.EnqueueOwner, ownerKind: params.OwnerKind, - }) + }, metricScope) } func (e K8sTaskExecutor) HandleTaskSuccess(ctx context.Context, taskCtx types.TaskContext) (types.TaskStatus, error) { From b16160c3be62be8ab8099480b34c28a94e44c0d4 Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Tue, 17 Sep 2019 16:07:20 -0700 Subject: [PATCH 0109/1918] Update plugins to 0.1.7 (#11) --- flytepropeller/Gopkg.lock | 74 +++++++++++++++++++-------------------- 1 file changed, 37 insertions(+), 37 deletions(-) diff --git a/flytepropeller/Gopkg.lock b/flytepropeller/Gopkg.lock index a34e7a2852..184534eed4 100644 --- a/flytepropeller/Gopkg.lock +++ b/flytepropeller/Gopkg.lock @@ -2,12 +2,12 @@ [[projects]] - digest = "1:62010c37b093a63520b29cefd17edb3aa43735b265d2f2ede22e85fa86bf4cf0" + digest = "1:fd319661e53f3607b8ddbded6121f8e4fe42978cb66968b20fdee68e10d10f9f" name = "cloud.google.com/go" packages = ["compute/metadata"] pruneopts = "" - revision = "d1af076dc3f6a314e9dd6610fe83b0f7afaff6d6" - version = "v0.45.1" + revision = "264def2dd949cdb8a803bb9f50fa29a67b798a6a" + version = "v0.46.3" [[projects]] digest = "1:9a11be778d5fcb8e4873e64a097dfd2862d8665d9e2d969b90810d5272e51acb" @@ -61,7 +61,7 @@ version = "1.0.0" [[projects]] - digest = "1:9f8615d7907ba737b89895391626f81e9dd7776f48ea062d5bea77b46bdcfc89" + digest = "1:3cb5a7438cd06b18f192433cdb375f557c3c2eb6d97b88f0eb27c42f65c6b4dd" name = "github.com/aws/aws-sdk-go" packages = [ "aws", @@ -105,8 +105,8 @@ "service/sts/stsiface", ] pruneopts = "" - revision = "8c6586204ba7e9a887fb4cb6ffca2428e6c6dc7c" - version = "v1.23.20" + revision = "b4225bdde03b756685c89b7db31fe2ffac6d9234" + version = "v1.24.0" [[projects]] branch = "master" @@ -331,7 +331,7 @@ revision = "901d90724c7919163f472a9812253fb26761123d" [[projects]] - digest = "1:9a0b2dd1f882668a3d7fbcd424eed269c383a16f1faa3a03d14e0dd5fba571b1" + digest = "1:cbea643bd7f1c76bb6e48ab08f3dd01456602ab7b252f3c85133d7a1a6413a18" name = "github.com/grpc-ecosystem/go-grpc-middleware" packages = [ ".", @@ -340,8 +340,8 @@ "util/metautils", ] pruneopts = "" - revision = "c250d6563d4d4c20252cd865923440e829844f4e" - version = "v1.0.0" + revision = "dd15ed025b6054e5253963e355991f3070d4e593" + version = "v1.1.0" [[projects]] digest = "1:e24dc5ef44694848785de507f439a24e9e6d96d7b43b8cf3d6cfa857aa1e2186" @@ -434,13 +434,13 @@ version = "v1.0.2" [[projects]] - digest = "1:eaef68aaa87572012b236975f558582a043044c21be3fda97921c4871fb4298f" + digest = "1:697f2e18b181eeb64fbd3ff601636816439d89535e00dd6fd34ea1b117543cd9" name = "github.com/lyft/datacatalog" packages = ["protos/gen"] pruneopts = "" - revision = "0da0ffbb4705efd5d5ecd04ea560b35d968beb86" + revision = "2ee1d756da821e8616019b8e5310119c01b611ba" source = "https://github.com/lyft/datacatalog" - version = "v0.1.0" + version = "v0.1.1" [[projects]] digest = "1:dec8d616f023c717c476cad2b30d5afebb4b6ed2f23d215bf309b9889c2a377b" @@ -466,7 +466,7 @@ version = "v0.14.1" [[projects]] - digest = "1:c6181f0bd353d2558f8b3b0128b0625f0982f6a428664c5fc4f3a69118172799" + digest = "1:ab33b506292441cab3455961ad86e9e5e81989b2f6089d4437d2b15a6d417dc2" name = "github.com/lyft/flyteplugins" packages = [ "go/tasks", @@ -488,12 +488,12 @@ "go/tasks/v1/utils", ] pruneopts = "" - revision = "fcad09176702bc061d6e0422593834f266a88c13" + revision = "31b0fb21c6bf86f9e5e81297d9ad3de3feaea6e6" source = "https://github.com/lyft/flyteplugins" - version = "v0.1.6" + version = "v0.1.7" [[projects]] - digest = "1:77615fd7bcfd377c5e898c41d33be827f108166691cb91257d27113ee5d08650" + digest = "1:b8860863eb1eb7fe4f7d2b3c6b34b89aaedeaae88d761d2c9a4eccc46262abcc" name = "github.com/lyft/flytestdlib" packages = [ "atomic", @@ -515,9 +515,9 @@ "yamlutils", ] pruneopts = "" - revision = "7292f20ec17b42f104fd61d7f0120e17bcacf751" + revision = "a758bbb91a09d6e9474fa160a633fdb05f3abc03" source = "https://github.com/lyft/flytestdlib" - version = "v0.2.16" + version = "v0.2.17" [[projects]] digest = "1:ae39921edb7f801f7ce1b6b5484f9715a1dd2b52cb645daef095cd10fd6ee774" @@ -663,7 +663,7 @@ revision = "14fe0d1b01d4d5fc031dd4bec1823bd3ebbe8016" [[projects]] - digest = "1:0f2cee44695a3208fe5d6926076641499c72304e6f015348c9ab2df90a202cdf" + digest = "1:8904acfa3ef080005c1fc0670ed0471739d1e211be5638cfa6af536b701942ae" name = "github.com/prometheus/common" packages = [ "expfmt", @@ -671,11 +671,11 @@ "model", ] pruneopts = "" - revision = "31bed53e4047fd6c510e43a941f90cb31be0972a" - version = "v0.6.0" + revision = "287d3e634a1e550c9e463dd7e5a75a422c614505" + version = "v0.7.0" [[projects]] - digest = "1:e010d89927008cac947ad9650f643a1b2b668dde47adfe56664da8694c1541d1" + digest = "1:af5cd8219fd15c06eadaab455c0beb72f2f7bb32d298acb401d30c452a8dbd7e" name = "github.com/prometheus/procfs" packages = [ ".", @@ -683,8 +683,8 @@ "internal/util", ] pruneopts = "" - revision = "00ec24a6a2d86e7074629c8384715dbb05adccd8" - version = "v0.0.4" + revision = "499c85531f756d1129edd26485a5f73871eeb308" + version = "v0.0.5" [[projects]] digest = "1:7f569d906bdd20d906b606415b7d794f798f91a62fcfb6a4daa6d50690fb7a3f" @@ -738,12 +738,12 @@ version = "v1.1.0" [[projects]] - digest = "1:cbaf13cdbfef0e4734ed8a7504f57fe893d471d62a35b982bf6fb3f036449a66" + digest = "1:1bff633980ce46a718d753a3bc7e3eb6d7e4df34bc7e4a9869e71ccb4314dc40" name = "github.com/spf13/pflag" packages = ["."] pruneopts = "" - revision = "298182f68c66c05229eb03ac171abe6e309ee79a" - version = "v1.0.3" + revision = "e8f29969b682c41a730f8f08b76033b120498464" + version = "v1.0.4" [[projects]] digest = "1:c25a789c738f7cc8ec7f34026badd4e117853f329334a5aa45cf5d0727d7d442" @@ -837,7 +837,7 @@ [[projects]] branch = "master" - digest = "1:3feebd8c7f8c56efb8dd591ccb3227ba5e05863ead67a7e64cbba4c3957f61b4" + digest = "1:af83d44f1195692d76697f62af82329be4d55b100d33b9b3db8f5d0f44563fb9" name = "golang.org/x/net" packages = [ "context", @@ -850,7 +850,7 @@ "trace", ] pruneopts = "" - revision = "a7b16738d86b947dd0fadb08ca2c2342b51958b6" + revision = "c8589233b77dde5edd2205ba8a4fb5c9c2472556" [[projects]] branch = "master" @@ -868,14 +868,14 @@ [[projects]] branch = "master" - digest = "1:0c8be8c385496c91dd86d6cc727041eafdc2f42537c9f58ec992e8efad0fa923" + digest = "1:d75336f9fd966011ede7a692794c112e81f0cb80d6e0082ad352e1986fc7b5ee" name = "golang.org/x/sys" packages = [ "unix", "windows", ] pruneopts = "" - revision = "7ad0cfa0b7b5a50bdf0fb49923febdf3742a975c" + revision = "b4ddaad3f8a36719f2b8bc6486c14cc468ca2bb5" [[projects]] digest = "1:740b51a55815493a8d0f2b1e0d0ae48fe48953bf7eaf3fcc4198823bf67768c0" @@ -912,7 +912,7 @@ [[projects]] branch = "master" - digest = "1:a779e74cb881ebe8626d4fee58ad262c4b0d41b5d97a8454e4dda97b638f6177" + digest = "1:3eccfc7625ff5137b398e7b16323d272f2423cd04985de69c79c3e0793494e8f" name = "golang.org/x/tools" packages = [ "go/ast/astutil", @@ -929,11 +929,11 @@ "internal/semver", ] pruneopts = "" - revision = "6bfd74cf029c99138aa1bb5b7e0d6b57c9d4eb49" + revision = "1cc945182204dc50f5760d9859d043a3d4e27047" [[projects]] branch = "master" - digest = "1:d8b3f0a0769b9c2694d4ca4fb4b2b6d971d1af008ce9ea9c3b188df76d9d77d6" + digest = "1:1163141bda0af433265042eb4345e4433f1302991b6be9bfa623bd0316004e32" name = "google.golang.org/api" packages = [ "googleapi", @@ -947,7 +947,7 @@ "transport/http/internal/propagation", ] pruneopts = "" - revision = "28eb3b1f27f6a4e7185acb5b5074a90b4cc04cc4" + revision = "7439972e83a764c2a87f6fa78c4a37e8ed2db615" [[projects]] digest = "1:0568e577f790e9bd0420521cff50580f9b38165a38f217ce68f55c4bbaa97066" @@ -970,7 +970,7 @@ [[projects]] branch = "master" - digest = "1:f9e92b6d2b267abfae825d2a674c5d18a8a5c05354c428bff7b9e8536a23a2b6" + digest = "1:6112f28261ff024a08e50da048719d3be7bd39397ec64aab2aa25aac97e668c7" name = "google.golang.org/genproto" packages = [ "googleapis/api/annotations", @@ -979,7 +979,7 @@ "protobuf/field_mask", ] pruneopts = "" - revision = "1774047e7e5133fa3573a4e51b37a586b6b0360c" + revision = "f660b865573183437d2d868f703fe88bb8af0b55" [[projects]] digest = "1:7ed022f305690d5843ba3f0bb94f9890a1f9c9459f0158a28304798213326d88" From 29d13a324b7fc466111d4d3b62c3318f43effc65 Mon Sep 17 00:00:00 2001 From: Ketan Umare Date: Wed, 18 Sep 2019 15:05:51 -0700 Subject: [PATCH 0110/1918] Metrics captured for offending workflows --- flytepropeller/pkg/controller/handler.go | 30 ++++++++++--------- .../pkg/controller/nodes/executor.go | 6 +++- .../pkg/controller/nodes/task/handler.go | 6 ++++ 3 files changed, 27 insertions(+), 15 deletions(-) diff --git a/flytepropeller/pkg/controller/handler.go b/flytepropeller/pkg/controller/handler.go index 798d7b6c0a..ab7d8bb2c9 100644 --- a/flytepropeller/pkg/controller/handler.go +++ b/flytepropeller/pkg/controller/handler.go @@ -6,10 +6,12 @@ import ( "runtime/debug" "time" + "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flytestdlib/promutils/labeled" + "github.com/lyft/flytepropeller/pkg/controller/config" "github.com/lyft/flytepropeller/pkg/controller/workflowstore" - "github.com/lyft/flytestdlib/contextutils" "github.com/lyft/flytestdlib/logger" "github.com/lyft/flytestdlib/promutils" "github.com/prometheus/client_golang/prometheus" @@ -22,10 +24,10 @@ import ( type propellerMetrics struct { Scope promutils.Scope DeepCopyTime promutils.StopWatch - RawWorkflowTraversalTime promutils.StopWatch - SystemError prometheus.Counter - AbortError prometheus.Counter - PanicObserved prometheus.Counter + RawWorkflowTraversalTime labeled.StopWatch + SystemError labeled.Counter + AbortError labeled.Counter + PanicObserved labeled.Counter RoundSkipped prometheus.Counter WorkflowNotFound prometheus.Counter } @@ -35,10 +37,10 @@ func newPropellerMetrics(scope promutils.Scope) *propellerMetrics { return &propellerMetrics{ Scope: scope, DeepCopyTime: roundScope.MustNewStopWatch("deepcopy", "Total time to deep copy wf object", time.Millisecond), - RawWorkflowTraversalTime: roundScope.MustNewStopWatch("raw", "Total time to traverse the workflow", time.Millisecond), - SystemError: roundScope.MustNewCounter("system_error", "Failure to reconcile a workflow, system error"), - AbortError: roundScope.MustNewCounter("abort_error", "Failure to abort a workflow, system error"), - PanicObserved: roundScope.MustNewCounter("panic", "Panic during handling or aborting workflow"), + RawWorkflowTraversalTime: labeled.NewStopWatch("raw", "Total time to traverse the workflow", time.Millisecond, roundScope, labeled.EmitUnlabeledMetric), + SystemError: labeled.NewCounter("system_error", "Failure to reconcile a workflow, system error", roundScope, labeled.EmitUnlabeledMetric), + AbortError: labeled.NewCounter("abort_error", "Failure to abort a workflow, system error", roundScope, labeled.EmitUnlabeledMetric), + PanicObserved: labeled.NewCounter("panic", "Panic during handling or aborting workflow", roundScope, labeled.EmitUnlabeledMetric), RoundSkipped: roundScope.MustNewCounter("skipped", "Round Skipped because of stale workflow"), WorkflowNotFound: roundScope.MustNewCounter("not_found", "workflow not found in the cache"), } @@ -111,13 +113,13 @@ func (p *Propeller) Handle(ctx context.Context, namespace, name string) error { if r := recover(); r != nil { stack := debug.Stack() err = fmt.Errorf("panic when aborting workflow, Stack: [%s]", string(stack)) - p.metrics.PanicObserved.Inc() + p.metrics.PanicObserved.Inc(ctx) } }() err = p.workflowExecutor.HandleAbortedWorkflow(ctx, wfDeepCopy, maxRetries) }() if err != nil { - p.metrics.AbortError.Inc() + p.metrics.AbortError.Inc(ctx) return err } } else { @@ -133,13 +135,13 @@ func (p *Propeller) Handle(ctx context.Context, namespace, name string) error { SetFinalizerIfEmpty(wfDeepCopy, FinalizerKey) func() { - t := p.metrics.RawWorkflowTraversalTime.Start() + t := p.metrics.RawWorkflowTraversalTime.Start(ctx) defer func() { t.Stop() if r := recover(); r != nil { stack := debug.Stack() err = fmt.Errorf("panic when aborting workflow, Stack: [%s]", string(stack)) - p.metrics.PanicObserved.Inc() + p.metrics.PanicObserved.Inc(ctx) } }() err = p.workflowExecutor.HandleFlyteWorkflow(ctx, wfDeepCopy) @@ -152,7 +154,7 @@ func (p *Propeller) Handle(ctx context.Context, namespace, name string) error { wfDeepCopy = w.DeepCopy() wfDeepCopy.GetExecutionStatus().IncFailedAttempts() wfDeepCopy.GetExecutionStatus().SetMessage(err.Error()) - p.metrics.SystemError.Inc() + p.metrics.SystemError.Inc(ctx) } else { // No updates in the status we detected, we will skip writing to KubeAPI if wfDeepCopy.Status.Equals(&w.Status) { diff --git a/flytepropeller/pkg/controller/nodes/executor.go b/flytepropeller/pkg/controller/nodes/executor.go index 7d0ea5ee32..9af320fec4 100644 --- a/flytepropeller/pkg/controller/nodes/executor.go +++ b/flytepropeller/pkg/controller/nodes/executor.go @@ -36,7 +36,8 @@ type nodeMetrics struct { TransitionLatency labeled.StopWatch // Measures the latency between the time a node's been queued to the time the handler reported the executable moved // to running state - QueuingLatency labeled.StopWatch + QueuingLatency labeled.StopWatch + NodeExecutionTime labeled.StopWatch } type nodeExecutor struct { @@ -455,6 +456,8 @@ func (c *nodeExecutor) RecursiveNodeHandler(ctx context.Context, w v1alpha1.Exec switch nodeStatus.GetPhase() { case v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseFailing, v1alpha1.NodePhaseRetryableFailure, v1alpha1.NodePhaseSucceeding: logger.Debugf(currentNodeCtx, "Handling node Status [%v]", nodeStatus.GetPhase().String()) + t := c.metrics.NodeExecutionTime.Start(ctx) + defer t.Stop() return c.executeNode(currentNodeCtx, w, currentNode) // TODO we can optimize skip state handling by iterating down the graph and marking all as skipped // Currently we treat either Skip or Success the same way. In this approach only one node will be skipped @@ -521,6 +524,7 @@ func NewExecutor(ctx context.Context, store *storage.DataStore, enQWorkflow v1al ResolutionFailure: labeled.NewCounter("input_resolve_fail", "Indicates failure in resolving node inputs", nodeScope), TransitionLatency: labeled.NewStopWatch("transition_latency", "Measures the latency between the last parent node stoppedAt time and current node's queued time.", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), QueuingLatency: labeled.NewStopWatch("queueing_latency", "Measures the latency between the time a node's been queued to the time the handler reported the executable moved to running state", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + NodeExecutionTime: labeled.NewStopWatch("node_exec_latency", "Measures the time taken to execute one node, a node can be complex so it may encompass sub-node latency.", time.Microsecond, nodeScope, labeled.EmitUnlabeledMetric), }, } nodeHandlerFactory, err := NewHandlerFactory( diff --git a/flytepropeller/pkg/controller/nodes/task/handler.go b/flytepropeller/pkg/controller/nodes/task/handler.go index dc0776a17d..25318f6537 100644 --- a/flytepropeller/pkg/controller/nodes/task/handler.go +++ b/flytepropeller/pkg/controller/nodes/task/handler.go @@ -126,6 +126,7 @@ type metrics struct { discoveryGetFailureCount labeled.Counter discoveryMissCount labeled.Counter discoveryHitCount labeled.Counter + pluginExecutionLatency labeled.StopWatch // TODO We should have a metric to capture custom state size } @@ -255,7 +256,9 @@ func (h *taskHandler) StartNode(ctx context.Context, w v1alpha1.ExecutableWorkfl logger.Errorf(ctx, "Panic in plugin for TaskType [%s]", task.TaskType()) } }() + t := h.metrics.pluginExecutionLatency.Start(ctx) taskStatus, err = t.StartTask(ctx, taskCtx, task.CoreTask(), nodeInputs) + t.Stop() }() if err != nil { @@ -306,7 +309,9 @@ func (h *taskHandler) CheckNodeStatus(ctx context.Context, w v1alpha1.Executable logger.Errorf(ctx, "Panic in plugin for TaskType [%s]", task.TaskType()) } }() + t := h.metrics.pluginExecutionLatency.Start(ctx) taskStatus, err = t.CheckTaskStatus(ctx, taskCtx, task.CoreTask()) + t.Stop() }() if err != nil { @@ -428,6 +433,7 @@ func NewTaskHandlerForFactory(eventSink events.EventSink, store *storage.DataSto discoveryMissCount: labeled.NewCounter("discovery_miss_count", "Task not cached in Discovery", scope), discoveryPutFailureCount: labeled.NewCounter("discovery_put_failure_count", "Discovery Put failure count", scope), discoveryGetFailureCount: labeled.NewCounter("discovery_get_failure_count", "Discovery Get faillure count", scope), + pluginExecutionLatency: labeled.NewStopWatch("plugin_exec_latecny", "Time taken to invoke plugin for one round", time.Microsecond, scope), }, } } From c20ec36535847f89e4160ac5d450ecd8aadb4b4b Mon Sep 17 00:00:00 2001 From: Ketan Umare Date: Wed, 18 Sep 2019 15:16:14 -0700 Subject: [PATCH 0111/1918] Another metric to measure time taken to gather inputs --- .../pkg/controller/nodes/executor.go | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/flytepropeller/pkg/controller/nodes/executor.go b/flytepropeller/pkg/controller/nodes/executor.go index 9af320fec4..3f7205e6ef 100644 --- a/flytepropeller/pkg/controller/nodes/executor.go +++ b/flytepropeller/pkg/controller/nodes/executor.go @@ -36,8 +36,9 @@ type nodeMetrics struct { TransitionLatency labeled.StopWatch // Measures the latency between the time a node's been queued to the time the handler reported the executable moved // to running state - QueuingLatency labeled.StopWatch - NodeExecutionTime labeled.StopWatch + QueuingLatency labeled.StopWatch + NodeExecutionTime labeled.StopWatch + NodeInputGatherLatency labeled.StopWatch } type nodeExecutor struct { @@ -104,6 +105,8 @@ func (c *nodeExecutor) startNode(ctx context.Context, w v1alpha1.ExecutableWorkf dataDir := nodeStatus.GetDataDir() var nodeInputs *handler.Data if !node.IsStartNode() { + t := c.metrics.NodeInputGatherLatency.Start(ctx) + defer t.Stop() // Can execute var err error nodeInputs, err = Resolve(ctx, c.nodeHandlerFactory, w, node.GetID(), node.GetInputBindings(), c.store) @@ -518,13 +521,14 @@ func NewExecutor(ctx context.Context, store *storage.DataStore, enQWorkflow v1al enqueueWorkflow: enQWorkflow, nodeRecorder: events.NewNodeEventRecorder(eventSink, nodeScope), metrics: &nodeMetrics{ - FailureDuration: labeled.NewStopWatch("failure_duration", "Indicates the total execution time of a failed workflow.", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), - SuccessDuration: labeled.NewStopWatch("success_duration", "Indicates the total execution time of a successful workflow.", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), - InputsWriteFailure: labeled.NewCounter("inputs_write_fail", "Indicates failure in writing node inputs to metastore", nodeScope), - ResolutionFailure: labeled.NewCounter("input_resolve_fail", "Indicates failure in resolving node inputs", nodeScope), - TransitionLatency: labeled.NewStopWatch("transition_latency", "Measures the latency between the last parent node stoppedAt time and current node's queued time.", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), - QueuingLatency: labeled.NewStopWatch("queueing_latency", "Measures the latency between the time a node's been queued to the time the handler reported the executable moved to running state", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), - NodeExecutionTime: labeled.NewStopWatch("node_exec_latency", "Measures the time taken to execute one node, a node can be complex so it may encompass sub-node latency.", time.Microsecond, nodeScope, labeled.EmitUnlabeledMetric), + FailureDuration: labeled.NewStopWatch("failure_duration", "Indicates the total execution time of a failed workflow.", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + SuccessDuration: labeled.NewStopWatch("success_duration", "Indicates the total execution time of a successful workflow.", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + InputsWriteFailure: labeled.NewCounter("inputs_write_fail", "Indicates failure in writing node inputs to metastore", nodeScope), + ResolutionFailure: labeled.NewCounter("input_resolve_fail", "Indicates failure in resolving node inputs", nodeScope), + TransitionLatency: labeled.NewStopWatch("transition_latency", "Measures the latency between the last parent node stoppedAt time and current node's queued time.", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + QueuingLatency: labeled.NewStopWatch("queueing_latency", "Measures the latency between the time a node's been queued to the time the handler reported the executable moved to running state", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + NodeExecutionTime: labeled.NewStopWatch("node_exec_latency", "Measures the time taken to execute one node, a node can be complex so it may encompass sub-node latency.", time.Microsecond, nodeScope, labeled.EmitUnlabeledMetric), + NodeInputGatherLatency: labeled.NewStopWatch("node_input_latency", "Measures the latency to aggregate inputs and check readiness of a node", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), }, } nodeHandlerFactory, err := NewHandlerFactory( From e8594d3e02704556e474b661f5168029e61ebefe Mon Sep 17 00:00:00 2001 From: Ketan Umare Date: Wed, 18 Sep 2019 16:12:52 -0700 Subject: [PATCH 0112/1918] Variable created would shadow external variable --- flytepropeller/pkg/controller/nodes/task/handler.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/flytepropeller/pkg/controller/nodes/task/handler.go b/flytepropeller/pkg/controller/nodes/task/handler.go index 25318f6537..3b94d1903c 100644 --- a/flytepropeller/pkg/controller/nodes/task/handler.go +++ b/flytepropeller/pkg/controller/nodes/task/handler.go @@ -256,9 +256,9 @@ func (h *taskHandler) StartNode(ctx context.Context, w v1alpha1.ExecutableWorkfl logger.Errorf(ctx, "Panic in plugin for TaskType [%s]", task.TaskType()) } }() - t := h.metrics.pluginExecutionLatency.Start(ctx) + s := h.metrics.pluginExecutionLatency.Start(ctx) taskStatus, err = t.StartTask(ctx, taskCtx, task.CoreTask(), nodeInputs) - t.Stop() + s.Stop() }() if err != nil { @@ -309,9 +309,9 @@ func (h *taskHandler) CheckNodeStatus(ctx context.Context, w v1alpha1.Executable logger.Errorf(ctx, "Panic in plugin for TaskType [%s]", task.TaskType()) } }() - t := h.metrics.pluginExecutionLatency.Start(ctx) + s := h.metrics.pluginExecutionLatency.Start(ctx) taskStatus, err = t.CheckTaskStatus(ctx, taskCtx, task.CoreTask()) - t.Stop() + s.Stop() }() if err != nil { From 457699280104bd98356260ed28526e72dcaa7285 Mon Sep 17 00:00:00 2001 From: Ketan Umare Date: Wed, 18 Sep 2019 21:48:16 -0700 Subject: [PATCH 0113/1918] Advanced options to configure kube client for large scale deployments --- flytepropeller/cmd/controller/cmd/root.go | 10 ++++++++-- flytepropeller/pkg/controller/config/config.go | 12 ++++++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/flytepropeller/cmd/controller/cmd/root.go b/flytepropeller/cmd/controller/cmd/root.go index 55923eb790..36720fc504 100644 --- a/flytepropeller/cmd/controller/cmd/root.go +++ b/flytepropeller/cmd/controller/cmd/root.go @@ -6,9 +6,10 @@ import ( "fmt" "os" - "github.com/lyft/flytepropeller/pkg/controller/executors" "sigs.k8s.io/controller-runtime/pkg/cache" + "github.com/lyft/flytepropeller/pkg/controller/executors" + "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/manager" @@ -36,11 +37,12 @@ import ( "k8s.io/client-go/kubernetes" "k8s.io/client-go/tools/clientcmd" + restclient "k8s.io/client-go/rest" + clientset "github.com/lyft/flytepropeller/pkg/client/clientset/versioned" informers "github.com/lyft/flytepropeller/pkg/client/informers/externalversions" "github.com/lyft/flytepropeller/pkg/controller" "github.com/lyft/flytepropeller/pkg/signals" - restclient "k8s.io/client-go/rest" ) const ( @@ -134,6 +136,10 @@ func getKubeConfig(_ context.Context, cfg *config2.Config) (*kubernetes.Clientse } } + kubecfg.QPS = cfg.KubeConfig.QPS + kubecfg.Burst = cfg.KubeConfig.Burst + kubecfg.Timeout = cfg.KubeConfig.Timeout.Duration + kubeClient, err := kubernetes.NewForConfig(kubecfg) if err != nil { return nil, nil, errors.Wrapf(err, "Error building kubernetes clientset") diff --git a/flytepropeller/pkg/controller/config/config.go b/flytepropeller/pkg/controller/config/config.go index 3062906341..7074fe8256 100644 --- a/flytepropeller/pkg/controller/config/config.go +++ b/flytepropeller/pkg/controller/config/config.go @@ -31,6 +31,18 @@ type Config struct { GCInterval config.Duration `json:"gc-interval" pflag:"\"30m\",Run periodic GC every 30 minutes"` LeaderElection LeaderElectionConfig `json:"leader-election,omitempty" pflag:",Config for leader election."` PublishK8sEvents bool `json:"publish-k8s-events" pflag:",Enable events publishing to K8s events API."` + KubeConfig KubeClientConfig `json:"kube-client-config" pflag:",Configuration to control the Kubernetes client"` +} + +type KubeClientConfig struct { + // QPS indicates the maximum QPS to the master from this client. + // If it's zero, the created RESTClient will use DefaultQPS: 5 + QPS float32 `json:"qps" pflag:",Max QPS to the master for requests to KubeAPI. 0 defaults to 5."` + // Maximum burst for throttle. + // If it's zero, the created RESTClient will use DefaultBurst: 10. + Burst int `json:"burst" pflag:",Max burst rate for throttle. 0 defaults to 10"` + // The maximum length of time to wait before giving up on a server request. A value of zero means no timeout. + Timeout config.Duration `json:"timeout" pflag:",Max duration allowed for every request to KubeAPI before giving up. 0 implies no timeout."` } type CompositeQueueType = string From 8212a4315aa2eae45c67554d9b27fb94d824d5be Mon Sep 17 00:00:00 2001 From: Ketan Umare Date: Thu, 19 Sep 2019 16:09:53 -0700 Subject: [PATCH 0114/1918] Optimize Command line template parsing (#15) * Template is more optimal to lazily serialize input parameters to command line * Fixed the usage * return an error in case of group match but no literal --- .../go/tasks/v1/flytek8s/container_helper.go | 2 +- .../go/tasks/v1/k8splugins/sidecar.go | 4 +- flyteplugins/go/tasks/v1/k8splugins/spark.go | 4 +- flyteplugins/go/tasks/v1/utils/template.go | 84 ++++++------- .../go/tasks/v1/utils/template_test.go | 111 +++++++++++------- 5 files changed, 108 insertions(+), 97 deletions(-) diff --git a/flyteplugins/go/tasks/v1/flytek8s/container_helper.go b/flyteplugins/go/tasks/v1/flytek8s/container_helper.go index 5c0f6dbf7b..8f0d6597ac 100755 --- a/flyteplugins/go/tasks/v1/flytek8s/container_helper.go +++ b/flyteplugins/go/tasks/v1/flytek8s/container_helper.go @@ -78,7 +78,7 @@ func ToK8sContainer(ctx context.Context, taskCtx types.TaskContext, taskContaine cmdLineArgs := utils.CommandLineTemplateArgs{ Input: inputFile.String(), OutputPrefix: taskCtx.GetDataDir().String(), - Inputs: utils.LiteralMapToTemplateArgs(ctx, inputs), + Inputs: inputs, } modifiedCommand, err := utils.ReplaceTemplateCommandArgs(ctx, taskContainer.GetCommand(), cmdLineArgs) diff --git a/flyteplugins/go/tasks/v1/k8splugins/sidecar.go b/flyteplugins/go/tasks/v1/k8splugins/sidecar.go index 3f3d7b2659..bff90a4656 100755 --- a/flyteplugins/go/tasks/v1/k8splugins/sidecar.go +++ b/flyteplugins/go/tasks/v1/k8splugins/sidecar.go @@ -45,7 +45,7 @@ func validateAndFinalizeContainers( utils.CommandLineTemplateArgs{ Input: taskCtx.GetInputsFile().String(), OutputPrefix: taskCtx.GetDataDir().String(), - Inputs: utils.LiteralMapToTemplateArgs(ctx, inputs), + Inputs: inputs, }) if err != nil { @@ -58,7 +58,7 @@ func validateAndFinalizeContainers( utils.CommandLineTemplateArgs{ Input: taskCtx.GetInputsFile().String(), OutputPrefix: taskCtx.GetDataDir().String(), - Inputs: utils.LiteralMapToTemplateArgs(ctx, inputs), + Inputs: inputs, }) if err != nil { diff --git a/flyteplugins/go/tasks/v1/k8splugins/spark.go b/flyteplugins/go/tasks/v1/k8splugins/spark.go index 17672cb40e..5f9ad35fd5 100755 --- a/flyteplugins/go/tasks/v1/k8splugins/spark.go +++ b/flyteplugins/go/tasks/v1/k8splugins/spark.go @@ -113,12 +113,12 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx types.Tas }, } - modifiedArgs, err := utils.ReplaceTemplateCommandArgs(context.TODO(), + modifiedArgs, err := utils.ReplaceTemplateCommandArgs(ctx, task.GetContainer().GetArgs(), utils.CommandLineTemplateArgs{ Input: taskCtx.GetInputsFile().String(), OutputPrefix: taskCtx.GetDataDir().String(), - Inputs: utils.LiteralMapToTemplateArgs(context.TODO(), inputs), + Inputs: inputs, }) if err != nil { diff --git a/flyteplugins/go/tasks/v1/utils/template.go b/flyteplugins/go/tasks/v1/utils/template.go index 1d88ec6f0e..84a9523657 100755 --- a/flyteplugins/go/tasks/v1/utils/template.go +++ b/flyteplugins/go/tasks/v1/utils/template.go @@ -9,7 +9,7 @@ import ( "github.com/golang/protobuf/ptypes" "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" - "github.com/lyft/flytestdlib/logger" + "github.com/pkg/errors" ) var inputFileRegex = regexp.MustCompile(`(?i){{\s*[\.$]Input\s*}}`) @@ -18,9 +18,9 @@ var inputVarRegex = regexp.MustCompile(`(?i){{\s*[\.$]Inputs\.(?P[^} // Contains arguments passed down to command line templates. type CommandLineTemplateArgs struct { - Input string `json:"input"` - OutputPrefix string `json:"output"` - Inputs map[string]string `json:"inputs"` + Input string `json:"input"` + OutputPrefix string `json:"output"` + Inputs *core.LiteralMap `json:"inputs"` } // Evaluates templates in each command with the equivalent value from passed args. Templates are case-insensitive @@ -48,7 +48,7 @@ func ReplaceTemplateCommandArgs(ctx context.Context, command []string, args Comm return res, nil } -func replaceTemplateCommandArgs(_ context.Context, commandTemplate string, args *CommandLineTemplateArgs) (string, error) { +func replaceTemplateCommandArgs(ctx context.Context, commandTemplate string, args *CommandLineTemplateArgs) (string, error) { val := inputFileRegex.ReplaceAllString(commandTemplate, args.Input) val = outputRegex.ReplaceAllString(val, args.OutputPrefix) groupMatches := inputVarRegex.FindAllStringSubmatchIndex(val, -1) @@ -64,89 +64,75 @@ func replaceTemplateCommandArgs(_ context.Context, commandTemplate string, args inputStartIdx := groupMatches[0][2] inputEndIdx := groupMatches[0][3] inputName := val[inputStartIdx:inputEndIdx] - inputVal, exists := args.Inputs[inputName] + + if args.Inputs == nil || args.Inputs.Literals == nil { + return val, fmt.Errorf("no inputs provided, cannot bind input name [%s]", inputName) + } + inputVal, exists := args.Inputs.Literals[inputName] if !exists { return val, fmt.Errorf("requested input is not found [%v] while processing template [%v]", inputName, commandTemplate) } + v, err := serializeLiteral(ctx, inputVal) + if err != nil { + return val, errors.Wrapf(err, "failed to bind a value to inputName [%s]", inputName) + } if endIdx >= len(val) { - return val[:startIdx] + inputVal, nil + return val[:startIdx] + v, nil } - return val[:startIdx] + inputVal + val[endIdx:], nil - } -} - -// Converts a literal map to a go map that can be used in templates. It drops literals that don't have a defined way to -// be safely serialized into a string. -func LiteralMapToTemplateArgs(ctx context.Context, m *core.LiteralMap) map[string]string { - if m == nil { - return map[string]string{} - } - - res := make(map[string]string, len(m.Literals)) - - for key, val := range m.Literals { - serialized, ok := serializeLiteral(ctx, val) - if ok { - res[key] = serialized - } + return val[:startIdx] + v + val[endIdx:], nil } - - return res } -func serializePrimitive(ctx context.Context, p *core.Primitive) (string, bool) { +func serializePrimitive(p *core.Primitive) (string, error) { switch o := p.Value.(type) { case *core.Primitive_Integer: - return fmt.Sprintf("%v", o.Integer), true + return fmt.Sprintf("%v", o.Integer), nil case *core.Primitive_Boolean: - return fmt.Sprintf("%v", o.Boolean), true + return fmt.Sprintf("%v", o.Boolean), nil case *core.Primitive_Datetime: - return ptypes.TimestampString(o.Datetime), true + return ptypes.TimestampString(o.Datetime), nil case *core.Primitive_Duration: - return o.Duration.String(), true + return o.Duration.String(), nil case *core.Primitive_FloatValue: - return fmt.Sprintf("%v", o.FloatValue), true + return fmt.Sprintf("%v", o.FloatValue), nil case *core.Primitive_StringValue: - return o.StringValue, true + return o.StringValue, nil default: - logger.Warnf(ctx, "Received an unexpected primitive type [%v]", reflect.TypeOf(p.Value)) - return "", false + return "", fmt.Errorf("received an unexpected primitive type [%v]", reflect.TypeOf(p.Value)) } } -func serializeLiteralScalar(ctx context.Context, l *core.Scalar) (string, bool) { +func serializeLiteralScalar(l *core.Scalar) (string, error) { switch o := l.Value.(type) { case *core.Scalar_Primitive: - return serializePrimitive(ctx, o.Primitive) + return serializePrimitive(o.Primitive) case *core.Scalar_Blob: - return o.Blob.Uri, true + return o.Blob.Uri, nil default: - logger.Warnf(ctx, "Received an unexpected scalar type [%v]", reflect.TypeOf(l.Value)) - return "", false + return "", fmt.Errorf("received an unexpected scalar type [%v]", reflect.TypeOf(l.Value)) } } -func serializeLiteral(ctx context.Context, l *core.Literal) (string, bool) { +func serializeLiteral(ctx context.Context, l *core.Literal) (string, error) { switch o := l.Value.(type) { case *core.Literal_Collection: res := make([]string, 0, len(o.Collection.Literals)) for _, sub := range o.Collection.Literals { - s, ok := serializeLiteral(ctx, sub) - if !ok { - return "", false + s, err := serializeLiteral(ctx, sub) + if err != nil { + return "", err } res = append(res, s) } - return fmt.Sprintf("[%v]", strings.Join(res, ",")), true + return fmt.Sprintf("[%v]", strings.Join(res, ",")), nil case *core.Literal_Scalar: - return serializeLiteralScalar(ctx, o.Scalar) + return serializeLiteralScalar(o.Scalar) default: - logger.Warnf(ctx, "Received an unexpected primitive type [%v]", reflect.TypeOf(l.Value)) - return "", false + return "", fmt.Errorf("received an unexpected primitive type [%v]", reflect.TypeOf(l.Value)) } } diff --git a/flyteplugins/go/tasks/v1/utils/template_test.go b/flyteplugins/go/tasks/v1/utils/template_test.go index 5060927db6..b8c2369bc0 100755 --- a/flyteplugins/go/tasks/v1/utils/template_test.go +++ b/flyteplugins/go/tasks/v1/utils/template_test.go @@ -8,6 +8,7 @@ import ( "time" "github.com/lyft/flyteidl/clients/go/coreutils" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" "github.com/stretchr/testify/assert" ) @@ -27,8 +28,10 @@ func BenchmarkReplacements(b *testing.B) { cmdTemplate := `abc {{ index .Inputs "x" }}` cmdArgs := CommandLineTemplateArgs{ Input: "inputfile.pb", - Inputs: map[string]string{ - "x": "1", + Inputs: &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "x": coreutils.MustMakePrimitiveLiteral(1), + }, }, } @@ -180,8 +183,16 @@ func TestReplaceTemplateCommandArgs(t *testing.T) { }, CommandLineTemplateArgs{ Input: "input/blah", OutputPrefix: "output/blah", - Inputs: map[string]string{ - "arr": "[a,b]", + Inputs: &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "arr": { + Value: &core.Literal_Collection{ + Collection: &core.LiteralCollection{ + Literals: []*core.Literal{coreutils.MustMakeLiteral("a"), coreutils.MustMakeLiteral("b")}, + }, + }, + }, + }, }}) assert.NoError(t, err) assert.Equal(t, []string{ @@ -191,49 +202,63 @@ func TestReplaceTemplateCommandArgs(t *testing.T) { "output/blah", }, actual) }) -} - -func TestLiteralMapToTemplateArgs(t *testing.T) { - t.Run("Scalars", func(t *testing.T) { - expected := map[string]string{ - "str": "blah", - "int": "5", - "date": "1900-01-01T01:01:01.000000001Z", - } - - dd := time.Date(1900, 1, 1, 1, 1, 1, 1, time.UTC) - lit := coreutils.MustMakeLiteral(map[string]interface{}{ - "str": "blah", - "int": 5, - "date": dd, - }) - - actual := LiteralMapToTemplateArgs(context.TODO(), lit.GetMap()) - assert.Equal(t, expected, actual) + t.Run("Date", func(t *testing.T) { + actual, err := ReplaceTemplateCommandArgs(context.TODO(), []string{ + "hello", + "world", + `--someArg {{ .Inputs.date }}`, + "{{ .OutputPrefix }}", + }, CommandLineTemplateArgs{ + Input: "input/blah", + OutputPrefix: "output/blah", + Inputs: &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "date": coreutils.MustMakeLiteral(time.Date(1900, 01, 01, 01, 01, 01, 000000001, time.UTC)), + }, + }}) + assert.NoError(t, err) + assert.Equal(t, []string{ + "hello", + "world", + "--someArg 1900-01-01T01:01:01.000000001Z", + "output/blah", + }, actual) }) - t.Run("1d array", func(t *testing.T) { - expected := map[string]string{ - "arr": "[a,b]", - } - - actual := LiteralMapToTemplateArgs(context.TODO(), coreutils.MustMakeLiteral(map[string]interface{}{ - "arr": []interface{}{"a", "b"}, - }).GetMap()) - - assert.Equal(t, expected, actual) + t.Run("2d Array arg", func(t *testing.T) { + actual, err := ReplaceTemplateCommandArgs(context.TODO(), []string{ + "hello", + "world", + `--someArg {{ .Inputs.arr }}`, + "{{ .OutputPrefix }}", + }, CommandLineTemplateArgs{ + Input: "input/blah", + OutputPrefix: "output/blah", + Inputs: &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "arr": coreutils.MustMakeLiteral([]interface{}{[]interface{}{"a", "b"}, []interface{}{1, 2}}), + }, + }}) + assert.NoError(t, err) + assert.Equal(t, []string{ + "hello", + "world", + "--someArg [[a,b],[1,2]]", + "output/blah", + }, actual) }) - t.Run("2d array", func(t *testing.T) { - expected := map[string]string{ - "arr": "[[a,b],[1,2]]", - } - - actual := LiteralMapToTemplateArgs(context.TODO(), coreutils.MustMakeLiteral(map[string]interface{}{ - "arr": []interface{}{[]interface{}{"a", "b"}, []interface{}{1, 2}}, - }).GetMap()) - - assert.Equal(t, expected, actual) + t.Run("nil input", func(t *testing.T) { + _, err := ReplaceTemplateCommandArgs(context.TODO(), []string{ + "hello", + "world", + `--someArg {{ .Inputs.arr }}`, + "{{ .OutputPrefix }}", + }, CommandLineTemplateArgs{ + Input: "input/blah", + OutputPrefix: "output/blah", + Inputs: &core.LiteralMap{Literals: nil}}) + assert.Error(t, err) }) } From edb198bfd4ce3647b72fa3194bb756458114927a Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Fri, 20 Sep 2019 10:22:08 -0700 Subject: [PATCH 0115/1918] Cache Dynamic WF Spec (#16) * Cache Dynamic WF Spec * Cleanup * Update Deps --- flytepropeller/Gopkg.lock | 44 +++---- .../pkg/apis/flyteworkflow/v1alpha1/iface.go | 4 + .../pkg/controller/nodes/dynamic/handler.go | 110 ++++++++++++++---- .../pkg/controller/nodes/dynamic/utils.go | 35 ++++++ .../controller/nodes/dynamic/utils_test.go | 37 ++++++ 5 files changed, 186 insertions(+), 44 deletions(-) diff --git a/flytepropeller/Gopkg.lock b/flytepropeller/Gopkg.lock index 184534eed4..496ae1dc2b 100644 --- a/flytepropeller/Gopkg.lock +++ b/flytepropeller/Gopkg.lock @@ -61,7 +61,7 @@ version = "1.0.0" [[projects]] - digest = "1:3cb5a7438cd06b18f192433cdb375f557c3c2eb6d97b88f0eb27c42f65c6b4dd" + digest = "1:43785d5148b719a3f30dbda83a2e2f46d56ac6c5de83a14802e3fea7ffd0a5b9" name = "github.com/aws/aws-sdk-go" packages = [ "aws", @@ -105,8 +105,8 @@ "service/sts/stsiface", ] pruneopts = "" - revision = "b4225bdde03b756685c89b7db31fe2ffac6d9234" - version = "v1.24.0" + revision = "86666b1e15ce3072f0a5e22d7988d8d195c44611" + version = "v1.24.2" [[projects]] branch = "master" @@ -466,7 +466,7 @@ version = "v0.14.1" [[projects]] - digest = "1:ab33b506292441cab3455961ad86e9e5e81989b2f6089d4437d2b15a6d417dc2" + digest = "1:b8a5dd502461caca270a001089dd5dda700a409c79748cef84742e27f03ed83a" name = "github.com/lyft/flyteplugins" packages = [ "go/tasks", @@ -488,9 +488,9 @@ "go/tasks/v1/utils", ] pruneopts = "" - revision = "31b0fb21c6bf86f9e5e81297d9ad3de3feaea6e6" + revision = "5592b9fb09a8f67cd0086ef1c623aa1cb1a2030f" source = "https://github.com/lyft/flyteplugins" - version = "v0.1.7" + version = "v0.1.8" [[projects]] digest = "1:b8860863eb1eb7fe4f7d2b3c6b34b89aaedeaae88d761d2c9a4eccc46262abcc" @@ -738,12 +738,12 @@ version = "v1.1.0" [[projects]] - digest = "1:1bff633980ce46a718d753a3bc7e3eb6d7e4df34bc7e4a9869e71ccb4314dc40" + digest = "1:688428eeb1ca80d92599eb3254bdf91b51d7e232fead3a73844c1f201a281e51" name = "github.com/spf13/pflag" packages = ["."] pruneopts = "" - revision = "e8f29969b682c41a730f8f08b76033b120498464" - version = "v1.0.4" + revision = "2e9d26c8c37aae03e3f9d4e90b7116f5accb7cab" + version = "v1.0.5" [[projects]] digest = "1:c25a789c738f7cc8ec7f34026badd4e117853f329334a5aa45cf5d0727d7d442" @@ -837,7 +837,7 @@ [[projects]] branch = "master" - digest = "1:af83d44f1195692d76697f62af82329be4d55b100d33b9b3db8f5d0f44563fb9" + digest = "1:ce26d94b8841936fff59bb524f4b96ac434f411b780b3aa784da90ee96ae2367" name = "golang.org/x/net" packages = [ "context", @@ -850,7 +850,7 @@ "trace", ] pruneopts = "" - revision = "c8589233b77dde5edd2205ba8a4fb5c9c2472556" + revision = "a8b05e9114ab0cb08faec337c959aed24b68bf50" [[projects]] branch = "master" @@ -868,14 +868,14 @@ [[projects]] branch = "master" - digest = "1:d75336f9fd966011ede7a692794c112e81f0cb80d6e0082ad352e1986fc7b5ee" + digest = "1:c04d252619a11f0ba51313ad9ee728c0b7bb61c34c7ab65e841a05ac350e65a0" name = "golang.org/x/sys" packages = [ "unix", "windows", ] pruneopts = "" - revision = "b4ddaad3f8a36719f2b8bc6486c14cc468ca2bb5" + revision = "0c1ff786ef13daa914a3351c5e6b0321aed5960e" [[projects]] digest = "1:740b51a55815493a8d0f2b1e0d0ae48fe48953bf7eaf3fcc4198823bf67768c0" @@ -912,7 +912,7 @@ [[projects]] branch = "master" - digest = "1:3eccfc7625ff5137b398e7b16323d272f2423cd04985de69c79c3e0793494e8f" + digest = "1:2420676cd09e5b3bb8ffab905e706818ffa5d91af9e70a202eb36d251fed8c64" name = "golang.org/x/tools" packages = [ "go/ast/astutil", @@ -929,11 +929,11 @@ "internal/semver", ] pruneopts = "" - revision = "1cc945182204dc50f5760d9859d043a3d4e27047" + revision = "db1d4edb46856964c77d7931f8076747b3015980" [[projects]] branch = "master" - digest = "1:1163141bda0af433265042eb4345e4433f1302991b6be9bfa623bd0316004e32" + digest = "1:8fe044fdd4c04c27426bc4d3a2914fa0a03d7199c546563d04ea0d9544af54ba" name = "google.golang.org/api" packages = [ "googleapi", @@ -947,7 +947,7 @@ "transport/http/internal/propagation", ] pruneopts = "" - revision = "7439972e83a764c2a87f6fa78c4a37e8ed2db615" + revision = "634b73c1f50be990f1ba97c3f325fb7f88b13775" [[projects]] digest = "1:0568e577f790e9bd0420521cff50580f9b38165a38f217ce68f55c4bbaa97066" @@ -1259,20 +1259,20 @@ revision = "ebc107f98eab922ef99d645781b87caca01f4f48" [[projects]] - digest = "1:3063061b6514ad2666c4fa292451685884cacf77c803e1b10b4a4fa23f7787fb" + digest = "1:7ce71844fcaaabcbe09a392902edb5790ddca3a7070ae8d20830dc6dbe2751af" name = "k8s.io/klog" packages = ["."] pruneopts = "" - revision = "3ca30a56d8a775276f9cdae009ba326fdc05af7f" - version = "v0.4.0" + revision = "2ca9ad30301bf30a8a6e0fa2110db6b8df699a91" + version = "v1.0.0" [[projects]] branch = "master" - digest = "1:71e59e355758d825c891c77bfe3ec2c0b2523b05076e96b2a2bfa804e6ac576a" + digest = "1:ad13d36fb31a3e590b143439610f1a35b4033437ebf565dbc14a72ed4bd61dfb" name = "k8s.io/kube-openapi" packages = ["pkg/util/proto"] pruneopts = "" - revision = "743ec37842bffe49dd4221d9026f30fb1d5adbc4" + revision = "0270cf2f1c1d995d34b36019a6f65d58e6e33ad4" [[projects]] digest = "1:77629c3c036454b4623e99e20f5591b9551dd81d92db616384af92435b52e9b6" diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/iface.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/iface.go index deea5db502..45ca152a4e 100644 --- a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/iface.go +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/iface.go @@ -389,3 +389,7 @@ func GetOutputErrorFile(inputDir DataReference) DataReference { func GetFutureFile() string { return "futures.pb" } + +func GetCompiledFutureFile() string { + return "futures_compiled.pb" +} diff --git a/flytepropeller/pkg/controller/nodes/dynamic/handler.go b/flytepropeller/pkg/controller/nodes/dynamic/handler.go index fe8af1cb36..6769838de5 100644 --- a/flytepropeller/pkg/controller/nodes/dynamic/handler.go +++ b/flytepropeller/pkg/controller/nodes/dynamic/handler.go @@ -37,12 +37,16 @@ type dynamicNodeHandler struct { type metrics struct { buildDynamicWorkflow labeled.StopWatch retrieveDynamicJobSpec labeled.StopWatch + CacheHit labeled.StopWatch + CacheError labeled.Counter } func newMetrics(scope promutils.Scope) metrics { return metrics{ buildDynamicWorkflow: labeled.NewStopWatch("build_dynamic_workflow", "Overhead for building a dynamic workflow in memory.", time.Microsecond, scope), retrieveDynamicJobSpec: labeled.NewStopWatch("retrieve_dynamic_spec", "Overhead of downloading and unmarshaling dynamic job spec", time.Microsecond, scope), + CacheHit: labeled.NewStopWatch("dynamic_workflow_cache_hit", "A dynamic workflow was loaded from store.", time.Microsecond, scope), + CacheError: labeled.NewCounter("cache_err", "A dynamic workflow failed to store or load from data store.", scope), } } @@ -56,11 +60,11 @@ func (e dynamicNodeHandler) ExtractOutput(ctx context.Context, w v1alpha1.Execut return outputResolver.ExtractOutput(ctx, w, n, bindToVar) } -func (e dynamicNodeHandler) getDynamicJobSpec(ctx context.Context, node v1alpha1.ExecutableNode, nodeStatus v1alpha1.ExecutableNodeStatus) (*core.DynamicJobSpec, error) { +func (e dynamicNodeHandler) getDynamicJobSpec(ctx context.Context, node v1alpha1.ExecutableNode, dataDir storage.DataReference) (*core.DynamicJobSpec, error) { t := e.metrics.retrieveDynamicJobSpec.Start(ctx) defer t.Stop() - futuresFilePath, err := e.store.ConstructReference(ctx, nodeStatus.GetDataDir(), v1alpha1.GetFutureFile()) + futuresFilePath, err := e.store.ConstructReference(ctx, dataDir, v1alpha1.GetFutureFile()) if err != nil { logger.Warnf(ctx, "Failed to construct data path for futures file. Error: %v", err) return nil, err @@ -218,22 +222,7 @@ func (e dynamicNodeHandler) CheckNodeStatus(ctx context.Context, w v1alpha1.Exec func (e dynamicNodeHandler) buildContextualDynamicWorkflow(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, previousNodeStatus v1alpha1.ExecutableNodeStatus) (dynamicWf v1alpha1.ExecutableWorkflow, status v1alpha1.ExecutableNodeStatus, isDynamic bool, err error) { - t := e.metrics.buildDynamicWorkflow.Start(ctx) - defer t.Stop() - var nStatus v1alpha1.ExecutableNodeStatus - // We will only get here if the Phase is success. The downside is that this is an overhead for all nodes that are - // not dynamic. But given that we will only check once, it should be ok. - // TODO: Check for node.is_dynamic once the IDL changes are in and SDK migration has happened. - djSpec, err := e.getDynamicJobSpec(ctx, node, previousNodeStatus) - if err != nil { - return nil, nil, false, err - } - - if djSpec == nil { - return nil, status, false, nil - } - rootNodeStatus := w.GetNodeExecutionStatus(node.GetID()) if node.GetTaskID() != nil { // TODO: This is a hack to set parent task execution id, we should move to node-node relationship. @@ -253,29 +242,106 @@ func (e dynamicNodeHandler) buildContextualDynamicWorkflow(ctx context.Context, nStatus = w.GetNodeExecutionStatus(node.GetID()) } + subwf, isDynamic, err := e.loadOrBuildDynamicWorkflow(ctx, w, node, previousNodeStatus.GetDataDir(), nStatus) + if err != nil { + return nil, nStatus, false, err + } + + if !isDynamic { + return nil, nil, false, nil + } + + return newContextualWorkflow(w, subwf, nStatus, subwf.Tasks, subwf.SubWorkflows), nStatus, isDynamic, nil +} + +func (e dynamicNodeHandler) buildFlyteWorkflow(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, + dataDir storage.DataReference, nStatus v1alpha1.ExecutableNodeStatus) (compiledWf *v1alpha1.FlyteWorkflow, isDynamic bool, err error) { + t := e.metrics.buildDynamicWorkflow.Start(ctx) + defer t.Stop() + + // We will only get here if the Phase is success. The downside is that this is an overhead for all nodes that are + // not dynamic. But given that we will only check once, it should be ok. + // TODO: Check for node.is_dynamic once the IDL changes are in and SDK migration has happened. + djSpec, err := e.getDynamicJobSpec(ctx, node, dataDir) + if err != nil { + return nil, false, err + } + + if djSpec == nil { + return nil, false, nil + } + var closure *core.CompiledWorkflowClosure wf, err := e.buildDynamicWorkflowTemplate(ctx, djSpec, w, node, nStatus) if err != nil { - return nil, nil, true, err + return nil, true, err } compiledTasks, err := compileTasks(ctx, djSpec.Tasks) if err != nil { - return nil, nil, true, err + return nil, true, err } // TODO: This will currently fail if the WF references any launch plans closure, err = compiler.CompileWorkflow(wf, djSpec.Subworkflows, compiledTasks, []common2.InterfaceProvider{}) if err != nil { - return nil, nil, true, err + return nil, true, err } subwf, err := k8s.BuildFlyteWorkflow(closure, nil, nil, "") if err != nil { - return nil, nil, true, err + return nil, false, err + } + + return subwf, true, nil +} + +func (e dynamicNodeHandler) loadOrBuildDynamicWorkflow(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, + dataDir storage.DataReference, nodeStatus v1alpha1.ExecutableNodeStatus) (compiledWf *v1alpha1.FlyteWorkflow, isDynamic bool, err error) { + + cacheHitStopWatch := e.metrics.CacheHit.Start(ctx) + // Check if we have compiled the workflow before: + compiledFuturesFilePath, err := e.store.ConstructReference(ctx, nodeStatus.GetDataDir(), v1alpha1.GetCompiledFutureFile()) + if err != nil { + logger.Warnf(ctx, "Failed to construct data path for futures file. Error: %v", err) + return nil, false, errors.Wrapf(errors.CausedByError, node.GetID(), err, "Failed to construct data path for futures file.") + } + + // If there is a cached compiled Workflow, load and return it. + if metadata, err := e.store.Head(ctx, compiledFuturesFilePath); err != nil { + logger.Warnf(ctx, "Failed to call head on compiled futures file. Error: %v", err) + return nil, false, errors.Wrapf(errors.CausedByError, node.GetID(), err, "Failed to do HEAD on compiled futures file.") + } else if metadata.Exists() { + // It exists, load and return it + compiledWf, err = loadCachedFlyteWorkflow(ctx, e.store, compiledFuturesFilePath) + if err != nil { + logger.Warnf(ctx, "Failed to load cached flyte workflow from [%v], this will cause the dynamic workflow to be recompiled. Error: %v", + compiledFuturesFilePath, err) + e.metrics.CacheError.Inc(ctx) + } else { + cacheHitStopWatch.Stop() + return compiledWf, true, nil + } + } + + // If we have not build this spec before, build it now and cache it. + compiledWf, isDynamic, err = e.buildFlyteWorkflow(ctx, w, node, dataDir, nodeStatus) + if err != nil { + return compiledWf, isDynamic, err + } + + if !isDynamic { + return compiledWf, isDynamic, err + } + + // Cache the built WF. Errors are swallowed. + err = cacheFlyteWorkflow(ctx, e.store, compiledWf, compiledFuturesFilePath) + if err != nil { + logger.Warnf(ctx, "Failed to cache flyte workflow, this will cause a cache miss next time and cause the dynamic workflow to be recompiled. Error: %v", err) + e.metrics.CacheError.Inc(ctx) } - return newContextualWorkflow(w, subwf, nStatus, subwf.Tasks, subwf.SubWorkflows), nStatus, true, nil + return compiledWf, true, nil } func (e dynamicNodeHandler) progressDynamicWorkflow(ctx context.Context, parentNodeStatus v1alpha1.ExecutableNodeStatus, diff --git a/flytepropeller/pkg/controller/nodes/dynamic/utils.go b/flytepropeller/pkg/controller/nodes/dynamic/utils.go index 42c9e081c0..913d603751 100644 --- a/flytepropeller/pkg/controller/nodes/dynamic/utils.go +++ b/flytepropeller/pkg/controller/nodes/dynamic/utils.go @@ -1,8 +1,11 @@ package dynamic import ( + "bytes" "context" + "encoding/json" + "github.com/lyft/flytestdlib/storage" "k8s.io/apimachinery/pkg/util/sets" "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" @@ -103,3 +106,35 @@ func compileTasks(_ context.Context, tasks []*core.TaskTemplate) ([]*core.Compil return compiledTasks, nil } + +func cacheFlyteWorkflow(ctx context.Context, store storage.RawStore, wf *v1alpha1.FlyteWorkflow, target storage.DataReference) error { + raw, err := json.Marshal(wf) + if err != nil { + return err + } + + return store.WriteRaw(ctx, target, int64(len(raw)), storage.Options{}, bytes.NewReader(raw)) +} + +func loadCachedFlyteWorkflow(ctx context.Context, store storage.RawStore, source storage.DataReference) ( + *v1alpha1.FlyteWorkflow, error) { + + rawReader, err := store.ReadRaw(ctx, source) + if err != nil { + return nil, err + } + + buf := bytes.NewBuffer(nil) + _, err = buf.ReadFrom(rawReader) + if err != nil { + return nil, err + } + + err = rawReader.Close() + if err != nil { + return nil, err + } + + wf := &v1alpha1.FlyteWorkflow{} + return wf, json.Unmarshal(buf.Bytes(), wf) +} diff --git a/flytepropeller/pkg/controller/nodes/dynamic/utils_test.go b/flytepropeller/pkg/controller/nodes/dynamic/utils_test.go index e5df771a64..475b903514 100644 --- a/flytepropeller/pkg/controller/nodes/dynamic/utils_test.go +++ b/flytepropeller/pkg/controller/nodes/dynamic/utils_test.go @@ -1,8 +1,13 @@ package dynamic import ( + "context" "testing" + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/storage" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" @@ -75,3 +80,35 @@ func TestUnderlyingInterface(t *testing.T) { assert.NotNil(t, iface) assert.Equal(t, expectedIface, iface) } + +func createInmemoryStore(t testing.TB) *storage.DataStore { + cfg := storage.Config{ + Type: storage.TypeMemory, + } + + d, err := storage.NewDataStore(&cfg, promutils.NewTestScope()) + assert.NoError(t, err) + + return d +} + +func Test_cacheFlyteWorkflow(t *testing.T) { + store := createInmemoryStore(t) + expected := &v1alpha1.FlyteWorkflow{ + TypeMeta: v1.TypeMeta{}, + ObjectMeta: v1.ObjectMeta{}, + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "abc", + Connections: v1alpha1.Connections{ + DownstreamEdges: map[v1alpha1.NodeID][]v1alpha1.NodeID{}, + UpstreamEdges: map[v1alpha1.NodeID][]v1alpha1.NodeID{}, + }, + }, + } + + location := storage.DataReference("somekey/file.json") + assert.NoError(t, cacheFlyteWorkflow(context.TODO(), store, expected, location)) + actual, err := loadCachedFlyteWorkflow(context.TODO(), store, location) + assert.NoError(t, err) + assert.Equal(t, expected, actual) +} From 0f79e1e86b3b2305638a9aada31151414536f987 Mon Sep 17 00:00:00 2001 From: Katrina Rogan Date: Tue, 24 Sep 2019 11:16:15 -0700 Subject: [PATCH 0116/1918] Initial Commit --- flyteadmin/.dockerignore | 1 + flyteadmin/.gitignore | 8 + flyteadmin/.golangci.yml | 27 + flyteadmin/.travis.yml | 27 + flyteadmin/CODE_OF_CONDUCT.md | 3 + flyteadmin/Dockerfile | 33 + flyteadmin/Gopkg.lock | 1191 +++++++++++ flyteadmin/Gopkg.toml | 110 + flyteadmin/LICENSE | 202 ++ flyteadmin/Makefile | 37 + flyteadmin/NOTICE | 4 + flyteadmin/README.rst | 8 + .../boilerplate/lyft/docker_build/Makefile | 12 + .../boilerplate/lyft/docker_build/Readme.rst | 23 + .../lyft/docker_build/docker_build.sh | 67 + .../golang_dockerfile/Dockerfile.GoTemplate | 33 + .../lyft/golang_dockerfile/Readme.rst | 16 + .../lyft/golang_dockerfile/update.sh | 13 + .../lyft/golang_test_targets/Makefile | 38 + .../lyft/golang_test_targets/Readme.rst | 31 + .../lyft/golang_test_targets/goimports | 8 + flyteadmin/boilerplate/update.cfg | 3 + flyteadmin/boilerplate/update.sh | 53 + flyteadmin/cmd/entrypoints/clusterresource.go | 111 + flyteadmin/cmd/entrypoints/migrate.go | 133 ++ flyteadmin/cmd/entrypoints/root.go | 83 + flyteadmin/cmd/entrypoints/serve.go | 219 ++ flyteadmin/cmd/main.go | 14 + flyteadmin/flyteadmin_config.yaml | 141 ++ flyteadmin/pkg/async/notifications/email.go | 54 + .../pkg/async/notifications/email_test.go | 79 + flyteadmin/pkg/async/notifications/factory.go | 111 + .../implementations/aws_emailer.go | 86 + .../implementations/aws_emailer_test.go | 121 ++ .../implementations/noop_notifications.go | 54 + .../implementations/processor.go | 160 ++ .../implementations/processor_test.go | 140 ++ .../implementations/publisher.go | 52 + .../implementations/publisher_test.go | 76 + .../async/notifications/interfaces/emailer.go | 13 + .../notifications/interfaces/processor.go | 17 + .../notifications/interfaces/publisher.go | 27 + .../pkg/async/notifications/mocks/emailer.go | 24 + .../async/notifications/mocks/processor.go | 47 + .../async/notifications/mocks/publisher.go | 24 + .../schedule/aws/cloud_watch_scheduler.go | 263 +++ .../aws/cloud_watch_scheduler_test.go | 272 +++ .../interfaces/cloud_watch_event_client.go | 11 + .../mocks/mock_cloud_watch_event_client.go | 71 + .../pkg/async/schedule/aws/serialization.go | 86 + .../async/schedule/aws/serialization_test.go | 61 + flyteadmin/pkg/async/schedule/aws/shared.go | 24 + .../pkg/async/schedule/aws/shared_test.go | 18 + .../async/schedule/aws/workflow_executor.go | 312 +++ .../schedule/aws/workflow_executor_test.go | 307 +++ flyteadmin/pkg/async/schedule/factory.go | 98 + .../schedule/interfaces/event_scheduler.go | 25 + .../schedule/interfaces/workflow_executor.go | 7 + .../schedule/mocks/mock_event_scheduler.go | 42 + .../schedule/mocks/mock_workflow_executor.go | 28 + .../async/schedule/noop/event_scheduler.go | 30 + .../async/schedule/noop/workflow_executor.go | 15 + flyteadmin/pkg/clusterresource/controller.go | 348 +++ .../pkg/clusterresource/controller_test.go | 108 + flyteadmin/pkg/common/cloud.go | 10 + flyteadmin/pkg/common/constants.go | 6 + flyteadmin/pkg/common/entity.go | 13 + flyteadmin/pkg/common/executions.go | 59 + flyteadmin/pkg/common/executions_test.go | 17 + flyteadmin/pkg/common/filters.go | 280 +++ flyteadmin/pkg/common/filters_test.go | 131 ++ flyteadmin/pkg/common/mocks/storage.go | 71 + flyteadmin/pkg/common/sorting.go | 39 + flyteadmin/pkg/common/sorting_test.go | 26 + flyteadmin/pkg/config/config.go | 44 + flyteadmin/pkg/config/config_flags.go | 22 + flyteadmin/pkg/config/config_flags_test.go | 191 ++ flyteadmin/pkg/data/factory.go | 54 + .../data/implementations/aws_remote_url.go | 104 + .../implementations/aws_remote_url_test.go | 85 + .../data/implementations/noop_remote_url.go | 34 + .../implementations/noop_remote_url_test.go | 42 + flyteadmin/pkg/data/interfaces/remote.go | 12 + flyteadmin/pkg/data/mocks/remote.go | 24 + flyteadmin/pkg/errors/errors.go | 93 + flyteadmin/pkg/errors/errors_test.go | 29 + flyteadmin/pkg/flytek8s/client.go | 76 + .../pkg/manager/impl/execution_manager.go | 810 +++++++ .../manager/impl/execution_manager_test.go | 1887 +++++++++++++++++ .../pkg/manager/impl/executions/queues.go | 248 +++ .../manager/impl/executions/queues_test.go | 179 ++ .../pkg/manager/impl/launch_plan_manager.go | 560 +++++ .../manager/impl/launch_plan_manager_test.go | 1322 ++++++++++++ .../manager/impl/node_execution_manager.go | 377 ++++ .../impl/node_execution_manager_test.go | 805 +++++++ .../pkg/manager/impl/project_manager.go | 69 + .../pkg/manager/impl/project_manager_test.go | 125 ++ .../pkg/manager/impl/shared/constants.go | 33 + flyteadmin/pkg/manager/impl/shared/errors.go | 21 + .../manager/impl/task_execution_manager.go | 293 +++ .../impl/task_execution_manager_test.go | 905 ++++++++ flyteadmin/pkg/manager/impl/task_manager.go | 255 +++ .../pkg/manager/impl/task_manager_test.go | 400 ++++ .../pkg/manager/impl/testutils/config.go | 29 + .../pkg/manager/impl/testutils/constants.go | 8 + .../manager/impl/testutils/mock_closures.go | 50 + .../manager/impl/testutils/mock_requests.go | 290 +++ .../pkg/manager/impl/testutils/repository.go | 28 + flyteadmin/pkg/manager/impl/util/digests.go | 51 + .../pkg/manager/impl/util/digests_test.go | 174 ++ flyteadmin/pkg/manager/impl/util/filters.go | 267 +++ .../pkg/manager/impl/util/filters_test.go | 194 ++ flyteadmin/pkg/manager/impl/util/shared.go | 215 ++ .../pkg/manager/impl/util/shared_test.go | 385 ++++ .../manager/impl/util/testdata/workflow.json | 722 +++++++ .../impl/validation/execution_validator.go | 149 ++ .../validation/execution_validator_test.go | 224 ++ .../impl/validation/launch_plan_validator.go | 156 ++ .../validation/launch_plan_validator_test.go | 370 ++++ .../validation/node_execution_validator.go | 41 + .../node_execution_validator_test.go | 131 ++ .../impl/validation/project_validator.go | 60 + .../impl/validation/project_validator_test.go | 116 + .../validation/task_execution_validator.go | 56 + .../task_execution_validator_test.go | 171 ++ .../manager/impl/validation/task_validator.go | 339 +++ .../impl/validation/task_validator_test.go | 557 +++++ .../pkg/manager/impl/validation/validation.go | 181 ++ .../impl/validation/validation_test.go | 290 +++ .../impl/validation/workflow_validator.go | 63 + .../validation/workflow_validator_test.go | 81 + .../pkg/manager/impl/workflow_manager.go | 353 +++ .../pkg/manager/impl/workflow_manager_test.go | 575 +++++ .../pkg/manager/interfaces/execution.go | 24 + .../pkg/manager/interfaces/launch_plan.go | 26 + .../pkg/manager/interfaces/node_execution.go | 18 + flyteadmin/pkg/manager/interfaces/project.go | 13 + flyteadmin/pkg/manager/interfaces/task.go | 16 + .../pkg/manager/interfaces/task_execution.go | 17 + flyteadmin/pkg/manager/interfaces/workflow.go | 16 + flyteadmin/pkg/manager/mocks/execution.go | 120 ++ flyteadmin/pkg/manager/mocks/launch_plan.go | 115 + .../pkg/manager/mocks/node_execution.go | 86 + flyteadmin/pkg/manager/mocks/project.go | 39 + flyteadmin/pkg/manager/mocks/task.go | 50 + .../pkg/manager/mocks/task_execution.go | 75 + flyteadmin/pkg/manager/mocks/workflow.go | 41 + .../pkg/repositories/config/database.go | 12 + .../repositories/config/migration_models.go | 82 + .../pkg/repositories/config/migrations.go | 112 + .../pkg/repositories/config/postgres.go | 81 + .../pkg/repositories/config/postgres_test.go | 21 + .../pkg/repositories/config/seed_data.go | 26 + flyteadmin/pkg/repositories/database_test.go | 30 + .../repositories/errors/error_transformer.go | 10 + flyteadmin/pkg/repositories/errors/errors.go | 26 + .../pkg/repositories/errors/postgres.go | 95 + .../pkg/repositories/errors/postgres_test.go | 43 + .../errors/test_error_transformer.go | 31 + flyteadmin/pkg/repositories/factory.go | 48 + .../pkg/repositories/gormimpl/common.go | 116 + .../repositories/gormimpl/execution_repo.go | 150 ++ .../gormimpl/execution_repo_test.go | 409 ++++ .../repositories/gormimpl/launch_plan_repo.go | 196 ++ .../gormimpl/launch_plan_repo_test.go | 494 +++++ .../pkg/repositories/gormimpl/metrics.go | 33 + .../gormimpl/node_execution_repo.go | 179 ++ .../gormimpl/node_execution_repo_test.go | 423 ++++ .../pkg/repositories/gormimpl/project_repo.go | 72 + .../gormimpl/project_repo_test.go | 83 + .../gormimpl/task_execution_repo.go | 140 ++ .../gormimpl/task_execution_repo_test.go | 230 ++ .../pkg/repositories/gormimpl/task_repo.go | 135 ++ .../repositories/gormimpl/task_repo_test.go | 264 +++ .../pkg/repositories/gormimpl/test_utils.go | 30 + .../repositories/gormimpl/workflow_repo.go | 134 ++ .../gormimpl/workflow_repo_test.go | 270 +++ .../pkg/repositories/interfaces/common.go | 29 + .../repositories/interfaces/execution_repo.go | 29 + .../interfaces/launch_plan_repo.go | 37 + .../interfaces/node_execution_repo.go | 37 + .../repositories/interfaces/project_repo.go | 17 + .../interfaces/task_execution_repo.go | 29 + .../pkg/repositories/interfaces/task_repo.go | 25 + .../repositories/interfaces/workflow_repo.go | 23 + .../pkg/repositories/mocks/execution_repo.go | 96 + .../repositories/mocks/launch_plan_repo.go | 107 + .../repositories/mocks/node_execution_repo.go | 85 + .../pkg/repositories/mocks/project_repo.go | 44 + .../pkg/repositories/mocks/repository.go | 56 + .../repositories/mocks/task_execution_repo.go | 68 + .../pkg/repositories/mocks/task_repo.go | 79 + .../pkg/repositories/mocks/workflow_repo.go | 80 + .../pkg/repositories/models/base_model.go | 13 + .../pkg/repositories/models/execution.go | 46 + .../repositories/models/execution_event.go | 13 + .../pkg/repositories/models/launch_plan.go | 35 + .../pkg/repositories/models/node_execution.go | 36 + .../models/node_execution_event.go | 13 + flyteadmin/pkg/repositories/models/project.go | 7 + flyteadmin/pkg/repositories/models/task.go | 21 + .../pkg/repositories/models/task_execution.go | 39 + .../pkg/repositories/models/workflow.go | 21 + flyteadmin/pkg/repositories/postgres_repo.go | 60 + .../repositories/transformers/execution.go | 194 ++ .../transformers/execution_event.go | 27 + .../transformers/execution_event_test.go | 45 + .../transformers/execution_test.go | 470 ++++ .../repositories/transformers/launch_plan.go | 133 ++ .../transformers/launch_plan_test.go | 267 +++ .../transformers/node_execution.go | 200 ++ .../transformers/node_execution_event.go | 30 + .../transformers/node_execution_event_test.go | 49 + .../transformers/node_execution_test.go | 257 +++ .../pkg/repositories/transformers/project.go | 36 + .../repositories/transformers/project_test.go | 74 + .../pkg/repositories/transformers/task.go | 81 + .../transformers/task_execution.go | 211 ++ .../transformers/task_execution_test.go | 503 +++++ .../repositories/transformers/task_test.go | 148 ++ .../pkg/repositories/transformers/workflow.go | 82 + .../transformers/workflow_test.go | 139 ++ flyteadmin/pkg/rpc/adminservice/base.go | 163 ++ flyteadmin/pkg/rpc/adminservice/execution.go | 139 ++ .../pkg/rpc/adminservice/launch_plan.go | 154 ++ flyteadmin/pkg/rpc/adminservice/metrics.go | 152 ++ .../pkg/rpc/adminservice/node_execution.go | 110 + flyteadmin/pkg/rpc/adminservice/project.go | 46 + flyteadmin/pkg/rpc/adminservice/task.go | 93 + .../pkg/rpc/adminservice/task_execution.go | 113 + .../rpc/adminservice/tests/execution_test.go | 325 +++ .../adminservice/tests/launch_plan_test.go | 162 ++ .../adminservice/tests/node_execution_test.go | 252 +++ .../rpc/adminservice/tests/project_test.go | 55 + .../adminservice/tests/task_execution_test.go | 349 +++ .../pkg/rpc/adminservice/tests/task_test.go | 94 + flyteadmin/pkg/rpc/adminservice/tests/util.go | 31 + .../rpc/adminservice/tests/workflow_test.go | 64 + .../pkg/rpc/adminservice/util/metrics.go | 67 + .../pkg/rpc/adminservice/util/transformers.go | 20 + .../adminservice/util/transformers_test.go | 40 + flyteadmin/pkg/rpc/adminservice/workflow.go | 94 + .../runtime/application_config_provider.go | 76 + .../pkg/runtime/cluster_config_provider.go | 45 + .../pkg/runtime/cluster_resource_provider.go | 48 + .../runtime/cluster_resource_provider_test.go | 40 + .../pkg/runtime/config_provider_test.go | 42 + .../pkg/runtime/configuration_provider.go | 56 + .../pkg/runtime/execution_queue_provider.go | 38 + .../interfaces/application_configuration.go | 104 + .../interfaces/cluster_configuration.go | 51 + .../cluster_resource_configuration.go | 32 + .../pkg/runtime/interfaces/configuration.go | 12 + .../runtime/interfaces/queue_configuration.go | 35 + .../registration_validation_provider.go | 16 + .../interfaces/task_resource_configuration.go | 14 + .../pkg/runtime/interfaces/whitelist.go | 14 + .../mocks/mock_application_provider.go | 62 + .../mocks/mock_cluster_resource_provider.go | 28 + .../mocks/mock_configuration_provider.go | 64 + .../mocks/mock_execution_queue_provider.go | 25 + .../mock_registration_validation_provider.go | 30 + .../mocks/mock_task_resource_provider.go | 22 + .../runtime/mocks/mock_whitelist_provider.go | 15 + .../registration_validation_provider.go | 53 + .../pkg/runtime/task_resource_provider.go | 41 + .../testdata/cluster_resource_config.yaml | 11 + .../pkg/runtime/testdata/clusters_config.yaml | 15 + flyteadmin/pkg/runtime/testdata/config.yaml | 20 + flyteadmin/pkg/runtime/whitelist_provider.go | 29 + .../pkg/workflowengine/impl/compiler.go | 46 + .../workflowengine/impl/interface_provider.go | 49 + .../impl/interface_provider_test.go | 76 + .../workflowengine/impl/propeller_executor.go | 193 ++ .../impl/propeller_executor_test.go | 397 ++++ .../pkg/workflowengine/interfaces/compiler.go | 16 + .../pkg/workflowengine/interfaces/executor.go | 32 + .../pkg/workflowengine/mocks/mock_compiler.go | 62 + .../pkg/workflowengine/mocks/mock_executor.go | 44 + .../sampleresourcetemplates/docker.yaml | 8 + .../imagepullsecrets.yaml | 7 + .../sampleresourcetemplates/namespace.yaml | 7 + .../script/integration/k8s/integration.yaml | 433 ++++ flyteadmin/script/integration/k8s/main.sh | 46 + flyteadmin/script/integration/launch.sh | 40 + flyteadmin/tests/bootstrap.go | 66 + flyteadmin/tests/execution_test.go | 280 +++ flyteadmin/tests/helpers.go | 28 + flyteadmin/tests/launch_plan_test.go | 898 ++++++++ flyteadmin/tests/node_execution_test.go | 283 +++ flyteadmin/tests/project.go | 45 + flyteadmin/tests/shared.go | 6 + flyteadmin/tests/task_execution_test.go | 382 ++++ flyteadmin/tests/task_test.go | 540 +++++ flyteadmin/tests/workflow_test.go | 344 +++ 295 files changed, 38548 insertions(+) create mode 100644 flyteadmin/.dockerignore create mode 100644 flyteadmin/.gitignore create mode 100644 flyteadmin/.golangci.yml create mode 100644 flyteadmin/.travis.yml create mode 100644 flyteadmin/CODE_OF_CONDUCT.md create mode 100644 flyteadmin/Dockerfile create mode 100644 flyteadmin/Gopkg.lock create mode 100644 flyteadmin/Gopkg.toml create mode 100644 flyteadmin/LICENSE create mode 100644 flyteadmin/Makefile create mode 100644 flyteadmin/NOTICE create mode 100644 flyteadmin/README.rst create mode 100644 flyteadmin/boilerplate/lyft/docker_build/Makefile create mode 100644 flyteadmin/boilerplate/lyft/docker_build/Readme.rst create mode 100755 flyteadmin/boilerplate/lyft/docker_build/docker_build.sh create mode 100644 flyteadmin/boilerplate/lyft/golang_dockerfile/Dockerfile.GoTemplate create mode 100644 flyteadmin/boilerplate/lyft/golang_dockerfile/Readme.rst create mode 100755 flyteadmin/boilerplate/lyft/golang_dockerfile/update.sh create mode 100644 flyteadmin/boilerplate/lyft/golang_test_targets/Makefile create mode 100644 flyteadmin/boilerplate/lyft/golang_test_targets/Readme.rst create mode 100755 flyteadmin/boilerplate/lyft/golang_test_targets/goimports create mode 100644 flyteadmin/boilerplate/update.cfg create mode 100755 flyteadmin/boilerplate/update.sh create mode 100644 flyteadmin/cmd/entrypoints/clusterresource.go create mode 100644 flyteadmin/cmd/entrypoints/migrate.go create mode 100644 flyteadmin/cmd/entrypoints/root.go create mode 100644 flyteadmin/cmd/entrypoints/serve.go create mode 100644 flyteadmin/cmd/main.go create mode 100644 flyteadmin/flyteadmin_config.yaml create mode 100644 flyteadmin/pkg/async/notifications/email.go create mode 100644 flyteadmin/pkg/async/notifications/email_test.go create mode 100644 flyteadmin/pkg/async/notifications/factory.go create mode 100644 flyteadmin/pkg/async/notifications/implementations/aws_emailer.go create mode 100644 flyteadmin/pkg/async/notifications/implementations/aws_emailer_test.go create mode 100644 flyteadmin/pkg/async/notifications/implementations/noop_notifications.go create mode 100644 flyteadmin/pkg/async/notifications/implementations/processor.go create mode 100644 flyteadmin/pkg/async/notifications/implementations/processor_test.go create mode 100644 flyteadmin/pkg/async/notifications/implementations/publisher.go create mode 100644 flyteadmin/pkg/async/notifications/implementations/publisher_test.go create mode 100644 flyteadmin/pkg/async/notifications/interfaces/emailer.go create mode 100644 flyteadmin/pkg/async/notifications/interfaces/processor.go create mode 100644 flyteadmin/pkg/async/notifications/interfaces/publisher.go create mode 100644 flyteadmin/pkg/async/notifications/mocks/emailer.go create mode 100644 flyteadmin/pkg/async/notifications/mocks/processor.go create mode 100644 flyteadmin/pkg/async/notifications/mocks/publisher.go create mode 100644 flyteadmin/pkg/async/schedule/aws/cloud_watch_scheduler.go create mode 100644 flyteadmin/pkg/async/schedule/aws/cloud_watch_scheduler_test.go create mode 100644 flyteadmin/pkg/async/schedule/aws/interfaces/cloud_watch_event_client.go create mode 100644 flyteadmin/pkg/async/schedule/aws/mocks/mock_cloud_watch_event_client.go create mode 100644 flyteadmin/pkg/async/schedule/aws/serialization.go create mode 100644 flyteadmin/pkg/async/schedule/aws/serialization_test.go create mode 100644 flyteadmin/pkg/async/schedule/aws/shared.go create mode 100644 flyteadmin/pkg/async/schedule/aws/shared_test.go create mode 100644 flyteadmin/pkg/async/schedule/aws/workflow_executor.go create mode 100644 flyteadmin/pkg/async/schedule/aws/workflow_executor_test.go create mode 100644 flyteadmin/pkg/async/schedule/factory.go create mode 100644 flyteadmin/pkg/async/schedule/interfaces/event_scheduler.go create mode 100644 flyteadmin/pkg/async/schedule/interfaces/workflow_executor.go create mode 100644 flyteadmin/pkg/async/schedule/mocks/mock_event_scheduler.go create mode 100644 flyteadmin/pkg/async/schedule/mocks/mock_workflow_executor.go create mode 100644 flyteadmin/pkg/async/schedule/noop/event_scheduler.go create mode 100644 flyteadmin/pkg/async/schedule/noop/workflow_executor.go create mode 100644 flyteadmin/pkg/clusterresource/controller.go create mode 100644 flyteadmin/pkg/clusterresource/controller_test.go create mode 100644 flyteadmin/pkg/common/cloud.go create mode 100644 flyteadmin/pkg/common/constants.go create mode 100644 flyteadmin/pkg/common/entity.go create mode 100644 flyteadmin/pkg/common/executions.go create mode 100644 flyteadmin/pkg/common/executions_test.go create mode 100644 flyteadmin/pkg/common/filters.go create mode 100644 flyteadmin/pkg/common/filters_test.go create mode 100644 flyteadmin/pkg/common/mocks/storage.go create mode 100644 flyteadmin/pkg/common/sorting.go create mode 100644 flyteadmin/pkg/common/sorting_test.go create mode 100644 flyteadmin/pkg/config/config.go create mode 100755 flyteadmin/pkg/config/config_flags.go create mode 100755 flyteadmin/pkg/config/config_flags_test.go create mode 100644 flyteadmin/pkg/data/factory.go create mode 100644 flyteadmin/pkg/data/implementations/aws_remote_url.go create mode 100644 flyteadmin/pkg/data/implementations/aws_remote_url_test.go create mode 100644 flyteadmin/pkg/data/implementations/noop_remote_url.go create mode 100644 flyteadmin/pkg/data/implementations/noop_remote_url_test.go create mode 100644 flyteadmin/pkg/data/interfaces/remote.go create mode 100644 flyteadmin/pkg/data/mocks/remote.go create mode 100644 flyteadmin/pkg/errors/errors.go create mode 100644 flyteadmin/pkg/errors/errors_test.go create mode 100644 flyteadmin/pkg/flytek8s/client.go create mode 100644 flyteadmin/pkg/manager/impl/execution_manager.go create mode 100644 flyteadmin/pkg/manager/impl/execution_manager_test.go create mode 100644 flyteadmin/pkg/manager/impl/executions/queues.go create mode 100644 flyteadmin/pkg/manager/impl/executions/queues_test.go create mode 100644 flyteadmin/pkg/manager/impl/launch_plan_manager.go create mode 100644 flyteadmin/pkg/manager/impl/launch_plan_manager_test.go create mode 100644 flyteadmin/pkg/manager/impl/node_execution_manager.go create mode 100644 flyteadmin/pkg/manager/impl/node_execution_manager_test.go create mode 100644 flyteadmin/pkg/manager/impl/project_manager.go create mode 100644 flyteadmin/pkg/manager/impl/project_manager_test.go create mode 100644 flyteadmin/pkg/manager/impl/shared/constants.go create mode 100644 flyteadmin/pkg/manager/impl/shared/errors.go create mode 100644 flyteadmin/pkg/manager/impl/task_execution_manager.go create mode 100644 flyteadmin/pkg/manager/impl/task_execution_manager_test.go create mode 100644 flyteadmin/pkg/manager/impl/task_manager.go create mode 100644 flyteadmin/pkg/manager/impl/task_manager_test.go create mode 100644 flyteadmin/pkg/manager/impl/testutils/config.go create mode 100644 flyteadmin/pkg/manager/impl/testutils/constants.go create mode 100644 flyteadmin/pkg/manager/impl/testutils/mock_closures.go create mode 100644 flyteadmin/pkg/manager/impl/testutils/mock_requests.go create mode 100644 flyteadmin/pkg/manager/impl/testutils/repository.go create mode 100644 flyteadmin/pkg/manager/impl/util/digests.go create mode 100644 flyteadmin/pkg/manager/impl/util/digests_test.go create mode 100644 flyteadmin/pkg/manager/impl/util/filters.go create mode 100644 flyteadmin/pkg/manager/impl/util/filters_test.go create mode 100644 flyteadmin/pkg/manager/impl/util/shared.go create mode 100644 flyteadmin/pkg/manager/impl/util/shared_test.go create mode 100644 flyteadmin/pkg/manager/impl/util/testdata/workflow.json create mode 100644 flyteadmin/pkg/manager/impl/validation/execution_validator.go create mode 100644 flyteadmin/pkg/manager/impl/validation/execution_validator_test.go create mode 100644 flyteadmin/pkg/manager/impl/validation/launch_plan_validator.go create mode 100644 flyteadmin/pkg/manager/impl/validation/launch_plan_validator_test.go create mode 100644 flyteadmin/pkg/manager/impl/validation/node_execution_validator.go create mode 100644 flyteadmin/pkg/manager/impl/validation/node_execution_validator_test.go create mode 100644 flyteadmin/pkg/manager/impl/validation/project_validator.go create mode 100644 flyteadmin/pkg/manager/impl/validation/project_validator_test.go create mode 100644 flyteadmin/pkg/manager/impl/validation/task_execution_validator.go create mode 100644 flyteadmin/pkg/manager/impl/validation/task_execution_validator_test.go create mode 100644 flyteadmin/pkg/manager/impl/validation/task_validator.go create mode 100644 flyteadmin/pkg/manager/impl/validation/task_validator_test.go create mode 100644 flyteadmin/pkg/manager/impl/validation/validation.go create mode 100644 flyteadmin/pkg/manager/impl/validation/validation_test.go create mode 100644 flyteadmin/pkg/manager/impl/validation/workflow_validator.go create mode 100644 flyteadmin/pkg/manager/impl/validation/workflow_validator_test.go create mode 100644 flyteadmin/pkg/manager/impl/workflow_manager.go create mode 100644 flyteadmin/pkg/manager/impl/workflow_manager_test.go create mode 100644 flyteadmin/pkg/manager/interfaces/execution.go create mode 100644 flyteadmin/pkg/manager/interfaces/launch_plan.go create mode 100644 flyteadmin/pkg/manager/interfaces/node_execution.go create mode 100644 flyteadmin/pkg/manager/interfaces/project.go create mode 100644 flyteadmin/pkg/manager/interfaces/task.go create mode 100644 flyteadmin/pkg/manager/interfaces/task_execution.go create mode 100644 flyteadmin/pkg/manager/interfaces/workflow.go create mode 100644 flyteadmin/pkg/manager/mocks/execution.go create mode 100644 flyteadmin/pkg/manager/mocks/launch_plan.go create mode 100644 flyteadmin/pkg/manager/mocks/node_execution.go create mode 100644 flyteadmin/pkg/manager/mocks/project.go create mode 100644 flyteadmin/pkg/manager/mocks/task.go create mode 100644 flyteadmin/pkg/manager/mocks/task_execution.go create mode 100644 flyteadmin/pkg/manager/mocks/workflow.go create mode 100644 flyteadmin/pkg/repositories/config/database.go create mode 100644 flyteadmin/pkg/repositories/config/migration_models.go create mode 100644 flyteadmin/pkg/repositories/config/migrations.go create mode 100644 flyteadmin/pkg/repositories/config/postgres.go create mode 100644 flyteadmin/pkg/repositories/config/postgres_test.go create mode 100644 flyteadmin/pkg/repositories/config/seed_data.go create mode 100644 flyteadmin/pkg/repositories/database_test.go create mode 100644 flyteadmin/pkg/repositories/errors/error_transformer.go create mode 100644 flyteadmin/pkg/repositories/errors/errors.go create mode 100644 flyteadmin/pkg/repositories/errors/postgres.go create mode 100644 flyteadmin/pkg/repositories/errors/postgres_test.go create mode 100644 flyteadmin/pkg/repositories/errors/test_error_transformer.go create mode 100644 flyteadmin/pkg/repositories/factory.go create mode 100644 flyteadmin/pkg/repositories/gormimpl/common.go create mode 100644 flyteadmin/pkg/repositories/gormimpl/execution_repo.go create mode 100644 flyteadmin/pkg/repositories/gormimpl/execution_repo_test.go create mode 100644 flyteadmin/pkg/repositories/gormimpl/launch_plan_repo.go create mode 100644 flyteadmin/pkg/repositories/gormimpl/launch_plan_repo_test.go create mode 100644 flyteadmin/pkg/repositories/gormimpl/metrics.go create mode 100644 flyteadmin/pkg/repositories/gormimpl/node_execution_repo.go create mode 100644 flyteadmin/pkg/repositories/gormimpl/node_execution_repo_test.go create mode 100644 flyteadmin/pkg/repositories/gormimpl/project_repo.go create mode 100644 flyteadmin/pkg/repositories/gormimpl/project_repo_test.go create mode 100644 flyteadmin/pkg/repositories/gormimpl/task_execution_repo.go create mode 100644 flyteadmin/pkg/repositories/gormimpl/task_execution_repo_test.go create mode 100644 flyteadmin/pkg/repositories/gormimpl/task_repo.go create mode 100644 flyteadmin/pkg/repositories/gormimpl/task_repo_test.go create mode 100644 flyteadmin/pkg/repositories/gormimpl/test_utils.go create mode 100644 flyteadmin/pkg/repositories/gormimpl/workflow_repo.go create mode 100644 flyteadmin/pkg/repositories/gormimpl/workflow_repo_test.go create mode 100644 flyteadmin/pkg/repositories/interfaces/common.go create mode 100644 flyteadmin/pkg/repositories/interfaces/execution_repo.go create mode 100644 flyteadmin/pkg/repositories/interfaces/launch_plan_repo.go create mode 100644 flyteadmin/pkg/repositories/interfaces/node_execution_repo.go create mode 100644 flyteadmin/pkg/repositories/interfaces/project_repo.go create mode 100644 flyteadmin/pkg/repositories/interfaces/task_execution_repo.go create mode 100644 flyteadmin/pkg/repositories/interfaces/task_repo.go create mode 100644 flyteadmin/pkg/repositories/interfaces/workflow_repo.go create mode 100644 flyteadmin/pkg/repositories/mocks/execution_repo.go create mode 100644 flyteadmin/pkg/repositories/mocks/launch_plan_repo.go create mode 100644 flyteadmin/pkg/repositories/mocks/node_execution_repo.go create mode 100644 flyteadmin/pkg/repositories/mocks/project_repo.go create mode 100644 flyteadmin/pkg/repositories/mocks/repository.go create mode 100644 flyteadmin/pkg/repositories/mocks/task_execution_repo.go create mode 100644 flyteadmin/pkg/repositories/mocks/task_repo.go create mode 100644 flyteadmin/pkg/repositories/mocks/workflow_repo.go create mode 100644 flyteadmin/pkg/repositories/models/base_model.go create mode 100644 flyteadmin/pkg/repositories/models/execution.go create mode 100644 flyteadmin/pkg/repositories/models/execution_event.go create mode 100644 flyteadmin/pkg/repositories/models/launch_plan.go create mode 100644 flyteadmin/pkg/repositories/models/node_execution.go create mode 100644 flyteadmin/pkg/repositories/models/node_execution_event.go create mode 100644 flyteadmin/pkg/repositories/models/project.go create mode 100644 flyteadmin/pkg/repositories/models/task.go create mode 100644 flyteadmin/pkg/repositories/models/task_execution.go create mode 100644 flyteadmin/pkg/repositories/models/workflow.go create mode 100644 flyteadmin/pkg/repositories/postgres_repo.go create mode 100644 flyteadmin/pkg/repositories/transformers/execution.go create mode 100644 flyteadmin/pkg/repositories/transformers/execution_event.go create mode 100644 flyteadmin/pkg/repositories/transformers/execution_event_test.go create mode 100644 flyteadmin/pkg/repositories/transformers/execution_test.go create mode 100644 flyteadmin/pkg/repositories/transformers/launch_plan.go create mode 100644 flyteadmin/pkg/repositories/transformers/launch_plan_test.go create mode 100644 flyteadmin/pkg/repositories/transformers/node_execution.go create mode 100644 flyteadmin/pkg/repositories/transformers/node_execution_event.go create mode 100644 flyteadmin/pkg/repositories/transformers/node_execution_event_test.go create mode 100644 flyteadmin/pkg/repositories/transformers/node_execution_test.go create mode 100644 flyteadmin/pkg/repositories/transformers/project.go create mode 100644 flyteadmin/pkg/repositories/transformers/project_test.go create mode 100644 flyteadmin/pkg/repositories/transformers/task.go create mode 100644 flyteadmin/pkg/repositories/transformers/task_execution.go create mode 100644 flyteadmin/pkg/repositories/transformers/task_execution_test.go create mode 100644 flyteadmin/pkg/repositories/transformers/task_test.go create mode 100644 flyteadmin/pkg/repositories/transformers/workflow.go create mode 100644 flyteadmin/pkg/repositories/transformers/workflow_test.go create mode 100644 flyteadmin/pkg/rpc/adminservice/base.go create mode 100644 flyteadmin/pkg/rpc/adminservice/execution.go create mode 100644 flyteadmin/pkg/rpc/adminservice/launch_plan.go create mode 100644 flyteadmin/pkg/rpc/adminservice/metrics.go create mode 100644 flyteadmin/pkg/rpc/adminservice/node_execution.go create mode 100644 flyteadmin/pkg/rpc/adminservice/project.go create mode 100644 flyteadmin/pkg/rpc/adminservice/task.go create mode 100644 flyteadmin/pkg/rpc/adminservice/task_execution.go create mode 100644 flyteadmin/pkg/rpc/adminservice/tests/execution_test.go create mode 100644 flyteadmin/pkg/rpc/adminservice/tests/launch_plan_test.go create mode 100644 flyteadmin/pkg/rpc/adminservice/tests/node_execution_test.go create mode 100644 flyteadmin/pkg/rpc/adminservice/tests/project_test.go create mode 100644 flyteadmin/pkg/rpc/adminservice/tests/task_execution_test.go create mode 100644 flyteadmin/pkg/rpc/adminservice/tests/task_test.go create mode 100644 flyteadmin/pkg/rpc/adminservice/tests/util.go create mode 100644 flyteadmin/pkg/rpc/adminservice/tests/workflow_test.go create mode 100644 flyteadmin/pkg/rpc/adminservice/util/metrics.go create mode 100644 flyteadmin/pkg/rpc/adminservice/util/transformers.go create mode 100644 flyteadmin/pkg/rpc/adminservice/util/transformers_test.go create mode 100644 flyteadmin/pkg/rpc/adminservice/workflow.go create mode 100644 flyteadmin/pkg/runtime/application_config_provider.go create mode 100644 flyteadmin/pkg/runtime/cluster_config_provider.go create mode 100644 flyteadmin/pkg/runtime/cluster_resource_provider.go create mode 100644 flyteadmin/pkg/runtime/cluster_resource_provider_test.go create mode 100644 flyteadmin/pkg/runtime/config_provider_test.go create mode 100644 flyteadmin/pkg/runtime/configuration_provider.go create mode 100644 flyteadmin/pkg/runtime/execution_queue_provider.go create mode 100644 flyteadmin/pkg/runtime/interfaces/application_configuration.go create mode 100644 flyteadmin/pkg/runtime/interfaces/cluster_configuration.go create mode 100644 flyteadmin/pkg/runtime/interfaces/cluster_resource_configuration.go create mode 100644 flyteadmin/pkg/runtime/interfaces/configuration.go create mode 100644 flyteadmin/pkg/runtime/interfaces/queue_configuration.go create mode 100644 flyteadmin/pkg/runtime/interfaces/registration_validation_provider.go create mode 100644 flyteadmin/pkg/runtime/interfaces/task_resource_configuration.go create mode 100644 flyteadmin/pkg/runtime/interfaces/whitelist.go create mode 100644 flyteadmin/pkg/runtime/mocks/mock_application_provider.go create mode 100644 flyteadmin/pkg/runtime/mocks/mock_cluster_resource_provider.go create mode 100644 flyteadmin/pkg/runtime/mocks/mock_configuration_provider.go create mode 100644 flyteadmin/pkg/runtime/mocks/mock_execution_queue_provider.go create mode 100644 flyteadmin/pkg/runtime/mocks/mock_registration_validation_provider.go create mode 100644 flyteadmin/pkg/runtime/mocks/mock_task_resource_provider.go create mode 100644 flyteadmin/pkg/runtime/mocks/mock_whitelist_provider.go create mode 100644 flyteadmin/pkg/runtime/registration_validation_provider.go create mode 100644 flyteadmin/pkg/runtime/task_resource_provider.go create mode 100644 flyteadmin/pkg/runtime/testdata/cluster_resource_config.yaml create mode 100644 flyteadmin/pkg/runtime/testdata/clusters_config.yaml create mode 100644 flyteadmin/pkg/runtime/testdata/config.yaml create mode 100644 flyteadmin/pkg/runtime/whitelist_provider.go create mode 100644 flyteadmin/pkg/workflowengine/impl/compiler.go create mode 100644 flyteadmin/pkg/workflowengine/impl/interface_provider.go create mode 100644 flyteadmin/pkg/workflowengine/impl/interface_provider_test.go create mode 100644 flyteadmin/pkg/workflowengine/impl/propeller_executor.go create mode 100644 flyteadmin/pkg/workflowengine/impl/propeller_executor_test.go create mode 100644 flyteadmin/pkg/workflowengine/interfaces/compiler.go create mode 100644 flyteadmin/pkg/workflowengine/interfaces/executor.go create mode 100644 flyteadmin/pkg/workflowengine/mocks/mock_compiler.go create mode 100644 flyteadmin/pkg/workflowengine/mocks/mock_executor.go create mode 100644 flyteadmin/sampleresourcetemplates/docker.yaml create mode 100644 flyteadmin/sampleresourcetemplates/imagepullsecrets.yaml create mode 100644 flyteadmin/sampleresourcetemplates/namespace.yaml create mode 100644 flyteadmin/script/integration/k8s/integration.yaml create mode 100755 flyteadmin/script/integration/k8s/main.sh create mode 100755 flyteadmin/script/integration/launch.sh create mode 100644 flyteadmin/tests/bootstrap.go create mode 100644 flyteadmin/tests/execution_test.go create mode 100644 flyteadmin/tests/helpers.go create mode 100644 flyteadmin/tests/launch_plan_test.go create mode 100644 flyteadmin/tests/node_execution_test.go create mode 100644 flyteadmin/tests/project.go create mode 100644 flyteadmin/tests/shared.go create mode 100644 flyteadmin/tests/task_execution_test.go create mode 100644 flyteadmin/tests/task_test.go create mode 100644 flyteadmin/tests/workflow_test.go diff --git a/flyteadmin/.dockerignore b/flyteadmin/.dockerignore new file mode 100644 index 0000000000..140fada73f --- /dev/null +++ b/flyteadmin/.dockerignore @@ -0,0 +1 @@ +vendor/* diff --git a/flyteadmin/.gitignore b/flyteadmin/.gitignore new file mode 100644 index 0000000000..5ed4a2ac25 --- /dev/null +++ b/flyteadmin/.gitignore @@ -0,0 +1,8 @@ + +.idea/ +.DS_Store +.vscode/ +.vendor-new/ + +vendor/ +node_modules/ diff --git a/flyteadmin/.golangci.yml b/flyteadmin/.golangci.yml new file mode 100644 index 0000000000..3df02b549d --- /dev/null +++ b/flyteadmin/.golangci.yml @@ -0,0 +1,27 @@ +run: + skip-files: + # because we're skipping TLS verification - for now + - cmd/entrypoints/serve.go + - pkg/async/messages/sqs.go + +linters: + disable-all: true + enable: + - deadcode + - errcheck + - gas + - goconst + - goimports + - golint + - gosimple + - govet + - ineffassign + - misspell + - nakedret + - staticcheck + - structcheck + - typecheck + - unconvert + - unparam + - unused + - varcheck diff --git a/flyteadmin/.travis.yml b/flyteadmin/.travis.yml new file mode 100644 index 0000000000..8764b329ad --- /dev/null +++ b/flyteadmin/.travis.yml @@ -0,0 +1,27 @@ +sudo: required +language: go +go: + - "1.10" +services: + - docker +jobs: + include: + # dont push to dockerhub on forks + - if: fork = true + stage: test + name: build, integration test + install: true + script: BUILD_PHASE=builder make docker_build && make k8s_integration + - if: fork = false + stage: test + name: build, integration test, and push + install: true + script: BUILD_PHASE=builder make docker_build && make k8s_integration && make dockerhub_push + - stage: test + name: unit tests + install: make install + script: make test_unit + - stage: test + install: make install + name: lint + script: make lint diff --git a/flyteadmin/CODE_OF_CONDUCT.md b/flyteadmin/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000..803d8a77f3 --- /dev/null +++ b/flyteadmin/CODE_OF_CONDUCT.md @@ -0,0 +1,3 @@ +This project is governed by [Lyft's code of +conduct](https://github.com/lyft/code-of-conduct). All contributors +and participants agree to abide by its terms. diff --git a/flyteadmin/Dockerfile b/flyteadmin/Dockerfile new file mode 100644 index 0000000000..11fde287e9 --- /dev/null +++ b/flyteadmin/Dockerfile @@ -0,0 +1,33 @@ +# WARNING: THIS FILE IS MANAGED IN THE 'BOILERPLATE' REPO AND COPIED TO OTHER REPOSITORIES. +# ONLY EDIT THIS FILE FROM WITHIN THE 'LYFT/BOILERPLATE' REPOSITORY: +# +# TO OPT OUT OF UPDATES, SEE https://github.com/lyft/boilerplate/blob/master/Readme.rst + +# Using go1.10.4 +FROM golang:1.10.4-alpine3.8 as builder +RUN apk add git openssh-client make curl dep + +# COPY only the dep files for efficient caching +COPY Gopkg.* /go/src/github.com/lyft/flyteadmin/ +WORKDIR /go/src/github.com/lyft/flyteadmin + +# Pull dependencies +RUN dep ensure -vendor-only + +# COPY the rest of the source code +COPY . /go/src/github.com/lyft/flyteadmin/ + +# This 'linux_compile' target should compile binaries to the /artifacts directory +# The main entrypoint should be compiled to /artifacts/flyteadmin +RUN make linux_compile + +# update the PATH to include the /artifacts directory +ENV PATH="/artifacts:${PATH}" + +# This will eventually move to centurylink/ca-certs:latest for minimum possible image size +FROM alpine:3.8 +COPY --from=builder /artifacts /bin + +RUN apk --update add ca-certificates + +CMD ["flyteadmin"] diff --git a/flyteadmin/Gopkg.lock b/flyteadmin/Gopkg.lock new file mode 100644 index 0000000000..864d541601 --- /dev/null +++ b/flyteadmin/Gopkg.lock @@ -0,0 +1,1191 @@ +# This file is autogenerated, do not edit; changes may be undone by the next 'dep ensure'. + + +[[projects]] + digest = "1:80004fcc5cf64e591486b3e11b406f1e0d17bf85d475d64203c8494f5da4fcd1" + name = "cloud.google.com/go" + packages = ["compute/metadata"] + pruneopts = "UT" + revision = "ceeb313ad77b789a7fa5287b36a1d127b69b7093" + version = "v0.44.3" + +[[projects]] + digest = "1:94d4ae958b3d2ab476bef4bed53c1dcc3cb0fb2639bd45dd08b40e57139192e5" + name = "github.com/Azure/azure-sdk-for-go" + packages = ["storage"] + pruneopts = "UT" + revision = "2d49bb8f2cee530cc16f1f1a9f0aae763dee257d" + version = "v10.2.1-beta" + +[[projects]] + digest = "1:0aa68ac7d88c06b85442e07b9e4d56cb5e332df2360fa2a5441b2edc5f1ae32b" + name = "github.com/Azure/go-autorest" + packages = [ + "autorest", + "autorest/adal", + "autorest/azure", + "autorest/date", + "logger", + "tracing", + ] + pruneopts = "UT" + revision = "5e7a399d8bbf4953ab0c8e3167d7fd535fd74ce1" + version = "v13.0.0" + +[[projects]] + digest = "1:4d8aa8bc01f60d0fd7f764e1838f26dbc5a5dec428217f936726007cdf3929f0" + name = "github.com/NYTimes/gizmo" + packages = [ + "config/aws", + "pubsub", + "pubsub/aws", + "pubsub/pubsubtest", + ] + pruneopts = "UT" + revision = "27bac814561a097fe9af4585fcefe223315973b2" + version = "v0.4.3" + +[[projects]] + digest = "1:7e704bce17074e862cfe9e4c2849320c2628fc3501b7d0795c589a427ef2bf50" + name = "github.com/Selvatico/go-mocket" + packages = ["."] + pruneopts = "UT" + revision = "c368d4162be502eea110ae12fb85e98567b0f1e6" + version = "v1.0.7" + +[[projects]] + digest = "1:313b743d54588010f7c6f5e00bbfe00ad0a2d63a075cb7d71ea85eaf8f91efa7" + name = "github.com/aws/aws-sdk-go" + packages = [ + "aws", + "aws/awserr", + "aws/awsutil", + "aws/client", + "aws/client/metadata", + "aws/corehandlers", + "aws/credentials", + "aws/credentials/ec2rolecreds", + "aws/credentials/endpointcreds", + "aws/credentials/processcreds", + "aws/credentials/stscreds", + "aws/csm", + "aws/defaults", + "aws/ec2metadata", + "aws/endpoints", + "aws/request", + "aws/session", + "aws/signer/v4", + "internal/ini", + "internal/s3err", + "internal/sdkio", + "internal/sdkmath", + "internal/sdkrand", + "internal/sdkuri", + "internal/shareddefaults", + "private/protocol", + "private/protocol/eventstream", + "private/protocol/eventstream/eventstreamapi", + "private/protocol/json/jsonutil", + "private/protocol/jsonrpc", + "private/protocol/query", + "private/protocol/query/queryutil", + "private/protocol/rest", + "private/protocol/restxml", + "private/protocol/xml/xmlutil", + "service/cloudwatchevents", + "service/elasticache", + "service/s3", + "service/s3/s3iface", + "service/s3/s3manager", + "service/ses", + "service/ses/sesiface", + "service/sns", + "service/sns/snsiface", + "service/sqs", + "service/sqs/sqsiface", + "service/sts", + "service/sts/stsiface", + ] + pruneopts = "UT" + revision = "d57c8d96f72d9475194ccf18d2ba70ac294b0cb3" + version = "v1.23.13" + +[[projects]] + branch = "master" + digest = "1:0ad5484a25fbd88409bae8b8b19134135fe73d3cb00e45d3255280b2ab975fcc" + name = "github.com/benbjohnson/clock" + packages = ["."] + pruneopts = "UT" + revision = "7dc76406b6d3c05b5f71a86293cbcf3c4ea03b19" + +[[projects]] + branch = "master" + digest = "1:a6609679ca468a89b711934f16b346e99f6ec344eadd2f7b00b1156785dd1236" + name = "github.com/benlaurie/objecthash" + packages = ["go/objecthash"] + pruneopts = "UT" + revision = "d1e3d6079fc16f8f542183fb5b2fdc11d9f00866" + +[[projects]] + digest = "1:d6afaeed1502aa28e80a4ed0981d570ad91b2579193404256ce672ed0a609e0d" + name = "github.com/beorn7/perks" + packages = ["quantile"] + pruneopts = "UT" + revision = "37c8de3658fcb183f997c4e13e8337516ab753e6" + version = "v1.0.1" + +[[projects]] + branch = "master" + digest = "1:f98385a9b77f6cacae716a59c04e6ac374d101466d4369c4e8cc706a39c4bb2e" + name = "github.com/bradfitz/gomemcache" + packages = ["memcache"] + pruneopts = "UT" + revision = "551aad21a6682b95329c1f5bd62ee5060d64f7e8" + +[[projects]] + digest = "1:998cf998358a303ac2430c386ba3fd3398477d6013153d3c6e11432765cc9ae6" + name = "github.com/cespare/xxhash" + packages = ["."] + pruneopts = "UT" + revision = "3b82fb7d186719faeedd0c2864f868c74fbf79a1" + version = "v2.0.0" + +[[projects]] + digest = "1:00eb5d8bd96289512920ac43367d5bee76bbca2062da34862a98b26b92741896" + name = "github.com/coocood/freecache" + packages = ["."] + pruneopts = "UT" + revision = "3c79a0a23c1940ab4479332fb3e0127265650ce3" + version = "v1.1.0" + +[[projects]] + digest = "1:ffe9824d294da03b391f44e1ae8281281b4afc1bdaa9588c9097785e3af10cec" + name = "github.com/davecgh/go-spew" + packages = ["spew"] + pruneopts = "UT" + revision = "8991bc29aa16c548c550c7ff78260e27b9ab7c73" + version = "v1.1.1" + +[[projects]] + digest = "1:76dc72490af7174349349838f2fe118996381b31ea83243812a97e5a0fd5ed55" + name = "github.com/dgrijalva/jwt-go" + packages = ["."] + pruneopts = "UT" + revision = "06ea1031745cb8b3dab3f6a236daf2b0aa468b7e" + version = "v3.2.0" + +[[projects]] + digest = "1:865079840386857c809b72ce300be7580cb50d3d3129ce11bf9aa6ca2bc1934a" + name = "github.com/fatih/color" + packages = ["."] + pruneopts = "UT" + revision = "5b77d2a35fb0ede96d138fc9a99f5c9b6aef11b4" + version = "v1.7.0" + +[[projects]] + branch = "master" + digest = "1:78a5b63751bd99054bee07a498f6aa54da0a909922f9365d1aa3339091efa70a" + name = "github.com/fsnotify/fsnotify" + packages = ["."] + pruneopts = "UT" + revision = "1485a34d5d5723fea214f5710708e19a831720e4" + +[[projects]] + digest = "1:4d02824a56d268f74a6b6fdd944b20b58a77c3d70e81008b3ee0c4f1a6777340" + name = "github.com/gogo/protobuf" + packages = [ + "proto", + "sortkeys", + ] + pruneopts = "UT" + revision = "ba06b47c162d49f2af050fb4c75bcbc86a159d5c" + version = "v1.2.1" + +[[projects]] + branch = "master" + digest = "1:1ba1d79f2810270045c328ae5d674321db34e3aae468eb4233883b473c5c0467" + name = "github.com/golang/glog" + packages = ["."] + pruneopts = "UT" + revision = "23def4e6c14b4da8ac2ed8007337bc5eb5007998" + +[[projects]] + digest = "1:b532ee3f683c057e797694b5bfeb3827d89e6adf41c53dbc80e549bca76364ea" + name = "github.com/golang/protobuf" + packages = [ + "jsonpb", + "proto", + "protoc-gen-go/descriptor", + "protoc-gen-go/generator", + "protoc-gen-go/generator/internal/remap", + "protoc-gen-go/plugin", + "ptypes", + "ptypes/any", + "ptypes/duration", + "ptypes/struct", + "ptypes/timestamp", + "ptypes/wrappers", + ] + pruneopts = "UT" + revision = "6c65a5562fc06764971b7c5d05c76c75e84bdbf7" + version = "v1.3.2" + +[[projects]] + digest = "1:0bfbe13936953a98ae3cfe8ed6670d396ad81edf069a806d2f6515d7bb6950df" + name = "github.com/google/btree" + packages = ["."] + pruneopts = "UT" + revision = "4030bb1f1f0c35b30ca7009e9ebd06849dd45306" + version = "v1.0.0" + +[[projects]] + digest = "1:a6181aca1fd5e27103f9a920876f29ac72854df7345a39f3b01e61c8c94cc8af" + name = "github.com/google/gofuzz" + packages = ["."] + pruneopts = "UT" + revision = "f140a6486e521aad38f5917de355cbf147cc0496" + version = "v1.0.0" + +[[projects]] + digest = "1:766102087520f9d54f2acc72bd6637045900ac735b4a419b128d216f0c5c4876" + name = "github.com/googleapis/gax-go" + packages = ["v2"] + pruneopts = "UT" + revision = "bd5b16380fd03dc758d11cef74ba2e3bc8b0e8c2" + version = "v2.0.5" + +[[projects]] + digest = "1:ca4524b4855ded427c7003ec903a5c854f37e7b1e8e2a93277243462c5b753a8" + name = "github.com/googleapis/gnostic" + packages = [ + "OpenAPIv2", + "compiler", + "extensions", + ] + pruneopts = "UT" + revision = "ab0dd09aa10e2952b28e12ecd35681b20463ebab" + version = "v0.3.1" + +[[projects]] + digest = "1:16e1cbd76f0d4152b5573f08f38b451748f74ec59b99a004a7481342b3fc05af" + name = "github.com/graymeta/stow" + packages = [ + ".", + "azure", + "google", + "local", + "oracle", + "s3", + "swift", + ] + pruneopts = "UT" + revision = "903027f87de7054953efcdb8ba70d5dc02df38c7" + +[[projects]] + branch = "master" + digest = "1:5fc0e23b254a1bd7d8d2d42fa093ba33471d08f52fe04afd3713adabb5888dc3" + name = "github.com/gregjones/httpcache" + packages = [ + ".", + "diskcache", + ] + pruneopts = "UT" + revision = "901d90724c7919163f472a9812253fb26761123d" + +[[projects]] + digest = "1:73513cdd52d6f0768201cebbf82612aa39a9d8022bc6337815cd504e532281b7" + name = "github.com/grpc-ecosystem/go-grpc-middleware" + packages = [ + ".", + "retry", + "util/backoffutils", + "util/metautils", + ] + pruneopts = "UT" + revision = "c250d6563d4d4c20252cd865923440e829844f4e" + version = "v1.0.0" + +[[projects]] + digest = "1:9b7a07ac7577787a8ecc1334cb9f34df1c76ed82a917d556c5713d3ab84fbc43" + name = "github.com/grpc-ecosystem/go-grpc-prometheus" + packages = ["."] + pruneopts = "UT" + revision = "c225b8c3b01faf2899099b768856a9e916e5087b" + version = "v1.2.0" + +[[projects]] + digest = "1:9da9ffdf93e29e054fb3b066e3c258e8ed090f6bec4bba1e86aeb9b1ba0056a9" + name = "github.com/grpc-ecosystem/grpc-gateway" + packages = [ + "internal", + "protoc-gen-swagger/options", + "runtime", + "utilities", + ] + pruneopts = "UT" + revision = "a9bbe40ed238db18f710b0e3d2970348c8fcec41" + version = "v1.10.0" + +[[projects]] + digest = "1:7fae9ec96d10b2afce0da23c378c8b3389319b7f92fa092f2621bba3078cfb4b" + name = "github.com/hashicorp/golang-lru" + packages = ["simplelru"] + pruneopts = "UT" + revision = "7f827b33c0f158ec5dfbba01bb0b14a4541fd81d" + version = "v0.5.3" + +[[projects]] + digest = "1:c0d19ab64b32ce9fe5cf4ddceba78d5bc9807f0016db6b1183599da3dcc24d10" + name = "github.com/hashicorp/hcl" + packages = [ + ".", + "hcl/ast", + "hcl/parser", + "hcl/printer", + "hcl/scanner", + "hcl/strconv", + "hcl/token", + "json/parser", + "json/scanner", + "json/token", + ] + pruneopts = "UT" + revision = "8cb6e5b959231cc1119e43259c4a608f9c51a241" + version = "v1.0.0" + +[[projects]] + digest = "1:a0cefd27d12712af4b5018dc7046f245e1e3b5760e2e848c30b171b570708f9b" + name = "github.com/imdario/mergo" + packages = ["."] + pruneopts = "UT" + revision = "7c29201646fa3de8506f701213473dd407f19646" + version = "v0.3.7" + +[[projects]] + digest = "1:870d441fe217b8e689d7949fef6e43efbc787e50f200cb1e70dbca9204a1d6be" + name = "github.com/inconshreveable/mousetrap" + packages = ["."] + pruneopts = "UT" + revision = "76626ae9c91c4f2a10f34cad8ce83ea42c93bb75" + version = "v1.0" + +[[projects]] + digest = "1:da6718abe4d47b1132d98bf3f9b18e302d537bf6daf02bd40804d9295a3f32bd" + name = "github.com/jinzhu/gorm" + packages = [ + ".", + "dialects/postgres", + ] + pruneopts = "UT" + revision = "836fb2c19d84dac7b0272958dfb9af7cf0d0ade4" + version = "v1.9.10" + +[[projects]] + digest = "1:01ed62f8f4f574d8aff1d88caee113700a2b44c42351943fa73cc1808f736a50" + name = "github.com/jinzhu/inflection" + packages = ["."] + pruneopts = "UT" + revision = "f5c5f50e6090ae76a29240b61ae2a90dd810112e" + version = "v1.0.0" + +[[projects]] + digest = "1:bb81097a5b62634f3e9fec1014657855610c82d19b9a40c17612e32651e35dca" + name = "github.com/jmespath/go-jmespath" + packages = ["."] + pruneopts = "UT" + revision = "c2b33e84" + +[[projects]] + digest = "1:709cd2a2c29cc9b89732f6c24846bbb9d6270f28ef5ef2128cc73bd0d6d7bff9" + name = "github.com/json-iterator/go" + packages = ["."] + pruneopts = "UT" + revision = "27518f6661eba504be5a7a9a9f6d9460d892ade3" + version = "v1.1.7" + +[[projects]] + digest = "1:fd9bea48bbc5bba66d9891c72af7255fbebecdff845c37c679406174ece5ca1b" + name = "github.com/kelseyhightower/envconfig" + packages = ["."] + pruneopts = "UT" + revision = "0b417c4ec4a8a82eecc22a1459a504aa55163d61" + version = "v1.4.0" + +[[projects]] + digest = "1:31e761d97c76151dde79e9d28964a812c46efc5baee4085b86f68f0c654450de" + name = "github.com/konsorten/go-windows-terminal-sequences" + packages = ["."] + pruneopts = "UT" + revision = "f55edac94c9bbba5d6182a4be46d86a2c9b5b50e" + version = "v1.0.2" + +[[projects]] + digest = "1:0ead8e64fe356bd9221605e3ec40b4438509868018cbbbaaaff3ebae1b69b78b" + name = "github.com/lib/pq" + packages = [ + ".", + "hstore", + "oid", + "scram", + ] + pruneopts = "UT" + revision = "3427c32cb71afc948325f299f040e53c1dd78979" + version = "v1.2.0" + +[[projects]] + digest = "1:4c02e347457c97ee8cfafb413554854fe236d715879ac0a43743017cd179de2e" + name = "github.com/lyft/flyteidl" + packages = [ + "clients/go/admin", + "clients/go/admin/mocks", + "clients/go/events", + "clients/go/events/errors", + "gen/pb-go/flyteidl/admin", + "gen/pb-go/flyteidl/core", + "gen/pb-go/flyteidl/event", + "gen/pb-go/flyteidl/service", + ] + pruneopts = "UT" + revision = "c92b79f5f448ec36420eb79bbebb2b372261b77f" + source = "https://github.com/lyft/flyteidl" + version = "v0.1.1" + +[[projects]] + digest = "1:09785a77f804b9b5524cfec6d6240ea0ce53251a38eb55abeb616bcfdd85de99" + name = "github.com/lyft/flyteplugins" + packages = ["go/tasks/v1/types"] + pruneopts = "UT" + revision = "9156da396c7af5b34b4411c3ec99470864425b18" + source = "https://github.com/lyft/flyteplugins" + version = "v0.1.1" + +[[projects]] + digest = "1:3dfb37d4f608c21e5f1d14de40b82d919b76c5044cc6daf38f94a98162e899c7" + name = "github.com/lyft/flytepropeller" + packages = [ + "pkg/apis/flyteworkflow", + "pkg/apis/flyteworkflow/v1alpha1", + "pkg/client/clientset/versioned", + "pkg/client/clientset/versioned/scheme", + "pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1", + "pkg/compiler", + "pkg/compiler/common", + "pkg/compiler/errors", + "pkg/compiler/transformers/k8s", + "pkg/compiler/typing", + "pkg/compiler/validators", + "pkg/utils", + ] + pruneopts = "UT" + revision = "40db32eaa4dc75293560e50c51c2120c9c41d4bb" + source = "https://github.com/lyft/flytepropeller" + version = "v0.1.0" + +[[projects]] + digest = "1:3218b76036eebb079cc456504891ab7b5edace6bc8ce8473b507a5cfd7a6f81e" + name = "github.com/lyft/flytestdlib" + packages = [ + "atomic", + "config", + "config/files", + "config/viper", + "contextutils", + "errors", + "ioutils", + "logger", + "pbhash", + "profutils", + "promutils", + "promutils/labeled", + "storage", + "version", + ] + pruneopts = "UT" + revision = "7292f20ec17b42f104fd61d7f0120e17bcacf751" + source = "https://github.com/lyft/flytestdlib" + version = "v0.2.16" + +[[projects]] + digest = "1:2a0da3440db3f2892609d99cd0389c2776a3fef24435f7b7b58bfc9030aa86ca" + name = "github.com/magiconair/properties" + packages = [ + ".", + "assert", + ] + pruneopts = "UT" + revision = "de8848e004dd33dc07a2947b3d76f618a7fc7ef1" + version = "v1.8.1" + +[[projects]] + digest = "1:c658e84ad3916da105a761660dcaeb01e63416c8ec7bc62256a9b411a05fcd67" + name = "github.com/mattn/go-colorable" + packages = ["."] + pruneopts = "UT" + revision = "167de6bfdfba052fa6b2d3664c8f5272e23c9072" + version = "v0.0.9" + +[[projects]] + digest = "1:36325ebb862e0382f2f14feef409ba9351271b89ada286ae56836c603d43b59c" + name = "github.com/mattn/go-isatty" + packages = ["."] + pruneopts = "UT" + revision = "e1f7b56ace729e4a73a29a6b4fac6cd5fcda7ab3" + version = "v0.0.9" + +[[projects]] + digest = "1:ff5ebae34cfbf047d505ee150de27e60570e8c394b3b8fdbb720ff6ac71985fc" + name = "github.com/matttproud/golang_protobuf_extensions" + packages = ["pbutil"] + pruneopts = "UT" + revision = "c12348ce28de40eed0136aa2b644d0ee0650e56c" + version = "v1.0.1" + +[[projects]] + digest = "1:53bc4cd4914cd7cd52139990d5170d6dc99067ae31c56530621b18b35fc30318" + name = "github.com/mitchellh/mapstructure" + packages = ["."] + pruneopts = "UT" + revision = "3536a929edddb9a5b34bd6861dc4a9647cb459fe" + version = "v1.1.2" + +[[projects]] + digest = "1:33422d238f147d247752996a26574ac48dcf472976eda7f5134015f06bf16563" + name = "github.com/modern-go/concurrent" + packages = ["."] + pruneopts = "UT" + revision = "bacd9c7ef1dd9b15be4a9909b8ac7a4e313eec94" + version = "1.0.3" + +[[projects]] + digest = "1:e32bdbdb7c377a07a9a46378290059822efdce5c8d96fe71940d87cb4f918855" + name = "github.com/modern-go/reflect2" + packages = ["."] + pruneopts = "UT" + revision = "4b7aa43c6742a2c18fdef89dd197aaae7dac7ccd" + version = "1.0.1" + +[[projects]] + branch = "master" + digest = "1:2339820c575323b56a7f94146a2549fd344c51c637fa5b8bafae9695ffa6e1a5" + name = "github.com/ncw/swift" + packages = ["."] + pruneopts = "UT" + revision = "a24ef33bc9b7e59ae4bed9e87a51d7bc76122731" + +[[projects]] + digest = "1:93131d8002d7025da13582877c32d1fc302486775a1b06f62241741006428c5e" + name = "github.com/pelletier/go-toml" + packages = ["."] + pruneopts = "UT" + revision = "728039f679cbcd4f6a54e080d2219a4c4928c546" + version = "v1.4.0" + +[[projects]] + branch = "master" + digest = "1:89da0f0574bc94cfd0ac8b59af67bf76cdd110d503df2721006b9f0492394333" + name = "github.com/petar/GoLLRB" + packages = ["llrb"] + pruneopts = "UT" + revision = "33fb24c13b99c46c93183c291836c573ac382536" + +[[projects]] + digest = "1:a8c2725121694dfbf6d552fb86fe6b46e3e7135ea05db580c28695b916162aad" + name = "github.com/peterbourgon/diskv" + packages = ["."] + pruneopts = "UT" + revision = "0be1b92a6df0e4f5cb0a5d15fb7f643d0ad93ce6" + version = "v3.0.0" + +[[projects]] + digest = "1:cf31692c14422fa27c83a05292eb5cbe0fb2775972e8f1f8446a71549bd8980b" + name = "github.com/pkg/errors" + packages = ["."] + pruneopts = "UT" + revision = "ba968bfe8b2f7e042a574c888954fccecfa385b4" + version = "v0.8.1" + +[[projects]] + digest = "1:0028cb19b2e4c3112225cd871870f2d9cf49b9b4276531f03438a88e94be86fe" + name = "github.com/pmezard/go-difflib" + packages = ["difflib"] + pruneopts = "UT" + revision = "792786c7400a136282c1664665ae0a8db921c6c2" + version = "v1.0.0" + +[[projects]] + digest = "1:e89f2cdede55684adbe44b5566f55838ad2aee1dff348d14b73ccf733607b671" + name = "github.com/prometheus/client_golang" + packages = [ + "prometheus", + "prometheus/internal", + "prometheus/promhttp", + ] + pruneopts = "UT" + revision = "2641b987480bca71fb39738eb8c8b0d577cb1d76" + version = "v0.9.4" + +[[projects]] + branch = "master" + digest = "1:2d5cd61daa5565187e1d96bae64dbbc6080dacf741448e9629c64fd93203b0d4" + name = "github.com/prometheus/client_model" + packages = ["go"] + pruneopts = "UT" + revision = "14fe0d1b01d4d5fc031dd4bec1823bd3ebbe8016" + +[[projects]] + digest = "1:8dcedf2e8f06c7f94e48267dea0bc0be261fa97b377f3ae3e87843a92a549481" + name = "github.com/prometheus/common" + packages = [ + "expfmt", + "internal/bitbucket.org/ww/goautoneg", + "model", + ] + pruneopts = "UT" + revision = "31bed53e4047fd6c510e43a941f90cb31be0972a" + version = "v0.6.0" + +[[projects]] + digest = "1:8232537905152d6a0b116b9af5a0868fcac0e84eb02ec5a150624c077bdedb0b" + name = "github.com/prometheus/procfs" + packages = [ + ".", + "internal/fs", + "internal/util", + ] + pruneopts = "UT" + revision = "00ec24a6a2d86e7074629c8384715dbb05adccd8" + version = "v0.0.4" + +[[projects]] + digest = "1:274f67cb6fed9588ea2521ecdac05a6d62a8c51c074c1fccc6a49a40ba80e925" + name = "github.com/satori/uuid" + packages = ["."] + pruneopts = "UT" + revision = "f58768cc1a7a7e77a3bd49e98cdd21419399b6a3" + version = "v1.2.0" + +[[projects]] + digest = "1:04457f9f6f3ffc5fea48e71d62f2ca256637dee0a04d710288e27e05c8b41976" + name = "github.com/sirupsen/logrus" + packages = ["."] + pruneopts = "UT" + revision = "839c75faf7f98a33d445d181f3018b5c3409a45e" + version = "v1.4.2" + +[[projects]] + digest = "1:bb495ec276ab82d3dd08504bbc0594a65de8c3b22c6f2aaa92d05b73fbf3a82e" + name = "github.com/spf13/afero" + packages = [ + ".", + "mem", + ] + pruneopts = "UT" + revision = "588a75ec4f32903aa5e39a2619ba6a4631e28424" + version = "v1.2.2" + +[[projects]] + digest = "1:08d65904057412fc0270fc4812a1c90c594186819243160dc779a402d4b6d0bc" + name = "github.com/spf13/cast" + packages = ["."] + pruneopts = "UT" + revision = "8c9545af88b134710ab1cd196795e7f2388358d7" + version = "v1.3.0" + +[[projects]] + digest = "1:e096613fb7cf34743d49af87d197663cfccd61876e2219853005a57baedfa562" + name = "github.com/spf13/cobra" + packages = ["."] + pruneopts = "UT" + revision = "f2b07da1e2c38d5f12845a4f607e2e1018cbb1f5" + version = "v0.0.5" + +[[projects]] + digest = "1:1b753ec16506f5864d26a28b43703c58831255059644351bbcb019b843950900" + name = "github.com/spf13/jwalterweatherman" + packages = ["."] + pruneopts = "UT" + revision = "94f6ae3ed3bceceafa716478c5fbf8d29ca601a1" + version = "v1.1.0" + +[[projects]] + digest = "1:c1b1102241e7f645bc8e0c22ae352e8f0dc6484b6cb4d132fa9f24174e0119e2" + name = "github.com/spf13/pflag" + packages = ["."] + pruneopts = "UT" + revision = "298182f68c66c05229eb03ac171abe6e309ee79a" + version = "v1.0.3" + +[[projects]] + digest = "1:2532daa308722c7b65f4566e634dac2ddfaa0a398a17d8418e96ef2af3939e37" + name = "github.com/spf13/viper" + packages = ["."] + pruneopts = "UT" + revision = "ae103d7e593e371c69e832d5eb3347e2b80cbbc9" + +[[projects]] + digest = "1:ac83cf90d08b63ad5f7e020ef480d319ae890c208f8524622a2f3136e2686b02" + name = "github.com/stretchr/objx" + packages = ["."] + pruneopts = "UT" + revision = "477a77ecc69700c7cdeb1fa9e129548e1c1c393c" + version = "v0.1.1" + +[[projects]] + digest = "1:ad527ce5c6b2426790449db7663fe53f8bb647f9387295406794c8be001238da" + name = "github.com/stretchr/testify" + packages = [ + "assert", + "mock", + ] + pruneopts = "UT" + revision = "221dbe5ed46703ee255b1da0dec05086f5035f62" + version = "v1.4.0" + +[[projects]] + digest = "1:74055050ea547bb04600be79cc501965cb3de8988018262f2ca430f0a0b48ec3" + name = "go.opencensus.io" + packages = [ + ".", + "internal", + "internal/tagencoding", + "metric/metricdata", + "metric/metricproducer", + "plugin/ochttp", + "plugin/ochttp/propagation/b3", + "resource", + "stats", + "stats/internal", + "stats/view", + "tag", + "trace", + "trace/internal", + "trace/propagation", + "trace/tracestate", + ] + pruneopts = "UT" + revision = "9c377598961b706d1542bd2d84d538b5094d596e" + version = "v0.22.0" + +[[projects]] + branch = "master" + digest = "1:bbe51412d9915d64ffaa96b51d409e070665efc5194fcf145c4a27d4133107a4" + name = "golang.org/x/crypto" + packages = ["ssh/terminal"] + pruneopts = "UT" + revision = "9756ffdc24725223350eb3266ffb92590d28f278" + +[[projects]] + branch = "master" + digest = "1:e93fe09ca93cf16f8b2dc48053f56c2f91ed4f3fd16bfaf9596b6548c7b48a7f" + name = "golang.org/x/net" + packages = [ + "context", + "context/ctxhttp", + "http/httpguts", + "http2", + "http2/hpack", + "idna", + "internal/timeseries", + "trace", + ] + pruneopts = "UT" + revision = "ba9fcec4b297b415637633c5a6e8fa592e4a16c3" + +[[projects]] + branch = "master" + digest = "1:31e33f76456ccf54819ab4a646cf01271d1a99d7712ab84bf1a9e7b61cd2031b" + name = "golang.org/x/oauth2" + packages = [ + ".", + "google", + "internal", + "jws", + "jwt", + ] + pruneopts = "UT" + revision = "0f29369cfe4552d0e4bcddc57cc75f4d7e672a33" + +[[projects]] + branch = "master" + digest = "1:db4d094dcdda93745779828d4f7536085eae66f9ebcba842bda762883db08800" + name = "golang.org/x/sys" + packages = [ + "unix", + "windows", + ] + pruneopts = "UT" + revision = "1e83adbbebd0f5dc971915fd7e5db032c3d2b731" + +[[projects]] + digest = "1:8d8faad6b12a3a4c819a3f9618cb6ee1fa1cfc33253abeeea8b55336721e3405" + name = "golang.org/x/text" + packages = [ + "collate", + "collate/build", + "internal/colltab", + "internal/gen", + "internal/language", + "internal/language/compact", + "internal/tag", + "internal/triegen", + "internal/ucd", + "language", + "secure/bidirule", + "transform", + "unicode/bidi", + "unicode/cldr", + "unicode/norm", + "unicode/rangetable", + ] + pruneopts = "UT" + revision = "342b2e1fbaa52c93f31447ad2c6abc048c63e475" + version = "v0.3.2" + +[[projects]] + branch = "master" + digest = "1:9fdc2b55e8e0fafe4b41884091e51e77344f7dc511c5acedcfd98200003bff90" + name = "golang.org/x/time" + packages = ["rate"] + pruneopts = "UT" + revision = "9d24e82272b4f38b78bc8cff74fa936d31ccd8ef" + +[[projects]] + branch = "master" + digest = "1:218feb07b42ba85b991b6f2decbc81e7fa6bec9d59cb0c617be40c65dd5edf22" + name = "google.golang.org/api" + packages = [ + "gensupport", + "googleapi", + "googleapi/internal/uritemplates", + "googleapi/transport", + "internal", + "option", + "storage/v1", + "transport/http", + "transport/http/internal/propagation", + ] + pruneopts = "UT" + revision = "d1c9f49851b5339dea6bf7e4076b60a66e62be1f" + +[[projects]] + digest = "1:498b722d33dde4471e7d6e5d88a5e7132d2a8306fea5ff5ee82d1f418b4f41ed" + name = "google.golang.org/appengine" + packages = [ + ".", + "internal", + "internal/app_identity", + "internal/base", + "internal/datastore", + "internal/log", + "internal/modules", + "internal/remote_api", + "internal/urlfetch", + "urlfetch", + ] + pruneopts = "UT" + revision = "5f2a59506353b8d5ba8cbbcd9f3c1f41f1eaf079" + version = "v1.6.2" + +[[projects]] + branch = "master" + digest = "1:1233ed1b527b0ff66c3df5879f7e80b1d8631e030cc45821b77fc25acd0d72a6" + name = "google.golang.org/genproto" + packages = [ + "googleapis/api/annotations", + "googleapis/api/httpbody", + "googleapis/rpc/status", + "protobuf/field_mask", + ] + pruneopts = "UT" + revision = "24fa4b261c55da65468f2abfdae2b024eef27dfb" + +[[projects]] + digest = "1:3b97661db2e5d4c87f7345e875ea28f911e54c715ba0a74be08e1649d67e05cd" + name = "google.golang.org/grpc" + packages = [ + ".", + "balancer", + "balancer/base", + "balancer/roundrobin", + "binarylog/grpc_binarylog_v1", + "codes", + "connectivity", + "credentials", + "credentials/internal", + "encoding", + "encoding/proto", + "grpclog", + "internal", + "internal/backoff", + "internal/balancerload", + "internal/binarylog", + "internal/channelz", + "internal/envconfig", + "internal/grpcrand", + "internal/grpcsync", + "internal/syscall", + "internal/transport", + "keepalive", + "metadata", + "naming", + "peer", + "resolver", + "resolver/dns", + "resolver/passthrough", + "serviceconfig", + "stats", + "status", + "tap", + ] + pruneopts = "UT" + revision = "6eaf6f47437a6b4e2153a190160ef39a92c7eceb" + version = "v1.23.0" + +[[projects]] + digest = "1:1048ae210f190cd7b6aea19a92a055bd6112b025dd49f560579dfdfd76c8c42e" + name = "gopkg.in/gormigrate.v1" + packages = ["."] + pruneopts = "UT" + revision = "ff46dd7d2c0b00a58540e19ca3d3f5e370fa3607" + version = "v1.6.0" + +[[projects]] + digest = "1:2d1fbdc6777e5408cabeb02bf336305e724b925ff4546ded0fa8715a7267922a" + name = "gopkg.in/inf.v0" + packages = ["."] + pruneopts = "UT" + revision = "d2d2541c53f18d2a059457998ce2876cc8e67cbf" + version = "v0.9.1" + +[[projects]] + digest = "1:4d2e5a73dc1500038e504a8d78b986630e3626dc027bc030ba5c75da257cdb96" + name = "gopkg.in/yaml.v2" + packages = ["."] + pruneopts = "UT" + revision = "51d6538a90f86fe93ac480b35f37b2be17fef232" + version = "v2.2.2" + +[[projects]] + branch = "release-1.13" + digest = "1:86b38004415341a2f678a19f9312213bc851a3620bb42b3dca005c8ad0d3485c" + name = "k8s.io/api" + packages = [ + "admissionregistration/v1alpha1", + "admissionregistration/v1beta1", + "apps/v1", + "apps/v1beta1", + "apps/v1beta2", + "auditregistration/v1alpha1", + "authentication/v1", + "authentication/v1beta1", + "authorization/v1", + "authorization/v1beta1", + "autoscaling/v1", + "autoscaling/v2beta1", + "autoscaling/v2beta2", + "batch/v1", + "batch/v1beta1", + "batch/v2alpha1", + "certificates/v1beta1", + "coordination/v1beta1", + "core/v1", + "events/v1beta1", + "extensions/v1beta1", + "networking/v1", + "policy/v1beta1", + "rbac/v1", + "rbac/v1alpha1", + "rbac/v1beta1", + "scheduling/v1alpha1", + "scheduling/v1beta1", + "settings/v1alpha1", + "storage/v1", + "storage/v1alpha1", + "storage/v1beta1", + ] + pruneopts = "UT" + revision = "ebce17126a01f5fe02364d88c899816bcc2a8165" + +[[projects]] + digest = "1:97be1d171d2125d42ddc05182cb53f0c22bff4d6eb20e6c56709e4173242423f" + name = "k8s.io/apimachinery" + packages = [ + "pkg/api/errors", + "pkg/api/meta", + "pkg/api/resource", + "pkg/apis/meta/v1", + "pkg/apis/meta/v1/unstructured", + "pkg/apis/meta/v1beta1", + "pkg/conversion", + "pkg/conversion/queryparams", + "pkg/fields", + "pkg/labels", + "pkg/runtime", + "pkg/runtime/schema", + "pkg/runtime/serializer", + "pkg/runtime/serializer/json", + "pkg/runtime/serializer/protobuf", + "pkg/runtime/serializer/recognizer", + "pkg/runtime/serializer/streaming", + "pkg/runtime/serializer/versioning", + "pkg/selection", + "pkg/types", + "pkg/util/clock", + "pkg/util/errors", + "pkg/util/framer", + "pkg/util/intstr", + "pkg/util/json", + "pkg/util/naming", + "pkg/util/net", + "pkg/util/rand", + "pkg/util/runtime", + "pkg/util/sets", + "pkg/util/validation", + "pkg/util/validation/field", + "pkg/util/wait", + "pkg/util/yaml", + "pkg/version", + "pkg/watch", + "third_party/forked/golang/reflect", + ] + pruneopts = "UT" + revision = "2b1284ed4c93a43499e781493253e2ac5959c4fd" + version = "kubernetes-1.13.1" + +[[projects]] + digest = "1:a7b135b3eb8e33e02745491c89c990daf1c22f096fc168e295c868d2ad617c0c" + name = "k8s.io/client-go" + packages = [ + "discovery", + "dynamic", + "kubernetes/scheme", + "pkg/apis/clientauthentication", + "pkg/apis/clientauthentication/v1alpha1", + "pkg/apis/clientauthentication/v1beta1", + "pkg/version", + "plugin/pkg/client/auth/exec", + "rest", + "rest/watch", + "restmapper", + "tools/auth", + "tools/clientcmd", + "tools/clientcmd/api", + "tools/clientcmd/api/latest", + "tools/clientcmd/api/v1", + "tools/metrics", + "transport", + "util/cert", + "util/connrotation", + "util/flowcontrol", + "util/homedir", + "util/integer", + "util/workqueue", + ] + pruneopts = "UT" + revision = "8d9ed539ba3134352c586810e749e58df4e94e4f" + version = "kubernetes-1.13.1" + +[[projects]] + digest = "1:ccb9be4c583b6ec848eb98aa395a4e8c8f8ad9ebb823642c0dd1c1c45939a5bb" + name = "k8s.io/klog" + packages = ["."] + pruneopts = "UT" + revision = "3ca30a56d8a775276f9cdae009ba326fdc05af7f" + version = "v0.4.0" + +[[projects]] + digest = "1:8f87a12b4d6f63f7787a8b5ca06348741321048d08023f5fda3ddf63e3ca2e6a" + name = "sigs.k8s.io/controller-runtime" + packages = [ + "pkg/client", + "pkg/client/apiutil", + ] + pruneopts = "UT" + revision = "477bf4f046c31c351b46fa00262bc814ac0bbca1" + version = "v0.1.11" + +[[projects]] + digest = "1:7719608fe0b52a4ece56c2dde37bedd95b938677d1ab0f84b8a7852e4c59f849" + name = "sigs.k8s.io/yaml" + packages = ["."] + pruneopts = "UT" + revision = "fd68e9863619f6ec2fdd8625fe1f02e7c877e480" + version = "v1.1.0" + +[solve-meta] + analyzer-name = "dep" + analyzer-version = 1 + input-imports = [ + "github.com/NYTimes/gizmo/pubsub", + "github.com/NYTimes/gizmo/pubsub/aws", + "github.com/NYTimes/gizmo/pubsub/pubsubtest", + "github.com/Selvatico/go-mocket", + "github.com/aws/aws-sdk-go/aws", + "github.com/aws/aws-sdk-go/aws/awserr", + "github.com/aws/aws-sdk-go/aws/request", + "github.com/aws/aws-sdk-go/aws/session", + "github.com/aws/aws-sdk-go/service/cloudwatchevents", + "github.com/aws/aws-sdk-go/service/s3", + "github.com/aws/aws-sdk-go/service/ses", + "github.com/aws/aws-sdk-go/service/ses/sesiface", + "github.com/benbjohnson/clock", + "github.com/gogo/protobuf/proto", + "github.com/golang/glog", + "github.com/golang/protobuf/jsonpb", + "github.com/golang/protobuf/proto", + "github.com/golang/protobuf/ptypes", + "github.com/golang/protobuf/ptypes/duration", + "github.com/golang/protobuf/ptypes/struct", + "github.com/golang/protobuf/ptypes/timestamp", + "github.com/grpc-ecosystem/go-grpc-prometheus", + "github.com/grpc-ecosystem/grpc-gateway/runtime", + "github.com/jinzhu/gorm", + "github.com/jinzhu/gorm/dialects/postgres", + "github.com/lib/pq", + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin", + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core", + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/event", + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/service", + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1", + "github.com/lyft/flytepropeller/pkg/client/clientset/versioned", + "github.com/lyft/flytepropeller/pkg/client/clientset/versioned/typed/flyteworkflow/v1alpha1", + "github.com/lyft/flytepropeller/pkg/compiler", + "github.com/lyft/flytepropeller/pkg/compiler/common", + "github.com/lyft/flytepropeller/pkg/compiler/transformers/k8s", + "github.com/lyft/flytepropeller/pkg/compiler/validators", + "github.com/lyft/flytepropeller/pkg/utils", + "github.com/lyft/flytestdlib/config", + "github.com/lyft/flytestdlib/config/viper", + "github.com/lyft/flytestdlib/contextutils", + "github.com/lyft/flytestdlib/logger", + "github.com/lyft/flytestdlib/pbhash", + "github.com/lyft/flytestdlib/profutils", + "github.com/lyft/flytestdlib/promutils", + "github.com/lyft/flytestdlib/promutils/labeled", + "github.com/lyft/flytestdlib/storage", + "github.com/magiconair/properties/assert", + "github.com/mitchellh/mapstructure", + "github.com/pkg/errors", + "github.com/prometheus/client_golang/prometheus", + "github.com/spf13/cobra", + "github.com/spf13/pflag", + "github.com/stretchr/testify/assert", + "google.golang.org/grpc", + "google.golang.org/grpc/codes", + "google.golang.org/grpc/credentials", + "google.golang.org/grpc/grpclog", + "google.golang.org/grpc/status", + "gopkg.in/gormigrate.v1", + "k8s.io/apimachinery/pkg/api/errors", + "k8s.io/apimachinery/pkg/api/resource", + "k8s.io/apimachinery/pkg/apis/meta/v1", + "k8s.io/apimachinery/pkg/runtime/schema", + "k8s.io/apimachinery/pkg/util/validation", + "k8s.io/apimachinery/pkg/util/wait", + "k8s.io/client-go/kubernetes/scheme", + "k8s.io/client-go/rest", + "k8s.io/client-go/tools/clientcmd", + "sigs.k8s.io/controller-runtime/pkg/client", + ] + solver-name = "gps-cdcl" + solver-version = 1 diff --git a/flyteadmin/Gopkg.toml b/flyteadmin/Gopkg.toml new file mode 100644 index 0000000000..fad2d6f530 --- /dev/null +++ b/flyteadmin/Gopkg.toml @@ -0,0 +1,110 @@ +# Gopkg.toml example +# +# Refer to https://golang.github.io/dep/docs/Gopkg.toml.html +# for detailed Gopkg.toml documentation. +# +# required = ["github.com/user/thing/cmd/thing"] +# ignored = ["github.com/user/project/pkgX", "bitbucket.org/user/project/pkgA/pkgY"] +# +# [[constraint]] +# name = "github.com/user/project" +# version = "1.0.0" +# +# [[constraint]] +# name = "github.com/user/project2" +# branch = "dev" +# source = "github.com/myfork/project2" +# +# [[override]] +# name = "github.com/x/y" +# version = "2.4.0" +# +# [prune] +# non-go = false +# go-tests = true +# unused-packages = true + + +[[constraint]] + name = "github.com/aws/aws-sdk-go" + version = "1.15.0" + +[[constraint]] + name = "github.com/NYTimes/gizmo" + version = "v0.4.2" + +[[constraint]] + branch = "master" + name = "github.com/golang/glog" + +[[constraint]] + name = "github.com/golang/protobuf" + version = "1.1.0" + +[[constraint]] + name = "github.com/grpc-ecosystem/grpc-gateway" + version = "1.5.1" + +[[constraint]] + name = "github.com/grpc-ecosystem/go-grpc-prometheus" + version = "1.2.0" + +[[constraint]] + name = "github.com/lib/pq" + version = "1.0.0" + +[[override]] + name = "github.com/lyft/flyteidl" + source = "https://github.com/lyft/flyteidl" + version = "^0.1.x" + +[[constraint]] + name = "github.com/lyft/flytepropeller" + source = "https://github.com/lyft/flytepropeller" + version = "^v0.1.x" + +[[override]] + name = "github.com/lyft/flytestdlib" + source = "https://github.com/lyft/flytestdlib" + version = "^v0.2.12" + +[[constraint]] + name = "github.com/magiconair/properties" + version = "1.8.0" + +[[constraint]] + name = "github.com/spf13/cobra" + version = "0.0.3" + +[[constraint]] + name = "github.com/spf13/pflag" + version = "1.0.1" + +[[constraint]] + name = "google.golang.org/grpc" + version = "1.16.0" + +[[constraint]] + name = "gopkg.in/gormigrate.v1" + version = "1.2.1" + +[[constraint]] + name = "k8s.io/apimachinery" + version = "kubernetes-1.13.1" + +[[constraint]] + name = "k8s.io/client-go" + version = "kubernetes-1.13.1" + +[[override]] + branch = "master" + name = "golang.org/x/net" + +[[override]] + name = "github.com/json-iterator/go" + version = "^1.1.5" + +[prune] + go-tests = true + unused-packages = true + diff --git a/flyteadmin/LICENSE b/flyteadmin/LICENSE new file mode 100644 index 0000000000..bed437514f --- /dev/null +++ b/flyteadmin/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2019 Lyft, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/flyteadmin/Makefile b/flyteadmin/Makefile new file mode 100644 index 0000000000..cf5ce51d5c --- /dev/null +++ b/flyteadmin/Makefile @@ -0,0 +1,37 @@ +export REPOSITORY=flyteadmin +include boilerplate/lyft/docker_build/Makefile +include boilerplate/lyft/golang_test_targets/Makefile + +.PHONY: update_boilerplate +update_boilerplate: + @boilerplate/update.sh + +.PHONY: integration +integration: + GOCACHE=off go test -v -tags=integration ./tests/... + +.PHONY: k8s_integration +k8s_integration: + @script/integration/launch.sh + +.PHONY: compile +compile: + go build -o flyteadmin ./cmd/ && mv ./flyteadmin ${GOPATH}/bin + +.PHONY: linux_compile +linux_compile: + GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -o /artifacts/flyteadmin ./cmd/ + +.PHONY: server +server: + go run cmd/main.go --logtostderr --application.kube-config ~/.kube/config --config flyteadmin_config.yaml serve + +.PHONY: migrate +migrate: + go run cmd/main.go --logtostderr --application.kube-config ~/.kube/config --config flyteadmin_config.yaml migrate run + +.PHONY: seed_projects +seed_projects: + go run cmd/main.go --logtostderr --application.kube-config ~/.kube/config --config flyteadmin_config.yaml migrate seed-projects project admintests flytekit + +all: compile diff --git a/flyteadmin/NOTICE b/flyteadmin/NOTICE new file mode 100644 index 0000000000..dab3948b4d --- /dev/null +++ b/flyteadmin/NOTICE @@ -0,0 +1,4 @@ +flyteadmin +Copyright 2019 Lyft Inc. + +This product includes software developed at Lyft Inc. diff --git a/flyteadmin/README.rst b/flyteadmin/README.rst new file mode 100644 index 0000000000..2512efbba7 --- /dev/null +++ b/flyteadmin/README.rst @@ -0,0 +1,8 @@ +Flyteadmin +============= + +Flyteadmin is the control plane for Flyte responsible for managing entities (task, workflows, launch plans) and +administering workflow executions. Flyteadmin implements the +`AdminService `_ which +defines a stateless REST/gRPC service for interacting with registered Flyte entities and executions. +Flyteadmin uses a relational style Metadata Store abstracted by `GORM `_ ORM library. diff --git a/flyteadmin/boilerplate/lyft/docker_build/Makefile b/flyteadmin/boilerplate/lyft/docker_build/Makefile new file mode 100644 index 0000000000..4019dab839 --- /dev/null +++ b/flyteadmin/boilerplate/lyft/docker_build/Makefile @@ -0,0 +1,12 @@ +# WARNING: THIS FILE IS MANAGED IN THE 'BOILERPLATE' REPO AND COPIED TO OTHER REPOSITORIES. +# ONLY EDIT THIS FILE FROM WITHIN THE 'LYFT/BOILERPLATE' REPOSITORY: +# +# TO OPT OUT OF UPDATES, SEE https://github.com/lyft/boilerplate/blob/master/Readme.rst + +.PHONY: docker_build +docker_build: + IMAGE_NAME=$$REPOSITORY ./boilerplate/lyft/docker_build/docker_build.sh + +.PHONY: dockerhub_push +dockerhub_push: + IMAGE_NAME=lyft/$$REPOSITORY REGISTRY=docker.io ./boilerplate/lyft/docker_build/docker_build.sh diff --git a/flyteadmin/boilerplate/lyft/docker_build/Readme.rst b/flyteadmin/boilerplate/lyft/docker_build/Readme.rst new file mode 100644 index 0000000000..bb6af9b49e --- /dev/null +++ b/flyteadmin/boilerplate/lyft/docker_build/Readme.rst @@ -0,0 +1,23 @@ +Docker Build and Push +~~~~~~~~~~~~~~~~~~~~~ + +Provides a ``make docker_build`` target that builds your image locally. + +Provides a ``make dockerhub_push`` target that pushes your final image to Dockerhub. + +The Dockerhub image will tagged ``:`` + +If git head has a git tag, the Dockerhub image will also be tagged ``:``. + +**To Enable:** + +Add ``lyft/docker_build`` to your ``boilerplate/update.cfg`` file. + +Add ``include boilerplate/lyft/docker_build/Makefile`` in your main ``Makefile`` _after_ your REPOSITORY environment variable + +:: + + REPOSITORY= + include boilerplate/lyft/docker_build/Makefile + +(this ensures the extra Make targets get included in your main Makefile) diff --git a/flyteadmin/boilerplate/lyft/docker_build/docker_build.sh b/flyteadmin/boilerplate/lyft/docker_build/docker_build.sh new file mode 100755 index 0000000000..f504c100c7 --- /dev/null +++ b/flyteadmin/boilerplate/lyft/docker_build/docker_build.sh @@ -0,0 +1,67 @@ +#!/usr/bin/env bash + +# WARNING: THIS FILE IS MANAGED IN THE 'BOILERPLATE' REPO AND COPIED TO OTHER REPOSITORIES. +# ONLY EDIT THIS FILE FROM WITHIN THE 'LYFT/BOILERPLATE' REPOSITORY: +# +# TO OPT OUT OF UPDATES, SEE https://github.com/lyft/boilerplate/blob/master/Readme.rst + +set -e + +echo "" +echo "------------------------------------" +echo " DOCKER BUILD" +echo "------------------------------------" +echo "" + +if [ -n "$REGISTRY" ]; then + # Do not push if there are unstaged git changes + CHANGED=$(git status --porcelain) + if [ -n "$CHANGED" ]; then + echo "Please commit git changes before pushing to a registry" + exit 1 + fi +fi + + +GIT_SHA=$(git rev-parse HEAD) + +IMAGE_TAG_SUFFIX="" +# for intermediate build phases, append -$BUILD_PHASE to all image tags +if [ -n "$BUILD_PHASE" ]; then + IMAGE_TAG_SUFFIX="-${BUILD_PHASE}" +fi + +IMAGE_TAG_WITH_SHA="${IMAGE_NAME}:${GIT_SHA}${IMAGE_TAG_SUFFIX}" + +RELEASE_SEMVER=$(git describe --tags --exact-match "$GIT_SHA" 2>/dev/null) || true +if [ -n "$RELEASE_SEMVER" ]; then + IMAGE_TAG_WITH_SEMVER="${IMAGE_NAME}:${RELEASE_SEMVER}${IMAGE_TAG_SUFFIX}" +fi + +# build the image +# passing no build phase will build the final image +docker build -t "$IMAGE_TAG_WITH_SHA" --target=${BUILD_PHASE} . +echo "${IMAGE_TAG_WITH_SHA} built locally." + +# if REGISTRY specified, push the images to the remote registy +if [ -n "$REGISTRY" ]; then + + if [ -n "${DOCKER_REGISTRY_PASSWORD}" ]; then + docker login --username="$DOCKER_REGISTRY_USERNAME" --password="$DOCKER_REGISTRY_PASSWORD" + fi + + docker tag "$IMAGE_TAG_WITH_SHA" "${REGISTRY}/${IMAGE_TAG_WITH_SHA}" + + docker push "${REGISTRY}/${IMAGE_TAG_WITH_SHA}" + echo "${REGISTRY}/${IMAGE_TAG_WITH_SHA} pushed to remote." + + # If the current commit has a semver tag, also push the images with the semver tag + if [ -n "$RELEASE_SEMVER" ]; then + + docker tag "$IMAGE_TAG_WITH_SHA" "${REGISTRY}/${IMAGE_TAG_WITH_SEMVER}" + + docker push "${REGISTRY}/${IMAGE_TAG_WITH_SEMVER}" + echo "${REGISTRY}/${IMAGE_TAG_WITH_SEMVER} pushed to remote." + + fi +fi diff --git a/flyteadmin/boilerplate/lyft/golang_dockerfile/Dockerfile.GoTemplate b/flyteadmin/boilerplate/lyft/golang_dockerfile/Dockerfile.GoTemplate new file mode 100644 index 0000000000..5e7b984a11 --- /dev/null +++ b/flyteadmin/boilerplate/lyft/golang_dockerfile/Dockerfile.GoTemplate @@ -0,0 +1,33 @@ +# WARNING: THIS FILE IS MANAGED IN THE 'BOILERPLATE' REPO AND COPIED TO OTHER REPOSITORIES. +# ONLY EDIT THIS FILE FROM WITHIN THE 'LYFT/BOILERPLATE' REPOSITORY: +# +# TO OPT OUT OF UPDATES, SEE https://github.com/lyft/boilerplate/blob/master/Readme.rst + +# Using go1.10.4 +FROM golang:1.10.4-alpine3.8 as builder +RUN apk add git openssh-client make curl dep + +# COPY only the dep files for efficient caching +COPY Gopkg.* /go/src/github.com/lyft/{{REPOSITORY}}/ +WORKDIR /go/src/github.com/lyft/{{REPOSITORY}} + +# Pull dependencies +RUN dep ensure -vendor-only + +# COPY the rest of the source code +COPY . /go/src/github.com/lyft/{{REPOSITORY}}/ + +# This 'linux_compile' target should compile binaries to the /artifacts directory +# The main entrypoint should be compiled to /artifacts/{{REPOSITORY}} +RUN make linux_compile + +# update the PATH to include the /artifacts directory +ENV PATH="/artifacts:${PATH}" + +# This will eventually move to centurylink/ca-certs:latest for minimum possible image size +FROM alpine:3.8 +COPY --from=builder /artifacts /bin + +RUN apk --update add ca-certificates + +CMD ["{{REPOSITORY}}"] diff --git a/flyteadmin/boilerplate/lyft/golang_dockerfile/Readme.rst b/flyteadmin/boilerplate/lyft/golang_dockerfile/Readme.rst new file mode 100644 index 0000000000..f801ef98d6 --- /dev/null +++ b/flyteadmin/boilerplate/lyft/golang_dockerfile/Readme.rst @@ -0,0 +1,16 @@ +Golang Dockerfile +~~~~~~~~~~~~~~~~~ + +Provides a Dockerfile that produces a small image. + +**To Enable:** + +Add ``lyft/golang_dockerfile`` to your ``boilerplate/update.cfg`` file. + +Create and configure a ``make linux_compile`` target that compiles your go binaries to the ``/artifacts`` directory :: + + .PHONY: linux_compile + linux_compile: + RUN GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -o /artifacts {{ packages }} + +All binaries compiled to ``/artifacts`` will be available at ``/bin`` in your final image. diff --git a/flyteadmin/boilerplate/lyft/golang_dockerfile/update.sh b/flyteadmin/boilerplate/lyft/golang_dockerfile/update.sh new file mode 100755 index 0000000000..7d84663262 --- /dev/null +++ b/flyteadmin/boilerplate/lyft/golang_dockerfile/update.sh @@ -0,0 +1,13 @@ +#!/usr/bin/env bash + +# WARNING: THIS FILE IS MANAGED IN THE 'BOILERPLATE' REPO AND COPIED TO OTHER REPOSITORIES. +# ONLY EDIT THIS FILE FROM WITHIN THE 'LYFT/BOILERPLATE' REPOSITORY: +# +# TO OPT OUT OF UPDATES, SEE https://github.com/lyft/boilerplate/blob/master/Readme.rst + +set -e + +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null && pwd )" + +echo " - generating Dockerfile in root directory." +sed -e "s/{{REPOSITORY}}/${REPOSITORY}/g" ${DIR}/Dockerfile.GoTemplate > ${DIR}/../../../Dockerfile diff --git a/flyteadmin/boilerplate/lyft/golang_test_targets/Makefile b/flyteadmin/boilerplate/lyft/golang_test_targets/Makefile new file mode 100644 index 0000000000..6c1e527fd6 --- /dev/null +++ b/flyteadmin/boilerplate/lyft/golang_test_targets/Makefile @@ -0,0 +1,38 @@ +# WARNING: THIS FILE IS MANAGED IN THE 'BOILERPLATE' REPO AND COPIED TO OTHER REPOSITORIES. +# ONLY EDIT THIS FILE FROM WITHIN THE 'LYFT/BOILERPLATE' REPOSITORY: +# +# TO OPT OUT OF UPDATES, SEE https://github.com/lyft/boilerplate/blob/master/Readme.rst + +DEP_SHA=1f7c19e5f52f49ffb9f956f64c010be14683468b + +.PHONY: lint +lint: #lints the package for common code smells + which golangci-lint || curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | bash -s -- -b $$GOPATH/bin v1.16.0 + golangci-lint run --exclude deprecated + +# If code is failing goimports linter, this will fix. +# skips 'vendor' +.PHONY: goimports +goimports: + @boilerplate/lyft/golang_test_targets/goimports + +.PHONY: install +install: #download dependencies (including test deps) for the package + which dep || (curl "https://raw.githubusercontent.com/golang/dep/${DEP_SHA}/install.sh" | sh) + dep ensure + +.PHONY: test_unit +test_unit: + go test -cover ./... -race + +.PHONY: test_benchmark +test_benchmark: + go test -bench . ./... + +.PHONY: test_unit_cover +test_unit_cover: + go test ./... -coverprofile /tmp/cover.out -covermode=count; go tool cover -func /tmp/cover.out + +.PHONY: test_unit_visual +test_unit_visual: + go test ./... -coverprofile /tmp/cover.out -covermode=count; go tool cover -html=/tmp/cover.out diff --git a/flyteadmin/boilerplate/lyft/golang_test_targets/Readme.rst b/flyteadmin/boilerplate/lyft/golang_test_targets/Readme.rst new file mode 100644 index 0000000000..acc5744f59 --- /dev/null +++ b/flyteadmin/boilerplate/lyft/golang_test_targets/Readme.rst @@ -0,0 +1,31 @@ +Golang Test Targets +~~~~~~~~~~~~~~~~~~~ + +Provides an ``install`` make target that uses ``dep`` install golang dependencies. + +Provides a ``lint`` make target that uses golangci to lint your code. + +Provides a ``test_unit`` target for unit tests. + +Provides a ``test_unit_cover`` target for analysing coverage of unit tests, which will output the coverage of each function and total statement coverage. + +Provides a ``test_unit_visual`` target for visualizing coverage of unit tests through an interactive html code heat map. + +Provides a ``test_benchmark`` target for benchmark tests. + +**To Enable:** + +Add ``lyft/golang_test_targets`` to your ``boilerplate/update.cfg`` file. + +Make sure you're using ``dep`` for dependency management. + +Provide a ``.golangci`` configuration (the lint target requires it). + +Add ``include boilerplate/lyft/golang_test_targets/Makefile`` in your main ``Makefile`` _after_ your REPOSITORY environment variable + +:: + + REPOSITORY= + include boilerplate/lyft/golang_test_targets/Makefile + +(this ensures the extra make targets get included in your main Makefile) diff --git a/flyteadmin/boilerplate/lyft/golang_test_targets/goimports b/flyteadmin/boilerplate/lyft/golang_test_targets/goimports new file mode 100755 index 0000000000..160525a8cc --- /dev/null +++ b/flyteadmin/boilerplate/lyft/golang_test_targets/goimports @@ -0,0 +1,8 @@ +#!/usr/bin/env bash + +# WARNING: THIS FILE IS MANAGED IN THE 'BOILERPLATE' REPO AND COPIED TO OTHER REPOSITORIES. +# ONLY EDIT THIS FILE FROM WITHIN THE 'LYFT/BOILERPLATE' REPOSITORY: +# +# TO OPT OUT OF UPDATES, SEE https://github.com/lyft/boilerplate/blob/master/Readme.rst + +goimports -w $(find . -type f -name '*.go' -not -path "./vendor/*" -not -path "./pkg/client/*") diff --git a/flyteadmin/boilerplate/update.cfg b/flyteadmin/boilerplate/update.cfg new file mode 100644 index 0000000000..c454fda70a --- /dev/null +++ b/flyteadmin/boilerplate/update.cfg @@ -0,0 +1,3 @@ +lyft/docker_build +lyft/golang_test_targets +lyft/golang_dockerfile diff --git a/flyteadmin/boilerplate/update.sh b/flyteadmin/boilerplate/update.sh new file mode 100755 index 0000000000..bea661d9a0 --- /dev/null +++ b/flyteadmin/boilerplate/update.sh @@ -0,0 +1,53 @@ +#!/usr/bin/env bash + +# WARNING: THIS FILE IS MANAGED IN THE 'BOILERPLATE' REPO AND COPIED TO OTHER REPOSITORIES. +# ONLY EDIT THIS FILE FROM WITHIN THE 'LYFT/BOILERPLATE' REPOSITORY: +# +# TO OPT OUT OF UPDATES, SEE https://github.com/lyft/boilerplate/blob/master/Readme.rst + +set -e + +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null && pwd )" + +OUT="$(mktemp -d)" +git clone git@github.com:lyft/boilerplate.git "${OUT}" + +echo "Updating the update.sh script." +cp "${OUT}/boilerplate/update.sh" "${DIR}/update.sh" +echo "" + + +CONFIG_FILE="${DIR}/update.cfg" +README="https://github.com/lyft/boilerplate/blob/master/Readme.rst" + +if [ ! -f "$CONFIG_FILE" ]; then + echo "$CONFIG_FILE not found." + echo "This file is required in order to select which features to include." + echo "See $README for more details." + exit 1 +fi + +if [ -z "$REPOSITORY" ]; then + echo '$REPOSITORY is required to run this script' + echo "See $README for more details." + exit 1 +fi + +while read directory; do + echo "***********************************************************************************" + echo "$directory is configured in update.cfg." + echo "-----------------------------------------------------------------------------------" + echo "syncing files from source." + dir_path="${OUT}/boilerplate/${directory}" + rm -rf "${DIR}/${directory}" + mkdir -p $(dirname "${DIR}/${directory}") + cp -r "$dir_path" "${DIR}/${directory}" + if [ -f "${DIR}/${directory}/update.sh" ]; then + echo "executing ${DIR}/${directory}/update.sh" + "${DIR}/${directory}/update.sh" + fi + echo "***********************************************************************************" + echo "" +done < "$CONFIG_FILE" + +rm -rf "${OUT}" diff --git a/flyteadmin/cmd/entrypoints/clusterresource.go b/flyteadmin/cmd/entrypoints/clusterresource.go new file mode 100644 index 0000000000..05a1e4a8e9 --- /dev/null +++ b/flyteadmin/cmd/entrypoints/clusterresource.go @@ -0,0 +1,111 @@ +package entrypoints + +import ( + "context" + + "github.com/lyft/flyteadmin/pkg/clusterresource" + + "github.com/lyft/flyteadmin/pkg/flytek8s" + + "github.com/lyft/flyteadmin/pkg/runtime" + + "github.com/lyft/flytestdlib/logger" + + _ "github.com/jinzhu/gorm/dialects/postgres" // Required to import database driver. + "github.com/lyft/flyteadmin/pkg/config" + "github.com/lyft/flyteadmin/pkg/repositories" + repositoryConfig "github.com/lyft/flyteadmin/pkg/repositories/config" + "github.com/lyft/flytestdlib/promutils" + "github.com/spf13/cobra" +) + +var parentClusterResourceCmd = &cobra.Command{ + Use: "clusterresource", + Short: "This command administers the ClusterResourceController. Please choose a subcommand.", +} + +func GetLocalDbConfig() repositoryConfig.DbConfig { + return repositoryConfig.DbConfig{ + Host: "localhost", + Port: 5432, + DbName: "postgres", + User: "postgres", + } +} + +var controllerRunCmd = &cobra.Command{ + Use: "run", + Short: "This command will start a cluster resource controller to periodically sync cluster resources", + Run: func(cmd *cobra.Command, args []string) { + ctx := context.Background() + configuration := runtime.NewConfigurationProvider() + scope := promutils.NewScope(configuration.ApplicationConfiguration().GetTopLevelConfig().MetricsScope).NewSubScope("clusterresource") + dbConfigValues := configuration.ApplicationConfiguration().GetDbConfig() + dbConfig := repositoryConfig.DbConfig{ + Host: dbConfigValues.Host, + Port: dbConfigValues.Port, + DbName: dbConfigValues.DbName, + User: dbConfigValues.User, + Password: dbConfigValues.Password, + ExtraOptions: dbConfigValues.ExtraOptions, + } + db := repositories.GetRepository( + repositories.POSTGRES, dbConfig, scope.NewSubScope("database")) + + cfg := config.GetConfig() + kubeClient, err := flytek8s.NewKubeClient(cfg.KubeConfig, cfg.Master, configuration.ClusterConfiguration()) + if err != nil { + scope.NewSubScope("flytekubeconfig").MustNewCounter( + "kubeconfig_get_error", + "count of errors encountered fetching and initializing kube config").Inc() + logger.Fatalf(ctx, "Failed to initialize kubeClient: %+v", err) + } + + clusterResourceController := clusterresource.NewClusterResourceController(db, kubeClient, scope) + clusterResourceController.Run() + logger.Infof(ctx, "ClusterResourceController started successfully") + }, +} + +var controllerSyncCmd = &cobra.Command{ + Use: "sync", + Short: "This command will sync cluster resources", + Run: func(cmd *cobra.Command, args []string) { + ctx := context.Background() + configuration := runtime.NewConfigurationProvider() + scope := promutils.NewScope(configuration.ApplicationConfiguration().GetTopLevelConfig().MetricsScope).NewSubScope("clusterresource") + dbConfigValues := configuration.ApplicationConfiguration().GetDbConfig() + dbConfig := repositoryConfig.DbConfig{ + Host: dbConfigValues.Host, + Port: dbConfigValues.Port, + DbName: dbConfigValues.DbName, + User: dbConfigValues.User, + Password: dbConfigValues.Password, + ExtraOptions: dbConfigValues.ExtraOptions, + } + db := repositories.GetRepository( + repositories.POSTGRES, dbConfig, scope.NewSubScope("database")) + + cfg := config.GetConfig() + kubeClient, err := flytek8s.NewKubeClient(cfg.KubeConfig, cfg.Master, configuration.ClusterConfiguration()) + if err != nil { + scope.NewSubScope("flytekubeconfig").MustNewCounter( + "kubeconfig_get_error", + "count of errors encountered fetching and initializing kube config").Inc() + logger.Fatalf(ctx, "Failed to initialize kubeClient: %+v", err) + } + + clusterResourceController := clusterresource.NewClusterResourceController(db, kubeClient, scope) + err = clusterResourceController.Sync(ctx) + if err != nil { + logger.Fatalf(ctx, "Failed to sync cluster resources [%+v]", err) + } + logger.Infof(ctx, "ClusterResourceController started successfully") + }, +} + +func init() { + RootCmd.AddCommand(parentClusterResourceCmd) + parentClusterResourceCmd.AddCommand(controllerRunCmd) + parentClusterResourceCmd.AddCommand(controllerSyncCmd) +} diff --git a/flyteadmin/cmd/entrypoints/migrate.go b/flyteadmin/cmd/entrypoints/migrate.go new file mode 100644 index 0000000000..b081d5a3c7 --- /dev/null +++ b/flyteadmin/cmd/entrypoints/migrate.go @@ -0,0 +1,133 @@ +package entrypoints + +import ( + "context" + + "github.com/lyft/flyteadmin/pkg/runtime" + + "github.com/lyft/flytestdlib/promutils" + + "github.com/lyft/flytestdlib/logger" + + "github.com/jinzhu/gorm" + _ "github.com/jinzhu/gorm/dialects/postgres" // Required to import database driver. + "github.com/lyft/flyteadmin/pkg/repositories/config" + "github.com/spf13/cobra" + gormigrate "gopkg.in/gormigrate.v1" +) + +var parentMigrateCmd = &cobra.Command{ + Use: "migrate", + Short: "This command controls migration behavior for the Flyte admin database. Please choose a subcommand.", +} + +var migrationsScope = promutils.NewScope("migrations") +var migrateScope = migrationsScope.NewSubScope("migrate") +var rollbackScope = promutils.NewScope("migrations").NewSubScope("rollback") + +// This runs all the migrations +var migrateCmd = &cobra.Command{ + Use: "run", + Short: "This command will run all the migrations for the database", + Run: func(cmd *cobra.Command, args []string) { + ctx := context.Background() + configuration := runtime.NewConfigurationProvider() + databaseConfig := configuration.ApplicationConfiguration().GetDbConfig() + postgresConfigProvider := config.NewPostgresConfigProvider(config.DbConfig{ + Host: databaseConfig.Host, + Port: databaseConfig.Port, + DbName: databaseConfig.DbName, + User: databaseConfig.User, + Password: databaseConfig.Password, + ExtraOptions: databaseConfig.ExtraOptions, + }, migrateScope) + db, err := gorm.Open(postgresConfigProvider.GetType(), postgresConfigProvider.GetArgs()) + if err != nil { + logger.Fatal(ctx, err) + } + defer db.Close() + db.LogMode(true) + if err = db.DB().Ping(); err != nil { + logger.Fatal(ctx, err) + } + + m := gormigrate.New(db, gormigrate.DefaultOptions, config.Migrations) + if err = m.Migrate(); err != nil { + logger.Fatalf(ctx, "Could not migrate: %v", err) + } + logger.Infof(ctx, "Migration ran successfully") + }, +} + +// Rollback the latest migration +var rollbackCmd = &cobra.Command{ + Use: "rollback", + Short: "This command will rollback one migration", + Run: func(cmd *cobra.Command, args []string) { + ctx := context.Background() + configuration := runtime.NewConfigurationProvider() + databaseConfig := configuration.ApplicationConfiguration().GetDbConfig() + postgresConfigProvider := config.NewPostgresConfigProvider(config.DbConfig{ + Host: databaseConfig.Host, + Port: databaseConfig.Port, + DbName: databaseConfig.DbName, + User: databaseConfig.User, + Password: databaseConfig.Password, + ExtraOptions: databaseConfig.ExtraOptions, + }, rollbackScope) + + db, err := gorm.Open(postgresConfigProvider.GetType(), postgresConfigProvider.GetArgs()) + if err != nil { + logger.Fatal(ctx, err) + } + defer db.Close() + db.LogMode(true) + if err = db.DB().Ping(); err != nil { + logger.Fatal(ctx, err) + } + + m := gormigrate.New(db, gormigrate.DefaultOptions, config.Migrations) + err = m.RollbackLast() + if err != nil { + logger.Fatalf(ctx, "Could not rollback latest migration: %v", err) + } + logger.Infof(ctx, "Rolled back one migration successfully") + }, +} + +// This seeds the database with project values +var seedProjectsCmd = &cobra.Command{ + Use: "seed-projects", + Short: "Seed projects in the database.", + Run: func(cmd *cobra.Command, args []string) { + ctx := context.Background() + configuration := runtime.NewConfigurationProvider() + databaseConfig := configuration.ApplicationConfiguration().GetDbConfig() + postgresConfigProvider := config.NewPostgresConfigProvider(config.DbConfig{ + Host: databaseConfig.Host, + Port: databaseConfig.Port, + DbName: databaseConfig.DbName, + User: databaseConfig.User, + Password: databaseConfig.Password, + ExtraOptions: databaseConfig.ExtraOptions, + }, migrateScope) + db, err := gorm.Open(postgresConfigProvider.GetType(), postgresConfigProvider.GetArgs()) + if err != nil { + logger.Fatal(ctx, err) + } + defer db.Close() + db.LogMode(true) + + if err = config.SeedProjects(db, args); err != nil { + logger.Fatalf(ctx, "Could not add projects to database with err: %v", err) + } + logger.Infof(ctx, "Successfully added projects to database") + }, +} + +func init() { + RootCmd.AddCommand(parentMigrateCmd) + parentMigrateCmd.AddCommand(migrateCmd) + parentMigrateCmd.AddCommand(rollbackCmd) + parentMigrateCmd.AddCommand(seedProjectsCmd) +} diff --git a/flyteadmin/cmd/entrypoints/root.go b/flyteadmin/cmd/entrypoints/root.go new file mode 100644 index 0000000000..f2f8c6f1f1 --- /dev/null +++ b/flyteadmin/cmd/entrypoints/root.go @@ -0,0 +1,83 @@ +package entrypoints + +import ( + "context" + "flag" + "fmt" + "os" + + "github.com/lyft/flytestdlib/config" + "github.com/lyft/flytestdlib/config/viper" + "github.com/spf13/cobra" + "github.com/spf13/pflag" +) + +var ( + cfgFile string + kubeMasterURL string + configAccessor = viper.NewAccessor(config.Options{}) +) + +// RootCmd represents the base command when called without any subcommands +var RootCmd = &cobra.Command{ + Use: "flyteadmin", + Short: "Fill in later", + Long: ` +To get started run the serve subcommand which will start a server on localhost:8088: + + flyteadmin serve + +Then you can hit it with the client: + + flyteadmin adminservice foo bar baz + +Or over HTTP 1.1 with curl: + curl -X POST http://localhost:8088/api/v1/projects' +`, + PersistentPreRunE: func(cmd *cobra.Command, args []string) error { + return initConfig(cmd.Flags()) + }, +} + +// Execute adds all child commands to the root command sets flags appropriately. +// This is called by main.main(). It only needs to happen once to the rootCmd. +func Execute() error { + if err := RootCmd.Execute(); err != nil { + fmt.Println(err) + return err + } + return nil +} + +func init() { + // allows `$ flyteadmin --logtostderr` to work + pflag.CommandLine.AddGoFlagSet(flag.CommandLine) + + // Add persistent flags - persistent flags persist through all sub-commands + RootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is ./flyteadmin_config.yaml)") + RootCmd.PersistentFlags().StringVar(&kubeMasterURL, "master", "", "The address of the Kubernetes API server. Overrides any value in kubeconfig. Only required if out-of-cluster.") + + RootCmd.AddCommand(viper.GetConfigCommand()) + + // Allow viper to read the value of the flags + configAccessor.InitializePflags(RootCmd.PersistentFlags()) + + err := flag.CommandLine.Parse([]string{}) + if err != nil { + fmt.Println(err) + os.Exit(-1) + } +} + +func initConfig(flags *pflag.FlagSet) error { + configAccessor = viper.NewAccessor(config.Options{ + SearchPaths: []string{cfgFile, ".", "/etc/flyte/config", "$GOPATH/src/github.com/lyft/flyteadmin"}, + StrictMode: false, + }) + + fmt.Println("Using config file: ", configAccessor.ConfigFilesUsed()) + + configAccessor.InitializePflags(flags) + + return configAccessor.UpdateConfig(context.TODO()) +} diff --git a/flyteadmin/cmd/entrypoints/serve.go b/flyteadmin/cmd/entrypoints/serve.go new file mode 100644 index 0000000000..3c76d6bed0 --- /dev/null +++ b/flyteadmin/cmd/entrypoints/serve.go @@ -0,0 +1,219 @@ +package entrypoints + +import ( + "context" + "crypto/tls" + "crypto/x509" + "fmt" + "io/ioutil" + "net" + "net/http" + _ "net/http/pprof" // Required to serve application. + "strings" + + "github.com/pkg/errors" + "google.golang.org/grpc/credentials" + + "github.com/lyft/flyteadmin/pkg/common" + + "github.com/lyft/flytestdlib/logger" + + "github.com/grpc-ecosystem/grpc-gateway/runtime" + flyteService "github.com/lyft/flyteidl/gen/pb-go/flyteidl/service" + + "github.com/lyft/flyteadmin/pkg/config" + "github.com/lyft/flyteadmin/pkg/rpc/adminservice" + + "github.com/spf13/cobra" + + grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus" + "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flytestdlib/promutils/labeled" + "google.golang.org/grpc" +) + +type ServingOptions struct { + Secure bool + // Optional Arguments, Should be provided to enable secure mode + CertFile string + KeyFile string +} + +var serviceOpts = ServingOptions{} + +// serveCmd represents the serve command +var serveCmd = &cobra.Command{ + Use: "serve", + Short: "Launches the Flyte admin server", + RunE: func(cmd *cobra.Command, args []string) error { + ctx := context.Background() + cfg := config.GetConfig() + if serviceOpts.Secure { + return serveGatewaySecure(ctx, cfg, serviceOpts) + } + return serveGatewayInsecure(ctx, cfg) + }, +} + +func init() { + // Command information + RootCmd.AddCommand(serveCmd) + serveCmd.Flags().BoolVarP(&serviceOpts.Secure, "secure", "s", false, "Use ssl") + serveCmd.Flags().StringVarP(&serviceOpts.CertFile, "cert-file", "c", "", "Path of file that contains x509 certificate") + serveCmd.Flags().StringVarP(&serviceOpts.KeyFile, "key-file", "k", "", "Path of file that contains x509 client key") + + // Set Keys + labeled.SetMetricKeys(contextutils.AppNameKey, contextutils.ProjectKey, contextutils.DomainKey, + contextutils.ExecIDKey, contextutils.WorkflowIDKey, contextutils.NodeIDKey, contextutils.TaskIDKey, + contextutils.TaskTypeKey, common.RuntimeTypeKey, common.RuntimeVersionKey) +} + +// Creates a new gRPC Server with all the configuration +func newGRPCServer(_ context.Context, cfg *config.Config, opts ...grpc.ServerOption) (*grpc.Server, error) { + serverOpts := []grpc.ServerOption{ + grpc.StreamInterceptor(grpc_prometheus.StreamServerInterceptor), + grpc.UnaryInterceptor(grpc_prometheus.UnaryServerInterceptor), + } + serverOpts = append(serverOpts, opts...) + grpcServer := grpc.NewServer(serverOpts...) + grpc_prometheus.Register(grpcServer) + flyteService.RegisterAdminServiceServer(grpcServer, adminservice.NewAdminServer(cfg.KubeConfig, cfg.Master)) + return grpcServer, nil +} + +func newHTTPServer(ctx context.Context, cfg *config.Config, grpcConnectionOpts []grpc.DialOption, grpcAddress string) (*http.ServeMux, error) { + // Register the server that will serve HTTP/REST Traffic + mux := http.NewServeMux() + + // Register healthcheck + mux.HandleFunc("/healthcheck", func(w http.ResponseWriter, r *http.Request) { + // A very simple health check. + w.WriteHeader(http.StatusOK) + }) + + // Register OpenAPI endpoint + // This endpoint will serve the OpenAPI2 spec generated by the swagger protoc plugin, and bundled by go-bindata + mux.HandleFunc("/api/v1/openapi", func(w http.ResponseWriter, r *http.Request) { + swaggerBytes, err := flyteService.Asset("admin.swagger.json") + if err != nil { + logger.Warningf(ctx, "Err %v", err) + w.WriteHeader(http.StatusFailedDependency) + } else { + w.WriteHeader(http.StatusOK) + _, err := w.Write(swaggerBytes) + if err != nil { + logger.Errorf(ctx, "failed to write openAPI information, error: %s", err.Error()) + } + } + }) + + // Register the actual Server that will service gRPC traffic + gwmux := runtime.NewServeMux(runtime.WithMarshalerOption("application/octet-stream", &runtime.ProtoMarshaller{})) + err := flyteService.RegisterAdminServiceHandlerFromEndpoint(ctx, gwmux, grpcAddress, grpcConnectionOpts) + if err != nil { + return nil, errors.Wrap(err, "error registering admin service") + } + + mux.Handle("/", gwmux) + + return mux, nil +} + +func serveGatewayInsecure(ctx context.Context, cfg *config.Config) error { + logger.Infof(ctx, "Serving FlyteAdmin Insecure") + grpcServer, err := newGRPCServer(ctx, cfg) + if err != nil { + return errors.Wrap(err, "failed to create GRPC server") + } + + logger.Infof(ctx, "Serving GRPC Traffic on: %s", cfg.GetGrpcHostAddress()) + lis, err := net.Listen("tcp", cfg.GetGrpcHostAddress()) + if err != nil { + return errors.Wrapf(err, "failed to listen on GRPC port: %s", cfg.GetGrpcHostAddress()) + } + + go func() { + err := grpcServer.Serve(lis) + logger.Fatalf(ctx, "Failed to create GRPC Server, Err: ", err) + }() + + logger.Infof(ctx, "Starting HTTP/1 Gateway server on %s", cfg.GetHostAddress()) + httpServer, err := newHTTPServer(ctx, cfg, []grpc.DialOption{grpc.WithInsecure()}, cfg.GetGrpcHostAddress()) + if err != nil { + return err + } + err = http.ListenAndServe(cfg.GetHostAddress(), httpServer) + if err != nil { + return errors.Wrapf(err, "failed to Start HTTP Server") + } + + return nil +} + +// grpcHandlerFunc returns an http.Handler that delegates to grpcServer on incoming gRPC +// connections or otherHandler otherwise. +// See https://github.com/philips/grpc-gateway-example/blob/master/cmd/serve.go for reference +func grpcHandlerFunc(grpcServer *grpc.Server, otherHandler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // This is a partial recreation of gRPC's internal checks + if r.ProtoMajor == 2 && strings.Contains(r.Header.Get("Content-Type"), "application/grpc") { + logger.Infof(context.TODO(), "Received GRPC request, %s, %s, %s, %d, %d", r.RequestURI, r.Header, r.Method, r.ProtoMajor, r.ProtoMinor) + grpcServer.ServeHTTP(w, r) + } else { + logger.Infof(context.TODO(), "Received regular request - %s, %s, %s, %d, %d", r.RequestURI, r.Header, r.Method, r.ProtoMajor, r.ProtoMinor) + otherHandler.ServeHTTP(w, r) + } + }) +} + +func serveGatewaySecure(ctx context.Context, cfg *config.Config, opts ServingOptions) error { + // This support single cert right now? + var err error + cert, err := tls.LoadX509KeyPair(opts.CertFile, opts.KeyFile) + if err != nil { + return err + } + certPool := x509.NewCertPool() + data, err := ioutil.ReadFile(opts.CertFile) + if err != nil { + return errors.Wrapf(err, "failed to read server cert file: %s", opts.CertFile) + } + if ok := certPool.AppendCertsFromPEM([]byte(data)); !ok { + return fmt.Errorf("failed to load certificate into the pool") + } + + grpcServer, err := newGRPCServer(ctx, cfg, grpc.Creds(credentials.NewClientTLSFromCert(certPool, cfg.GetHostAddress()))) + if err != nil { + return errors.Wrap(err, "failed to create GRPC server") + } + + dialCreds := credentials.NewTLS(&tls.Config{ + ServerName: cfg.GetHostAddress(), + RootCAs: certPool, + }) + httpServer, err := newHTTPServer(ctx, cfg, []grpc.DialOption{grpc.WithTransportCredentials(dialCreds)}, cfg.GetHostAddress()) + if err != nil { + return err + } + + conn, err := net.Listen("tcp", cfg.GetHostAddress()) + if err != nil { + panic(err) + } + + srv := &http.Server{ + Addr: cfg.GetHostAddress(), + Handler: grpcHandlerFunc(grpcServer, httpServer), + TLSConfig: &tls.Config{ + Certificates: []tls.Certificate{cert}, + NextProtos: []string{"h2"}, + }, + } + + err = srv.Serve(tls.NewListener(conn, srv.TLSConfig)) + + if err != nil { + return errors.Wrapf(err, "failed to Start HTTP/2 Server") + } + return nil +} diff --git a/flyteadmin/cmd/main.go b/flyteadmin/cmd/main.go new file mode 100644 index 0000000000..a9c0708d9c --- /dev/null +++ b/flyteadmin/cmd/main.go @@ -0,0 +1,14 @@ +package main + +import ( + "github.com/golang/glog" + "github.com/lyft/flyteadmin/cmd/entrypoints" +) + +func main() { + glog.V(2).Info("Beginning Flyte Controller") + err := entrypoints.Execute() + if err != nil { + panic(err) + } +} diff --git a/flyteadmin/flyteadmin_config.yaml b/flyteadmin/flyteadmin_config.yaml new file mode 100644 index 0000000000..377126c61b --- /dev/null +++ b/flyteadmin/flyteadmin_config.yaml @@ -0,0 +1,141 @@ +# This is a sample configuration file. +# Real configuration when running inside K8s (local or otherwise) lives in a ConfigMap +# Look in the artifacts directory in the flyte repo for what's actually run +# https://github.com/lyft/flyte/blob/b47565c9998cde32b0b5f995981e3f3c990fa7cd/artifacts/flyteadmin.yaml#L72 +application: + httpPort: 8088 + grpcPort: 8089 +flyteadmin: + someBoolean: true + someOtherBoolean: false + runScheduler: false + roleNameKey: "iam.amazonaws.com/role" + metricsScope: "flyte:" + profilerPort: 10254 + testing: + host: "http://localhost:8088" + # This last must be in order! For example, a file path would be prefixed with metadata/admin/... + metadataStoragePrefix: + - "metadata" + - "admin" +database: + port: 5432 + username: postgres + host: localhost + dbname: postgres + options: "sslmode=disable" +scheduler: + eventScheduler: + scheme: local + region: "my-region" + scheduleRole: "arn:aws:iam::abc123:role/my-iam-role" + targetName: "arn:aws:sqs:my-region:abc123:my-queue" + workflowExecutor: + scheme: local + region: "my-region" + scheduleQueueName: "won't-work-locally" + accountId: "abc123" +remoteData: + region: "my-region" + scheme: local + signedUrls: + durationMinutes: 3 +notifications: + type: local + region: "my-region" + publisher: + topicName: "foo" + processor: + queueName: "queue" + accountId: "bar" + emailer: + subject: "Notice: Execution \"{{ name }}\" has {{ phase }} in \"{{ domain }}\"." + sender: "flyte-notifications@example.com" + body: > + Execution \"{{ name }}\" has {{ phase }} in \"{{ domain }}\". View details at + + http://example.com/projects/{{ project }}/domains/{{ domain }}/executions/{{ name }}. {{ error }} +Logger: + show-source: true + level: 5 +storage: + type: minio + connection: + access-key: minio + auth-type: accesskey + secret-key: miniostorage + disable-ssl: true + endpoint: "http://localhost:9000" + region: my-region + cache: + max_size_mbs: 10 + target_gc_percent: 100 + container: "flyte" +queues: + executionQueues: + - primary: "gpu_primary" + dynamic: "gpu_dynamic" + attributes: + - gpu + - primary: "critical" + dynamic: "critical" + attributes: + - critical + - primary: "default" + dynamic: "default" + attributes: + - defaultclusters + - primary: "my_queue_1" + domain: "production" + workflowName: "my_workflow_1" + tags: + - critical + - primary: "my_queue_1" + domain: "production" + workflowName: "my_workflow_2" + tags: + - gpu + - primary: "my_queue_3" + domain: "production" + workflowName: "my_workflow_3" + tags: + - critical + - tags: + - default +task_resources: + defaults: + cpu: 100m + gpu: 20m + memory: 1Mi + storage: 10M + limits: + cpu: 500m + gpu: 100m + memory: 1Mi + storage: 10G +task_type_whitelist: + sparkonk8s: + - project: my_queue_1 + domain: production + - project: my_queue_2 + domain: production + qubolespark: + - project: my_queue_2 +domains: + - id: development + name: development + - id: staging + name: staging + - id: production + name: production + - id: domain + name: domain +cluster_resources: + templatePath: pkg/clusterresource/sampletemplates + templateData: + foo: + value: "bar" + foofoo: + valueFrom: + env: SHELL + refresh: 3s diff --git a/flyteadmin/pkg/async/notifications/email.go b/flyteadmin/pkg/async/notifications/email.go new file mode 100644 index 0000000000..ed888d381d --- /dev/null +++ b/flyteadmin/pkg/async/notifications/email.go @@ -0,0 +1,54 @@ +package notifications + +import ( + "fmt" + + "strings" + + "github.com/lyft/flyteadmin/pkg/repositories/models" + runtimeInterfaces "github.com/lyft/flyteadmin/pkg/runtime/interfaces" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" +) + +const executionError = " The execution failed with error: [%s]." + +const substitutionParam = "{{ %s }}" +const project = "project" +const domain = "domain" +const name = "name" +const phase = "phase" +const errorPlaceholder = "error" +const replaceAllInstances = -1 + +func substituteEmailParameters(message string, request admin.WorkflowExecutionEventRequest, execution models.Execution) string { + response := strings.Replace(message, fmt.Sprintf(substitutionParam, project), execution.Project, replaceAllInstances) + response = strings.Replace(response, fmt.Sprintf(substitutionParam, domain), execution.Domain, replaceAllInstances) + response = strings.Replace(response, fmt.Sprintf(substitutionParam, name), execution.Name, replaceAllInstances) + response = strings.Replace(response, fmt.Sprintf(substitutionParam, phase), + strings.ToLower(request.Event.Phase.String()), replaceAllInstances) + if request.Event.GetError() != nil { + response = strings.Replace(response, fmt.Sprintf(substitutionParam, errorPlaceholder), + fmt.Sprintf(executionError, request.Event.GetError().Message), replaceAllInstances) + } else { + // Replace the optional error placeholder with an empty string. + response = strings.Replace(response, fmt.Sprintf(substitutionParam, errorPlaceholder), "", replaceAllInstances) + } + + return response +} + +// Converts a terminal execution event and existing execution model to an admin.EmailMessage proto, substituting parameters +// in customizable email fields set in the flyteadmin application notifications config. +func ToEmailMessageFromWorkflowExecutionEvent( + config runtimeInterfaces.NotificationsConfig, + emailNotification admin.EmailNotification, + request admin.WorkflowExecutionEventRequest, + execution models.Execution) *admin.EmailMessage { + + return &admin.EmailMessage{ + SubjectLine: substituteEmailParameters(config.NotificationsEmailerConfig.Subject, request, execution), + SenderEmail: config.NotificationsEmailerConfig.Sender, + RecipientsEmail: emailNotification.GetRecipientsEmail(), + Body: substituteEmailParameters(config.NotificationsEmailerConfig.Body, request, execution), + } +} diff --git a/flyteadmin/pkg/async/notifications/email_test.go b/flyteadmin/pkg/async/notifications/email_test.go new file mode 100644 index 0000000000..1c301edbde --- /dev/null +++ b/flyteadmin/pkg/async/notifications/email_test.go @@ -0,0 +1,79 @@ +package notifications + +import ( + "fmt" + "testing" + + "github.com/gogo/protobuf/proto" + "github.com/lyft/flyteadmin/pkg/repositories/models" + runtimeInterfaces "github.com/lyft/flyteadmin/pkg/runtime/interfaces" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/event" + "github.com/stretchr/testify/assert" +) + +func TestSubstituteEmailParameters(t *testing.T) { + message := "{{ unused }}. {{project }} and {{ domain }} and {{ name }} ended up in {{ phase }}.{{ error }}" + request := admin.WorkflowExecutionEventRequest{ + Event: &event.WorkflowExecutionEvent{ + Phase: core.WorkflowExecution_SUCCEEDED, + }, + } + model := models.Execution{ + ExecutionKey: models.ExecutionKey{ + Project: "proj", + Domain: "prod", + Name: "e124", + }, + } + assert.Equal(t, "{{ unused }}. {{project }} and prod and e124 ended up in succeeded.", + substituteEmailParameters(message, request, model)) + request.Event.OutputResult = &event.WorkflowExecutionEvent_Error{ + Error: &core.ExecutionError{ + Message: "uh-oh", + }, + } + assert.Equal(t, "{{ unused }}. {{project }} and prod and e124 ended up in succeeded. The execution failed with error: [uh-oh].", + substituteEmailParameters(message, request, model)) +} + +func TestToEmailMessageFromWorkflowExecutionEvent(t *testing.T) { + notificationsConfig := runtimeInterfaces.NotificationsConfig{ + NotificationsEmailerConfig: runtimeInterfaces.NotificationsEmailerConfig{ + Body: "Execution \"{{ name }}\" has succeeded in \"{{ domain }}\". View details at " + + "" + + "https://example.com/executions/{{ project }}/{{ domain }}/{{ name }}.", + Sender: "no-reply@example.com", + Subject: "Notice: Execution \"{{ name }}\" has succeeded in \"{{ domain }}\".", + }, + } + emailNotification := admin.EmailNotification{ + RecipientsEmail: []string{ + "a@example.com", "b@example.org", + }, + } + request := admin.WorkflowExecutionEventRequest{ + Event: &event.WorkflowExecutionEvent{ + Phase: core.WorkflowExecution_ABORTED, + }, + } + model := models.Execution{ + ExecutionKey: models.ExecutionKey{ + Project: "proj", + Domain: "prod", + Name: "e124", + }, + } + emailMessage := ToEmailMessageFromWorkflowExecutionEvent(notificationsConfig, emailNotification, request, model) + assert.True(t, proto.Equal(emailMessage, &admin.EmailMessage{ + RecipientsEmail: []string{ + "a@example.com", "b@example.org", + }, + SenderEmail: "no-reply@example.com", + SubjectLine: "Notice: Execution \"e124\" has succeeded in \"prod\".", + Body: "Execution \"e124\" has succeeded in \"prod\". View details at " + + "" + + "https://example.com/executions/proj/prod/e124.", + }), fmt.Sprintf("%+v", emailMessage)) +} diff --git a/flyteadmin/pkg/async/notifications/factory.go b/flyteadmin/pkg/async/notifications/factory.go new file mode 100644 index 0000000000..72e78cf1de --- /dev/null +++ b/flyteadmin/pkg/async/notifications/factory.go @@ -0,0 +1,111 @@ +package notifications + +import ( + "context" + + "github.com/lyft/flyteadmin/pkg/async/notifications/implementations" + "github.com/lyft/flyteadmin/pkg/async/notifications/interfaces" + runtimeInterfaces "github.com/lyft/flyteadmin/pkg/runtime/interfaces" + "github.com/lyft/flytestdlib/logger" + + "github.com/NYTimes/gizmo/pubsub" + gizmoConfig "github.com/NYTimes/gizmo/pubsub/aws" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/ses" + + "github.com/lyft/flyteadmin/pkg/common" + "github.com/lyft/flytestdlib/promutils" +) + +const maxRetries = 3 + +var enable64decoding = false + +type PublisherConfig struct { + TopicName string +} + +type ProcessorConfig struct { + QueueName string + AccountID string +} + +type EmailerConfig struct { + SenderEmail string + BaseURL string +} + +func GetEmailer(config runtimeInterfaces.NotificationsConfig, scope promutils.Scope) interfaces.Emailer { + switch config.Type { + case common.AWS: + awsConfig := aws.NewConfig().WithRegion(config.Region).WithMaxRetries(maxRetries) + awsSession, err := session.NewSession(awsConfig) + if err != nil { + panic(err) + } + sesClient := ses.New(awsSession) + return implementations.NewAwsEmailer( + config, + scope, + sesClient, + ) + case common.Local: + fallthrough + default: + logger.Infof(context.Background(), "Using default noop emailer implementation for config type [%s]", config.Type) + return implementations.NewNoopEmail() + } +} + +func NewNotificationsProcessor(config runtimeInterfaces.NotificationsConfig, scope promutils.Scope) interfaces.Processor { + var sub pubsub.Subscriber + var emailer interfaces.Emailer + switch config.Type { + case common.AWS: + sqsConfig := gizmoConfig.SQSConfig{ + QueueName: config.NotificationsProcessorConfig.QueueName, + QueueOwnerAccountID: config.NotificationsProcessorConfig.AccountID, + // The AWS configuration type uses SNS to SQS for notifications. + // Gizmo by default will decode the SQS message using Base64 decoding. + // However, the message body of SQS is the SNS message format which isn't Base64 encoded. + ConsumeBase64: &enable64decoding, + } + sqsConfig.Region = config.Region + process, err := gizmoConfig.NewSubscriber(sqsConfig) + if err != nil { + panic(err) + } + sub = process + emailer = GetEmailer(config, scope) + case common.Local: + fallthrough + default: + logger.Infof(context.Background(), + "Using default noop notifications processor implementation for config type [%s]", config.Type) + return implementations.NewNoopProcess() + } + return implementations.NewProcessor(sub, emailer, scope) +} + +func NewNotificationsPublisher(config runtimeInterfaces.NotificationsConfig, scope promutils.Scope) interfaces.Publisher { + switch config.Type { + case common.AWS: + snsConfig := gizmoConfig.SNSConfig{ + Topic: config.NotificationsPublisherConfig.TopicName, + } + snsConfig.Region = config.Region + publisher, err := gizmoConfig.NewPublisher(snsConfig) + // Any errors initiating Publisher with Amazon configurations results in a failed start up. + if err != nil { + panic(err) + } + return implementations.NewPublisher(publisher, scope) + case common.Local: + fallthrough + default: + logger.Infof(context.Background(), + "Using default noop notifications publisher implementation for config type [%s]", config.Type) + return implementations.NewNoopPublish() + } +} diff --git a/flyteadmin/pkg/async/notifications/implementations/aws_emailer.go b/flyteadmin/pkg/async/notifications/implementations/aws_emailer.go new file mode 100644 index 0000000000..516a8aebde --- /dev/null +++ b/flyteadmin/pkg/async/notifications/implementations/aws_emailer.go @@ -0,0 +1,86 @@ +package implementations + +import ( + "context" + + "github.com/aws/aws-sdk-go/service/ses" + "github.com/aws/aws-sdk-go/service/ses/sesiface" + "github.com/lyft/flyteadmin/pkg/async/notifications/interfaces" + "github.com/lyft/flyteadmin/pkg/errors" + runtimeInterfaces "github.com/lyft/flyteadmin/pkg/runtime/interfaces" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/promutils" + "github.com/prometheus/client_golang/prometheus" + "google.golang.org/grpc/codes" +) + +type emailMetrics struct { + Scope promutils.Scope + SendSuccess prometheus.Counter + SendError prometheus.Counter + SendTotal prometheus.Counter +} + +func newEmailMetrics(scope promutils.Scope) emailMetrics { + return emailMetrics{ + Scope: scope, + SendSuccess: scope.MustNewCounter("send_success", "Number of successful emails sent via Emailer."), + SendError: scope.MustNewCounter("send_error", "Number of errors when sending email via Emailer"), + SendTotal: scope.MustNewCounter("send_total", "Total number of emails attempted to be sent"), + } +} + +type AwsEmailer struct { + config runtimeInterfaces.NotificationsConfig + systemMetrics emailMetrics + awsEmail sesiface.SESAPI +} + +func (e *AwsEmailer) SendEmail(ctx context.Context, email admin.EmailMessage) error { + var toAddress []*string + for _, toEmail := range email.RecipientsEmail { + toAddress = append(toAddress, &toEmail) + } + + emailInput := ses.SendEmailInput{ + Destination: &ses.Destination{ + ToAddresses: toAddress, + }, + // Currently use the senderEmail specified apart of the Emailer instead of the body. + // Once a more generic way of setting the emailNotification is defined, remove this + // workaround and defer back to email.SenderEmail + Source: &email.SenderEmail, + Message: &ses.Message{ + Body: &ses.Body{ + Html: &ses.Content{ + Data: &email.Body, + }, + }, + Subject: &ses.Content{ + Data: &email.SubjectLine, + }, + }, + } + + _, err := e.awsEmail.SendEmail(&emailInput) + e.systemMetrics.SendTotal.Inc() + + if err != nil { + // TODO: If we see a certain set of AWS errors consistently, we can break the errors down based on type. + logger.Errorf(ctx, "error in sending email [%s] via ses mailer with err: %s", email.String(), err) + e.systemMetrics.SendError.Inc() + return errors.NewFlyteAdminErrorf(codes.Internal, "errors were seen while sending emails") + } + + e.systemMetrics.SendSuccess.Inc() + return nil +} + +func NewAwsEmailer(config runtimeInterfaces.NotificationsConfig, scope promutils.Scope, awsEmail sesiface.SESAPI) interfaces.Emailer { + return &AwsEmailer{ + config: config, + systemMetrics: newEmailMetrics(scope.NewSubScope("aws_ses")), + awsEmail: awsEmail, + } +} diff --git a/flyteadmin/pkg/async/notifications/implementations/aws_emailer_test.go b/flyteadmin/pkg/async/notifications/implementations/aws_emailer_test.go new file mode 100644 index 0000000000..c72e356037 --- /dev/null +++ b/flyteadmin/pkg/async/notifications/implementations/aws_emailer_test.go @@ -0,0 +1,121 @@ +package implementations + +import ( + "testing" + + "context" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/ses" + "github.com/aws/aws-sdk-go/service/ses/sesiface" + "github.com/lyft/flyteadmin/pkg/async/notifications/mocks" + runtimeInterfaces "github.com/lyft/flyteadmin/pkg/runtime/interfaces" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/lyft/flytestdlib/promutils" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" +) + +func getNotificationsConfig() runtimeInterfaces.NotificationsConfig { + return runtimeInterfaces.NotificationsConfig{ + NotificationsEmailerConfig: runtimeInterfaces.NotificationsEmailerConfig{ + Body: "Execution \"{{ name }}\" has succeeded in \"{{ domain }}\". View details at " + + "" + + "https://example.com/executions/{{ project }}/{{ domain }}/{{ name }}.", + Sender: "no-reply@example.com", + Subject: "Notice: Execution \"{{ name }}\" has succeeded in \"{{ domain }}\".", + }, + } +} + +func TestAwsEmailer_SendEmail(t *testing.T) { + mockAwsEmail := mocks.SESClient{} + var awsSES sesiface.SESAPI = &mockAwsEmail + expectedSenderEmail := "no-reply@example.com" + emailNotification := admin.EmailMessage{ + SubjectLine: "Notice: Execution \"name\" has succeeded in \"domain\".", + SenderEmail: "no-reply@example.com", + RecipientsEmail: []string{ + "my@example.com", + "john@example.com", + }, + Body: "Execution \"name\" has succeeded in \"domain\". View details at " + + "" + + "https://example.com/executions/T/B/D.", + } + + sendEmailValidationFunc := func(input *ses.SendEmailInput) (*ses.SendEmailOutput, error) { + assert.Equal(t, *input.Source, expectedSenderEmail) + assert.Equal(t, *input.Message.Body.Html.Data, emailNotification.Body) + assert.Equal(t, *input.Message.Subject.Data, emailNotification.SubjectLine) + for _, toEmail := range input.Destination.ToAddresses { + var foundEmail = false + for _, verifyToEmail := range emailNotification.RecipientsEmail { + if *toEmail == verifyToEmail { + foundEmail = true + } + } + assert.Truef(t, foundEmail, "To Email address [%s] wasn't apart of original inputs.", *toEmail) + } + assert.Equal(t, len(input.Destination.ToAddresses), len(emailNotification.RecipientsEmail)) + return &ses.SendEmailOutput{}, nil + } + mockAwsEmail.SetSendEmailFunc(sendEmailValidationFunc) + testEmail := NewAwsEmailer(getNotificationsConfig(), promutils.NewTestScope(), awsSES) + + assert.Nil(t, testEmail.SendEmail(context.Background(), emailNotification)) +} + +func TestAwsEmailer_SendEmailError(t *testing.T) { + mockAwsEmail := mocks.SESClient{} + var awsSES sesiface.SESAPI + emailError := errors.New("error sending email") + sendEmailErrorFunc := func(input *ses.SendEmailInput) (*ses.SendEmailOutput, error) { + return nil, emailError + } + mockAwsEmail.SetSendEmailFunc(sendEmailErrorFunc) + awsSES = &mockAwsEmail + + testEmail := NewAwsEmailer(getNotificationsConfig(), promutils.NewTestScope(), awsSES) + + emailNotification := admin.EmailMessage{ + SubjectLine: "Notice: Execution \"name\" has succeeded in \"domain\".", + SenderEmail: "no-reply@example.com", + RecipientsEmail: []string{ + "my@example.com", + "john@example.com", + }, + Body: "Execution \"name\" has succeeded in \"domain\". View details at " + + "" + + "https://example.com/executions/T/B/D.", + } + assert.EqualError(t, testEmail.SendEmail(context.Background(), emailNotification), "errors were seen while sending emails") +} + +func TestAwsEmailer_SendEmailEmailOutput(t *testing.T) { + mockAwsEmail := mocks.SESClient{} + var awsSES sesiface.SESAPI + emailOutput := ses.SendEmailOutput{ + MessageId: aws.String("1234"), + } + sendEmailErrorFunc := func(input *ses.SendEmailInput) (*ses.SendEmailOutput, error) { + return &emailOutput, nil + } + mockAwsEmail.SetSendEmailFunc(sendEmailErrorFunc) + awsSES = &mockAwsEmail + + testEmail := NewAwsEmailer(getNotificationsConfig(), promutils.NewTestScope(), awsSES) + + emailNotification := admin.EmailMessage{ + SubjectLine: "Notice: Execution \"name\" has succeeded in \"domain\".", + SenderEmail: "no-reply@example.com", + RecipientsEmail: []string{ + "my@example.com", + "john@example.com", + }, + Body: "Execution \"name\" has succeeded in \"domain\". View details at " + + "" + + "https://example.com/executions/T/B/D.", + } + assert.Nil(t, testEmail.SendEmail(context.Background(), emailNotification)) +} diff --git a/flyteadmin/pkg/async/notifications/implementations/noop_notifications.go b/flyteadmin/pkg/async/notifications/implementations/noop_notifications.go new file mode 100644 index 0000000000..f9a0e1595b --- /dev/null +++ b/flyteadmin/pkg/async/notifications/implementations/noop_notifications.go @@ -0,0 +1,54 @@ +package implementations + +import ( + "context" + + "github.com/lyft/flyteadmin/pkg/async/notifications/interfaces" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/lyft/flytestdlib/logger" + + "strings" + + "github.com/golang/protobuf/proto" +) + +// Email to use when there is no email configuration. +type NoopEmail struct{} + +func (n *NoopEmail) SendEmail(ctx context.Context, email admin.EmailMessage) error { + logger.Debugf(ctx, "received noop SendEmail request with subject [%s] and recipient [%s]", + email.SubjectLine, strings.Join(email.RecipientsEmail, ",")) + return nil +} + +func NewNoopEmail() interfaces.Emailer { + return &NoopEmail{} +} + +type NoopPublish struct{} + +func (n *NoopPublish) Publish(ctx context.Context, notificationType string, msg proto.Message) error { + logger.Debugf(ctx, "call to noop publish with notification type [%s] and proto message [%s]", notificationType, msg.String()) + return nil +} + +func NewNoopPublish() interfaces.Publisher { + return &NoopPublish{} +} + +type NoopProcess struct{} + +func (n *NoopProcess) StartProcessing() error { + logger.Debug(context.Background(), "call to noop start processing.") + return nil +} + +func (n *NoopProcess) StopProcessing() error { + logger.Debug(context.Background(), "call to noop stop processing.") + return nil +} + +func NewNoopProcess() interfaces.Processor { + return &NoopProcess{} +} diff --git a/flyteadmin/pkg/async/notifications/implementations/processor.go b/flyteadmin/pkg/async/notifications/implementations/processor.go new file mode 100644 index 0000000000..a3d5ff9c86 --- /dev/null +++ b/flyteadmin/pkg/async/notifications/implementations/processor.go @@ -0,0 +1,160 @@ +package implementations + +import ( + "context" + + "github.com/lyft/flyteadmin/pkg/async/notifications/interfaces" + + "encoding/base64" + "encoding/json" + + "github.com/NYTimes/gizmo/pubsub" + "github.com/golang/protobuf/proto" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/promutils" + "github.com/prometheus/client_golang/prometheus" +) + +type processorSystemMetrics struct { + Scope promutils.Scope + MessageTotal prometheus.Counter + MessageDoneError prometheus.Counter + MessageDecodingError prometheus.Counter + MessageDataError prometheus.Counter + MessageProcessorError prometheus.Counter + MessageSuccess prometheus.Counter + ChannelClosedError prometheus.Counter + StopError prometheus.Counter +} + +// TODO: Add a counter that encompasses the publisher stats grouped by project and domain. +type Processor struct { + sub pubsub.Subscriber + email interfaces.Emailer + systemMetrics processorSystemMetrics +} + +// Currently only email is the supported notification because slack and pagerduty both use +// email client to trigger those notifications. +// When Pagerduty and other notifications are supported, a publisher per type should be created. +func (p *Processor) StartProcessing() error { + var emailMessage admin.EmailMessage + var err error + for msg := range p.sub.Start() { + + p.systemMetrics.MessageTotal.Inc() + // Currently this is safe because Gizmo takes a string and casts it to a byte array. + var stringMsg = string(msg.Message()) + // Amazon doesn't provide a struct that can be used to unmarshall into. A generic JSON struct is used in its place. + var snsJSONFormat map[string]interface{} + + // At Lyft, SNS populates SQS. This results in the message body of SQS having the SNS message format. + // The message format is documented here: https://docs.aws.amazon.com/sns/latest/dg/sns-message-and-json-formats.html + // The notification published is stored in the message field after unmarshalling the SQS message. + if err := json.Unmarshal(msg.Message(), &snsJSONFormat); err != nil { + p.systemMetrics.MessageDecodingError.Inc() + logger.Errorf(context.Background(), "failed to unmarshall JSON message [%s] from processor with err: %v", stringMsg, err) + p.markMessageDone(msg) + continue + } + + var value interface{} + var ok bool + var valueString string + + if value, ok = snsJSONFormat["Message"]; !ok { + logger.Errorf(context.Background(), "failed to retrieve message from unmarshalled JSON object [%s]", stringMsg) + p.systemMetrics.MessageDataError.Inc() + p.markMessageDone(msg) + continue + } + + if valueString, ok = value.(string); !ok { + p.systemMetrics.MessageDataError.Inc() + logger.Errorf(context.Background(), "failed to retrieve notification message (in string format) from unmarshalled JSON object for message [%s]", stringMsg) + p.markMessageDone(msg) + continue + } + + // The Publish method for SNS Encodes the notification using Base64 then stringifies it before + // setting that as the message body for SNS. Do the inverse to retrieve the notification. + notificationBytes, err := base64.StdEncoding.DecodeString(valueString) + if err != nil { + logger.Errorf(context.Background(), "failed to Base64 decode from message string [%s] from message [%s] with err: %v", valueString, stringMsg, err) + p.systemMetrics.MessageDecodingError.Inc() + p.markMessageDone(msg) + continue + } + + if err = proto.Unmarshal(notificationBytes, &emailMessage); err != nil { + logger.Debugf(context.Background(), "failed to unmarshal to notification object from decoded string[%s] from message [%s] with err: %v", valueString, stringMsg, err) + p.systemMetrics.MessageDecodingError.Inc() + p.markMessageDone(msg) + continue + } + + if err = p.email.SendEmail(context.Background(), emailMessage); err != nil { + p.systemMetrics.MessageProcessorError.Inc() + logger.Errorf(context.Background(), "Error sending an email message for message [%s] with emailM with err: %v", emailMessage.String(), err) + } else { + p.systemMetrics.MessageSuccess.Inc() + } + + p.markMessageDone(msg) + + } + + // According to https://github.com/NYTimes/gizmo/blob/f2b3deec03175b11cdfb6642245a49722751357f/pubsub/pubsub.go#L36-L39, + // the channel backing the subscriber will just close if there is an error. The call to Err() is needed to identify + // there was an error in the channel or there are no more messages left (resulting in no errors when calling Err()). + if err = p.sub.Err(); err != nil { + p.systemMetrics.ChannelClosedError.Inc() + logger.Warningf(context.Background(), "The stream for the subscriber channel closed with err: %v", err) + } + + // If there are no errors, nil will be returned. + return err +} + +func (p *Processor) markMessageDone(message pubsub.SubscriberMessage) { + if err := message.Done(); err != nil { + p.systemMetrics.MessageDoneError.Inc() + logger.Errorf(context.Background(), "failed to mark message as Done() in processor with err: %v", err) + } +} + +func (p *Processor) StopProcessing() error { + // Note: If the underlying channel is already closed, then Stop() will return an error. + err := p.sub.Stop() + if err != nil { + p.systemMetrics.StopError.Inc() + logger.Errorf(context.Background(), "Failed to stop the subscriber channel gracefully with err: %v", err) + } + return err +} + +func newProcessorSystemMetrics(scope promutils.Scope) processorSystemMetrics { + return processorSystemMetrics{ + Scope: scope, + MessageTotal: scope.MustNewCounter("message_total", "overall count of messages processed"), + MessageDecodingError: scope.MustNewCounter("message_decoding_error", "count of messages with decoding errors"), + MessageDataError: scope.MustNewCounter("message_data_error", "count of message data processing errors experience when preparing the message to be notified."), + MessageDoneError: scope.MustNewCounter("message_done_error", + "count of message errors when marking it as done with underlying processor"), + MessageProcessorError: scope.MustNewCounter("message_processing_error", + "count of errors when interacting with notification processor"), + MessageSuccess: scope.MustNewCounter("message_ok", + "count of messages successfully processed by underlying notification mechanism"), + ChannelClosedError: scope.MustNewCounter("channel_closed_error", "count of channel closing errors"), + StopError: scope.MustNewCounter("stop_error", "count of errors in Stop() method"), + } +} + +func NewProcessor(sub pubsub.Subscriber, emailer interfaces.Emailer, scope promutils.Scope) interfaces.Processor { + return &Processor{ + sub: sub, + email: emailer, + systemMetrics: newProcessorSystemMetrics(scope.NewSubScope("processor")), + } +} diff --git a/flyteadmin/pkg/async/notifications/implementations/processor_test.go b/flyteadmin/pkg/async/notifications/implementations/processor_test.go new file mode 100644 index 0000000000..c714b5480a --- /dev/null +++ b/flyteadmin/pkg/async/notifications/implementations/processor_test.go @@ -0,0 +1,140 @@ +package implementations + +import ( + "context" + "errors" + "testing" + + "encoding/base64" + + "github.com/aws/aws-sdk-go/aws" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + + "github.com/lyft/flyteadmin/pkg/async/notifications/mocks" + "github.com/stretchr/testify/assert" +) + +var mockEmailer mocks.MockEmailer + +// This method should be invoked before every test to Subscriber. +func initializeProcessor() { + testSubscriber.GivenStopError = nil + testSubscriber.GivenErrError = nil + testSubscriber.FoundError = nil + testSubscriber.ProtoMessages = nil + testSubscriber.JSONMessages = nil +} + +func TestProcessor_StartProcessing(t *testing.T) { + initializeProcessor() + + // Because the message stored in Amazon SQS is a JSON of the SNS output, store the test output in the JSON Messages. + testSubscriber.JSONMessages = append(testSubscriber.JSONMessages, testSubscriberMessage) + + sendEmailValidationFunc := func(ctx context.Context, email admin.EmailMessage) error { + assert.Equal(t, email.Body, testEmail.Body) + assert.Equal(t, email.RecipientsEmail, testEmail.RecipientsEmail) + assert.Equal(t, email.SubjectLine, testEmail.SubjectLine) + assert.Equal(t, email.SenderEmail, testEmail.SenderEmail) + return nil + } + mockEmailer.SetSendEmailFunc(sendEmailValidationFunc) + // TODO Add test for metric inc for number of messages processed. + // Assert 1 message processed and 1 total. + assert.Nil(t, testProcessor.StartProcessing()) +} + +func TestProcessor_StartProcessingNoMessages(t *testing.T) { + initializeProcessor() + // Expect no errors are returned. + assert.Nil(t, testProcessor.StartProcessing()) + // TODO add test for metric inc() for number of messages processed. + // Assert 0 messages processed and 0 total. +} + +func TestProcessor_StartProcessingNoNotificationMessage(t *testing.T) { + var testMessage = map[string]interface{}{ + "Type": "Not a real notification", + "MessageId": "1234", + } + initializeProcessor() + testSubscriber.JSONMessages = append(testSubscriber.JSONMessages, testMessage) + assert.Nil(t, testProcessor.StartProcessing()) + // TODO add test for metric inc() for number of messages processed. + // Assert 1 messages error and 1 total. +} + +func TestProcessor_StartProcessingMessageWrongDataType(t *testing.T) { + var testMessage = map[string]interface{}{ + "Type": "Not a real notification", + "MessageId": "1234", + "Message": 12, + } + initializeProcessor() + testSubscriber.JSONMessages = append(testSubscriber.JSONMessages, testMessage) + assert.Nil(t, testProcessor.StartProcessing()) + // TODO add test for metric inc() for number of messages processed. + // Assert 1 messages error and 1 total. +} + +func TestProcessor_StartProcessingBase64DecodeError(t *testing.T) { + var testMessage = map[string]interface{}{ + "Type": "Not a real notification", + "MessageId": "1234", + "Message": "NotBase64encoded", + } + initializeProcessor() + testSubscriber.JSONMessages = append(testSubscriber.JSONMessages, testMessage) + assert.Nil(t, testProcessor.StartProcessing()) + // TODO add test for metric inc() for number of messages processed. + // Assert 1 messages error and 1 total. +} + +func TestProcessor_StartProcessingProtoMarshallError(t *testing.T) { + var badByte = []byte("atreyu") + var testMessage = map[string]interface{}{ + "Type": "Not a real notification", + "MessageId": "1234", + "Message": aws.String(base64.StdEncoding.EncodeToString(badByte)), + } + initializeProcessor() + testSubscriber.JSONMessages = append(testSubscriber.JSONMessages, testMessage) + assert.Nil(t, testProcessor.StartProcessing()) + // TODO add test for metric inc() for number of messages processed. + // Assert 1 messages error and 1 total. +} + +func TestProcessor_StartProcessingError(t *testing.T) { + initializeProcessor() + var ret = errors.New("err() returned an error") + // The error set by GivenErrError is returned by Err(). + // Err() is checked before Run() returning. + testSubscriber.GivenErrError = ret + assert.Equal(t, ret, testProcessor.StartProcessing()) +} + +func TestProcessor_StartProcessingEmailError(t *testing.T) { + initializeProcessor() + emailError := errors.New("error sending email") + sendEmailErrorFunc := func(ctx context.Context, email admin.EmailMessage) error { + return emailError + } + mockEmailer.SetSendEmailFunc(sendEmailErrorFunc) + testSubscriber.JSONMessages = append(testSubscriber.JSONMessages, testSubscriberMessage) + + // Even if there is an error in sending an email StartProcessing will return no errors. + // TODO: Once stats have been added check for an email error stat. + assert.Nil(t, testProcessor.StartProcessing()) +} + +func TestProcessor_StopProcessing(t *testing.T) { + initializeProcessor() + assert.Nil(t, testProcessor.StopProcessing()) +} + +func TestProcessor_StopProcessingError(t *testing.T) { + initializeProcessor() + var stopError = errors.New("stop() returns an error") + testSubscriber.GivenStopError = stopError + assert.Equal(t, stopError, testProcessor.StopProcessing()) +} diff --git a/flyteadmin/pkg/async/notifications/implementations/publisher.go b/flyteadmin/pkg/async/notifications/implementations/publisher.go new file mode 100644 index 0000000000..63d00883c0 --- /dev/null +++ b/flyteadmin/pkg/async/notifications/implementations/publisher.go @@ -0,0 +1,52 @@ +package implementations + +import ( + "context" + + "github.com/lyft/flyteadmin/pkg/async/notifications/interfaces" + + "github.com/NYTimes/gizmo/pubsub" + "github.com/golang/protobuf/proto" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/promutils" + "github.com/prometheus/client_golang/prometheus" +) + +type publisherSystemMetrics struct { + Scope promutils.Scope + PublishTotal prometheus.Counter + PublishError prometheus.Counter +} + +// TODO: Add a counter that encompasses the publisher stats grouped by project and domain. +type Publisher struct { + pub pubsub.Publisher + systemMetrics publisherSystemMetrics +} + +// The key is the notification type as defined as an enum. +func (p *Publisher) Publish(ctx context.Context, notificationType string, msg proto.Message) error { + p.systemMetrics.PublishTotal.Inc() + logger.Debugf(ctx, "Publishing the following message [%s]", msg.String()) + err := p.pub.Publish(ctx, notificationType, msg) + if err != nil { + p.systemMetrics.PublishError.Inc() + logger.Errorf(ctx, "Failed to publish a message with key [%s] and message [%s] and error: %v", notificationType, msg.String(), err) + } + return err +} + +func newPublisherSystemMetrics(scope promutils.Scope) publisherSystemMetrics { + return publisherSystemMetrics{ + Scope: scope, + PublishTotal: scope.MustNewCounter("publish_total", "overall count of publish messages"), + PublishError: scope.MustNewCounter("publish_errors", "count of publish errors"), + } +} + +func NewPublisher(pub pubsub.Publisher, scope promutils.Scope) interfaces.Publisher { + return &Publisher{ + pub: pub, + systemMetrics: newPublisherSystemMetrics(scope.NewSubScope("publisher")), + } +} diff --git a/flyteadmin/pkg/async/notifications/implementations/publisher_test.go b/flyteadmin/pkg/async/notifications/implementations/publisher_test.go new file mode 100644 index 0000000000..893a2e2c11 --- /dev/null +++ b/flyteadmin/pkg/async/notifications/implementations/publisher_test.go @@ -0,0 +1,76 @@ +package implementations + +import ( + "context" + "errors" + "testing" + + "github.com/lyft/flyteadmin/pkg/async/notifications/mocks" + + "encoding/base64" + + "github.com/NYTimes/gizmo/pubsub" + "github.com/NYTimes/gizmo/pubsub/pubsubtest" + "github.com/aws/aws-sdk-go/aws" + "github.com/golang/protobuf/proto" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/lyft/flytestdlib/promutils" + "github.com/stretchr/testify/assert" +) + +var testPublisher pubsubtest.TestPublisher +var mockPublisher pubsub.Publisher = &testPublisher +var currentPublisher = NewPublisher(mockPublisher, promutils.NewTestScope()) +var testEmail = admin.EmailMessage{ + RecipientsEmail: []string{ + "a@example.com", + "b@example.com", + }, + SenderEmail: "no-reply@example.com", + SubjectLine: "Test email", + Body: "This is a sample email.", +} + +var msg, _ = proto.Marshal(&testEmail) + +var testSubscriberMessage = map[string]interface{}{ + "Type": "Notification", + "MessageId": "1-a-3-c", + "TopicArn": "arn:aws:sns:my-region:123:flyte-test-notifications", + "Subject": "flyteidl.admin.EmailNotification", + "Message": aws.String(base64.StdEncoding.EncodeToString(msg)), + "Timestamp": "2019-01-04T22:59:32.849Z", + "SignatureVersion": "1", + "Signature": "some&ignature==", + "SigningCertURL": "https://sns.my-region.amazonaws.com/afdaf", + "UnsubscribeURL": "https://sns.my-region.amazonaws.com/sns:my-region:123:flyte-test-notifications:1-2-3-4-5", +} +var testSubscriber pubsubtest.TestSubscriber +var mockSub pubsub.Subscriber = &testSubscriber +var mockEmail mocks.MockEmailer +var testProcessor = NewProcessor(mockSub, &mockEmail, promutils.NewTestScope()) + +// This method should be invoked before every test around Publisher. +func initializePublisher() { + testPublisher.Published = nil + testPublisher.GivenError = nil + testPublisher.FoundError = nil +} + +func TestPublisher_PublishSuccess(t *testing.T) { + initializePublisher() + assert.Nil(t, currentPublisher.Publish(context.Background(), proto.MessageName(&testEmail), &testEmail)) + assert.Equal(t, 1, len(testPublisher.Published)) + assert.Equal(t, proto.MessageName(&testEmail), testPublisher.Published[0].Key) + marshalledData, err := proto.Marshal(&testEmail) + assert.Nil(t, err) + assert.Equal(t, marshalledData, testPublisher.Published[0].Body) + +} + +func TestPublisher_PublishError(t *testing.T) { + initializePublisher() + var publishError = errors.New("publish() returns an error") + testPublisher.GivenError = publishError + assert.Equal(t, publishError, currentPublisher.Publish(context.Background(), "test", &testEmail)) +} diff --git a/flyteadmin/pkg/async/notifications/interfaces/emailer.go b/flyteadmin/pkg/async/notifications/interfaces/emailer.go new file mode 100644 index 0000000000..8970ff41f6 --- /dev/null +++ b/flyteadmin/pkg/async/notifications/interfaces/emailer.go @@ -0,0 +1,13 @@ +package interfaces + +import ( + "context" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" +) + +// The implementation of Emailer needs to be passed to the implementation of Processor +// in order for emails to be sent. +type Emailer interface { + SendEmail(ctx context.Context, email admin.EmailMessage) error +} diff --git a/flyteadmin/pkg/async/notifications/interfaces/processor.go b/flyteadmin/pkg/async/notifications/interfaces/processor.go new file mode 100644 index 0000000000..0a8534d57c --- /dev/null +++ b/flyteadmin/pkg/async/notifications/interfaces/processor.go @@ -0,0 +1,17 @@ +package interfaces + +// Exposes the common methods required for a subscriber. +// There is one ProcessNotification per type. +type Processor interface { + + // Starts processing messages from the underlying subscriber. + // If the channel closes gracefully, no error will be returned. + // If the underlying channel experiences errors, + // an error is returned and the channel is closed. + StartProcessing() error + + // This should be invoked when the application is shutting down. + // If StartProcessing() returned an error, StopProcessing() will return an error because + // the channel was already closed. + StopProcessing() error +} diff --git a/flyteadmin/pkg/async/notifications/interfaces/publisher.go b/flyteadmin/pkg/async/notifications/interfaces/publisher.go new file mode 100644 index 0000000000..94e45180fb --- /dev/null +++ b/flyteadmin/pkg/async/notifications/interfaces/publisher.go @@ -0,0 +1,27 @@ +package interfaces + +import ( + "context" + + "github.com/golang/protobuf/proto" +) + +// Note on Notifications + +// Notifications are handled in two steps. +// 1. Publishing a notification +// 2. Processing a notification + +// Publishing a notification enqueues a notification message to be processed. Currently there is only +// one publisher for all notification types with the type differing based on the key. +// The notification hasn't been delivered at this stage. +// Processing a notification takes a notification message from the publisher and will pass +// the notification using the desired delivery method (ex: email). There is one processor per +// notification type. + +// Publish a notification will differ between different types of notifications using the key +// The contract requires one subscription per type i.e. one for email one for slack, etc... +type Publisher interface { + // The notification type is inferred from the Notification object in the Execution Spec. + Publish(ctx context.Context, notificationType string, msg proto.Message) error +} diff --git a/flyteadmin/pkg/async/notifications/mocks/emailer.go b/flyteadmin/pkg/async/notifications/mocks/emailer.go new file mode 100644 index 0000000000..3705d49355 --- /dev/null +++ b/flyteadmin/pkg/async/notifications/mocks/emailer.go @@ -0,0 +1,24 @@ +package mocks + +import ( + "github.com/aws/aws-sdk-go/service/ses" + "github.com/aws/aws-sdk-go/service/ses/sesiface" +) + +type AwsSendEmailFunc func(input *ses.SendEmailInput) (*ses.SendEmailOutput, error) + +type SESClient struct { + sesiface.SESAPI + sendEmail AwsSendEmailFunc +} + +func (m *SESClient) SetSendEmailFunc(emailFunc AwsSendEmailFunc) { + m.sendEmail = emailFunc +} + +func (m *SESClient) SendEmail(input *ses.SendEmailInput) (*ses.SendEmailOutput, error) { + if m.sendEmail != nil { + return m.sendEmail(input) + } + return &ses.SendEmailOutput{}, nil +} diff --git a/flyteadmin/pkg/async/notifications/mocks/processor.go b/flyteadmin/pkg/async/notifications/mocks/processor.go new file mode 100644 index 0000000000..9ef18a8da4 --- /dev/null +++ b/flyteadmin/pkg/async/notifications/mocks/processor.go @@ -0,0 +1,47 @@ +package mocks + +import ( + "context" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" +) + +type RunFunc func() error + +type StopFunc func() error + +type MockSubscriber struct { + runFunc RunFunc + stopFunc StopFunc +} + +func (m *MockSubscriber) Run() error { + if m.runFunc != nil { + return m.runFunc() + } + return nil +} + +func (m *MockSubscriber) Stop() error { + if m.stopFunc != nil { + return m.stopFunc() + } + return nil +} + +type SendEmailFunc func(ctx context.Context, email admin.EmailMessage) error + +type MockEmailer struct { + sendEmailFunc SendEmailFunc +} + +func (m *MockEmailer) SetSendEmailFunc(sendEmail SendEmailFunc) { + m.sendEmailFunc = sendEmail +} + +func (m *MockEmailer) SendEmail(ctx context.Context, email admin.EmailMessage) error { + if m.sendEmailFunc != nil { + return m.sendEmailFunc(ctx, email) + } + return nil +} diff --git a/flyteadmin/pkg/async/notifications/mocks/publisher.go b/flyteadmin/pkg/async/notifications/mocks/publisher.go new file mode 100644 index 0000000000..ccfa041eb1 --- /dev/null +++ b/flyteadmin/pkg/async/notifications/mocks/publisher.go @@ -0,0 +1,24 @@ +package mocks + +import ( + "context" + + "github.com/golang/protobuf/proto" +) + +type PublishFunc func(ctx context.Context, key string, msg proto.Message) error + +type MockPublisher struct { + publishFunc PublishFunc +} + +func (m *MockPublisher) SetPublishCallback(publishFunction PublishFunc) { + m.publishFunc = publishFunction +} + +func (m *MockPublisher) Publish(ctx context.Context, notificationType string, msg proto.Message) error { + if m.publishFunc != nil { + return m.publishFunc(ctx, notificationType, msg) + } + return nil +} diff --git a/flyteadmin/pkg/async/schedule/aws/cloud_watch_scheduler.go b/flyteadmin/pkg/async/schedule/aws/cloud_watch_scheduler.go new file mode 100644 index 0000000000..297a57b689 --- /dev/null +++ b/flyteadmin/pkg/async/schedule/aws/cloud_watch_scheduler.go @@ -0,0 +1,263 @@ +package aws + +import ( + "context" + "fmt" + "strings" + + "github.com/lyft/flyteadmin/pkg/async/schedule/aws/interfaces" + scheduleInterfaces "github.com/lyft/flyteadmin/pkg/async/schedule/interfaces" + + "github.com/lyft/flytestdlib/promutils" + "github.com/prometheus/client_golang/prometheus" + + "github.com/aws/aws-sdk-go/aws/awserr" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/cloudwatchevents" + "github.com/lyft/flyteadmin/pkg/errors" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/lyft/flytestdlib/logger" + "google.golang.org/grpc/codes" +) + +// To indicate that a schedule rule is enabled. +var enableState = "ENABLED" + +// CloudWatch schedule expressions. +const ( + cronExpression = "cron(%s)" + rateExpression = "rate(%v %s)" +) + +const timePlaceholder = "time" + +var timeValue = "$.time" + +const scheduleNameInputsFormat = "%s:%s:%s" +const scheduleDescriptionFormat = "Schedule for Project:%s Domain:%s Name:%s launch plan" +const scheduleNameFormat = "flyte_%d" + +// Container for initialized metrics objects +type cloudWatchSchedulerMetrics struct { + Scope promutils.Scope + InvalidSchedules prometheus.Counter + AddRuleFailures prometheus.Counter + AddTargetFailures prometheus.Counter + SchedulesAdded prometheus.Counter + + RemoveRuleFailures prometheus.Counter + RemoveRuleDoesntExist prometheus.Counter + RemoveTargetFailures prometheus.Counter + RemoveTargetDoesntExist prometheus.Counter + RemovedSchedules prometheus.Counter + + ActiveSchedules prometheus.Gauge +} + +// An AWS CloudWatch implementation of the EventScheduler. +type cloudWatchScheduler struct { + // The ARN of the IAM role associated with the scheduler. + scheduleRoleArn string + // The ARN of the SQS target used for registering schedule events. + targetSqsArn string + // AWS CloudWatchEvents service client. + cloudWatchEventClient interfaces.CloudWatchEventClient + // For emitting scheduler-related metrics + metrics cloudWatchSchedulerMetrics +} + +func getScheduleName(identifier admin.NamedEntityIdentifier) string { + hashedIdentifier := hashIdentifier(identifier) + return fmt.Sprintf(scheduleNameFormat, hashedIdentifier) +} + +func getScheduleDescription(identifier admin.NamedEntityIdentifier) string { + return fmt.Sprintf(scheduleDescriptionFormat, + identifier.Project, identifier.Domain, identifier.Name) +} + +func getScheduleExpression(schedule admin.Schedule) (string, error) { + if schedule.GetCronExpression() != "" { + return fmt.Sprintf(cronExpression, schedule.GetCronExpression()), nil + } + if schedule.GetRate() != nil { + // AWS uses pluralization for units of values not equal to 1. + // See https://docs.aws.amazon.com/lambda/latest/dg/tutorial-scheduled-events-schedule-expressions.html + unit := strings.ToLower(schedule.GetRate().Unit.String()) + if schedule.GetRate().Value != 1 { + unit = fmt.Sprintf("%ss", unit) + } + return fmt.Sprintf(rateExpression, schedule.GetRate().Value, unit), nil + } + logger.Debugf(context.Background(), "scheduler encountered invalid schedule expression: %s", schedule.String()) + return "", errors.NewFlyteAdminErrorf(codes.InvalidArgument, "unrecognized schedule expression") +} + +func formatEventScheduleInputs(inputTemplate *string) cloudwatchevents.InputTransformer { + inputsPathMap := map[string]*string{ + timePlaceholder: &timeValue, + } + return cloudwatchevents.InputTransformer{ + InputPathsMap: inputsPathMap, + InputTemplate: inputTemplate, + } +} + +func (s *cloudWatchScheduler) AddSchedule(ctx context.Context, input scheduleInterfaces.AddScheduleInput) error { + if input.Payload == nil { + logger.Debugf(ctx, "AddSchedule called with empty input payload: %+v", input) + return errors.NewFlyteAdminError(codes.InvalidArgument, "payload serialization function cannot be nil") + } + scheduleExpression, err := getScheduleExpression(input.ScheduleExpression) + if err != nil { + s.metrics.InvalidSchedules.Inc() + return err + } + scheduleName := getScheduleName(input.Identifier) + scheduleDescription := getScheduleDescription(input.Identifier) + // First define a rule which gets triggered on a schedule. + requestInput := cloudwatchevents.PutRuleInput{ + ScheduleExpression: &scheduleExpression, + Name: &scheduleName, + Description: &scheduleDescription, + RoleArn: &s.scheduleRoleArn, + State: &enableState, + } + putRuleOutput, err := s.cloudWatchEventClient.PutRule(&requestInput) + if err != nil { + logger.Infof(ctx, "Failed to add rule to cloudwatch for schedule [%+v] with name %s and expression %s with err: %v", + input.Identifier, scheduleName, scheduleExpression, err) + s.metrics.AddRuleFailures.Inc() + return errors.NewFlyteAdminErrorf(codes.Internal, "failed to add rule to cloudwatch with err: %v", err) + } + eventInputTransformer := formatEventScheduleInputs(input.Payload) + // Next, add a target which gets invoked when the above rule is triggered. + putTargetOutput, err := s.cloudWatchEventClient.PutTargets(&cloudwatchevents.PutTargetsInput{ + Rule: &scheduleName, + Targets: []*cloudwatchevents.Target{ + { + Arn: &s.targetSqsArn, + Id: &scheduleName, + InputTransformer: &eventInputTransformer, + }, + }, + }) + if err != nil { + logger.Infof(ctx, "Failed to add target for event schedule [%+v] with name %s with err: %v", + input.Identifier, scheduleName, err) + s.metrics.AddTargetFailures.Inc() + return errors.NewFlyteAdminErrorf(codes.Internal, "failed to add target for event schedule with err: %v", err) + } else if putTargetOutput.FailedEntryCount != nil && *putTargetOutput.FailedEntryCount > 0 { + logger.Infof(ctx, "Failed to add target for event schedule [%+v] with name %s with failed entries: %d", + input.Identifier, scheduleName, *putTargetOutput.FailedEntryCount) + s.metrics.AddTargetFailures.Inc() + return errors.NewFlyteAdminErrorf(codes.Internal, + "failed to add target for event schedule with %v errs", *putTargetOutput.FailedEntryCount) + } + var putRuleOutputName string + if putRuleOutput != nil && putRuleOutput.RuleArn != nil { + putRuleOutputName = *putRuleOutput.RuleArn + } + logger.Debugf(ctx, "Added schedule %s [%s] with arn: %s (%s)", + scheduleName, scheduleExpression, putRuleOutputName, scheduleDescription) + s.metrics.SchedulesAdded.Inc() + s.metrics.ActiveSchedules.Inc() + return nil +} + +func isResourceNotFoundException(err error) bool { + switch err.(type) { + case awserr.Error: + return err.(awserr.Error).Code() == cloudwatchevents.ErrCodeResourceNotFoundException + } + return false +} + +func (s *cloudWatchScheduler) RemoveSchedule(ctx context.Context, identifier admin.NamedEntityIdentifier) error { + name := getScheduleName(identifier) + // All outbound targets for a rule must be deleted before the rule itself can be deleted. + output, err := s.cloudWatchEventClient.RemoveTargets(&cloudwatchevents.RemoveTargetsInput{ + Ids: []*string{ + &name, + }, + Rule: &name, + }) + if err != nil { + if isResourceNotFoundException(err) { + s.metrics.RemoveTargetDoesntExist.Inc() + logger.Debugf(ctx, "Tried to remove cloudwatch target %s but it was not found", name) + } else { + s.metrics.RemoveTargetFailures.Inc() + logger.Errorf(ctx, "failed to remove cloudwatch target %s with err: %v", name, err) + return errors.NewFlyteAdminErrorf(codes.Internal, "failed to remove cloudwatch target %s with err: %v", name, err) + } + } + if output != nil && output.FailedEntryCount != nil && *output.FailedEntryCount > 0 { + s.metrics.RemoveTargetFailures.Inc() + logger.Errorf(ctx, "failed to remove cloudwatch target %s with %v errs", + name, *output.FailedEntryCount) + return errors.NewFlyteAdminErrorf(codes.Internal, "failed to remove cloudwatch target %s with %v errs", + name, *output.FailedEntryCount) + } + + // Output from the call to DeleteRule is an empty struct. + _, err = s.cloudWatchEventClient.DeleteRule(&cloudwatchevents.DeleteRuleInput{ + Name: &name, + }) + if err != nil { + if isResourceNotFoundException(err) { + s.metrics.RemoveRuleDoesntExist.Inc() + logger.Debugf(ctx, "Tried to remove cloudwatch rule %s but it was not found", name) + } else { + s.metrics.RemoveRuleFailures.Inc() + logger.Errorf(ctx, "failed to remove cloudwatch rule %s with err: %v", name, err) + return errors.NewFlyteAdminErrorf(codes.Internal, + "failed to remove cloudwatch rule %s with err: %v", name, err) + } + } + s.metrics.RemovedSchedules.Inc() + s.metrics.ActiveSchedules.Dec() + logger.Debugf(ctx, "Removed schedule %s for identifier [%+v]", name, identifier) + return nil +} + +// Initializes a new set of metrics specific to the cloudwatch scheduler implementation. +func newCloudWatchSchedulerMetrics(scope promutils.Scope) cloudWatchSchedulerMetrics { + return cloudWatchSchedulerMetrics{ + Scope: scope, + InvalidSchedules: scope.MustNewCounter("schedules_invalid", "count of invalid schedule expressions submitted"), + AddRuleFailures: scope.MustNewCounter("add_rule_failures", + "count of attempts to add a cloudwatch rule that have failed"), + AddTargetFailures: scope.MustNewCounter("add_target_failures", + "count of attempts to add a cloudwatch target that have failed"), + SchedulesAdded: scope.MustNewCounter("schedules_added", + "count of all schedules successfully added to cloudwatch"), + RemoveRuleFailures: scope.MustNewCounter("delete_rule_failures", + "count of attempts to remove a cloudwatch rule that have failed"), + RemoveRuleDoesntExist: scope.MustNewCounter("delete_rule_no_rule", + "count of attempts to remove a cloudwatch rule that doesn't exist"), + RemoveTargetFailures: scope.MustNewCounter("delete_target_failures", + "count of attempts to remove a cloudwatch target that have failed"), + RemoveTargetDoesntExist: scope.MustNewCounter("delete_target_no_target", + "count of attempts to remove a cloudwatch target that doesn't exist"), + RemovedSchedules: scope.MustNewCounter("schedules_removed", + "count of all schedules successfully removed from cloudwatch"), + ActiveSchedules: scope.MustNewGauge("active_schedules", + "count of all active schedules currently in cloudwatch"), + } +} + +func NewCloudWatchScheduler( + scheduleRoleArn, targetSqsArn string, session *session.Session, config *aws.Config, + scope promutils.Scope) scheduleInterfaces.EventScheduler { + cloudwatchEventClient := cloudwatchevents.New(session, config) + metrics := newCloudWatchSchedulerMetrics(scope) + return &cloudWatchScheduler{ + scheduleRoleArn: scheduleRoleArn, + targetSqsArn: targetSqsArn, + cloudWatchEventClient: cloudwatchEventClient, + metrics: metrics, + } +} diff --git a/flyteadmin/pkg/async/schedule/aws/cloud_watch_scheduler_test.go b/flyteadmin/pkg/async/schedule/aws/cloud_watch_scheduler_test.go new file mode 100644 index 0000000000..91e00bfcc1 --- /dev/null +++ b/flyteadmin/pkg/async/schedule/aws/cloud_watch_scheduler_test.go @@ -0,0 +1,272 @@ +package aws + +import ( + "context" + "fmt" + + "github.com/lyft/flyteadmin/pkg/async/schedule/aws/interfaces" + "github.com/lyft/flyteadmin/pkg/async/schedule/aws/mocks" + scheduleInterfaces "github.com/lyft/flyteadmin/pkg/async/schedule/interfaces" + + "github.com/lyft/flytestdlib/promutils" + + "github.com/aws/aws-sdk-go/aws/awserr" + + "testing" + + "github.com/aws/aws-sdk-go/service/cloudwatchevents" + flyteAdminErrors "github.com/lyft/flyteadmin/pkg/errors" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/stretchr/testify/assert" + "google.golang.org/grpc/codes" +) + +const testScheduleName = "flyte_16301494360130577061" +const testScheduleDescription = "Schedule for Project:project Domain:domain Name:name launch plan" + +var expectedError = flyteAdminErrors.NewFlyteAdminError(codes.Internal, "foo") + +var testSerializedPayload = fmt.Sprintf("event triggered at '%s'", awsTimestampPlaceholder) + +var testSchedulerIdentifier = admin.NamedEntityIdentifier{ + Project: "project", + Domain: "domain", + Name: "name", +} + +var scope = promutils.NewScope("test_scheduler") + +var testCloudWatchSchedulerMetrics = newCloudWatchSchedulerMetrics(scope) + +func TestGetScheduleName(t *testing.T) { + scheduleName := getScheduleName(testSchedulerIdentifier) + assert.Equal(t, "flyte_16301494360130577061", scheduleName) +} + +func TestGetScheduleDescription(t *testing.T) { + scheduleDescription := getScheduleDescription(testSchedulerIdentifier) + assert.Equal(t, "Schedule for Project:project Domain:domain Name:name launch plan", scheduleDescription) +} + +func TestGetScheduleExpression(t *testing.T) { + expression, err := getScheduleExpression(admin.Schedule{ + ScheduleExpression: &admin.Schedule_CronExpression{ + CronExpression: "foo", + }, + }) + assert.Nil(t, err) + assert.Equal(t, "cron(foo)", expression) + + expression, err = getScheduleExpression(admin.Schedule{ + ScheduleExpression: &admin.Schedule_Rate{ + Rate: &admin.FixedRate{ + Value: 1, + Unit: admin.FixedRateUnit_DAY, + }, + }, + }) + assert.Nil(t, err) + assert.Equal(t, "rate(1 day)", expression) + + expression, err = getScheduleExpression(admin.Schedule{ + ScheduleExpression: &admin.Schedule_Rate{ + Rate: &admin.FixedRate{ + Value: 2, + Unit: admin.FixedRateUnit_HOUR, + }, + }, + }) + assert.Nil(t, err) + assert.Equal(t, "rate(2 hours)", expression) + + _, err = getScheduleExpression(admin.Schedule{}) + assert.Equal(t, codes.InvalidArgument, err.(flyteAdminErrors.FlyteAdminError).Code()) +} + +func TestFormatEventScheduleInputs(t *testing.T) { + inputTransformer := formatEventScheduleInputs(&testSerializedPayload) + assert.EqualValues(t, map[string]*string{ + "time": &timeValue, + }, inputTransformer.InputPathsMap) + assert.Equal(t, testSerializedPayload, *inputTransformer.InputTemplate) +} + +func getCloudWatchSchedulerForTest(client interfaces.CloudWatchEventClient) scheduleInterfaces.EventScheduler { + + return &cloudWatchScheduler{ + scheduleRoleArn: "ScheduleRole", + targetSqsArn: "TargetSqsArn", + cloudWatchEventClient: client, + metrics: testCloudWatchSchedulerMetrics, + } +} + +func TestAddSchedule(t *testing.T) { + mockCloudWatchEventClient := mocks.NewMockCloudWatchEventClient() + mockCloudWatchEventClient.(*mocks.MockCloudWatchEventClient).SetPutRuleFunc(func( + input *cloudwatchevents.PutRuleInput) (*cloudwatchevents.PutRuleOutput, error) { + assert.Equal(t, "rate(1 minute)", *input.ScheduleExpression) + assert.Equal(t, testScheduleName, *input.Name) + assert.Equal(t, testScheduleDescription, *input.Description) + assert.Equal(t, "ScheduleRole", *input.RoleArn) + assert.Equal(t, enableState, *input.State) + return &cloudwatchevents.PutRuleOutput{}, nil + }) + + mockCloudWatchEventClient.(*mocks.MockCloudWatchEventClient).SetPutTargetsFunc(func( + input *cloudwatchevents.PutTargetsInput) (*cloudwatchevents.PutTargetsOutput, error) { + assert.Equal(t, testScheduleName, *input.Rule) + assert.Len(t, input.Targets, 1) + assert.Equal(t, "TargetSqsArn", *input.Targets[0].Arn) + assert.Equal(t, testScheduleName, *input.Targets[0].Id) + assert.NotEmpty(t, *input.Targets[0].InputTransformer) + return &cloudwatchevents.PutTargetsOutput{}, nil + }) + + scheduler := getCloudWatchSchedulerForTest(mockCloudWatchEventClient) + assert.Nil(t, scheduler.AddSchedule(context.Background(), + scheduleInterfaces.AddScheduleInput{ + Identifier: testSchedulerIdentifier, + ScheduleExpression: admin.Schedule{ + ScheduleExpression: &admin.Schedule_Rate{ + Rate: &admin.FixedRate{ + Value: 1, + Unit: admin.FixedRateUnit_MINUTE, + }, + }, + }, + Payload: &testSerializedPayload, + })) +} + +func TestAddSchedule_InvalidScheduleExpression(t *testing.T) { + mockCloudWatchEventClient := mocks.NewMockCloudWatchEventClient() + scheduler := getCloudWatchSchedulerForTest(mockCloudWatchEventClient) + err := scheduler.AddSchedule(context.Background(), + scheduleInterfaces.AddScheduleInput{ + Identifier: testSchedulerIdentifier, + Payload: &testSerializedPayload, + }) + assert.Equal(t, codes.InvalidArgument, err.(flyteAdminErrors.FlyteAdminError).Code()) +} + +func TestAddSchedule_PutRuleError(t *testing.T) { + mockCloudWatchEventClient := mocks.NewMockCloudWatchEventClient() + mockCloudWatchEventClient.(*mocks.MockCloudWatchEventClient).SetPutRuleFunc(func( + input *cloudwatchevents.PutRuleInput) (*cloudwatchevents.PutRuleOutput, error) { + return nil, expectedError + }) + + scheduler := getCloudWatchSchedulerForTest(mockCloudWatchEventClient) + err := scheduler.AddSchedule(context.Background(), + scheduleInterfaces.AddScheduleInput{ + Identifier: testSchedulerIdentifier, + ScheduleExpression: admin.Schedule{ + ScheduleExpression: &admin.Schedule_Rate{ + Rate: &admin.FixedRate{ + Value: 1, + Unit: admin.FixedRateUnit_MINUTE, + }, + }, + }, + Payload: &testSerializedPayload, + }) + assert.Equal(t, codes.Internal, err.(flyteAdminErrors.FlyteAdminError).Code()) +} + +func TestAddSchedule_PutTargetsError(t *testing.T) { + mockCloudWatchEventClient := mocks.NewMockCloudWatchEventClient() + mockCloudWatchEventClient.(*mocks.MockCloudWatchEventClient).SetPutRuleFunc(func( + input *cloudwatchevents.PutRuleInput) (*cloudwatchevents.PutRuleOutput, error) { + return &cloudwatchevents.PutRuleOutput{}, nil + }) + mockCloudWatchEventClient.(*mocks.MockCloudWatchEventClient).SetPutTargetsFunc(func( + input *cloudwatchevents.PutTargetsInput) (*cloudwatchevents.PutTargetsOutput, error) { + return nil, expectedError + }) + scheduler := getCloudWatchSchedulerForTest(mockCloudWatchEventClient) + err := scheduler.AddSchedule(context.Background(), + scheduleInterfaces.AddScheduleInput{ + Identifier: testSchedulerIdentifier, + ScheduleExpression: admin.Schedule{ + ScheduleExpression: &admin.Schedule_Rate{ + Rate: &admin.FixedRate{ + Value: 1, + Unit: admin.FixedRateUnit_MINUTE, + }, + }, + }, + Payload: &testSerializedPayload, + }) + assert.Equal(t, codes.Internal, err.(flyteAdminErrors.FlyteAdminError).Code()) +} + +func TestRemoveSchedule(t *testing.T) { + mockCloudWatchEventClient := mocks.NewMockCloudWatchEventClient() + mockCloudWatchEventClient.(*mocks.MockCloudWatchEventClient).SetRemoveTargetsFunc(func( + input *cloudwatchevents.RemoveTargetsInput) (*cloudwatchevents.RemoveTargetsOutput, error) { + assert.Len(t, input.Ids, 1) + assert.Equal(t, testScheduleName, *input.Ids[0]) + assert.Equal(t, testScheduleName, *input.Rule) + return &cloudwatchevents.RemoveTargetsOutput{}, nil + }) + mockCloudWatchEventClient.(*mocks.MockCloudWatchEventClient).SetDeleteRuleFunc(func( + input *cloudwatchevents.DeleteRuleInput) (*cloudwatchevents.DeleteRuleOutput, error) { + assert.Equal(t, testScheduleName, *input.Name) + return &cloudwatchevents.DeleteRuleOutput{}, nil + }) + scheduler := getCloudWatchSchedulerForTest(mockCloudWatchEventClient) + assert.Nil(t, scheduler.RemoveSchedule(context.Background(), testSchedulerIdentifier)) +} + +func TestRemoveSchedule_RemoveTargetsError(t *testing.T) { + mockCloudWatchEventClient := mocks.NewMockCloudWatchEventClient() + mockCloudWatchEventClient.(*mocks.MockCloudWatchEventClient).SetRemoveTargetsFunc(func( + input *cloudwatchevents.RemoveTargetsInput) (*cloudwatchevents.RemoveTargetsOutput, error) { + return nil, expectedError + }) + scheduler := getCloudWatchSchedulerForTest(mockCloudWatchEventClient) + err := scheduler.RemoveSchedule(context.Background(), testSchedulerIdentifier) + assert.Equal(t, codes.Internal, err.(flyteAdminErrors.FlyteAdminError).Code()) +} + +func TestRemoveSchedule_InvalidTarget(t *testing.T) { + mockCloudWatchEventClient := mocks.NewMockCloudWatchEventClient() + mockCloudWatchEventClient.(*mocks.MockCloudWatchEventClient).SetRemoveTargetsFunc(func( + input *cloudwatchevents.RemoveTargetsInput) (*cloudwatchevents.RemoveTargetsOutput, error) { + return nil, awserr.New(cloudwatchevents.ErrCodeResourceNotFoundException, "foo", expectedError) + }) + scheduler := getCloudWatchSchedulerForTest(mockCloudWatchEventClient) + err := scheduler.RemoveSchedule(context.Background(), testSchedulerIdentifier) + assert.Nil(t, err) +} + +func TestRemoveSchedule_DeleteRuleError(t *testing.T) { + mockCloudWatchEventClient := mocks.NewMockCloudWatchEventClient() + mockCloudWatchEventClient.(*mocks.MockCloudWatchEventClient).SetRemoveTargetsFunc(func( + input *cloudwatchevents.RemoveTargetsInput) (*cloudwatchevents.RemoveTargetsOutput, error) { + return &cloudwatchevents.RemoveTargetsOutput{}, nil + }) + mockCloudWatchEventClient.(*mocks.MockCloudWatchEventClient).SetDeleteRuleFunc(func( + input *cloudwatchevents.DeleteRuleInput) (*cloudwatchevents.DeleteRuleOutput, error) { + return nil, expectedError + }) + scheduler := getCloudWatchSchedulerForTest(mockCloudWatchEventClient) + err := scheduler.RemoveSchedule(context.Background(), testSchedulerIdentifier) + assert.Equal(t, codes.Internal, err.(flyteAdminErrors.FlyteAdminError).Code()) +} + +func TestRemoveSchedule_InvalidRule(t *testing.T) { + mockCloudWatchEventClient := mocks.NewMockCloudWatchEventClient() + mockCloudWatchEventClient.(*mocks.MockCloudWatchEventClient).SetRemoveTargetsFunc(func( + input *cloudwatchevents.RemoveTargetsInput) (*cloudwatchevents.RemoveTargetsOutput, error) { + return &cloudwatchevents.RemoveTargetsOutput{}, nil + }) + mockCloudWatchEventClient.(*mocks.MockCloudWatchEventClient).SetDeleteRuleFunc(func( + input *cloudwatchevents.DeleteRuleInput) (*cloudwatchevents.DeleteRuleOutput, error) { + return nil, awserr.New(cloudwatchevents.ErrCodeResourceNotFoundException, "foo", expectedError) + }) + scheduler := getCloudWatchSchedulerForTest(mockCloudWatchEventClient) + err := scheduler.RemoveSchedule(context.Background(), testSchedulerIdentifier) + assert.Nil(t, err) +} diff --git a/flyteadmin/pkg/async/schedule/aws/interfaces/cloud_watch_event_client.go b/flyteadmin/pkg/async/schedule/aws/interfaces/cloud_watch_event_client.go new file mode 100644 index 0000000000..2ba47974e7 --- /dev/null +++ b/flyteadmin/pkg/async/schedule/aws/interfaces/cloud_watch_event_client.go @@ -0,0 +1,11 @@ +package interfaces + +import "github.com/aws/aws-sdk-go/service/cloudwatchevents" + +// A subset of the AWS CloudWatchEvents service client. +type CloudWatchEventClient interface { + PutRule(input *cloudwatchevents.PutRuleInput) (*cloudwatchevents.PutRuleOutput, error) + PutTargets(input *cloudwatchevents.PutTargetsInput) (*cloudwatchevents.PutTargetsOutput, error) + DeleteRule(input *cloudwatchevents.DeleteRuleInput) (*cloudwatchevents.DeleteRuleOutput, error) + RemoveTargets(input *cloudwatchevents.RemoveTargetsInput) (*cloudwatchevents.RemoveTargetsOutput, error) +} diff --git a/flyteadmin/pkg/async/schedule/aws/mocks/mock_cloud_watch_event_client.go b/flyteadmin/pkg/async/schedule/aws/mocks/mock_cloud_watch_event_client.go new file mode 100644 index 0000000000..0a85b60bbb --- /dev/null +++ b/flyteadmin/pkg/async/schedule/aws/mocks/mock_cloud_watch_event_client.go @@ -0,0 +1,71 @@ +package mocks + +import ( + "github.com/aws/aws-sdk-go/service/cloudwatchevents" + "github.com/lyft/flyteadmin/pkg/async/schedule/aws/interfaces" +) + +type putRuleFunc func(input *cloudwatchevents.PutRuleInput) (*cloudwatchevents.PutRuleOutput, error) +type putTargetsFunc func(input *cloudwatchevents.PutTargetsInput) (*cloudwatchevents.PutTargetsOutput, error) +type deleteRuleFunc func(input *cloudwatchevents.DeleteRuleInput) (*cloudwatchevents.DeleteRuleOutput, error) +type removeTargetsFunc func(input *cloudwatchevents.RemoveTargetsInput) (*cloudwatchevents.RemoveTargetsOutput, error) + +// A mock implementation of CloudWatchEventClient for use in tests. +type MockCloudWatchEventClient struct { + putRule putRuleFunc + putTargets putTargetsFunc + deleteRule deleteRuleFunc + removeTargets removeTargetsFunc +} + +func (c *MockCloudWatchEventClient) SetPutRuleFunc(putRule putRuleFunc) { + c.putRule = putRule +} + +func (c *MockCloudWatchEventClient) PutRule(input *cloudwatchevents.PutRuleInput) ( + *cloudwatchevents.PutRuleOutput, error) { + if c.putRule != nil { + return c.putRule(input) + } + return nil, nil +} + +func (c *MockCloudWatchEventClient) SetPutTargetsFunc(putTargets putTargetsFunc) { + c.putTargets = putTargets +} + +func (c *MockCloudWatchEventClient) PutTargets(input *cloudwatchevents.PutTargetsInput) ( + *cloudwatchevents.PutTargetsOutput, error) { + if c.putTargets != nil { + return c.putTargets(input) + } + return nil, nil +} + +func (c *MockCloudWatchEventClient) SetDeleteRuleFunc(deleteRule deleteRuleFunc) { + c.deleteRule = deleteRule +} + +func (c *MockCloudWatchEventClient) DeleteRule(input *cloudwatchevents.DeleteRuleInput) ( + *cloudwatchevents.DeleteRuleOutput, error) { + if c.deleteRule != nil { + return c.deleteRule(input) + } + return nil, nil +} + +func (c *MockCloudWatchEventClient) SetRemoveTargetsFunc(removeTargets removeTargetsFunc) { + c.removeTargets = removeTargets +} + +func (c *MockCloudWatchEventClient) RemoveTargets(input *cloudwatchevents.RemoveTargetsInput) ( + *cloudwatchevents.RemoveTargetsOutput, error) { + if c.removeTargets != nil { + return c.removeTargets(input) + } + return nil, nil +} + +func NewMockCloudWatchEventClient() interfaces.CloudWatchEventClient { + return &MockCloudWatchEventClient{} +} diff --git a/flyteadmin/pkg/async/schedule/aws/serialization.go b/flyteadmin/pkg/async/schedule/aws/serialization.go new file mode 100644 index 0000000000..e6b455a9d2 --- /dev/null +++ b/flyteadmin/pkg/async/schedule/aws/serialization.go @@ -0,0 +1,86 @@ +// Functions for serializing and deserializing scheduled events in AWS. +package aws + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "time" + + "github.com/lyft/flytestdlib/logger" + + "github.com/golang/protobuf/proto" + "github.com/lyft/flyteadmin/pkg/errors" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + "google.golang.org/grpc/codes" +) + +const awsTimestampPlaceholder = "