From 537f86093d86af270aab300742ee1a56f2905885 Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Tue, 9 Apr 2019 13:26:25 -0700 Subject: [PATCH] Initial Release --- .gitignore | 113 +++ .golangci.yml | 25 + .travis.yml | 20 + CODE_OF_CONDUCT.md | 2 + Gopkg.lock | 516 ++++++++++++++ Gopkg.toml | 72 ++ LICENSE | 202 ++++++ Makefile | 24 + NOTICE | 21 + README.rst | 35 + atomic/atomic.go | 129 ++++ atomic/atomic_test.go | 61 ++ atomic/non_blocking_lock.go | 23 + atomic/non_blocking_lock_test.go | 15 + boilerplate/lyft/golang_test_targets/Makefile | 31 + .../lyft/golang_test_targets/Readme.rst | 31 + .../lyft/golang_test_targets/goimports | 3 + boilerplate/update.cfg | 2 + cli/pflags/api/generator.go | 277 ++++++++ cli/pflags/api/generator_test.go | 56 ++ cli/pflags/api/pflag_provider.go | 90 +++ cli/pflags/api/sample.go | 53 ++ cli/pflags/api/tag.go | 66 ++ cli/pflags/api/templates.go | 175 +++++ cli/pflags/api/testdata/testtype.go | 36 + cli/pflags/api/testdata/testtype_test.go | 520 ++++++++++++++ cli/pflags/api/types.go | 36 + cli/pflags/api/utils.go | 32 + cli/pflags/cmd/root.go | 71 ++ cli/pflags/cmd/version.go | 17 + cli/pflags/main.go | 15 + cli/pflags/readme.rst | 24 + config/accessor.go | 61 ++ config/accessor_test.go | 91 +++ config/config_cmd.go | 113 +++ config/config_cmd_test.go | 64 ++ config/duration.go | 47 ++ config/duration_test.go | 66 ++ config/errors.go | 37 + config/errors_test.go | 31 + config/files/finder.go | 83 +++ config/files/finder_test.go | 28 + config/files/testdata/config-1.yaml | 9 + config/files/testdata/config-2.yaml | 2 + config/files/testdata/other-group-1.yaml | 9 + config/files/testdata/other-group-2.yaml | 2 + config/port.go | 56 ++ config/port_test.go | 81 +++ config/section.go | 230 +++++++ config/section_test.go | 119 ++++ config/testdata/config.yaml | 11 + config/tests/accessor_test.go | 641 +++++++++++++++++ config/tests/config_cmd_test.go | 92 +++ config/tests/testdata/array_configs.yaml | 7 + config/tests/testdata/bad_config.yaml | 13 + config/tests/testdata/config.yaml | 11 + config/tests/testdata/nested_config.yaml | 11 + config/tests/types_test.go | 66 ++ config/url.go | 36 + config/url_test.go | 59 ++ config/utils.go | 69 ++ config/utils_test.go | 39 ++ config/viper/collection.go | 175 +++++ config/viper/viper.go | 357 ++++++++++ contextutils/context.go | 140 ++++ contextutils/context_test.go | 113 +++ internal/utils/parsers.go | 20 + internal/utils/parsers_test.go | 25 + ioutils/bytes.go | 21 + ioutils/bytes_test.go | 17 + ioutils/timed_readers.go | 17 + ioutils/timed_readers_test.go | 20 + logger/config.go | 81 +++ logger/config_flags.go | 21 + logger/config_flags_test.go | 190 ++++++ logger/config_test.go | 20 + logger/logger.go | 337 +++++++++ logger/logger_test.go | 643 ++++++++++++++++++ pbhash/pbhash.go | 58 ++ pbhash/pbhash_test.go | 145 ++++ profutils/server.go | 116 ++++ profutils/server_test.go | 103 +++ promutils/labeled/counter.go | 65 ++ promutils/labeled/counter_test.go | 31 + promutils/labeled/keys.go | 47 ++ promutils/labeled/keys_test.go | 24 + promutils/labeled/metric_option.go | 15 + promutils/labeled/metric_option_test.go | 13 + promutils/labeled/stopwatch.go | 87 +++ promutils/labeled/stopwatch_test.go | 47 ++ promutils/labeled/timer_wrapper.go | 20 + promutils/labeled/timer_wrapper_test.go | 28 + promutils/scope.go | 434 ++++++++++++ promutils/scope_test.go | 151 ++++ promutils/workqueue.go | 82 +++ promutils/workqueue_test.go | 42 ++ sets/generic_set.go | 195 ++++++ sets/generic_set_test.go | 116 ++++ storage/cached_rawstore.go | 123 ++++ storage/cached_rawstore_test.go | 182 +++++ storage/config.go | 85 +++ storage/config_flags.go | 28 + storage/config_flags_test.go | 344 ++++++++++ storage/config_test.go | 45 ++ storage/copy_impl.go | 60 ++ storage/copy_impl_test.go | 82 +++ storage/localstore.go | 48 ++ storage/localstore_test.go | 66 ++ storage/mem_store.go | 74 ++ storage/mem_store_test.go | 78 +++ storage/protobuf_store.go | 85 +++ storage/protobuf_store_test.go | 41 ++ storage/rawstores.go | 40 ++ storage/s3store.go | 102 +++ storage/s3stsore_test.go | 26 + storage/storage.go | 95 +++ storage/storage_test.go | 52 ++ storage/stow_store.go | 174 +++++ storage/stow_store_test.go | 133 ++++ storage/testdata/config.yaml | 14 + storage/url_path.go | 44 ++ storage/url_path_test.go | 15 + storage/utils.go | 34 + tests/config_test.go | 78 +++ tests/testdata/combined.yaml | 19 + utils/auto_refresh_cache.go | 99 +++ utils/auto_refresh_cache_test.go | 108 +++ utils/auto_refresh_example_test.go | 88 +++ utils/rate_limiter.go | 35 + utils/rate_limiter_test.go | 39 ++ utils/sequencer.go | 39 ++ utils/sequencer_test.go | 54 ++ version/version.go | 29 + version/version_test.go | 29 + yamlutils/yaml_json.go | 17 + 135 files changed, 11697 insertions(+) create mode 100644 .gitignore create mode 100644 .golangci.yml create mode 100644 .travis.yml create mode 100644 CODE_OF_CONDUCT.md create mode 100644 Gopkg.lock create mode 100644 Gopkg.toml create mode 100644 LICENSE create mode 100644 Makefile create mode 100644 NOTICE create mode 100644 README.rst create mode 100644 atomic/atomic.go create mode 100644 atomic/atomic_test.go create mode 100644 atomic/non_blocking_lock.go create mode 100644 atomic/non_blocking_lock_test.go create mode 100644 boilerplate/lyft/golang_test_targets/Makefile create mode 100644 boilerplate/lyft/golang_test_targets/Readme.rst create mode 100755 boilerplate/lyft/golang_test_targets/goimports create mode 100644 boilerplate/update.cfg create mode 100644 cli/pflags/api/generator.go create mode 100644 cli/pflags/api/generator_test.go create mode 100644 cli/pflags/api/pflag_provider.go create mode 100644 cli/pflags/api/sample.go create mode 100644 cli/pflags/api/tag.go create mode 100644 cli/pflags/api/templates.go create mode 100755 cli/pflags/api/testdata/testtype.go create mode 100755 cli/pflags/api/testdata/testtype_test.go create mode 100644 cli/pflags/api/types.go create mode 100644 cli/pflags/api/utils.go create mode 100644 cli/pflags/cmd/root.go create mode 100644 cli/pflags/cmd/version.go create mode 100644 cli/pflags/main.go create mode 100644 cli/pflags/readme.rst create mode 100644 config/accessor.go create mode 100644 config/accessor_test.go create mode 100644 config/config_cmd.go create mode 100644 config/config_cmd_test.go create mode 100644 config/duration.go create mode 100644 config/duration_test.go create mode 100644 config/errors.go create mode 100644 config/errors_test.go create mode 100644 config/files/finder.go create mode 100644 config/files/finder_test.go create mode 100644 config/files/testdata/config-1.yaml create mode 100755 config/files/testdata/config-2.yaml create mode 100644 config/files/testdata/other-group-1.yaml create mode 100755 config/files/testdata/other-group-2.yaml create mode 100644 config/port.go create mode 100644 config/port_test.go create mode 100644 config/section.go create mode 100644 config/section_test.go create mode 100755 config/testdata/config.yaml create mode 100644 config/tests/accessor_test.go create mode 100644 config/tests/config_cmd_test.go create mode 100644 config/tests/testdata/array_configs.yaml create mode 100644 config/tests/testdata/bad_config.yaml create mode 100755 config/tests/testdata/config.yaml create mode 100755 config/tests/testdata/nested_config.yaml create mode 100644 config/tests/types_test.go create mode 100644 config/url.go create mode 100644 config/url_test.go create mode 100644 config/utils.go create mode 100644 config/utils_test.go create mode 100644 config/viper/collection.go create mode 100644 config/viper/viper.go create mode 100644 contextutils/context.go create mode 100644 contextutils/context_test.go create mode 100644 internal/utils/parsers.go create mode 100644 internal/utils/parsers_test.go create mode 100644 ioutils/bytes.go create mode 100644 ioutils/bytes_test.go create mode 100644 ioutils/timed_readers.go create mode 100644 ioutils/timed_readers_test.go create mode 100644 logger/config.go create mode 100755 logger/config_flags.go create mode 100755 logger/config_flags_test.go create mode 100644 logger/config_test.go create mode 100644 logger/logger.go create mode 100644 logger/logger_test.go create mode 100644 pbhash/pbhash.go create mode 100644 pbhash/pbhash_test.go create mode 100644 profutils/server.go create mode 100644 profutils/server_test.go create mode 100644 promutils/labeled/counter.go create mode 100644 promutils/labeled/counter_test.go create mode 100644 promutils/labeled/keys.go create mode 100644 promutils/labeled/keys_test.go create mode 100644 promutils/labeled/metric_option.go create mode 100644 promutils/labeled/metric_option_test.go create mode 100644 promutils/labeled/stopwatch.go create mode 100644 promutils/labeled/stopwatch_test.go create mode 100644 promutils/labeled/timer_wrapper.go create mode 100644 promutils/labeled/timer_wrapper_test.go create mode 100644 promutils/scope.go create mode 100644 promutils/scope_test.go create mode 100644 promutils/workqueue.go create mode 100644 promutils/workqueue_test.go create mode 100644 sets/generic_set.go create mode 100644 sets/generic_set_test.go create mode 100644 storage/cached_rawstore.go create mode 100644 storage/cached_rawstore_test.go create mode 100644 storage/config.go create mode 100755 storage/config_flags.go create mode 100755 storage/config_flags_test.go create mode 100644 storage/config_test.go create mode 100644 storage/copy_impl.go create mode 100644 storage/copy_impl_test.go create mode 100644 storage/localstore.go create mode 100644 storage/localstore_test.go create mode 100644 storage/mem_store.go create mode 100644 storage/mem_store_test.go create mode 100644 storage/protobuf_store.go create mode 100644 storage/protobuf_store_test.go create mode 100644 storage/rawstores.go create mode 100644 storage/s3store.go create mode 100644 storage/s3stsore_test.go create mode 100644 storage/storage.go create mode 100644 storage/storage_test.go create mode 100644 storage/stow_store.go create mode 100644 storage/stow_store_test.go create mode 100755 storage/testdata/config.yaml create mode 100644 storage/url_path.go create mode 100644 storage/url_path_test.go create mode 100644 storage/utils.go create mode 100644 tests/config_test.go create mode 100755 tests/testdata/combined.yaml create mode 100644 utils/auto_refresh_cache.go create mode 100644 utils/auto_refresh_cache_test.go create mode 100644 utils/auto_refresh_example_test.go create mode 100644 utils/rate_limiter.go create mode 100644 utils/rate_limiter_test.go create mode 100644 utils/sequencer.go create mode 100644 utils/sequencer_test.go create mode 100644 version/version.go create mode 100644 version/version_test.go create mode 100644 yamlutils/yaml_json.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000..00820a03e4 --- /dev/null +++ b/.gitignore @@ -0,0 +1,113 @@ + +# Temporary Build Files +tmp/_output +tmp/_test + + +# Created by https://www.gitignore.io/api/go,vim,emacs,visualstudiocode + +### Emacs ### +# -*- mode: gitignore; -*- +*~ +\#*\# +/.emacs.desktop +/.emacs.desktop.lock +*.elc +auto-save-list +tramp +.\#* + +# Org-mode +.org-id-locations +*_archive + +# flymake-mode +*_flymake.* + +# eshell files +/eshell/history +/eshell/lastdir + +# elpa packages +/elpa/ + +# reftex files +*.rel + +# AUCTeX auto folder +/auto/ + +# cask packages +.cask/ +dist/ + +# Flycheck +flycheck_*.el + +# server auth directory +/server/ + +# projectiles files +.projectile +projectile-bookmarks.eld + +# directory configuration +.dir-locals.el + +# saveplace +places + +# url cache +url/cache/ + +# cedet +ede-projects.el + +# smex +smex-items + +# company-statistics +company-statistics-cache.el + +# anaconda-mode +anaconda-mode/ + +### Go ### +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib +vendor/ + +# Test binary, build with 'go test -c' +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +### Vim ### +# swap +.sw[a-p] +.*.sw[a-p] +# session +Session.vim +# temporary +.netrwhist +# auto-generated tag files +tags + +### VisualStudioCode ### +.vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json +.history + +### GoLand ### +.idea/* + + +# End of https://www.gitignore.io/api/go,vim,emacs,visualstudiocode diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000000..dbfea73e09 --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,25 @@ +run: + skip-dirs: + - pkg/client + +linters: + disable-all: true + enable: + - deadcode + - errcheck + - gas + - goconst + - goimports + - golint + - gosimple + - govet + - ineffassign + - misspell + - nakedret + - staticcheck + - structcheck + - typecheck + - unconvert + - unparam + - unused + - varcheck diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000000..91723384da --- /dev/null +++ b/.travis.yml @@ -0,0 +1,20 @@ +sudo: required +language: go +go: + - "1.10" +services: + - docker +jobs: + include: + - stage: test + name: unit tests + install: make install + script: make test_unit + - stage: test + name: benchmark tests + install: make install + script: make test_benchmark + - stage: test + install: make install + name: lint + script: make lint diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000..4c3a38cc48 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,2 @@ +This project is governed by [Lyft's code of conduct](https://github.com/lyft/code-of-conduct). +All contributors and participants agree to abide by its terms. diff --git a/Gopkg.lock b/Gopkg.lock new file mode 100644 index 0000000000..a56ebd2dba --- /dev/null +++ b/Gopkg.lock @@ -0,0 +1,516 @@ +# This file is autogenerated, do not edit; changes may be undone by the next 'dep ensure'. + + +[[projects]] + digest = "1:7fbdc0ca5fc0b0bb66b81ec2fdca82fbe64416742267f11aceb8ed56e6ca3121" + name = "github.com/aws/aws-sdk-go" + packages = [ + "aws", + "aws/awserr", + "aws/awsutil", + "aws/client", + "aws/client/metadata", + "aws/corehandlers", + "aws/credentials", + "aws/credentials/ec2rolecreds", + "aws/credentials/endpointcreds", + "aws/credentials/processcreds", + "aws/credentials/stscreds", + "aws/csm", + "aws/defaults", + "aws/ec2metadata", + "aws/endpoints", + "aws/request", + "aws/session", + "aws/signer/v4", + "internal/ini", + "internal/s3err", + "internal/sdkio", + "internal/sdkrand", + "internal/sdkuri", + "internal/shareddefaults", + "private/protocol", + "private/protocol/eventstream", + "private/protocol/eventstream/eventstreamapi", + "private/protocol/query", + "private/protocol/query/queryutil", + "private/protocol/rest", + "private/protocol/restxml", + "private/protocol/xml/xmlutil", + "service/s3", + "service/sts", + ] + pruneopts = "UT" + revision = "81f3829f5a9d041041bdf56e55926691309d7699" + version = "v1.16.26" + +[[projects]] + branch = "master" + digest = "1:a6609679ca468a89b711934f16b346e99f6ec344eadd2f7b00b1156785dd1236" + name = "github.com/benlaurie/objecthash" + packages = ["go/objecthash"] + pruneopts = "UT" + revision = "d1e3d6079fc16f8f542183fb5b2fdc11d9f00866" + +[[projects]] + branch = "master" + digest = "1:d6afaeed1502aa28e80a4ed0981d570ad91b2579193404256ce672ed0a609e0d" + name = "github.com/beorn7/perks" + packages = ["quantile"] + pruneopts = "UT" + revision = "3a771d992973f24aa725d07868b467d1ddfceafb" + +[[projects]] + digest = "1:998cf998358a303ac2430c386ba3fd3398477d6013153d3c6e11432765cc9ae6" + name = "github.com/cespare/xxhash" + packages = ["."] + pruneopts = "UT" + revision = "3b82fb7d186719faeedd0c2864f868c74fbf79a1" + version = "v2.0.0" + +[[projects]] + digest = "1:04179a5bcbecdb18f06cca42e3808ae8560f86ad7fe470fde21206008f0c5e26" + name = "github.com/coocood/freecache" + packages = ["."] + pruneopts = "UT" + revision = "f3233c8095b26cd0dea0b136b931708c05defa08" + version = "v1.0.1" + +[[projects]] + digest = "1:ffe9824d294da03b391f44e1ae8281281b4afc1bdaa9588c9097785e3af10cec" + name = "github.com/davecgh/go-spew" + packages = ["spew"] + pruneopts = "UT" + revision = "8991bc29aa16c548c550c7ff78260e27b9ab7c73" + version = "v1.1.1" + +[[projects]] + branch = "master" + digest = "1:dc8bf44b7198605c83a4f2bb36a92c4d9f71eab2e8cf8094ce31b0297dd8ea89" + name = "github.com/ernesto-jimenez/gogen" + packages = [ + "gogenutil", + "imports", + ] + pruneopts = "UT" + revision = "d7d4131e6607813977e78297a6060f360f056a97" + +[[projects]] + digest = "1:865079840386857c809b72ce300be7580cb50d3d3129ce11bf9aa6ca2bc1934a" + name = "github.com/fatih/color" + packages = ["."] + pruneopts = "UT" + revision = "5b77d2a35fb0ede96d138fc9a99f5c9b6aef11b4" + version = "v1.7.0" + +[[projects]] + digest = "1:060e2ff7ee3e51b4a0fadf46308033bfe3b8030af6a8078ec26916e2e9b2fdc3" + name = "github.com/fatih/structtag" + packages = ["."] + pruneopts = "UT" + revision = "76ae1d6d2117609598c7d4e8f3e938145f204e8f" + version = "v1.0.0" + +[[projects]] + branch = "master" + digest = "1:b9414457752702c53f6afd3838da3d89b9513ada40cdbe9603bdf54b1ceb5014" + name = "github.com/fsnotify/fsnotify" + packages = ["."] + pruneopts = "UT" + revision = "ccc981bf80385c528a65fbfdd49bf2d8da22aa23" + +[[projects]] + digest = "1:2cd7915ab26ede7d95b8749e6b1f933f1c6d5398030684e6505940a10f31cfda" + name = "github.com/ghodss/yaml" + packages = ["."] + pruneopts = "UT" + revision = "0ca9ea5df5451ffdf184b4428c902747c2c11cd7" + version = "v1.0.0" + +[[projects]] + branch = "master" + digest = "1:1ba1d79f2810270045c328ae5d674321db34e3aae468eb4233883b473c5c0467" + name = "github.com/golang/glog" + packages = ["."] + pruneopts = "UT" + revision = "23def4e6c14b4da8ac2ed8007337bc5eb5007998" + +[[projects]] + digest = "1:9d6dc4d6de69b330d0de86494d6db90c09848c003d5db748f40c925f865c8534" + name = "github.com/golang/protobuf" + packages = [ + "jsonpb", + "proto", + "ptypes", + "ptypes/any", + "ptypes/duration", + "ptypes/struct", + "ptypes/timestamp", + ] + pruneopts = "UT" + revision = "aa810b61a9c79d51363740d207bb46cf8e620ed5" + version = "v1.2.0" + +[[projects]] + digest = "1:3dab0e385faed192353d2150f6a041f4607f04a0e885f4a5a824eee6b676b4b9" + name = "github.com/graymeta/stow" + packages = [ + ".", + "local", + "s3", + ] + pruneopts = "UT" + revision = "77c84b1dd69c41b74fe0a94ca8ee257d85947327" + +[[projects]] + digest = "1:c0d19ab64b32ce9fe5cf4ddceba78d5bc9807f0016db6b1183599da3dcc24d10" + name = "github.com/hashicorp/hcl" + packages = [ + ".", + "hcl/ast", + "hcl/parser", + "hcl/printer", + "hcl/scanner", + "hcl/strconv", + "hcl/token", + "json/parser", + "json/scanner", + "json/token", + ] + pruneopts = "UT" + revision = "8cb6e5b959231cc1119e43259c4a608f9c51a241" + version = "v1.0.0" + +[[projects]] + digest = "1:870d441fe217b8e689d7949fef6e43efbc787e50f200cb1e70dbca9204a1d6be" + name = "github.com/inconshreveable/mousetrap" + packages = ["."] + pruneopts = "UT" + revision = "76626ae9c91c4f2a10f34cad8ce83ea42c93bb75" + version = "v1.0" + +[[projects]] + digest = "1:bb81097a5b62634f3e9fec1014657855610c82d19b9a40c17612e32651e35dca" + name = "github.com/jmespath/go-jmespath" + packages = ["."] + pruneopts = "UT" + revision = "c2b33e84" + +[[projects]] + digest = "1:0a69a1c0db3591fcefb47f115b224592c8dfa4368b7ba9fae509d5e16cdc95c8" + name = "github.com/konsorten/go-windows-terminal-sequences" + packages = ["."] + pruneopts = "UT" + revision = "5c8c8bd35d3832f5d134ae1e1e375b69a4d25242" + version = "v1.0.1" + +[[projects]] + digest = "1:53e8c5c79716437e601696140e8b1801aae4204f4ec54a504333702a49572c4f" + name = "github.com/magiconair/properties" + packages = [ + ".", + "assert", + ] + pruneopts = "UT" + revision = "c2353362d570a7bfa228149c62842019201cfb71" + version = "v1.8.0" + +[[projects]] + digest = "1:c658e84ad3916da105a761660dcaeb01e63416c8ec7bc62256a9b411a05fcd67" + name = "github.com/mattn/go-colorable" + packages = ["."] + pruneopts = "UT" + revision = "167de6bfdfba052fa6b2d3664c8f5272e23c9072" + version = "v0.0.9" + +[[projects]] + digest = "1:0981502f9816113c9c8c4ac301583841855c8cf4da8c72f696b3ebedf6d0e4e5" + name = "github.com/mattn/go-isatty" + packages = ["."] + pruneopts = "UT" + revision = "6ca4dbf54d38eea1a992b3c722a76a5d1c4cb25c" + version = "v0.0.4" + +[[projects]] + digest = "1:ff5ebae34cfbf047d505ee150de27e60570e8c394b3b8fdbb720ff6ac71985fc" + name = "github.com/matttproud/golang_protobuf_extensions" + packages = ["pbutil"] + pruneopts = "UT" + revision = "c12348ce28de40eed0136aa2b644d0ee0650e56c" + version = "v1.0.1" + +[[projects]] + digest = "1:53bc4cd4914cd7cd52139990d5170d6dc99067ae31c56530621b18b35fc30318" + name = "github.com/mitchellh/mapstructure" + packages = ["."] + pruneopts = "UT" + revision = "3536a929edddb9a5b34bd6861dc4a9647cb459fe" + version = "v1.1.2" + +[[projects]] + digest = "1:95741de3af260a92cc5c7f3f3061e85273f5a81b5db20d4bd68da74bd521675e" + name = "github.com/pelletier/go-toml" + packages = ["."] + pruneopts = "UT" + revision = "c01d1270ff3e442a8a57cddc1c92dc1138598194" + version = "v1.2.0" + +[[projects]] + digest = "1:cf31692c14422fa27c83a05292eb5cbe0fb2775972e8f1f8446a71549bd8980b" + name = "github.com/pkg/errors" + packages = ["."] + pruneopts = "UT" + revision = "ba968bfe8b2f7e042a574c888954fccecfa385b4" + version = "v0.8.1" + +[[projects]] + digest = "1:0028cb19b2e4c3112225cd871870f2d9cf49b9b4276531f03438a88e94be86fe" + name = "github.com/pmezard/go-difflib" + packages = ["difflib"] + pruneopts = "UT" + revision = "792786c7400a136282c1664665ae0a8db921c6c2" + version = "v1.0.0" + +[[projects]] + digest = "1:93a746f1060a8acbcf69344862b2ceced80f854170e1caae089b2834c5fbf7f4" + name = "github.com/prometheus/client_golang" + packages = [ + "prometheus", + "prometheus/internal", + "prometheus/promhttp", + ] + pruneopts = "UT" + revision = "505eaef017263e299324067d40ca2c48f6a2cf50" + version = "v0.9.2" + +[[projects]] + branch = "master" + digest = "1:2d5cd61daa5565187e1d96bae64dbbc6080dacf741448e9629c64fd93203b0d4" + name = "github.com/prometheus/client_model" + packages = ["go"] + pruneopts = "UT" + revision = "fd36f4220a901265f90734c3183c5f0c91daa0b8" + +[[projects]] + digest = "1:35cf6bdf68db765988baa9c4f10cc5d7dda1126a54bd62e252dbcd0b1fc8da90" + name = "github.com/prometheus/common" + packages = [ + "expfmt", + "internal/bitbucket.org/ww/goautoneg", + "model", + ] + pruneopts = "UT" + revision = "cfeb6f9992ffa54aaa4f2170ade4067ee478b250" + version = "v0.2.0" + +[[projects]] + branch = "master" + digest = "1:5833c61ebbd625a6bad8e5a1ada2b3e13710cf3272046953a2c8915340fe60a3" + name = "github.com/prometheus/procfs" + packages = [ + ".", + "internal/util", + "nfs", + "xfs", + ] + pruneopts = "UT" + revision = "316cf8ccfec56d206735d46333ca162eb374da8b" + +[[projects]] + digest = "1:87c2e02fb01c27060ccc5ba7c5a407cc91147726f8f40b70cceeedbc52b1f3a8" + name = "github.com/sirupsen/logrus" + packages = ["."] + pruneopts = "UT" + revision = "e1e72e9de974bd926e5c56f83753fba2df402ce5" + version = "v1.3.0" + +[[projects]] + digest = "1:3e39bafd6c2f4bf3c76c3bfd16a2e09e016510ad5db90dc02b88e2f565d6d595" + name = "github.com/spf13/afero" + packages = [ + ".", + "mem", + ] + pruneopts = "UT" + revision = "f4711e4db9e9a1d3887343acb72b2bbfc2f686f5" + version = "v1.2.1" + +[[projects]] + digest = "1:08d65904057412fc0270fc4812a1c90c594186819243160dc779a402d4b6d0bc" + name = "github.com/spf13/cast" + packages = ["."] + pruneopts = "UT" + revision = "8c9545af88b134710ab1cd196795e7f2388358d7" + version = "v1.3.0" + +[[projects]] + digest = "1:645cabccbb4fa8aab25a956cbcbdf6a6845ca736b2c64e197ca7cbb9d210b939" + name = "github.com/spf13/cobra" + packages = ["."] + pruneopts = "UT" + revision = "ef82de70bb3f60c65fb8eebacbb2d122ef517385" + version = "v0.0.3" + +[[projects]] + digest = "1:68ea4e23713989dc20b1bded5d9da2c5f9be14ff9885beef481848edd18c26cb" + name = "github.com/spf13/jwalterweatherman" + packages = ["."] + pruneopts = "UT" + revision = "4a4406e478ca629068e7768fc33f3f044173c0a6" + version = "v1.0.0" + +[[projects]] + digest = "1:c1b1102241e7f645bc8e0c22ae352e8f0dc6484b6cb4d132fa9f24174e0119e2" + name = "github.com/spf13/pflag" + packages = ["."] + pruneopts = "UT" + revision = "298182f68c66c05229eb03ac171abe6e309ee79a" + version = "v1.0.3" + +[[projects]] + digest = "1:2532daa308722c7b65f4566e634dac2ddfaa0a398a17d8418e96ef2af3939e37" + name = "github.com/spf13/viper" + packages = ["."] + pruneopts = "UT" + revision = "ae103d7e593e371c69e832d5eb3347e2b80cbbc9" + +[[projects]] + digest = "1:972c2427413d41a1e06ca4897e8528e5a1622894050e2f527b38ddf0f343f759" + name = "github.com/stretchr/testify" + packages = ["assert"] + pruneopts = "UT" + revision = "ffdc059bfe9ce6a4e144ba849dbedead332c6053" + version = "v1.3.0" + +[[projects]] + branch = "master" + digest = "1:fde12c4da6237363bf36b81b59aa36a43d28061167ec4acb0d41fc49464e28b9" + name = "golang.org/x/crypto" + packages = ["ssh/terminal"] + pruneopts = "UT" + revision = "b01c7a72566457eb1420261cdafef86638fc3861" + +[[projects]] + branch = "master" + digest = "1:7941e2f16c0833b438cbef7fccfe4f8346f9f7876b42b29717a75d7e8c4800cb" + name = "golang.org/x/sys" + packages = [ + "unix", + "windows", + ] + pruneopts = "UT" + revision = "aca44879d5644da7c5b8ec6a1115e9b6ea6c40d9" + +[[projects]] + digest = "1:8029e9743749d4be5bc9f7d42ea1659471767860f0cdc34d37c3111bd308a295" + name = "golang.org/x/text" + packages = [ + "internal/gen", + "internal/triegen", + "internal/ucd", + "transform", + "unicode/cldr", + "unicode/norm", + ] + pruneopts = "UT" + revision = "f21a4dfb5e38f5895301dc265a8def02365cc3d0" + version = "v0.3.0" + +[[projects]] + branch = "master" + digest = "1:9fdc2b55e8e0fafe4b41884091e51e77344f7dc511c5acedcfd98200003bff90" + name = "golang.org/x/time" + packages = ["rate"] + pruneopts = "UT" + revision = "85acf8d2951cb2a3bde7632f9ff273ef0379bcbd" + +[[projects]] + branch = "master" + digest = "1:86d002f2c67e364e097c5047f517ab38cdef342c3c20be53974c1bfd5b191d30" + name = "golang.org/x/tools" + packages = [ + "go/ast/astutil", + "go/gcexportdata", + "go/internal/cgo", + "go/internal/gcimporter", + "go/internal/packagesdriver", + "go/packages", + "go/types/typeutil", + "imports", + "internal/fastwalk", + "internal/gopathwalk", + "internal/module", + "internal/semver", + ] + pruneopts = "UT" + revision = "58ecf64b2ccd4e014267d2ea143d23c617ee7e4c" + +[[projects]] + digest = "1:4d2e5a73dc1500038e504a8d78b986630e3626dc027bc030ba5c75da257cdb96" + name = "gopkg.in/yaml.v2" + packages = ["."] + pruneopts = "UT" + revision = "51d6538a90f86fe93ac480b35f37b2be17fef232" + version = "v2.2.2" + +[[projects]] + digest = "1:5922c4db083d03579c576df514f096003f422b602aeb30028aedd892b69a4876" + name = "k8s.io/apimachinery" + packages = [ + "pkg/util/clock", + "pkg/util/rand", + "pkg/util/runtime", + "pkg/util/wait", + ] + pruneopts = "UT" + revision = "103fd098999dc9c0c88536f5c9ad2e5da39373ae" + version = "kubernetes-1.11.2" + +[[projects]] + digest = "1:8d66fef1249b9b2105840377af3bab078604d3c298058f563685e88d2a9e6ad3" + name = "k8s.io/client-go" + packages = ["util/workqueue"] + pruneopts = "UT" + revision = "1f13a808da65775f22cbf47862c4e5898d8f4ca1" + version = "kubernetes-1.11.2" + +[solve-meta] + analyzer-name = "dep" + analyzer-version = 1 + input-imports = [ + "github.com/aws/aws-sdk-go/aws/awserr", + "github.com/aws/aws-sdk-go/service/s3", + "github.com/benlaurie/objecthash/go/objecthash", + "github.com/coocood/freecache", + "github.com/ernesto-jimenez/gogen/gogenutil", + "github.com/ernesto-jimenez/gogen/imports", + "github.com/fatih/color", + "github.com/fatih/structtag", + "github.com/fsnotify/fsnotify", + "github.com/ghodss/yaml", + "github.com/golang/protobuf/jsonpb", + "github.com/golang/protobuf/proto", + "github.com/golang/protobuf/ptypes", + "github.com/golang/protobuf/ptypes/duration", + "github.com/golang/protobuf/ptypes/timestamp", + "github.com/graymeta/stow", + "github.com/graymeta/stow/local", + "github.com/graymeta/stow/s3", + "github.com/magiconair/properties/assert", + "github.com/mitchellh/mapstructure", + "github.com/pkg/errors", + "github.com/prometheus/client_golang/prometheus", + "github.com/prometheus/client_golang/prometheus/promhttp", + "github.com/sirupsen/logrus", + "github.com/spf13/cobra", + "github.com/spf13/pflag", + "github.com/spf13/viper", + "github.com/stretchr/testify/assert", + "golang.org/x/time/rate", + "golang.org/x/tools/imports", + "k8s.io/apimachinery/pkg/util/rand", + "k8s.io/apimachinery/pkg/util/wait", + "k8s.io/client-go/util/workqueue", + ] + solver-name = "gps-cdcl" + solver-version = 1 diff --git a/Gopkg.toml b/Gopkg.toml new file mode 100644 index 0000000000..3d6cf08f6f --- /dev/null +++ b/Gopkg.toml @@ -0,0 +1,72 @@ +[[constraint]] + name = "github.com/aws/aws-sdk-go" + version = "1.15.0" + +[[constraint]] + name = "github.com/coocood/freecache" + version = "1.0.1" + +[[constraint]] + branch = "master" + name = "github.com/ernesto-jimenez/gogen" + +[[constraint]] + name = "github.com/fatih/color" + version = "1.7.0" + +[[constraint]] + name = "github.com/fatih/structtag" + version = "1.0.0" + +[[constraint]] + branch = "master" + name = "github.com/fsnotify/fsnotify" + +[[constraint]] + name = "github.com/golang/protobuf" + version = "1.1.0" + +[[constraint]] + name = "github.com/mitchellh/mapstructure" + version = "1.1.2" + +[[constraint]] + name = "github.com/pkg/errors" + version = "0.8.0" + +[[constraint]] + name = "github.com/prometheus/client_golang" + version = "^0.9.0" + +[[constraint]] + name = "github.com/spf13/cobra" + version = "0.0.3" + +[[constraint]] + name = "github.com/spf13/pflag" + version = "1.0.1" + +[[constraint]] + name = "github.com/spf13/viper" + # Viper only fixed symlink config watching after this SHA. move to a proper semVer when one is available. + revision = "ae103d7e593e371c69e832d5eb3347e2b80cbbc9" + +[[constraint]] + branch = "master" + name = "golang.org/x/time" + +[[constraint]] + name = "k8s.io/apimachinery" + version = "kubernetes-1.11.2" + +[[constraint]] + name = "k8s.io/client-go" + version = "kubernetes-1.11.2" + +[[constraint]] + name = "github.com/graymeta/stow" + revision = "77c84b1dd69c41b74fe0a94ca8ee257d85947327" + +[prune] + go-tests = true + unused-packages = true diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000..bed437514f --- /dev/null +++ b/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2019 Lyft, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000000..623663f5f1 --- /dev/null +++ b/Makefile @@ -0,0 +1,24 @@ +export REPOSITORY=flytestdlib +include boilerplate/lyft/golang_test_targets/Makefile + +# Generate golden files. Add test packages that generate golden files here. +golden: + go test ./cli/pflags/api -update + go test ./config -update + go test ./storage -update + go test ./tests -update + + +generate: + @echo "************************ go generate **********************************" + go generate ./... + +# This is the only target that should be overriden by the project. Get your binary into ${GOREPO}/bin +.PHONY: compile +compile: + mkdir -p ./bin + go build -o pflags ./cli/pflags/main.go && mv ./pflags ./bin + +gen-config: + which pflags || (go get github.com/lyft/flytestdlib/cli/pflags) + @go generate ./... diff --git a/NOTICE b/NOTICE new file mode 100644 index 0000000000..9316928ad6 --- /dev/null +++ b/NOTICE @@ -0,0 +1,21 @@ +flytestdlib +Copyright 2019-2020 Lyft Inc. + +This product includes software developed at Lyft Inc. + +Notices for file(s): + promutils/workqueue.go contains work from https://github.com/kubernetes/kubernetes/ + under the Apache2 license. + +/* +Copyright 2016 The Kubernetes Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ diff --git a/README.rst b/README.rst new file mode 100644 index 0000000000..f78785fa3d --- /dev/null +++ b/README.rst @@ -0,0 +1,35 @@ +K8s Standard Library +===================== +Shared components we found ourselves building time and time again, so we collected them in one place! + +This library consists of: + - config + + Enables strongly typed config throughout your application. Offers a way to represent config in go structs. takes care of parsing, validating and watching for changes on config. + + - cli/pflags + + Tool to generate a pflags for all fields in a given struct. + - storage + + Abstract storage library that uses stow behind the scenes to connect to s3/azure/gcs but also offers configurable factory, in-memory storage (for testing) as well as native protobuf support. + - contextutils + + Wrapper around golang's context to set/get known keys. + - logger + + Wrapper around logrus that's configurable, taggable and context-aware. + - profutils + + Starts an http server that serves /metrics (exposes prometheus metrics), /healthcheck and /version endpoints. + - promutils + + Exposes a Scope instance that's a more convenient way to construct prometheus metrics and scope them per component. + - atomic + + Wrapper around sync.atomic library to offer AtomicInt32 and other convenient types. + - sets + + Offers strongly types and convenient interface sets. + - utils + - version diff --git a/atomic/atomic.go b/atomic/atomic.go new file mode 100644 index 0000000000..26d04ab721 --- /dev/null +++ b/atomic/atomic.go @@ -0,0 +1,129 @@ +package atomic + +import "sync/atomic" + +// This file contains some simplified atomics primitives that Golang default library does not offer +// like, Boolean + +// Takes in a uint32 and converts to bool by checking whether the last bit is set to 1 +func toBool(n uint32) bool { + return n&1 == 1 +} + +// Takes in a bool and returns a uint32 representation +func toInt(b bool) uint32 { + if b { + return 1 + } + return 0 +} + +// Bool is an atomic Boolean. +// It stores the bool as a uint32 internally. This is to use the uint32 atomic functions available in golang +type Bool struct{ v uint32 } + +// NewBool creates a Bool. +func NewBool(initial bool) Bool { + return Bool{v: toInt(initial)} +} + +// Load atomically loads the Boolean. +func (b *Bool) Load() bool { + return toBool(atomic.LoadUint32(&b.v)) +} + +// CAS is an atomic compare-and-swap. +func (b *Bool) CompareAndSwap(old, new bool) bool { + return atomic.CompareAndSwapUint32(&b.v, toInt(old), toInt(new)) +} + +// Store atomically stores the passed value. +func (b *Bool) Store(new bool) { + atomic.StoreUint32(&b.v, toInt(new)) +} + +// Swap sets the given value and returns the previous value. +func (b *Bool) Swap(new bool) bool { + return toBool(atomic.SwapUint32(&b.v, toInt(new))) +} + +// Toggle atomically negates the Boolean and returns the previous value. +func (b *Bool) Toggle() bool { + return toBool(atomic.AddUint32(&b.v, 1) - 1) +} + +type Uint32 struct { + v uint32 +} + +// Returns a loaded uint32 value +func (u *Uint32) Load() uint32 { + return atomic.LoadUint32(&u.v) +} + +// CAS is an atomic compare-and-swap. +func (u *Uint32) CompareAndSwap(old, new uint32) bool { + return atomic.CompareAndSwapUint32(&u.v, old, new) +} + +// Add a delta to the number +func (u *Uint32) Add(delta uint32) uint32 { + return atomic.AddUint32(&u.v, delta) +} + +// Increment the value +func (u *Uint32) Inc() uint32 { + return atomic.AddUint32(&u.v, 1) +} + +// Set the value +func (u *Uint32) Store(v uint32) { + atomic.StoreUint32(&u.v, v) +} + +func NewUint32(v uint32) Uint32 { + return Uint32{v: v} +} + +type Int32 struct { + v int32 +} + +// Returns a loaded uint32 value +func (i *Int32) Load() int32 { + return atomic.LoadInt32(&i.v) +} + +// CAS is an atomic compare-and-swap. +func (i *Int32) CompareAndSwap(old, new int32) bool { + return atomic.CompareAndSwapInt32(&i.v, old, new) +} + +// Add a delta to the number +func (i *Int32) Add(delta int32) int32 { + return atomic.AddInt32(&i.v, delta) +} + +// Subtract a delta from the number +func (i *Int32) Sub(delta int32) int32 { + return atomic.AddInt32(&i.v, -delta) +} + +// Increment the value +func (i *Int32) Inc() int32 { + return atomic.AddInt32(&i.v, 1) +} + +// Decrement the value +func (i *Int32) Dec() int32 { + return atomic.AddInt32(&i.v, -1) +} + +// Set the value +func (i *Int32) Store(v int32) { + atomic.StoreInt32(&i.v, v) +} + +func NewInt32(v int32) Int32 { + return Int32{v: v} +} diff --git a/atomic/atomic_test.go b/atomic/atomic_test.go new file mode 100644 index 0000000000..c441a88993 --- /dev/null +++ b/atomic/atomic_test.go @@ -0,0 +1,61 @@ +package atomic + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBool(t *testing.T) { + atom := NewBool(false) + assert.False(t, atom.Toggle(), "Expected swap to return False.") + assert.True(t, atom.Load(), "Unexpected state after swap. Expected True") + + assert.True(t, atom.CompareAndSwap(true, true), "CAS should swap when old matches") + assert.True(t, atom.Load(), "previous swap should have no effect") + assert.True(t, atom.CompareAndSwap(true, false), "CAS should swap when old matches") + assert.False(t, atom.Load(), "Post swap the value should be true") + assert.False(t, atom.CompareAndSwap(true, false), "CAS should fail on old mismatch") + assert.False(t, atom.Load(), "CAS should not have modified the value") + + atom.Store(false) + assert.False(t, atom.Load(), "Unexpected state after store.") + + prev := atom.Swap(false) + assert.False(t, prev, "Expected Swap to return previous value.") + + prev = atom.Swap(true) + assert.False(t, prev, "Expected Swap to return previous value.") +} + +func TestInt32(t *testing.T) { + atom := NewInt32(2) + assert.False(t, atom.CompareAndSwap(3, 4), "Expected swap to return False.") + assert.Equal(t, int32(2), atom.Load(), "Unexpected state after swap. Expected True") + + assert.True(t, atom.CompareAndSwap(2, 2), "CAS should swap when old matches") + assert.Equal(t, int32(2), atom.Load(), "previous swap should have no effect") + assert.True(t, atom.CompareAndSwap(2, 4), "CAS should swap when old matches") + assert.Equal(t, int32(4), atom.Load(), "Post swap the value should be true") + assert.False(t, atom.CompareAndSwap(2, 3), "CAS should fail on old mismatch") + assert.Equal(t, int32(4), atom.Load(), "CAS should not have modified the value") + + atom.Store(5) + assert.Equal(t, int32(5), atom.Load(), "Unexpected state after store.") +} + +func TestUint32(t *testing.T) { + atom := NewUint32(2) + assert.False(t, atom.CompareAndSwap(3, 4), "Expected swap to return False.") + assert.Equal(t, uint32(2), atom.Load(), "Unexpected state after swap. Expected True") + + assert.True(t, atom.CompareAndSwap(2, 2), "CAS should swap when old matches") + assert.Equal(t, uint32(2), atom.Load(), "previous swap should have no effect") + assert.True(t, atom.CompareAndSwap(2, 4), "CAS should swap when old matches") + assert.Equal(t, uint32(4), atom.Load(), "Post swap the value should be true") + assert.False(t, atom.CompareAndSwap(2, 3), "CAS should fail on old mismatch") + assert.Equal(t, uint32(4), atom.Load(), "CAS should not have modified the value") + + atom.Store(5) + assert.Equal(t, uint32(5), atom.Load(), "Unexpected state after store.") +} diff --git a/atomic/non_blocking_lock.go b/atomic/non_blocking_lock.go new file mode 100644 index 0000000000..449841a960 --- /dev/null +++ b/atomic/non_blocking_lock.go @@ -0,0 +1,23 @@ +package atomic + +// Lock that provides TryLock method instead of blocking lock +type NonBlockingLock interface { + TryLock() bool + Release() +} + +func NewNonBlockingLock() NonBlockingLock { + return &nonBlockingLock{lock: NewBool(false)} +} + +type nonBlockingLock struct { + lock Bool +} + +func (n *nonBlockingLock) TryLock() bool { + return n.lock.CompareAndSwap(false, true) +} + +func (n *nonBlockingLock) Release() { + n.lock.Store(false) +} diff --git a/atomic/non_blocking_lock_test.go b/atomic/non_blocking_lock_test.go new file mode 100644 index 0000000000..ddd0a2123d --- /dev/null +++ b/atomic/non_blocking_lock_test.go @@ -0,0 +1,15 @@ +package atomic + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewNonBlockingLock(t *testing.T) { + lock := NewNonBlockingLock() + assert.True(t, lock.TryLock(), "Unexpected lock acquire failure") + assert.False(t, lock.TryLock(), "already-acquired lock acquired again") + lock.Release() + assert.True(t, lock.TryLock(), "Unexpected lock acquire failure") +} diff --git a/boilerplate/lyft/golang_test_targets/Makefile b/boilerplate/lyft/golang_test_targets/Makefile new file mode 100644 index 0000000000..1c6f893521 --- /dev/null +++ b/boilerplate/lyft/golang_test_targets/Makefile @@ -0,0 +1,31 @@ +.PHONY: lint +lint: #lints the package for common code smells + which golangci-lint || curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | bash -s -- -b $$GOPATH/bin v1.10 + golangci-lint run + +# If code is failing goimports linter, this will fix. +# skips 'vendor' +.PHONY: goimports +goimports: + @boilerplate/lyft/golang_test_targets/goimports + +.PHONY: install +install: #download dependencies (including test deps) for the package + which dep || (curl https://raw.githubusercontent.com/golang/dep/master/install.sh | sh) + dep ensure + +.PHONY: test_unit +test_unit: + go test -cover ./... -race + +.PHONY: test_benchmark +test_benchmark: + go test -bench . ./... + +.PHONY: test_unit_cover +test_unit_cover: + go test ./... -coverprofile /tmp/cover.out -covermode=count; go tool cover -func /tmp/cover.out + +.PHONY: test_unit_visual +test_unit_visual: + go test ./... -coverprofile /tmp/cover.out -covermode=count; go tool cover -html=/tmp/cover.out diff --git a/boilerplate/lyft/golang_test_targets/Readme.rst b/boilerplate/lyft/golang_test_targets/Readme.rst new file mode 100644 index 0000000000..acc5744f59 --- /dev/null +++ b/boilerplate/lyft/golang_test_targets/Readme.rst @@ -0,0 +1,31 @@ +Golang Test Targets +~~~~~~~~~~~~~~~~~~~ + +Provides an ``install`` make target that uses ``dep`` install golang dependencies. + +Provides a ``lint`` make target that uses golangci to lint your code. + +Provides a ``test_unit`` target for unit tests. + +Provides a ``test_unit_cover`` target for analysing coverage of unit tests, which will output the coverage of each function and total statement coverage. + +Provides a ``test_unit_visual`` target for visualizing coverage of unit tests through an interactive html code heat map. + +Provides a ``test_benchmark`` target for benchmark tests. + +**To Enable:** + +Add ``lyft/golang_test_targets`` to your ``boilerplate/update.cfg`` file. + +Make sure you're using ``dep`` for dependency management. + +Provide a ``.golangci`` configuration (the lint target requires it). + +Add ``include boilerplate/lyft/golang_test_targets/Makefile`` in your main ``Makefile`` _after_ your REPOSITORY environment variable + +:: + + REPOSITORY= + include boilerplate/lyft/golang_test_targets/Makefile + +(this ensures the extra make targets get included in your main Makefile) diff --git a/boilerplate/lyft/golang_test_targets/goimports b/boilerplate/lyft/golang_test_targets/goimports new file mode 100755 index 0000000000..11d3c9af06 --- /dev/null +++ b/boilerplate/lyft/golang_test_targets/goimports @@ -0,0 +1,3 @@ +#!/usr/bin/env bash + +goimports -w $(find . -type f -name '*.go' -not -path "./vendor/*" -not -path "./pkg/client/*") diff --git a/boilerplate/update.cfg b/boilerplate/update.cfg new file mode 100644 index 0000000000..f861a23ccd --- /dev/null +++ b/boilerplate/update.cfg @@ -0,0 +1,2 @@ +lyft/golang_test_targets +lyft/golangci_file diff --git a/cli/pflags/api/generator.go b/cli/pflags/api/generator.go new file mode 100644 index 0000000000..2e4dd30c54 --- /dev/null +++ b/cli/pflags/api/generator.go @@ -0,0 +1,277 @@ +package api + +import ( + "context" + "fmt" + "go/types" + "path/filepath" + + "github.com/lyft/flytestdlib/logger" + + "go/importer" + + "github.com/ernesto-jimenez/gogen/gogenutil" +) + +const ( + indent = " " +) + +// PFlagProviderGenerator parses and generates GetPFlagSet implementation to add PFlags for a given struct's fields. +type PFlagProviderGenerator struct { + pkg *types.Package + st *types.Named +} + +// This list is restricted because that's the only kinds viper parses out, otherwise it assumes strings. +// github.com/spf13/viper/viper.go:1016 +var allowedKinds = []types.Type{ + types.Typ[types.Int], + types.Typ[types.Int8], + types.Typ[types.Int16], + types.Typ[types.Int32], + types.Typ[types.Int64], + types.Typ[types.Bool], + types.Typ[types.String], +} + +type SliceOrArray interface { + Elem() types.Type +} + +func capitalize(s string) string { + if s[0] >= 'a' && s[0] <= 'z' { + return string(s[0]-'a'+'A') + s[1:] + } + + return s +} + +func buildFieldForSlice(ctx context.Context, t SliceOrArray, name, goName, usage, defaultValue string) (FieldInfo, error) { + strategy := SliceRaw + FlagMethodName := "StringSlice" + typ := types.NewSlice(types.Typ[types.String]) + emptyDefaultValue := `[]string{}` + if b, ok := t.Elem().(*types.Basic); !ok { + logger.Infof(ctx, "Elem of type [%v] is not a basic type. It must be json unmarshalable or generation will fail.", t.Elem()) + if !jsonUnmarshaler(t.Elem()) { + return FieldInfo{}, + fmt.Errorf("slice of type [%v] is not supported. Only basic slices or slices of json-unmarshalable types are supported", + t.Elem().String()) + } + } else { + logger.Infof(ctx, "Elem of type [%v] is a basic type. Will use a pflag as a Slice.", b) + strategy = SliceJoined + FlagMethodName = fmt.Sprintf("%vSlice", capitalize(b.Name())) + typ = types.NewSlice(b) + emptyDefaultValue = fmt.Sprintf(`[]%v{}`, b.Name()) + } + + testValue := defaultValue + if len(defaultValue) == 0 { + defaultValue = emptyDefaultValue + testValue = `"1,1"` + } + + return FieldInfo{ + Name: name, + GoName: goName, + Typ: typ, + FlagMethodName: FlagMethodName, + DefaultValue: defaultValue, + UsageString: usage, + TestValue: testValue, + TestStrategy: strategy, + }, nil +} + +// Traverses fields in type and follows recursion tree to discover all fields. It stops when one of two conditions is +// met; encountered a basic type (e.g. string, int... etc.) or the field type implements UnmarshalJSON. +func discoverFieldsRecursive(ctx context.Context, typ *types.Named) ([]FieldInfo, error) { + logger.Printf(ctx, "Finding all fields in [%v.%v.%v]", + typ.Obj().Pkg().Path(), typ.Obj().Pkg().Name(), typ.Obj().Name()) + + ctx = logger.WithIndent(ctx, indent) + + st := typ.Underlying().(*types.Struct) + fields := make([]FieldInfo, 0, st.NumFields()) + for i := 0; i < st.NumFields(); i++ { + v := st.Field(i) + if !v.IsField() { + continue + } + + // Parses out the tag if one exists. + tag, err := ParseTag(st.Tag(i)) + if err != nil { + return nil, err + } + + if len(tag.Name) == 0 { + tag.Name = v.Name() + } + + typ := v.Type() + if ptr, isPtr := typ.(*types.Pointer); isPtr { + typ = ptr.Elem() + } + + switch t := typ.(type) { + case *types.Basic: + if len(tag.DefaultValue) == 0 { + tag.DefaultValue = fmt.Sprintf("*new(%v)", typ.String()) + } + + logger.Infof(ctx, "[%v] is of a basic type with default value [%v].", tag.Name, tag.DefaultValue) + + isAllowed := false + for _, k := range allowedKinds { + if t.String() == k.String() { + isAllowed = true + break + } + } + + if !isAllowed { + return nil, fmt.Errorf("only these basic kinds are allowed. given [%v] (Kind: [%v]. expected: [%+v]", + t.String(), t.Kind(), allowedKinds) + } + + fields = append(fields, FieldInfo{ + Name: tag.Name, + GoName: v.Name(), + Typ: t, + FlagMethodName: camelCase(t.String()), + DefaultValue: tag.DefaultValue, + UsageString: tag.Usage, + TestValue: `"1"`, + TestStrategy: JSON, + }) + case *types.Named: + if _, isStruct := t.Underlying().(*types.Struct); !isStruct { + // TODO: Add a more descriptive error message. + return nil, fmt.Errorf("invalid type. it must be struct, received [%v] for field [%v]", t.Underlying().String(), tag.Name) + } + + // If the type has json unmarshaler, then stop the recursion and assume the type is string. config package + // will use json unmarshaler to fill in the final config object. + jsonUnmarshaler := jsonUnmarshaler(t) + + testValue := tag.DefaultValue + if len(tag.DefaultValue) == 0 { + tag.DefaultValue = `""` + testValue = `"1"` + } + + logger.Infof(ctx, "[%v] is of a Named type (struct) with default value [%v].", tag.Name, tag.DefaultValue) + + if jsonUnmarshaler { + logger.Infof(logger.WithIndent(ctx, indent), "Type is json unmarhslalable.") + + fields = append(fields, FieldInfo{ + Name: tag.Name, + GoName: v.Name(), + Typ: types.Typ[types.String], + FlagMethodName: "String", + DefaultValue: tag.DefaultValue, + UsageString: tag.Usage, + TestValue: testValue, + TestStrategy: JSON, + }) + } else { + logger.Infof(ctx, "Traversing fields in type.") + + nested, err := discoverFieldsRecursive(logger.WithIndent(ctx, indent), t) + if err != nil { + return nil, err + } + + for _, subField := range nested { + fields = append(fields, FieldInfo{ + Name: fmt.Sprintf("%v.%v", tag.Name, subField.Name), + GoName: fmt.Sprintf("%v.%v", v.Name(), subField.GoName), + Typ: subField.Typ, + FlagMethodName: subField.FlagMethodName, + DefaultValue: subField.DefaultValue, + UsageString: subField.UsageString, + TestValue: subField.TestValue, + TestStrategy: subField.TestStrategy, + }) + } + } + case *types.Slice: + logger.Infof(ctx, "[%v] is of a slice type with default value [%v].", tag.Name, tag.DefaultValue) + + f, err := buildFieldForSlice(logger.WithIndent(ctx, indent), t, tag.Name, v.Name(), tag.Usage, tag.DefaultValue) + if err != nil { + return nil, err + } + + fields = append(fields, f) + case *types.Array: + logger.Infof(ctx, "[%v] is of an array with default value [%v].", tag.Name, tag.DefaultValue) + + f, err := buildFieldForSlice(logger.WithIndent(ctx, indent), t, tag.Name, v.Name(), tag.Usage, tag.DefaultValue) + if err != nil { + return nil, err + } + + fields = append(fields, f) + default: + return nil, fmt.Errorf("unexpected type %v", t.String()) + } + } + + return fields, nil +} + +// NewGenerator initializes a PFlagProviderGenerator for pflags files for targetTypeName struct under pkg. If pkg is not filled in, +// it's assumed to be current package (which is expected to be the common use case when invoking pflags from +// go:generate comments) +func NewGenerator(pkg, targetTypeName string) (*PFlagProviderGenerator, error) { + var err error + // Resolve package path + if pkg == "" || pkg[0] == '.' { + pkg, err = filepath.Abs(filepath.Clean(pkg)) + if err != nil { + return nil, err + } + pkg = gogenutil.StripGopath(pkg) + } + + targetPackage, err := importer.For("source", nil).Import(pkg) + if err != nil { + return nil, err + } + + obj := targetPackage.Scope().Lookup(targetTypeName) + if obj == nil { + return nil, fmt.Errorf("struct %s missing", targetTypeName) + } + + var st *types.Named + switch obj.Type().Underlying().(type) { + case *types.Struct: + st = obj.Type().(*types.Named) + default: + return nil, fmt.Errorf("%s should be an struct, was %s", targetTypeName, obj.Type().Underlying()) + } + + return &PFlagProviderGenerator{ + st: st, + pkg: targetPackage, + }, nil +} + +func (g PFlagProviderGenerator) GetTargetPackage() *types.Package { + return g.pkg +} + +func (g PFlagProviderGenerator) Generate(ctx context.Context) (PFlagProvider, error) { + fields, err := discoverFieldsRecursive(ctx, g.st) + if err != nil { + return PFlagProvider{}, err + } + + return newPflagProvider(g.pkg, g.st.Obj().Name(), fields), nil +} diff --git a/cli/pflags/api/generator_test.go b/cli/pflags/api/generator_test.go new file mode 100644 index 0000000000..edfab6c1ca --- /dev/null +++ b/cli/pflags/api/generator_test.go @@ -0,0 +1,56 @@ +package api + +import ( + "context" + "flag" + "io/ioutil" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" +) + +// Make sure existing config file(s) parse correctly before overriding them with this flag! +var update = flag.Bool("update", false, "Updates testdata") + +func TestNewGenerator(t *testing.T) { + g, err := NewGenerator(".", "TestType") + assert.NoError(t, err) + + ctx := context.Background() + p, err := g.Generate(ctx) + assert.NoError(t, err) + + codeOutput, err := ioutil.TempFile("", "output-*.go") + assert.NoError(t, err) + defer func() { assert.NoError(t, os.Remove(codeOutput.Name())) }() + + testOutput, err := ioutil.TempFile("", "output-*_test.go") + assert.NoError(t, err) + defer func() { assert.NoError(t, os.Remove(testOutput.Name())) }() + + assert.NoError(t, p.WriteCodeFile(codeOutput.Name())) + assert.NoError(t, p.WriteTestFile(testOutput.Name())) + + codeBytes, err := ioutil.ReadFile(codeOutput.Name()) + assert.NoError(t, err) + + testBytes, err := ioutil.ReadFile(testOutput.Name()) + assert.NoError(t, err) + + goldenFilePath := filepath.Join("testdata", "testtype.go") + goldenTestFilePath := filepath.Join("testdata", "testtype_test.go") + if *update { + assert.NoError(t, ioutil.WriteFile(goldenFilePath, codeBytes, os.ModePerm)) + assert.NoError(t, ioutil.WriteFile(goldenTestFilePath, testBytes, os.ModePerm)) + } + + goldenOutput, err := ioutil.ReadFile(filepath.Clean(goldenFilePath)) + assert.NoError(t, err) + assert.Equal(t, goldenOutput, codeBytes) + + goldenTestOutput, err := ioutil.ReadFile(filepath.Clean(goldenTestFilePath)) + assert.NoError(t, err) + assert.Equal(t, string(goldenTestOutput), string(testBytes)) +} diff --git a/cli/pflags/api/pflag_provider.go b/cli/pflags/api/pflag_provider.go new file mode 100644 index 0000000000..e414398063 --- /dev/null +++ b/cli/pflags/api/pflag_provider.go @@ -0,0 +1,90 @@ +package api + +import ( + "bytes" + "fmt" + "go/types" + "io/ioutil" + "os" + "time" + + "github.com/ernesto-jimenez/gogen/imports" + goimports "golang.org/x/tools/imports" +) + +type PFlagProvider struct { + typeName string + pkg *types.Package + fields []FieldInfo +} + +// Adds any needed imports for types not directly declared in this package. +func (p PFlagProvider) Imports() map[string]string { + imp := imports.New(p.pkg.Name()) + for _, m := range p.fields { + imp.AddImportsFrom(m.Typ) + } + + return imp.Imports() +} + +// Evaluates the main code file template and writes the output to outputFilePath +func (p PFlagProvider) WriteCodeFile(outputFilePath string) error { + buf := bytes.Buffer{} + err := p.generate(GenerateCodeFile, &buf, outputFilePath) + if err != nil { + return fmt.Errorf("error generating code, Error: %v. Source: %v", err, buf.String()) + } + + return p.writeToFile(&buf, outputFilePath) +} + +// Evaluates the test code file template and writes the output to outputFilePath +func (p PFlagProvider) WriteTestFile(outputFilePath string) error { + buf := bytes.Buffer{} + err := p.generate(GenerateTestFile, &buf, outputFilePath) + if err != nil { + return fmt.Errorf("error generating code, Error: %v. Source: %v", err, buf.String()) + } + + return p.writeToFile(&buf, outputFilePath) +} + +func (p PFlagProvider) writeToFile(buffer *bytes.Buffer, fileName string) error { + return ioutil.WriteFile(fileName, buffer.Bytes(), os.ModePerm) +} + +// Evaluates the generator and writes the output to buffer. targetFileName is used only to influence how imports are +// generated/optimized. +func (p PFlagProvider) generate(generator func(buffer *bytes.Buffer, info TypeInfo) error, buffer *bytes.Buffer, targetFileName string) error { + info := TypeInfo{ + Name: p.typeName, + Fields: p.fields, + Package: p.pkg.Name(), + Timestamp: time.Now(), + Imports: p.Imports(), + } + + if err := generator(buffer, info); err != nil { + return err + } + + // Update imports + newBytes, err := goimports.Process(targetFileName, buffer.Bytes(), nil) + if err != nil { + return err + } + + buffer.Reset() + _, err = buffer.Write(newBytes) + + return err +} + +func newPflagProvider(pkg *types.Package, typeName string, fields []FieldInfo) PFlagProvider { + return PFlagProvider{ + typeName: typeName, + pkg: pkg, + fields: fields, + } +} diff --git a/cli/pflags/api/sample.go b/cli/pflags/api/sample.go new file mode 100644 index 0000000000..b1ebb50684 --- /dev/null +++ b/cli/pflags/api/sample.go @@ -0,0 +1,53 @@ +package api + +import ( + "encoding/json" + "errors" + + "github.com/lyft/flytestdlib/storage" +) + +type TestType struct { + StringValue string `json:"str" pflag:"\"hello world\",\"life is short\""` + BoolValue bool `json:"bl" pflag:"true"` + NestedType NestedType `json:"nested"` + IntArray []int `json:"ints" pflag:"[]int{12%2C1}"` + StringArray []string `json:"strs" pflag:"[]string{\"12\"%2C\"1\"}"` + ComplexJSONArray []ComplexJSONType `json:"complexArr"` + StringToJSON ComplexJSONType `json:"c" pflag:",I'm a complex type but can be converted from string."` + StorageConfig storage.Config `json:"storage"` + IntValue *int `json:"i"` +} + +type NestedType struct { + IntValue int `json:"i" pflag:",this is an important flag"` +} + +type ComplexJSONType struct { + StringValue string `json:"str"` + IntValue int `json:"i"` +} + +func (c *ComplexJSONType) UnmarshalJSON(b []byte) error { + if len(b) == 0 { + c.StringValue = "" + return nil + } + + var v interface{} + if err := json.Unmarshal(b, &v); err != nil { + return err + } + switch value := v.(type) { + case string: + if len(value) == 0 { + c.StringValue = "" + } else { + c.StringValue = value + } + default: + return errors.New("invalid duration") + } + + return nil +} diff --git a/cli/pflags/api/tag.go b/cli/pflags/api/tag.go new file mode 100644 index 0000000000..5d4a2d8e57 --- /dev/null +++ b/cli/pflags/api/tag.go @@ -0,0 +1,66 @@ +package api + +import ( + "fmt" + "net/url" + "strings" + + "github.com/fatih/structtag" +) + +const ( + TagName = "pflag" + JSONTagName = "json" +) + +// Represents parsed PFlag Go-struct tag. +// type Foo struct { +// StringValue string `json:"str" pflag:"\"hello world\",This is a string value"` +// } +// Name will be "str", Default value is "hello world" and Usage is "This is a string value" +type Tag struct { + Name string + DefaultValue string + Usage string +} + +// Parses tag. Name is computed from json tag, defaultvalue is the name of the pflag tag and usage is the concatenation +// of all options for pflag tag. +// e.g. `json:"name" pflag:"2,this is a useful param"` +func ParseTag(tag string) (t Tag, err error) { + tags, err := structtag.Parse(tag) + if err != nil { + return Tag{}, err + } + + t = Tag{} + + jsonTag, err := tags.Get(JSONTagName) + if err == nil { + t.Name = jsonTag.Name + } + + pflagTag, err := tags.Get(TagName) + if err == nil { + t.DefaultValue, err = url.QueryUnescape(pflagTag.Name) + if err != nil { + fmt.Printf("Failed to Query unescape tag name [%v], will use value as is. Error: %v", pflagTag.Name, err) + t.DefaultValue = pflagTag.Name + } + + t.Usage = strings.Join(pflagTag.Options, ", ") + if len(t.Usage) == 0 { + t.Usage = `""` + } + + if t.Usage[0] != '"' { + t.Usage = fmt.Sprintf(`"%v"`, t.Usage) + } + } else { + // We receive an error when the tag isn't present (or is malformed). Because there is no strongly-typed way to + // do that, we will just set Usage to empty string and move on. + t.Usage = `""` + } + + return t, nil +} diff --git a/cli/pflags/api/templates.go b/cli/pflags/api/templates.go new file mode 100644 index 0000000000..a7adba8922 --- /dev/null +++ b/cli/pflags/api/templates.go @@ -0,0 +1,175 @@ +package api + +import ( + "bytes" + "text/template" +) + +func GenerateCodeFile(buffer *bytes.Buffer, info TypeInfo) error { + return mainTmpl.Execute(buffer, info) +} + +func GenerateTestFile(buffer *bytes.Buffer, info TypeInfo) error { + return testTmpl.Execute(buffer, info) +} + +var mainTmpl = template.Must(template.New("MainFile").Parse( + `// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package {{ .Package }} + +import ( + "github.com/spf13/pflag" + "fmt" +{{range $path, $name := .Imports}} + {{$name}} "{{$path}}"{{end}} +) + +// GetPFlagSet will return strongly types pflags for all fields in {{ .Name }} and its nested types. The format of the +// flags is json-name.json-sub-name... etc. +func ({{ .Name }}) GetPFlagSet(prefix string) *pflag.FlagSet { + cmdFlags := pflag.NewFlagSet("{{ .Name }}", pflag.ExitOnError) + {{- range .Fields }} + cmdFlags.{{ .FlagMethodName }}(fmt.Sprintf("%v%v", prefix, "{{ .Name }}"), {{ .DefaultValue }}, {{ .UsageString }}) + {{- end }} + return cmdFlags +} +`)) + +var testTmpl = template.Must(template.New("TestFile").Parse( + `// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package {{ .Package }} + +import ( + "encoding/json" + "fmt" + "reflect" + "strings" + "testing" + + "github.com/mitchellh/mapstructure" + "github.com/stretchr/testify/assert" +{{- range $path, $name := .Imports}} + {{$name}} "{{$path}}" +{{- end}} +) + +var dereferencableKinds{{ .Name }} = map[reflect.Kind]struct{}{ + reflect.Array: {}, reflect.Chan: {}, reflect.Map: {}, reflect.Ptr: {}, reflect.Slice: {}, +} + +// Checks if t is a kind that can be dereferenced to get its underlying type. +func canGetElement{{ .Name }}(t reflect.Kind) bool { + _, exists := dereferencableKinds{{ .Name }}[t] + return exists +} + +// This decoder hook tests types for json unmarshaling capability. If implemented, it uses json unmarshal to build the +// object. Otherwise, it'll just pass on the original data. +func jsonUnmarshalerHook{{ .Name }}(_, to reflect.Type, data interface{}) (interface{}, error) { + unmarshalerType := reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() + if to.Implements(unmarshalerType) || reflect.PtrTo(to).Implements(unmarshalerType) || + (canGetElement{{ .Name }}(to.Kind()) && to.Elem().Implements(unmarshalerType)) { + + raw, err := json.Marshal(data) + if err != nil { + fmt.Printf("Failed to marshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + res := reflect.New(to).Interface() + err = json.Unmarshal(raw, &res) + if err != nil { + fmt.Printf("Failed to umarshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + return res, nil + } + + return data, nil +} + +func decode_{{ .Name }}(input, result interface{}) error { + config := &mapstructure.DecoderConfig{ + TagName: "json", + WeaklyTypedInput: true, + Result: result, + DecodeHook: mapstructure.ComposeDecodeHookFunc( + mapstructure.StringToTimeDurationHookFunc(), + mapstructure.StringToSliceHookFunc(","), + jsonUnmarshalerHook{{ .Name }}, + ), + } + + decoder, err := mapstructure.NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(input) +} + +func join_{{ .Name }}(arr interface{}, sep string) string { + listValue := reflect.ValueOf(arr) + strs := make([]string, 0, listValue.Len()) + for i := 0; i < listValue.Len(); i++ { + strs = append(strs, fmt.Sprintf("%v", listValue.Index(i))) + } + + return strings.Join(strs, sep) +} + +func testDecodeJson_{{ .Name }}(t *testing.T, val, result interface{}) { + assert.NoError(t, decode_{{ .Name }}(val, result)) +} + +func testDecodeSlice_{{ .Name }}(t *testing.T, vStringSlice, result interface{}) { + assert.NoError(t, decode_{{ .Name }}(vStringSlice, result)) +} + +func Test{{ .Name }}_GetPFlagSet(t *testing.T) { + val := {{ .Name }}{} + cmdFlags := val.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) +} + +func Test{{ .Name }}_SetFlags(t *testing.T) { + actual := {{ .Name }}{} + cmdFlags := actual.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) + + {{ $ParentName := .Name }} + {{- range .Fields }} + t.Run("Test_{{ .Name }}", func(t *testing.T) { {{ $varName := print "v" .FlagMethodName }} + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if {{ $varName }}, err := cmdFlags.Get{{ .FlagMethodName }}("{{ .Name }}"); err == nil { + assert.Equal(t, {{ .Typ }}({{ .DefaultValue }}), {{ $varName }}) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + {{ if eq .TestStrategy "Json" }}testValue := {{ .TestValue }} + {{ else if eq .TestStrategy "SliceRaw" }}testValue := {{ .TestValue }} + {{ else }}testValue := join_{{ $ParentName }}({{ .TestValue }}, ",") + {{ end }} + cmdFlags.Set("{{ .Name }}", testValue) + if {{ $varName }}, err := cmdFlags.Get{{ .FlagMethodName }}("{{ .Name }}"); err == nil { + {{ if eq .TestStrategy "Json" }}testDecodeJson_{{ $ParentName }}(t, fmt.Sprintf("%v", {{ print "v" .FlagMethodName }}), &actual.{{ .GoName }}) + {{ else if eq .TestStrategy "SliceRaw" }}testDecodeSlice_{{ $ParentName }}(t, {{ print "v" .FlagMethodName }}, &actual.{{ .GoName }}) + {{ else }}testDecodeSlice_{{ $ParentName }}(t, join_{{ $ParentName }}({{ print "v" .FlagMethodName }}, ","), &actual.{{ .GoName }}) + {{ end }} + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + {{- end }} +} +`)) diff --git a/cli/pflags/api/testdata/testtype.go b/cli/pflags/api/testdata/testtype.go new file mode 100755 index 0000000000..87f5cb7dfe --- /dev/null +++ b/cli/pflags/api/testdata/testtype.go @@ -0,0 +1,36 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package api + +import ( + "fmt" + + "github.com/spf13/pflag" +) + +// GetPFlagSet will return strongly types pflags for all fields in TestType and its nested types. The format of the +// flags is json-name.json-sub-name... etc. +func (TestType) GetPFlagSet(prefix string) *pflag.FlagSet { + cmdFlags := pflag.NewFlagSet("TestType", pflag.ExitOnError) + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "str"), "hello world", "life is short") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "bl"), true, "") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "nested.i"), *new(int), "this is an important flag") + cmdFlags.IntSlice(fmt.Sprintf("%v%v", prefix, "ints"), []int{12, 1}, "") + cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "strs"), []string{"12", "1"}, "") + cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "complexArr"), []string{}, "") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "c"), "", "I'm a complex type but can be converted from string.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.type"), "s3", "Sets the type of storage to configure [s3/minio/local/mem].") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.connection.endpoint"), "", "URL for storage client to connect to.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.connection.auth-type"), "iam", "Auth Type to use [iam, accesskey].") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.connection.access-key"), *new(string), "Access key to use. Only required when authtype is set to accesskey.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.connection.secret-key"), *new(string), "Secret to use when accesskey is set.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.connection.region"), "us-east-1", "Region to connect to.") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "storage.connection.disable-ssl"), *new(bool), "Disables SSL connection. Should only be used for development.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storage.container"), *new(string), "Initial container to create -if it doesn't exist-.'") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "storage.cache.max_size_mbs"), *new(int), "Maximum size of the cache where the Blob store data is cached in-memory. If not specified or set to 0, cache is not used") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "storage.cache.target_gc_percent"), *new(int), "Sets the garbage collection target percentage.") + cmdFlags.Int64(fmt.Sprintf("%v%v", prefix, "storage.limits.maxDownloadMBs"), 2, "Maximum allowed download size (in MBs) per call.") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "i"), *new(int), "") + return cmdFlags +} diff --git a/cli/pflags/api/testdata/testtype_test.go b/cli/pflags/api/testdata/testtype_test.go new file mode 100755 index 0000000000..f8b81bbe11 --- /dev/null +++ b/cli/pflags/api/testdata/testtype_test.go @@ -0,0 +1,520 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package api + +import ( + "encoding/json" + "fmt" + "reflect" + "strings" + "testing" + + "github.com/mitchellh/mapstructure" + "github.com/stretchr/testify/assert" +) + +var dereferencableKindsTestType = map[reflect.Kind]struct{}{ + reflect.Array: {}, reflect.Chan: {}, reflect.Map: {}, reflect.Ptr: {}, reflect.Slice: {}, +} + +// Checks if t is a kind that can be dereferenced to get its underlying type. +func canGetElementTestType(t reflect.Kind) bool { + _, exists := dereferencableKindsTestType[t] + return exists +} + +// This decoder hook tests types for json unmarshaling capability. If implemented, it uses json unmarshal to build the +// object. Otherwise, it'll just pass on the original data. +func jsonUnmarshalerHookTestType(_, to reflect.Type, data interface{}) (interface{}, error) { + unmarshalerType := reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() + if to.Implements(unmarshalerType) || reflect.PtrTo(to).Implements(unmarshalerType) || + (canGetElementTestType(to.Kind()) && to.Elem().Implements(unmarshalerType)) { + + raw, err := json.Marshal(data) + if err != nil { + fmt.Printf("Failed to marshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + res := reflect.New(to).Interface() + err = json.Unmarshal(raw, &res) + if err != nil { + fmt.Printf("Failed to umarshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + return res, nil + } + + return data, nil +} + +func decode_TestType(input, result interface{}) error { + config := &mapstructure.DecoderConfig{ + TagName: "json", + WeaklyTypedInput: true, + Result: result, + DecodeHook: mapstructure.ComposeDecodeHookFunc( + mapstructure.StringToTimeDurationHookFunc(), + mapstructure.StringToSliceHookFunc(","), + jsonUnmarshalerHookTestType, + ), + } + + decoder, err := mapstructure.NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(input) +} + +func join_TestType(arr interface{}, sep string) string { + listValue := reflect.ValueOf(arr) + strs := make([]string, 0, listValue.Len()) + for i := 0; i < listValue.Len(); i++ { + strs = append(strs, fmt.Sprintf("%v", listValue.Index(i))) + } + + return strings.Join(strs, sep) +} + +func testDecodeJson_TestType(t *testing.T, val, result interface{}) { + assert.NoError(t, decode_TestType(val, result)) +} + +func testDecodeSlice_TestType(t *testing.T, vStringSlice, result interface{}) { + assert.NoError(t, decode_TestType(vStringSlice, result)) +} + +func TestTestType_GetPFlagSet(t *testing.T) { + val := TestType{} + cmdFlags := val.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) +} + +func TestTestType_SetFlags(t *testing.T) { + actual := TestType{} + cmdFlags := actual.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) + + t.Run("Test_str", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("str"); err == nil { + assert.Equal(t, string("hello world"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("str", testValue) + if vString, err := cmdFlags.GetString("str"); err == nil { + testDecodeJson_TestType(t, fmt.Sprintf("%v", vString), &actual.StringValue) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_bl", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vBool, err := cmdFlags.GetBool("bl"); err == nil { + assert.Equal(t, bool(true), vBool) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("bl", testValue) + if vBool, err := cmdFlags.GetBool("bl"); err == nil { + testDecodeJson_TestType(t, fmt.Sprintf("%v", vBool), &actual.BoolValue) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_nested.i", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("nested.i"); err == nil { + assert.Equal(t, int(*new(int)), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("nested.i", testValue) + if vInt, err := cmdFlags.GetInt("nested.i"); err == nil { + testDecodeJson_TestType(t, fmt.Sprintf("%v", vInt), &actual.NestedType.IntValue) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_ints", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vIntSlice, err := cmdFlags.GetIntSlice("ints"); err == nil { + assert.Equal(t, []int([]int{12, 1}), vIntSlice) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := join_TestType([]int{12, 1}, ",") + + cmdFlags.Set("ints", testValue) + if vIntSlice, err := cmdFlags.GetIntSlice("ints"); err == nil { + testDecodeSlice_TestType(t, join_TestType(vIntSlice, ","), &actual.IntArray) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_strs", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vStringSlice, err := cmdFlags.GetStringSlice("strs"); err == nil { + assert.Equal(t, []string([]string{"12", "1"}), vStringSlice) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := join_TestType([]string{"12", "1"}, ",") + + cmdFlags.Set("strs", testValue) + if vStringSlice, err := cmdFlags.GetStringSlice("strs"); err == nil { + testDecodeSlice_TestType(t, join_TestType(vStringSlice, ","), &actual.StringArray) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_complexArr", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vStringSlice, err := cmdFlags.GetStringSlice("complexArr"); err == nil { + assert.Equal(t, []string([]string{}), vStringSlice) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1,1" + + cmdFlags.Set("complexArr", testValue) + if vStringSlice, err := cmdFlags.GetStringSlice("complexArr"); err == nil { + testDecodeSlice_TestType(t, vStringSlice, &actual.ComplexJSONArray) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_c", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("c"); err == nil { + assert.Equal(t, string(""), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("c", testValue) + if vString, err := cmdFlags.GetString("c"); err == nil { + testDecodeJson_TestType(t, fmt.Sprintf("%v", vString), &actual.StringToJSON) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_storage.type", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("storage.type"); err == nil { + assert.Equal(t, string("s3"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("storage.type", testValue) + if vString, err := cmdFlags.GetString("storage.type"); err == nil { + testDecodeJson_TestType(t, fmt.Sprintf("%v", vString), &actual.StorageConfig.Type) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_storage.connection.endpoint", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("storage.connection.endpoint"); err == nil { + assert.Equal(t, string(""), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("storage.connection.endpoint", testValue) + if vString, err := cmdFlags.GetString("storage.connection.endpoint"); err == nil { + testDecodeJson_TestType(t, fmt.Sprintf("%v", vString), &actual.StorageConfig.Connection.Endpoint) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_storage.connection.auth-type", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("storage.connection.auth-type"); err == nil { + assert.Equal(t, string("iam"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("storage.connection.auth-type", testValue) + if vString, err := cmdFlags.GetString("storage.connection.auth-type"); err == nil { + testDecodeJson_TestType(t, fmt.Sprintf("%v", vString), &actual.StorageConfig.Connection.AuthType) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_storage.connection.access-key", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("storage.connection.access-key"); err == nil { + assert.Equal(t, string(*new(string)), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("storage.connection.access-key", testValue) + if vString, err := cmdFlags.GetString("storage.connection.access-key"); err == nil { + testDecodeJson_TestType(t, fmt.Sprintf("%v", vString), &actual.StorageConfig.Connection.AccessKey) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_storage.connection.secret-key", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("storage.connection.secret-key"); err == nil { + assert.Equal(t, string(*new(string)), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("storage.connection.secret-key", testValue) + if vString, err := cmdFlags.GetString("storage.connection.secret-key"); err == nil { + testDecodeJson_TestType(t, fmt.Sprintf("%v", vString), &actual.StorageConfig.Connection.SecretKey) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_storage.connection.region", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("storage.connection.region"); err == nil { + assert.Equal(t, string("us-east-1"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("storage.connection.region", testValue) + if vString, err := cmdFlags.GetString("storage.connection.region"); err == nil { + testDecodeJson_TestType(t, fmt.Sprintf("%v", vString), &actual.StorageConfig.Connection.Region) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_storage.connection.disable-ssl", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vBool, err := cmdFlags.GetBool("storage.connection.disable-ssl"); err == nil { + assert.Equal(t, bool(*new(bool)), vBool) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("storage.connection.disable-ssl", testValue) + if vBool, err := cmdFlags.GetBool("storage.connection.disable-ssl"); err == nil { + testDecodeJson_TestType(t, fmt.Sprintf("%v", vBool), &actual.StorageConfig.Connection.DisableSSL) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_storage.container", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("storage.container"); err == nil { + assert.Equal(t, string(*new(string)), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("storage.container", testValue) + if vString, err := cmdFlags.GetString("storage.container"); err == nil { + testDecodeJson_TestType(t, fmt.Sprintf("%v", vString), &actual.StorageConfig.InitContainer) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_storage.cache.max_size_mbs", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("storage.cache.max_size_mbs"); err == nil { + assert.Equal(t, int(*new(int)), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("storage.cache.max_size_mbs", testValue) + if vInt, err := cmdFlags.GetInt("storage.cache.max_size_mbs"); err == nil { + testDecodeJson_TestType(t, fmt.Sprintf("%v", vInt), &actual.StorageConfig.Cache.MaxSizeMegabytes) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_storage.cache.target_gc_percent", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("storage.cache.target_gc_percent"); err == nil { + assert.Equal(t, int(*new(int)), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("storage.cache.target_gc_percent", testValue) + if vInt, err := cmdFlags.GetInt("storage.cache.target_gc_percent"); err == nil { + testDecodeJson_TestType(t, fmt.Sprintf("%v", vInt), &actual.StorageConfig.Cache.TargetGCPercent) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_storage.limits.maxDownloadMBs", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt64, err := cmdFlags.GetInt64("storage.limits.maxDownloadMBs"); err == nil { + assert.Equal(t, int64(2), vInt64) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("storage.limits.maxDownloadMBs", testValue) + if vInt64, err := cmdFlags.GetInt64("storage.limits.maxDownloadMBs"); err == nil { + testDecodeJson_TestType(t, fmt.Sprintf("%v", vInt64), &actual.StorageConfig.Limits.GetLimitMegabytes) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_i", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("i"); err == nil { + assert.Equal(t, int(*new(int)), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("i", testValue) + if vInt, err := cmdFlags.GetInt("i"); err == nil { + testDecodeJson_TestType(t, fmt.Sprintf("%v", vInt), &actual.IntValue) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) +} diff --git a/cli/pflags/api/types.go b/cli/pflags/api/types.go new file mode 100644 index 0000000000..1e6c1297ff --- /dev/null +++ b/cli/pflags/api/types.go @@ -0,0 +1,36 @@ +package api + +import ( + "go/types" + "time" +) + +// Determines how tests should be generated. +type TestStrategy string + +const ( + JSON TestStrategy = "Json" + SliceJoined TestStrategy = "SliceJoined" + SliceRaw TestStrategy = "SliceRaw" +) + +type FieldInfo struct { + Name string + GoName string + Typ types.Type + DefaultValue string + UsageString string + FlagMethodName string + TestValue string + TestStrategy TestStrategy +} + +// Holds the finalized information passed to the template for evaluation. +type TypeInfo struct { + Timestamp time.Time + Fields []FieldInfo + Package string + Name string + TypeRef string + Imports map[string]string +} diff --git a/cli/pflags/api/utils.go b/cli/pflags/api/utils.go new file mode 100644 index 0000000000..4c71fbb1c4 --- /dev/null +++ b/cli/pflags/api/utils.go @@ -0,0 +1,32 @@ +package api + +import ( + "bytes" + "fmt" + "go/types" + "unicode" +) + +func camelCase(str string) string { + if len(str) == 0 { + return str + } + + firstRune := bytes.Runes([]byte(str))[0] + if unicode.IsLower(firstRune) { + return fmt.Sprintf("%v%v", string(unicode.ToUpper(firstRune)), str[1:]) + } + + return str +} + +func jsonUnmarshaler(t types.Type) bool { + mset := types.NewMethodSet(t) + jsonUnmarshaler := mset.Lookup(nil, "UnmarshalJSON") + if jsonUnmarshaler == nil { + mset = types.NewMethodSet(types.NewPointer(t)) + jsonUnmarshaler = mset.Lookup(nil, "UnmarshalJSON") + } + + return jsonUnmarshaler != nil +} diff --git a/cli/pflags/cmd/root.go b/cli/pflags/cmd/root.go new file mode 100644 index 0000000000..b6562d8a1b --- /dev/null +++ b/cli/pflags/cmd/root.go @@ -0,0 +1,71 @@ +package cmd + +import ( + "bytes" + "context" + "flag" + "fmt" + "strings" + + "github.com/lyft/flytestdlib/cli/pflags/api" + "github.com/lyft/flytestdlib/logger" + "github.com/spf13/cobra" +) + +var ( + pkg = flag.String("pkg", ".", "what package to get the interface from") +) + +var root = cobra.Command{ + Use: "pflags MyStructName --package myproject/mypackage", + Args: cobra.ExactArgs(1), + RunE: generatePflagsProvider, + Example: ` +// go:generate pflags MyStruct +type MyStruct struct { + BoolValue bool ` + "`json:\"bl\" pflag:\"true\"`" + ` + NestedType NestedType ` + "`json:\"nested\"`" + ` + IntArray []int ` + "`json:\"ints\" pflag:\"[]int{12%2C1}\"`" + ` +} + `, +} + +func init() { + root.Flags().StringP("package", "p", ".", "Determines the source/destination package.") +} + +func Execute() error { + return root.Execute() +} + +func generatePflagsProvider(cmd *cobra.Command, args []string) error { + structName := args[0] + if structName == "" { + return fmt.Errorf("need to specify a struct name") + } + + ctx := context.Background() + gen, err := api.NewGenerator(*pkg, structName) + if err != nil { + return err + } + + provider, err := gen.Generate(ctx) + if err != nil { + return err + } + + var buf bytes.Buffer + defer buf.Reset() + + logger.Infof(ctx, "Generating PFlags for type [%v.%v.%v]\n", gen.GetTargetPackage().Path(), gen.GetTargetPackage().Name(), structName) + + outFilePath := fmt.Sprintf("%s_flags.go", strings.ToLower(structName)) + err = provider.WriteCodeFile(outFilePath) + if err != nil { + return err + } + + tOutFilePath := fmt.Sprintf("%s_flags_test.go", strings.ToLower(structName)) + return provider.WriteTestFile(tOutFilePath) +} diff --git a/cli/pflags/cmd/version.go b/cli/pflags/cmd/version.go new file mode 100644 index 0000000000..8ee00af2e4 --- /dev/null +++ b/cli/pflags/cmd/version.go @@ -0,0 +1,17 @@ +package cmd + +import ( + "github.com/lyft/flytestdlib/version" + "github.com/spf13/cobra" +) + +var versionCmd = &cobra.Command{ + Aliases: []string{"version", "ver"}, + Run: func(cmd *cobra.Command, args []string) { + cmd.Printf("Version: %s\nBuildSHA: %s\nBuildTS: %s\n", version.Version, version.Build, version.BuildTime.String()) + }, +} + +func init() { + root.AddCommand(versionCmd) +} diff --git a/cli/pflags/main.go b/cli/pflags/main.go new file mode 100644 index 0000000000..e4c784a7b1 --- /dev/null +++ b/cli/pflags/main.go @@ -0,0 +1,15 @@ +// Generates a Register method to automatically add pflags to a pflagSet for all fields in a given type. +package main + +import ( + "log" + + "github.com/lyft/flytestdlib/cli/pflags/cmd" +) + +func main() { + err := cmd.Execute() + if err != nil { + log.Fatal(err) + } +} diff --git a/cli/pflags/readme.rst b/cli/pflags/readme.rst new file mode 100644 index 0000000000..8a47d921f8 --- /dev/null +++ b/cli/pflags/readme.rst @@ -0,0 +1,24 @@ +================ +Pflags Generator +================ + +This tool enables you to generate code to add pflags for all fields in a struct (recursively). In conjunction with the config package, this can be useful to generate cli flags that overrides configs while maintaing type safety and not having to deal with string typos. + +Getting Started +^^^^^^^^^^^^^^^ + - ``go get github.com/lyft/flytestdlib/cli/pflags`` + - call ``pflags --package `` OR + - add ``//go:generate pflags `` to the top of the file where the struct is declared. + has to be a struct type (it can't be, for instance, a slice type). + Supported fields' types within the struct: basic types (string, int8, int16, int32, int64, bool), json-unmarshalable types and other structs that conform to the same rules or slices of these types. + +This generates two files (struct_name_pflags.go and struct_name_pflags_test.go). If you open those, you will notice that all generated flags default to empty/zero values and no usage strings. That behavior can be customized using ``pflag`` tag. + +.. code-block:: + + type TestType struct { + StringValue string `json:"str" pflag:"\"hello world\",\"life is short\""` + BoolValue bool `json:"bl" pflag:",This is a bool value that will default to false."` + } + +``pflag`` tag is a comma-separated list. First item represents default value. Second value is usage. diff --git a/config/accessor.go b/config/accessor.go new file mode 100644 index 0000000000..1b5a693316 --- /dev/null +++ b/config/accessor.go @@ -0,0 +1,61 @@ +// A strongly-typed config library to parse configs from PFlags, Env Vars and Config files. +// Config package enables consumers to access (readonly for now) strongly typed configs without worrying about mismatching +// keys or casting to the wrong type. It supports basic types (e.g. int, string) as well as more complex structures through +// json encoding/decoding. +// +// Config package introduces the concept of Sections. Each section should be given a unique section key. The binary will +// not load if there is a conflict. Each section should be represented as a Go struct and registered at startup before +// config is loaded/parsed. +// +// Sections can be nested too. A new config section can be registered as a sub-section of an existing one. This allows +// dynamic grouping of sections while continuing to enforce strong-typed parsing of configs. +// +// Config data can be parsed from supported config file(s) (yaml, prop, toml), env vars, PFlags or a combination of these +// Precedence is (flags, env vars, config file, defaults). When data is read from config files, a file watcher is started +// to monitor for changes in those files. If the registrant of a section subscribes to changes then a handler is called +// when the relevant section has been updated. Sections within a single config file will be invoked after all sections +// from that particular config file are parsed. It follows that if there are inter-dependent sections (e.g. changing one +// MUST be followed by a change in another), then make sure those sections are placed in the same config file. +// +// A convenience tool is also provided in cli package (pflags) that generates an implementation for PFlagProvider interface +// based on json names of the fields. +package config + +import ( + "context" + "flag" + + "github.com/spf13/pflag" +) + +// Provides a simple config parser interface. +type Accessor interface { + // Gets a friendly identifier for the accessor. + ID() string + + // Initializes the config parser with golang's default flagset. + InitializeFlags(cmdFlags *flag.FlagSet) + + // Initializes the config parser with pflag's flagset. + InitializePflags(cmdFlags *pflag.FlagSet) + + // Parses and validates config file(s) discovered then updates the underlying config section with the results. + // Exercise caution when calling this because multiple invocations will overwrite each other's results. + UpdateConfig(ctx context.Context) error + + // Gets path(s) to the config file(s) used. + ConfigFilesUsed() []string +} + +// Options used to initialize a Config Accessor +type Options struct { + // Instructs parser to fail if any section/key in the config file read do not have a corresponding registered section. + StrictMode bool + + // Search paths to look for config file(s). If not specified, it searches for config.yaml under current directory as well + // as /etc/flyte/config directories. + SearchPaths []string + + // Defines the root section to use with the accessor. + RootSection Section +} diff --git a/config/accessor_test.go b/config/accessor_test.go new file mode 100644 index 0000000000..f386b30b9c --- /dev/null +++ b/config/accessor_test.go @@ -0,0 +1,91 @@ +package config + +import ( + "context" + "fmt" + "path/filepath" +) + +func Example() { + // This example demonstrates basic usage of config sections. + + //go:generate pflags OtherComponentConfig + + type OtherComponentConfig struct { + DurationValue Duration `json:"duration-value"` + URLValue URL `json:"url-value"` + StringValue string `json:"string-value"` + } + + // Each component should register their section in package init() or as a package var + section := MustRegisterSection("other-component", &OtherComponentConfig{}) + + // Override configpath to look for a custom location. + configPath := filepath.Join("testdata", "config.yaml") + + // Initialize an accessor. + var accessor Accessor + // e.g. + // accessor = viper.NewAccessor(viper.Options{ + // StrictMode: true, + // SearchPaths: []string{configPath, configPath2}, + // }) + + // Optionally bind to Pflags. + // accessor.InitializePflags(flags) + + // Parse config from file or pass empty to rely on env variables and PFlags + err := accessor.UpdateConfig(context.Background()) + if err != nil { + fmt.Printf("Failed to validate config from [%v], error: %v", configPath, err) + return + } + + // Get parsed config value. + parsedConfig := section.GetConfig().(*OtherComponentConfig) + fmt.Printf("Config: %v", parsedConfig) +} + +func Example_nested() { + // This example demonstrates registering nested config sections dynamically. + + //go:generate pflags OtherComponentConfig + + type OtherComponentConfig struct { + DurationValue Duration `json:"duration-value"` + URLValue URL `json:"url-value"` + StringValue string `json:"string-value"` + } + + // Each component should register their section in package init() or as a package var + Section := MustRegisterSection("my-component", &MyComponentConfig{}) + + // Other packages can register their sections at the root level (like the above line) or as nested sections of other + // sections (like the below line) + NestedSection := Section.MustRegisterSection("nested", &OtherComponentConfig{}) + + // Override configpath to look for a custom location. + configPath := filepath.Join("testdata", "nested_config.yaml") + + // Initialize an accessor. + var accessor Accessor + // e.g. + // accessor = viper.NewAccessor(viper.Options{ + // StrictMode: true, + // SearchPaths: []string{configPath, configPath2}, + // }) + + // Optionally bind to Pflags. + // accessor.InitializePflags(flags) + + // Parse config from file or pass empty to rely on env variables and PFlags + err := accessor.UpdateConfig(context.Background()) + if err != nil { + fmt.Printf("Failed to validate config from [%v], error: %v", configPath, err) + return + } + + // Get parsed config value. + parsedConfig := NestedSection.GetConfig().(*OtherComponentConfig) + fmt.Printf("Config: %v", parsedConfig) +} diff --git a/config/config_cmd.go b/config/config_cmd.go new file mode 100644 index 0000000000..a5023ceb61 --- /dev/null +++ b/config/config_cmd.go @@ -0,0 +1,113 @@ +package config + +import ( + "context" + "os" + "strings" + + "github.com/fatih/color" + "github.com/spf13/cobra" +) + +const ( + PathFlag = "file" + StrictModeFlag = "strict" + CommandValidate = "validate" + CommandDiscover = "discover" +) + +type AccessorProvider func(options Options) Accessor + +type printer interface { + Printf(format string, i ...interface{}) + Println(i ...interface{}) +} + +func NewConfigCommand(accessorProvider AccessorProvider) *cobra.Command { + opts := Options{} + rootCmd := &cobra.Command{ + Use: "config", + Short: "Runs various config commands, look at the help of this command to get a list of available commands..", + ValidArgs: []string{CommandValidate, CommandDiscover}, + } + + validateCmd := &cobra.Command{ + Use: "validate", + Short: "Validates the loaded config.", + RunE: func(cmd *cobra.Command, args []string) error { + return validate(accessorProvider(opts), cmd) + }, + } + + discoverCmd := &cobra.Command{ + Use: "discover", + Short: "Searches for a config in one of the default search paths.", + RunE: func(cmd *cobra.Command, args []string) error { + return validate(accessorProvider(opts), cmd) + }, + } + + // Configure Root Command + rootCmd.PersistentFlags().StringArrayVar(&opts.SearchPaths, PathFlag, []string{}, `Passes the config file to load. +If empty, it'll first search for the config file path then, if found, will load config from there.`) + + rootCmd.AddCommand(validateCmd) + rootCmd.AddCommand(discoverCmd) + + // Configure Validate Command + validateCmd.Flags().BoolVar(&opts.StrictMode, StrictModeFlag, false, `Validates that all keys in loaded config +map to already registered sections.`) + + return rootCmd +} + +// Redirects Stdout to a string buffer until context is cancelled. +func redirectStdOut() (old, new *os.File) { + old = os.Stdout // keep backup of the real stdout + var err error + _, new, err = os.Pipe() + if err != nil { + panic(err) + } + + os.Stdout = new + + return +} + +func validate(accessor Accessor, p printer) error { + // Redirect stdout + old, n := redirectStdOut() + defer func() { + err := n.Close() + if err != nil { + panic(err) + } + }() + defer func() { os.Stdout = old }() + + err := accessor.UpdateConfig(context.Background()) + + printInfo(p, accessor) + if err == nil { + green := color.New(color.FgGreen).SprintFunc() + p.Println(green("Validated config file successfully.")) + } else { + red := color.New(color.FgRed).SprintFunc() + p.Println(red("Failed to validate config file.")) + } + + return err +} + +func printInfo(p printer, v Accessor) { + cfgFile := v.ConfigFilesUsed() + if len(cfgFile) != 0 { + green := color.New(color.FgGreen).SprintFunc() + + p.Printf("Config file(s) found at: %v\n", green(strings.Join(cfgFile, "\n"))) + } else { + red := color.New(color.FgRed).SprintFunc() + p.Println(red("Couldn't find a config file.")) + } +} diff --git a/config/config_cmd_test.go b/config/config_cmd_test.go new file mode 100644 index 0000000000..d8ef31d298 --- /dev/null +++ b/config/config_cmd_test.go @@ -0,0 +1,64 @@ +package config + +import ( + "bytes" + "context" + "flag" + "testing" + + "github.com/spf13/cobra" + "github.com/spf13/pflag" + "github.com/stretchr/testify/assert" +) + +type MockAccessor struct { +} + +func (MockAccessor) ID() string { + panic("implement me") +} + +func (MockAccessor) InitializeFlags(cmdFlags *flag.FlagSet) { +} + +func (MockAccessor) InitializePflags(cmdFlags *pflag.FlagSet) { +} + +func (MockAccessor) UpdateConfig(ctx context.Context) error { + return nil +} + +func (MockAccessor) ConfigFilesUsed() []string { + return []string{"test"} +} + +func (MockAccessor) RefreshFromConfig() error { + return nil +} + +func newMockAccessor(options Options) Accessor { + return MockAccessor{} +} + +func executeCommandC(root *cobra.Command, args ...string) (c *cobra.Command, output string, err error) { + buf := new(bytes.Buffer) + root.SetOutput(buf) + root.SetArgs(args) + + c, err = root.ExecuteC() + + return c, buf.String(), err +} + +func TestNewConfigCommand(t *testing.T) { + cmd := NewConfigCommand(newMockAccessor) + assert.NotNil(t, cmd) + + _, output, err := executeCommandC(cmd, CommandDiscover) + assert.NoError(t, err) + assert.Contains(t, output, "test") + + _, output, err = executeCommandC(cmd, CommandValidate) + assert.NoError(t, err) + assert.Contains(t, output, "test") +} diff --git a/config/duration.go b/config/duration.go new file mode 100644 index 0000000000..e1a978d5c8 --- /dev/null +++ b/config/duration.go @@ -0,0 +1,47 @@ +package config + +import ( + "encoding/json" + "errors" + "time" +) + +// A wrapper around time.Duration that enables Json Marshalling capabilities +type Duration struct { + time.Duration +} + +func (d Duration) MarshalJSON() ([]byte, error) { + return json.Marshal(d.String()) +} + +func (d *Duration) UnmarshalJSON(b []byte) error { + if len(b) == 0 { + d.Duration = time.Duration(0) + return nil + } + + var v interface{} + if err := json.Unmarshal(b, &v); err != nil { + return err + } + switch value := v.(type) { + case float64: + d.Duration = time.Duration(value) + return nil + case string: + if len(value) == 0 { + d.Duration = time.Duration(0) + } else { + var err error + d.Duration, err = time.ParseDuration(value) + if err != nil { + return err + } + } + default: + return errors.New("invalid duration") + } + + return nil +} diff --git a/config/duration_test.go b/config/duration_test.go new file mode 100644 index 0000000000..8e411987ac --- /dev/null +++ b/config/duration_test.go @@ -0,0 +1,66 @@ +package config + +import ( + "encoding/json" + "reflect" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestDuration_MarshalJSON(t *testing.T) { + t.Run("Valid", func(t *testing.T) { + expected := Duration{ + Duration: time.Second * 2, + } + + b, err := expected.MarshalJSON() + assert.NoError(t, err) + + actual := Duration{} + err = actual.UnmarshalJSON(b) + assert.NoError(t, err) + + assert.True(t, reflect.DeepEqual(expected, actual)) + }) +} + +func TestDuration_UnmarshalJSON(t *testing.T) { + t.Run("Empty", func(t *testing.T) { + actual := Duration{} + err := actual.UnmarshalJSON([]byte{}) + assert.NoError(t, err) + assert.Equal(t, time.Duration(0), actual.Duration) + }) + + t.Run("Invalid_string", func(t *testing.T) { + input := "blah" + raw, err := json.Marshal(input) + assert.NoError(t, err) + + actual := Duration{} + err = actual.UnmarshalJSON(raw) + assert.Error(t, err) + }) + + t.Run("Valid_float", func(t *testing.T) { + input := float64(12345) + raw, err := json.Marshal(input) + assert.NoError(t, err) + + actual := Duration{} + err = actual.UnmarshalJSON(raw) + assert.NoError(t, err) + }) + + t.Run("Invalid_bool", func(t *testing.T) { + input := true + raw, err := json.Marshal(input) + assert.NoError(t, err) + + actual := Duration{} + err = actual.UnmarshalJSON(raw) + assert.Error(t, err) + }) +} diff --git a/config/errors.go b/config/errors.go new file mode 100644 index 0000000000..b46459a9f9 --- /dev/null +++ b/config/errors.go @@ -0,0 +1,37 @@ +package config + +import "fmt" + +var ( + ErrStrictModeValidation = fmt.Errorf("failed strict mode check") + ErrChildConfigOverridesConfig = fmt.Errorf("child config attempts to override an existing native config property") +) + +// A helper object that collects errors. +type ErrorCollection []error + +func (e ErrorCollection) Error() string { + res := "" + for _, err := range e { + res = fmt.Sprintf("%v\n%v", res, err.Error()) + } + + return res +} + +func (e ErrorCollection) ErrorOrDefault() error { + if len(e) == 0 { + return nil + } + + return e +} + +func (e *ErrorCollection) Append(err error) bool { + if err != nil { + *e = append(*e, err) + return true + } + + return false +} diff --git a/config/errors_test.go b/config/errors_test.go new file mode 100644 index 0000000000..4dc1e72876 --- /dev/null +++ b/config/errors_test.go @@ -0,0 +1,31 @@ +package config + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestErrorCollection_ErrorOrDefault(t *testing.T) { + errs := ErrorCollection{} + assert.Error(t, errs) + assert.NoError(t, errs.ErrorOrDefault()) +} + +func TestErrorCollection_Append(t *testing.T) { + errs := ErrorCollection{} + errs.Append(nil) + errs.Append(fmt.Errorf("this is an actual error")) + assert.Error(t, errs.ErrorOrDefault()) + assert.Len(t, errs, 1) + assert.Len(t, errs.ErrorOrDefault(), 1) +} + +func TestErrorCollection_Error(t *testing.T) { + errs := ErrorCollection{} + errs.Append(nil) + errs.Append(fmt.Errorf("this is an actual error")) + assert.Error(t, errs.ErrorOrDefault()) + assert.Contains(t, errs.ErrorOrDefault().Error(), "this is an actual error") +} diff --git a/config/files/finder.go b/config/files/finder.go new file mode 100644 index 0000000000..e389d88306 --- /dev/null +++ b/config/files/finder.go @@ -0,0 +1,83 @@ +package files + +import ( + "os" + "path/filepath" +) + +const ( + configFileType = "yaml" + configFileName = "config" +) + +var configLocations = [][]string{ + {"."}, + {"/etc", "flyte", "config"}, + {os.ExpandEnv("$GOPATH"), "src", "github.com", "lyft", "flytestdlib"}, +} + +// Check if File / Directory Exists +func exists(path string) (bool, error) { + _, err := os.Stat(path) + if err == nil { + return true, nil + } + + if os.IsNotExist(err) { + return false, nil + } + + return false, err +} + +func isFile(path string) (bool, error) { + s, err := os.Stat(path) + if err != nil { + return false, err + } + + return !s.IsDir(), nil +} + +func contains(slice []string, value string) bool { + for _, s := range slice { + if s == value { + return true + } + } + + return false +} + +// Finds config files in search paths. If searchPaths is empty, it'll look in default locations (see configLocations above) +// If searchPaths is not empty but no configs are found there, it'll still look into configLocations. +// If it found any config file in searchPaths, it'll stop the search. +// searchPaths can contain patterns to match (behavior is OS-dependent). And it'll try to Glob the pattern for any matching +// files. +func FindConfigFiles(searchPaths []string) []string { + res := make([]string, 0, 1) + + for _, location := range searchPaths { + matchedFiles, err := filepath.Glob(location) + if err != nil { + continue + } + + for _, matchedFile := range matchedFiles { + if file, err := isFile(matchedFile); err == nil && file && !contains(res, matchedFile) { + res = append(res, matchedFile) + } + } + } + + if len(res) == 0 { + for _, location := range configLocations { + pathToTest := filepath.Join(append(location, configFileName+"."+configFileType)...) + if b, err := exists(pathToTest); err == nil && b && !contains(res, pathToTest) { + res = append(res, pathToTest) + } + } + } + + return res +} diff --git a/config/files/finder_test.go b/config/files/finder_test.go new file mode 100644 index 0000000000..833ac461f1 --- /dev/null +++ b/config/files/finder_test.go @@ -0,0 +1,28 @@ +package files + +import ( + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestFindConfigFiles(t *testing.T) { + t.Run("Find config-* group", func(t *testing.T) { + files := FindConfigFiles([]string{filepath.Join("testdata", "config*.yaml")}) + assert.Equal(t, 2, len(files)) + }) + + t.Run("Find other-group-* group", func(t *testing.T) { + files := FindConfigFiles([]string{filepath.Join("testdata", "other-group*.yaml")}) + assert.Equal(t, 2, len(files)) + }) + + t.Run("Absolute path", func(t *testing.T) { + files := FindConfigFiles([]string{filepath.Join("testdata", "other-group-1.yaml")}) + assert.Equal(t, 1, len(files)) + + files = FindConfigFiles([]string{filepath.Join("testdata", "other-group-3.yaml")}) + assert.Equal(t, 0, len(files)) + }) +} diff --git a/config/files/testdata/config-1.yaml b/config/files/testdata/config-1.yaml new file mode 100644 index 0000000000..a5b6191d98 --- /dev/null +++ b/config/files/testdata/config-1.yaml @@ -0,0 +1,9 @@ +other-component: + duration-value: 20s + int-val: 4 + string-value: Hey there! + strings: + - hello + - world + - '!' + url-value: http://something.com diff --git a/config/files/testdata/config-2.yaml b/config/files/testdata/config-2.yaml new file mode 100755 index 0000000000..5c28d00a09 --- /dev/null +++ b/config/files/testdata/config-2.yaml @@ -0,0 +1,2 @@ +my-component: + str: Hello World diff --git a/config/files/testdata/other-group-1.yaml b/config/files/testdata/other-group-1.yaml new file mode 100644 index 0000000000..a5b6191d98 --- /dev/null +++ b/config/files/testdata/other-group-1.yaml @@ -0,0 +1,9 @@ +other-component: + duration-value: 20s + int-val: 4 + string-value: Hey there! + strings: + - hello + - world + - '!' + url-value: http://something.com diff --git a/config/files/testdata/other-group-2.yaml b/config/files/testdata/other-group-2.yaml new file mode 100755 index 0000000000..5c28d00a09 --- /dev/null +++ b/config/files/testdata/other-group-2.yaml @@ -0,0 +1,2 @@ +my-component: + str: Hello World diff --git a/config/port.go b/config/port.go new file mode 100644 index 0000000000..87bbc854e2 --- /dev/null +++ b/config/port.go @@ -0,0 +1,56 @@ +package config + +import ( + "encoding/json" + "errors" + "fmt" + "strconv" +) + +// A common port struct that supports Json marshal/unmarshal into/from simple strings/floats. +type Port struct { + Port int `json:"port,omitempty"` +} + +func (p Port) MarshalJSON() ([]byte, error) { + return json.Marshal(p.Port) +} + +func (p *Port) UnmarshalJSON(b []byte) error { + var v interface{} + if err := json.Unmarshal(b, &v); err != nil { + return err + } + + switch value := v.(type) { + case string: + u, err := parsePortString(value) + if err != nil { + return err + } + + p.Port = u + return nil + case float64: + if !validPortRange(value) { + return fmt.Errorf("port must be a valid number between 0 and 65535, inclusive") + } + + p.Port = int(value) + return nil + default: + return errors.New("invalid port") + } +} + +func parsePortString(port string) (int, error) { + if portInt, err := strconv.Atoi(port); err == nil && validPortRange(float64(portInt)) { + return portInt, nil + } + + return 0, fmt.Errorf("port must be a valid number between 1 and 65535, inclusive") +} + +func validPortRange(port float64) bool { + return 0 <= port && port <= 65535 +} diff --git a/config/port_test.go b/config/port_test.go new file mode 100644 index 0000000000..c69c09570d --- /dev/null +++ b/config/port_test.go @@ -0,0 +1,81 @@ +package config + +import ( + "encoding/json" + "fmt" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +type PortTestCase struct { + Expected Port + Input interface{} +} + +func TestPort_MarshalJSON(t *testing.T) { + validPorts := []PortTestCase{ + {Expected: Port{Port: 8080}, Input: 8080}, + {Expected: Port{Port: 1}, Input: 1}, + {Expected: Port{Port: 65535}, Input: "65535"}, + {Expected: Port{Port: 65535}, Input: 65535}, + } + + for i, validPort := range validPorts { + t.Run(fmt.Sprintf("Valid %v [%v]", i, validPort.Input), func(t *testing.T) { + b, err := json.Marshal(validPort.Input) + assert.NoError(t, err) + + actual := Port{} + err = actual.UnmarshalJSON(b) + assert.NoError(t, err) + + assert.True(t, reflect.DeepEqual(validPort.Expected, actual)) + }) + } +} + +func TestPort_UnmarshalJSON(t *testing.T) { + invalidValues := []interface{}{ + "%gh&%ij", + 1000000, + true, + } + + for i, invalidPort := range invalidValues { + t.Run(fmt.Sprintf("Invalid %v", i), func(t *testing.T) { + raw, err := json.Marshal(invalidPort) + assert.NoError(t, err) + + actual := URL{} + err = actual.UnmarshalJSON(raw) + assert.Error(t, err) + }) + } + + t.Run("Invalid json", func(t *testing.T) { + actual := Port{} + err := actual.UnmarshalJSON([]byte{}) + assert.Error(t, err) + }) + + t.Run("Invalid Range", func(t *testing.T) { + b, err := json.Marshal(float64(100000)) + assert.NoError(t, err) + + actual := Port{} + err = actual.UnmarshalJSON(b) + assert.Error(t, err) + }) + + t.Run("Unmarshal Empty", func(t *testing.T) { + p := Port{} + raw, err := json.Marshal(p) + assert.NoError(t, err) + + actual := Port{} + assert.NoError(t, actual.UnmarshalJSON(raw)) + assert.Equal(t, 0, actual.Port) + }) +} diff --git a/config/section.go b/config/section.go new file mode 100644 index 0000000000..41fddd0c1b --- /dev/null +++ b/config/section.go @@ -0,0 +1,230 @@ +package config + +import ( + "context" + "errors" + "fmt" + "reflect" + "strings" + "sync" + + "github.com/lyft/flytestdlib/atomic" + + "github.com/spf13/pflag" +) + +type Section interface { + // Gets a cloned copy of the Config registered to this section. This config instance does not account for any child + // section registered. + GetConfig() Config + + // Gets a function pointer to call when the config has been updated. + GetConfigUpdatedHandler() SectionUpdated + + // Sets the config and sets a bit indicating whether the new config is different when compared to the existing value. + SetConfig(config Config) error + + // Gets a value indicating whether the config has changed since the last call to GetConfigChangedAndClear and clears + // the changed bit. This operation is atomic. + GetConfigChangedAndClear() bool + + // Retrieves the loaded values for section key if one exists, or nil otherwise. + GetSection(key SectionKey) Section + + // Gets all child config sections. + GetSections() SectionMap + + // Registers a section with the config manager. Section keys are case insensitive and must be unique. + // The section object must be passed by reference since it'll be used to unmarshal into. It must also support json + // marshaling. If the section registered gets updated at runtime, the updatesFn will be invoked to handle the propagation + // of changes. + RegisterSectionWithUpdates(key SectionKey, configSection Config, updatesFn SectionUpdated) (Section, error) + + // Registers a section with the config manager. Section keys are case insensitive and must be unique. + // The section object must be passed by reference since it'll be used to unmarshal into. It must also support json + // marshaling. If the section registered gets updated at runtime, the updatesFn will be invoked to handle the propagation + // of changes. + MustRegisterSectionWithUpdates(key SectionKey, configSection Config, updatesFn SectionUpdated) Section + + // Registers a section with the config manager. Section keys are case insensitive and must be unique. + // The section object must be passed by reference since it'll be used to unmarshal into. It must also support json + // marshaling. + RegisterSection(key SectionKey, configSection Config) (Section, error) + + // Registers a section with the config manager. Section keys are case insensitive and must be unique. + // The section object must be passed by reference since it'll be used to unmarshal into. It must also support json + // marshaling. + MustRegisterSection(key SectionKey, configSection Config) Section +} + +type Config = interface{} +type SectionKey = string +type SectionMap map[SectionKey]Section + +// A section can optionally implements this interface to add its fields as cmdline arguments. +type PFlagProvider interface { + GetPFlagSet(prefix string) *pflag.FlagSet +} + +type SectionUpdated func(ctx context.Context, newValue Config) + +// Global section to use with any root-level config sections registered. +var rootSection = NewRootSection() + +type section struct { + config Config + handler SectionUpdated + isDirty atomic.Bool + sections SectionMap + lockObj sync.RWMutex +} + +// Gets the global root section. +func GetRootSection() Section { + return rootSection +} + +func MustRegisterSection(key SectionKey, configSection Config) Section { + s, err := RegisterSection(key, configSection) + if err != nil { + panic(err) + } + + return s +} + +func (r *section) MustRegisterSection(key SectionKey, configSection Config) Section { + s, err := r.RegisterSection(key, configSection) + if err != nil { + panic(err) + } + + return s +} + +// Registers a section with the config manager. Section keys are case insensitive and must be unique. +// The section object must be passed by reference since it'll be used to unmarshal into. It must also support json +// marshaling. +func RegisterSection(key SectionKey, configSection Config) (Section, error) { + return rootSection.RegisterSection(key, configSection) +} + +func (r *section) RegisterSection(key SectionKey, configSection Config) (Section, error) { + return r.RegisterSectionWithUpdates(key, configSection, nil) +} + +func MustRegisterSectionWithUpdates(key SectionKey, configSection Config, updatesFn SectionUpdated) Section { + s, err := RegisterSectionWithUpdates(key, configSection, updatesFn) + if err != nil { + panic(err) + } + + return s +} + +func (r *section) MustRegisterSectionWithUpdates(key SectionKey, configSection Config, updatesFn SectionUpdated) Section { + s, err := r.RegisterSectionWithUpdates(key, configSection, updatesFn) + if err != nil { + panic(err) + } + + return s +} + +// Registers a section with the config manager. Section keys are case insensitive and must be unique. +// The section object must be passed by reference since it'll be used to unmarshal into. It must also support json +// marshaling. If the section registered gets updated at runtime, the updatesFn will be invoked to handle the propagation +// of changes. +func RegisterSectionWithUpdates(key SectionKey, configSection Config, updatesFn SectionUpdated) (Section, error) { + return rootSection.RegisterSectionWithUpdates(key, configSection, updatesFn) +} + +func (r *section) RegisterSectionWithUpdates(key SectionKey, configSection Config, updatesFn SectionUpdated) (Section, error) { + r.lockObj.Lock() + defer r.lockObj.Unlock() + + key = strings.ToLower(key) + + if len(key) == 0 { + return nil, errors.New("key must be a non-zero string") + } + + if configSection == nil { + return nil, fmt.Errorf("configSection must be a non-nil pointer. SectionKey: %v", key) + } + + if reflect.TypeOf(configSection).Kind() != reflect.Ptr { + return nil, fmt.Errorf("section must be a Pointer. SectionKey: %v", key) + } + + if _, alreadyExists := r.sections[key]; alreadyExists { + return nil, fmt.Errorf("key already exists [%v]", key) + } + + section := NewSection(configSection, updatesFn) + r.sections[key] = section + return section, nil +} + +// Retrieves the loaded values for section key if one exists, or nil otherwise. +func GetSection(key SectionKey) Section { + return rootSection.GetSection(key) +} + +func (r *section) GetSection(key SectionKey) Section { + r.lockObj.RLock() + defer r.lockObj.RUnlock() + + key = strings.ToLower(key) + + if section, alreadyExists := r.sections[key]; alreadyExists { + return section + } + + return nil +} + +func (r *section) GetSections() SectionMap { + return r.sections +} + +func (r *section) GetConfig() Config { + r.lockObj.RLock() + defer r.lockObj.RUnlock() + + return r.config +} + +func (r *section) SetConfig(c Config) error { + r.lockObj.Lock() + defer r.lockObj.Unlock() + + if !DeepEqual(r.config, c) { + r.config = c + r.isDirty.Store(true) + } + + return nil +} + +func (r *section) GetConfigUpdatedHandler() SectionUpdated { + return r.handler +} + +func (r *section) GetConfigChangedAndClear() bool { + return r.isDirty.CompareAndSwap(true, false) +} + +func NewSection(configSection Config, updatesFn SectionUpdated) Section { + return §ion{ + config: configSection, + handler: updatesFn, + isDirty: atomic.NewBool(false), + sections: map[SectionKey]Section{}, + lockObj: sync.RWMutex{}, + } +} + +func NewRootSection() Section { + return NewSection(nil, nil) +} diff --git a/config/section_test.go b/config/section_test.go new file mode 100644 index 0000000000..5b314bb339 --- /dev/null +++ b/config/section_test.go @@ -0,0 +1,119 @@ +package config + +import ( + "flag" + "fmt" + "io/ioutil" + "os" + "path/filepath" + "reflect" + "testing" + "time" + + "k8s.io/apimachinery/pkg/util/rand" + + "github.com/ghodss/yaml" + "github.com/lyft/flytestdlib/internal/utils" + "github.com/spf13/pflag" + + "github.com/stretchr/testify/assert" +) + +// Make sure existing config file(s) parse correctly before overriding them with this flag! +var update = flag.Bool("update", false, "Updates testdata") + +type MyComponentConfig struct { + StringValue string `json:"str"` +} + +type OtherComponentConfig struct { + DurationValue Duration `json:"duration-value"` + URLValue URL `json:"url-value"` + StringValue string `json:"string-value"` + IntValue int `json:"int-val"` + StringArray []string `json:"strings"` +} + +func (MyComponentConfig) GetPFlagSet(prefix string) *pflag.FlagSet { + cmdFlags := pflag.NewFlagSet("MyComponentConfig", pflag.ExitOnError) + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "str"), "hello world", "life is short") + return cmdFlags +} + +func (OtherComponentConfig) GetPFlagSet(prefix string) *pflag.FlagSet { + cmdFlags := pflag.NewFlagSet("MyComponentConfig", pflag.ExitOnError) + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "string-value"), "hello world", "life is short") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "duration-value"), "20s", "") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "int-val"), 4, "this is an important flag") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "url-value"), "http://blah.com", "Sets the type of storage to configure [s3/minio/local/mem].") + return cmdFlags +} + +type TestConfig struct { + MyComponentConfig MyComponentConfig `json:"my-component"` + OtherComponentConfig OtherComponentConfig `json:"other-component"` +} + +func TestMarshal(t *testing.T) { + expected := TestConfig{ + MyComponentConfig: MyComponentConfig{ + StringValue: "Hello World", + }, + OtherComponentConfig: OtherComponentConfig{ + StringValue: "Hey there!", + IntValue: 4, + URLValue: URL{URL: utils.MustParseURL("http://something.com")}, + DurationValue: Duration{Duration: time.Second * 20}, + StringArray: []string{"hello", "world", "!"}, + }, + } + + configPath := filepath.Join("testdata", "config.yaml") + if *update { + t.Log("Updating config file.") + raw, err := yaml.Marshal(expected) + assert.NoError(t, err) + assert.NoError(t, ioutil.WriteFile(configPath, raw, os.ModePerm)) + } + + r := TestConfig{} + raw, err := ioutil.ReadFile(configPath) + assert.NoError(t, err) + assert.NoError(t, yaml.Unmarshal(raw, &r)) + assert.True(t, reflect.DeepEqual(expected, r)) +} + +func TestRegisterSection(t *testing.T) { + t.Run("New Section", func(t *testing.T) { + _, err := RegisterSection(rand.String(6), &TestConfig{}) + assert.NoError(t, err) + }) + + t.Run("Duplicate", func(t *testing.T) { + s := rand.String(6) + _, err := RegisterSection(s, &TestConfig{}) + assert.NoError(t, err) + _, err = RegisterSection(s, &TestConfig{}) + assert.Error(t, err) + }) + + t.Run("Register Nested", func(t *testing.T) { + root := NewRootSection() + s := rand.String(6) + _, err := root.RegisterSection(s, &TestConfig{}) + assert.NoError(t, err) + _, err = root.RegisterSection(s, &TestConfig{}) + assert.Error(t, err) + }) +} + +func TestGetSection(t *testing.T) { + sectionName := rand.String(6) + actual1, err := RegisterSection(sectionName, &TestConfig{}) + assert.NoError(t, err) + assert.Equal(t, reflect.TypeOf(&TestConfig{}), reflect.TypeOf(actual1.GetConfig())) + + actual2 := GetSection(sectionName) + assert.NotNil(t, actual2) + assert.Equal(t, reflect.TypeOf(&TestConfig{}), reflect.TypeOf(actual2.GetConfig())) +} diff --git a/config/testdata/config.yaml b/config/testdata/config.yaml new file mode 100755 index 0000000000..2f20ad97b5 --- /dev/null +++ b/config/testdata/config.yaml @@ -0,0 +1,11 @@ +my-component: + str: Hello World +other-component: + duration-value: 20s + int-val: 4 + string-value: Hey there! + strings: + - hello + - world + - '!' + url-value: http://something.com diff --git a/config/tests/accessor_test.go b/config/tests/accessor_test.go new file mode 100644 index 0000000000..34d86237e9 --- /dev/null +++ b/config/tests/accessor_test.go @@ -0,0 +1,641 @@ +package tests + +import ( + "context" + "crypto/rand" + "encoding/binary" + "fmt" + "io/ioutil" + "os" + "os/exec" + "path" + "path/filepath" + "reflect" + "runtime" + "strings" + "testing" + "time" + + "github.com/fsnotify/fsnotify" + + k8sRand "k8s.io/apimachinery/pkg/util/rand" + + "github.com/lyft/flytestdlib/config" + "github.com/lyft/flytestdlib/internal/utils" + "github.com/spf13/pflag" + + "github.com/ghodss/yaml" + "github.com/stretchr/testify/assert" +) + +type accessorCreatorFn func(registry config.Section, configPath string) config.Accessor + +func getRandInt() uint64 { + c := 10 + b := make([]byte, c) + _, err := rand.Read(b) + if err != nil { + return 0 + } + + return binary.BigEndian.Uint64(b) +} + +func tempFileName(pattern string) string { + // TODO: Remove this hack after we use Go1.11 everywhere: + // https://github.com/golang/go/commit/191efbc419d7e5dec842c20841f6f716da4b561d + + var prefix, suffix string + if pos := strings.LastIndex(pattern, "*"); pos != -1 { + prefix, suffix = pattern[:pos], pattern[pos+1:] + } else { + prefix = pattern + } + + return filepath.Join(os.TempDir(), prefix+k8sRand.String(6)+suffix) +} + +func populateConfigData(configPath string) (TestConfig, error) { + expected := TestConfig{ + MyComponentConfig: MyComponentConfig{ + StringValue: fmt.Sprintf("Hello World %v", getRandInt()), + }, + OtherComponentConfig: OtherComponentConfig{ + StringValue: fmt.Sprintf("Hello World %v", getRandInt()), + URLValue: config.URL{URL: utils.MustParseURL("http://something.com")}, + DurationValue: config.Duration{Duration: time.Second * 20}, + }, + } + + raw, err := yaml.Marshal(expected) + if err != nil { + return TestConfig{}, err + } + + return expected, ioutil.WriteFile(configPath, raw, os.ModePerm) +} + +func TestGetEmptySection(t *testing.T) { + t.Run("empty", func(t *testing.T) { + r := config.GetSection("Empty") + assert.Nil(t, r) + }) +} + +type ComplexType struct { + IntValue int `json:"int-val"` +} + +type ComplexTypeArray []ComplexType + +type ConfigWithLists struct { + ListOfStuff []ComplexType `json:"list"` + StringValue string `json:"string-val"` +} + +type ConfigWithMaps struct { + MapOfStuff map[string]ComplexType `json:"m"` + MapWithoutJSON map[string]ComplexType +} + +type ConfigWithJSONTypes struct { + Duration config.Duration `json:"duration"` +} + +func TestAccessor_InitializePflags(t *testing.T) { + for _, provider := range providers { + t.Run(fmt.Sprintf("[%v] Unused flag", provider(config.Options{}).ID()), func(t *testing.T) { + reg := config.NewRootSection() + _, err := reg.RegisterSection(MyComponentSectionKey, &MyComponentConfig{}) + assert.NoError(t, err) + v := provider(config.Options{ + SearchPaths: []string{filepath.Join("testdata", "config.yaml")}, + RootSection: reg, + }) + + set := pflag.NewFlagSet("test", pflag.ContinueOnError) + set.String("flag1", "123", "") + v.InitializePflags(set) + assert.NoError(t, v.UpdateConfig(context.TODO())) + }) + + t.Run(fmt.Sprintf("[%v] Override string value", provider(config.Options{}).ID()), func(t *testing.T) { + reg := config.NewRootSection() + _, err := reg.RegisterSection(MyComponentSectionKey, &MyComponentConfig{}) + assert.NoError(t, err) + + v := provider(config.Options{ + SearchPaths: []string{filepath.Join("testdata", "config.yaml")}, + RootSection: reg, + }) + + set := pflag.NewFlagSet("test", pflag.ContinueOnError) + v.InitializePflags(set) + key := "MY_COMPONENT.STR2" + assert.NoError(t, os.Setenv(key, "123")) + defer func() { assert.NoError(t, os.Unsetenv(key)) }() + assert.NoError(t, v.UpdateConfig(context.TODO())) + r := reg.GetSection(MyComponentSectionKey).GetConfig().(*MyComponentConfig) + assert.Equal(t, "123", r.StringValue2) + }) + + t.Run(fmt.Sprintf("[%v] Parse from config file", provider(config.Options{}).ID()), func(t *testing.T) { + reg := config.NewRootSection() + _, err := reg.RegisterSection(MyComponentSectionKey, &MyComponentConfig{}) + assert.NoError(t, err) + + _, err = reg.RegisterSection(OtherComponentSectionKey, &OtherComponentConfig{}) + assert.NoError(t, err) + + v := provider(config.Options{ + SearchPaths: []string{filepath.Join("testdata", "config.yaml")}, + RootSection: reg, + }) + + set := pflag.NewFlagSet("test", pflag.ExitOnError) + v.InitializePflags(set) + assert.NoError(t, v.UpdateConfig(context.TODO())) + r := reg.GetSection(MyComponentSectionKey).GetConfig().(*MyComponentConfig) + assert.Equal(t, "Hello World", r.StringValue) + otherC := reg.GetSection(OtherComponentSectionKey).GetConfig().(*OtherComponentConfig) + assert.Equal(t, 4, otherC.IntValue) + assert.Equal(t, []string{"default value"}, otherC.StringArrayWithDefaults) + }) + } +} + +func TestStrictAccessor(t *testing.T) { + for _, provider := range providers { + t.Run(fmt.Sprintf("[%v] Bad config", provider(config.Options{}).ID()), func(t *testing.T) { + reg := config.NewRootSection() + v := provider(config.Options{ + StrictMode: true, + SearchPaths: []string{filepath.Join("testdata", "bad_config.yaml")}, + RootSection: reg, + }) + + set := pflag.NewFlagSet("test", pflag.ExitOnError) + v.InitializePflags(set) + assert.Error(t, v.UpdateConfig(context.TODO())) + }) + + t.Run(fmt.Sprintf("[%v] Set through env", provider(config.Options{}).ID()), func(t *testing.T) { + reg := config.NewRootSection() + _, err := reg.RegisterSection(MyComponentSectionKey, &MyComponentConfig{}) + assert.NoError(t, err) + + _, err = reg.RegisterSection(OtherComponentSectionKey, &OtherComponentConfig{}) + assert.NoError(t, err) + + v := provider(config.Options{ + StrictMode: true, + SearchPaths: []string{filepath.Join("testdata", "config.yaml")}, + RootSection: reg, + }) + + set := pflag.NewFlagSet("other-component.string-value", pflag.ExitOnError) + v.InitializePflags(set) + + key := "OTHER_COMPONENT.STRING_VALUE" + assert.NoError(t, os.Setenv(key, "set from env")) + defer func() { assert.NoError(t, os.Unsetenv(key)) }() + assert.NoError(t, v.UpdateConfig(context.TODO())) + }) + } +} + +func TestAccessor_UpdateConfig(t *testing.T) { + for _, provider := range providers { + t.Run(fmt.Sprintf("[%v] Static File", provider(config.Options{}).ID()), func(t *testing.T) { + reg := config.NewRootSection() + _, err := reg.RegisterSection(MyComponentSectionKey, &MyComponentConfig{}) + assert.NoError(t, err) + + v := provider(config.Options{ + SearchPaths: []string{filepath.Join("testdata", "config.yaml")}, + RootSection: reg, + }) + + assert.NoError(t, v.UpdateConfig(context.TODO())) + r := reg.GetSection(MyComponentSectionKey).GetConfig().(*MyComponentConfig) + assert.Equal(t, "Hello World", r.StringValue) + }) + + t.Run(fmt.Sprintf("[%v] Nested", provider(config.Options{}).ID()), func(t *testing.T) { + root := config.NewRootSection() + section, err := root.RegisterSection(MyComponentSectionKey, &MyComponentConfig{}) + assert.NoError(t, err) + + _, err = section.RegisterSection("nested", &OtherComponentConfig{}) + assert.NoError(t, err) + + v := provider(config.Options{ + SearchPaths: []string{filepath.Join("testdata", "nested_config.yaml")}, + RootSection: root, + }) + + assert.NoError(t, v.UpdateConfig(context.TODO())) + r := root.GetSection(MyComponentSectionKey).GetConfig().(*MyComponentConfig) + assert.Equal(t, "Hello World", r.StringValue) + + nested := section.GetSection("nested").GetConfig().(*OtherComponentConfig) + assert.Equal(t, "Hey there!", nested.StringValue) + }) + + t.Run(fmt.Sprintf("[%v] Array Configs", provider(config.Options{}).ID()), func(t *testing.T) { + root := config.NewRootSection() + section, err := root.RegisterSection(MyComponentSectionKey, &MyComponentConfig{}) + assert.NoError(t, err) + + _, err = section.RegisterSection("nested", &ComplexTypeArray{}) + assert.NoError(t, err) + + _, err = root.RegisterSection("array-config", &ComplexTypeArray{}) + assert.NoError(t, err) + + v := provider(config.Options{ + SearchPaths: []string{filepath.Join("testdata", "array_configs.yaml")}, + RootSection: root, + }) + + assert.NoError(t, v.UpdateConfig(context.TODO())) + r := root.GetSection(MyComponentSectionKey).GetConfig().(*MyComponentConfig) + assert.Equal(t, "Hello World", r.StringValue) + + nested := section.GetSection("nested").GetConfig().(*ComplexTypeArray) + assert.Len(t, *nested, 1) + assert.Equal(t, 1, (*nested)[0].IntValue) + + topLevel := root.GetSection("array-config").GetConfig().(*ComplexTypeArray) + assert.Len(t, *topLevel, 2) + assert.Equal(t, 4, (*topLevel)[1].IntValue) + }) + + t.Run(fmt.Sprintf("[%v] Override in Env Var", provider(config.Options{}).ID()), func(t *testing.T) { + reg := config.NewRootSection() + _, err := reg.RegisterSection(MyComponentSectionKey, &MyComponentConfig{}) + assert.NoError(t, err) + + v := provider(config.Options{ + SearchPaths: []string{filepath.Join("testdata", "config.yaml")}, + RootSection: reg, + }) + key := strings.ToUpper("my-component.str") + assert.NoError(t, os.Setenv(key, "Set From Env")) + defer func() { assert.NoError(t, os.Unsetenv(key)) }() + assert.NoError(t, v.UpdateConfig(context.TODO())) + r := reg.GetSection(MyComponentSectionKey).GetConfig().(*MyComponentConfig) + assert.Equal(t, "Set From Env", r.StringValue) + }) + + t.Run(fmt.Sprintf("[%v] Override in Env Var no config file", provider(config.Options{}).ID()), func(t *testing.T) { + reg := config.NewRootSection() + _, err := reg.RegisterSection(MyComponentSectionKey, &MyComponentConfig{}) + assert.NoError(t, err) + + v := provider(config.Options{RootSection: reg}) + key := strings.ToUpper("my-component.str3") + assert.NoError(t, os.Setenv(key, "Set From Env")) + defer func() { assert.NoError(t, os.Unsetenv(key)) }() + assert.NoError(t, v.UpdateConfig(context.TODO())) + r := reg.GetSection(MyComponentSectionKey).GetConfig().(*MyComponentConfig) + assert.Equal(t, "Set From Env", r.StringValue3) + }) + + t.Run(fmt.Sprintf("[%v] Change handler", provider(config.Options{}).ID()), func(t *testing.T) { + configFile := tempFileName("config-*.yaml") + defer func() { assert.NoError(t, os.Remove(configFile)) }() + _, err := populateConfigData(configFile) + assert.NoError(t, err) + + reg := config.NewRootSection() + _, err = reg.RegisterSection(MyComponentSectionKey, &MyComponentConfig{}) + assert.NoError(t, err) + + opts := config.Options{ + SearchPaths: []string{configFile}, + RootSection: reg, + } + v := provider(opts) + err = v.UpdateConfig(context.TODO()) + assert.NoError(t, err) + + r := reg.GetSection(MyComponentSectionKey).GetConfig().(*MyComponentConfig) + firstValue := r.StringValue + + fileUpdated, err := beginWaitForFileChange(configFile) + assert.NoError(t, err) + + _, err = populateConfigData(configFile) + assert.NoError(t, err) + + // Simulate filewatcher event + assert.NoError(t, waitForFileChangeOrTimeout(fileUpdated)) + + time.Sleep(2 * time.Second) + + r = reg.GetSection(MyComponentSectionKey).GetConfig().(*MyComponentConfig) + secondValue := r.StringValue + assert.NotEqual(t, firstValue, secondValue) + }) + + t.Run(fmt.Sprintf("[%v] Change handler k8s configmaps", provider(config.Options{}).ID()), func(t *testing.T) { + // 1. Create Dir structure + watchDir, configFile, cleanup := newSymlinkedConfigFile(t) + defer cleanup() + + // Independently watch for when symlink underlying change happens to know when do we expect accessor to have picked up + // the changes + fileUpdated, err := beginWaitForFileChange(configFile) + assert.NoError(t, err) + + // 2. Start accessor with the symlink as config location + reg := config.NewRootSection() + section, err := reg.RegisterSection(MyComponentSectionKey, &MyComponentConfig{}) + assert.NoError(t, err) + + opts := config.Options{ + SearchPaths: []string{configFile}, + RootSection: reg, + } + v := provider(opts) + err = v.UpdateConfig(context.TODO()) + assert.NoError(t, err) + + r := section.GetConfig().(*MyComponentConfig) + firstValue := r.StringValue + + // 3. Now update /data symlink to point to data2 + dataDir2 := path.Join(watchDir, "data2") + err = os.Mkdir(dataDir2, os.ModePerm) + assert.NoError(t, err) + + configFile2 := path.Join(dataDir2, "config.yaml") + _, err = populateConfigData(configFile2) + assert.NoError(t, err) + + // change the symlink using the `ln -sfn` command + err = changeSymLink(dataDir2, path.Join(watchDir, "data")) + assert.NoError(t, err) + + t.Logf("New config Location: %v", configFile2) + + // Wait for filewatcher event + assert.NoError(t, waitForFileChangeOrTimeout(fileUpdated)) + + time.Sleep(2 * time.Second) + + r = section.GetConfig().(*MyComponentConfig) + secondValue := r.StringValue + // Make sure values have changed + assert.NotEqual(t, firstValue, secondValue) + }) + } +} + +func changeSymLink(targetPath, symLink string) error { + if runtime.GOOS == "windows" { + tmpLink := tempFileName("temp-sym-link-*") + err := exec.Command("mklink", filepath.Clean(tmpLink), filepath.Clean(targetPath)).Run() + if err != nil { + return err + } + + err = exec.Command("copy", "/l", "/y", filepath.Clean(tmpLink), filepath.Clean(symLink)).Run() + if err != nil { + return err + } + + return exec.Command("del", filepath.Clean(tmpLink)).Run() + } + + return exec.Command("ln", "-sfn", filepath.Clean(targetPath), filepath.Clean(symLink)).Run() +} + +// 1. Create Dir structure: +// |_ data1 +// |_ config.yaml +// |_ data (symlink for data1) +// |_ config.yaml (symlink for data/config.yaml -recursively a symlink of data1/config.yaml) +func newSymlinkedConfigFile(t *testing.T) (watchDir, configFile string, cleanup func()) { + watchDir, err := ioutil.TempDir("", "config-test-") + assert.NoError(t, err) + + dataDir1 := path.Join(watchDir, "data1") + err = os.Mkdir(dataDir1, os.ModePerm) + assert.NoError(t, err) + + realConfigFile := path.Join(dataDir1, "config.yaml") + t.Logf("Real config file location: %s\n", realConfigFile) + _, err = populateConfigData(realConfigFile) + assert.NoError(t, err) + + cleanup = func() { + assert.NoError(t, os.RemoveAll(watchDir)) + } + + // now, symlink the tm `data1` dir to `data` in the baseDir + assert.NoError(t, os.Symlink(dataDir1, path.Join(watchDir, "data"))) + + // and link the `/datadir1/config.yaml` to `/config.yaml` + configFile = path.Join(watchDir, "config.yaml") + assert.NoError(t, os.Symlink(path.Join(watchDir, "data", "config.yaml"), configFile)) + + t.Logf("Config file location: %s\n", path.Join(watchDir, "config.yaml")) + return watchDir, configFile, cleanup +} + +func waitForFileChangeOrTimeout(done chan error) error { + timeout := make(chan bool, 1) + go func() { + time.Sleep(5 * time.Second) + timeout <- true + }() + + for { + select { + case <-timeout: + return fmt.Errorf("timed out") + case err := <-done: + return err + } + } +} + +func beginWaitForFileChange(filename string) (done chan error, terminalErr error) { + watcher, err := fsnotify.NewWatcher() + if err != nil { + return nil, err + } + + configFile := filepath.Clean(filename) + realConfigFile, err := filepath.EvalSymlinks(configFile) + if err != nil { + return nil, err + } + + configDir, _ := filepath.Split(configFile) + + done = make(chan error) + go func() { + for { + select { + case event := <-watcher.Events: + // we only care about the config file + currentConfigFile, err := filepath.EvalSymlinks(filename) + if err != nil { + closeErr := watcher.Close() + if closeErr != nil { + done <- closeErr + } else { + done <- err + } + + return + } + + // We only care about the config file with the following cases: + // 1 - if the config file was modified or created + // 2 - if the real path to the config file changed (eg: k8s ConfigMap replacement) + const writeOrCreateMask = fsnotify.Write | fsnotify.Create + if (filepath.Clean(event.Name) == configFile && + event.Op&writeOrCreateMask != 0) || + (currentConfigFile != "" && currentConfigFile != realConfigFile) { + realConfigFile = currentConfigFile + closeErr := watcher.Close() + if closeErr != nil { + fmt.Printf("Close Watcher error: %v\n", closeErr) + } else { + done <- nil + } + + return + } else if filepath.Clean(event.Name) == configFile && + event.Op&fsnotify.Remove&fsnotify.Remove != 0 { + closeErr := watcher.Close() + if closeErr != nil { + fmt.Printf("Close Watcher error: %v\n", closeErr) + } else { + done <- nil + } + + return + } + case err, ok := <-watcher.Errors: + if ok { + fmt.Printf("Watcher error: %v\n", err) + closeErr := watcher.Close() + if closeErr != nil { + fmt.Printf("Close Watcher error: %v\n", closeErr) + } + } + + done <- nil + return + } + } + }() + + err = watcher.Add(configDir) + if err != nil { + return nil, err + } + + return done, err +} + +func testTypes(accessor accessorCreatorFn) func(t *testing.T) { + return func(t *testing.T) { + t.Run("ArrayConfigType", func(t *testing.T) { + expected := ComplexTypeArray{ + {IntValue: 1}, + {IntValue: 4}, + } + + runEqualTest(t, accessor, &expected, &ComplexTypeArray{}) + }) + + t.Run("Lists", func(t *testing.T) { + expected := ConfigWithLists{ + ListOfStuff: []ComplexType{ + {IntValue: 1}, + {IntValue: 4}, + }, + } + + runEqualTest(t, accessor, &expected, &ConfigWithLists{}) + }) + + t.Run("Maps", func(t *testing.T) { + expected := ConfigWithMaps{ + MapOfStuff: map[string]ComplexType{ + "item1": {IntValue: 1}, + "item2": {IntValue: 3}, + }, + MapWithoutJSON: map[string]ComplexType{ + "it-1": {IntValue: 5}, + }, + } + + runEqualTest(t, accessor, &expected, &ConfigWithMaps{}) + }) + + t.Run("JsonUnmarshalableTypes", func(t *testing.T) { + expected := ConfigWithJSONTypes{ + Duration: config.Duration{ + Duration: time.Second * 10, + }, + } + + runEqualTest(t, accessor, &expected, &ConfigWithJSONTypes{}) + }) + } +} + +func runEqualTest(t *testing.T, accessor accessorCreatorFn, expected interface{}, emptyType interface{}) { + assert.NotPanics(t, func() { + reflect.TypeOf(expected).Elem() + }, "expected must be a Pointer type. Instead, it was %v", reflect.TypeOf(expected)) + + assert.Equal(t, reflect.TypeOf(expected), reflect.TypeOf(emptyType)) + + rootSection := config.NewRootSection() + sectionKey := fmt.Sprintf("rand-key-%v", getRandInt()%2000) + _, err := rootSection.RegisterSection(sectionKey, emptyType) + assert.NoError(t, err) + + m := map[string]interface{}{ + sectionKey: expected, + } + + raw, err := yaml.Marshal(m) + assert.NoError(t, err) + f := tempFileName("test_type_*.yaml") + assert.NoError(t, err) + defer func() { assert.NoError(t, os.Remove(f)) }() + + assert.NoError(t, ioutil.WriteFile(f, raw, os.ModePerm)) + t.Logf("Generated yaml: %v", string(raw)) + assert.NoError(t, accessor(rootSection, f).UpdateConfig(context.TODO())) + + res := rootSection.GetSection(sectionKey).GetConfig() + t.Logf("Expected: %+v", expected) + t.Logf("Actual: %+v", res) + assert.True(t, reflect.DeepEqual(res, expected)) +} + +func TestAccessor_Integration(t *testing.T) { + accessorsToTest := make([]accessorCreatorFn, 0, len(providers)) + for _, provider := range providers { + accessorsToTest = append(accessorsToTest, func(r config.Section, configPath string) config.Accessor { + return provider(config.Options{ + SearchPaths: []string{configPath}, + RootSection: r, + }) + }) + } + + for _, accessor := range accessorsToTest { + t.Run(fmt.Sprintf(testNameFormatter, accessor(nil, "").ID(), "Types"), testTypes(accessor)) + } +} diff --git a/config/tests/config_cmd_test.go b/config/tests/config_cmd_test.go new file mode 100644 index 0000000000..3b15268292 --- /dev/null +++ b/config/tests/config_cmd_test.go @@ -0,0 +1,92 @@ +package tests + +import ( + "bytes" + "fmt" + "os" + "testing" + + "github.com/lyft/flytestdlib/config" + + "github.com/spf13/cobra" + "github.com/stretchr/testify/assert" +) + +func executeCommand(root *cobra.Command, args ...string) (output string, err error) { + _, output, err = executeCommandC(root, args...) + return output, err +} + +func executeCommandC(root *cobra.Command, args ...string) (c *cobra.Command, output string, err error) { + buf := new(bytes.Buffer) + root.SetOutput(buf) + root.SetArgs(args) + + c, err = root.ExecuteC() + + return c, buf.String(), err +} + +func TestDiscoverCommand(t *testing.T) { + for _, provider := range providers { + t.Run(fmt.Sprintf(testNameFormatter, provider(config.Options{}).ID(), "No config file"), func(t *testing.T) { + cmd := config.NewConfigCommand(provider) + output, err := executeCommand(cmd, config.CommandDiscover) + assert.NoError(t, err) + assert.Contains(t, output, "Couldn't find a config file.") + }) + + t.Run(fmt.Sprintf(testNameFormatter, provider(config.Options{}).ID(), "Valid config file"), func(t *testing.T) { + dir, err := os.Getwd() + assert.NoError(t, err) + wd := os.ExpandEnv("$PWD/testdata") + err = os.Chdir(wd) + assert.NoError(t, err) + defer func() { assert.NoError(t, os.Chdir(dir)) }() + + cmd := config.NewConfigCommand(provider) + output, err := executeCommand(cmd, config.CommandDiscover) + assert.NoError(t, err) + assert.Contains(t, output, "Config") + }) + } +} + +func TestValidateCommand(t *testing.T) { + for _, provider := range providers { + t.Run(fmt.Sprintf(testNameFormatter, provider(config.Options{}).ID(), "No config file"), func(t *testing.T) { + cmd := config.NewConfigCommand(provider) + output, err := executeCommand(cmd, config.CommandValidate) + assert.NoError(t, err) + assert.Contains(t, output, "Couldn't find a config file.") + }) + + t.Run(fmt.Sprintf(testNameFormatter, provider(config.Options{}).ID(), "Invalid Config file"), func(t *testing.T) { + dir, err := os.Getwd() + assert.NoError(t, err) + wd := os.ExpandEnv("$PWD/testdata") + err = os.Chdir(wd) + assert.NoError(t, err) + defer func() { assert.NoError(t, os.Chdir(dir)) }() + + cmd := config.NewConfigCommand(provider) + output, err := executeCommand(cmd, config.CommandValidate, "--file=bad_config.yaml", "--strict") + assert.Error(t, err) + assert.Contains(t, output, "Failed") + }) + + t.Run(fmt.Sprintf(testNameFormatter, provider(config.Options{}).ID(), "Valid config file"), func(t *testing.T) { + dir, err := os.Getwd() + assert.NoError(t, err) + wd := os.ExpandEnv("$PWD/testdata") + err = os.Chdir(wd) + assert.NoError(t, err) + defer func() { assert.NoError(t, os.Chdir(dir)) }() + + cmd := config.NewConfigCommand(provider) + output, err := executeCommand(cmd, config.CommandValidate) + assert.NoError(t, err) + assert.Contains(t, output, "successfully") + }) + } +} diff --git a/config/tests/testdata/array_configs.yaml b/config/tests/testdata/array_configs.yaml new file mode 100644 index 0000000000..6a02a280a6 --- /dev/null +++ b/config/tests/testdata/array_configs.yaml @@ -0,0 +1,7 @@ +my-component: + str: Hello World + nested: + - int-val: 1 +array-config: + - int-val: 1 + - int-val: 4 diff --git a/config/tests/testdata/bad_config.yaml b/config/tests/testdata/bad_config.yaml new file mode 100644 index 0000000000..c0707abdeb --- /dev/null +++ b/config/tests/testdata/bad_config.yaml @@ -0,0 +1,13 @@ +my-component: + str: Hello World +other-component: + duration-value: 20s + int-val: 4 + string-value: Hey there! + strings: + - hello + - world + - '!' + url-value: http://something.com + unknown-key: "something" + diff --git a/config/tests/testdata/config.yaml b/config/tests/testdata/config.yaml new file mode 100755 index 0000000000..ca78698fae --- /dev/null +++ b/config/tests/testdata/config.yaml @@ -0,0 +1,11 @@ +my-component: + str: Hello World +other-component: + duration-value: 20s + int-val: 4 + string-value: Hey there! + strings: + - hello + - world + - '!' + url-value: http://something.com diff --git a/config/tests/testdata/nested_config.yaml b/config/tests/testdata/nested_config.yaml new file mode 100755 index 0000000000..321f563a42 --- /dev/null +++ b/config/tests/testdata/nested_config.yaml @@ -0,0 +1,11 @@ +my-component: + str: Hello World + nested: + duration-value: 20s + int-val: 4 + string-value: Hey there! + strings: + - hello + - world + - '!' + url-value: http://something.com diff --git a/config/tests/types_test.go b/config/tests/types_test.go new file mode 100644 index 0000000000..c6150e93fa --- /dev/null +++ b/config/tests/types_test.go @@ -0,0 +1,66 @@ +package tests + +import ( + "fmt" + + "github.com/lyft/flytestdlib/config" + "github.com/lyft/flytestdlib/config/viper" + "github.com/spf13/pflag" +) + +const testNameFormatter = "[%v] %v" + +var providers = []config.AccessorProvider{viper.NewAccessor} + +type MyComponentConfig struct { + StringValue string `json:"str"` + StringValue2 string `json:"str2"` + StringValue3 string `json:"str3"` +} + +type OtherComponentConfig struct { + DurationValue config.Duration `json:"duration-value"` + URLValue config.URL `json:"url-value"` + StringValue string `json:"string-value"` + IntValue int `json:"int-val"` + StringArray []string `json:"strings"` + StringArrayWithDefaults []string `json:"strings-def"` +} + +func (MyComponentConfig) GetPFlagSet(prefix string) *pflag.FlagSet { + cmdFlags := pflag.NewFlagSet("MyComponentConfig", pflag.ExitOnError) + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "str"), "hello world", "life is short") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "str2"), "hello world", "life is short") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "str3"), "hello world", "life is short") + return cmdFlags +} + +func (OtherComponentConfig) GetPFlagSet(prefix string) *pflag.FlagSet { + cmdFlags := pflag.NewFlagSet("MyComponentConfig", pflag.ExitOnError) + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "string-value"), "hello world", "life is short") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "duration-value"), "20s", "") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "int-val"), 4, "this is an important flag") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "url-value"), "http://blah.com", "Sets the type of storage to configure [s3/minio/local/mem].") + cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "strings-def"), []string{"default value"}, "Sets the type of storage to configure [s3/minio/local/mem].") + return cmdFlags +} + +type TestConfig struct { + MyComponentConfig MyComponentConfig `json:"my-component"` + OtherComponentConfig OtherComponentConfig `json:"other-component"` +} + +const ( + MyComponentSectionKey = "my-component" + OtherComponentSectionKey = "other-component" +) + +func init() { + if _, err := config.RegisterSection(MyComponentSectionKey, &MyComponentConfig{}); err != nil { + panic(err) + } + + if _, err := config.RegisterSection(OtherComponentSectionKey, &OtherComponentConfig{}); err != nil { + panic(err) + } +} diff --git a/config/url.go b/config/url.go new file mode 100644 index 0000000000..4045caf962 --- /dev/null +++ b/config/url.go @@ -0,0 +1,36 @@ +package config + +import ( + "encoding/json" + "errors" + "net/url" +) + +// A url.URL wrapper that can marshal and unmarshal into simple URL strings. +type URL struct { + url.URL +} + +func (d URL) MarshalJSON() ([]byte, error) { + return json.Marshal(d.String()) +} + +func (d *URL) UnmarshalJSON(b []byte) error { + var v interface{} + if err := json.Unmarshal(b, &v); err != nil { + return err + } + + switch value := v.(type) { + case string: + u, err := url.Parse(value) + if err != nil { + return err + } + + d.URL = *u + return nil + default: + return errors.New("invalid url") + } +} diff --git a/config/url_test.go b/config/url_test.go new file mode 100644 index 0000000000..e4046b3b16 --- /dev/null +++ b/config/url_test.go @@ -0,0 +1,59 @@ +package config + +import ( + "encoding/json" + "fmt" + "reflect" + "testing" + + "github.com/lyft/flytestdlib/internal/utils" + + "github.com/stretchr/testify/assert" +) + +func TestURL_MarshalJSON(t *testing.T) { + validURLs := []string{ + "http://localhost:123", + "http://localhost", + "https://non-existent.com/path/to/something", + } + + for i, validURL := range validURLs { + t.Run(fmt.Sprintf("Valid %v", i), func(t *testing.T) { + expected := URL{URL: utils.MustParseURL(validURL)} + + b, err := expected.MarshalJSON() + assert.NoError(t, err) + + actual := URL{} + err = actual.UnmarshalJSON(b) + assert.NoError(t, err) + + assert.True(t, reflect.DeepEqual(expected, actual)) + }) + } +} + +func TestURL_UnmarshalJSON(t *testing.T) { + invalidValues := []interface{}{ + "%gh&%ij", + 123, + true, + } + for i, invalidURL := range invalidValues { + t.Run(fmt.Sprintf("Invalid %v", i), func(t *testing.T) { + raw, err := json.Marshal(invalidURL) + assert.NoError(t, err) + + actual := URL{} + err = actual.UnmarshalJSON(raw) + assert.Error(t, err) + }) + } + + t.Run("Invalid json", func(t *testing.T) { + actual := URL{} + err := actual.UnmarshalJSON([]byte{}) + assert.Error(t, err) + }) +} diff --git a/config/utils.go b/config/utils.go new file mode 100644 index 0000000000..fcd833ff75 --- /dev/null +++ b/config/utils.go @@ -0,0 +1,69 @@ +package config + +import ( + "encoding/json" + "fmt" + "reflect" + + "github.com/pkg/errors" +) + +// Uses Json marshal/unmarshal to make a deep copy of a config object. +func DeepCopyConfig(config Config) (Config, error) { + raw, err := json.Marshal(config) + if err != nil { + return nil, err + } + + t := reflect.TypeOf(config) + ptrValue := reflect.New(t) + newObj := ptrValue.Interface() + if err = json.Unmarshal(raw, newObj); err != nil { + return nil, err + } + + return ptrValue.Elem().Interface(), nil +} + +func DeepEqual(config1, config2 Config) bool { + return reflect.DeepEqual(config1, config2) +} + +func toInterface(config Config) (interface{}, error) { + raw, err := json.Marshal(config) + if err != nil { + return nil, err + } + + var m interface{} + err = json.Unmarshal(raw, &m) + return m, err +} + +// Builds a generic map out of the root section config and its sub-sections configs. +func AllConfigsAsMap(root Section) (m map[string]interface{}, err error) { + errs := ErrorCollection{} + allConfigs := make(map[string]interface{}, len(root.GetSections())) + if root.GetConfig() != nil { + rootConfig, err := toInterface(root.GetConfig()) + if !errs.Append(err) { + if asMap, isCasted := rootConfig.(map[string]interface{}); isCasted { + allConfigs = asMap + } else { + allConfigs[""] = rootConfig + } + } + } + + for k, section := range root.GetSections() { + if _, alreadyExists := allConfigs[k]; alreadyExists { + errs.Append(errors.Wrap(ErrChildConfigOverridesConfig, + fmt.Sprintf("section key [%v] overrides an existing native config property", k))) + } + + allConfigs[k], err = AllConfigsAsMap(section) + errs.Append(err) + } + + return allConfigs, errs.ErrorOrDefault() +} diff --git a/config/utils_test.go b/config/utils_test.go new file mode 100644 index 0000000000..e5d7d02105 --- /dev/null +++ b/config/utils_test.go @@ -0,0 +1,39 @@ +package config + +import ( + "reflect" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDeepCopyConfig(t *testing.T) { + type pair struct { + first interface{} + second interface{} + } + + type fakeConfig struct { + I int + S string + Ptr *fakeConfig + } + + testCases := []pair{ + {3, 3}, + {"word", "word"}, + {fakeConfig{I: 4, S: "four", Ptr: &fakeConfig{I: 5, S: "five"}}, fakeConfig{I: 4, S: "four", Ptr: &fakeConfig{I: 5, S: "five"}}}, + {&fakeConfig{I: 4, S: "four", Ptr: &fakeConfig{I: 5, S: "five"}}, &fakeConfig{I: 4, S: "four", Ptr: &fakeConfig{I: 5, S: "five"}}}, + } + + for i, testCase := range testCases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + input, expected := testCase.first, testCase.second + actual, err := DeepCopyConfig(input) + assert.NoError(t, err) + assert.Equal(t, reflect.TypeOf(expected).String(), reflect.TypeOf(actual).String()) + assert.Equal(t, expected, actual) + }) + } +} diff --git a/config/viper/collection.go b/config/viper/collection.go new file mode 100644 index 0000000000..680052f494 --- /dev/null +++ b/config/viper/collection.go @@ -0,0 +1,175 @@ +package viper + +import ( + "context" + "fmt" + "io" + "os" + "strings" + + "github.com/lyft/flytestdlib/config" + + "github.com/lyft/flytestdlib/logger" + + viperLib "github.com/spf13/viper" + + "github.com/fsnotify/fsnotify" + "github.com/spf13/pflag" +) + +type Viper interface { + BindPFlags(flags *pflag.FlagSet) error + BindEnv(input ...string) error + AutomaticEnv() + ReadInConfig() error + OnConfigChange(run func(in fsnotify.Event)) + WatchConfig() + AllSettings() map[string]interface{} + ConfigFileUsed() string + MergeConfig(in io.Reader) error +} + +// A proxy object for a collection of Viper instances. +type CollectionProxy struct { + underlying []Viper + pflags *pflag.FlagSet + envVars [][]string + automaticEnv bool +} + +func (c *CollectionProxy) BindPFlags(flags *pflag.FlagSet) error { + err := config.ErrorCollection{} + for _, v := range c.underlying { + err.Append(v.BindPFlags(flags)) + } + + c.pflags = flags + + return err.ErrorOrDefault() +} + +func (c *CollectionProxy) BindEnv(input ...string) error { + err := config.ErrorCollection{} + for _, v := range c.underlying { + err.Append(v.BindEnv(input...)) + } + + if c.envVars == nil { + c.envVars = make([][]string, 0, 1) + } + + c.envVars = append(c.envVars, input) + + return err.ErrorOrDefault() +} + +func (c *CollectionProxy) AutomaticEnv() { + for _, v := range c.underlying { + v.AutomaticEnv() + } + + c.automaticEnv = true +} + +func (c CollectionProxy) ReadInConfig() error { + err := config.ErrorCollection{} + for _, v := range c.underlying { + err.Append(v.ReadInConfig()) + } + + return err.ErrorOrDefault() +} + +func (c CollectionProxy) OnConfigChange(run func(in fsnotify.Event)) { + for _, v := range c.underlying { + v.OnConfigChange(run) + } +} + +func (c CollectionProxy) WatchConfig() { + for _, v := range c.underlying { + v.WatchConfig() + } +} + +func (c CollectionProxy) AllSettings() map[string]interface{} { + finalRes := map[string]interface{}{} + if len(c.underlying) == 0 { + return finalRes + } + + combinedConfig, err := c.MergeAllConfigs() + if err != nil { + logger.Warnf(context.TODO(), "Failed to merge config. Error: %v", err) + return finalRes + } + + return combinedConfig.AllSettings() +} + +func (c CollectionProxy) ConfigFileUsed() string { + return fmt.Sprintf("[%v]", strings.Join(c.ConfigFilesUsed(), ",")) +} + +func (c CollectionProxy) MergeConfig(in io.Reader) error { + panic("Not yet implemented.") +} + +func (c CollectionProxy) MergeAllConfigs() (all Viper, err error) { + combinedConfig := viperLib.New() + if c.envVars != nil { + for _, envConfig := range c.envVars { + err = combinedConfig.BindEnv(envConfig...) + if err != nil { + return nil, err + } + } + } + + if c.automaticEnv { + combinedConfig.AutomaticEnv() + } + + if c.pflags != nil { + err = combinedConfig.BindPFlags(c.pflags) + if err != nil { + return nil, err + } + } + + for _, v := range c.underlying { + if _, isCollection := v.(*CollectionProxy); isCollection { + return nil, fmt.Errorf("merging nested CollectionProxies is not yet supported") + } + + if len(v.ConfigFileUsed()) == 0 { + continue + } + + combinedConfig.SetConfigFile(v.ConfigFileUsed()) + + reader, err := os.Open(v.ConfigFileUsed()) + if err != nil { + return nil, err + } + + err = combinedConfig.MergeConfig(reader) + if err != nil { + return nil, err + } + } + + return combinedConfig, nil +} + +func (c CollectionProxy) ConfigFilesUsed() []string { + res := make([]string, 0, len(c.underlying)) + for _, v := range c.underlying { + filePath := v.ConfigFileUsed() + if len(filePath) > 0 { + res = append(res, filePath) + } + } + + return res +} diff --git a/config/viper/viper.go b/config/viper/viper.go new file mode 100644 index 0000000000..05a5378ec7 --- /dev/null +++ b/config/viper/viper.go @@ -0,0 +1,357 @@ +package viper + +import ( + "context" + "encoding/json" + "flag" + "fmt" + "reflect" + "strings" + "sync" + + "github.com/pkg/errors" + + "github.com/lyft/flytestdlib/config" + "github.com/lyft/flytestdlib/config/files" + "github.com/lyft/flytestdlib/logger" + "github.com/spf13/cobra" + + "github.com/fsnotify/fsnotify" + "github.com/mitchellh/mapstructure" + + "github.com/spf13/pflag" + viperLib "github.com/spf13/viper" +) + +const ( + keyDelim = "." +) + +var ( + dereferencableKinds = map[reflect.Kind]struct{}{ + reflect.Array: {}, reflect.Chan: {}, reflect.Map: {}, reflect.Ptr: {}, reflect.Slice: {}, + } +) + +type viperAccessor struct { + // Determines whether parsing config should fail if it contains un-registered sections. + strictMode bool + viper *CollectionProxy + rootConfig config.Section + // Ensures we initialize the file Watcher once. + watcherInitializer *sync.Once +} + +func (viperAccessor) ID() string { + return "Viper" +} + +func (viperAccessor) InitializeFlags(cmdFlags *flag.FlagSet) { + // TODO: Implement? +} + +func (v viperAccessor) InitializePflags(cmdFlags *pflag.FlagSet) { + err := v.addSectionsPFlags(cmdFlags) + if err != nil { + panic(errors.Wrap(err, "error adding config PFlags to flag set")) + } + + // Allow viper to read the value of the flags + err = v.viper.BindPFlags(cmdFlags) + if err != nil { + panic(errors.Wrap(err, "error binding PFlags")) + } +} + +func (v viperAccessor) addSectionsPFlags(flags *pflag.FlagSet) (err error) { + for key, section := range v.rootConfig.GetSections() { + if asPFlagProvider, ok := section.GetConfig().(config.PFlagProvider); ok { + flags.AddFlagSet(asPFlagProvider.GetPFlagSet(key + keyDelim)) + } + } + + return nil +} + +// Binds keys from all sections to viper env vars. This instructs viper to lookup those from env vars when we ask for +// viperLib.AllSettings() +func (v viperAccessor) bindViperConfigsFromEnv(root config.Section) (err error) { + allConfigs, err := config.AllConfigsAsMap(root) + if err != nil { + return err + } + + return v.bindViperConfigsEnvDepth(allConfigs, "") +} + +func (v viperAccessor) bindViperConfigsEnvDepth(m map[string]interface{}, prefix string) error { + errs := config.ErrorCollection{} + for key, val := range m { + subKey := prefix + key + if asMap, ok := val.(map[string]interface{}); ok { + errs.Append(v.bindViperConfigsEnvDepth(asMap, subKey+keyDelim)) + } else { + errs.Append(v.viper.BindEnv(subKey, strings.ToUpper(strings.Replace(subKey, "-", "_", -1)))) + } + } + + return errs.ErrorOrDefault() +} + +func (v viperAccessor) updateConfig(ctx context.Context, r config.Section) error { + // Binds all keys to env vars. + err := v.bindViperConfigsFromEnv(r) + if err != nil { + return err + } + + v.viper.AutomaticEnv() // read in environment variables that match + + shouldWatchChanges := true + // If a config file is found, read it in. + if err = v.viper.ReadInConfig(); err == nil { + logger.Printf(ctx, "Using config file: %+v", v.viper.ConfigFilesUsed()) + } else if asErrorCollection, ok := err.(config.ErrorCollection); ok { + shouldWatchChanges = false + for i, e := range asErrorCollection { + if _, isNotFound := errors.Cause(e).(viperLib.ConfigFileNotFoundError); isNotFound { + logger.Printf(ctx, "[%v] Couldn't find a config file [%v]. Relying on env vars and pflags.", + i, v.viper.underlying[i].ConfigFileUsed()) + } else { + return err + } + } + } else if reflect.TypeOf(err) == reflect.TypeOf(viperLib.ConfigFileNotFoundError{}) { + shouldWatchChanges = false + logger.Printf(ctx, "Couldn't find a config file. Relying on env vars and pflags.") + } else { + return err + } + + if shouldWatchChanges { + v.watcherInitializer.Do(func() { + // Watch config files to pick up on file changes without requiring a full application restart. + // This call must occur after *all* config paths have been added. + v.viper.OnConfigChange(func(e fsnotify.Event) { + fmt.Printf("Got a notification change for file [%v]\n", e.Name) + v.configChangeHandler() + }) + v.viper.WatchConfig() + }) + } + + return v.RefreshFromConfig(ctx, r) +} + +func (v viperAccessor) UpdateConfig(ctx context.Context) error { + return v.updateConfig(ctx, v.rootConfig) +} + +// Checks if t is a kind that can be dereferenced to get its underlying type. +func canGetElement(t reflect.Kind) bool { + _, exists := dereferencableKinds[t] + return exists +} + +// This decoder hook tests types for json unmarshaling capability. If implemented, it uses json unmarshal to build the +// object. Otherwise, it'll just pass on the original data. +func jsonUnmarshallerHook(_, to reflect.Type, data interface{}) (interface{}, error) { + unmarshalerType := reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() + if to.Implements(unmarshalerType) || reflect.PtrTo(to).Implements(unmarshalerType) || + (canGetElement(to.Kind()) && to.Elem().Implements(unmarshalerType)) { + + ctx := context.Background() + raw, err := json.Marshal(data) + if err != nil { + logger.Printf(ctx, "Failed to marshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + res := reflect.New(to).Interface() + err = json.Unmarshal(raw, &res) + if err != nil { + logger.Printf(ctx, "Failed to umarshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + return res, nil + } + + return data, nil +} + +// Parses RootType config from parsed Viper settings. This should be called after viper has parsed config file/pflags...etc. +func (v viperAccessor) parseViperConfig(root config.Section) error { + // We use AllSettings instead of AllKeys to get the root level keys folded. + return v.parseViperConfigRecursive(root, v.viper.AllSettings()) +} + +func (v viperAccessor) parseViperConfigRecursive(root config.Section, settings interface{}) error { + errs := config.ErrorCollection{} + var mine interface{} + myKeysCount := 0 + if asMap, casted := settings.(map[string]interface{}); casted { + myMap := map[string]interface{}{} + for childKey, childValue := range asMap { + if childSection, found := root.GetSections()[childKey]; found { + errs.Append(v.parseViperConfigRecursive(childSection, childValue)) + } else { + myMap[childKey] = childValue + } + } + + mine = myMap + myKeysCount = len(myMap) + } else if asSlice, casted := settings.([]interface{}); casted { + mine = settings + myKeysCount = len(asSlice) + } else { + mine = settings + if settings != nil { + myKeysCount = 1 + } + } + + if root.GetConfig() != nil { + c, err := config.DeepCopyConfig(root.GetConfig()) + errs.Append(err) + if err != nil { + return errs.ErrorOrDefault() + } + + errs.Append(decode(mine, defaultDecoderConfig(c, v.decoderConfigs()...))) + errs.Append(root.SetConfig(c)) + + return errs.ErrorOrDefault() + } else if myKeysCount > 0 { + // There are keys set that are meant to be decoded but no config to receive them. Fail if strict mode is on. + if v.strictMode { + errs.Append(errors.Wrap( + config.ErrStrictModeValidation, + fmt.Sprintf("strict mode is on but received keys [%+v] to decode with no config assigned to"+ + " receive them", mine))) + } + } + + return errs.ErrorOrDefault() +} + +// Adds any specific configs controlled by this viper accessor instance. +func (v viperAccessor) decoderConfigs() []viperLib.DecoderConfigOption { + return []viperLib.DecoderConfigOption{ + func(config *mapstructure.DecoderConfig) { + config.ErrorUnused = v.strictMode + }, + } +} + +// defaultDecoderConfig returns default mapsstructure.DecoderConfig with support +// of time.Duration values & string slices +func defaultDecoderConfig(output interface{}, opts ...viperLib.DecoderConfigOption) *mapstructure.DecoderConfig { + c := &mapstructure.DecoderConfig{ + Metadata: nil, + Result: output, + WeaklyTypedInput: true, + TagName: "json", + DecodeHook: mapstructure.ComposeDecodeHookFunc( + jsonUnmarshallerHook, + mapstructure.StringToTimeDurationHookFunc(), + mapstructure.StringToSliceHookFunc(","), + ), + } + + for _, opt := range opts { + opt(c) + } + + return c +} + +// A wrapper around mapstructure.Decode that mimics the WeakDecode functionality +func decode(input interface{}, config *mapstructure.DecoderConfig) error { + decoder, err := mapstructure.NewDecoder(config) + if err != nil { + return err + } + return decoder.Decode(input) +} + +func (v viperAccessor) configChangeHandler() { + ctx := context.Background() + err := v.RefreshFromConfig(ctx, v.rootConfig) + if err != nil { + // TODO: Retry? panic? + logger.Printf(ctx, "Failed to update config. Error: %v", err) + } else { + logger.Printf(ctx, "Refreshed config in response to file(s) change.") + } +} + +func (v viperAccessor) RefreshFromConfig(ctx context.Context, r config.Section) error { + err := v.parseViperConfig(r) + if err != nil { + return err + } + + v.sendUpdatedEvents(ctx, r, "") + + return nil +} + +func (v viperAccessor) sendUpdatedEvents(ctx context.Context, root config.Section, sectionKey config.SectionKey) { + for key, section := range root.GetSections() { + if !section.GetConfigChangedAndClear() { + logger.Infof(ctx, "Config section [%v] hasn't changed.", sectionKey+key) + } else if section.GetConfigUpdatedHandler() == nil { + logger.Infof(ctx, "Config section [%v] updated. No update handler registered.", sectionKey+key) + } else { + logger.Infof(ctx, "Config section [%v] updated. Firing updated event.", sectionKey+key) + section.GetConfigUpdatedHandler()(ctx, section.GetConfig()) + } + + v.sendUpdatedEvents(ctx, section, sectionKey+key+keyDelim) + } +} + +func (v viperAccessor) ConfigFilesUsed() []string { + return v.viper.ConfigFilesUsed() +} + +// Creates a config accessor that implements Accessor interface and uses viper to load configs. +func NewAccessor(opts config.Options) config.Accessor { + return newAccessor(opts) +} + +func newAccessor(opts config.Options) viperAccessor { + vipers := make([]Viper, 0, 1) + configFiles := files.FindConfigFiles(opts.SearchPaths) + for _, configFile := range configFiles { + v := viperLib.New() + v.SetConfigFile(configFile) + + vipers = append(vipers, v) + } + + // Create a default viper even if we couldn't find any matching files + if len(configFiles) == 0 { + v := viperLib.New() + vipers = append(vipers, v) + } + + r := opts.RootSection + if r == nil { + r = config.GetRootSection() + } + + return viperAccessor{ + strictMode: opts.StrictMode, + rootConfig: r, + viper: &CollectionProxy{underlying: vipers}, + watcherInitializer: &sync.Once{}, + } +} + +// Gets the root level command that can be added to any cobra-powered cli to get config* commands. +func GetConfigCommand() *cobra.Command { + return config.NewConfigCommand(NewAccessor) +} diff --git a/contextutils/context.go b/contextutils/context.go new file mode 100644 index 0000000000..b5a9c00fa2 --- /dev/null +++ b/contextutils/context.go @@ -0,0 +1,140 @@ +// Contains common flyte context utils. +package contextutils + +import ( + "context" + "fmt" +) + +type Key string + +const ( + AppNameKey Key = "app_name" + NamespaceKey Key = "ns" + TaskTypeKey Key = "tasktype" + ProjectKey Key = "project" + DomainKey Key = "domain" + WorkflowIDKey Key = "wf" + NodeIDKey Key = "node" + TaskIDKey Key = "task" + ExecIDKey Key = "exec_id" + JobIDKey Key = "job_id" + PhaseKey Key = "phase" +) + +func (k Key) String() string { + return string(k) +} + +var logKeys = []Key{ + AppNameKey, + JobIDKey, + NamespaceKey, + ExecIDKey, + NodeIDKey, + WorkflowIDKey, + TaskTypeKey, + PhaseKey, +} + +// Gets a new context with namespace set. +func WithNamespace(ctx context.Context, namespace string) context.Context { + return context.WithValue(ctx, NamespaceKey, namespace) +} + +// Gets a new context with JobId set. If the existing context already has a job id, the new context will have +// / set as the job id. +func WithJobID(ctx context.Context, jobID string) context.Context { + existingJobID := ctx.Value(JobIDKey) + if existingJobID != nil { + jobID = fmt.Sprintf("%v/%v", existingJobID, jobID) + } + + return context.WithValue(ctx, JobIDKey, jobID) +} + +// Gets a new context with AppName set. +func WithAppName(ctx context.Context, appName string) context.Context { + return context.WithValue(ctx, AppNameKey, appName) +} + +// Gets a new context with Phase set. +func WithPhase(ctx context.Context, phase string) context.Context { + return context.WithValue(ctx, PhaseKey, phase) +} + +// Gets a new context with ExecutionID set. +func WithExecutionID(ctx context.Context, execID string) context.Context { + return context.WithValue(ctx, ExecIDKey, execID) +} + +// Gets a new context with NodeID (nested) set. +func WithNodeID(ctx context.Context, nodeID string) context.Context { + existingNodeID := ctx.Value(NodeIDKey) + if existingNodeID != nil { + nodeID = fmt.Sprintf("%v/%v", existingNodeID, nodeID) + } + return context.WithValue(ctx, NodeIDKey, nodeID) +} + +// Gets a new context with WorkflowName set. +func WithWorkflowID(ctx context.Context, workflow string) context.Context { + return context.WithValue(ctx, WorkflowIDKey, workflow) +} + +// Get new context with Project and Domain values set +func WithProjectDomain(ctx context.Context, project, domain string) context.Context { + c := context.WithValue(ctx, ProjectKey, project) + return context.WithValue(c, DomainKey, domain) +} + +// Gets a new context with WorkflowName set. +func WithTaskID(ctx context.Context, taskID string) context.Context { + return context.WithValue(ctx, TaskIDKey, taskID) +} + +// Gets a new context with TaskType set. +func WithTaskType(ctx context.Context, taskType string) context.Context { + return context.WithValue(ctx, TaskTypeKey, taskType) +} + +func addFieldIfNotNil(ctx context.Context, m map[string]interface{}, fieldKey Key) { + val := ctx.Value(fieldKey) + if val != nil { + m[fieldKey.String()] = val + } +} + +func addStringFieldWithDefaults(ctx context.Context, m map[string]string, fieldKey Key) { + val := ctx.Value(fieldKey) + if val == nil { + val = "" + } + m[fieldKey.String()] = val.(string) +} + +// Gets a map of all known logKeys set on the context. logKeys are special and should be used incase, context fields +// are to be added to the log lines. +func GetLogFields(ctx context.Context) map[string]interface{} { + res := map[string]interface{}{} + for _, k := range logKeys { + addFieldIfNotNil(ctx, res, k) + } + return res +} + +func Value(ctx context.Context, key Key) string { + val := ctx.Value(key) + if val != nil { + return val.(string) + } + return "" +} + +func Values(ctx context.Context, keys ...Key) map[string]string { + res := map[string]string{} + for _, k := range keys { + addStringFieldWithDefaults(ctx, res, k) + } + return res +} diff --git a/contextutils/context_test.go b/contextutils/context_test.go new file mode 100644 index 0000000000..99d29d3e3d --- /dev/null +++ b/contextutils/context_test.go @@ -0,0 +1,113 @@ +package contextutils + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestWithAppName(t *testing.T) { + ctx := context.Background() + assert.Nil(t, ctx.Value(AppNameKey)) + ctx = WithAppName(ctx, "application-name-123") + assert.Equal(t, "application-name-123", ctx.Value(AppNameKey)) + + ctx = WithAppName(ctx, "app-name-modified") + assert.Equal(t, "app-name-modified", ctx.Value(AppNameKey)) +} + +func TestWithPhase(t *testing.T) { + ctx := context.Background() + assert.Nil(t, ctx.Value(PhaseKey)) + ctx = WithPhase(ctx, "Running") + assert.Equal(t, "Running", ctx.Value(PhaseKey)) + + ctx = WithPhase(ctx, "Failed") + assert.Equal(t, "Failed", ctx.Value(PhaseKey)) +} + +func TestWithJobId(t *testing.T) { + ctx := context.Background() + assert.Nil(t, ctx.Value(JobIDKey)) + ctx = WithJobID(ctx, "job123") + assert.Equal(t, "job123", ctx.Value(JobIDKey)) + + ctx = WithJobID(ctx, "subtask") + assert.Equal(t, "job123/subtask", ctx.Value(JobIDKey)) +} + +func TestWithNamespace(t *testing.T) { + ctx := context.Background() + assert.Nil(t, ctx.Value(NamespaceKey)) + ctx = WithNamespace(ctx, "flyte") + assert.Equal(t, "flyte", ctx.Value(NamespaceKey)) + + ctx = WithNamespace(ctx, "flyte2") + assert.Equal(t, "flyte2", ctx.Value(NamespaceKey)) +} + +func TestWithExecutionID(t *testing.T) { + ctx := context.Background() + assert.Nil(t, ctx.Value(ExecIDKey)) + ctx = WithExecutionID(ctx, "job123") + assert.Equal(t, "job123", ctx.Value(ExecIDKey)) +} + +func TestWithTaskType(t *testing.T) { + ctx := context.Background() + assert.Nil(t, ctx.Value(TaskTypeKey)) + ctx = WithTaskType(ctx, "flyte") + assert.Equal(t, "flyte", ctx.Value(TaskTypeKey)) +} + +func TestWithWorkflowID(t *testing.T) { + ctx := context.Background() + assert.Nil(t, ctx.Value(WorkflowIDKey)) + ctx = WithWorkflowID(ctx, "flyte") + assert.Equal(t, "flyte", ctx.Value(WorkflowIDKey)) +} + +func TestWithNodeID(t *testing.T) { + ctx := context.Background() + assert.Nil(t, ctx.Value(NodeIDKey)) + ctx = WithNodeID(ctx, "n1") + assert.Equal(t, "n1", ctx.Value(NodeIDKey)) + + ctx = WithNodeID(ctx, "n2") + assert.Equal(t, "n1/n2", ctx.Value(NodeIDKey)) +} + +func TestWithProjectDomain(t *testing.T) { + ctx := context.Background() + assert.Nil(t, ctx.Value(ProjectKey)) + assert.Nil(t, ctx.Value(DomainKey)) + ctx = WithProjectDomain(ctx, "proj", "domain") + assert.Equal(t, "proj", ctx.Value(ProjectKey)) + assert.Equal(t, "domain", ctx.Value(DomainKey)) +} + +func TestWithTaskID(t *testing.T) { + ctx := context.Background() + assert.Nil(t, ctx.Value(TaskIDKey)) + ctx = WithTaskID(ctx, "task") + assert.Equal(t, "task", ctx.Value(TaskIDKey)) +} + +func TestGetFields(t *testing.T) { + ctx := context.Background() + ctx = WithJobID(WithNamespace(ctx, "ns123"), "job123") + m := GetLogFields(ctx) + assert.Equal(t, "ns123", m[NamespaceKey.String()]) + assert.Equal(t, "job123", m[JobIDKey.String()]) +} + +func TestValues(t *testing.T) { + ctx := context.Background() + ctx = WithWorkflowID(ctx, "flyte") + m := Values(ctx, ProjectKey, WorkflowIDKey) + assert.NotNil(t, m) + assert.Equal(t, 2, len(m)) + assert.Equal(t, "flyte", m[WorkflowIDKey.String()]) + assert.Equal(t, "", m[ProjectKey.String()]) +} diff --git a/internal/utils/parsers.go b/internal/utils/parsers.go new file mode 100644 index 0000000000..c1fcfa3a4a --- /dev/null +++ b/internal/utils/parsers.go @@ -0,0 +1,20 @@ +package utils + +import ( + "net/url" +) + +// A utility function to be used in tests. It parses urlString as url.URL or panics if it's invalid. +func MustParseURL(urlString string) url.URL { + u, err := url.Parse(urlString) + if err != nil { + panic(err) + } + + return *u +} + +// A utility function to be used in tests. It returns the address of the passed value. +func RefInt(val int) *int { + return &val +} diff --git a/internal/utils/parsers_test.go b/internal/utils/parsers_test.go new file mode 100644 index 0000000000..c15202b6ac --- /dev/null +++ b/internal/utils/parsers_test.go @@ -0,0 +1,25 @@ +package utils + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMustParseURL(t *testing.T) { + t.Run("Valid URL", func(t *testing.T) { + MustParseURL("http://something-profound-localhost.com") + }) + + t.Run("Invalid URL", func(t *testing.T) { + assert.Panics(t, func() { + MustParseURL("invalid_url:is_here\\") + }) + }) +} + +func TestRefUint32(t *testing.T) { + input := int(5) + res := RefInt(input) + assert.Equal(t, input, *res) +} diff --git a/ioutils/bytes.go b/ioutils/bytes.go new file mode 100644 index 0000000000..ad69c0d97b --- /dev/null +++ b/ioutils/bytes.go @@ -0,0 +1,21 @@ +package ioutils + +import ( + "bytes" + "io" +) + +// A Closeable Reader for bytes to mimic stream from inmemory byte storage +type BytesReadCloser struct { + *bytes.Reader +} + +func (*BytesReadCloser) Close() error { + return nil +} + +func NewBytesReadCloser(b []byte) io.ReadCloser { + return &BytesReadCloser{ + Reader: bytes.NewReader(b), + } +} diff --git a/ioutils/bytes_test.go b/ioutils/bytes_test.go new file mode 100644 index 0000000000..9745c62c91 --- /dev/null +++ b/ioutils/bytes_test.go @@ -0,0 +1,17 @@ +package ioutils + +import ( + "io/ioutil" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewBytesReadCloser(t *testing.T) { + i := []byte("abc") + r := NewBytesReadCloser(i) + o, e := ioutil.ReadAll(r) + assert.NoError(t, e) + assert.Equal(t, o, i) + assert.NoError(t, r.Close()) +} diff --git a/ioutils/timed_readers.go b/ioutils/timed_readers.go new file mode 100644 index 0000000000..ceb6415952 --- /dev/null +++ b/ioutils/timed_readers.go @@ -0,0 +1,17 @@ +package ioutils + +import ( + "io" + "io/ioutil" +) + +// Defines a common interface for timers. +type Timer interface { + // Stops the timer and reports observation. + Stop() float64 +} + +func ReadAll(r io.Reader, t Timer) ([]byte, error) { + defer t.Stop() + return ioutil.ReadAll(r) +} diff --git a/ioutils/timed_readers_test.go b/ioutils/timed_readers_test.go new file mode 100644 index 0000000000..7fa74f7241 --- /dev/null +++ b/ioutils/timed_readers_test.go @@ -0,0 +1,20 @@ +package ioutils + +import ( + "bytes" + "testing" + "time" + + "github.com/lyft/flytestdlib/promutils" + "github.com/stretchr/testify/assert" +) + +func TestReadAll(t *testing.T) { + r := bytes.NewReader([]byte("hello")) + s := promutils.NewTestScope() + w, e := s.NewStopWatch("x", "empty", time.Millisecond) + assert.NoError(t, e) + b, err := ReadAll(r, w.Start()) + assert.NoError(t, err) + assert.Equal(t, "hello", string(b)) +} diff --git a/logger/config.go b/logger/config.go new file mode 100644 index 0000000000..faa2da6f9e --- /dev/null +++ b/logger/config.go @@ -0,0 +1,81 @@ +package logger + +import ( + "context" + + "github.com/lyft/flytestdlib/config" +) + +//go:generate pflags Config + +const configSectionKey = "Logger" + +type FormatterType = string + +const ( + FormatterJSON FormatterType = "json" + FormatterText FormatterType = "text" +) + +const ( + jsonDataKey string = "json" +) + +// Global logger config. +type Config struct { + // Determines whether to include source code location in logs. This might incurs a performance hit and is only + // recommended on debug/development builds. + IncludeSourceCode bool `json:"show-source" pflag:",Includes source code location in logs."` + + // Determines whether the logger should mute all logs (including panics) + Mute bool `json:"mute" pflag:",Mutes all logs regardless of severity. Intended for benchmarks/tests only."` + + // Determines the minimum log level to log. + Level Level `json:"level" pflag:"4,Sets the minimum logging level."` + + Formatter FormatterConfig `json:"formatter" pflag:",Sets logging format."` +} + +type FormatterConfig struct { + Type FormatterType `json:"type" pflag:"\"json\",Sets logging format type."` +} + +var globalConfig = Config{} + +// Sets global logger config +func SetConfig(cfg Config) { + globalConfig = cfg + + onConfigUpdated(cfg) +} + +// Level type. +type Level = int + +// These are the different logging levels. +const ( + // PanicLevel level, highest level of severity. Logs and then calls panic with the + // message passed to Debug, Info, ... + PanicLevel Level = iota + // FatalLevel level. Logs and then calls `os.Exit(1)`. It will exit even if the + // logging level is set to Panic. + FatalLevel + // ErrorLevel level. Logs. Used for errors that should definitely be noted. + // Commonly used for hooks to send errors to an error tracking service. + ErrorLevel + // WarnLevel level. Non-critical entries that deserve eyes. + WarnLevel + // InfoLevel level. General operational entries about what's going on inside the + // application. + InfoLevel + // DebugLevel level. Usually only enabled when debugging. Very verbose logging. + DebugLevel +) + +func init() { + if _, err := config.RegisterSectionWithUpdates(configSectionKey, &Config{}, func(ctx context.Context, newValue config.Config) { + SetConfig(*newValue.(*Config)) + }); err != nil { + panic(err) + } +} diff --git a/logger/config_flags.go b/logger/config_flags.go new file mode 100755 index 0000000000..cf8950b94f --- /dev/null +++ b/logger/config_flags.go @@ -0,0 +1,21 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package logger + +import ( + "fmt" + + "github.com/spf13/pflag" +) + +// GetPFlagSet will return strongly types pflags for all fields in Config and its nested types. The format of the +// flags is json-name.json-sub-name... etc. +func (Config) GetPFlagSet(prefix string) *pflag.FlagSet { + cmdFlags := pflag.NewFlagSet("Config", pflag.ExitOnError) + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "show-source"), *new(bool), "Includes source code location in logs.") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "mute"), *new(bool), "Mutes all logs regardless of severity. Intended for benchmarks/tests only.") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "level"), 4, "Sets the minimum logging level.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "formatter.type"), "json", "Sets logging format type.") + return cmdFlags +} diff --git a/logger/config_flags_test.go b/logger/config_flags_test.go new file mode 100755 index 0000000000..401d58d493 --- /dev/null +++ b/logger/config_flags_test.go @@ -0,0 +1,190 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package logger + +import ( + "encoding/json" + "fmt" + "reflect" + "strings" + "testing" + + "github.com/mitchellh/mapstructure" + "github.com/stretchr/testify/assert" +) + +var dereferencableKindsConfig = map[reflect.Kind]struct{}{ + reflect.Array: {}, reflect.Chan: {}, reflect.Map: {}, reflect.Ptr: {}, reflect.Slice: {}, +} + +// Checks if t is a kind that can be dereferenced to get its underlying type. +func canGetElementConfig(t reflect.Kind) bool { + _, exists := dereferencableKindsConfig[t] + return exists +} + +// This decoder hook tests types for json unmarshaling capability. If implemented, it uses json unmarshal to build the +// object. Otherwise, it'll just pass on the original data. +func jsonUnmarshalerHookConfig(_, to reflect.Type, data interface{}) (interface{}, error) { + unmarshalerType := reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() + if to.Implements(unmarshalerType) || reflect.PtrTo(to).Implements(unmarshalerType) || + (canGetElementConfig(to.Kind()) && to.Elem().Implements(unmarshalerType)) { + + raw, err := json.Marshal(data) + if err != nil { + fmt.Printf("Failed to marshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + res := reflect.New(to).Interface() + err = json.Unmarshal(raw, &res) + if err != nil { + fmt.Printf("Failed to umarshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + return res, nil + } + + return data, nil +} + +func decode_Config(input, result interface{}) error { + config := &mapstructure.DecoderConfig{ + TagName: "json", + WeaklyTypedInput: true, + Result: result, + DecodeHook: mapstructure.ComposeDecodeHookFunc( + mapstructure.StringToTimeDurationHookFunc(), + mapstructure.StringToSliceHookFunc(","), + jsonUnmarshalerHookConfig, + ), + } + + decoder, err := mapstructure.NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(input) +} + +func join_Config(arr interface{}, sep string) string { + listValue := reflect.ValueOf(arr) + strs := make([]string, 0, listValue.Len()) + for i := 0; i < listValue.Len(); i++ { + strs = append(strs, fmt.Sprintf("%v", listValue.Index(i))) + } + + return strings.Join(strs, sep) +} + +func testDecodeJson_Config(t *testing.T, val, result interface{}) { + assert.NoError(t, decode_Config(val, result)) +} + +func testDecodeSlice_Config(t *testing.T, vStringSlice, result interface{}) { + assert.NoError(t, decode_Config(vStringSlice, result)) +} + +func TestConfig_GetPFlagSet(t *testing.T) { + val := Config{} + cmdFlags := val.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) +} + +func TestConfig_SetFlags(t *testing.T) { + actual := Config{} + cmdFlags := actual.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) + + t.Run("Test_show-source", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vBool, err := cmdFlags.GetBool("show-source"); err == nil { + assert.Equal(t, bool(*new(bool)), vBool) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("show-source", testValue) + if vBool, err := cmdFlags.GetBool("show-source"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vBool), &actual.IncludeSourceCode) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_mute", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vBool, err := cmdFlags.GetBool("mute"); err == nil { + assert.Equal(t, bool(*new(bool)), vBool) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("mute", testValue) + if vBool, err := cmdFlags.GetBool("mute"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vBool), &actual.Mute) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_level", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("level"); err == nil { + assert.Equal(t, int(4), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("level", testValue) + if vInt, err := cmdFlags.GetInt("level"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.Level) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_formatter.type", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("formatter.type"); err == nil { + assert.Equal(t, string("json"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("formatter.type", testValue) + if vString, err := cmdFlags.GetString("formatter.type"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Formatter.Type) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) +} diff --git a/logger/config_test.go b/logger/config_test.go new file mode 100644 index 0000000000..7d2d3782b1 --- /dev/null +++ b/logger/config_test.go @@ -0,0 +1,20 @@ +package logger + +import "testing" + +func TestSetConfig(t *testing.T) { + type args struct { + cfg Config + } + tests := []struct { + name string + args args + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + SetConfig(tt.args.cfg) + }) + } +} diff --git a/logger/logger.go b/logger/logger.go new file mode 100644 index 0000000000..29aadc8bf8 --- /dev/null +++ b/logger/logger.go @@ -0,0 +1,337 @@ +// Defines global context-aware logger. +// The default implementation uses logrus. This package registers "logger" config section on init(). The structure of the +// config section is expected to be un-marshal-able to Config struct. +package logger + +import ( + "context" + + "github.com/lyft/flytestdlib/contextutils" + + "fmt" + "runtime" + "strings" + + "github.com/sirupsen/logrus" +) + +//go:generate gotests -w -all $FILE + +const indentLevelKey contextutils.Key = "LoggerIndentLevel" + +func onConfigUpdated(cfg Config) { + logrus.SetLevel(logrus.Level(cfg.Level)) + + switch cfg.Formatter.Type { + case FormatterText: + if _, isText := logrus.StandardLogger().Formatter.(*logrus.JSONFormatter); !isText { + logrus.SetFormatter(&logrus.TextFormatter{ + FieldMap: logrus.FieldMap{ + logrus.FieldKeyTime: "ts", + }, + }) + } + default: + if _, isJSON := logrus.StandardLogger().Formatter.(*logrus.JSONFormatter); !isJSON { + logrus.SetFormatter(&logrus.JSONFormatter{ + DataKey: jsonDataKey, + FieldMap: logrus.FieldMap{ + logrus.FieldKeyTime: "ts", + }, + }) + } + } +} + +func getSourceLocation() string { + if globalConfig.IncludeSourceCode { + _, file, line, ok := runtime.Caller(3) + if !ok { + file = "???" + line = 1 + } else { + slash := strings.LastIndex(file, "/") + if slash >= 0 { + file = file[slash+1:] + } + } + + return fmt.Sprintf("[%v:%v] ", file, line) + } + + return "" +} + +func wrapHeader(ctx context.Context, args ...interface{}) []interface{} { + args = append([]interface{}{getIndent(ctx)}, args...) + + if globalConfig.IncludeSourceCode { + return append( + []interface{}{ + fmt.Sprintf("%v", getSourceLocation()), + }, + args...) + } + + return args +} + +func wrapHeaderForMessage(ctx context.Context, message string) string { + message = fmt.Sprintf("%v%v", getIndent(ctx), message) + + if globalConfig.IncludeSourceCode { + return fmt.Sprintf("%v%v", getSourceLocation(), message) + } + + return message +} + +func getLogger(ctx context.Context) *logrus.Entry { + entry := logrus.WithFields(logrus.Fields(contextutils.GetLogFields(ctx))) + entry.Level = logrus.Level(globalConfig.Level) + return entry +} + +func WithIndent(ctx context.Context, additionalIndent string) context.Context { + indentLevel := getIndent(ctx) + additionalIndent + return context.WithValue(ctx, indentLevelKey, indentLevel) +} + +func getIndent(ctx context.Context) string { + if existing := ctx.Value(indentLevelKey); existing != nil { + return existing.(string) + } + + return "" +} + +// Gets a value indicating whether logs at this level will be written to the logger. This is particularly useful to avoid +// computing log messages unnecessarily. +func IsLoggable(ctx context.Context, level Level) bool { + return getLogger(ctx).Level >= logrus.Level(level) +} + +// Debug logs a message at level Debug on the standard logger. +func Debug(ctx context.Context, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(ctx).Debug(wrapHeader(ctx, args)...) +} + +// Print logs a message at level Info on the standard logger. +func Print(ctx context.Context, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(ctx).Print(wrapHeader(ctx, args)...) +} + +// Info logs a message at level Info on the standard logger. +func Info(ctx context.Context, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(ctx).Info(wrapHeader(ctx, args)...) +} + +// Warn logs a message at level Warn on the standard logger. +func Warn(ctx context.Context, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(ctx).Warn(wrapHeader(ctx, args)...) +} + +// Warning logs a message at level Warn on the standard logger. +func Warning(ctx context.Context, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(ctx).Warning(wrapHeader(ctx, args)...) +} + +// Error logs a message at level Error on the standard logger. +func Error(ctx context.Context, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(ctx).Error(wrapHeader(ctx, args)...) +} + +// Panic logs a message at level Panic on the standard logger. +func Panic(ctx context.Context, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(ctx).Panic(wrapHeader(ctx, args)...) +} + +// Fatal logs a message at level Fatal on the standard logger. +func Fatal(ctx context.Context, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(ctx).Fatal(wrapHeader(ctx, args)...) +} + +// Debugf logs a message at level Debug on the standard logger. +func Debugf(ctx context.Context, format string, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(ctx).Debugf(wrapHeaderForMessage(ctx, format), args...) +} + +// Printf logs a message at level Info on the standard logger. +func Printf(ctx context.Context, format string, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(ctx).Printf(wrapHeaderForMessage(ctx, format), args...) +} + +// Infof logs a message at level Info on the standard logger. +func Infof(ctx context.Context, format string, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(ctx).Infof(wrapHeaderForMessage(ctx, format), args...) +} + +// InfofNoCtx logs a formatted message without context. +func InfofNoCtx(format string, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(context.TODO()).Infof(format, args...) +} + +// Warnf logs a message at level Warn on the standard logger. +func Warnf(ctx context.Context, format string, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(ctx).Warnf(wrapHeaderForMessage(ctx, format), args...) +} + +// Warningf logs a message at level Warn on the standard logger. +func Warningf(ctx context.Context, format string, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(ctx).Warningf(wrapHeaderForMessage(ctx, format), args...) +} + +// Errorf logs a message at level Error on the standard logger. +func Errorf(ctx context.Context, format string, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(ctx).Errorf(wrapHeaderForMessage(ctx, format), args...) +} + +// Panicf logs a message at level Panic on the standard logger. +func Panicf(ctx context.Context, format string, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(ctx).Panicf(wrapHeaderForMessage(ctx, format), args...) +} + +// Fatalf logs a message at level Fatal on the standard logger. +func Fatalf(ctx context.Context, format string, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(ctx).Fatalf(wrapHeaderForMessage(ctx, format), args...) +} + +// Debugln logs a message at level Debug on the standard logger. +func Debugln(ctx context.Context, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(ctx).Debugln(wrapHeader(ctx, args)...) +} + +// Println logs a message at level Info on the standard logger. +func Println(ctx context.Context, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(ctx).Println(wrapHeader(ctx, args)...) +} + +// Infoln logs a message at level Info on the standard logger. +func Infoln(ctx context.Context, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(ctx).Infoln(wrapHeader(ctx, args)...) +} + +// Warnln logs a message at level Warn on the standard logger. +func Warnln(ctx context.Context, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(ctx).Warnln(wrapHeader(ctx, args)...) +} + +// Warningln logs a message at level Warn on the standard logger. +func Warningln(ctx context.Context, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(ctx).Warningln(wrapHeader(ctx, args)...) +} + +// Errorln logs a message at level Error on the standard logger. +func Errorln(ctx context.Context, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(ctx).Errorln(wrapHeader(ctx, args)...) +} + +// Panicln logs a message at level Panic on the standard logger. +func Panicln(ctx context.Context, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(ctx).Panicln(wrapHeader(ctx, args)...) +} + +// Fatalln logs a message at level Fatal on the standard logger. +func Fatalln(ctx context.Context, args ...interface{}) { + if globalConfig.Mute { + return + } + + getLogger(ctx).Fatalln(wrapHeader(ctx, args)...) +} diff --git a/logger/logger_test.go b/logger/logger_test.go new file mode 100644 index 0000000000..75c73a9432 --- /dev/null +++ b/logger/logger_test.go @@ -0,0 +1,643 @@ +// Defines global context-aware logger. +// The default implementation uses logrus. This package registers "logger" config section on init(). The structure of the +// config section is expected to be un-marshal-able to Config struct. +package logger + +import ( + "context" + "reflect" + "strings" + "testing" + + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" +) + +func init() { + SetConfig(Config{ + Level: InfoLevel, + IncludeSourceCode: true, + }) +} + +func Test_getSourceLocation(t *testing.T) { + tests := []struct { + name string + want string + }{ + {"current", " "}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := getSourceLocation(); !strings.HasSuffix(got, tt.want) { + t.Errorf("getSourceLocation() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_wrapHeaderForMessage(t *testing.T) { + type args struct { + message string + } + tests := []struct { + name string + args args + want string + }{ + {"no args", args{message: ""}, " "}, + {"1 arg", args{message: "hello"}, " hello"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := wrapHeaderForMessage(context.TODO(), tt.args.message); !strings.HasSuffix(got, tt.want) { + t.Errorf("wrapHeaderForMessage() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestIsLoggable(t *testing.T) { + type args struct { + ctx context.Context + level Level + } + tests := []struct { + name string + args args + want bool + }{ + {"Debug Is not loggable", args{ctx: context.TODO(), level: DebugLevel}, false}, + {"Info Is loggable", args{ctx: context.TODO(), level: InfoLevel}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := IsLoggable(tt.args.ctx, tt.args.level); got != tt.want { + t.Errorf("IsLoggable() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestDebug(t *testing.T) { + type args struct { + ctx context.Context + args []interface{} + } + tests := []struct { + name string + args args + }{ + {"test", args{ctx: context.TODO(), args: []interface{}{"arg"}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Debug(tt.args.ctx, tt.args.args...) + }) + } +} + +func TestPrint(t *testing.T) { + type args struct { + ctx context.Context + args []interface{} + } + tests := []struct { + name string + args args + }{ + {"test", args{ctx: context.TODO(), args: []interface{}{"arg"}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Print(tt.args.ctx, tt.args.args...) + }) + } +} + +func TestInfo(t *testing.T) { + type args struct { + ctx context.Context + args []interface{} + } + tests := []struct { + name string + args args + }{ + {"test", args{ctx: context.TODO(), args: []interface{}{"arg"}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Info(tt.args.ctx, tt.args.args...) + }) + } +} + +func TestWarn(t *testing.T) { + type args struct { + ctx context.Context + args []interface{} + } + tests := []struct { + name string + args args + }{ + {"test", args{ctx: context.TODO(), args: []interface{}{"arg"}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Warn(tt.args.ctx, tt.args.args...) + }) + } +} + +func TestWarning(t *testing.T) { + type args struct { + ctx context.Context + args []interface{} + } + tests := []struct { + name string + args args + }{ + {"test", args{ctx: context.TODO(), args: []interface{}{"arg"}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Warning(tt.args.ctx, tt.args.args...) + }) + } +} + +func TestError(t *testing.T) { + type args struct { + ctx context.Context + args []interface{} + } + tests := []struct { + name string + args args + }{ + {"test", args{ctx: context.TODO(), args: []interface{}{"arg"}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Error(tt.args.ctx, tt.args.args...) + }) + } +} + +func TestPanic(t *testing.T) { + type args struct { + ctx context.Context + args []interface{} + } + tests := []struct { + name string + args args + }{ + {"test", args{ctx: context.TODO(), args: []interface{}{"arg"}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Panics(t, func() { + Panic(tt.args.ctx, tt.args.args...) + }) + }) + } +} + +func TestDebugf(t *testing.T) { + type args struct { + ctx context.Context + format string + args []interface{} + } + tests := []struct { + name string + args args + }{ + {"test", args{ctx: context.TODO(), args: []interface{}{"arg"}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Debugf(tt.args.ctx, tt.args.format, tt.args.args...) + }) + } +} + +func TestPrintf(t *testing.T) { + type args struct { + ctx context.Context + format string + args []interface{} + } + tests := []struct { + name string + args args + }{ + {"test", args{ctx: context.TODO(), args: []interface{}{"arg"}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Printf(tt.args.ctx, tt.args.format, tt.args.args...) + }) + } +} + +func TestInfof(t *testing.T) { + type args struct { + ctx context.Context + format string + args []interface{} + } + tests := []struct { + name string + args args + }{ + {"test", args{ctx: context.TODO(), args: []interface{}{"arg"}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Infof(tt.args.ctx, tt.args.format, tt.args.args...) + }) + } +} + +func TestInfofNoCtx(t *testing.T) { + type args struct { + format string + args []interface{} + } + tests := []struct { + name string + args args + }{ + {"test", args{format: "%v", args: []interface{}{"arg"}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + InfofNoCtx(tt.args.format, tt.args.args...) + }) + } +} + +func TestWarnf(t *testing.T) { + type args struct { + ctx context.Context + format string + args []interface{} + } + tests := []struct { + name string + args args + }{ + {"test", args{ctx: context.TODO(), args: []interface{}{"arg"}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Warnf(tt.args.ctx, tt.args.format, tt.args.args...) + }) + } +} + +func TestWarningf(t *testing.T) { + type args struct { + ctx context.Context + format string + args []interface{} + } + tests := []struct { + name string + args args + }{ + {"test", args{ctx: context.TODO(), args: []interface{}{"arg"}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Warningf(tt.args.ctx, tt.args.format, tt.args.args...) + }) + } +} + +func TestErrorf(t *testing.T) { + type args struct { + ctx context.Context + format string + args []interface{} + } + tests := []struct { + name string + args args + }{ + {"test", args{ctx: context.TODO(), args: []interface{}{"arg"}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Errorf(tt.args.ctx, tt.args.format, tt.args.args...) + }) + } +} + +func TestPanicf(t *testing.T) { + type args struct { + ctx context.Context + format string + args []interface{} + } + tests := []struct { + name string + args args + }{ + {"test", args{ctx: context.TODO(), args: []interface{}{"arg"}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Panics(t, func() { + Panicf(tt.args.ctx, tt.args.format, tt.args.args...) + }) + }) + } +} + +func TestDebugln(t *testing.T) { + type args struct { + ctx context.Context + args []interface{} + } + tests := []struct { + name string + args args + }{ + {"test", args{ctx: context.TODO(), args: []interface{}{"arg"}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Debugln(tt.args.ctx, tt.args.args...) + }) + } +} + +func TestPrintln(t *testing.T) { + type args struct { + ctx context.Context + args []interface{} + } + tests := []struct { + name string + args args + }{ + {"test", args{ctx: context.TODO(), args: []interface{}{"arg"}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Println(tt.args.ctx, tt.args.args...) + }) + } +} + +func TestInfoln(t *testing.T) { + type args struct { + ctx context.Context + args []interface{} + } + tests := []struct { + name string + args args + }{ + {"test", args{ctx: context.TODO(), args: []interface{}{"arg"}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Infoln(tt.args.ctx, tt.args.args...) + }) + } +} + +func TestWarnln(t *testing.T) { + type args struct { + ctx context.Context + args []interface{} + } + tests := []struct { + name string + args args + }{ + {"test", args{ctx: context.TODO(), args: []interface{}{"arg"}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Warnln(tt.args.ctx, tt.args.args...) + }) + } +} + +func TestWarningln(t *testing.T) { + type args struct { + ctx context.Context + args []interface{} + } + tests := []struct { + name string + args args + }{ + {"test", args{ctx: context.TODO(), args: []interface{}{"arg"}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Warningln(tt.args.ctx, tt.args.args...) + }) + } +} + +func TestErrorln(t *testing.T) { + type args struct { + ctx context.Context + args []interface{} + } + tests := []struct { + name string + args args + }{ + {"test", args{ctx: context.TODO(), args: []interface{}{"arg"}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Errorln(tt.args.ctx, tt.args.args...) + }) + } +} + +func TestPanicln(t *testing.T) { + type args struct { + ctx context.Context + args []interface{} + } + tests := []struct { + name string + args args + }{ + {"test", args{ctx: context.TODO(), args: []interface{}{"arg"}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Panics(t, func() { + Panicln(tt.args.ctx, tt.args.args...) + }) + }) + } +} + +func Test_wrapHeader(t *testing.T) { + type args struct { + ctx context.Context + args []interface{} + } + tests := []struct { + name string + args args + want []interface{} + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := wrapHeader(tt.args.ctx, tt.args.args...); !reflect.DeepEqual(got, tt.want) { + t.Errorf("wrapHeader() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_getLogger(t *testing.T) { + type args struct { + ctx context.Context + } + tests := []struct { + name string + args args + want *logrus.Entry + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := getLogger(tt.args.ctx); !reflect.DeepEqual(got, tt.want) { + t.Errorf("getLogger() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestWithIndent(t *testing.T) { + type args struct { + ctx context.Context + additionalIndent string + } + tests := []struct { + name string + args args + want context.Context + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := WithIndent(tt.args.ctx, tt.args.additionalIndent); !reflect.DeepEqual(got, tt.want) { + t.Errorf("WithIndent() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_getIndent(t *testing.T) { + type args struct { + ctx context.Context + } + tests := []struct { + name string + args args + want string + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := getIndent(tt.args.ctx); got != tt.want { + t.Errorf("getIndent() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestFatal(t *testing.T) { + type args struct { + ctx context.Context + args []interface{} + } + tests := []struct { + name string + args args + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Fatal(tt.args.ctx, tt.args.args...) + }) + } +} + +func TestFatalf(t *testing.T) { + type args struct { + ctx context.Context + format string + args []interface{} + } + tests := []struct { + name string + args args + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Fatalf(tt.args.ctx, tt.args.format, tt.args.args...) + }) + } +} + +func TestFatalln(t *testing.T) { + type args struct { + ctx context.Context + args []interface{} + } + tests := []struct { + name string + args args + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Fatalln(tt.args.ctx, tt.args.args...) + }) + } +} + +func Test_onConfigUpdated(t *testing.T) { + type args struct { + cfg Config + } + tests := []struct { + name string + args args + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + onConfigUpdated(tt.args.cfg) + }) + } +} diff --git a/pbhash/pbhash.go b/pbhash/pbhash.go new file mode 100644 index 0000000000..820b5c511a --- /dev/null +++ b/pbhash/pbhash.go @@ -0,0 +1,58 @@ +// This is a package that provides hashing utilities for Protobuf objects. +package pbhash + +import ( + "context" + "encoding/base64" + + goObjectHash "github.com/benlaurie/objecthash/go/objecthash" + "github.com/golang/protobuf/jsonpb" + "github.com/golang/protobuf/proto" + "github.com/lyft/flytestdlib/logger" +) + +var marshaller = &jsonpb.Marshaler{} + +func fromHashToByteArray(input [32]byte) []byte { + output := make([]byte, 32) + for idx, val := range input { + output[idx] = val + } + return output +} + +// Generate a deterministic hash in bytes for the pb object +func ComputeHash(ctx context.Context, pb proto.Message) ([]byte, error) { + // We marshal the pb object to JSON first which should provide a consistent mapping of pb to json fields as stated + // here: https://developers.google.com/protocol-buffers/docs/proto3#json + // jsonpb marshalling includes: + // - sorting map values to provide a stable output + // - omitting empty values which supports backwards compatibility of old protobuf definitions + // We do not use protobuf marshalling because it does not guarantee stable output because of how it handles + // unknown fields and ordering of fields. https://github.com/protocolbuffers/protobuf/issues/2830 + pbJSON, err := marshaller.MarshalToString(pb) + if err != nil { + logger.Warning(ctx, "failed to marshal pb [%+v] to JSON with err %v", pb, err) + return nil, err + } + + // Deterministically hash the JSON object to a byte array. The library will sort the map keys of the JSON object + // so that we do not run into the issues from pb marshalling. + hash, err := goObjectHash.CommonJSONHash(pbJSON) + if err != nil { + logger.Warning(ctx, "failed to hash JSON for pb [%+v] with err %v", pb, err) + return nil, err + } + + return fromHashToByteArray(hash), err +} + +// Generate a deterministic hash as a base64 encoded string for the pb object. +func ComputeHashString(ctx context.Context, pb proto.Message) (string, error) { + hashBytes, err := ComputeHash(ctx, pb) + if err != nil { + return "", err + } + + return base64.StdEncoding.EncodeToString(hashBytes), err +} diff --git a/pbhash/pbhash_test.go b/pbhash/pbhash_test.go new file mode 100644 index 0000000000..75735b4135 --- /dev/null +++ b/pbhash/pbhash_test.go @@ -0,0 +1,145 @@ +package pbhash + +import ( + "context" + "testing" + "time" + + "github.com/golang/protobuf/proto" + "github.com/golang/protobuf/ptypes" + "github.com/golang/protobuf/ptypes/duration" + "github.com/golang/protobuf/ptypes/timestamp" + "github.com/stretchr/testify/assert" +) + +// Mock a Protobuf generated GO object +type mockProtoMessage struct { + Integer int64 `protobuf:"varint,1,opt,name=integer,proto3" json:"integer,omitempty"` + FloatValue float64 `protobuf:"fixed64,2,opt,name=float_value,json=floatValue,proto3" json:"float_value,omitempty"` + StringValue string `protobuf:"bytes,3,opt,name=string_value,json=stringValue,proto3" json:"string_value,omitempty"` + Boolean bool `protobuf:"varint,4,opt,name=boolean,proto3" json:"boolean,omitempty"` + Datetime *timestamp.Timestamp `protobuf:"bytes,5,opt,name=datetime,proto3" json:"datetime,omitempty"` + Duration *duration.Duration `protobuf:"bytes,6,opt,name=duration,proto3" json:"duration,omitempty"` + MapValue map[string]string `protobuf:"bytes,7,rep,name=map_value,json=mapValue,proto3" json:"map_value,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` + Collections []string `protobuf:"bytes,8,rep,name=collections,proto3" json:"collections,omitempty"` +} + +func (mockProtoMessage) Reset() { +} + +func (m mockProtoMessage) String() string { + return proto.CompactTextString(m) +} + +func (mockProtoMessage) ProtoMessage() { +} + +// Mock an older version of the above pb object that doesn't have some fields +type mockOlderProto struct { + Integer int64 `protobuf:"varint,1,opt,name=integer,proto3" json:"integer,omitempty"` + FloatValue float64 `protobuf:"fixed64,2,opt,name=float_value,json=floatValue,proto3" json:"float_value,omitempty"` + StringValue string `protobuf:"bytes,3,opt,name=string_value,json=stringValue,proto3" json:"string_value,omitempty"` + Boolean bool `protobuf:"varint,4,opt,name=boolean,proto3" json:"boolean,omitempty"` +} + +func (mockOlderProto) Reset() { +} + +func (m mockOlderProto) String() string { + return proto.CompactTextString(m) +} + +func (mockOlderProto) ProtoMessage() { +} + +var sampleTime, _ = ptypes.TimestampProto( + time.Date(2019, 03, 29, 12, 0, 0, 0, time.UTC)) + +func TestProtoHash(t *testing.T) { + mockProto := &mockProtoMessage{ + Integer: 18, + FloatValue: 1.3, + StringValue: "lets test this", + Boolean: true, + Datetime: sampleTime, + Duration: ptypes.DurationProto(time.Millisecond), + MapValue: map[string]string{ + "z": "last", + "a": "first", + }, + Collections: []string{"1", "2", "3"}, + } + + expectedHashedMockProto := []byte{0x62, 0x95, 0xb2, 0x2c, 0x23, 0xf5, 0x35, 0x6d, 0x3, 0x56, 0x4d, 0xc7, 0x8f, 0xae, + 0x2d, 0x2b, 0xbd, 0x7, 0xff, 0xdb, 0x7e, 0xe5, 0xf4, 0x25, 0x8f, 0xbc, 0xb2, 0xc, 0xad, 0xa5, 0x48, 0x44} + expectedHashString := "YpWyLCP1NW0DVk3Hj64tK70H/9t+5fQlj7yyDK2lSEQ=" + + t.Run("TestFullProtoHash", func(t *testing.T) { + hashedBytes, err := ComputeHash(context.Background(), mockProto) + assert.Nil(t, err) + assert.Equal(t, expectedHashedMockProto, hashedBytes) + assert.Len(t, hashedBytes, 32) + + hashedString, err := ComputeHashString(context.Background(), mockProto) + assert.Nil(t, err) + assert.Equal(t, hashedString, expectedHashString) + }) + + t.Run("TestFullProtoHashReorderKeys", func(t *testing.T) { + mockProto.MapValue = map[string]string{"a": "first", "z": "last"} + hashedBytes, err := ComputeHash(context.Background(), mockProto) + assert.Nil(t, err) + assert.Equal(t, expectedHashedMockProto, hashedBytes) + assert.Len(t, hashedBytes, 32) + + hashedString, err := ComputeHashString(context.Background(), mockProto) + assert.Nil(t, err) + assert.Equal(t, hashedString, expectedHashString) + }) +} + +func TestPartialFilledProtoHash(t *testing.T) { + + mockProtoOmitEmpty := &mockProtoMessage{ + Integer: 18, + FloatValue: 1.3, + StringValue: "lets test this", + Boolean: true, + } + + expectedHashedMockProtoOmitEmpty := []byte{0x1a, 0x13, 0xcc, 0x4c, 0xab, 0xc9, 0x7d, 0x43, 0xc7, 0x2b, 0xc5, 0x37, + 0xbc, 0x49, 0xa8, 0x8b, 0xfc, 0x1d, 0x54, 0x1c, 0x7b, 0x21, 0x04, 0x8f, 0xab, 0x28, 0xc6, 0x5c, 0x06, 0x73, + 0xaa, 0xe2} + + expectedHashStringOmitEmpty := "GhPMTKvJfUPHK8U3vEmoi/wdVBx7IQSPqyjGXAZzquI=" + + t.Run("TestPartial", func(t *testing.T) { + hashedBytes, err := ComputeHash(context.Background(), mockProtoOmitEmpty) + assert.Nil(t, err) + assert.Equal(t, expectedHashedMockProtoOmitEmpty, hashedBytes) + assert.Len(t, hashedBytes, 32) + + hashedString, err := ComputeHashString(context.Background(), mockProtoOmitEmpty) + assert.Nil(t, err) + assert.Equal(t, hashedString, expectedHashStringOmitEmpty) + }) + + mockOldProtoMessage := &mockOlderProto{ + Integer: 18, + FloatValue: 1.3, + StringValue: "lets test this", + Boolean: true, + } + + t.Run("TestOlderProto", func(t *testing.T) { + hashedBytes, err := ComputeHash(context.Background(), mockOldProtoMessage) + assert.Nil(t, err) + assert.Equal(t, expectedHashedMockProtoOmitEmpty, hashedBytes) + assert.Len(t, hashedBytes, 32) + + hashedString, err := ComputeHashString(context.Background(), mockProtoOmitEmpty) + assert.Nil(t, err) + assert.Equal(t, hashedString, expectedHashStringOmitEmpty) + }) + +} diff --git a/profutils/server.go b/profutils/server.go new file mode 100644 index 0000000000..b8240cf25c --- /dev/null +++ b/profutils/server.go @@ -0,0 +1,116 @@ +package profutils + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + + "github.com/lyft/flytestdlib/config" + + "github.com/lyft/flytestdlib/version" + + "github.com/lyft/flytestdlib/logger" + "github.com/prometheus/client_golang/prometheus/promhttp" + + _ "net/http/pprof" // Import for pprof server +) + +const ( + healthcheck = "/healthcheck" + metricsPath = "/metrics" + versionPath = "/version" + configPath = "/config" +) + +const ( + contentTypeHeader = "Content-Type" + contentTypeJSON = "application/json; charset=utf-8" +) + +func WriteStringResponse(resp http.ResponseWriter, code int, body string) error { + resp.WriteHeader(code) + _, err := resp.Write([]byte(body)) + return err +} + +func WriteJSONResponse(resp http.ResponseWriter, code int, body interface{}) error { + resp.Header().Set(contentTypeHeader, contentTypeJSON) + resp.WriteHeader(code) + j, err := json.Marshal(body) + if err != nil { + return WriteStringResponse(resp, http.StatusInternalServerError, err.Error()) + } + return WriteStringResponse(resp, http.StatusOK, string(j)) +} + +func healtcheckHandler(w http.ResponseWriter, req *http.Request) { + err := WriteStringResponse(w, http.StatusOK, http.StatusText(http.StatusOK)) + if err != nil { + panic(err) + } +} + +func versionHandler(w http.ResponseWriter, req *http.Request) { + err := WriteStringResponse(w, http.StatusOK, fmt.Sprintf("Build [%s], Version [%s]", version.Build, version.Version)) + if err != nil { + panic(err) + } +} + +func configHandler(w http.ResponseWriter, req *http.Request) { + m, err := config.AllConfigsAsMap(config.GetRootSection()) + if err != nil { + err = WriteStringResponse(w, http.StatusInternalServerError, err.Error()) + if err != nil { + logger.Errorf(context.TODO(), "Failed to write error response. Error: %v", err) + panic(err) + } + } + + if err := WriteJSONResponse(w, http.StatusOK, m); err != nil { + panic(err) + } +} + +// Starts an http server on the given port +func StartProfilingServer(ctx context.Context, pprofPort int) error { + logger.Infof(ctx, "Starting profiling server on port [%v]", pprofPort) + e := http.ListenAndServe(fmt.Sprintf(":%d", pprofPort), nil) + if e != nil { + logger.Errorf(ctx, "Failed to start profiling server. Error: %v", e) + return fmt.Errorf("failed to start profiling server, %s", e) + } + + return nil +} + +func configureGlobalHTTPHandler(handlers map[string]http.Handler) error { + if handlers == nil { + handlers = map[string]http.Handler{} + } + handlers[metricsPath] = promhttp.Handler() + handlers[healthcheck] = http.HandlerFunc(healtcheckHandler) + handlers[versionPath] = http.HandlerFunc(versionHandler) + handlers[configPath] = http.HandlerFunc(configHandler) + + for p, h := range handlers { + http.Handle(p, h) + } + + return nil +} + +// Forwards the call to StartProfilingServer +// Also registers: +// 1. the prometheus HTTP handler on '/metrics' path shared with the profiling server. +// 2. A healthcheck (L7) handler on '/healthcheck'. +// 3. A version handler on '/version' provides information about the specific build. +// 4. A config handler on '/config' provides a dump of the currently loaded config. +func StartProfilingServerWithDefaultHandlers(ctx context.Context, pprofPort int, handlers map[string]http.Handler) error { + if err := configureGlobalHTTPHandler(handlers); err != nil { + return err + } + + return StartProfilingServer(ctx, pprofPort) +} diff --git a/profutils/server_test.go b/profutils/server_test.go new file mode 100644 index 0000000000..e2eb709943 --- /dev/null +++ b/profutils/server_test.go @@ -0,0 +1,103 @@ +package profutils + +import ( + "encoding/json" + "net/http" + "testing" + + "github.com/lyft/flytestdlib/internal/utils" + + "github.com/stretchr/testify/assert" +) + +type MockResponseWriter struct { + Status int + Headers http.Header + Body []byte +} + +func (m *MockResponseWriter) Write(b []byte) (int, error) { + m.Body = b + return 1, nil +} + +func (m *MockResponseWriter) WriteHeader(statusCode int) { + m.Status = statusCode +} + +func (m *MockResponseWriter) Header() http.Header { + return m.Headers +} + +type TestObj struct { + X int `json:"x"` +} + +func init() { + if err := configureGlobalHTTPHandler(nil); err != nil { + panic(err) + } +} + +func TestWriteJsonResponse(t *testing.T) { + m := &MockResponseWriter{Headers: http.Header{}} + assert.NoError(t, WriteJSONResponse(m, http.StatusOK, TestObj{10})) + assert.Equal(t, http.StatusOK, m.Status) + assert.Equal(t, http.Header{contentTypeHeader: []string{contentTypeJSON}}, m.Headers) + assert.Equal(t, `{"x":10}`, string(m.Body)) +} + +func TestWriteStringResponse(t *testing.T) { + m := &MockResponseWriter{Headers: http.Header{}} + assert.NoError(t, WriteStringResponse(m, http.StatusOK, "OK")) + assert.Equal(t, http.StatusOK, m.Status) + assert.Equal(t, "OK", string(m.Body)) +} + +func TestConfigHandler(t *testing.T) { + writer := &MockResponseWriter{Headers: http.Header{}} + testURL := utils.MustParseURL(configPath) + request := &http.Request{ + URL: &testURL, + } + + http.DefaultServeMux.ServeHTTP(writer, request) + assert.Equal(t, http.StatusOK, writer.Status) + + m := map[string]interface{}{} + assert.NoError(t, json.Unmarshal(writer.Body, &m)) + assert.Equal(t, map[string]interface{}{ + "logger": map[string]interface{}{ + "show-source": false, + "mute": false, + "level": float64(0), + "formatter": map[string]interface{}{ + "type": "", + }, + }, + }, m) +} + +func TestVersionHandler(t *testing.T) { + writer := &MockResponseWriter{Headers: http.Header{}} + testURL := utils.MustParseURL(versionPath) + request := &http.Request{ + URL: &testURL, + } + + http.DefaultServeMux.ServeHTTP(writer, request) + assert.Equal(t, http.StatusOK, writer.Status) + assert.Equal(t, `Build [unknown], Version [unknown]`, string(writer.Body)) +} + +func TestHealthcheckHandler(t *testing.T) { + writer := &MockResponseWriter{Headers: http.Header{}} + testURL := utils.MustParseURL(healthcheck) + request := &http.Request{ + URL: &testURL, + } + + http.DefaultServeMux.ServeHTTP(writer, request) + assert.Equal(t, http.StatusOK, writer.Status) + assert.Equal(t, `OK`, string(writer.Body)) +} diff --git a/promutils/labeled/counter.go b/promutils/labeled/counter.go new file mode 100644 index 0000000000..c68ce02d23 --- /dev/null +++ b/promutils/labeled/counter.go @@ -0,0 +1,65 @@ +package labeled + +import ( + "context" + + "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flytestdlib/promutils" + "github.com/prometheus/client_golang/prometheus" +) + +// Represents a counter labeled with values from the context. See labeled.SetMetricsKeys for information about to +// configure that. +type Counter struct { + *prometheus.CounterVec + + prometheus.Counter +} + +// Inc increments the counter by 1. Use Add to increment it by arbitrary non-negative values. The data point will be +// labeled with values from context. See labeled.SetMetricsKeys for information about to configure that. +func (c Counter) Inc(ctx context.Context) { + counter, err := c.CounterVec.GetMetricWith(contextutils.Values(ctx, metricKeys...)) + if err != nil { + panic(err.Error()) + } + counter.Inc() + + if c.Counter != nil { + c.Counter.Inc() + } +} + +// Add adds the given value to the counter. It panics if the value is < 0.. The data point will be labeled with values +// from context. See labeled.SetMetricsKeys for information about to configure that. +func (c Counter) Add(ctx context.Context, v float64) { + counter, err := c.CounterVec.GetMetricWith(contextutils.Values(ctx, metricKeys...)) + if err != nil { + panic(err.Error()) + } + counter.Add(v) + + if c.Counter != nil { + c.Counter.Add(v) + } +} + +// Creates a new labeled counter. Label keys must be set before instantiating a counter. See labeled.SetMetricsKeys for +// information about to configure that. +func NewCounter(name, description string, scope promutils.Scope, opts ...MetricOption) Counter { + if len(metricKeys) == 0 { + panic(ErrNeverSet) + } + + c := Counter{ + CounterVec: scope.MustNewCounterVec(name, description, metricStringKeys...), + } + + for _, opt := range opts { + if _, emitUnlabeledMetric := opt.(EmitUnlabeledMetricOption); emitUnlabeledMetric { + c.Counter = scope.MustNewCounter(GetUnlabeledMetricName(name), description) + } + } + + return c +} diff --git a/promutils/labeled/counter_test.go b/promutils/labeled/counter_test.go new file mode 100644 index 0000000000..130b8217a2 --- /dev/null +++ b/promutils/labeled/counter_test.go @@ -0,0 +1,31 @@ +package labeled + +import ( + "context" + "testing" + + "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flytestdlib/promutils" + "github.com/stretchr/testify/assert" +) + +func TestLabeledCounter(t *testing.T) { + assert.NotPanics(t, func() { + SetMetricKeys(contextutils.ProjectKey, contextutils.DomainKey, contextutils.WorkflowIDKey, contextutils.TaskIDKey) + }) + + scope := promutils.NewTestScope() + c := NewCounter("lbl_counter", "help", scope) + assert.NotNil(t, c) + ctx := context.TODO() + c.Inc(ctx) + c.Add(ctx, 1.0) + + ctx = contextutils.WithProjectDomain(ctx, "project", "domain") + c.Inc(ctx) + c.Add(ctx, 1.0) + + ctx = contextutils.WithTaskID(ctx, "task") + c.Inc(ctx) + c.Add(ctx, 1.0) +} diff --git a/promutils/labeled/keys.go b/promutils/labeled/keys.go new file mode 100644 index 0000000000..d8c8683750 --- /dev/null +++ b/promutils/labeled/keys.go @@ -0,0 +1,47 @@ +package labeled + +import ( + "fmt" + "reflect" + "sync" + + "github.com/lyft/flytestdlib/contextutils" +) + +var ( + ErrAlreadySet = fmt.Errorf("cannot set metric keys more than once") + ErrEmpty = fmt.Errorf("cannot set metric keys to an empty set") + ErrNeverSet = fmt.Errorf("must call SetMetricKeys prior to using labeled package") + + // Metric Keys to label metrics with. These keys get pulled from context if they are present. Use contextutils to fill + // them in. + metricKeys = make([]contextutils.Key, 0) + + // :(, we have to create a separate list to satisfy the MustNewCounterVec API as it accepts string only + metricStringKeys = make([]string, 0) + metricKeysAreSet = sync.Once{} +) + +// Sets keys to use with labeled metrics. The values of these keys will be pulled from context at runtime. +func SetMetricKeys(keys ...contextutils.Key) { + if len(keys) == 0 { + panic(ErrEmpty) + } + + ran := false + metricKeysAreSet.Do(func() { + ran = true + metricKeys = keys + for _, metricKey := range metricKeys { + metricStringKeys = append(metricStringKeys, metricKey.String()) + } + }) + + if !ran && !reflect.DeepEqual(keys, metricKeys) { + panic(ErrAlreadySet) + } +} + +func GetUnlabeledMetricName(metricName string) string { + return metricName + "_unlabeled" +} diff --git a/promutils/labeled/keys_test.go b/promutils/labeled/keys_test.go new file mode 100644 index 0000000000..4a8600aea3 --- /dev/null +++ b/promutils/labeled/keys_test.go @@ -0,0 +1,24 @@ +package labeled + +import ( + "testing" + + "github.com/lyft/flytestdlib/contextutils" + "github.com/stretchr/testify/assert" +) + +func TestMetricKeys(t *testing.T) { + input := []contextutils.Key{ + contextutils.ProjectKey, contextutils.DomainKey, contextutils.WorkflowIDKey, contextutils.TaskIDKey, + } + + assert.NotPanics(t, func() { SetMetricKeys(input...) }) + assert.Equal(t, input, metricKeys) + + for i, k := range metricKeys { + assert.Equal(t, k.String(), metricStringKeys[i]) + } + + assert.NotPanics(t, func() { SetMetricKeys(input...) }) + assert.Panics(t, func() { SetMetricKeys(contextutils.DomainKey) }) +} diff --git a/promutils/labeled/metric_option.go b/promutils/labeled/metric_option.go new file mode 100644 index 0000000000..08fb2f76f9 --- /dev/null +++ b/promutils/labeled/metric_option.go @@ -0,0 +1,15 @@ +package labeled + +// Defines extra set of options to customize the emitted metric. +type MetricOption interface { + isMetricOption() +} + +// Instructs the metric to emit unlabeled metric (besides the labeled one). This is useful to get overall system +// performance. +type EmitUnlabeledMetricOption struct { +} + +func (EmitUnlabeledMetricOption) isMetricOption() {} + +var EmitUnlabeledMetric = EmitUnlabeledMetricOption{} diff --git a/promutils/labeled/metric_option_test.go b/promutils/labeled/metric_option_test.go new file mode 100644 index 0000000000..0a070f7420 --- /dev/null +++ b/promutils/labeled/metric_option_test.go @@ -0,0 +1,13 @@ +package labeled + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMetricOption(t *testing.T) { + var opt MetricOption = &EmitUnlabeledMetric + _, isMetricOption := opt.(MetricOption) + assert.True(t, isMetricOption) +} diff --git a/promutils/labeled/stopwatch.go b/promutils/labeled/stopwatch.go new file mode 100644 index 0000000000..90e29971f9 --- /dev/null +++ b/promutils/labeled/stopwatch.go @@ -0,0 +1,87 @@ +package labeled + +import ( + "context" + "time" + + "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flytestdlib/promutils" +) + +type StopWatch struct { + *promutils.StopWatchVec + + // We use SummaryVec for emitting StopWatchVec, this computes percentiles per metric tags combination on the client- + // side. This makes it impossible to aggregate percentiles across tags (e.g. to have system-wide view). When enabled + // through a flag in the constructor, we initialize this additional untagged stopwatch to compute percentiles + // across tags. + promutils.StopWatch +} + +// Start creates a new Instance of the StopWatch called a Timer that is closeable/stoppable. +// Common pattern to time a scope would be +// { +// timer := stopWatch.Start(ctx) +// defer timer.Stop() +// .... +// } +func (c StopWatch) Start(ctx context.Context) Timer { + w, err := c.StopWatchVec.GetMetricWith(contextutils.Values(ctx, metricKeys...)) + if err != nil { + panic(err.Error()) + } + + if c.StopWatch.Observer == nil { + return w.Start() + } + + return timer{ + Timers: []Timer{ + w.Start(), + c.StopWatch.Start(), + }, + } +} + +// Observes specified duration between the start and end time. The data point will be labeled with values from context. +// See labeled.SetMetricsKeys for information about to configure that. +func (c StopWatch) Observe(ctx context.Context, start, end time.Time) { + w, err := c.StopWatchVec.GetMetricWith(contextutils.Values(ctx, metricKeys...)) + if err != nil { + panic(err.Error()) + } + w.Observe(start, end) + + if c.StopWatch.Observer != nil { + c.StopWatch.Observe(start, end) + } +} + +// This method observes the elapsed duration since the creation of the timer. The timer is created using a StopWatch. +// The data point will be labeled with values from context. See labeled.SetMetricsKeys for information about to +// configure that. +func (c StopWatch) Time(ctx context.Context, f func()) { + t := c.Start(ctx) + f() + t.Stop() +} + +// Creates a new labeled stopwatch. Label keys must be set before instantiating a counter. See labeled.SetMetricsKeys +// for information about to configure that. +func NewStopWatch(name, description string, scale time.Duration, scope promutils.Scope, opts ...MetricOption) StopWatch { + if len(metricKeys) == 0 { + panic(ErrNeverSet) + } + + sw := StopWatch{ + StopWatchVec: scope.MustNewStopWatchVec(name, description, scale, metricStringKeys...), + } + + for _, opt := range opts { + if _, emitUnableMetric := opt.(EmitUnlabeledMetricOption); emitUnableMetric { + sw.StopWatch = scope.MustNewStopWatch(GetUnlabeledMetricName(name), description, scale) + } + } + + return sw +} diff --git a/promutils/labeled/stopwatch_test.go b/promutils/labeled/stopwatch_test.go new file mode 100644 index 0000000000..d5adf0eade --- /dev/null +++ b/promutils/labeled/stopwatch_test.go @@ -0,0 +1,47 @@ +package labeled + +import ( + "context" + "testing" + "time" + + "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flytestdlib/promutils" + "github.com/stretchr/testify/assert" +) + +func TestLabeledStopWatch(t *testing.T) { + assert.NotPanics(t, func() { + SetMetricKeys(contextutils.ProjectKey, contextutils.DomainKey, contextutils.WorkflowIDKey, contextutils.TaskIDKey) + }) + + t.Run("always labeled", func(t *testing.T) { + scope := promutils.NewTestScope() + c := NewStopWatch("lbl_counter", "help", time.Second, scope) + assert.NotNil(t, c) + ctx := context.TODO() + w := c.Start(ctx) + w.Stop() + + ctx = contextutils.WithProjectDomain(ctx, "project", "domain") + w = c.Start(ctx) + w.Stop() + + ctx = contextutils.WithTaskID(ctx, "task") + w = c.Start(ctx) + w.Stop() + + c.Observe(ctx, time.Now(), time.Now().Add(time.Second)) + c.Time(ctx, func() { + // Do nothing + }) + }) + + t.Run("unlabeled", func(t *testing.T) { + scope := promutils.NewTestScope() + c := NewStopWatch("lbl_counter_2", "help", time.Second, scope, EmitUnlabeledMetric) + assert.NotNil(t, c) + + c.Start(context.TODO()) + }) +} diff --git a/promutils/labeled/timer_wrapper.go b/promutils/labeled/timer_wrapper.go new file mode 100644 index 0000000000..75aa4bee94 --- /dev/null +++ b/promutils/labeled/timer_wrapper.go @@ -0,0 +1,20 @@ +package labeled + +// Defines a common interface for timers. +type Timer interface { + // Stops the timer and reports observation. + Stop() float64 +} + +type timer struct { + Timers []Timer +} + +func (t timer) Stop() float64 { + var res float64 + for _, timer := range t.Timers { + res = timer.Stop() + } + + return res +} diff --git a/promutils/labeled/timer_wrapper_test.go b/promutils/labeled/timer_wrapper_test.go new file mode 100644 index 0000000000..375836c557 --- /dev/null +++ b/promutils/labeled/timer_wrapper_test.go @@ -0,0 +1,28 @@ +package labeled + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +type fakeTimer struct { + stopCount int +} + +func (f *fakeTimer) Stop() float64 { + f.stopCount++ + return 0 +} + +func TestTimerStop(t *testing.T) { + ft := &fakeTimer{} + tim := timer{ + Timers: []Timer{ + ft, ft, ft, + }, + } + + tim.Stop() + assert.Equal(t, 3, ft.stopCount) +} diff --git a/promutils/scope.go b/promutils/scope.go new file mode 100644 index 0000000000..d32fab4cfa --- /dev/null +++ b/promutils/scope.go @@ -0,0 +1,434 @@ +package promutils + +import ( + "strings" + "time" + + "k8s.io/apimachinery/pkg/util/rand" + + "github.com/prometheus/client_golang/prometheus" +) + +const defaultScopeDelimiterStr = ":" +const defaultMetricDelimiterStr = "_" + +func panicIfError(err error) { + if err != nil { + panic("Failed to register metrics. Error: " + err.Error()) + } +} + +// A Simple StopWatch that works with prometheus summary +// It will scale the output to match the expected time scale (milliseconds, seconds etc) +// NOTE: Do not create a StopWatch object by hand, use a Scope to get a new instance of the StopWatch object +type StopWatch struct { + prometheus.Observer + outputScale time.Duration +} + +// Start creates a new Instance of the StopWatch called a Timer that is closeable/stoppable. +// Common pattern to time a scope would be +// { +// timer := stopWatch.Start() +// defer timer.Stop() +// .... +// } +func (s StopWatch) Start() Timer { + return Timer{ + start: time.Now(), + outputScale: s.outputScale, + timer: s.Observer, + } +} + +// Observes specified duration between the start and end time +func (s StopWatch) Observe(start, end time.Time) { + observed := end.Sub(start).Nanoseconds() + outputScaleDuration := s.outputScale.Nanoseconds() + if outputScaleDuration == 0 { + s.Observer.Observe(0) + return + } + scaled := float64(observed / outputScaleDuration) + s.Observer.Observe(scaled) +} + +// Observes/records the time to execute the given function synchronously +func (s StopWatch) Time(f func()) { + t := s.Start() + f() + t.Stop() +} + +// A Simple StopWatch that works with prometheus summary +// It will scale the output to match the expected time scale (milliseconds, seconds etc) +// NOTE: Do not create a StopWatch object by hand, use a Scope to get a new instance of the StopWatch object +type StopWatchVec struct { + *prometheus.SummaryVec + outputScale time.Duration +} + +// Gets a concrete StopWatch instance that can be used to start a timer and record observations. +func (s StopWatchVec) WithLabelValues(values ...string) StopWatch { + return StopWatch{ + Observer: s.SummaryVec.WithLabelValues(values...), + outputScale: s.outputScale, + } +} + +func (s StopWatchVec) GetMetricWith(labels prometheus.Labels) (StopWatch, error) { + sVec, err := s.SummaryVec.GetMetricWith(labels) + if err != nil { + return StopWatch{}, err + } + return StopWatch{ + Observer: sVec, + outputScale: s.outputScale, + }, nil +} + +// This is a stoppable instance of a StopWatch or a Timer +// A Timer can only be stopped. On stopping it will output the elapsed duration to prometheus +type Timer struct { + start time.Time + outputScale time.Duration + timer prometheus.Observer +} + +// This method observes the elapsed duration since the creation of the timer. The timer is created using a StopWatch +func (s Timer) Stop() float64 { + observed := time.Since(s.start).Nanoseconds() + outputScaleDuration := s.outputScale.Nanoseconds() + if outputScaleDuration == 0 { + s.timer.Observe(0) + return 0 + } + scaled := float64(observed / outputScaleDuration) + s.timer.Observe(scaled) + return scaled +} + +// A Scope represents a prefix in Prometheus. It is nestable, thus every metric that is published does not need to +// provide a prefix, but just the name of the metric. As long as the Scope is used to create a new instance of the metric +// The prefix (or scope) is automatically set. +type Scope interface { + // Creates new prometheus.Gauge metric with the prefix as the CurrentScope + // Name is a string that follows prometheus conventions (mostly [_a-z]) + // Refer to https://prometheus.io/docs/concepts/metric_types/ for more information + NewGauge(name, description string) (prometheus.Gauge, error) + MustNewGauge(name, description string) prometheus.Gauge + + // Creates new prometheus.GaugeVec metric with the prefix as the CurrentScope + // Refer to https://prometheus.io/docs/concepts/metric_types/ for more information + NewGaugeVec(name, description string, labelNames ...string) (*prometheus.GaugeVec, error) + MustNewGaugeVec(name, description string, labelNames ...string) *prometheus.GaugeVec + + // Creates new prometheus.Summary metric with the prefix as the CurrentScope + // Refer to https://prometheus.io/docs/concepts/metric_types/ for more information + NewSummary(name, description string) (prometheus.Summary, error) + MustNewSummary(name, description string) prometheus.Summary + + // Creates new prometheus.SummaryVec metric with the prefix as the CurrentScope + // Refer to https://prometheus.io/docs/concepts/metric_types/ for more information + NewSummaryVec(name, description string, labelNames ...string) (*prometheus.SummaryVec, error) + MustNewSummaryVec(name, description string, labelNames ...string) *prometheus.SummaryVec + + // Creates new prometheus.Histogram metric with the prefix as the CurrentScope + // Refer to https://prometheus.io/docs/concepts/metric_types/ for more information + NewHistogram(name, description string) (prometheus.Histogram, error) + MustNewHistogram(name, description string) prometheus.Histogram + + // Creates new prometheus.HistogramVec metric with the prefix as the CurrentScope + // Refer to https://prometheus.io/docs/concepts/metric_types/ for more information + NewHistogramVec(name, description string, labelNames ...string) (*prometheus.HistogramVec, error) + MustNewHistogramVec(name, description string, labelNames ...string) *prometheus.HistogramVec + + // Creates new prometheus.Counter metric with the prefix as the CurrentScope + // Refer to https://prometheus.io/docs/concepts/metric_types/ for more information + // Important to note, counters are not like typical counters. These are ever increasing and cumulative. + // So if you want to observe counters within buckets use Summary/Histogram + NewCounter(name, description string) (prometheus.Counter, error) + MustNewCounter(name, description string) prometheus.Counter + + // Creates new prometheus.GaugeVec metric with the prefix as the CurrentScope + // Refer to https://prometheus.io/docs/concepts/metric_types/ for more information + NewCounterVec(name, description string, labelNames ...string) (*prometheus.CounterVec, error) + MustNewCounterVec(name, description string, labelNames ...string) *prometheus.CounterVec + + // This is a custom wrapper to create a StopWatch object in the current Scope. + // Duration is to specify the scale of the Timer. For example if you are measuring times in milliseconds + // pass scale=times.Millisecond + // https://golang.org/pkg/time/#Duration + // The metric name is auto-suffixed with the right scale. Refer to DurationToString to understand + NewStopWatch(name, description string, scale time.Duration) (StopWatch, error) + MustNewStopWatch(name, description string, scale time.Duration) StopWatch + + // This is a custom wrapper to create a StopWatch object in the current Scope. + // Duration is to specify the scale of the Timer. For example if you are measuring times in milliseconds + // pass scale=times.Millisecond + // https://golang.org/pkg/time/#Duration + // The metric name is auto-suffixed with the right scale. Refer to DurationToString to understand + NewStopWatchVec(name, description string, scale time.Duration, labelNames ...string) (*StopWatchVec, error) + MustNewStopWatchVec(name, description string, scale time.Duration, labelNames ...string) *StopWatchVec + + // In case nesting is desired for metrics, create a new subScope. This is generally useful in creating + // Scoped and SubScoped metrics + NewSubScope(name string) Scope + + // Returns the current ScopeName. Use for creating your own metrics + CurrentScope() string + + // Method that provides a scoped metric name. Can be used, if you want to directly create your own metric + NewScopedMetricName(name string) string +} + +type metricsScope struct { + scope string +} + +func (m metricsScope) NewGauge(name, description string) (prometheus.Gauge, error) { + g := prometheus.NewGauge( + prometheus.GaugeOpts{ + Name: m.NewScopedMetricName(name), + Help: description, + }, + ) + return g, prometheus.Register(g) +} + +func (m metricsScope) MustNewGauge(name, description string) prometheus.Gauge { + g, err := m.NewGauge(name, description) + panicIfError(err) + return g +} + +func (m metricsScope) NewGaugeVec(name, description string, labelNames ...string) (*prometheus.GaugeVec, error) { + g := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: m.NewScopedMetricName(name), + Help: description, + }, + labelNames, + ) + return g, prometheus.Register(g) +} + +func (m metricsScope) MustNewGaugeVec(name, description string, labelNames ...string) *prometheus.GaugeVec { + g, err := m.NewGaugeVec(name, description, labelNames...) + panicIfError(err) + return g +} + +func (m metricsScope) NewSummary(name, description string) (prometheus.Summary, error) { + s := prometheus.NewSummary( + prometheus.SummaryOpts{ + Name: m.NewScopedMetricName(name), + Help: description, + }, + ) + + return s, prometheus.Register(s) +} + +func (m metricsScope) MustNewSummary(name, description string) prometheus.Summary { + s, err := m.NewSummary(name, description) + panicIfError(err) + return s +} + +func (m metricsScope) NewSummaryVec(name, description string, labelNames ...string) (*prometheus.SummaryVec, error) { + s := prometheus.NewSummaryVec( + prometheus.SummaryOpts{ + Name: m.NewScopedMetricName(name), + Help: description, + }, + labelNames, + ) + + return s, prometheus.Register(s) +} +func (m metricsScope) MustNewSummaryVec(name, description string, labelNames ...string) *prometheus.SummaryVec { + s, err := m.NewSummaryVec(name, description, labelNames...) + panicIfError(err) + return s +} + +func (m metricsScope) NewHistogram(name, description string) (prometheus.Histogram, error) { + h := prometheus.NewHistogram( + prometheus.HistogramOpts{ + Name: m.NewScopedMetricName(name), + Help: description, + }, + ) + return h, prometheus.Register(h) +} + +func (m metricsScope) MustNewHistogram(name, description string) prometheus.Histogram { + h, err := m.NewHistogram(name, description) + panicIfError(err) + return h +} + +func (m metricsScope) NewHistogramVec(name, description string, labelNames ...string) (*prometheus.HistogramVec, error) { + h := prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Name: m.NewScopedMetricName(name), + Help: description, + }, + labelNames, + ) + return h, prometheus.Register(h) +} + +func (m metricsScope) MustNewHistogramVec(name, description string, labelNames ...string) *prometheus.HistogramVec { + h, err := m.NewHistogramVec(name, description, labelNames...) + panicIfError(err) + return h +} + +func (m metricsScope) NewCounter(name, description string) (prometheus.Counter, error) { + c := prometheus.NewCounter( + prometheus.CounterOpts{ + Name: m.NewScopedMetricName(name), + Help: description, + }, + ) + return c, prometheus.Register(c) +} + +func (m metricsScope) MustNewCounter(name, description string) prometheus.Counter { + c, err := m.NewCounter(name, description) + panicIfError(err) + return c +} + +func (m metricsScope) NewCounterVec(name, description string, labelNames ...string) (*prometheus.CounterVec, error) { + c := prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: m.NewScopedMetricName(name), + Help: description, + }, + labelNames, + ) + return c, prometheus.Register(c) +} + +func (m metricsScope) MustNewCounterVec(name, description string, labelNames ...string) *prometheus.CounterVec { + c, err := m.NewCounterVec(name, description, labelNames...) + panicIfError(err) + return c +} + +func (m metricsScope) NewStopWatch(name, description string, scale time.Duration) (StopWatch, error) { + if !strings.HasSuffix(name, defaultMetricDelimiterStr) { + name += defaultMetricDelimiterStr + } + name += DurationToString(scale) + s, err := m.NewSummary(name, description) + if err != nil { + return StopWatch{}, err + } + + return StopWatch{ + Observer: s, + outputScale: scale, + }, nil +} + +func (m metricsScope) MustNewStopWatch(name, description string, scale time.Duration) StopWatch { + s, err := m.NewStopWatch(name, description, scale) + panicIfError(err) + return s +} + +func (m metricsScope) NewStopWatchVec(name, description string, scale time.Duration, labelNames ...string) (*StopWatchVec, error) { + if !strings.HasSuffix(name, defaultMetricDelimiterStr) { + name += defaultMetricDelimiterStr + } + name += DurationToString(scale) + s, err := m.NewSummaryVec(name, description, labelNames...) + if err != nil { + return &StopWatchVec{}, err + } + + return &StopWatchVec{ + SummaryVec: s, + outputScale: scale, + }, nil +} + +func (m metricsScope) MustNewStopWatchVec(name, description string, scale time.Duration, labelNames ...string) *StopWatchVec { + s, err := m.NewStopWatchVec(name, description, scale, labelNames...) + panicIfError(err) + return s +} + +func (m metricsScope) CurrentScope() string { + return m.scope +} + +// Creates a metric name under the scope. Scope will always have a defaultScopeDelimiterRune as the last character +func (m metricsScope) NewScopedMetricName(name string) string { + if name == "" { + panic("metric name cannot be an empty string") + } + + return m.scope + name +} + +func (m metricsScope) NewSubScope(subscopeName string) Scope { + if subscopeName == "" { + panic("scope name cannot be an empty string") + } + + // If the last character of the new subscope is already a defaultScopeDelimiterRune, do not add anything + if !strings.HasSuffix(subscopeName, defaultScopeDelimiterStr) { + subscopeName += defaultScopeDelimiterStr + } + + // Always add a new defaultScopeDelimiterRune to every scope name + return NewScope(m.scope + subscopeName) +} + +// Creates a new scope in the format `name + defaultScopeDelimiterRune` +// If the last character is already a defaultScopeDelimiterRune, then it does not add it to the scope name +func NewScope(name string) Scope { + if name == "" { + panic("base scope for a metric cannot be an empty string") + } + + // If the last character of the new subscope is already a defaultScopeDelimiterRune, do not add anything + if !strings.HasSuffix(name, defaultScopeDelimiterStr) { + name += defaultScopeDelimiterStr + } + + return metricsScope{ + scope: name, + } +} + +// Returns a randomly-named scope for use in tests. +// Prometheus requires that metric names begin with a single word, which is generated from the alphabetic testScopeNameCharset. +func NewTestScope() Scope { + return NewScope("test" + rand.String(6)) +} + +// DurationToString converts the duration to a string suffix that indicates the scale of the timer. +func DurationToString(duration time.Duration) string { + if duration >= time.Hour { + return "h" + } + if duration >= time.Minute { + return "m" + } + if duration >= time.Second { + return "s" + } + if duration >= time.Millisecond { + return "ms" + } + if duration >= time.Microsecond { + return "us" + } + return "ns" +} diff --git a/promutils/scope_test.go b/promutils/scope_test.go new file mode 100644 index 0000000000..cb076ee980 --- /dev/null +++ b/promutils/scope_test.go @@ -0,0 +1,151 @@ +package promutils + +import ( + "testing" + "time" + + "k8s.io/apimachinery/pkg/util/rand" + + "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/assert" +) + +func TestDurationToString(t *testing.T) { + assert.Equal(t, "m", DurationToString(time.Minute)) + assert.Equal(t, "m", DurationToString(time.Minute*10)) + assert.Equal(t, "h", DurationToString(time.Hour)) + assert.Equal(t, "h", DurationToString(time.Hour*10)) + assert.Equal(t, "s", DurationToString(time.Second)) + assert.Equal(t, "s", DurationToString(time.Second*10)) + assert.Equal(t, "us", DurationToString(time.Microsecond*10)) + assert.Equal(t, "us", DurationToString(time.Microsecond)) + assert.Equal(t, "ms", DurationToString(time.Millisecond*10)) + assert.Equal(t, "ms", DurationToString(time.Millisecond)) + assert.Equal(t, "ns", DurationToString(1)) +} + +func TestNewScope(t *testing.T) { + assert.Panics(t, func() { + NewScope("") + }) + s := NewScope("test") + assert.Equal(t, "test:", s.CurrentScope()) + assert.Equal(t, "test:hello:", s.NewSubScope("hello").CurrentScope()) + assert.Panics(t, func() { + s.NewSubScope("") + }) + assert.Equal(t, "test:timer_x", s.NewScopedMetricName("timer_x")) + assert.Equal(t, "test:hello:timer:", s.NewSubScope("hello").NewSubScope("timer").CurrentScope()) + assert.Equal(t, "test:hello:timer:", s.NewSubScope("hello").NewSubScope("timer:").CurrentScope()) +} + +func TestMetricsScope(t *testing.T) { + s := NewScope("test") + const description = "some x" + if !assert.NotNil(t, prometheus.DefaultRegisterer) { + assert.Fail(t, "Prometheus registrar failed") + } + t.Run("Counter", func(t *testing.T) { + m := s.MustNewCounter("xc", description) + assert.Equal(t, `Desc{fqName: "test:xc", help: "some x", constLabels: {}, variableLabels: []}`, m.Desc().String()) + mv := s.MustNewCounterVec("xcv", description) + assert.NotNil(t, mv) + assert.Panics(t, func() { + _ = s.MustNewCounter("xc", description) + }) + assert.Panics(t, func() { + _ = s.MustNewCounterVec("xcv", description) + }) + }) + + t.Run("Histogram", func(t *testing.T) { + m := s.MustNewHistogram("xh", description) + assert.Equal(t, `Desc{fqName: "test:xh", help: "some x", constLabels: {}, variableLabels: []}`, m.Desc().String()) + mv := s.MustNewHistogramVec("xhv", description) + assert.NotNil(t, mv) + assert.Panics(t, func() { + _ = s.MustNewHistogram("xh", description) + }) + assert.Panics(t, func() { + _ = s.MustNewHistogramVec("xhv", description) + }) + }) + + t.Run("Summary", func(t *testing.T) { + m := s.MustNewSummary("xs", description) + assert.Equal(t, `Desc{fqName: "test:xs", help: "some x", constLabels: {}, variableLabels: []}`, m.Desc().String()) + mv := s.MustNewSummaryVec("xsv", description) + assert.NotNil(t, mv) + assert.Panics(t, func() { + _ = s.MustNewSummary("xs", description) + }) + assert.Panics(t, func() { + _ = s.MustNewSummaryVec("xsv", description) + }) + }) + + t.Run("Gauge", func(t *testing.T) { + m := s.MustNewGauge("xg", description) + assert.Equal(t, `Desc{fqName: "test:xg", help: "some x", constLabels: {}, variableLabels: []}`, m.Desc().String()) + mv := s.MustNewGaugeVec("xgv", description) + assert.NotNil(t, mv) + assert.Panics(t, func() { + m = s.MustNewGauge("xg", description) + }) + assert.Panics(t, func() { + _ = s.MustNewGaugeVec("xgv", description) + }) + }) + + t.Run("Timer", func(t *testing.T) { + m := s.MustNewStopWatch("xt", description, time.Second) + asDesc, ok := m.Observer.(prometheus.Metric) + assert.True(t, ok) + assert.Equal(t, `Desc{fqName: "test:xt_s", help: "some x", constLabels: {}, variableLabels: []}`, asDesc.Desc().String()) + assert.Panics(t, func() { + _ = s.MustNewStopWatch("xt", description, time.Second) + }) + }) + +} + +func TestStopWatch_Start(t *testing.T) { + scope := NewTestScope() + s, e := scope.NewStopWatch("yt"+rand.String(3), "timer", time.Millisecond) + assert.NoError(t, e) + assert.Equal(t, time.Millisecond, s.outputScale) + i := s.Start() + assert.Equal(t, time.Millisecond, i.outputScale) + assert.NotNil(t, i.start) + i.Stop() +} + +func TestStopWatch_Observe(t *testing.T) { + scope := NewTestScope() + s, e := scope.NewStopWatch("yt"+rand.String(3), "timer", time.Millisecond) + assert.NoError(t, e) + assert.Equal(t, time.Millisecond, s.outputScale) + s.Observe(time.Now(), time.Now().Add(time.Second)) +} + +func TestStopWatch_Time(t *testing.T) { + scope := NewTestScope() + s, e := scope.NewStopWatch("yt"+rand.String(3), "timer", time.Millisecond) + assert.NoError(t, e) + assert.Equal(t, time.Millisecond, s.outputScale) + s.Time(func() { + }) +} + +func TestStopWatchVec_WithLabelValues(t *testing.T) { + scope := NewTestScope() + v, e := scope.NewStopWatchVec("yt"+rand.String(3), "timer", time.Millisecond, "workflow", "label") + assert.NoError(t, e) + assert.Equal(t, time.Millisecond, v.outputScale) + s := v.WithLabelValues("my_wf", "something") + assert.NotNil(t, s) + i := s.Start() + assert.Equal(t, time.Millisecond, i.outputScale) + assert.NotNil(t, i.start) + i.Stop() +} diff --git a/promutils/workqueue.go b/promutils/workqueue.go new file mode 100644 index 0000000000..24de6a0053 --- /dev/null +++ b/promutils/workqueue.go @@ -0,0 +1,82 @@ +// Source: https://raw.githubusercontent.com/kubernetes/kubernetes/3dbbd0bdf44cb07fdde85aa392adf99ea7e95939/pkg/util/workqueue/prometheus/prometheus.go +/* +Copyright 2016 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package promutils + +import ( + "k8s.io/client-go/util/workqueue" + + "github.com/prometheus/client_golang/prometheus" +) + +// Package prometheus sets the workqueue DefaultMetricsFactory to produce +// prometheus metrics. To use this package, you just have to import it. + +func init() { + workqueue.SetProvider(prometheusMetricsProvider{}) +} + +type prometheusMetricsProvider struct{} + +func (prometheusMetricsProvider) NewDepthMetric(name string) workqueue.GaugeMetric { + depth := prometheus.NewGauge(prometheus.GaugeOpts{ + Subsystem: name, + Name: "depth", + Help: "Current depth of workqueue: " + name, + }) + prometheus.MustRegister(depth) + return depth +} + +func (prometheusMetricsProvider) NewAddsMetric(name string) workqueue.CounterMetric { + adds := prometheus.NewCounter(prometheus.CounterOpts{ + Subsystem: name, + Name: "adds", + Help: "Total number of adds handled by workqueue: " + name, + }) + prometheus.MustRegister(adds) + return adds +} + +func (prometheusMetricsProvider) NewLatencyMetric(name string) workqueue.SummaryMetric { + latency := prometheus.NewSummary(prometheus.SummaryOpts{ + Subsystem: name, + Name: "queue_latency_us", + Help: "How long an item stays in workqueue" + name + " before being requested.", + }) + prometheus.MustRegister(latency) + return latency +} + +func (prometheusMetricsProvider) NewWorkDurationMetric(name string) workqueue.SummaryMetric { + workDuration := prometheus.NewSummary(prometheus.SummaryOpts{ + Subsystem: name, + Name: "work_duration_us", + Help: "How long processing an item from workqueue" + name + " takes.", + }) + prometheus.MustRegister(workDuration) + return workDuration +} + +func (prometheusMetricsProvider) NewRetriesMetric(name string) workqueue.CounterMetric { + retries := prometheus.NewCounter(prometheus.CounterOpts{ + Subsystem: name, + Name: "retries", + Help: "Total number of retries handled by workqueue: " + name, + }) + prometheus.MustRegister(retries) + return retries +} diff --git a/promutils/workqueue_test.go b/promutils/workqueue_test.go new file mode 100644 index 0000000000..4c5bbcae9e --- /dev/null +++ b/promutils/workqueue_test.go @@ -0,0 +1,42 @@ +package promutils + +import ( + "testing" + + "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/assert" +) + +var provider = prometheusMetricsProvider{} + +func TestPrometheusMetricsProvider(t *testing.T) { + t.Run("Adds", func(t *testing.T) { + c := provider.NewAddsMetric("x") + _, ok := c.(prometheus.Counter) + assert.True(t, ok) + }) + + t.Run("Depth", func(t *testing.T) { + c := provider.NewDepthMetric("x") + _, ok := c.(prometheus.Gauge) + assert.True(t, ok) + }) + + t.Run("Latency", func(t *testing.T) { + c := provider.NewLatencyMetric("x") + _, ok := c.(prometheus.Summary) + assert.True(t, ok) + }) + + t.Run("Retries", func(t *testing.T) { + c := provider.NewRetriesMetric("x") + _, ok := c.(prometheus.Counter) + assert.True(t, ok) + }) + + t.Run("WorkDuration", func(t *testing.T) { + c := provider.NewWorkDurationMetric("x") + _, ok := c.(prometheus.Summary) + assert.True(t, ok) + }) +} diff --git a/sets/generic_set.go b/sets/generic_set.go new file mode 100644 index 0000000000..abc739f8aa --- /dev/null +++ b/sets/generic_set.go @@ -0,0 +1,195 @@ +package sets + +import ( + "sort" +) + +type SetObject interface { + GetID() string +} + +type Generic map[string]SetObject + +// New creates a Generic from a list of values. +func NewGeneric(items ...SetObject) Generic { + gs := Generic{} + gs.Insert(items...) + return gs +} + +// Insert adds items to the set. +func (g Generic) Insert(items ...SetObject) { + for _, item := range items { + g[item.GetID()] = item + } +} + +// Delete removes all items from the set. +func (g Generic) Delete(items ...SetObject) { + for _, item := range items { + delete(g, item.GetID()) + } +} + +// Has returns true if and only if item is contained in the set. +func (g Generic) Has(item SetObject) bool { + _, contained := g[item.GetID()] + return contained +} + +// HasAll returns true if and only if all items are contained in the set. +func (g Generic) HasAll(items ...SetObject) bool { + for _, item := range items { + if !g.Has(item) { + return false + } + } + return true +} + +// HasAny returns true if any items are contained in the set. +func (g Generic) HasAny(items ...SetObject) bool { + for _, item := range items { + if g.Has(item) { + return true + } + } + return false +} + +// Difference returns a set of objects that are not in s2 +// For example: +// s1 = {a1, a2, a3} +// s2 = {a1, a2, a4, a5} +// s1.Difference(s2) = {a3} +// s2.Difference(s1) = {a4, a5} +func (g Generic) Difference(g2 Generic) Generic { + result := NewGeneric() + for _, v := range g { + if !g2.Has(v) { + result.Insert(v) + } + } + return result +} + +// Union returns a new set which includes items in either s1 or s2. +// For example: +// s1 = {a1, a2} +// s2 = {a3, a4} +// s1.Union(s2) = {a1, a2, a3, a4} +// s2.Union(s1) = {a1, a2, a3, a4} +func (g Generic) Union(s2 Generic) Generic { + result := NewGeneric() + for _, v := range g { + result.Insert(v) + } + for _, v := range s2 { + result.Insert(v) + } + return result +} + +// Intersection returns a new set which includes the item in BOTH s1 and s2 +// For example: +// s1 = {a1, a2} +// s2 = {a2, a3} +// s1.Intersection(s2) = {a2} +func (g Generic) Intersection(s2 Generic) Generic { + var walk, other Generic + result := NewGeneric() + if g.Len() < s2.Len() { + walk = g + other = s2 + } else { + walk = s2 + other = g + } + for _, v := range walk { + if other.Has(v) { + result.Insert(v) + } + } + return result +} + +// IsSuperset returns true if and only if s1 is a superset of s2. +func (g Generic) IsSuperset(s2 Generic) bool { + for _, v := range s2 { + if !g.Has(v) { + return false + } + } + return true +} + +// Equal returns true if and only if s1 is equal (as a set) to s2. +// Two sets are equal if their membership is identical. +// (In practice, this means same elements, order doesn't matter) +func (g Generic) Equal(s2 Generic) bool { + return len(g) == len(s2) && g.IsSuperset(s2) +} + +type sortableSliceOfGeneric []string + +func (s sortableSliceOfGeneric) Len() int { return len(s) } +func (s sortableSliceOfGeneric) Less(i, j int) bool { return lessString(s[i], s[j]) } +func (s sortableSliceOfGeneric) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +// List returns the contents as a sorted string slice. +func (g Generic) ListKeys() []string { + res := make(sortableSliceOfGeneric, 0, len(g)) + for key := range g { + res = append(res, key) + } + sort.Sort(res) + return []string(res) +} + +// List returns the contents as a sorted string slice. +func (g Generic) List() []SetObject { + keys := g.ListKeys() + res := make([]SetObject, 0, len(keys)) + for _, k := range keys { + s := g[k] + res = append(res, s) + } + return res +} + +// UnsortedList returns the slice with contents in random order. +func (g Generic) UnsortedListKeys() []string { + res := make([]string, 0, len(g)) + for key := range g { + res = append(res, key) + } + return res +} + +// UnsortedList returns the slice with contents in random order. +func (g Generic) UnsortedList() []SetObject { + res := make([]SetObject, 0, len(g)) + for _, v := range g { + res = append(res, v) + } + return res +} + +// Returns a single element from the set. +func (g Generic) PopAny() (SetObject, bool) { + for _, v := range g { + g.Delete(v) + return v, true + } + var zeroValue SetObject + return zeroValue, false +} + +// Len returns the size of the set. +func (g Generic) Len() int { + return len(g) +} + +func lessString(lhs, rhs string) bool { + return lhs < rhs +} diff --git a/sets/generic_set_test.go b/sets/generic_set_test.go new file mode 100644 index 0000000000..9d9f165ed5 --- /dev/null +++ b/sets/generic_set_test.go @@ -0,0 +1,116 @@ +package sets + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +type GenericVal string + +func (g GenericVal) GetID() string { + return string(g) +} + +func TestGenericSet(t *testing.T) { + assert.Equal(t, []string{"a", "b"}, NewGeneric(GenericVal("a"), GenericVal("b")).ListKeys()) + assert.Equal(t, []string{"a", "b"}, NewGeneric(GenericVal("a"), GenericVal("b"), GenericVal("a")).ListKeys()) + + { + g1 := NewGeneric(GenericVal("a"), GenericVal("b")) + g2 := NewGeneric(GenericVal("a"), GenericVal("b"), GenericVal("c")) + { + g := g1.Intersection(g2) + assert.Equal(t, []string{"a", "b"}, g.ListKeys()) + } + { + g := g2.Intersection(g1) + assert.Equal(t, []string{"a", "b"}, g.ListKeys()) + } + } + { + g1 := NewGeneric(GenericVal("a"), GenericVal("b")) + g2 := NewGeneric(GenericVal("a"), GenericVal("b"), GenericVal("c")) + g := g2.Difference(g1) + assert.Equal(t, []string{"c"}, g.ListKeys()) + } + { + g1 := NewGeneric(GenericVal("a"), GenericVal("b")) + g2 := NewGeneric(GenericVal("a"), GenericVal("b"), GenericVal("c")) + g := g1.Difference(g2) + assert.Equal(t, []string{}, g.ListKeys()) + } + { + g1 := NewGeneric(GenericVal("a"), GenericVal("b")) + assert.True(t, g1.Has(GenericVal("a"))) + assert.False(t, g1.Has(GenericVal("c"))) + } + { + g1 := NewGeneric(GenericVal("a"), GenericVal("b")) + assert.False(t, g1.HasAll(GenericVal("a"), GenericVal("b"), GenericVal("c"))) + assert.True(t, g1.HasAll(GenericVal("a"), GenericVal("b"))) + } + + { + g1 := NewGeneric(GenericVal("a"), GenericVal("b")) + g2 := NewGeneric(GenericVal("a"), GenericVal("b"), GenericVal("c")) + g := g1.Union(g2) + assert.Equal(t, []string{"a", "b", "c"}, g.ListKeys()) + } + + { + g1 := NewGeneric(GenericVal("a"), GenericVal("b")) + g2 := NewGeneric(GenericVal("a"), GenericVal("b"), GenericVal("c")) + assert.True(t, g2.IsSuperset(g1)) + assert.False(t, g1.IsSuperset(g2)) + } + + { + g1 := NewGeneric(GenericVal("a"), GenericVal("b")) + assert.True(t, g1.Has(GenericVal("a"))) + assert.False(t, g1.Has(GenericVal("c"))) + assert.Equal(t, []SetObject{ + GenericVal("a"), + GenericVal("b"), + }, g1.List()) + g2 := NewGeneric(g1.UnsortedList()...) + assert.True(t, g1.Equal(g2)) + } + { + g1 := NewGeneric(GenericVal("a"), GenericVal("b")) + g2 := NewGeneric(GenericVal("a"), GenericVal("b"), GenericVal("c")) + g3 := NewGeneric(GenericVal("a"), GenericVal("b")) + assert.False(t, g1.Equal(g2)) + assert.True(t, g1.Equal(g3)) + + assert.Equal(t, 2, g1.Len()) + g1.Insert(GenericVal("b")) + assert.Equal(t, 2, g1.Len()) + assert.True(t, g1.Equal(g3)) + g1.Insert(GenericVal("c")) + assert.Equal(t, 3, g1.Len()) + assert.True(t, g1.Equal(g2)) + assert.True(t, g1.HasAny(GenericVal("a"), GenericVal("d"))) + assert.False(t, g1.HasAny(GenericVal("f"), GenericVal("d"))) + g1.Delete(GenericVal("f")) + assert.True(t, g1.Equal(g2)) + g1.Delete(GenericVal("c")) + assert.True(t, g1.Equal(g3)) + + { + p, ok := g1.PopAny() + assert.NotNil(t, p) + assert.True(t, ok) + } + { + p, ok := g1.PopAny() + assert.NotNil(t, p) + assert.True(t, ok) + } + { + p, ok := g1.PopAny() + assert.Nil(t, p) + assert.False(t, ok) + } + } +} diff --git a/storage/cached_rawstore.go b/storage/cached_rawstore.go new file mode 100644 index 0000000000..2b539d7bb1 --- /dev/null +++ b/storage/cached_rawstore.go @@ -0,0 +1,123 @@ +package storage + +import ( + "bytes" + "context" + "io" + "runtime/debug" + "time" + + "github.com/prometheus/client_golang/prometheus" + + "github.com/lyft/flytestdlib/promutils" + + "github.com/coocood/freecache" + "github.com/lyft/flytestdlib/ioutils" + "github.com/lyft/flytestdlib/logger" +) + +const neverExpire = 0 + +// TODO Freecache has bunch of metrics it calculates. Lets write a prom collector to publish these metrics +type cacheMetrics struct { + CacheHit prometheus.Counter + CacheMiss prometheus.Counter + CacheWriteError prometheus.Counter + FetchLatency promutils.StopWatch +} + +type cachedRawStore struct { + RawStore + cache *freecache.Cache + scope promutils.Scope + metrics *cacheMetrics +} + +// Gets metadata about the reference. This should generally be a light weight operation. +func (s *cachedRawStore) Head(ctx context.Context, reference DataReference) (Metadata, error) { + key := []byte(reference) + if oRaw, err := s.cache.Get(key); err == nil { + s.metrics.CacheHit.Inc() + // Found, Cache hit + size := int64(len(oRaw)) + // return size in metadata + return StowMetadata{exists: true, size: size}, nil + } + s.metrics.CacheMiss.Inc() + return s.RawStore.Head(ctx, reference) +} + +// Retrieves a byte array from the Blob store or an error +func (s *cachedRawStore) ReadRaw(ctx context.Context, reference DataReference) (io.ReadCloser, error) { + key := []byte(reference) + if oRaw, err := s.cache.Get(key); err == nil { + // Found, Cache hit + s.metrics.CacheHit.Inc() + return ioutils.NewBytesReadCloser(oRaw), nil + } + s.metrics.CacheMiss.Inc() + reader, err := s.RawStore.ReadRaw(ctx, reference) + if err != nil { + return nil, err + } + + defer func() { + err = reader.Close() + if err != nil { + logger.Warnf(ctx, "Failed to close reader [%v]. Error: %v", reference, err) + } + }() + + b, err := ioutils.ReadAll(reader, s.metrics.FetchLatency.Start()) + if err != nil { + return nil, err + } + + err = s.cache.Set(key, b, 0) + if err != nil { + // TODO Ignore errors in writing to cache? + logger.Debugf(ctx, "Failed to Cache the metadata") + } + + return ioutils.NewBytesReadCloser(b), nil +} + +// Stores a raw byte array. +func (s *cachedRawStore) WriteRaw(ctx context.Context, reference DataReference, size int64, opts Options, raw io.Reader) error { + var buf bytes.Buffer + teeReader := io.TeeReader(raw, &buf) + err := s.RawStore.WriteRaw(ctx, reference, size, opts, teeReader) + if err != nil { + return err + } + + err = s.cache.Set([]byte(reference), buf.Bytes(), neverExpire) + if err != nil { + s.metrics.CacheWriteError.Inc() + } + + // TODO ignore errors? + return err +} + +// Creates a CachedStore if Caching is enabled, otherwise returns a RawStore +func newCachedRawStore(cfg *Config, store RawStore, scope promutils.Scope) RawStore { + if cfg.Cache.MaxSizeMegabytes > 0 { + c := &cachedRawStore{ + RawStore: store, + cache: freecache.NewCache(cfg.Cache.MaxSizeMegabytes * 1024 * 1024), + scope: scope, + metrics: &cacheMetrics{ + FetchLatency: scope.MustNewStopWatch("remote_fetch", "Total Time to read from remote metastore", time.Millisecond), + CacheHit: scope.MustNewCounter("cache_hit", "Number of times metadata was found in cache"), + CacheMiss: scope.MustNewCounter("cache_miss", "Number of times metadata was not found in cache and remote fetch was required"), + CacheWriteError: scope.MustNewCounter("cache_write_err", "Failed to write to cache"), + }, + } + if cfg.Cache.TargetGCPercent > 0 { + debug.SetGCPercent(cfg.Cache.TargetGCPercent) + } + return c + } + return store +} diff --git a/storage/cached_rawstore_test.go b/storage/cached_rawstore_test.go new file mode 100644 index 0000000000..316f999bf7 --- /dev/null +++ b/storage/cached_rawstore_test.go @@ -0,0 +1,182 @@ +package storage + +import ( + "bytes" + "context" + "fmt" + "io" + "io/ioutil" + "runtime/debug" + "testing" + + "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flytestdlib/promutils/labeled" + + "github.com/lyft/flytestdlib/promutils" + + "github.com/lyft/flytestdlib/ioutils" + "github.com/stretchr/testify/assert" +) + +func init() { + labeled.SetMetricKeys(contextutils.ProjectKey, contextutils.DomainKey, contextutils.WorkflowIDKey, contextutils.TaskIDKey) +} + +func TestNewCachedStore(t *testing.T) { + + t.Run("CachingDisabled", func(t *testing.T) { + testScope := promutils.NewTestScope() + cfg := &Config{} + assert.Nil(t, newCachedRawStore(cfg, nil, testScope)) + store, err := NewInMemoryRawStore(cfg, testScope) + assert.NoError(t, err) + assert.Equal(t, store, newCachedRawStore(cfg, store, testScope)) + }) + + t.Run("CachingEnabled", func(t *testing.T) { + testScope := promutils.NewTestScope() + cfg := &Config{ + Cache: CachingConfig{ + MaxSizeMegabytes: 1, + TargetGCPercent: 20, + }, + } + store, err := NewInMemoryRawStore(cfg, testScope) + assert.NoError(t, err) + cStore := newCachedRawStore(cfg, store, testScope) + assert.Equal(t, 20, debug.SetGCPercent(100)) + assert.NotNil(t, cStore) + assert.NotNil(t, cStore.(*cachedRawStore).cache) + }) +} + +func dummyCacheStore(t *testing.T, store RawStore, scope promutils.Scope) *cachedRawStore { + cfg := &Config{ + Cache: CachingConfig{ + MaxSizeMegabytes: 1, + TargetGCPercent: 20, + }, + } + cStore := newCachedRawStore(cfg, store, scope) + assert.NotNil(t, cStore) + return cStore.(*cachedRawStore) +} + +type dummyStore struct { + copyImpl + HeadCb func(ctx context.Context, reference DataReference) (Metadata, error) + ReadRawCb func(ctx context.Context, reference DataReference) (io.ReadCloser, error) + WriteRawCb func(ctx context.Context, reference DataReference, size int64, opts Options, raw io.Reader) error +} + +func (d *dummyStore) GetBaseContainerFQN(ctx context.Context) DataReference { + return "dummy" +} + +func (d *dummyStore) Head(ctx context.Context, reference DataReference) (Metadata, error) { + return d.HeadCb(ctx, reference) +} + +func (d *dummyStore) ReadRaw(ctx context.Context, reference DataReference) (io.ReadCloser, error) { + return d.ReadRawCb(ctx, reference) +} + +func (d *dummyStore) WriteRaw(ctx context.Context, reference DataReference, size int64, opts Options, raw io.Reader) error { + return d.WriteRawCb(ctx, reference, size, opts, raw) +} + +func TestCachedRawStore(t *testing.T) { + ctx := context.TODO() + k1 := DataReference("k1") + k2 := DataReference("k2") + d1 := []byte("abc") + d2 := []byte("xyz") + writeCalled := false + readCalled := false + store := &dummyStore{ + HeadCb: func(ctx context.Context, reference DataReference) (Metadata, error) { + if reference == "k1" { + return MemoryMetadata{exists: true, size: int64(len(d1))}, nil + } + return MemoryMetadata{}, fmt.Errorf("err") + }, + WriteRawCb: func(ctx context.Context, reference DataReference, size int64, opts Options, raw io.Reader) error { + if writeCalled { + assert.FailNow(t, "Should not be writeCalled") + } + writeCalled = true + if reference == "k2" { + b, err := ioutil.ReadAll(raw) + assert.NoError(t, err) + assert.Equal(t, d2, b) + return nil + } + return fmt.Errorf("err") + }, + ReadRawCb: func(ctx context.Context, reference DataReference) (io.ReadCloser, error) { + if readCalled { + assert.FailNow(t, "Should not be invoked again") + } + readCalled = true + if reference == "k1" { + return ioutils.NewBytesReadCloser(d1), nil + } + return nil, fmt.Errorf("err") + }, + } + testScope := promutils.NewTestScope() + + store.copyImpl = newCopyImpl(store, testScope.NewSubScope("copy")) + + cStore := dummyCacheStore(t, store, testScope.NewSubScope("x")) + + t.Run("HeadExists", func(t *testing.T) { + m, err := cStore.Head(ctx, k1) + assert.NoError(t, err) + assert.Equal(t, int64(len(d1)), m.Size()) + assert.True(t, m.Exists()) + }) + + t.Run("HeadNotExists", func(t *testing.T) { + m, err := cStore.Head(ctx, k2) + assert.Error(t, err) + assert.False(t, m.Exists()) + }) + + t.Run("ReadCachePopulate", func(t *testing.T) { + o, err := cStore.ReadRaw(ctx, k1) + assert.NoError(t, err) + b, err := ioutil.ReadAll(o) + assert.NoError(t, err) + assert.Equal(t, d1, b) + assert.True(t, readCalled) + readCalled = false + o, err = cStore.ReadRaw(ctx, k1) + assert.NoError(t, err) + b, err = ioutil.ReadAll(o) + assert.NoError(t, err) + assert.Equal(t, d1, b) + assert.False(t, readCalled) + }) + + t.Run("ReadFail", func(t *testing.T) { + readCalled = false + _, err := cStore.ReadRaw(ctx, k2) + assert.Error(t, err) + assert.True(t, readCalled) + }) + + t.Run("WriteAndRead", func(t *testing.T) { + readCalled = false + assert.NoError(t, cStore.WriteRaw(ctx, k2, int64(len(d2)), Options{}, bytes.NewReader(d2))) + assert.True(t, writeCalled) + + o, err := cStore.ReadRaw(ctx, k2) + assert.NoError(t, err) + b, err := ioutil.ReadAll(o) + assert.NoError(t, err) + assert.Equal(t, d2, b) + assert.False(t, readCalled) + }) + +} diff --git a/storage/config.go b/storage/config.go new file mode 100644 index 0000000000..59db0197c3 --- /dev/null +++ b/storage/config.go @@ -0,0 +1,85 @@ +package storage + +import ( + "context" + + "github.com/lyft/flytestdlib/config" + "github.com/lyft/flytestdlib/logger" +) + +//go:generate pflags Config + +// Defines the storage config type. +type Type = string + +// The reserved config section key for storage. +const configSectionKey = "Storage" + +const ( + TypeMemory Type = "mem" + TypeS3 Type = "s3" + TypeLocal Type = "local" + TypeMinio Type = "minio" +) + +const ( + KiB int64 = 1024 + MiB int64 = 1024 * KiB +) + +var ( + ConfigSection = config.MustRegisterSection(configSectionKey, &Config{}) +) + +// A common storage config. +type Config struct { + Type Type `json:"type" pflag:"\"s3\",Sets the type of storage to configure [s3/minio/local/mem]."` + Connection ConnectionConfig `json:"connection"` + InitContainer string `json:"container" pflag:",Initial container to create -if it doesn't exist-.'"` + // Caching is recommended to improve the performance of underlying systems. It caches the metadata and resolving + // inputs is accelerated. The size of the cache is large so understand how to configure the cache. + // TODO provide some default config choices + // If this section is skipped, Caching is disabled + Cache CachingConfig `json:"cache"` + Limits LimitsConfig `json:"limits" pflag:",Sets limits for stores."` +} + +// Defines connection configurations. +type ConnectionConfig struct { + Endpoint config.URL `json:"endpoint" pflag:",URL for storage client to connect to."` + AuthType string `json:"auth-type" pflag:"\"iam\",Auth Type to use [iam,accesskey]."` + AccessKey string `json:"access-key" pflag:",Access key to use. Only required when authtype is set to accesskey."` + SecretKey string `json:"secret-key" pflag:",Secret to use when accesskey is set."` + Region string `json:"region" pflag:"\"us-east-1\",Region to connect to."` + DisableSSL bool `json:"disable-ssl" pflag:",Disables SSL connection. Should only be used for development."` +} + +type CachingConfig struct { + // Maximum size of the cache where the Blob store data is cached in-memory + // Refer to https://github.com/coocood/freecache to understand how to set the value + // If not specified or set to 0, cache is not used + // NOTE: if Object sizes are larger than 1/1024 of the cache size, the entry will not be written to the cache + // Also refer to https://github.com/coocood/freecache/issues/17 to understand how to set the cache + MaxSizeMegabytes int `json:"max_size_mbs" pflag:",Maximum size of the cache where the Blob store data is cached in-memory. If not specified or set to 0, cache is not used"` + // sets the garbage collection target percentage: + // a collection is triggered when the ratio of freshly allocated data + // to live data remaining after the previous collection reaches this percentage. + // refer to https://golang.org/pkg/runtime/debug/#SetGCPercent + // If not specified or set to 0, GC percent is not tweaked + TargetGCPercent int `json:"target_gc_percent" pflag:",Sets the garbage collection target percentage."` +} + +// Specifies limits for storage package. +type LimitsConfig struct { + GetLimitMegabytes int64 `json:"maxDownloadMBs" pflag:"2,Maximum allowed download size (in MBs) per call."` +} + +// Retrieve current global config for storage. +func GetConfig() *Config { + if c, ok := ConfigSection.GetConfig().(*Config); ok { + return c + } + + logger.Warnf(context.TODO(), "Failed to retrieve config section [%v].", configSectionKey) + return nil +} diff --git a/storage/config_flags.go b/storage/config_flags.go new file mode 100755 index 0000000000..9a74efdd64 --- /dev/null +++ b/storage/config_flags.go @@ -0,0 +1,28 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package storage + +import ( + "fmt" + + "github.com/spf13/pflag" +) + +// GetPFlagSet will return strongly types pflags for all fields in Config and its nested types. The format of the +// flags is json-name.json-sub-name... etc. +func (Config) GetPFlagSet(prefix string) *pflag.FlagSet { + cmdFlags := pflag.NewFlagSet("Config", pflag.ExitOnError) + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "type"), "s3", "Sets the type of storage to configure [s3/minio/local/mem].") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "connection.endpoint"), "", "URL for storage client to connect to.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "connection.auth-type"), "iam", "Auth Type to use [iam, accesskey].") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "connection.access-key"), *new(string), "Access key to use. Only required when authtype is set to accesskey.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "connection.secret-key"), *new(string), "Secret to use when accesskey is set.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "connection.region"), "us-east-1", "Region to connect to.") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "connection.disable-ssl"), *new(bool), "Disables SSL connection. Should only be used for development.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "container"), *new(string), "Initial container to create -if it doesn't exist-.'") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "cache.max_size_mbs"), *new(int), "Maximum size of the cache where the Blob store data is cached in-memory. If not specified or set to 0, cache is not used") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "cache.target_gc_percent"), *new(int), "Sets the garbage collection target percentage.") + cmdFlags.Int64(fmt.Sprintf("%v%v", prefix, "limits.maxDownloadMBs"), 2, "Maximum allowed download size (in MBs) per call.") + return cmdFlags +} diff --git a/storage/config_flags_test.go b/storage/config_flags_test.go new file mode 100755 index 0000000000..2f39f00006 --- /dev/null +++ b/storage/config_flags_test.go @@ -0,0 +1,344 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package storage + +import ( + "encoding/json" + "fmt" + "reflect" + "strings" + "testing" + + "github.com/mitchellh/mapstructure" + "github.com/stretchr/testify/assert" +) + +var dereferencableKindsConfig = map[reflect.Kind]struct{}{ + reflect.Array: {}, reflect.Chan: {}, reflect.Map: {}, reflect.Ptr: {}, reflect.Slice: {}, +} + +// Checks if t is a kind that can be dereferenced to get its underlying type. +func canGetElementConfig(t reflect.Kind) bool { + _, exists := dereferencableKindsConfig[t] + return exists +} + +// This decoder hook tests types for json unmarshaling capability. If implemented, it uses json unmarshal to build the +// object. Otherwise, it'll just pass on the original data. +func jsonUnmarshalerHookConfig(_, to reflect.Type, data interface{}) (interface{}, error) { + unmarshalerType := reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() + if to.Implements(unmarshalerType) || reflect.PtrTo(to).Implements(unmarshalerType) || + (canGetElementConfig(to.Kind()) && to.Elem().Implements(unmarshalerType)) { + + raw, err := json.Marshal(data) + if err != nil { + fmt.Printf("Failed to marshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + res := reflect.New(to).Interface() + err = json.Unmarshal(raw, &res) + if err != nil { + fmt.Printf("Failed to umarshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + return res, nil + } + + return data, nil +} + +func decode_Config(input, result interface{}) error { + config := &mapstructure.DecoderConfig{ + TagName: "json", + WeaklyTypedInput: true, + Result: result, + DecodeHook: mapstructure.ComposeDecodeHookFunc( + mapstructure.StringToTimeDurationHookFunc(), + mapstructure.StringToSliceHookFunc(","), + jsonUnmarshalerHookConfig, + ), + } + + decoder, err := mapstructure.NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(input) +} + +func join_Config(arr interface{}, sep string) string { + listValue := reflect.ValueOf(arr) + strs := make([]string, 0, listValue.Len()) + for i := 0; i < listValue.Len(); i++ { + strs = append(strs, fmt.Sprintf("%v", listValue.Index(i))) + } + + return strings.Join(strs, sep) +} + +func testDecodeJson_Config(t *testing.T, val, result interface{}) { + assert.NoError(t, decode_Config(val, result)) +} + +func testDecodeSlice_Config(t *testing.T, vStringSlice, result interface{}) { + assert.NoError(t, decode_Config(vStringSlice, result)) +} + +func TestConfig_GetPFlagSet(t *testing.T) { + val := Config{} + cmdFlags := val.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) +} + +func TestConfig_SetFlags(t *testing.T) { + actual := Config{} + cmdFlags := actual.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) + + t.Run("Test_type", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("type"); err == nil { + assert.Equal(t, string("s3"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("type", testValue) + if vString, err := cmdFlags.GetString("type"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Type) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_connection.endpoint", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("connection.endpoint"); err == nil { + assert.Equal(t, string(""), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("connection.endpoint", testValue) + if vString, err := cmdFlags.GetString("connection.endpoint"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Connection.Endpoint) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_connection.auth-type", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("connection.auth-type"); err == nil { + assert.Equal(t, string("iam"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("connection.auth-type", testValue) + if vString, err := cmdFlags.GetString("connection.auth-type"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Connection.AuthType) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_connection.access-key", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("connection.access-key"); err == nil { + assert.Equal(t, string(*new(string)), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("connection.access-key", testValue) + if vString, err := cmdFlags.GetString("connection.access-key"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Connection.AccessKey) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_connection.secret-key", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("connection.secret-key"); err == nil { + assert.Equal(t, string(*new(string)), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("connection.secret-key", testValue) + if vString, err := cmdFlags.GetString("connection.secret-key"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Connection.SecretKey) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_connection.region", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("connection.region"); err == nil { + assert.Equal(t, string("us-east-1"), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("connection.region", testValue) + if vString, err := cmdFlags.GetString("connection.region"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Connection.Region) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_connection.disable-ssl", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vBool, err := cmdFlags.GetBool("connection.disable-ssl"); err == nil { + assert.Equal(t, bool(*new(bool)), vBool) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("connection.disable-ssl", testValue) + if vBool, err := cmdFlags.GetBool("connection.disable-ssl"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vBool), &actual.Connection.DisableSSL) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_container", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("container"); err == nil { + assert.Equal(t, string(*new(string)), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("container", testValue) + if vString, err := cmdFlags.GetString("container"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.InitContainer) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_cache.max_size_mbs", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("cache.max_size_mbs"); err == nil { + assert.Equal(t, int(*new(int)), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("cache.max_size_mbs", testValue) + if vInt, err := cmdFlags.GetInt("cache.max_size_mbs"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.Cache.MaxSizeMegabytes) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_cache.target_gc_percent", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("cache.target_gc_percent"); err == nil { + assert.Equal(t, int(*new(int)), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("cache.target_gc_percent", testValue) + if vInt, err := cmdFlags.GetInt("cache.target_gc_percent"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.Cache.TargetGCPercent) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_limits.maxDownloadMBs", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt64, err := cmdFlags.GetInt64("limits.maxDownloadMBs"); err == nil { + assert.Equal(t, int64(2), vInt64) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("limits.maxDownloadMBs", testValue) + if vInt64, err := cmdFlags.GetInt64("limits.maxDownloadMBs"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt64), &actual.Limits.GetLimitMegabytes) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) +} diff --git a/storage/config_test.go b/storage/config_test.go new file mode 100644 index 0000000000..93a5fe887c --- /dev/null +++ b/storage/config_test.go @@ -0,0 +1,45 @@ +package storage + +import ( + "flag" + "io/ioutil" + "os" + "path/filepath" + "reflect" + "testing" + + "github.com/ghodss/yaml" + "github.com/lyft/flytestdlib/config" + "github.com/lyft/flytestdlib/internal/utils" + "github.com/stretchr/testify/assert" +) + +// Make sure existing config file(s) parse correctly before overriding them with this flag! +var update = flag.Bool("update", false, "Updates testdata") + +func TestMarshal(t *testing.T) { + expected := Config{ + Type: "s3", + Connection: ConnectionConfig{ + Endpoint: config.URL{URL: utils.MustParseURL("http://minio:9000")}, + AuthType: "accesskey", + AccessKey: "minio", + SecretKey: "miniostorage", + Region: "us-east-1", + DisableSSL: true, + }, + } + + if *update { + t.Log("Updating config file.") + raw, err := yaml.Marshal(expected) + assert.NoError(t, err) + assert.NoError(t, ioutil.WriteFile(filepath.Join("testdata", "config.yaml"), raw, os.ModePerm)) + } + + actual := Config{} + raw, err := ioutil.ReadFile(filepath.Join("testdata", "config.yaml")) + assert.NoError(t, err) + assert.NoError(t, yaml.Unmarshal(raw, &actual)) + assert.True(t, reflect.DeepEqual(expected, actual)) +} diff --git a/storage/copy_impl.go b/storage/copy_impl.go new file mode 100644 index 0000000000..43f97f026f --- /dev/null +++ b/storage/copy_impl.go @@ -0,0 +1,60 @@ +package storage + +import ( + "context" + "io" + "time" + + "github.com/lyft/flytestdlib/ioutils" + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/promutils/labeled" +) + +type copyImpl struct { + rawStore RawStore + metrics copyMetrics +} + +type copyMetrics struct { + CopyLatency labeled.StopWatch + ComputeLengthLatency labeled.StopWatch +} + +// A naiive implementation for copy that reads all data locally then writes them to destination. +// TODO: We should upstream an API change to stow to implement copy more natively. E.g. Use s3 copy: +// https://docs.aws.amazon.com/AmazonS3/latest/dev/CopyingObjectUsingREST.html +func (c copyImpl) CopyRaw(ctx context.Context, source, destination DataReference, opts Options) error { + rc, err := c.rawStore.ReadRaw(ctx, source) + if err != nil { + return err + } + + length := int64(0) + if _, isSeeker := rc.(io.Seeker); !isSeeker { + // If the returned ReadCloser doesn't implement Seeker interface, then the underlying writer won't be able to + // calculate content length on its own. Some implementations (e.g. S3 Stow Store) will error if it can't. + var raw []byte + raw, err = ioutils.ReadAll(rc, c.metrics.ComputeLengthLatency.Start(ctx)) + if err != nil { + return err + } + + length = int64(len(raw)) + } + + return c.rawStore.WriteRaw(ctx, destination, length, Options{}, rc) +} + +func newCopyMetrics(scope promutils.Scope) copyMetrics { + return copyMetrics{ + CopyLatency: labeled.NewStopWatch("overall", "Overall copy latency", time.Millisecond, scope, labeled.EmitUnlabeledMetric), + ComputeLengthLatency: labeled.NewStopWatch("length", "Latency involved in computing length of content before writing.", time.Millisecond, scope, labeled.EmitUnlabeledMetric), + } +} + +func newCopyImpl(store RawStore, metricsScope promutils.Scope) copyImpl { + return copyImpl{ + rawStore: store, + metrics: newCopyMetrics(metricsScope.NewSubScope("copy")), + } +} diff --git a/storage/copy_impl_test.go b/storage/copy_impl_test.go new file mode 100644 index 0000000000..fc8d78cd1a --- /dev/null +++ b/storage/copy_impl_test.go @@ -0,0 +1,82 @@ +package storage + +import ( + "context" + "io" + "testing" + + "github.com/lyft/flytestdlib/ioutils" + "github.com/lyft/flytestdlib/promutils" + "github.com/stretchr/testify/assert" +) + +type notSeekerReader struct { + bytesCount int +} + +func (notSeekerReader) Close() error { + return nil +} + +func (r *notSeekerReader) Read(p []byte) (n int, err error) { + if len(p) < 1 { + return 0, nil + } + + p[0] = byte(10) + + r.bytesCount-- + if r.bytesCount <= 0 { + return 0, io.EOF + } + + return 1, nil +} + +func newNotSeekerReader(bytesCount int) *notSeekerReader { + return ¬SeekerReader{ + bytesCount: bytesCount, + } +} + +func TestCopyRaw(t *testing.T) { + t.Run("Called", func(t *testing.T) { + readerCalled := false + writerCalled := false + store := dummyStore{ + ReadRawCb: func(ctx context.Context, reference DataReference) (closer io.ReadCloser, e error) { + readerCalled = true + return ioutils.NewBytesReadCloser([]byte{}), nil + }, + WriteRawCb: func(ctx context.Context, reference DataReference, size int64, opts Options, raw io.Reader) error { + writerCalled = true + return nil + }, + } + + copier := newCopyImpl(&store, promutils.NewTestScope()) + assert.NoError(t, copier.CopyRaw(context.Background(), DataReference("source.pb"), DataReference("dest.pb"), Options{})) + assert.True(t, readerCalled) + assert.True(t, writerCalled) + }) + + t.Run("Not Seeker", func(t *testing.T) { + readerCalled := false + writerCalled := false + store := dummyStore{ + ReadRawCb: func(ctx context.Context, reference DataReference) (closer io.ReadCloser, e error) { + readerCalled = true + return newNotSeekerReader(10), nil + }, + WriteRawCb: func(ctx context.Context, reference DataReference, size int64, opts Options, raw io.Reader) error { + writerCalled = true + return nil + }, + } + + copier := newCopyImpl(&store, promutils.NewTestScope()) + assert.NoError(t, copier.CopyRaw(context.Background(), DataReference("source.pb"), DataReference("dest.pb"), Options{})) + assert.True(t, readerCalled) + assert.True(t, writerCalled) + }) +} diff --git a/storage/localstore.go b/storage/localstore.go new file mode 100644 index 0000000000..450c102a24 --- /dev/null +++ b/storage/localstore.go @@ -0,0 +1,48 @@ +package storage + +import ( + "fmt" + + "github.com/lyft/flytestdlib/promutils" + + "github.com/graymeta/stow" + "github.com/graymeta/stow/local" +) + +func getLocalStowConfigMap(cfg *Config) stow.ConfigMap { + stowConfig := stow.ConfigMap{} + if endpoint := cfg.Connection.Endpoint.String(); endpoint != "" { + stowConfig[local.ConfigKeyPath] = endpoint + } + + return stowConfig +} + +// Creates a Data store backed by Stow-S3 raw store. +func newLocalRawStore(cfg *Config, metricsScope promutils.Scope) (RawStore, error) { + if cfg.InitContainer == "" { + return nil, fmt.Errorf("initContainer is required") + } + + loc, err := stow.Dial(local.Kind, getLocalStowConfigMap(cfg)) + + if err != nil { + return emptyStore, fmt.Errorf("unable to configure the storage for local. Error: %v", err) + } + + c, err := loc.Container(cfg.InitContainer) + if err != nil { + if IsNotFound(err) { + c, err = loc.CreateContainer(cfg.InitContainer) + if err != nil && !IsExists(err) { + return emptyStore, fmt.Errorf("unable to initialize container [%v]. Error: %v", cfg.InitContainer, err) + } + + return NewStowRawStore(DataReference(c.Name()), c, metricsScope) + } + + return emptyStore, err + } + + return NewStowRawStore(DataReference(c.Name()), c, metricsScope) +} diff --git a/storage/localstore_test.go b/storage/localstore_test.go new file mode 100644 index 0000000000..31f8aabeb2 --- /dev/null +++ b/storage/localstore_test.go @@ -0,0 +1,66 @@ +package storage + +import ( + "context" + "io/ioutil" + "os" + "path/filepath" + "testing" + + "github.com/lyft/flytestdlib/promutils" + + "github.com/lyft/flytestdlib/config" + "github.com/lyft/flytestdlib/internal/utils" + "github.com/stretchr/testify/assert" +) + +func TestNewLocalStore(t *testing.T) { + t.Run("Valid config", func(t *testing.T) { + testScope := promutils.NewTestScope() + store, err := newLocalRawStore(&Config{ + Connection: ConnectionConfig{ + Endpoint: config.URL{URL: utils.MustParseURL("./")}, + }, + InitContainer: "testdata", + }, testScope.NewSubScope("x")) + + assert.NoError(t, err) + assert.NotNil(t, store) + + // Stow local store expects the full path after the container portion (looks like a bug to me) + rc, err := store.ReadRaw(context.TODO(), DataReference("file://testdata/testdata/config.yaml")) + assert.NoError(t, err) + assert.NotNil(t, rc) + assert.NoError(t, rc.Close()) + }) + + t.Run("Invalid config", func(t *testing.T) { + testScope := promutils.NewTestScope() + _, err := newLocalRawStore(&Config{}, testScope) + assert.Error(t, err) + }) + + t.Run("Initialize container", func(t *testing.T) { + testScope := promutils.NewTestScope() + tmpDir, err := ioutil.TempDir("", "stdlib_local") + assert.NoError(t, err) + + stats, err := os.Stat(tmpDir) + assert.NoError(t, err) + assert.NotNil(t, stats) + + store, err := newLocalRawStore(&Config{ + Connection: ConnectionConfig{ + Endpoint: config.URL{URL: utils.MustParseURL(tmpDir)}, + }, + InitContainer: "tmp", + }, testScope.NewSubScope("y")) + + assert.NoError(t, err) + assert.NotNil(t, store) + + stats, err = os.Stat(filepath.Join(tmpDir, "tmp")) + assert.NoError(t, err) + assert.True(t, stats.IsDir()) + }) +} diff --git a/storage/mem_store.go b/storage/mem_store.go new file mode 100644 index 0000000000..cc0c5854c0 --- /dev/null +++ b/storage/mem_store.go @@ -0,0 +1,74 @@ +package storage + +import ( + "bytes" + "context" + "io" + "io/ioutil" + "os" + + "github.com/lyft/flytestdlib/promutils" +) + +type rawFile = []byte + +type InMemoryStore struct { + copyImpl + cache map[DataReference]rawFile +} + +type MemoryMetadata struct { + exists bool + size int64 +} + +func (m MemoryMetadata) Size() int64 { + return m.size +} + +func (m MemoryMetadata) Exists() bool { + return m.exists +} + +func (s *InMemoryStore) Head(ctx context.Context, reference DataReference) (Metadata, error) { + data, found := s.cache[reference] + return MemoryMetadata{exists: found, size: int64(len(data))}, nil +} + +func (s *InMemoryStore) ReadRaw(ctx context.Context, reference DataReference) (io.ReadCloser, error) { + if raw, found := s.cache[reference]; found { + return ioutil.NopCloser(bytes.NewReader(raw)), nil + } + + return nil, os.ErrNotExist +} + +func (s *InMemoryStore) WriteRaw(ctx context.Context, reference DataReference, size int64, opts Options, raw io.Reader) ( + err error) { + + rawBytes, err := ioutil.ReadAll(raw) + if err != nil { + return err + } + + s.cache[reference] = rawBytes + return nil +} + +func (s *InMemoryStore) Clear(ctx context.Context) error { + s.cache = map[DataReference]rawFile{} + return nil +} + +func (s *InMemoryStore) GetBaseContainerFQN(ctx context.Context) DataReference { + return DataReference("") +} + +func NewInMemoryRawStore(_ *Config, scope promutils.Scope) (RawStore, error) { + self := &InMemoryStore{ + cache: map[DataReference]rawFile{}, + } + + self.copyImpl = newCopyImpl(self, scope) + return self, nil +} diff --git a/storage/mem_store_test.go b/storage/mem_store_test.go new file mode 100644 index 0000000000..fdfe2b724a --- /dev/null +++ b/storage/mem_store_test.go @@ -0,0 +1,78 @@ +package storage + +import ( + "bytes" + "context" + "testing" + + "github.com/lyft/flytestdlib/promutils" + + "github.com/stretchr/testify/assert" +) + +func TestInMemoryStore_Head(t *testing.T) { + t.Run("Empty store", func(t *testing.T) { + testScope := promutils.NewTestScope() + s, err := NewInMemoryRawStore(&Config{}, testScope) + assert.NoError(t, err) + metadata, err := s.Head(context.TODO(), DataReference("hello")) + assert.NoError(t, err) + assert.False(t, metadata.Exists()) + }) + + t.Run("Existing Item", func(t *testing.T) { + testScope := promutils.NewTestScope() + s, err := NewInMemoryRawStore(&Config{}, testScope) + assert.NoError(t, err) + err = s.WriteRaw(context.TODO(), DataReference("hello"), 0, Options{}, bytes.NewReader([]byte{})) + assert.NoError(t, err) + + metadata, err := s.Head(context.TODO(), DataReference("hello")) + assert.NoError(t, err) + assert.True(t, metadata.Exists()) + }) +} + +func TestInMemoryStore_ReadRaw(t *testing.T) { + t.Run("Empty store", func(t *testing.T) { + testScope := promutils.NewTestScope() + s, err := NewInMemoryRawStore(&Config{}, testScope) + assert.NoError(t, err) + + raw, err := s.ReadRaw(context.TODO(), DataReference("hello")) + assert.Error(t, err) + assert.Nil(t, raw) + }) + + t.Run("Existing Item", func(t *testing.T) { + testScope := promutils.NewTestScope() + s, err := NewInMemoryRawStore(&Config{}, testScope) + assert.NoError(t, err) + + err = s.WriteRaw(context.TODO(), DataReference("hello"), 0, Options{}, bytes.NewReader([]byte{})) + assert.NoError(t, err) + + _, err = s.ReadRaw(context.TODO(), DataReference("hello")) + assert.NoError(t, err) + }) +} + +func TestInMemoryStore_Clear(t *testing.T) { + testScope := promutils.NewTestScope() + m, err := NewInMemoryRawStore(&Config{}, testScope) + assert.NoError(t, err) + + mStore := m.(*InMemoryStore) + err = m.WriteRaw(context.TODO(), DataReference("hello"), 0, Options{}, bytes.NewReader([]byte("world"))) + assert.NoError(t, err) + + _, err = m.ReadRaw(context.TODO(), DataReference("hello")) + assert.NoError(t, err) + + err = mStore.Clear(context.TODO()) + assert.NoError(t, err) + + _, err = m.ReadRaw(context.TODO(), DataReference("hello")) + assert.Error(t, err) + assert.True(t, IsNotFound(err)) +} diff --git a/storage/protobuf_store.go b/storage/protobuf_store.go new file mode 100644 index 0000000000..ba11d3311c --- /dev/null +++ b/storage/protobuf_store.go @@ -0,0 +1,85 @@ +package storage + +import ( + "bytes" + "context" + "fmt" + "time" + + "github.com/lyft/flytestdlib/logger" + + "github.com/lyft/flytestdlib/ioutils" + "github.com/lyft/flytestdlib/promutils" + "github.com/prometheus/client_golang/prometheus" + + "github.com/golang/protobuf/proto" + errs "github.com/pkg/errors" +) + +type protoMetrics struct { + FetchLatency promutils.StopWatch + MarshalTime promutils.StopWatch + UnmarshalTime promutils.StopWatch + MarshalFailure prometheus.Counter + UnmarshalFailure prometheus.Counter +} + +// Implements ProtobufStore to marshal and unmarshal protobufs to/from a RawStore +type DefaultProtobufStore struct { + RawStore + metrics *protoMetrics +} + +func (s DefaultProtobufStore) ReadProtobuf(ctx context.Context, reference DataReference, msg proto.Message) error { + rc, err := s.ReadRaw(ctx, reference) + if err != nil { + return errs.Wrap(err, fmt.Sprintf("path:%v", reference)) + } + + defer func() { + err = rc.Close() + if err != nil { + logger.Warn(ctx, "Failed to close reference [%v]. Error: %v", reference, err) + } + }() + + docContents, err := ioutils.ReadAll(rc, s.metrics.FetchLatency.Start()) + if err != nil { + return errs.Wrap(err, fmt.Sprintf("readAll: %v", reference)) + } + + t := s.metrics.UnmarshalTime.Start() + err = proto.Unmarshal(docContents, msg) + t.Stop() + if err != nil { + s.metrics.UnmarshalFailure.Inc() + return errs.Wrap(err, fmt.Sprintf("unmarshall: %v", reference)) + } + + return nil +} + +func (s DefaultProtobufStore) WriteProtobuf(ctx context.Context, reference DataReference, opts Options, msg proto.Message) error { + t := s.metrics.MarshalTime.Start() + raw, err := proto.Marshal(msg) + t.Stop() + if err != nil { + s.metrics.MarshalFailure.Inc() + return err + } + + return s.WriteRaw(ctx, reference, int64(len(raw)), opts, bytes.NewReader(raw)) +} + +func NewDefaultProtobufStore(store RawStore, metricsScope promutils.Scope) DefaultProtobufStore { + return DefaultProtobufStore{ + RawStore: store, + metrics: &protoMetrics{ + FetchLatency: metricsScope.MustNewStopWatch("proto_fetch", "Time to read data before unmarshalling", time.Millisecond), + MarshalTime: metricsScope.MustNewStopWatch("marshal", "Time incurred in marshalling data before writing", time.Millisecond), + UnmarshalTime: metricsScope.MustNewStopWatch("unmarshal", "Time incurred in unmarshalling received data", time.Millisecond), + MarshalFailure: metricsScope.MustNewCounter("marshal_failure", "Failures when marshalling"), + UnmarshalFailure: metricsScope.MustNewCounter("unmarshal_failure", "Failures when unmarshalling"), + }, + } +} diff --git a/storage/protobuf_store_test.go b/storage/protobuf_store_test.go new file mode 100644 index 0000000000..160239bb73 --- /dev/null +++ b/storage/protobuf_store_test.go @@ -0,0 +1,41 @@ +package storage + +import ( + "context" + "testing" + + "github.com/lyft/flytestdlib/promutils" + + "github.com/golang/protobuf/proto" + "github.com/stretchr/testify/assert" +) + +type mockProtoMessage struct { + X int64 `protobuf:"varint,2,opt,name=x,json=x,proto3" json:"x,omitempty"` +} + +func (mockProtoMessage) Reset() { +} + +func (m mockProtoMessage) String() string { + return proto.CompactTextString(m) +} + +func (mockProtoMessage) ProtoMessage() { +} + +func TestDefaultProtobufStore_ReadProtobuf(t *testing.T) { + t.Run("Read after Write", func(t *testing.T) { + testScope := promutils.NewTestScope() + s, err := NewDataStore(&Config{Type: TypeMemory}, testScope) + assert.NoError(t, err) + + err = s.WriteProtobuf(context.TODO(), DataReference("hello"), Options{}, &mockProtoMessage{X: 5}) + assert.NoError(t, err) + + m := &mockProtoMessage{} + err = s.ReadProtobuf(context.TODO(), DataReference("hello"), m) + assert.NoError(t, err) + assert.Equal(t, int64(5), m.X) + }) +} diff --git a/storage/rawstores.go b/storage/rawstores.go new file mode 100644 index 0000000000..4a6d35573c --- /dev/null +++ b/storage/rawstores.go @@ -0,0 +1,40 @@ +package storage + +import ( + "fmt" + + "github.com/lyft/flytestdlib/promutils" +) + +type dataStoreCreateFn func(cfg *Config, metricsScope promutils.Scope) (RawStore, error) + +var stores = map[string]dataStoreCreateFn{ + TypeMemory: NewInMemoryRawStore, + TypeLocal: newLocalRawStore, + TypeMinio: newS3RawStore, + TypeS3: newS3RawStore, +} + +// Creates a new Data Store with the supplied config. +func NewDataStore(cfg *Config, metricsScope promutils.Scope) (s *DataStore, err error) { + var rawStore RawStore + if fn, found := stores[cfg.Type]; found { + rawStore, err = fn(cfg, metricsScope) + if err != nil { + return &emptyStore, err + } + + protoStore := NewDefaultProtobufStore(newCachedRawStore(cfg, rawStore, metricsScope), metricsScope) + return NewCompositeDataStore(URLPathConstructor{}, protoStore), nil + } + + return &emptyStore, fmt.Errorf("type is of an invalid value [%v]", cfg.Type) +} + +// Composes a new DataStore. +func NewCompositeDataStore(refConstructor ReferenceConstructor, composedProtobufStore ComposedProtobufStore) *DataStore { + return &DataStore{ + ReferenceConstructor: refConstructor, + ComposedProtobufStore: composedProtobufStore, + } +} diff --git a/storage/s3store.go b/storage/s3store.go new file mode 100644 index 0000000000..3c96730e0a --- /dev/null +++ b/storage/s3store.go @@ -0,0 +1,102 @@ +package storage + +import ( + "context" + "fmt" + + "github.com/lyft/flytestdlib/promutils" + + "github.com/aws/aws-sdk-go/aws/awserr" + awsS3 "github.com/aws/aws-sdk-go/service/s3" + "github.com/lyft/flytestdlib/logger" + "github.com/pkg/errors" + + "github.com/graymeta/stow" + "github.com/graymeta/stow/s3" +) + +func getStowConfigMap(cfg *Config) stow.ConfigMap { + // Non-nullable fields + stowConfig := stow.ConfigMap{ + s3.ConfigAuthType: cfg.Connection.AuthType, + s3.ConfigRegion: cfg.Connection.Region, + } + + // Fields that differ between minio and real S3 + if endpoint := cfg.Connection.Endpoint.String(); endpoint != "" { + stowConfig[s3.ConfigEndpoint] = endpoint + } + + if accessKey := cfg.Connection.AccessKey; accessKey != "" { + stowConfig[s3.ConfigAccessKeyID] = accessKey + } + + if secretKey := cfg.Connection.SecretKey; secretKey != "" { + stowConfig[s3.ConfigSecretKey] = secretKey + } + + if disableSsl := cfg.Connection.DisableSSL; disableSsl { + stowConfig[s3.ConfigDisableSSL] = "True" + } + + return stowConfig + +} + +func s3FQN(bucket string) DataReference { + return DataReference(fmt.Sprintf("s3://%s", bucket)) +} + +func newS3RawStore(cfg *Config, metricsScope promutils.Scope) (RawStore, error) { + if cfg.InitContainer == "" { + return nil, fmt.Errorf("initContainer is required") + } + + loc, err := stow.Dial(s3.Kind, getStowConfigMap(cfg)) + + if err != nil { + return emptyStore, fmt.Errorf("unable to configure the storage for s3. Error: %v", err) + } + + c, err := loc.Container(cfg.InitContainer) + if err != nil { + if IsNotFound(err) || awsBucketIsNotFound(err) { + c, err := loc.CreateContainer(cfg.InitContainer) + if err != nil { + // If the container's already created, move on. Otherwise, fail with error. + if awsBucketAlreadyExists(err) { + logger.Infof(context.TODO(), "Storage init-container already exists [%v].", cfg.InitContainer) + return NewStowRawStore(s3FQN(c.Name()), c, metricsScope) + } + } + return emptyStore, fmt.Errorf("unable to initialize container [%v]. Error: %v", cfg.InitContainer, err) + } + return emptyStore, err + } + + return NewStowRawStore(s3FQN(c.Name()), c, metricsScope) +} + +func awsBucketIsNotFound(err error) bool { + if IsNotFound(err) { + return true + } + + if awsErr, errOk := errors.Cause(err).(awserr.Error); errOk { + return awsErr.Code() == awsS3.ErrCodeNoSuchBucket + } + + return false +} + +func awsBucketAlreadyExists(err error) bool { + if IsExists(err) { + return true + } + + if awsErr, errOk := errors.Cause(err).(awserr.Error); errOk { + return awsErr.Code() == awsS3.ErrCodeBucketAlreadyOwnedByYou + } + + return false +} diff --git a/storage/s3stsore_test.go b/storage/s3stsore_test.go new file mode 100644 index 0000000000..2e8674a834 --- /dev/null +++ b/storage/s3stsore_test.go @@ -0,0 +1,26 @@ +package storage + +import ( + "testing" + + "github.com/lyft/flytestdlib/promutils" + + "github.com/lyft/flytestdlib/config" + "github.com/lyft/flytestdlib/internal/utils" + "github.com/stretchr/testify/assert" +) + +func TestNewS3RawStore(t *testing.T) { + t.Run("Missing required config", func(t *testing.T) { + testScope := promutils.NewTestScope() + _, err := NewDataStore(&Config{ + Type: TypeMinio, + InitContainer: "some-container", + Connection: ConnectionConfig{ + Endpoint: config.URL{URL: utils.MustParseURL("http://minio:9000")}, + }, + }, testScope) + + assert.Error(t, err) + }) +} diff --git a/storage/storage.go b/storage/storage.go new file mode 100644 index 0000000000..ffbc2c62b0 --- /dev/null +++ b/storage/storage.go @@ -0,0 +1,95 @@ +// Defines extensible storage interface. +// This package registers "storage" config section that maps to Config struct. Use NewDataStore(cfg) to initialize a +// DataStore with the provided config. The package provides default implementation to access local, S3 (and minio), +// and In-Memory storage. Use NewCompositeDataStore to swap any portions of the DataStore interface with an external +// implementation (e.g. a cached protobuf store). The underlying storage is provided by extensible "stow" library. You +// can use NewStowRawStore(cfg) to create a Raw store based on any other stow-supported configs (e.g. Azure Blob Storage) +package storage + +import ( + "context" + "strings" + + "io" + "net/url" + + "github.com/golang/protobuf/proto" +) + +// Defines a reference to data location. +type DataReference string + +var emptyStore = DataStore{} + +// Holder for recording storage options. It is used to pass Metadata (like headers for S3) and also tags or labels for +// objects +type Options struct { + Metadata map[string]interface{} +} + +// Placeholder for data reference metadata. +type Metadata interface { + Exists() bool + Size() int64 +} + +// A simplified interface for accessing and storing data in one of the Cloud stores. +// Today we rely on Stow for multi-cloud support, but this interface abstracts that part +type DataStore struct { + ComposedProtobufStore + ReferenceConstructor +} + +// Defines a low level interface for accessing and storing bytes. +type RawStore interface { + // returns a FQN DataReference with the configured base init container + GetBaseContainerFQN(ctx context.Context) DataReference + + // Gets metadata about the reference. This should generally be a light weight operation. + Head(ctx context.Context, reference DataReference) (Metadata, error) + + // Retrieves a byte array from the Blob store or an error + ReadRaw(ctx context.Context, reference DataReference) (io.ReadCloser, error) + + // Stores a raw byte array. + WriteRaw(ctx context.Context, reference DataReference, size int64, opts Options, raw io.Reader) error + + // Copies from source to destination. + CopyRaw(ctx context.Context, source, destination DataReference, opts Options) error +} + +// Defines an interface for building data reference paths. +type ReferenceConstructor interface { + // Creates a new dataReference that matches the storage structure. + ConstructReference(ctx context.Context, reference DataReference, nestedKeys ...string) (DataReference, error) +} + +// Defines an interface for reading and writing protobuf messages +type ProtobufStore interface { + // Retrieves the entire blob from blobstore and unmarshals it to the passed protobuf + ReadProtobuf(ctx context.Context, reference DataReference, msg proto.Message) error + + // Serializes and stores the protobuf. + WriteProtobuf(ctx context.Context, reference DataReference, opts Options, msg proto.Message) error +} + +// A ProtobufStore needs a RawStore to get the RawData. This interface provides all the necessary components to make +// Protobuf fetching work +type ComposedProtobufStore interface { + RawStore + ProtobufStore +} + +// Splits the data reference into parts. +func (r DataReference) Split() (scheme, container, key string, err error) { + u, err := url.Parse(string(r)) + if err != nil { + return "", "", "", err + } + + return u.Scheme, u.Host, strings.Trim(u.Path, "/"), nil +} + +func (r DataReference) String() string { + return string(r) +} diff --git a/storage/storage_test.go b/storage/storage_test.go new file mode 100644 index 0000000000..1895b0ac5f --- /dev/null +++ b/storage/storage_test.go @@ -0,0 +1,52 @@ +package storage + +import ( + "context" + "fmt" + "strings" + "testing" + + "github.com/lyft/flytestdlib/promutils" + + "github.com/stretchr/testify/assert" +) + +func TestDataReference_Split(t *testing.T) { + input := DataReference("s3://container/path/to/file") + scheme, container, key, err := input.Split() + + assert.NoError(t, err) + assert.Equal(t, "s3", scheme) + assert.Equal(t, "container", container) + assert.Equal(t, "path/to/file", key) +} + +func ExampleNewDataStore() { + testScope := promutils.NewTestScope() + ctx := context.Background() + fmt.Println("Creating in memory data store.") + store, err := NewDataStore(&Config{ + Type: TypeMemory, + }, testScope.NewSubScope("exp_new")) + + if err != nil { + fmt.Printf("Failed to create data store. Error: %v", err) + } + + ref, err := store.ConstructReference(ctx, DataReference("root"), "subkey", "subkey2") + if err != nil { + fmt.Printf("Failed to construct data reference. Error: %v", err) + } + + fmt.Printf("Constructed data reference [%v] and writing data to it.\n", ref) + + dataToStore := "hello world" + err = store.WriteRaw(ctx, ref, int64(len(dataToStore)), Options{}, strings.NewReader(dataToStore)) + if err != nil { + fmt.Printf("Failed to write data. Error: %v", err) + } + + // Output: + // Creating in memory data store. + // Constructed data reference [/root/subkey/subkey2] and writing data to it. +} diff --git a/storage/stow_store.go b/storage/stow_store.go new file mode 100644 index 0000000000..b13170bf09 --- /dev/null +++ b/storage/stow_store.go @@ -0,0 +1,174 @@ +package storage + +import ( + "context" + "io" + "time" + + "github.com/prometheus/client_golang/prometheus" + + "github.com/lyft/flytestdlib/promutils" + + "github.com/graymeta/stow" + errs "github.com/pkg/errors" +) + +type stowMetrics struct { + BadReference prometheus.Counter + BadContainer prometheus.Counter + + HeadFailure prometheus.Counter + HeadLatency promutils.StopWatch + + ReadFailure prometheus.Counter + ReadOpenLatency promutils.StopWatch + + WriteFailure prometheus.Counter + WriteLatency promutils.StopWatch +} + +// Implements DataStore to talk to stow location store. +type StowStore struct { + stow.Container + copyImpl + metrics *stowMetrics + containerBaseFQN DataReference +} + +type StowMetadata struct { + exists bool + size int64 +} + +func (s StowMetadata) Size() int64 { + return s.size +} + +func (s StowMetadata) Exists() bool { + return s.exists +} + +func (s *StowStore) getContainer(container string) (c stow.Container, err error) { + if s.Container.Name() != container { + s.metrics.BadContainer.Inc() + return nil, stow.ErrNotFound + } + + return s.Container, nil +} + +func (s *StowStore) Head(ctx context.Context, reference DataReference) (Metadata, error) { + _, c, k, err := reference.Split() + if err != nil { + s.metrics.BadReference.Inc() + return nil, err + } + + container, err := s.getContainer(c) + if err != nil { + return nil, err + } + + t := s.metrics.HeadLatency.Start() + item, err := container.Item(k) + if err == nil { + if _, err = item.Metadata(); err == nil { + size, err := item.Size() + if err == nil { + t.Stop() + return StowMetadata{ + exists: true, + size: size, + }, nil + } + } + } + s.metrics.HeadFailure.Inc() + if IsNotFound(err) { + return StowMetadata{exists: false}, nil + } + return StowMetadata{exists: false}, errs.Wrapf(err, "path:%v", k) +} + +func (s *StowStore) ReadRaw(ctx context.Context, reference DataReference) (io.ReadCloser, error) { + _, c, k, err := reference.Split() + if err != nil { + s.metrics.BadReference.Inc() + return nil, err + } + + container, err := s.getContainer(c) + if err != nil { + return nil, err + } + + t := s.metrics.ReadOpenLatency.Start() + item, err := container.Item(k) + if err != nil { + s.metrics.ReadFailure.Inc() + return nil, err + } + t.Stop() + + sizeBytes, err := item.Size() + if err != nil { + return nil, err + } + + if sizeBytes/MiB > GetConfig().Limits.GetLimitMegabytes { + return nil, ErrExceedsLimit + } + + return item.Open() +} + +func (s *StowStore) WriteRaw(ctx context.Context, reference DataReference, size int64, opts Options, raw io.Reader) error { + _, c, k, err := reference.Split() + if err != nil { + s.metrics.BadReference.Inc() + return err + } + + container, err := s.getContainer(c) + if err != nil { + return err + } + + t := s.metrics.WriteLatency.Start() + _, err = container.Put(k, raw, size, opts.Metadata) + if err != nil { + s.metrics.WriteFailure.Inc() + return errs.Wrapf(err, "Failed to write data [%vb] to path [%v].", size, k) + } + t.Stop() + + return nil +} + +func (s *StowStore) GetBaseContainerFQN(ctx context.Context) DataReference { + return s.containerBaseFQN +} + +func NewStowRawStore(containerBaseFQN DataReference, container stow.Container, metricsScope promutils.Scope) (*StowStore, error) { + self := &StowStore{ + Container: container, + containerBaseFQN: containerBaseFQN, + metrics: &stowMetrics{ + BadReference: metricsScope.MustNewCounter("bad_key", "Indicates the provided storage reference/key is incorrectly formatted"), + BadContainer: metricsScope.MustNewCounter("bad_container", "Indicates request for a container that has not been initialized"), + + HeadFailure: metricsScope.MustNewCounter("head_failure", "Indicates failure in HEAD for a given reference"), + HeadLatency: metricsScope.MustNewStopWatch("head", "Indicates time to fetch metadata using the Head API", time.Millisecond), + + ReadFailure: metricsScope.MustNewCounter("read_failure", "Indicates failure in GET for a given reference"), + ReadOpenLatency: metricsScope.MustNewStopWatch("read_open", "Indicates time to first byte when reading", time.Millisecond), + + WriteFailure: metricsScope.MustNewCounter("write_failure", "Indicates failure in storing/PUT for a given reference"), + WriteLatency: metricsScope.MustNewStopWatch("write", "Time to write an object irrespective of size", time.Millisecond), + }, + } + + self.copyImpl = newCopyImpl(self, metricsScope) + + return self, nil +} diff --git a/storage/stow_store_test.go b/storage/stow_store_test.go new file mode 100644 index 0000000000..f0fd282340 --- /dev/null +++ b/storage/stow_store_test.go @@ -0,0 +1,133 @@ +package storage + +import ( + "bytes" + "context" + "io" + "io/ioutil" + "net/url" + "testing" + "time" + + "github.com/lyft/flytestdlib/promutils" + + "github.com/graymeta/stow" + "github.com/stretchr/testify/assert" +) + +type mockStowContainer struct { + id string + items map[string]mockStowItem +} + +func (m mockStowContainer) ID() string { + return m.id +} + +func (m mockStowContainer) Name() string { + return m.id +} + +func (m mockStowContainer) Item(id string) (stow.Item, error) { + if item, found := m.items[id]; found { + return item, nil + } + + return nil, stow.ErrNotFound +} + +func (mockStowContainer) Items(prefix, cursor string, count int) ([]stow.Item, string, error) { + return []stow.Item{}, "", nil +} + +func (mockStowContainer) RemoveItem(id string) error { + return nil +} + +func (m *mockStowContainer) Put(name string, r io.Reader, size int64, metadata map[string]interface{}) (stow.Item, error) { + item := mockStowItem{url: name, size: size} + m.items[name] = item + return item, nil +} + +func newMockStowContainer(id string) *mockStowContainer { + return &mockStowContainer{ + id: id, + items: map[string]mockStowItem{}, + } +} + +type mockStowItem struct { + url string + size int64 +} + +func (m mockStowItem) ID() string { + return m.url +} + +func (m mockStowItem) Name() string { + return m.url +} + +func (m mockStowItem) URL() *url.URL { + u, err := url.Parse(m.url) + if err != nil { + panic(err) + } + + return u +} + +func (m mockStowItem) Size() (int64, error) { + return m.size, nil +} + +func (mockStowItem) Open() (io.ReadCloser, error) { + return ioutil.NopCloser(bytes.NewReader([]byte{})), nil +} + +func (mockStowItem) ETag() (string, error) { + return "", nil +} + +func (mockStowItem) LastMod() (time.Time, error) { + return time.Now(), nil +} + +func (mockStowItem) Metadata() (map[string]interface{}, error) { + return map[string]interface{}{}, nil +} + +func TestStowStore_ReadRaw(t *testing.T) { + t.Run("Happy Path", func(t *testing.T) { + testScope := promutils.NewTestScope() + s, err := NewStowRawStore(s3FQN("container"), newMockStowContainer("container"), testScope) + assert.NoError(t, err) + err = s.WriteRaw(context.TODO(), DataReference("s3://container/path"), 0, Options{}, bytes.NewReader([]byte{})) + assert.NoError(t, err) + metadata, err := s.Head(context.TODO(), DataReference("s3://container/path")) + assert.NoError(t, err) + assert.True(t, metadata.Exists()) + raw, err := s.ReadRaw(context.TODO(), DataReference("s3://container/path")) + assert.NoError(t, err) + rawBytes, err := ioutil.ReadAll(raw) + assert.NoError(t, err) + assert.Equal(t, 0, len(rawBytes)) + assert.Equal(t, DataReference("s3://container"), s.GetBaseContainerFQN(context.TODO())) + }) + + t.Run("Exceeds limit", func(t *testing.T) { + testScope := promutils.NewTestScope() + s, err := NewStowRawStore(s3FQN("container"), newMockStowContainer("container"), testScope) + assert.NoError(t, err) + err = s.WriteRaw(context.TODO(), DataReference("s3://container/path"), 3*MiB, Options{}, bytes.NewReader([]byte{})) + assert.NoError(t, err) + metadata, err := s.Head(context.TODO(), DataReference("s3://container/path")) + assert.NoError(t, err) + assert.True(t, metadata.Exists()) + _, err = s.ReadRaw(context.TODO(), DataReference("s3://container/path")) + assert.Error(t, err) + assert.True(t, IsExceedsLimit(err)) + }) +} diff --git a/storage/testdata/config.yaml b/storage/testdata/config.yaml new file mode 100755 index 0000000000..d8664ca347 --- /dev/null +++ b/storage/testdata/config.yaml @@ -0,0 +1,14 @@ +cache: + max_size_mbs: 0 + target_gc_percent: 0 +connection: + access-key: minio + auth-type: accesskey + disable-ssl: true + endpoint: http://minio:9000 + region: us-east-1 + secret-key: miniostorage +container: "" +limits: + maxDownloadMBs: 0 +type: s3 diff --git a/storage/url_path.go b/storage/url_path.go new file mode 100644 index 0000000000..94e26e317a --- /dev/null +++ b/storage/url_path.go @@ -0,0 +1,44 @@ +package storage + +import ( + "context" + "fmt" + + "github.com/pkg/errors" + + "net/url" + "os" + "path/filepath" + + "github.com/lyft/flytestdlib/logger" +) + +// Implements ReferenceConstructor that assumes paths are URL-compatible. +type URLPathConstructor struct { +} + +func ensureEndingPathSeparator(path DataReference) DataReference { + if len(path) > 0 && path[len(path)-1] == os.PathSeparator { + return path + } + + return path + "/" +} + +func (URLPathConstructor) ConstructReference(ctx context.Context, reference DataReference, nestedKeys ...string) (DataReference, error) { + u, err := url.Parse(string(ensureEndingPathSeparator(reference))) + if err != nil { + logger.Errorf(ctx, "Failed to parse prefix: %v", reference) + return "", errors.Wrap(err, fmt.Sprintf("Reference is of an invalid format [%v]", reference)) + } + + rel, err := url.Parse(filepath.Join(nestedKeys...)) + if err != nil { + logger.Errorf(ctx, "Failed to parse nested keys: %v", reference) + return "", errors.Wrap(err, fmt.Sprintf("Reference is of an invalid format [%v]", reference)) + } + + u = u.ResolveReference(rel) + + return DataReference(u.String()), nil +} diff --git a/storage/url_path_test.go b/storage/url_path_test.go new file mode 100644 index 0000000000..5dc24661cf --- /dev/null +++ b/storage/url_path_test.go @@ -0,0 +1,15 @@ +package storage + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestUrlPathConstructor_ConstructReference(t *testing.T) { + s := URLPathConstructor{} + r, err := s.ConstructReference(context.TODO(), DataReference("hello"), "key1", "key2/", "key3") + assert.NoError(t, err) + assert.Equal(t, "/hello/key1/key2/key3", r.String()) +} diff --git a/storage/utils.go b/storage/utils.go new file mode 100644 index 0000000000..62fb2aa22c --- /dev/null +++ b/storage/utils.go @@ -0,0 +1,34 @@ +package storage + +import ( + "fmt" + "os" + + "github.com/graymeta/stow" + "github.com/pkg/errors" +) + +var ErrExceedsLimit = fmt.Errorf("limit exceeded") + +// Gets a value indicating whether the underlying error is a Not Found error. +func IsNotFound(err error) bool { + if root := errors.Cause(err); root == stow.ErrNotFound || os.IsNotExist(root) { + return true + } + + return false +} + +// Gets a value indicating whether the underlying error is "already exists" error. +func IsExists(err error) bool { + if root := errors.Cause(err); os.IsExist(root) { + return true + } + + return false +} + +// Gets a value indicating whether the root cause of error is a "limit exceeded" error. +func IsExceedsLimit(err error) bool { + return errors.Cause(err) == ErrExceedsLimit +} diff --git a/tests/config_test.go b/tests/config_test.go new file mode 100644 index 0000000000..15c890cf08 --- /dev/null +++ b/tests/config_test.go @@ -0,0 +1,78 @@ +package tests + +import ( + "context" + "flag" + "io/ioutil" + "os" + "path/filepath" + "reflect" + "testing" + + "github.com/lyft/flytestdlib/config/viper" + + "github.com/ghodss/yaml" + "github.com/lyft/flytestdlib/config" + "github.com/lyft/flytestdlib/internal/utils" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/storage" + "github.com/stretchr/testify/assert" +) + +// Make sure existing config file(s) parse correctly before overriding them with this flag! +var update = flag.Bool("update", false, "Updates testdata") + +func TestStorageAndLoggerConfig(t *testing.T) { + type CompositeConfig struct { + Storage storage.Config `json:"storage"` + Logger logger.Config `json:"logger"` + } + + expected := CompositeConfig{ + Storage: storage.Config{ + Type: "s3", + Connection: storage.ConnectionConfig{ + Endpoint: config.URL{URL: utils.MustParseURL("http://minio:9000")}, + AuthType: "accesskey", + AccessKey: "minio", + SecretKey: "miniostorage", + Region: "us-east-1", + DisableSSL: true, + }, + }, + Logger: logger.Config{ + Level: logger.DebugLevel, + }, + } + + configPath := filepath.Join("testdata", "combined.yaml") + if *update { + t.Log("Updating golden files.") + raw, err := yaml.Marshal(expected) + assert.NoError(t, err) + assert.NoError(t, ioutil.WriteFile(configPath, raw, os.ModePerm)) + } + + actual := CompositeConfig{} + /* #nosec */ + raw, err := ioutil.ReadFile(configPath) + assert.NoError(t, err) + assert.NoError(t, yaml.Unmarshal(raw, &actual)) + assert.True(t, reflect.DeepEqual(expected, actual)) +} + +func TestParseExistingConfig(t *testing.T) { + accessor := viper.NewAccessor(config.Options{ + SearchPaths: []string{filepath.Join("testdata", "combined.yaml")}, + }) + + assert.NoError(t, accessor.UpdateConfig(context.TODO())) + + assert.NotNil(t, storage.ConfigSection) + + if _, ok := storage.ConfigSection.GetConfig().(*storage.Config); ok { + assert.True(t, ok) + } else { + assert.FailNow(t, "Retrieved section is not of type storage.") + } +} diff --git a/tests/testdata/combined.yaml b/tests/testdata/combined.yaml new file mode 100755 index 0000000000..f167b1ab33 --- /dev/null +++ b/tests/testdata/combined.yaml @@ -0,0 +1,19 @@ +logger: + level: 5 + mute: false + show-source: false +storage: + cache: + max_size_mbs: 0 + target_gc_percent: 0 + connection: + access-key: minio + auth-type: accesskey + disable-ssl: true + endpoint: http://minio:9000 + region: us-east-1 + secret-key: miniostorage + container: "" + limits: + maxDownloadMBs: 0 + type: s3 diff --git a/utils/auto_refresh_cache.go b/utils/auto_refresh_cache.go new file mode 100644 index 0000000000..a231bbd7c2 --- /dev/null +++ b/utils/auto_refresh_cache.go @@ -0,0 +1,99 @@ +package utils + +import ( + "context" + "sync" + "time" + + "github.com/lyft/flytestdlib/logger" + "k8s.io/apimachinery/pkg/util/wait" +) + +// AutoRefreshCache with regular GetOrCreate and Delete along with background asynchronous refresh. Caller provides +// callbacks for create, refresh and delete item. +// The cache doesn't provide apis to update items. +type AutoRefreshCache interface { + // starts background refresh of items + Start(ctx context.Context) + + // Get item by id if exists else null + Get(id string) CacheItem + + // Get object if exists else create it + GetOrCreate(item CacheItem) (CacheItem, error) +} + +type CacheItem interface { + ID() string +} + +type CacheSyncItem func(ctx context.Context, obj CacheItem) (CacheItem, error) + +func NewAutoRefreshCache(syncCb CacheSyncItem, syncRateLimiter RateLimiter, resyncPeriod time.Duration) AutoRefreshCache { + cache := &autoRefreshCache{ + syncCb: syncCb, + syncRateLimiter: syncRateLimiter, + resyncPeriod: resyncPeriod, + } + + return cache +} + +// Thread-safe general purpose auto-refresh cache that watches for updates asynchronously for the keys after they are added to +// the cache. An item can be inserted only once. +// +// Get reads from sync.map while refresh is invoked on a snapshot of keys. Cache eventually catches up on deleted items. +// +// Sync is run as a fixed-interval-scheduled-task, and is skipped if sync from previous cycle is still running. +type autoRefreshCache struct { + syncCb CacheSyncItem + syncMap sync.Map + syncRateLimiter RateLimiter + resyncPeriod time.Duration +} + +func (w *autoRefreshCache) Start(ctx context.Context) { + go wait.Until(func() { w.sync(ctx) }, w.resyncPeriod, ctx.Done()) +} + +func (w *autoRefreshCache) Get(id string) CacheItem { + if val, ok := w.syncMap.Load(id); ok { + return val.(CacheItem) + } + return nil +} + +// Return the item if exists else create it. +// Create should be invoked only once. recreating the object is not supported. +func (w *autoRefreshCache) GetOrCreate(item CacheItem) (CacheItem, error) { + if val, ok := w.syncMap.Load(item.ID()); ok { + return val.(CacheItem), nil + } + + w.syncMap.Store(item.ID(), item) + return item, nil +} + +func (w *autoRefreshCache) sync(ctx context.Context) { + w.syncMap.Range(func(key, value interface{}) bool { + if w.syncRateLimiter != nil { + err := w.syncRateLimiter.Wait(ctx) + if err != nil { + logger.Warnf(ctx, "unexpected failure in rate-limiter wait %v", key) + return true + } + } + item, err := w.syncCb(ctx, value.(CacheItem)) + if err != nil { + logger.Error(ctx, "failed to get latest copy of the item %v", key) + } + + if item == nil { + w.syncMap.Delete(key) + } else { + w.syncMap.Store(key, item) + } + + return true + }) +} diff --git a/utils/auto_refresh_cache_test.go b/utils/auto_refresh_cache_test.go new file mode 100644 index 0000000000..85a09ce508 --- /dev/null +++ b/utils/auto_refresh_cache_test.go @@ -0,0 +1,108 @@ +package utils + +import ( + "context" + "testing" + "time" + + atomic2 "sync/atomic" + + "github.com/lyft/flytestdlib/atomic" + "github.com/stretchr/testify/assert" +) + +type testCacheItem struct { + val int + deleted atomic.Bool + resyncPeriod time.Duration +} + +func (m *testCacheItem) ID() string { + return "id" +} + +func (m *testCacheItem) moveNext() { + // change value and spare enough time for cache to process the change. + m.val++ + time.Sleep(m.resyncPeriod * 5) +} + +func (m *testCacheItem) syncItem(ctx context.Context, obj CacheItem) (CacheItem, error) { + if m.deleted.Load() { + return nil, nil + } + return m, nil +} + +type testAutoIncrementItem struct { + val int32 +} + +func (a *testAutoIncrementItem) ID() string { + return "autoincrement" +} + +func (a *testAutoIncrementItem) syncItem(ctx context.Context, obj CacheItem) (CacheItem, error) { + atomic2.AddInt32(&a.val, 1) + return a, nil +} + +func TestCache(t *testing.T) { + testResyncPeriod := time.Millisecond + rateLimiter := NewRateLimiter("mockLimiter", 100, 1) + + item := &testCacheItem{val: 0, resyncPeriod: testResyncPeriod, deleted: atomic.NewBool(false)} + cache := NewAutoRefreshCache(item.syncItem, rateLimiter, testResyncPeriod) + + //ctx := context.Background() + ctx, cancel := context.WithCancel(context.Background()) + cache.Start(ctx) + + // create + _, err := cache.GetOrCreate(item) + assert.NoError(t, err, "unexpected GetOrCreate failure") + + // synced? + item.moveNext() + m := cache.Get(item.ID()).(*testCacheItem) + assert.Equal(t, 1, m.val) + + // synced again? + item.moveNext() + m = cache.Get(item.ID()).(*testCacheItem) + assert.Equal(t, 2, m.val) + + // removed? + item.moveNext() + item.deleted.Store(true) + time.Sleep(testResyncPeriod * 10) // spare enough time to process remove! + val := cache.Get(item.ID()) + + assert.Nil(t, val) + cancel() +} + +func TestCacheContextCancel(t *testing.T) { + testResyncPeriod := time.Millisecond + rateLimiter := NewRateLimiter("mockLimiter", 10000, 1) + + item := &testAutoIncrementItem{val: 0} + cache := NewAutoRefreshCache(item.syncItem, rateLimiter, testResyncPeriod) + + ctx, cancel := context.WithCancel(context.Background()) + cache.Start(ctx) + _, err := cache.GetOrCreate(item) + assert.NoError(t, err, "failed to add item to cache") + time.Sleep(testResyncPeriod * 10) // spare enough time to process remove! + cancel() + + // Get item + m, err := cache.GetOrCreate(item) + val1 := m.(*testAutoIncrementItem).val + assert.NoError(t, err, "unexpected GetOrCreate failure") + + // wait a few more resync periods and check that nothings has changed as auto-refresh is stopped + time.Sleep(testResyncPeriod * 20) + val2 := m.(*testAutoIncrementItem).val + assert.Equal(t, val1, val2) +} diff --git a/utils/auto_refresh_example_test.go b/utils/auto_refresh_example_test.go new file mode 100644 index 0000000000..1c47d39ec5 --- /dev/null +++ b/utils/auto_refresh_example_test.go @@ -0,0 +1,88 @@ +package utils + +import ( + "context" + "fmt" + "time" +) + +type ExampleItemStatus string + +const ( + ExampleStatusNotStarted ExampleItemStatus = "Not-started" + ExampleStatusStarted ExampleItemStatus = "Started" + ExampleStatusSucceeded ExampleItemStatus = "Completed" +) + +type ExampleCacheItem struct { + status ExampleItemStatus + id string +} + +func (e *ExampleCacheItem) ID() string { + return e.id +} + +type ExampleService struct { + jobStatus map[string]ExampleItemStatus +} + +func newExampleService() *ExampleService { + return &ExampleService{jobStatus: make(map[string]ExampleItemStatus)} +} + +// advance the status to next, and return +func (f *ExampleService) getStatus(id string) *ExampleCacheItem { + if _, ok := f.jobStatus[id]; !ok { + f.jobStatus[id] = ExampleStatusStarted + } + f.jobStatus[id] = ExampleStatusSucceeded + return &ExampleCacheItem{f.jobStatus[id], id} +} + +func ExampleNewAutoRefreshCache() { + // This auto-refresh cache can be used for cases where keys are created by caller but processed by + // an external service and we want to asynchronously keep track of its progress. + exampleService := newExampleService() + + // define a sync method that the cache can use to auto-refresh in background + syncItemCb := func(ctx context.Context, obj CacheItem) (CacheItem, error) { + return exampleService.getStatus(obj.(*ExampleCacheItem).ID()), nil + } + + // define resync period as time duration we want cache to refresh. We can go as low as we want but cache + // would still be constrained by time it takes to run Sync call for each item. + resyncPeriod := time.Millisecond + + // Since number of items in the cache is dynamic, rate limiter is our knob to control resources we spend on + // sync. + rateLimiter := NewRateLimiter("ExampleRateLimiter", 10000, 1) + + // since cache refreshes itself asynchronously, it may not notice that an object has been deleted immediately, + // so users of the cache should have the delete logic aware of this shortcoming (eg. not-exists may be a valid + // error during removal if based on status in cache). + cache := NewAutoRefreshCache(syncItemCb, rateLimiter, resyncPeriod) + + // start the cache with a context that would be to stop the cache by cancelling the context + ctx, cancel := context.WithCancel(context.Background()) + cache.Start(ctx) + + // creating objects that go through a couple of state transitions to reach the final state. + item1 := &ExampleCacheItem{status: ExampleStatusNotStarted, id: "item1"} + item2 := &ExampleCacheItem{status: ExampleStatusNotStarted, id: "item2"} + _, err1 := cache.GetOrCreate(item1) + _, err2 := cache.GetOrCreate(item2) + if err1 != nil || err2 != nil { + fmt.Printf("unexpected error in create; err1: %v, err2: %v", err1, err2) + } + + // wait for the cache to go through a few refresh cycles and then check status + time.Sleep(resyncPeriod * 10) + fmt.Printf("Current status for item1 is %v", cache.Get(item1.ID()).(*ExampleCacheItem).status) + + // stop the cache + cancel() + + // Output: + // Current status for item1 is Completed +} diff --git a/utils/rate_limiter.go b/utils/rate_limiter.go new file mode 100644 index 0000000000..6a28b21da9 --- /dev/null +++ b/utils/rate_limiter.go @@ -0,0 +1,35 @@ +package utils + +import ( + "context" + + "github.com/lyft/flytestdlib/logger" + "golang.org/x/time/rate" +) + +// Interface to use rate limiter +type RateLimiter interface { + Wait(ctx context.Context) error +} + +type rateLimiter struct { + name string + limiter rate.Limiter +} + +// Blocking method which waits for the next token as per the tps and burst values defined +func (r *rateLimiter) Wait(ctx context.Context) error { + logger.Debugf(ctx, "Waiting for a token from rate limiter %s", r.name) + if err := r.limiter.Wait(ctx); err != nil { + return err + } + return nil +} + +// Create a new rate-limiter with the tps and burst values +func NewRateLimiter(rateLimiterName string, tps float64, burst int) RateLimiter { + return &rateLimiter{ + name: rateLimiterName, + limiter: *rate.NewLimiter(rate.Limit(tps), burst), + } +} diff --git a/utils/rate_limiter_test.go b/utils/rate_limiter_test.go new file mode 100644 index 0000000000..3aaf1df522 --- /dev/null +++ b/utils/rate_limiter_test.go @@ -0,0 +1,39 @@ +package utils + +import ( + "context" + "math" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestInfiniteRateLimiter(t *testing.T) { + infRateLimiter := NewRateLimiter("test_rate_limiter", math.MaxFloat64, 0) + start := time.Now() + + for i := 0; i < 100; i++ { + err := infRateLimiter.Wait(context.Background()) + assert.NoError(t, err, "unexpected failure in wait") + } + assert.True(t, time.Since(start) < 100*time.Millisecond) +} + +func TestRateLimiter(t *testing.T) { + rateLimiter := NewRateLimiter("test_rate_limiter", 1, 1) + start := time.Now() + + for i := 0; i < 5; i++ { + err := rateLimiter.Wait(context.Background()) + assert.Nil(t, err) + } + assert.True(t, time.Since(start) > 3*time.Second) + assert.True(t, time.Since(start) < 5*time.Second) +} + +func TestInvalidRateLimitConfig(t *testing.T) { + rateLimiter := NewRateLimiter("test_rate_limiter", 1, 0) + err := rateLimiter.Wait(context.Background()) + assert.NotNil(t, err) +} diff --git a/utils/sequencer.go b/utils/sequencer.go new file mode 100644 index 0000000000..1d41dde367 --- /dev/null +++ b/utils/sequencer.go @@ -0,0 +1,39 @@ +package utils + +import ( + "sync" + "sync/atomic" +) + +// Sequencer is a thread-safe incremental integer counter. Note that it is a singleton, so +// GetSequencer.GetNext may not always start at 0. +type Sequencer interface { + GetNext() uint64 + GetCur() uint64 +} + +type sequencer struct { + val *uint64 +} + +var once sync.Once +var instance Sequencer + +func GetSequencer() Sequencer { + once.Do(func() { + val := uint64(0) + instance = &sequencer{val: &val} + }) + return instance +} + +// Get the next sequence number, 1 higher than the last and set it as the current one +func (s sequencer) GetNext() uint64 { + x := atomic.AddUint64(s.val, 1) + return x +} + +// Get the current sequence number +func (s sequencer) GetCur() uint64 { + return *s.val +} diff --git a/utils/sequencer_test.go b/utils/sequencer_test.go new file mode 100644 index 0000000000..2444048764 --- /dev/null +++ b/utils/sequencer_test.go @@ -0,0 +1,54 @@ +package utils + +import ( + "sync" + "testing" + + "fmt" + + "github.com/stretchr/testify/assert" +) + +func TestSequencer(t *testing.T) { + size := 3 + sequencer := GetSequencer() + curVal := sequencer.GetCur() + 1 + // sum = n(a0 + aN) / 2 + expectedSum := uint64(size) * (curVal + curVal + uint64(size-1)) / 2 + numbers := make(chan uint64, size) + + var wg sync.WaitGroup + wg.Add(size) + + iter := 0 + for iter < size { + go func() { + number := sequencer.GetNext() + fmt.Printf("list value: %d", number) + numbers <- number + wg.Done() + }() + iter++ + } + wg.Wait() + close(numbers) + + unique, sum := uniqueAndSum(numbers) + assert.True(t, unique, "sequencer generated duplicate numbers") + assert.Equal(t, expectedSum, sum, "sequencer generated sequence numbers with gap %d %d", expectedSum, sum) +} + +func uniqueAndSum(list chan uint64) (bool, uint64) { + set := make(map[uint64]struct{}) + var sum uint64 + + for elem := range list { + fmt.Printf("list value: %d\n", elem) + if _, ok := set[elem]; ok { + return false, sum + } + set[elem] = struct{}{} + sum += elem + } + return true, sum +} diff --git a/version/version.go b/version/version.go new file mode 100644 index 0000000000..ab3e4cf112 --- /dev/null +++ b/version/version.go @@ -0,0 +1,29 @@ +package version + +import ( + "time" + + "github.com/sirupsen/logrus" +) + +// This module provides the ability to inject Build (git sha) and Version information at compile time. +// To set these values invoke go build as follows +// go build -ldflags “-X github.com/lyft/flytestdlib/version.Build=xyz -X github.com/lyft/flytestdlib/version.Version=1.2.3" +// NOTE: If the version is set and server.StartProfilingServerWithDefaultHandlers are initialized then, `/version` +// will provide the build and version information +var ( + // Specifies the GIT sha of the build + Build = "unknown" + // Version for the build, should follow a semver + Version = "unknown" + // Build timestamp + BuildTime = time.Now() +) + +// Use this method to log the build information for the current app. The app name should be provided. To inject the build +// and version information refer to the top-level comment in this file +func LogBuildInformation(appName string) { + logrus.Info("------------------------------------------------------------------------") + logrus.Infof("App [%s], Version [%s], BuildSHA [%s], BuildTS [%s]", appName, Version, Build, BuildTime.String()) + logrus.Info("------------------------------------------------------------------------") +} diff --git a/version/version_test.go b/version/version_test.go new file mode 100644 index 0000000000..d4ddb2def4 --- /dev/null +++ b/version/version_test.go @@ -0,0 +1,29 @@ +package version + +import ( + "bytes" + "fmt" + "testing" + "time" + + "github.com/magiconair/properties/assert" + "github.com/sirupsen/logrus" +) + +type dFormat struct { +} + +func (dFormat) Format(e *logrus.Entry) ([]byte, error) { + return []byte(e.Message), nil +} + +func TestLogBuildInformation(t *testing.T) { + + n := time.Now() + BuildTime = n + buf := bytes.NewBufferString("") + logrus.SetFormatter(dFormat{}) + logrus.SetOutput(buf) + LogBuildInformation("hello") + assert.Equal(t, buf.String(), fmt.Sprintf("------------------------------------------------------------------------App [hello], Version [unknown], BuildSHA [unknown], BuildTS [%s]------------------------------------------------------------------------", n.String())) +} diff --git a/yamlutils/yaml_json.go b/yamlutils/yaml_json.go new file mode 100644 index 0000000000..152c6ce5f4 --- /dev/null +++ b/yamlutils/yaml_json.go @@ -0,0 +1,17 @@ +package yamlutils + +import ( + "io/ioutil" + "path/filepath" + + "github.com/ghodss/yaml" +) + +func ReadYamlFileAsJSON(path string) ([]byte, error) { + r, err := ioutil.ReadFile(filepath.Clean(path)) + if err != nil { + return nil, err + } + + return yaml.YAMLToJSON(r) +}