diff --git a/CHANGELOG.md b/CHANGELOG.md index 00979390e536..ad595da7a29c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,14 +2,18 @@ DEPRECATIONS/CHANGES: - * The AWS authentication backend now allows binds for inputs, as either a + * The AWS authentication backend now allows binds for inputs as either a comma-delimited string or a string array. However, to keep consistency with input and output, when reading a role the binds will now be returned as string arrays rather than strings. IMPROVEMENTS: + * auth/approle: Allow array input for bound_cidr_list [4078] * auth/aws: Allow using lists in role bind parameters [GH-3907] + * auth/aws: Allow binding by EC2 instance IDs [GH-3816] + * secret/transit: Allow selecting signature algorithm as well as hash + algorithm when signing/verifying [GH-4018] * server: Make sure `tls_disable_client_cert` is actually a true value rather than just set [GH-4049] * storage/gcs: Allow specifying chunk size for transfers, which can reduce @@ -19,8 +23,18 @@ IMPROVEMENTS: BUG FIXES: + * auth/aws: Fix honoring `max_ttl` when a corresponding role `ttl` is not also + set [GH-4107] + * auth/okta: Fix honoring configured `max_ttl` value [GH-4110] + * auth/token: If a periodic token being issued has a period greater than the + max_lease_ttl configured on the token store mount, truncate it. This matches + renewal behavior; before it was inconsistent between issuance and renewal. + [GH-4112] * cli: Improve error messages around `vault auth help` when there is no CLI helper for a particular method [GH-4056] + * cli: Fix autocomplete installation when using Fish as the shell [GH-4094] + * secret/database: Properly honor mount-tuned max TTL [GH-4051] + * secret/ssh: Return `key_bits` value when reading a role [GH-4098] ## 0.9.5 (February 26th, 2018) diff --git a/api/renewer.go b/api/renewer.go index 2a72ebe2ef66..b50cf814f376 100644 --- a/api/renewer.go +++ b/api/renewer.go @@ -162,8 +162,8 @@ func (r *Renewer) Stop() { } // Renew starts a background process for renewing this secret. When the secret -// is has auth data, this attempts to renew the auth (token). When the secret -// has a lease, this attempts to renew the lease. +// has auth data, this attempts to renew the auth (token). When the secret has +// a lease, this attempts to renew the lease. func (r *Renewer) Renew() { var result error if r.secret.Auth != nil { diff --git a/api/sys_auth.go b/api/sys_auth.go index 937b0eafb4b9..fd9c5c59a3bb 100644 --- a/api/sys_auth.go +++ b/api/sys_auth.go @@ -91,9 +91,11 @@ type EnableAuthOptions struct { } type AuthConfigInput struct { - DefaultLeaseTTL string `json:"default_lease_ttl" structs:"default_lease_ttl" mapstructure:"default_lease_ttl"` - MaxLeaseTTL string `json:"max_lease_ttl" structs:"max_lease_ttl" mapstructure:"max_lease_ttl"` - PluginName string `json:"plugin_name,omitempty" structs:"plugin_name,omitempty" mapstructure:"plugin_name"` + DefaultLeaseTTL string `json:"default_lease_ttl" structs:"default_lease_ttl" mapstructure:"default_lease_ttl"` + MaxLeaseTTL string `json:"max_lease_ttl" structs:"max_lease_ttl" mapstructure:"max_lease_ttl"` + PluginName string `json:"plugin_name,omitempty" structs:"plugin_name,omitempty" mapstructure:"plugin_name"` + AuditNonHMACRequestKeys []string `json:"audit_non_hmac_request_keys,omitempty" structs:"audit_non_hmac_request_keys" mapstructure:"audit_non_hmac_request_keys"` + AuditNonHMACResponseKeys []string `json:"audit_non_hmac_response_keys,omitempty" structs:"audit_non_hmac_response_keys" mapstructure:"audit_non_hmac_response_keys"` } type AuthMount struct { @@ -106,7 +108,9 @@ type AuthMount struct { } type AuthConfigOutput struct { - DefaultLeaseTTL int `json:"default_lease_ttl" structs:"default_lease_ttl" mapstructure:"default_lease_ttl"` - MaxLeaseTTL int `json:"max_lease_ttl" structs:"max_lease_ttl" mapstructure:"max_lease_ttl"` - PluginName string `json:"plugin_name,omitempty" structs:"plugin_name,omitempty" mapstructure:"plugin_name"` + DefaultLeaseTTL int `json:"default_lease_ttl" structs:"default_lease_ttl" mapstructure:"default_lease_ttl"` + MaxLeaseTTL int `json:"max_lease_ttl" structs:"max_lease_ttl" mapstructure:"max_lease_ttl"` + PluginName string `json:"plugin_name,omitempty" structs:"plugin_name,omitempty" mapstructure:"plugin_name"` + AuditNonHMACRequestKeys []string `json:"audit_non_hmac_request_keys,omitempty" structs:"audit_non_hmac_request_keys" mapstructure:"audit_non_hmac_request_keys"` + AuditNonHMACResponseKeys []string `json:"audit_non_hmac_response_keys,omitempty" structs:"audit_non_hmac_response_keys" mapstructure:"audit_non_hmac_response_keys"` } diff --git a/audit/audit.go b/audit/audit.go index 6adf3b8bcb26..fed7033500ab 100644 --- a/audit/audit.go +++ b/audit/audit.go @@ -27,7 +27,7 @@ type Backend interface { // GetHash is used to return the given data with the backend's hash, // so that a caller can determine if a value in the audit log matches // an expected plaintext value - GetHash(string) (string, error) + GetHash(context.Context, string) (string, error) // Reload is called on SIGHUP for supporting backends. Reload(context.Context) error diff --git a/audit/format.go b/audit/format.go index aaa92a5730a9..f226c95856c8 100644 --- a/audit/format.go +++ b/audit/format.go @@ -1,6 +1,7 @@ package audit import ( + "context" "fmt" "io" "strings" @@ -16,7 +17,7 @@ import ( type AuditFormatWriter interface { WriteRequest(io.Writer, *AuditRequestEntry) error WriteResponse(io.Writer, *AuditResponseEntry) error - Salt() (*salt.Salt, error) + Salt(context.Context) (*salt.Salt, error) } // AuditFormatter implements the Formatter interface, and allows the underlying @@ -27,7 +28,7 @@ type AuditFormatter struct { var _ Formatter = (*AuditFormatter)(nil) -func (f *AuditFormatter) FormatRequest(w io.Writer, config FormatterConfig, in *LogInput) error { +func (f *AuditFormatter) FormatRequest(ctx context.Context, w io.Writer, config FormatterConfig, in *LogInput) error { if in == nil || in.Request == nil { return fmt.Errorf("request to request-audit a nil request") } @@ -40,7 +41,7 @@ func (f *AuditFormatter) FormatRequest(w io.Writer, config FormatterConfig, in * return fmt.Errorf("no format writer specified") } - salt, err := f.Salt() + salt, err := f.Salt(ctx) if err != nil { return errwrap.Wrapf("error fetching salt: {{err}}", err) } @@ -151,7 +152,7 @@ func (f *AuditFormatter) FormatRequest(w io.Writer, config FormatterConfig, in * return f.AuditFormatWriter.WriteRequest(w, reqEntry) } -func (f *AuditFormatter) FormatResponse(w io.Writer, config FormatterConfig, in *LogInput) error { +func (f *AuditFormatter) FormatResponse(ctx context.Context, w io.Writer, config FormatterConfig, in *LogInput) error { if in == nil || in.Request == nil { return fmt.Errorf("request to response-audit a nil request") } @@ -164,7 +165,7 @@ func (f *AuditFormatter) FormatResponse(w io.Writer, config FormatterConfig, in return fmt.Errorf("no format writer specified") } - salt, err := f.Salt() + salt, err := f.Salt(ctx) if err != nil { return errwrap.Wrapf("error fetching salt: {{err}}", err) } diff --git a/audit/format_json.go b/audit/format_json.go index 0a5c9d90bdfa..f42ab20d387f 100644 --- a/audit/format_json.go +++ b/audit/format_json.go @@ -1,6 +1,7 @@ package audit import ( + "context" "encoding/json" "fmt" "io" @@ -12,7 +13,7 @@ import ( // a JSON format. type JSONFormatWriter struct { Prefix string - SaltFunc func() (*salt.Salt, error) + SaltFunc func(context.Context) (*salt.Salt, error) } func (f *JSONFormatWriter) WriteRequest(w io.Writer, req *AuditRequestEntry) error { @@ -47,6 +48,6 @@ func (f *JSONFormatWriter) WriteResponse(w io.Writer, resp *AuditResponseEntry) return enc.Encode(resp) } -func (f *JSONFormatWriter) Salt() (*salt.Salt, error) { - return f.SaltFunc() +func (f *JSONFormatWriter) Salt(ctx context.Context) (*salt.Salt, error) { + return f.SaltFunc(ctx) } diff --git a/audit/format_json_test.go b/audit/format_json_test.go index 90c44a09fa12..9ab20ac35e09 100644 --- a/audit/format_json_test.go +++ b/audit/format_json_test.go @@ -2,6 +2,7 @@ package audit import ( "bytes" + "context" "encoding/json" "strings" "testing" @@ -17,11 +18,11 @@ import ( ) func TestFormatJSON_formatRequest(t *testing.T) { - salter, err := salt.NewSalt(nil, nil) + salter, err := salt.NewSalt(context.Background(), nil, nil) if err != nil { t.Fatal(err) } - saltFunc := func() (*salt.Salt, error) { + saltFunc := func(context.Context) (*salt.Salt, error) { return salter, nil } @@ -90,7 +91,7 @@ func TestFormatJSON_formatRequest(t *testing.T) { Request: tc.Req, OuterErr: tc.Err, } - if err := formatter.FormatRequest(&buf, config, in); err != nil { + if err := formatter.FormatRequest(context.Background(), &buf, config, in); err != nil { t.Fatalf("bad: %s\nerr: %s", name, err) } diff --git a/audit/format_jsonx.go b/audit/format_jsonx.go index 792e5524c384..30937464df5e 100644 --- a/audit/format_jsonx.go +++ b/audit/format_jsonx.go @@ -1,6 +1,7 @@ package audit import ( + "context" "encoding/json" "fmt" "io" @@ -13,7 +14,7 @@ import ( // a XML format. type JSONxFormatWriter struct { Prefix string - SaltFunc func() (*salt.Salt, error) + SaltFunc func(context.Context) (*salt.Salt, error) } func (f *JSONxFormatWriter) WriteRequest(w io.Writer, req *AuditRequestEntry) error { @@ -68,6 +69,6 @@ func (f *JSONxFormatWriter) WriteResponse(w io.Writer, resp *AuditResponseEntry) return err } -func (f *JSONxFormatWriter) Salt() (*salt.Salt, error) { - return f.SaltFunc() +func (f *JSONxFormatWriter) Salt(ctx context.Context) (*salt.Salt, error) { + return f.SaltFunc(ctx) } diff --git a/audit/format_jsonx_test.go b/audit/format_jsonx_test.go index c9096430b7a4..8775ef570b1d 100644 --- a/audit/format_jsonx_test.go +++ b/audit/format_jsonx_test.go @@ -2,6 +2,7 @@ package audit import ( "bytes" + "context" "strings" "testing" "time" @@ -15,11 +16,11 @@ import ( ) func TestFormatJSONx_formatRequest(t *testing.T) { - salter, err := salt.NewSalt(nil, nil) + salter, err := salt.NewSalt(context.Background(), nil, nil) if err != nil { t.Fatal(err) } - saltFunc := func() (*salt.Salt, error) { + saltFunc := func(context.Context) (*salt.Salt, error) { return salter, nil } @@ -94,7 +95,7 @@ func TestFormatJSONx_formatRequest(t *testing.T) { Request: tc.Req, OuterErr: tc.Err, } - if err := formatter.FormatRequest(&buf, config, in); err != nil { + if err := formatter.FormatRequest(context.Background(), &buf, config, in); err != nil { t.Fatalf("bad: %s\nerr: %s", name, err) } diff --git a/audit/format_test.go b/audit/format_test.go index 54da56154f66..e0bd68f0b341 100644 --- a/audit/format_test.go +++ b/audit/format_test.go @@ -1,6 +1,7 @@ package audit import ( + "context" "io" "io/ioutil" "testing" @@ -22,12 +23,12 @@ func (n *noopFormatWriter) WriteResponse(_ io.Writer, _ *AuditResponseEntry) err return nil } -func (n *noopFormatWriter) Salt() (*salt.Salt, error) { +func (n *noopFormatWriter) Salt(ctx context.Context) (*salt.Salt, error) { if n.salt != nil { return n.salt, nil } var err error - n.salt, err = salt.NewSalt(nil, nil) + n.salt, err = salt.NewSalt(ctx, nil, nil) if err != nil { return nil, err } @@ -40,14 +41,14 @@ func TestFormatRequestErrors(t *testing.T) { AuditFormatWriter: &noopFormatWriter{}, } - if err := formatter.FormatRequest(ioutil.Discard, config, &LogInput{}); err == nil { + if err := formatter.FormatRequest(context.Background(), ioutil.Discard, config, &LogInput{}); err == nil { t.Fatal("expected error due to nil request") } in := &LogInput{ Request: &logical.Request{}, } - if err := formatter.FormatRequest(nil, config, in); err == nil { + if err := formatter.FormatRequest(context.Background(), nil, config, in); err == nil { t.Fatal("expected error due to nil writer") } } @@ -58,14 +59,14 @@ func TestFormatResponseErrors(t *testing.T) { AuditFormatWriter: &noopFormatWriter{}, } - if err := formatter.FormatResponse(ioutil.Discard, config, &LogInput{}); err == nil { + if err := formatter.FormatResponse(context.Background(), ioutil.Discard, config, &LogInput{}); err == nil { t.Fatal("expected error due to nil request") } in := &LogInput{ Request: &logical.Request{}, } - if err := formatter.FormatResponse(nil, config, in); err == nil { + if err := formatter.FormatResponse(context.Background(), nil, config, in); err == nil { t.Fatal("expected error due to nil writer") } } diff --git a/audit/formatter.go b/audit/formatter.go index d296e55b778d..7702a1ee5d64 100644 --- a/audit/formatter.go +++ b/audit/formatter.go @@ -1,6 +1,7 @@ package audit import ( + "context" "io" ) @@ -10,8 +11,8 @@ import ( // // It is recommended that you pass data through Hash prior to formatting it. type Formatter interface { - FormatRequest(io.Writer, FormatterConfig, *LogInput) error - FormatResponse(io.Writer, FormatterConfig, *LogInput) error + FormatRequest(context.Context, io.Writer, FormatterConfig, *LogInput) error + FormatResponse(context.Context, io.Writer, FormatterConfig, *LogInput) error } type FormatterConfig struct { diff --git a/audit/hashstructure_test.go b/audit/hashstructure_test.go index 21cf05b1e901..f42fa5040235 100644 --- a/audit/hashstructure_test.go +++ b/audit/hashstructure_test.go @@ -99,7 +99,7 @@ func TestHashString(t *testing.T) { Key: "salt", Value: []byte("foo"), }) - localSalt, err := salt.NewSalt(inmemStorage, &salt.Config{ + localSalt, err := salt.NewSalt(context.Background(), inmemStorage, &salt.Config{ HMAC: sha256.New, HMACType: "hmac-sha256", }) @@ -206,7 +206,7 @@ func TestHash(t *testing.T) { Key: "salt", Value: []byte("foo"), }) - localSalt, err := salt.NewSalt(inmemStorage, &salt.Config{ + localSalt, err := salt.NewSalt(context.Background(), inmemStorage, &salt.Config{ HMAC: sha256.New, HMACType: "hmac-sha256", }) diff --git a/builtin/audit/file/backend.go b/builtin/audit/file/backend.go index 7bf066d1c034..bd69c3e2bb79 100644 --- a/builtin/audit/file/backend.go +++ b/builtin/audit/file/backend.go @@ -143,7 +143,7 @@ type Backend struct { var _ audit.Backend = (*Backend)(nil) -func (b *Backend) Salt() (*salt.Salt, error) { +func (b *Backend) Salt(ctx context.Context) (*salt.Salt, error) { b.saltMutex.RLock() if b.salt != nil { defer b.saltMutex.RUnlock() @@ -155,7 +155,7 @@ func (b *Backend) Salt() (*salt.Salt, error) { if b.salt != nil { return b.salt, nil } - salt, err := salt.NewSalt(b.saltView, b.saltConfig) + salt, err := salt.NewSalt(ctx, b.saltView, b.saltConfig) if err != nil { return nil, err } @@ -163,30 +163,30 @@ func (b *Backend) Salt() (*salt.Salt, error) { return salt, nil } -func (b *Backend) GetHash(data string) (string, error) { - salt, err := b.Salt() +func (b *Backend) GetHash(ctx context.Context, data string) (string, error) { + salt, err := b.Salt(ctx) if err != nil { return "", err } return audit.HashString(salt, data), nil } -func (b *Backend) LogRequest(_ context.Context, in *audit.LogInput) error { +func (b *Backend) LogRequest(ctx context.Context, in *audit.LogInput) error { b.fileLock.Lock() defer b.fileLock.Unlock() switch b.path { case "stdout": - return b.formatter.FormatRequest(os.Stdout, b.formatConfig, in) + return b.formatter.FormatRequest(ctx, os.Stdout, b.formatConfig, in) case "discard": - return b.formatter.FormatRequest(ioutil.Discard, b.formatConfig, in) + return b.formatter.FormatRequest(ctx, ioutil.Discard, b.formatConfig, in) } if err := b.open(); err != nil { return err } - if err := b.formatter.FormatRequest(b.f, b.formatConfig, in); err == nil { + if err := b.formatter.FormatRequest(ctx, b.f, b.formatConfig, in); err == nil { return nil } @@ -198,26 +198,26 @@ func (b *Backend) LogRequest(_ context.Context, in *audit.LogInput) error { return err } - return b.formatter.FormatRequest(b.f, b.formatConfig, in) + return b.formatter.FormatRequest(ctx, b.f, b.formatConfig, in) } -func (b *Backend) LogResponse(_ context.Context, in *audit.LogInput) error { +func (b *Backend) LogResponse(ctx context.Context, in *audit.LogInput) error { b.fileLock.Lock() defer b.fileLock.Unlock() switch b.path { case "stdout": - return b.formatter.FormatResponse(os.Stdout, b.formatConfig, in) + return b.formatter.FormatResponse(ctx, os.Stdout, b.formatConfig, in) case "discard": - return b.formatter.FormatResponse(ioutil.Discard, b.formatConfig, in) + return b.formatter.FormatResponse(ctx, ioutil.Discard, b.formatConfig, in) } if err := b.open(); err != nil { return err } - if err := b.formatter.FormatResponse(b.f, b.formatConfig, in); err == nil { + if err := b.formatter.FormatResponse(ctx, b.f, b.formatConfig, in); err == nil { return nil } @@ -229,7 +229,7 @@ func (b *Backend) LogResponse(_ context.Context, in *audit.LogInput) error { return err } - return b.formatter.FormatResponse(b.f, b.formatConfig, in) + return b.formatter.FormatResponse(ctx, b.f, b.formatConfig, in) } // The file lock must be held before calling this diff --git a/builtin/audit/socket/backend.go b/builtin/audit/socket/backend.go index d99d28f574fa..e0d5b2271b18 100644 --- a/builtin/audit/socket/backend.go +++ b/builtin/audit/socket/backend.go @@ -123,8 +123,8 @@ type Backend struct { var _ audit.Backend = (*Backend)(nil) -func (b *Backend) GetHash(data string) (string, error) { - salt, err := b.Salt() +func (b *Backend) GetHash(ctx context.Context, data string) (string, error) { + salt, err := b.Salt(ctx) if err != nil { return "", err } @@ -133,7 +133,7 @@ func (b *Backend) GetHash(data string) (string, error) { func (b *Backend) LogRequest(ctx context.Context, in *audit.LogInput) error { var buf bytes.Buffer - if err := b.formatter.FormatRequest(&buf, b.formatConfig, in); err != nil { + if err := b.formatter.FormatRequest(ctx, &buf, b.formatConfig, in); err != nil { return err } @@ -156,7 +156,7 @@ func (b *Backend) LogRequest(ctx context.Context, in *audit.LogInput) error { func (b *Backend) LogResponse(ctx context.Context, in *audit.LogInput) error { var buf bytes.Buffer - if err := b.formatter.FormatResponse(&buf, b.formatConfig, in); err != nil { + if err := b.formatter.FormatResponse(ctx, &buf, b.formatConfig, in); err != nil { return err } @@ -223,7 +223,7 @@ func (b *Backend) Reload(ctx context.Context) error { return err } -func (b *Backend) Salt() (*salt.Salt, error) { +func (b *Backend) Salt(ctx context.Context) (*salt.Salt, error) { b.saltMutex.RLock() if b.salt != nil { defer b.saltMutex.RUnlock() @@ -235,7 +235,7 @@ func (b *Backend) Salt() (*salt.Salt, error) { if b.salt != nil { return b.salt, nil } - salt, err := salt.NewSalt(b.saltView, b.saltConfig) + salt, err := salt.NewSalt(ctx, b.saltView, b.saltConfig) if err != nil { return nil, err } diff --git a/builtin/audit/syslog/backend.go b/builtin/audit/syslog/backend.go index 2df25f229822..68d8a361ab1e 100644 --- a/builtin/audit/syslog/backend.go +++ b/builtin/audit/syslog/backend.go @@ -110,17 +110,17 @@ type Backend struct { var _ audit.Backend = (*Backend)(nil) -func (b *Backend) GetHash(data string) (string, error) { - salt, err := b.Salt() +func (b *Backend) GetHash(ctx context.Context, data string) (string, error) { + salt, err := b.Salt(ctx) if err != nil { return "", err } return audit.HashString(salt, data), nil } -func (b *Backend) LogRequest(_ context.Context, in *audit.LogInput) error { +func (b *Backend) LogRequest(ctx context.Context, in *audit.LogInput) error { var buf bytes.Buffer - if err := b.formatter.FormatRequest(&buf, b.formatConfig, in); err != nil { + if err := b.formatter.FormatRequest(ctx, &buf, b.formatConfig, in); err != nil { return err } @@ -129,9 +129,9 @@ func (b *Backend) LogRequest(_ context.Context, in *audit.LogInput) error { return err } -func (b *Backend) LogResponse(_ context.Context, in *audit.LogInput) error { +func (b *Backend) LogResponse(ctx context.Context, in *audit.LogInput) error { var buf bytes.Buffer - if err := b.formatter.FormatResponse(&buf, b.formatConfig, in); err != nil { + if err := b.formatter.FormatResponse(ctx, &buf, b.formatConfig, in); err != nil { return err } @@ -144,7 +144,7 @@ func (b *Backend) Reload(_ context.Context) error { return nil } -func (b *Backend) Salt() (*salt.Salt, error) { +func (b *Backend) Salt(ctx context.Context) (*salt.Salt, error) { b.saltMutex.RLock() if b.salt != nil { defer b.saltMutex.RUnlock() @@ -156,7 +156,7 @@ func (b *Backend) Salt() (*salt.Salt, error) { if b.salt != nil { return b.salt, nil } - salt, err := salt.NewSalt(b.saltView, b.saltConfig) + salt, err := salt.NewSalt(ctx, b.saltView, b.saltConfig) if err != nil { return nil, err } diff --git a/builtin/credential/app-id/backend.go b/builtin/credential/app-id/backend.go index 2cea93edabc9..32ec15875447 100644 --- a/builtin/credential/app-id/backend.go +++ b/builtin/credential/app-id/backend.go @@ -93,7 +93,7 @@ type backend struct { MapUserId *framework.PathMap } -func (b *backend) Salt() (*salt.Salt, error) { +func (b *backend) Salt(ctx context.Context) (*salt.Salt, error) { b.SaltMutex.RLock() if b.salt != nil { defer b.SaltMutex.RUnlock() @@ -105,7 +105,7 @@ func (b *backend) Salt() (*salt.Salt, error) { if b.salt != nil { return b.salt, nil } - salt, err := salt.NewSalt(b.view, &salt.Config{ + salt, err := salt.NewSalt(ctx, b.view, &salt.Config{ HashFunc: salt.SHA1Hash, Location: salt.DefaultLocation, }) diff --git a/builtin/credential/app-id/backend_test.go b/builtin/credential/app-id/backend_test.go index bff8bc77b484..e25fa9cbb7aa 100644 --- a/builtin/credential/app-id/backend_test.go +++ b/builtin/credential/app-id/backend_test.go @@ -54,7 +54,7 @@ func TestBackend_basic(t *testing.T) { if len(keys) != 1 { t.Fatalf("expected 1 key, got %d", len(keys)) } - bSalt, err := b.Salt() + bSalt, err := b.Salt(context.Background()) if err != nil { t.Fatal(err) } diff --git a/builtin/credential/approle/backend.go b/builtin/credential/approle/backend.go index 11c62648b60d..6b914685ef55 100644 --- a/builtin/credential/approle/backend.go +++ b/builtin/credential/approle/backend.go @@ -104,7 +104,7 @@ func Backend(conf *logical.BackendConfig) (*backend, error) { return b, nil } -func (b *backend) Salt() (*salt.Salt, error) { +func (b *backend) Salt(ctx context.Context) (*salt.Salt, error) { b.saltMutex.RLock() if b.salt != nil { defer b.saltMutex.RUnlock() @@ -116,7 +116,7 @@ func (b *backend) Salt() (*salt.Salt, error) { if b.salt != nil { return b.salt, nil } - salt, err := salt.NewSalt(b.view, &salt.Config{ + salt, err := salt.NewSalt(ctx, b.view, &salt.Config{ HashFunc: salt.SHA256Hash, Location: salt.DefaultLocation, }) diff --git a/builtin/credential/approle/path_role.go b/builtin/credential/approle/path_role.go index 0d3417cf4185..08b3ff4f9d6a 100644 --- a/builtin/credential/approle/path_role.go +++ b/builtin/credential/approle/path_role.go @@ -9,6 +9,7 @@ import ( "github.com/fatih/structs" "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/helper/cidrutil" + "github.com/hashicorp/vault/helper/consts" "github.com/hashicorp/vault/helper/locksutil" "github.com/hashicorp/vault/helper/policyutil" "github.com/hashicorp/vault/helper/strutil" @@ -50,7 +51,10 @@ type roleStorageEntry struct { BindSecretID bool `json:"bind_secret_id" structs:"bind_secret_id" mapstructure:"bind_secret_id"` // A constraint, if set, specifies the CIDR blocks from which logins should be allowed - BoundCIDRList string `json:"bound_cidr_list" structs:"bound_cidr_list" mapstructure:"bound_cidr_list"` + BoundCIDRListOld string `json:"bound_cidr_list,omitempty"` + + // A constraint, if set, specifies the CIDR blocks from which logins should be allowed + BoundCIDRList []string `json:"bound_cidr_list_list" structs:"bound_cidr_list" mapstructure:"bound_cidr_list"` // Period, if set, indicates that the token generated using this role // should never expire. The token should be renewed within the duration @@ -113,9 +117,9 @@ func rolePaths(b *backend) []*framework.Path { Description: "Impose secret_id to be presented when logging in using this role. Defaults to 'true'.", }, "bound_cidr_list": &framework.FieldSchema{ - Type: framework.TypeString, - Description: `Comma separated list of CIDR blocks, if set, specifies blocks of IP -addresses which can perform the login operation`, + Type: framework.TypeCommaStringSlice, + Description: `Comma separated string or list of CIDR blocks. If set, specifies the blocks of +IP addresses which can perform the login operation.`, }, "policies": &framework.FieldSchema{ Type: framework.TypeCommaStringSlice, @@ -198,9 +202,9 @@ TTL will be set to the value of this parameter.`, Description: "Name of the role.", }, "bound_cidr_list": &framework.FieldSchema{ - Type: framework.TypeString, - Description: `Comma separated list of CIDR blocks, if set, specifies blocks of IP -addresses which can perform the login operation`, + Type: framework.TypeCommaStringSlice, + Description: `Comma separated string or list of CIDR blocks. If set, specifies the blocks of +IP addresses which can perform the login operation.`, }, }, Callbacks: map[logical.Operation]framework.OperationFunc{ @@ -391,8 +395,8 @@ be renewed. Defaults to 0, in which case the value will fall back to the system/ formatted string containing the metadata in key value pairs.`, }, "cidr_list": &framework.FieldSchema{ - Type: framework.TypeString, - Description: `Comma separated list of CIDR blocks enforcing secret IDs to be used from + Type: framework.TypeCommaStringSlice, + Description: `Comma separated string or list of CIDR blocks enforcing secret IDs to be used from specific set of IP addresses. If 'bound_cidr_list' is set on the role, then the list of CIDR blocks listed here should be a subset of the CIDR blocks listed on the role.`, @@ -496,8 +500,8 @@ the role.`, formatted string containing metadata in key value pairs.`, }, "cidr_list": &framework.FieldSchema{ - Type: framework.TypeString, - Description: `Comma separated list of CIDR blocks enforcing secret IDs to be used from + Type: framework.TypeCommaStringSlice, + Description: `Comma separated string or list of CIDR blocks enforcing secret IDs to be used from specific set of IP addresses. If 'bound_cidr_list' is set on the role, then the list of CIDR blocks listed here should be a subset of the CIDR blocks listed on the role.`, @@ -633,7 +637,7 @@ func validateRoleConstraints(role *roleStorageEntry) error { // At least one constraint should be enabled on the role switch { case role.BindSecretID: - case role.BoundCIDRList != "": + case len(role.BoundCIDRList) != 0: default: return fmt.Errorf("at least one constraint should be enabled on the role") } @@ -718,6 +722,27 @@ func (b *backend) roleEntry(ctx context.Context, s logical.Storage, roleName str return nil, err } + needsUpgrade := false + + if role.BoundCIDRListOld != "" { + role.BoundCIDRList = strings.Split(role.BoundCIDRListOld, ",") + role.BoundCIDRListOld = "" + needsUpgrade = true + } + + if needsUpgrade && (b.System().LocalMount() || !b.System().ReplicationState().HasState(consts.ReplicationPerformanceSecondary)) { + entry, err := logical.StorageEntryJSON("role/"+strings.ToLower(roleName), &role) + if err != nil { + return nil, err + } + if err := s.Put(ctx, entry); err != nil { + // Only perform upgrades on replication primary + if !strings.Contains(err.Error(), logical.ErrReadOnly.Error()) { + return nil, err + } + } + } + return &role, nil } @@ -774,13 +799,13 @@ func (b *backend) pathRoleCreateUpdate(ctx context.Context, req *logical.Request } if boundCIDRListRaw, ok := data.GetOk("bound_cidr_list"); ok { - role.BoundCIDRList = strings.TrimSpace(boundCIDRListRaw.(string)) + role.BoundCIDRList = boundCIDRListRaw.([]string) } else if req.Operation == logical.CreateOperation { - role.BoundCIDRList = data.Get("bound_cidr_list").(string) + role.BoundCIDRList = data.Get("bound_cidr_list").([]string) } - if role.BoundCIDRList != "" { - valid, err := cidrutil.ValidateCIDRListString(role.BoundCIDRList, ",") + if len(role.BoundCIDRList) != 0 { + valid, err := cidrutil.ValidateCIDRListSlice(role.BoundCIDRList) if err != nil { return nil, fmt.Errorf("failed to validate CIDR blocks: %v", err) } @@ -1242,19 +1267,17 @@ func (b *backend) pathRoleBoundCIDRListUpdate(ctx context.Context, req *logical. return nil, nil } - role.BoundCIDRList = strings.TrimSpace(data.Get("bound_cidr_list").(string)) - if role.BoundCIDRList == "" { + role.BoundCIDRList = data.Get("bound_cidr_list").([]string) + if len(role.BoundCIDRList) == 0 { return logical.ErrorResponse("missing bound_cidr_list"), nil } - if role.BoundCIDRList != "" { - valid, err := cidrutil.ValidateCIDRListString(role.BoundCIDRList, ",") - if err != nil { - return nil, fmt.Errorf("failed to validate CIDR blocks: %v", err) - } - if !valid { - return logical.ErrorResponse("failed to validate CIDR blocks"), nil - } + valid, err := cidrutil.ValidateCIDRListSlice(role.BoundCIDRList) + if err != nil { + return nil, fmt.Errorf("failed to validate CIDR blocks: %v", err) + } + if !valid { + return logical.ErrorResponse("failed to validate CIDR blocks"), nil } return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "") @@ -1302,7 +1325,7 @@ func (b *backend) pathRoleBoundCIDRListDelete(ctx context.Context, req *logical. } // Deleting a field implies setting the value to it's default value. - role.BoundCIDRList = data.GetDefaultOrZero("bound_cidr_list").(string) + role.BoundCIDRList = data.GetDefaultOrZero("bound_cidr_list").([]string) return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "") } @@ -1990,11 +2013,11 @@ func (b *backend) handleRoleSecretIDCommon(ctx context.Context, req *logical.Req return logical.ErrorResponse("bind_secret_id is not set on the role"), nil } - cidrList := data.Get("cidr_list").(string) + secretIDCIDRs := data.Get("cidr_list").([]string) // Validate the list of CIDR blocks - if cidrList != "" { - valid, err := cidrutil.ValidateCIDRListString(cidrList, ",") + if len(secretIDCIDRs) != 0 { + valid, err := cidrutil.ValidateCIDRListSlice(secretIDCIDRs) if err != nil { return nil, fmt.Errorf("failed to validate CIDR blocks: %v", err) } @@ -2003,9 +2026,6 @@ func (b *backend) handleRoleSecretIDCommon(ctx context.Context, req *logical.Req } } - // Parse the CIDR blocks into a slice - secretIDCIDRs := strutil.ParseDedupLowercaseAndSortStrings(cidrList, ",") - // Ensure that the CIDRs on the secret ID are a subset of that of role's if err := verifyCIDRRoleSecretIDSubset(secretIDCIDRs, role.BoundCIDRList); err != nil { return nil, err @@ -2052,7 +2072,7 @@ func (b *backend) setRoleIDEntry(ctx context.Context, s logical.Storage, roleID lock.Lock() defer lock.Unlock() - salt, err := b.Salt() + salt, err := b.Salt(ctx) if err != nil { return err } @@ -2080,7 +2100,7 @@ func (b *backend) roleIDEntry(ctx context.Context, s logical.Storage, roleID str var result roleIDStorageEntry - salt, err := b.Salt() + salt, err := b.Salt(ctx) if err != nil { return nil, err } @@ -2108,7 +2128,7 @@ func (b *backend) roleIDEntryDelete(ctx context.Context, s logical.Storage, role lock.Lock() defer lock.Unlock() - salt, err := b.Salt() + salt, err := b.Salt(ctx) if err != nil { return err } diff --git a/builtin/credential/approle/path_role_test.go b/builtin/credential/approle/path_role_test.go index 65c2ecb95273..1e07c32484fe 100644 --- a/builtin/credential/approle/path_role_test.go +++ b/builtin/credential/approle/path_role_test.go @@ -12,6 +12,105 @@ import ( "github.com/mitchellh/mapstructure" ) +func TestApprole_UpgradeBoundCIDRList(t *testing.T) { + var resp *logical.Response + var err error + + b, storage := createBackendWithStorage(t) + + roleData := map[string]interface{}{ + "policies": []string{"default"}, + "bind_secret_id": true, + "bound_cidr_list": []string{"127.0.0.1/18", "192.178.1.2/24"}, + } + + // Create a role with bound_cidr_list set + resp, err = b.HandleRequest(context.Background(), &logical.Request{ + Path: "role/testrole", + Operation: logical.CreateOperation, + Storage: storage, + Data: roleData, + }) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: err: %v\nresp: %#v", err, resp) + } + + // Read the role and check that the bound_cidr_list is set properly + resp, err = b.HandleRequest(context.Background(), &logical.Request{ + Path: "role/testrole", + Operation: logical.ReadOperation, + Storage: storage, + }) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: err: %v\nresp: %#v", err, resp) + } + + expected := []string{"127.0.0.1/18", "192.178.1.2/24"} + actual := resp.Data["bound_cidr_list"].([]string) + + if !reflect.DeepEqual(expected, actual) { + t.Fatalf("bad: bound_cidr_list; expected: %#v\nactual: %#v\n", expected, actual) + } + + // Modify the storage entry of the role to hold the old style string typed bound_cidr_list + role := &roleStorageEntry{ + RoleID: "testroleid", + HMACKey: "testhmackey", + Policies: []string{"default"}, + BindSecretID: true, + BoundCIDRListOld: "127.0.0.1/18,192.178.1.2/24", + } + err = b.setRoleEntry(context.Background(), storage, "testrole", role, "") + if err != nil { + t.Fatal(err) + } + + // Read the role. The upgrade code should have migrated the old type to the new type + resp, err = b.HandleRequest(context.Background(), &logical.Request{ + Path: "role/testrole", + Operation: logical.ReadOperation, + Storage: storage, + }) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: err: %v\nresp: %#v", err, resp) + } + if !reflect.DeepEqual(expected, actual) { + t.Fatalf("bad: bound_cidr_list; expected: %#v\nactual: %#v\n", expected, actual) + } + + // Create a secret-id by supplying a subset of the role's CIDR blocks with the new type + resp, err = b.HandleRequest(context.Background(), &logical.Request{ + Path: "role/testrole/secret-id", + Operation: logical.UpdateOperation, + Storage: storage, + Data: map[string]interface{}{ + "cidr_list": []string{"127.0.0.1/24"}, + }, + }) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: err: %v\nresp: %#v", err, resp) + } + if resp.Data["secret_id"].(string) == "" { + t.Fatalf("failed to generate secret-id") + } + + // Check that the backwards compatibility for the string type is not broken + resp, err = b.HandleRequest(context.Background(), &logical.Request{ + Path: "role/testrole/secret-id", + Operation: logical.UpdateOperation, + Storage: storage, + Data: map[string]interface{}{ + "cidr_list": "127.0.0.1/24", + }, + }) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: err: %v\nresp: %#v", err, resp) + } + if resp.Data["secret_id"].(string) == "" { + t.Fatalf("failed to generate secret-id") + } +} + func TestApprole_RoleNameLowerCasing(t *testing.T) { var resp *logical.Response var err error @@ -858,8 +957,9 @@ func TestAppRole_RoleCRUD(t *testing.T) { "token_ttl": 400, "token_max_ttl": 500, "token_num_uses": 600, - "bound_cidr_list": "127.0.0.1/32,127.0.0.1/16", + "bound_cidr_list": []string{"127.0.0.1/32", "127.0.0.1/16"}, } + var expectedStruct roleStorageEntry err = mapstructure.Decode(expected, &expectedStruct) if err != nil { diff --git a/builtin/credential/approle/path_tidy_user_id.go b/builtin/credential/approle/path_tidy_user_id.go index 437d5c8137e3..23a380153c46 100644 --- a/builtin/credential/approle/path_tidy_user_id.go +++ b/builtin/credential/approle/path_tidy_user_id.go @@ -104,7 +104,7 @@ func (b *backend) tidySecretID(ctx context.Context, s logical.Storage) error { // the corresponding accessor from the accessorMap. This will leave // only the dangling accessors in the map which can then be cleaned // up later. - salt, err := b.Salt() + salt, err := b.Salt(ctx) if err != nil { lock.Unlock() return err diff --git a/builtin/credential/approle/validation.go b/builtin/credential/approle/validation.go index 2052e70e5e85..559e14140c62 100644 --- a/builtin/credential/approle/validation.go +++ b/builtin/credential/approle/validation.go @@ -13,7 +13,6 @@ import ( "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/helper/cidrutil" "github.com/hashicorp/vault/helper/locksutil" - "github.com/hashicorp/vault/helper/strutil" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -143,13 +142,13 @@ func (b *backend) validateCredentials(ctx context.Context, req *logical.Request, } } - if role.BoundCIDRList != "" { + if len(role.BoundCIDRList) != 0 { // If 'bound_cidr_list' was set, verify the CIDR restrictions if req.Connection == nil || req.Connection.RemoteAddr == "" { return nil, "", metadata, "", fmt.Errorf("failed to get connection information"), nil } - belongs, err := cidrutil.IPBelongsToCIDRBlocksString(req.Connection.RemoteAddr, role.BoundCIDRList, ",") + belongs, err := cidrutil.IPBelongsToCIDRBlocksSlice(req.Connection.RemoteAddr, role.BoundCIDRList) if err != nil { return nil, "", metadata, "", nil, errwrap.Wrapf("failed to verify the CIDR restrictions set on the role: {{err}}", err) } @@ -163,7 +162,7 @@ func (b *backend) validateCredentials(ctx context.Context, req *logical.Request, // validateBindSecretID is used to determine if the given SecretID is a valid one. func (b *backend) validateBindSecretID(ctx context.Context, req *logical.Request, roleName, secretID, - hmacKey, roleBoundCIDRList string) (bool, map[string]string, error) { + hmacKey string, roleBoundCIDRList []string) (bool, map[string]string, error) { secretIDHMAC, err := createHMAC(hmacKey, secretID) if err != nil { return false, nil, fmt.Errorf("failed to create HMAC of secret_id: %v", err) @@ -281,17 +280,14 @@ func (b *backend) validateBindSecretID(ctx context.Context, req *logical.Request // verifyCIDRRoleSecretIDSubset checks if the CIDR blocks set on the secret ID // are a subset of CIDR blocks set on the role -func verifyCIDRRoleSecretIDSubset(secretIDCIDRs []string, roleBoundCIDRList string) error { +func verifyCIDRRoleSecretIDSubset(secretIDCIDRs []string, roleBoundCIDRList []string) error { if len(secretIDCIDRs) != 0 { - // Parse the CIDRs on role as a slice - roleCIDRs := strutil.ParseDedupLowercaseAndSortStrings(roleBoundCIDRList, ",") - // If there are no CIDR blocks on the role, then the subset // requirement would be satisfied - if len(roleCIDRs) != 0 { - subset, err := cidrutil.SubsetBlocks(roleCIDRs, secretIDCIDRs) + if len(roleBoundCIDRList) != 0 { + subset, err := cidrutil.SubsetBlocks(roleBoundCIDRList, secretIDCIDRs) if !subset || err != nil { - return fmt.Errorf("failed to verify subset relationship between CIDR blocks on the role %q and CIDR blocks on the secret ID %q: %v", roleCIDRs, secretIDCIDRs, err) + return fmt.Errorf("failed to verify subset relationship between CIDR blocks on the role %q and CIDR blocks on the secret ID %q: %v", roleBoundCIDRList, secretIDCIDRs, err) } } } @@ -480,7 +476,7 @@ func (b *backend) secretIDAccessorEntry(ctx context.Context, s logical.Storage, var result secretIDAccessorStorageEntry // Create index entry, mapping the accessor to the token ID - salt, err := b.Salt() + salt, err := b.Salt(ctx) if err != nil { return nil, err } @@ -513,7 +509,7 @@ func (b *backend) createSecretIDAccessorEntry(ctx context.Context, s logical.Sto entry.SecretIDAccessor = accessorUUID // Create index entry, mapping the accessor to the token ID - salt, err := b.Salt() + salt, err := b.Salt(ctx) if err != nil { return err } @@ -536,7 +532,7 @@ func (b *backend) createSecretIDAccessorEntry(ctx context.Context, s logical.Sto // deleteSecretIDAccessorEntry deletes the storage index mapping the accessor to a SecretID. func (b *backend) deleteSecretIDAccessorEntry(ctx context.Context, s logical.Storage, secretIDAccessor string) error { - salt, err := b.Salt() + salt, err := b.Salt(ctx) if err != nil { return err } diff --git a/builtin/credential/aws/backend_test.go b/builtin/credential/aws/backend_test.go index cdaa8aac5b45..bb186df7ec3b 100644 --- a/builtin/credential/aws/backend_test.go +++ b/builtin/credential/aws/backend_test.go @@ -1070,6 +1070,11 @@ func TestBackendAcc_LoginWithInstanceIdentityDocAndWhitelistIdentity(t *testing. "nonce": "vault-client-nonce", } + parsedIdentityDoc, err := b.parseIdentityDocument(context.Background(), storage, pkcs7) + if err != nil { + t.Fatal(err) + } + // Perform the login operation with a AMI ID that is not matching // the bound on the role. loginRequest := &logical.Request{ @@ -1079,14 +1084,15 @@ func TestBackendAcc_LoginWithInstanceIdentityDocAndWhitelistIdentity(t *testing. Data: loginInput, } - // Place the wrong AMI ID in the role data. + // Baseline role data that should succeed permit login data := map[string]interface{}{ - "auth_type": "ec2", - "policies": "root", - "max_ttl": "120s", - "bound_ami_id": []string{"wrong_ami_id", amiID, "wrong_ami_id2"}, - "bound_account_id": accountID, - "bound_iam_role_arn": iamARN, + "auth_type": "ec2", + "policies": "root", + "max_ttl": "120s", + "bound_ami_id": []string{"wrong_ami_id", amiID, "wrong_ami_id2"}, + "bound_account_id": accountID, + "bound_iam_role_arn": iamARN, + "bound_ec2_instance_id": []string{parsedIdentityDoc.InstanceID, "i-1234567"}, } roleReq := &logical.Request{ @@ -1129,7 +1135,15 @@ func TestBackendAcc_LoginWithInstanceIdentityDocAndWhitelistIdentity(t *testing. t.Fatal(err) } - // place a substring of the IAM role ARN + // Place correct IAM role ARN, but incorrect instance ID + data["bound_iam_role_arn"] = []string{"wrong_iam_role_arn_1", iamARN, "wrong_iam_role_arn_2"} + data["bound_ec2_instance_id"] = "i-1234567" + if err := updateRoleExpectLoginFail(roleReq, loginRequest); err != nil { + t.Fatal(err) + } + + // Place correct instance ID, but substring of the IAM role ARN + data["bound_ec2_instance_id"] = []string{parsedIdentityDoc.InstanceID, "i-1234567"} data["bound_iam_role_arn"] = []string{"wrong_iam_role_arn", iamARN[:len(iamARN)-2], "wrong_iam_role_arn_2"} if err := updateRoleExpectLoginFail(roleReq, loginRequest); err != nil { t.Fatal(err) @@ -1143,7 +1157,7 @@ func TestBackendAcc_LoginWithInstanceIdentityDocAndWhitelistIdentity(t *testing. t.Fatal(err) } - // place a globbed IAM role ARN + // globbed IAM role ARN data["bound_iam_role_arn"] = []string{"wrong_iam_role_arn_1", fmt.Sprintf("%s*", iamARN[:len(iamARN)-2]), "wrong_iam_role_arn_2"} resp, err := b.HandleRequest(context.Background(), roleReq) if err != nil || (resp != nil && resp.IsError()) { @@ -1176,6 +1190,9 @@ func TestBackendAcc_LoginWithInstanceIdentityDocAndWhitelistIdentity(t *testing. if instanceID == "" { t.Fatalf("instance ID not present in the response object") } + if instanceID != parsedIdentityDoc.InstanceID { + t.Fatalf("instance ID in response (%q) did not match instance ID from identity document (%q)", instanceID, parsedIdentityDoc.InstanceID) + } _, ok := resp.Auth.Metadata["nonce"] if ok { diff --git a/builtin/credential/aws/path_login.go b/builtin/credential/aws/path_login.go index 8bd584e2f8fd..0a78b8b10fb2 100644 --- a/builtin/credential/aws/path_login.go +++ b/builtin/credential/aws/path_login.go @@ -384,6 +384,11 @@ func (b *backend) verifyInstanceMeetsRoleRequirements(ctx context.Context, return nil, fmt.Errorf("nil identityDoc") } + // Verify that the instance ID matches one of the ones set by the role + if len(roleEntry.BoundEc2InstanceIDs) > 0 && !strutil.StrListContains(roleEntry.BoundEc2InstanceIDs, *instance.InstanceId) { + return fmt.Errorf("instance ID %q does not belong to the role %q", *instance.InstanceId, roleName), nil + } + // Verify that the AccountID of the instance trying to login matches the // AccountID specified as a constraint on role if len(roleEntry.BoundAccountIDs) > 0 && !strutil.StrListContains(roleEntry.BoundAccountIDs, identityDoc.AccountID) { @@ -821,12 +826,14 @@ func (b *backend) pathLoginUpdateEc2(ctx context.Context, req *logical.Request, resp.Auth.Metadata["nonce"] = clientNonce } - if roleEntry.MaxTTL > time.Duration(0) { - // Cap TTL to shortestMaxTTL - if resp.Auth.TTL > shortestMaxTTL { - resp.AddWarning(fmt.Sprintf("Effective TTL of '%s' exceeded the effective max_ttl of '%s'; TTL value is capped accordingly", (resp.Auth.TTL / time.Second), (shortestMaxTTL / time.Second))) - resp.Auth.TTL = shortestMaxTTL - } + // In this case no role value was set so pull in what will be assigned by + // Core for comparison + if resp.Auth.TTL == 0 { + resp.Auth.TTL = b.System().DefaultLeaseTTL() + } + if resp.Auth.TTL > shortestMaxTTL { + resp.Auth.TTL = shortestMaxTTL + resp.AddWarning(fmt.Sprintf("Effective TTL of '%s' exceeded the effective max_ttl of '%s'; TTL value is capped accordingly", resp.Auth.TTL, shortestMaxTTL)) } return resp, nil @@ -1329,6 +1336,9 @@ func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request, }, } + if resp.Auth.TTL == 0 { + resp.Auth.TTL = b.System().DefaultLeaseTTL() + } if roleEntry.MaxTTL > time.Duration(0) { // Cap maxTTL to the sysview's max TTL maxTTL := roleEntry.MaxTTL @@ -1338,7 +1348,7 @@ func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request, // Cap TTL to MaxTTL if resp.Auth.TTL > maxTTL { - resp.AddWarning(fmt.Sprintf("Effective TTL of '%s' exceeded the effective max_ttl of '%s'; TTL value is capped accordingly", (resp.Auth.TTL / time.Second), (maxTTL / time.Second))) + resp.AddWarning(fmt.Sprintf("Effective TTL of '%s' exceeded the effective max_ttl of '%s'; TTL value is capped accordingly", resp.Auth.TTL, maxTTL)) resp.Auth.TTL = maxTTL } } diff --git a/builtin/credential/aws/path_role.go b/builtin/credential/aws/path_role.go index ac88f4452736..d31ab3fd72be 100644 --- a/builtin/credential/aws/path_role.go +++ b/builtin/credential/aws/path_role.go @@ -71,6 +71,13 @@ with an IAM instance profile ARN which has a prefix that matches one of the values specified by this parameter. The value is prefix-matched (as though it were a glob ending in '*'). This is only applicable when auth_type is ec2 or inferred_entity_type is ec2_instance.`, + }, + "bound_ec2_instance_id": { + Type: framework.TypeCommaStringSlice, + Description: `If set, defines a constraint on the EC2 instances to have one of the +given instance IDs. Can be a list or comma-separated string of EC2 instance +IDs. This is only applicable when auth_type is ec2 or inferred_entity_type is +ec2_instance.`, }, "resolve_aws_unique_ids": { Type: framework.TypeBool, @@ -545,6 +552,10 @@ func (b *backend) pathRoleCreateUpdate(ctx context.Context, req *logical.Request roleEntry.BoundIamInstanceProfileARNs = boundIamInstanceProfileARNRaw.([]string) } + if boundEc2InstanceIDRaw, ok := data.GetOk("bound_ec2_instance_id"); ok { + roleEntry.BoundEc2InstanceIDs = boundEc2InstanceIDRaw.([]string) + } + if boundIamPrincipalARNRaw, ok := data.GetOk("bound_iam_principal_arn"); ok { principalARNs := boundIamPrincipalARNRaw.([]string) roleEntry.BoundIamPrincipalARNs = principalARNs @@ -618,56 +629,63 @@ func (b *backend) pathRoleCreateUpdate(ctx context.Context, req *logical.Request if len(roleEntry.BoundAccountIDs) > 0 { if !allowEc2Binds { - return logical.ErrorResponse(fmt.Sprintf("specified bound_account_id but not allowing ec2 auth_type or inferring %s", ec2EntityType)), nil + return logical.ErrorResponse(fmt.Sprintf("specified bound_account_id but not specifying ec2 auth_type or inferring %s", ec2EntityType)), nil } numBinds++ } if len(roleEntry.BoundRegions) > 0 { if roleEntry.AuthType != ec2AuthType { - return logical.ErrorResponse("specified bound_region but not allowing ec2 auth_type"), nil + return logical.ErrorResponse("specified bound_region but not specifying ec2 auth_type"), nil } numBinds++ } if len(roleEntry.BoundAmiIDs) > 0 { if !allowEc2Binds { - return logical.ErrorResponse(fmt.Sprintf("specified bound_ami_id but not allowing ec2 auth_type or inferring %s", ec2EntityType)), nil + return logical.ErrorResponse(fmt.Sprintf("specified bound_ami_id but not specifying ec2 auth_type or inferring %s", ec2EntityType)), nil } numBinds++ } if len(roleEntry.BoundIamInstanceProfileARNs) > 0 { if !allowEc2Binds { - return logical.ErrorResponse(fmt.Sprintf("specified bound_iam_instance_profile_arn but not allowing ec2 auth_type or inferring %s", ec2EntityType)), nil + return logical.ErrorResponse(fmt.Sprintf("specified bound_iam_instance_profile_arn but not specifying ec2 auth_type or inferring %s", ec2EntityType)), nil + } + numBinds++ + } + + if len(roleEntry.BoundEc2InstanceIDs) > 0 { + if !allowEc2Binds { + return logical.ErrorResponse(fmt.Sprintf("specified bound_ec2_instance_id but not specifying ec2 auth_type or inferring %s", ec2EntityType)), nil } numBinds++ } if len(roleEntry.BoundIamRoleARNs) > 0 { if !allowEc2Binds { - return logical.ErrorResponse(fmt.Sprintf("specified bound_iam_role_arn but not allowing ec2 auth_type or inferring %s", ec2EntityType)), nil + return logical.ErrorResponse(fmt.Sprintf("specified bound_iam_role_arn but not specifying ec2 auth_type or inferring %s", ec2EntityType)), nil } numBinds++ } if len(roleEntry.BoundIamPrincipalARNs) > 0 { if roleEntry.AuthType != iamAuthType { - return logical.ErrorResponse("specified bound_iam_principal_arn but not allowing iam auth_type"), nil + return logical.ErrorResponse("specified bound_iam_principal_arn but not specifying iam auth_type"), nil } numBinds++ } if len(roleEntry.BoundVpcIDs) > 0 { if !allowEc2Binds { - return logical.ErrorResponse(fmt.Sprintf("specified bound_vpc_id but not allowing ec2 auth_type or inferring %s", ec2EntityType)), nil + return logical.ErrorResponse(fmt.Sprintf("specified bound_vpc_id but not specifying ec2 auth_type or inferring %s", ec2EntityType)), nil } numBinds++ } if len(roleEntry.BoundSubnetIDs) > 0 { if !allowEc2Binds { - return logical.ErrorResponse(fmt.Sprintf("specified bound_subnet_id but not allowing ec2 auth_type or inferring %s", ec2EntityType)), nil + return logical.ErrorResponse(fmt.Sprintf("specified bound_subnet_id but not specifying ec2 auth_type or inferring %s", ec2EntityType)), nil } numBinds++ } @@ -791,6 +809,7 @@ type awsRoleEntry struct { AuthType string `json:"auth_type" ` BoundAmiIDs []string `json:"bound_ami_id_list"` BoundAccountIDs []string `json:"bound_account_id_list"` + BoundEc2InstanceIDs []string `json:"bound_ec2_instance_id_list"` BoundIamPrincipalARNs []string `json:"bound_iam_principal_arn_list"` BoundIamPrincipalIDs []string `json:"bound_iam_principal_id_list"` BoundIamRoleARNs []string `json:"bound_iam_role_arn_list"` @@ -827,6 +846,7 @@ func (r *awsRoleEntry) ToResponseData() map[string]interface{} { "auth_type": r.AuthType, "bound_ami_id": r.BoundAmiIDs, "bound_account_id": r.BoundAccountIDs, + "bound_ec2_instance_id": r.BoundEc2InstanceIDs, "bound_iam_principal_arn": r.BoundIamPrincipalARNs, "bound_iam_principal_id": r.BoundIamPrincipalIDs, "bound_iam_role_arn": r.BoundIamRoleARNs, diff --git a/builtin/credential/aws/path_role_test.go b/builtin/credential/aws/path_role_test.go index 045352f59a28..1fa2ab178e65 100644 --- a/builtin/credential/aws/path_role_test.go +++ b/builtin/credential/aws/path_role_test.go @@ -570,6 +570,7 @@ func TestAwsEc2_RoleCrud(t *testing.T) { "bound_iam_instance_profile_arn": "arn:aws:iam::123456789012:instance-profile/MyInstancePro*", "bound_subnet_id": "testsubnetid", "bound_vpc_id": "testvpcid", + "bound_ec2_instance_id": "i-12345678901234567,i-76543210987654321", "role_tag": "testtag", "resolve_aws_unique_ids": false, "allow_instance_migration": true, @@ -600,6 +601,7 @@ func TestAwsEc2_RoleCrud(t *testing.T) { "bound_ami_id": []string{"testamiid"}, "bound_account_id": []string{"testaccountid"}, "bound_region": []string{"testregion"}, + "bound_ec2_instance_id": []string{"i-12345678901234567", "i-76543210987654321"}, "bound_iam_principal_arn": []string{}, "bound_iam_principal_id": []string{}, "bound_iam_role_arn": []string{"arn:aws:iam::123456789012:role/MyRole"}, diff --git a/builtin/credential/okta/path_login.go b/builtin/credential/okta/path_login.go index 331af5a776a7..a7a719e2423d 100644 --- a/builtin/credential/okta/path_login.go +++ b/builtin/credential/okta/path_login.go @@ -96,6 +96,21 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, d *framew }, } + if resp.Auth.TTL == 0 { + resp.Auth.TTL = b.System().DefaultLeaseTTL() + } + if cfg.MaxTTL > 0 { + maxTTL := cfg.MaxTTL + if maxTTL > b.System().MaxLeaseTTL() { + maxTTL = b.System().MaxLeaseTTL() + } + + if resp.Auth.TTL > maxTTL { + resp.Auth.TTL = maxTTL + resp.AddWarning(fmt.Sprintf("Effective TTL of '%s' exceeded the effective max_ttl of '%s'; TTL value is capped accordingly", resp.Auth.TTL, maxTTL)) + } + } + for _, groupName := range groupNames { if groupName == "" { continue diff --git a/builtin/logical/aws/backend_test.go b/builtin/logical/aws/backend_test.go index ad919b1bee23..1973a2ef8197 100644 --- a/builtin/logical/aws/backend_test.go +++ b/builtin/logical/aws/backend_test.go @@ -7,7 +7,6 @@ import ( "fmt" "log" "os" - "strings" "testing" "time" @@ -16,6 +15,7 @@ import ( "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/ec2" "github.com/aws/aws-sdk-go/service/iam" + "github.com/aws/aws-sdk-go/service/sts" "github.com/hashicorp/go-cleanhttp" "github.com/hashicorp/vault/logical" logicaltest "github.com/hashicorp/vault/logical/testing" @@ -41,15 +41,21 @@ func TestBackend_basic(t *testing.T) { } func TestBackend_basicSTS(t *testing.T) { + accessKey := &awsAccessKey{} logicaltest.Test(t, logicaltest.TestCase{ AcceptanceTest: true, PreCheck: func() { testAccPreCheck(t) + createUser(t, accessKey) createRole(t) + // Sleep sometime because AWS is eventually consistent + // Both the createUser and createRole depend on this + log.Println("[WARN] Sleeping for 10 seconds waiting for AWS...") + time.Sleep(10 * time.Second) }, Backend: getBackend(t), Steps: []logicaltest.TestStep{ - testAccStepConfig(t), + testAccStepConfigWithCreds(t, accessKey), testAccStepWritePolicy(t, "test", testPolicy), testAccStepReadSTS(t, "test"), testAccStepWriteArnPolicyRef(t, "test", testPolicyArn), @@ -57,7 +63,9 @@ func TestBackend_basicSTS(t *testing.T) { testAccStepWriteArnRoleRef(t, testRoleName), testAccStepReadSTS(t, testRoleName), }, - Teardown: teardown, + Teardown: func() error { + return teardown(accessKey) + }, }) } @@ -81,14 +89,6 @@ func TestBackend_policyCrud(t *testing.T) { } func testAccPreCheck(t *testing.T) { - if v := os.Getenv("AWS_ACCESS_KEY_ID"); v == "" { - t.Fatal("AWS_ACCESS_KEY_ID must be set for acceptance tests") - } - - if v := os.Getenv("AWS_SECRET_ACCESS_KEY"); v == "" { - t.Fatal("AWS_SECRET_ACCESS_KEY must be set for acceptance tests") - } - if v := os.Getenv("AWS_DEFAULT_REGION"); v == "" { log.Println("[INFO] Test: Using us-west-2 as test region") os.Setenv("AWS_DEFAULT_REGION", "us-west-2") @@ -97,7 +97,7 @@ func testAccPreCheck(t *testing.T) { if v := os.Getenv("AWS_ACCOUNT_ID"); v == "" { accountId, err := getAccountId() if err != nil { - t.Fatal("AWS_ACCOUNT_ID could not be read from iam:GetUser for acceptance tests") + t.Fatalf("AWS_ACCOUNT_ID could not be read from iam:GetUser for acceptance tests: %#v", err) } log.Printf("[INFO] Test: Used %s as AWS_ACCOUNT_ID", accountId) os.Setenv("AWS_ACCOUNT_ID", accountId) @@ -105,27 +105,23 @@ func testAccPreCheck(t *testing.T) { } func getAccountId() (string, error) { - creds := credentials.NewStaticCredentials(os.Getenv("AWS_ACCESS_KEY_ID"), - os.Getenv("AWS_SECRET_ACCESS_KEY"), - "") - awsConfig := &aws.Config{ - Credentials: creds, - Region: aws.String("us-east-1"), - HTTPClient: cleanhttp.DefaultClient(), + Region: aws.String("us-east-1"), + HTTPClient: cleanhttp.DefaultClient(), } - svc := iam.New(session.New(awsConfig)) + svc := sts.New(session.New(awsConfig)) - params := &iam.GetUserInput{} - res, err := svc.GetUser(params) + params := &sts.GetCallerIdentityInput{} + res, err := svc.GetCallerIdentity(params) if err != nil { return "", err } + if res == nil { + return "", fmt.Errorf("got nil response from GetCallerIdentity") + } - // split "arn:aws:iam::012345678912:user/username" - accountId := strings.Split(*res.User.Arn, ":")[4] - return accountId, nil + return *res.Account, nil } const testRoleName = "Vault-Acceptance-Test-AWS-Assume-Role" @@ -144,12 +140,9 @@ func createRole(t *testing.T) { ] } ` - creds := credentials.NewStaticCredentials(os.Getenv("AWS_ACCESS_KEY_ID"), os.Getenv("AWS_SECRET_ACCESS_KEY"), "") - awsConfig := &aws.Config{ - Credentials: creds, - Region: aws.String("us-east-1"), - HTTPClient: cleanhttp.DefaultClient(), + Region: aws.String("us-east-1"), + HTTPClient: cleanhttp.DefaultClient(), } svc := iam.New(session.New(awsConfig)) trustPolicy := fmt.Sprintf(testRoleAssumePolicy, os.Getenv("AWS_ACCOUNT_ID")) @@ -176,19 +169,94 @@ func createRole(t *testing.T) { if err != nil { t.Fatalf("AWS CreateRole failed: %v", err) } - - // Sleep sometime because AWS is eventually consistent - log.Println("[WARN] Sleeping for 10 seconds waiting for AWS...") - time.Sleep(10 * time.Second) } -func teardown() error { - creds := credentials.NewStaticCredentials(os.Getenv("AWS_ACCESS_KEY_ID"), os.Getenv("AWS_SECRET_ACCESS_KEY"), "") +const testUserName = "Vault-Acceptance-Test-AWS-FederationToken" + +func createUser(t *testing.T, accessKey *awsAccessKey) { + // The sequence of user creation actions is carefully chosen to minimize + // impact of stolen IAM user credentials + // 1. Create user, without any permissions or credentials. At this point, + // nobody cares if creds compromised because this user can do nothing. + // 2. Attach the timebomb policy. This grants no access but puts a time limit + // on validitity of compromised credentials. If this fails, nobody cares + // because the user has no permissions to do anything anyway + // 3. Attach the AdminAccess policy. The IAM user still has no credentials to + // do anything + // 4. Generate API creds to get an actual access key and secret key + timebombPolicyTemplate := `{ + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Deny", + "Action": "*", + "Resource": "*", + "Condition": { + "DateGreaterThan": { + "aws:CurrentTime": "%s" + } + } + } + ] + } + ` + validity := time.Duration(2 * time.Hour) + expiry := time.Now().Add(validity) + timebombPolicy := fmt.Sprintf(timebombPolicyTemplate, expiry.Format(time.RFC3339)) + awsConfig := &aws.Config{ + Region: aws.String("us-east-1"), + HTTPClient: cleanhttp.DefaultClient(), + } + svc := iam.New(session.New(awsConfig)) + + createUserInput := &iam.CreateUserInput{ + UserName: aws.String(testUserName), + } + log.Printf("[INFO] AWS CreateUser: %s", testUserName) + _, err := svc.CreateUser(createUserInput) + if err != nil { + t.Fatalf("AWS CreateUser failed: %v", err) + } + + putPolicyInput := &iam.PutUserPolicyInput{ + PolicyDocument: aws.String(timebombPolicy), + PolicyName: aws.String("SelfDestructionTimebomb"), + UserName: aws.String(testUserName), + } + _, err = svc.PutUserPolicy(putPolicyInput) + if err != nil { + t.Fatalf("AWS PutUserPolicy failed: %v", err) + } + + attachUserPolicyInput := &iam.AttachUserPolicyInput{ + PolicyArn: aws.String("arn:aws:iam::aws:policy/AdministratorAccess"), + UserName: aws.String(testUserName), + } + _, err = svc.AttachUserPolicy(attachUserPolicyInput) + if err != nil { + t.Fatalf("AWS AttachUserPolicy failed, %v", err) + } + + createAccessKeyInput := &iam.CreateAccessKeyInput{ + UserName: aws.String(testUserName), + } + createAccessKeyOutput, err := svc.CreateAccessKey(createAccessKeyInput) + if err != nil { + t.Fatalf("AWS CreateAccessKey failed: %v", err) + } + if createAccessKeyOutput == nil { + t.Fatalf("AWS CreateAccessKey returned nil") + } + genAccessKey := createAccessKeyOutput.AccessKey + + accessKey.AccessKeyId = *genAccessKey.AccessKeyId + accessKey.SecretAccessKey = *genAccessKey.SecretAccessKey +} +func teardown(accessKey *awsAccessKey) error { awsConfig := &aws.Config{ - Credentials: creds, - Region: aws.String("us-east-1"), - HTTPClient: cleanhttp.DefaultClient(), + Region: aws.String("us-east-1"), + HTTPClient: cleanhttp.DefaultClient(), } svc := iam.New(session.New(awsConfig)) @@ -214,6 +282,45 @@ func teardown() error { return err } + userDetachment := &iam.DetachUserPolicyInput{ + PolicyArn: aws.String("arn:aws:iam::aws:policy/AdministratorAccess"), + UserName: aws.String(testUserName), + } + _, err = svc.DetachUserPolicy(userDetachment) + if err != nil { + log.Printf("[WARN] AWS DetachUserPolicy failed: %v", err) + return err + } + + deleteAccessKeyInput := &iam.DeleteAccessKeyInput{ + AccessKeyId: aws.String(accessKey.AccessKeyId), + UserName: aws.String(testUserName), + } + _, err = svc.DeleteAccessKey(deleteAccessKeyInput) + if err != nil { + log.Printf("[WARN] AWS DeleteAccessKey failed: %v", err) + return err + } + + deleteUserPolicyInput := &iam.DeleteUserPolicyInput{ + PolicyName: aws.String("SelfDestructionTimebomb"), + UserName: aws.String(testUserName), + } + _, err = svc.DeleteUserPolicy(deleteUserPolicyInput) + if err != nil { + log.Printf("[WARN] AWS DeleteUserPolicy failed: %v", err) + return err + } + deleteUserInput := &iam.DeleteUserInput{ + UserName: aws.String(testUserName), + } + log.Printf("[INFO] AWS DeleteUser: %s", testUserName) + _, err = svc.DeleteUser(deleteUserInput) + if err != nil { + log.Printf("[WARN] AWS DeleteUser failed: %v", err) + return err + } + return nil } @@ -222,9 +329,26 @@ func testAccStepConfig(t *testing.T) logicaltest.TestStep { Operation: logical.UpdateOperation, Path: "config/root", Data: map[string]interface{}{ - "access_key": os.Getenv("AWS_ACCESS_KEY_ID"), - "secret_key": os.Getenv("AWS_SECRET_ACCESS_KEY"), - "region": os.Getenv("AWS_DEFAULT_REGION"), + "region": os.Getenv("AWS_DEFAULT_REGION"), + }, + } +} + +func testAccStepConfigWithCreds(t *testing.T, accessKey *awsAccessKey) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.UpdateOperation, + Path: "config/root", + Data: map[string]interface{}{ + "region": os.Getenv("AWS_DEFAULT_REGION"), + }, + PreFlight: func(req *logical.Request) error { + // Values in Data above get eagerly evaluated due to the testing framework. + // In particular, they get evaluated before accessKey gets set by CreateUser + // and thus would fail. By moving to a closure in a PreFlight, we ensure that + // the creds get evaluated lazily after they've been properly set + req.Data["access_key"] = accessKey.AccessKeyId + req.Data["secret_key"] = accessKey.SecretAccessKey + return nil }, } } @@ -243,10 +367,6 @@ func testAccStepReadUser(t *testing.T, name string) logicaltest.TestStep { } log.Printf("[WARN] Generated credentials: %v", d) - // Sleep sometime because AWS is eventually consistent - log.Println("[WARN] Sleeping for 10 seconds waiting for AWS...") - time.Sleep(10 * time.Second) - // Build a client and verify that the credentials work creds := credentials.NewStaticCredentials(d.AccessKey, d.SecretKey, "") awsConfig := &aws.Config{ @@ -257,12 +377,19 @@ func testAccStepReadUser(t *testing.T, name string) logicaltest.TestStep { client := ec2.New(session.New(awsConfig)) log.Printf("[WARN] Verifying that the generated credentials work...") - _, err := client.DescribeInstances(&ec2.DescribeInstancesInput{}) - if err != nil { - return err + retryCount := 0 + success := false + var err error + for !success && retryCount < 10 { + _, err = client.DescribeInstances(&ec2.DescribeInstancesInput{}) + if err == nil { + return nil + } + time.Sleep(time.Second) + retryCount++ } - return nil + return err }, } } @@ -458,3 +585,8 @@ func testAccStepWriteArnRoleRef(t *testing.T, roleName string) logicaltest.TestS }, } } + +type awsAccessKey struct { + AccessKeyId string + SecretAccessKey string +} diff --git a/builtin/logical/database/path_creds_create.go b/builtin/logical/database/path_creds_create.go index 7f66f9eaab10..aaab1b766639 100644 --- a/builtin/logical/database/path_creds_create.go +++ b/builtin/logical/database/path_creds_create.go @@ -74,9 +74,16 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc { } } - ttl := role.DefaultTTL - if ttl == 0 || (role.MaxTTL > 0 && ttl > role.MaxTTL) { - ttl = role.MaxTTL + ttl := b.System().DefaultLeaseTTL() + if role.DefaultTTL != 0 { + ttl = role.DefaultTTL + } + maxTTL := b.System().MaxLeaseTTL() + if role.MaxTTL != 0 && role.MaxTTL < maxTTL { + maxTTL = role.MaxTTL + } + if ttl > maxTTL { + ttl = maxTTL } expiration := time.Now().Add(ttl) diff --git a/builtin/logical/ssh/backend.go b/builtin/logical/ssh/backend.go index 9f12d58ad3c8..0616b8f783da 100644 --- a/builtin/logical/ssh/backend.go +++ b/builtin/logical/ssh/backend.go @@ -75,7 +75,7 @@ func Backend(conf *logical.BackendConfig) (*backend, error) { return &b, nil } -func (b *backend) Salt() (*salt.Salt, error) { +func (b *backend) Salt(ctx context.Context) (*salt.Salt, error) { b.saltMutex.RLock() if b.salt != nil { defer b.saltMutex.RUnlock() @@ -87,7 +87,7 @@ func (b *backend) Salt() (*salt.Salt, error) { if b.salt != nil { return b.salt, nil } - salt, err := salt.NewSalt(b.view, &salt.Config{ + salt, err := salt.NewSalt(ctx, b.view, &salt.Config{ HashFunc: salt.SHA256Hash, Location: salt.DefaultLocation, }) diff --git a/builtin/logical/ssh/path_creds_create.go b/builtin/logical/ssh/path_creds_create.go index 8761e6f2bf21..236cb8197f48 100644 --- a/builtin/logical/ssh/path_creds_create.go +++ b/builtin/logical/ssh/path_creds_create.go @@ -194,7 +194,7 @@ func (b *backend) GenerateDynamicCredential(ctx context.Context, req *logical.Re } // Add the public key to authorized_keys file in target machine - err = b.installPublicKeyInTarget(role.AdminUser, username, ip, role.Port, hostKey.Key, dynamicPublicKey, role.InstallScript, true) + err = b.installPublicKeyInTarget(ctx, role.AdminUser, username, ip, role.Port, hostKey.Key, dynamicPublicKey, role.InstallScript, true) if err != nil { return "", "", fmt.Errorf("failed to add public key to authorized_keys file in target: %v", err) } @@ -202,12 +202,12 @@ func (b *backend) GenerateDynamicCredential(ctx context.Context, req *logical.Re } // Generates a UUID OTP and its salted value based on the salt of the backend. -func (b *backend) GenerateSaltedOTP() (string, string, error) { +func (b *backend) GenerateSaltedOTP(ctx context.Context) (string, string, error) { str, err := uuid.GenerateUUID() if err != nil { return "", "", err } - salt, err := b.Salt() + salt, err := b.Salt(ctx) if err != nil { return "", "", err } @@ -217,7 +217,7 @@ func (b *backend) GenerateSaltedOTP() (string, string, error) { // Generates an UUID OTP and creates an entry for the same in storage backend with its salted string. func (b *backend) GenerateOTPCredential(ctx context.Context, req *logical.Request, sshOTPEntry *sshOTP) (string, error) { - otp, otpSalted, err := b.GenerateSaltedOTP() + otp, otpSalted, err := b.GenerateSaltedOTP(ctx) if err != nil { return "", err } @@ -230,7 +230,7 @@ func (b *backend) GenerateOTPCredential(ctx context.Context, req *logical.Reques // OTP is generated. It is very unlikely that this is the case and this // code is just for safety. for err == nil && entry != nil { - otp, otpSalted, err = b.GenerateSaltedOTP() + otp, otpSalted, err = b.GenerateSaltedOTP(ctx) if err != nil { return "", err } diff --git a/builtin/logical/ssh/path_roles.go b/builtin/logical/ssh/path_roles.go index 0d21d17b6ec0..44954a91e287 100644 --- a/builtin/logical/ssh/path_roles.go +++ b/builtin/logical/ssh/path_roles.go @@ -530,6 +530,7 @@ func (b *backend) parseRole(role *sshRole) (map[string]interface{}, error) { "allow_user_key_ids": role.AllowUserKeyIDs, "key_id_format": role.KeyIDFormat, "key_type": role.KeyType, + "key_bits": role.KeyBits, "default_critical_options": role.DefaultCriticalOptions, "default_extensions": role.DefaultExtensions, } diff --git a/builtin/logical/ssh/path_verify.go b/builtin/logical/ssh/path_verify.go index 94477a662f5c..d15d2ec6cbbd 100644 --- a/builtin/logical/ssh/path_verify.go +++ b/builtin/logical/ssh/path_verify.go @@ -59,7 +59,7 @@ func (b *backend) pathVerifyWrite(ctx context.Context, req *logical.Request, d * // Create the salt of OTP because entry would have been create with the // salt and not directly of the OTP. Salt will yield the same value which // because the seed is the same, the backend salt. - salt, err := b.Salt() + salt, err := b.Salt(ctx) if err != nil { return nil, err } diff --git a/builtin/logical/ssh/secret_dynamic_key.go b/builtin/logical/ssh/secret_dynamic_key.go index 2a7083e4aa4f..fd85ac0950c2 100644 --- a/builtin/logical/ssh/secret_dynamic_key.go +++ b/builtin/logical/ssh/secret_dynamic_key.go @@ -64,7 +64,7 @@ func (b *backend) secretDynamicKeyRevoke(ctx context.Context, req *logical.Reque // Remove the public key from authorized_keys file in target machine // The last param 'false' indicates that the key should be uninstalled. - err = b.installPublicKeyInTarget(intSec.AdminUser, intSec.Username, intSec.IP, intSec.Port, hostKey.Key, intSec.DynamicPublicKey, intSec.InstallScript, false) + err = b.installPublicKeyInTarget(ctx, intSec.AdminUser, intSec.Username, intSec.IP, intSec.Port, hostKey.Key, intSec.DynamicPublicKey, intSec.InstallScript, false) if err != nil { return nil, fmt.Errorf("error removing public key from authorized_keys file in target") } diff --git a/builtin/logical/ssh/secret_otp.go b/builtin/logical/ssh/secret_otp.go index c0a71ef3194c..40f75a308105 100644 --- a/builtin/logical/ssh/secret_otp.go +++ b/builtin/logical/ssh/secret_otp.go @@ -34,7 +34,7 @@ func (b *backend) secretOTPRevoke(ctx context.Context, req *logical.Request, d * return nil, fmt.Errorf("secret is missing internal data") } - salt, err := b.Salt() + salt, err := b.Salt(ctx) if err != nil { return nil, err } diff --git a/builtin/logical/ssh/util.go b/builtin/logical/ssh/util.go index 78a9621c5206..62eaf19e4961 100644 --- a/builtin/logical/ssh/util.go +++ b/builtin/logical/ssh/util.go @@ -46,10 +46,10 @@ func generateRSAKeys(keyBits int) (publicKeyRsa string, privateKeyRsa string, er // authorized_keys file is hard coded to resemble Linux. // // The last param 'install' if false, uninstalls the key. -func (b *backend) installPublicKeyInTarget(adminUser, username, ip string, port int, hostkey, dynamicPublicKey, installScript string, install bool) error { +func (b *backend) installPublicKeyInTarget(ctx context.Context, adminUser, username, ip string, port int, hostkey, dynamicPublicKey, installScript string, install bool) error { // Transfer the newly generated public key to remote host under a random // file name. This is to avoid name collisions from other requests. - _, publicKeyFileName, err := b.GenerateSaltedOTP() + _, publicKeyFileName, err := b.GenerateSaltedOTP(ctx) if err != nil { return err } diff --git a/builtin/logical/transit/backend_test.go b/builtin/logical/transit/backend_test.go index 493fea1cc8bb..7e42192a0b79 100644 --- a/builtin/logical/transit/backend_test.go +++ b/builtin/logical/transit/backend_test.go @@ -178,8 +178,8 @@ func testTransit_RSA(t *testing.T, keyType string) { } signReq.Data = map[string]interface{}{ - "input": plaintext, - "algorithm": "invalid", + "input": plaintext, + "hash_algorithm": "invalid", } resp, err = b.HandleRequest(context.Background(), signReq) if err != nil { @@ -190,8 +190,8 @@ func testTransit_RSA(t *testing.T, keyType string) { } signReq.Data = map[string]interface{}{ - "input": plaintext, - "algorithm": "sha2-512", + "input": plaintext, + "hash_algorithm": "sha2-512", } resp, err = b.HandleRequest(context.Background(), signReq) if err != nil || (resp != nil && resp.IsError()) { @@ -212,9 +212,9 @@ func testTransit_RSA(t *testing.T, keyType string) { } verifyReq.Data = map[string]interface{}{ - "input": plaintext, - "signature": signature, - "algorithm": "sha2-512", + "input": plaintext, + "signature": signature, + "hash_algorithm": "sha2-512", } resp, err = b.HandleRequest(context.Background(), verifyReq) if err != nil || (resp != nil && resp.IsError()) { diff --git a/builtin/logical/transit/path_sign_verify.go b/builtin/logical/transit/path_sign_verify.go index cf60ca5853fb..4a0edbf0ffb9 100644 --- a/builtin/logical/transit/path_sign_verify.go +++ b/builtin/logical/transit/path_sign_verify.go @@ -33,7 +33,7 @@ func (b *backend) pathSign() *framework.Path { derivation is enabled; currently only available with ed25519 keys.`, }, - "algorithm": &framework.FieldSchema{ + "hash_algorithm": &framework.FieldSchema{ Type: framework.TypeString, Default: "sha2-256", Description: `Hash algorithm to use (POST body parameter). Valid values are: @@ -47,6 +47,12 @@ Defaults to "sha2-256". Not valid for all key types, including ed25519.`, }, + "algorithm": &framework.FieldSchema{ + Type: framework.TypeString, + Default: "sha2-256", + Description: `Deprecated: use "hash_algorithm" instead.`, + }, + "urlalgorithm": &framework.FieldSchema{ Type: framework.TypeString, Description: `Hash algorithm to use (POST URL parameter)`, @@ -63,6 +69,11 @@ to the min_encryption_version configured on the key.`, Type: framework.TypeBool, Description: `Set to 'true' when the input is already hashed. If the key type is 'rsa-2048' or 'rsa-4096', then the algorithm used to hash the input should be indicated by the 'algorithm' parameter.`, }, + "signature_algorithm": &framework.FieldSchema{ + Type: framework.TypeString, + Description: `The signature algorithm to use for signing. Currently only applies to RSA key types. +Options are 'pss' or 'pkcs1v15'. Defaults to 'pss'`, + }, }, Callbacks: map[logical.Operation]framework.OperationFunc{ @@ -109,7 +120,7 @@ derivation is enabled; currently only available with ed25519 keys.`, Description: `Hash algorithm to use (POST URL parameter)`, }, - "algorithm": &framework.FieldSchema{ + "hash_algorithm": &framework.FieldSchema{ Type: framework.TypeString, Default: "sha2-256", Description: `Hash algorithm to use (POST body parameter). Valid values are: @@ -121,11 +132,21 @@ derivation is enabled; currently only available with ed25519 keys.`, Defaults to "sha2-256". Not valid for all key types.`, }, + "algorithm": &framework.FieldSchema{ + Type: framework.TypeString, + Default: "sha2-256", + Description: `Deprecated: use "hash_algorithm" instead.`, + }, "prehashed": &framework.FieldSchema{ Type: framework.TypeBool, Description: `Set to 'true' when the input is already hashed. If the key type is 'rsa-2048' or 'rsa-4096', then the algorithm used to hash the input should be indicated by the 'algorithm' parameter.`, }, + "signature_algorithm": &framework.FieldSchema{ + Type: framework.TypeString, + Description: `The signature algorithm to use for signature verification. Currently only applies to RSA key types. +Options are 'pss' or 'pkcs1v15'. Defaults to 'pss'`, + }, }, Callbacks: map[logical.Operation]framework.OperationFunc{ @@ -141,11 +162,15 @@ func (b *backend) pathSignWrite(ctx context.Context, req *logical.Request, d *fr name := d.Get("name").(string) ver := d.Get("key_version").(int) inputB64 := d.Get("input").(string) - algorithm := d.Get("urlalgorithm").(string) - if algorithm == "" { - algorithm = d.Get("algorithm").(string) + hashAlgorithm := d.Get("urlalgorithm").(string) + if hashAlgorithm == "" { + hashAlgorithm = d.Get("hash_algorithm").(string) + if hashAlgorithm == "" { + hashAlgorithm = d.Get("algorithm").(string) + } } prehashed := d.Get("prehashed").(bool) + sigAlgorithm := d.Get("signature_algorithm").(string) input, err := base64.StdEncoding.DecodeString(inputB64) if err != nil { @@ -179,7 +204,7 @@ func (b *backend) pathSignWrite(ctx context.Context, req *logical.Request, d *fr if p.Type.HashSignatureInput() && !prehashed { var hf hash.Hash - switch algorithm { + switch hashAlgorithm { case "sha2-224": hf = sha256.New224() case "sha2-256": @@ -189,13 +214,13 @@ func (b *backend) pathSignWrite(ctx context.Context, req *logical.Request, d *fr case "sha2-512": hf = sha512.New() default: - return logical.ErrorResponse(fmt.Sprintf("unsupported algorithm %s", algorithm)), nil + return logical.ErrorResponse(fmt.Sprintf("unsupported hash algorithm %s", hashAlgorithm)), nil } hf.Write(input) input = hf.Sum(nil) } - sig, err := p.Sign(ver, context, input, algorithm) + sig, err := p.Sign(ver, context, input, hashAlgorithm, sigAlgorithm) if err != nil { return nil, err } @@ -234,11 +259,15 @@ func (b *backend) pathVerifyWrite(ctx context.Context, req *logical.Request, d * name := d.Get("name").(string) inputB64 := d.Get("input").(string) - algorithm := d.Get("urlalgorithm").(string) - if algorithm == "" { - algorithm = d.Get("algorithm").(string) + hashAlgorithm := d.Get("urlalgorithm").(string) + if hashAlgorithm == "" { + hashAlgorithm = d.Get("hash_algorithm").(string) + if hashAlgorithm == "" { + hashAlgorithm = d.Get("algorithm").(string) + } } prehashed := d.Get("prehashed").(bool) + sigAlgorithm := d.Get("signature_algorithm").(string) input, err := base64.StdEncoding.DecodeString(inputB64) if err != nil { @@ -272,7 +301,7 @@ func (b *backend) pathVerifyWrite(ctx context.Context, req *logical.Request, d * if p.Type.HashSignatureInput() && !prehashed { var hf hash.Hash - switch algorithm { + switch hashAlgorithm { case "sha2-224": hf = sha256.New224() case "sha2-256": @@ -282,13 +311,13 @@ func (b *backend) pathVerifyWrite(ctx context.Context, req *logical.Request, d * case "sha2-512": hf = sha512.New() default: - return logical.ErrorResponse(fmt.Sprintf("unsupported algorithm %s", algorithm)), nil + return logical.ErrorResponse(fmt.Sprintf("unsupported hash algorithm %s", hashAlgorithm)), nil } hf.Write(input) input = hf.Sum(nil) } - valid, err := p.VerifySignature(context, input, sig, algorithm) + valid, err := p.VerifySignature(context, input, sig, hashAlgorithm, sigAlgorithm) if err != nil { switch err.(type) { case errutil.UserError: diff --git a/builtin/logical/transit/path_sign_verify_test.go b/builtin/logical/transit/path_sign_verify_test.go index c9c9dc673501..e8fc3a7bd228 100644 --- a/builtin/logical/transit/path_sign_verify_test.go +++ b/builtin/logical/transit/path_sign_verify_test.go @@ -158,11 +158,11 @@ func TestTransit_SignVerify_P256(t *testing.T) { verifyRequest(req, false, "/sha2-224", sig) // Reset and test algorithm selection in the data - req.Data["algorithm"] = "sha2-224" + req.Data["hash_algorithm"] = "sha2-224" sig = signRequest(req, false, "") verifyRequest(req, false, "", sig) - req.Data["algorithm"] = "sha2-384" + req.Data["hash_algorithm"] = "sha2-384" sig = signRequest(req, false, "") verifyRequest(req, false, "", sig) @@ -173,18 +173,18 @@ func TestTransit_SignVerify_P256(t *testing.T) { // Test 512 and save sig for later to ensure we can't validate once min // decryption version is set - req.Data["algorithm"] = "sha2-512" + req.Data["hash_algorithm"] = "sha2-512" sig = signRequest(req, false, "") verifyRequest(req, false, "", sig) v1sig := sig // Test bad algorithm - req.Data["algorithm"] = "foobar" + req.Data["hash_algorithm"] = "foobar" signRequest(req, true, "") // Test bad input - req.Data["algorithm"] = "sha2-256" + req.Data["hash_algorithm"] = "sha2-256" req.Data["input"] = "foobar" signRequest(req, true, "") @@ -204,7 +204,7 @@ func TestTransit_SignVerify_P256(t *testing.T) { } req.Data["input"] = "dGhlIHF1aWNrIGJyb3duIGZveA==" - req.Data["algorithm"] = "sha2-256" + req.Data["hash_algorithm"] = "sha2-256" // Make sure signing still works fine sig = signRequest(req, false, "") verifyRequest(req, false, "", sig) diff --git a/command/auth_enable.go b/command/auth_enable.go index 328524eb81cb..c6b7486bd0a8 100644 --- a/command/auth_enable.go +++ b/command/auth_enable.go @@ -1,6 +1,7 @@ package command import ( + "flag" "fmt" "strings" "time" @@ -16,13 +17,15 @@ var _ cli.CommandAutocomplete = (*AuthEnableCommand)(nil) type AuthEnableCommand struct { *BaseCommand - flagDescription string - flagPath string - flagDefaultLeaseTTL time.Duration - flagMaxLeaseTTL time.Duration - flagPluginName string - flagLocal bool - flagSealWrap bool + flagDescription string + flagPath string + flagDefaultLeaseTTL time.Duration + flagMaxLeaseTTL time.Duration + flagAuditNonHMACRequestKeys []string + flagAuditNonHMACResponseKeys []string + flagPluginName string + flagLocal bool + flagSealWrap bool } func (c *AuthEnableCommand) Synopsis() string { @@ -96,6 +99,20 @@ func (c *AuthEnableCommand) Flags() *FlagSets { "TTL.", }) + f.StringSliceVar(&StringSliceVar{ + Name: flagNameAuditNonHMACRequestKeys, + Target: &c.flagAuditNonHMACRequestKeys, + Usage: "Comma-separated string or list of keys that will not be HMAC'd by audit" + + "devices in the request data object.", + }) + + f.StringSliceVar(&StringSliceVar{ + Name: flagNameAuditNonHMACResponseKeys, + Target: &c.flagAuditNonHMACResponseKeys, + Usage: "Comma-separated string or list of keys that will not be HMAC'd by audit" + + "devices in the response data object.", + }) + f.StringVar(&StringVar{ Name: "plugin-name", Target: &c.flagPluginName, @@ -170,7 +187,7 @@ func (c *AuthEnableCommand) Run(args []string) int { // Append a trailing slash to indicate it's a path in output authPath = ensureTrailingSlash(authPath) - if err := client.Sys().EnableAuthWithOptions(authPath, &api.EnableAuthOptions{ + authOpts := &api.EnableAuthOptions{ Type: authType, Description: c.flagDescription, Local: c.flagLocal, @@ -180,7 +197,20 @@ func (c *AuthEnableCommand) Run(args []string) int { MaxLeaseTTL: c.flagMaxLeaseTTL.String(), PluginName: c.flagPluginName, }, - }); err != nil { + } + + // Set these values only if they are provided in the CLI + f.Visit(func(fl *flag.Flag) { + if fl.Name == flagNameAuditNonHMACRequestKeys { + authOpts.Config.AuditNonHMACRequestKeys = c.flagAuditNonHMACRequestKeys + } + + if fl.Name == flagNameAuditNonHMACResponseKeys { + authOpts.Config.AuditNonHMACRequestKeys = c.flagAuditNonHMACResponseKeys + } + }) + + if err := client.Sys().EnableAuthWithOptions(authPath, authOpts); err != nil { c.UI.Error(fmt.Sprintf("Error enabling %s auth: %s", authType, err)) return 2 } diff --git a/command/base.go b/command/base.go index 26c62dd64fdb..21c7518b760e 100644 --- a/command/base.go +++ b/command/base.go @@ -332,6 +332,12 @@ func (f *FlagSets) Args() []string { return f.mainSet.Args() } +// Visit visits the flags in lexicographical order, calling fn for each. It +// visits only those flags that have been set. +func (f *FlagSets) Visit(fn func(*flag.Flag)) { + f.mainSet.Visit(fn) +} + // Help builds custom help for this command, grouping by flag set. func (fs *FlagSets) Help() string { var out bytes.Buffer diff --git a/command/commands.go b/command/commands.go index 71d7d5bfd0b0..366d162809cb 100644 --- a/command/commands.go +++ b/command/commands.go @@ -71,6 +71,11 @@ const ( EnvVaultCLINoColor = `VAULT_CLI_NO_COLOR` // EnvVaultFormat is the output format EnvVaultFormat = `VAULT_FORMAT` + + // flagNameAuditNonHMACRequestKeys is the flag name used for auth/secrets enable + flagNameAuditNonHMACRequestKeys = "audit-non-hmac-request-keys" + // flagNameAuditNonHMACResponseKeys is the flag name used for auth/secrets enable + flagNameAuditNonHMACResponseKeys = "audit-non-hmac-response-keys" ) var ( diff --git a/command/secrets_enable.go b/command/secrets_enable.go index e31d77ec2456..ad464a5ef518 100644 --- a/command/secrets_enable.go +++ b/command/secrets_enable.go @@ -1,6 +1,7 @@ package command import ( + "flag" "fmt" "strings" "time" @@ -16,14 +17,16 @@ var _ cli.CommandAutocomplete = (*SecretsEnableCommand)(nil) type SecretsEnableCommand struct { *BaseCommand - flagDescription string - flagPath string - flagDefaultLeaseTTL time.Duration - flagMaxLeaseTTL time.Duration - flagForceNoCache bool - flagPluginName string - flagLocal bool - flagSealWrap bool + flagDescription string + flagPath string + flagDefaultLeaseTTL time.Duration + flagMaxLeaseTTL time.Duration + flagAuditNonHMACRequestKeys []string + flagAuditNonHMACResponseKeys []string + flagForceNoCache bool + flagPluginName string + flagLocal bool + flagSealWrap bool } func (c *SecretsEnableCommand) Synopsis() string { @@ -104,6 +107,20 @@ func (c *SecretsEnableCommand) Flags() *FlagSets { "TTL.", }) + f.StringSliceVar(&StringSliceVar{ + Name: flagNameAuditNonHMACRequestKeys, + Target: &c.flagAuditNonHMACRequestKeys, + Usage: "Comma-separated string or list of keys that will not be HMAC'd by audit" + + "devices in the request data object.", + }) + + f.StringSliceVar(&StringSliceVar{ + Name: flagNameAuditNonHMACResponseKeys, + Target: &c.flagAuditNonHMACResponseKeys, + Usage: "Comma-separated string or list of keys that will not be HMAC'd by audit" + + "devices in the response data object.", + }) + f.BoolVar(&BoolVar{ Name: "force-no-cache", Target: &c.flagForceNoCache, @@ -202,6 +219,17 @@ func (c *SecretsEnableCommand) Run(args []string) int { }, } + // Set these values only if they are provided in the CLI + f.Visit(func(fl *flag.Flag) { + if fl.Name == flagNameAuditNonHMACRequestKeys { + mountInput.Config.AuditNonHMACRequestKeys = c.flagAuditNonHMACRequestKeys + } + + if fl.Name == flagNameAuditNonHMACResponseKeys { + mountInput.Config.AuditNonHMACRequestKeys = c.flagAuditNonHMACResponseKeys + } + }) + if err := client.Sys().Mount(mountPath, mountInput); err != nil { c.UI.Error(fmt.Sprintf("Error enabling: %s", err)) return 2 diff --git a/helper/cidrutil/cidr.go b/helper/cidrutil/cidr.go index 8031bb8eb53f..2d89d846864d 100644 --- a/helper/cidrutil/cidr.go +++ b/helper/cidrutil/cidr.go @@ -31,30 +31,6 @@ func IPBelongsToCIDR(ipAddr string, cidr string) (bool, error) { return true, nil } -// IPBelongsToCIDRBlocksString checks if the given IP is encompassed by any of -// the given CIDR blocks, when the input is a string composed by joining all -// the CIDR blocks using a separator. The input is separated based on the given -// separator and the IP is checked to be belonged by any CIDR block. -func IPBelongsToCIDRBlocksString(ipAddr string, cidrList, separator string) (bool, error) { - if ipAddr == "" { - return false, fmt.Errorf("missing IP address") - } - - if cidrList == "" { - return false, fmt.Errorf("missing CIDR list") - } - - if separator == "" { - return false, fmt.Errorf("missing separator") - } - - if ip := net.ParseIP(ipAddr); ip == nil { - return false, fmt.Errorf("invalid IP address") - } - - return IPBelongsToCIDRBlocksSlice(ipAddr, strutil.ParseDedupLowercaseAndSortStrings(cidrList, separator)) -} - // IPBelongsToCIDRBlocksSlice checks if the given IP is encompassed by any of the given // CIDR blocks func IPBelongsToCIDRBlocksSlice(ipAddr string, cidrs []string) (bool, error) { diff --git a/helper/cidrutil/cidr_test.go b/helper/cidrutil/cidr_test.go index f6d5849c3e45..220afecc1ffa 100644 --- a/helper/cidrutil/cidr_test.go +++ b/helper/cidrutil/cidr_test.go @@ -42,50 +42,6 @@ func TestCIDRUtil_IPBelongsToCIDR(t *testing.T) { } } -func TestCIDRUtil_IPBelongsToCIDRBlocksString(t *testing.T) { - ip := "192.168.27.29" - cidrList := "172.169.100.200/18,192.168.0.0/16,10.10.20.20/24" - - belongs, err := IPBelongsToCIDRBlocksString(ip, cidrList, ",") - if err != nil { - t.Fatal(err) - } - if !belongs { - t.Fatalf("expected IP %q to belong to one of the CIDRs in %q", ip, cidrList) - } - - ip = "10.197.192.6" - cidrList = "1.2.3.0/8,10.197.192.0/18,10.197.193.0/24" - - belongs, err = IPBelongsToCIDRBlocksString(ip, cidrList, ",") - if err != nil { - t.Fatal(err) - } - if !belongs { - t.Fatalf("expected IP %q to belong to one of the CIDRs in %q", ip, cidrList) - } - - ip = "192.168.27.29" - cidrList = "172.169.100.200/18,192.168.0.0.0/16,10.10.20.20/24" - - belongs, err = IPBelongsToCIDRBlocksString(ip, cidrList, ",") - if err == nil { - t.Fatalf("expected an error") - } - - ip = "30.40.50.60" - cidrList = "172.169.100.200/18,192.168.0.0/16,10.10.20.20/24" - - belongs, err = IPBelongsToCIDRBlocksString(ip, cidrList, ",") - if err != nil { - t.Fatal(err) - } - if belongs { - t.Fatalf("expected IP %q to not belong to one of the CIDRs in %q", ip, cidrList) - } - -} - func TestCIDRUtil_IPBelongsToCIDRBlocksSlice(t *testing.T) { ip := "192.168.27.29" cidrList := []string{"172.169.100.200/18", "192.168.0.0/16", "10.10.20.20/24"} diff --git a/helper/keysutil/encrypted_key_storage.go b/helper/keysutil/encrypted_key_storage.go new file mode 100644 index 000000000000..00f961cea6a2 --- /dev/null +++ b/helper/keysutil/encrypted_key_storage.go @@ -0,0 +1,275 @@ +package keysutil + +import ( + "context" + "encoding/base64" + "errors" + paths "path" + "strings" + + "github.com/golang/go/src/math/big" + + "github.com/hashicorp/golang-lru" + "github.com/hashicorp/vault/logical" +) + +const ( + // DefaultCacheSize is used if no cache size is specified for + // NewEncryptedKeyStorage. This value is the number of cache entries to + // store, not the size in bytes of the cache. + DefaultCacheSize = 16 * 1024 + + // DefaultPrefix is used if no prefix is specified for + // NewEncryptedKeyStorage. Prefix must be defined so we can provide context + // for the base folder. + DefaultPrefix = "encryptedkeys/" + + // EncryptedKeyPolicyVersionTpl is a template that can be used to minimize + // the amount of data that's stored with the ciphertext. + EncryptedKeyPolicyVersionTpl = "{{version}}:" +) + +var ( + // ErrPolicyDerivedKeys is returned if the provided policy does not use + // derived keys. This is a requirement for this storage implementation. + ErrPolicyDerivedKeys = errors.New("key policy must use derived keys") + + // ErrPolicyConvergentEncryption is returned if the provided policy does not use + // convergent encryption. This is a requirement for this storage implementation. + ErrPolicyConvergentEncryption = errors.New("key policy must use convergent encryption") + + // ErrPolicyConvergentVersion is returned if the provided policy does not use + // a new enough convergent version. This is a requirement for this storage + // implementation. + ErrPolicyConvergentVersion = errors.New("key policy must use convergent version > 2") + + // ErrNilStorage is returned if the provided storage is nil. + ErrNilStorage = errors.New("nil storage provided") + + // ErrNilPolicy is returned if the provided policy is nil. + ErrNilPolicy = errors.New("nil policy provided") +) + +// EncryptedKeyStorageConfig is used to configure an EncryptedKeyStorage object. +type EncryptedKeyStorageConfig struct { + // Storage is the underlying storage to wrap requests to. + Storage logical.Storage + + // Policy is the key policy to use to encrypt the key paths. + Policy *Policy + + // Prefix is the storage prefix for this instance of the EncryptedKeyStorage + // object. This is stored in plaintext. If not set the DefaultPrefix will be + // used. + Prefix string + + // CacheSize is the number of elements to cache. If not set the + // DetaultCacheSize will be used. + CacheSize int +} + +// NewEncryptedKeyStorage takes an EncryptedKeyStorageConfig and returns a new +// EncryptedKeyStorage object. +func NewEncryptedKeyStorage(config EncryptedKeyStorageConfig) (*EncryptedKeyStorage, error) { + if config.Policy == nil { + return nil, ErrNilPolicy + } + + if !config.Policy.Derived { + return nil, ErrPolicyDerivedKeys + } + + if !config.Policy.ConvergentEncryption { + return nil, ErrPolicyConvergentEncryption + } + + if config.Policy.ConvergentVersion < 2 { + return nil, ErrPolicyConvergentVersion + } + + if config.Storage == nil { + return nil, ErrNilStorage + } + + if config.Prefix == "" { + config.Prefix = DefaultPrefix + } + + if !strings.HasSuffix(config.Prefix, "/") { + config.Prefix += "/" + } + + size := config.CacheSize + if size <= 0 { + size = DefaultCacheSize + } + + cache, err := lru.New2Q(size) + if err != nil { + return nil, err + } + + return &EncryptedKeyStorage{ + policy: config.Policy, + s: config.Storage, + prefix: config.Prefix, + lru: cache, + }, nil +} + +// EncryptedKeyStorage implements the logical.Storage interface and ensures the +// storage paths are encrypted in the underlying storage. +type EncryptedKeyStorage struct { + policy *Policy + s logical.Storage + lru *lru.TwoQueueCache + + prefix string +} + +// List implements the logical.Storage List method, and decrypts all the items +// in a path prefix. This can only operate on full folder structures so the +// prefix should end in a "/". +func (s *EncryptedKeyStorage) List(ctx context.Context, prefix string) ([]string, error) { + encPrefix, err := s.encryptPath(prefix) + if err != nil { + return nil, err + } + + keys, err := s.s.List(ctx, encPrefix+"/") + if err != nil { + return keys, err + } + + decryptedKeys := make([]string, len(keys)) + + // The context for the decryption operations will be the object's prefix + // joined with the provided prefix. Join cleans the path ensuring there + // isn't a trailing "/". + context := []byte(paths.Join(s.prefix, prefix)) + + for i, k := range keys { + raw, ok := s.lru.Get(k) + if ok { + // cache HIT, we can bail early and skip the decode & decrypt operations. + decryptedKeys[i] = raw.(string) + continue + } + + // If a folder is included in the keys it will have a trailing "/". + // We need to remove this before decoding/decrypting and add it back + // later. + appendSlash := strings.HasSuffix(k, "/") + if appendSlash { + k = strings.TrimSuffix(k, "/") + } + + decoded := Base62Decode(k) + if len(decoded) == 0 { + return nil, errors.New("Could not decode key") + } + + // Decrypt the data with the object's key policy. + encodedPlaintext, err := s.policy.Decrypt(context, nil, string(decoded[:])) + if err != nil { + return nil, err + } + + // The plaintext is still base64 encoded, decode it. + decoded, err = base64.StdEncoding.DecodeString(encodedPlaintext) + if err != nil { + return nil, err + } + + plaintext := string(decoded[:]) + + // Add the slash back to the plaintext value + if appendSlash { + plaintext += "/" + k += "/" + } + + // We want to store the unencoded version of the key in the cache. + // This will make it more performent when it's a HIT. + s.lru.Add(k, plaintext) + + decryptedKeys[i] = plaintext + } + + return decryptedKeys, nil +} + +// Get implements the logical.Storage Get method. +func (s *EncryptedKeyStorage) Get(ctx context.Context, path string) (*logical.StorageEntry, error) { + encPath, err := s.encryptPath(path) + if err != nil { + return nil, err + } + + return s.s.Get(ctx, encPath) +} + +// Put implements the logical.Storage Put method. +func (s *EncryptedKeyStorage) Put(ctx context.Context, entry *logical.StorageEntry) error { + encPath, err := s.encryptPath(entry.Key) + if err != nil { + return err + } + e := &logical.StorageEntry{} + *e = *entry + + e.Key = encPath + + return s.s.Put(ctx, e) +} + +// Delete implements the logical.Storage Delete method. +func (s *EncryptedKeyStorage) Delete(ctx context.Context, path string) error { + encPath, err := s.encryptPath(path) + if err != nil { + return err + } + + return s.s.Delete(ctx, encPath) +} + +// encryptPath takes a plaintext path and encrypts each path section (separated +// by "/") with the object's key policy. The context for each encryption is the +// plaintext path prefix for the key. +func (s *EncryptedKeyStorage) encryptPath(path string) (string, error) { + path = paths.Clean(path) + + // Trim the prefix if it starts with a "/" + path = strings.TrimPrefix(path, "/") + + parts := strings.Split(path, "/") + + encPath := s.prefix + context := s.prefix + for _, p := range parts { + encoded := base64.StdEncoding.EncodeToString([]byte(p)) + ciphertext, err := s.policy.Encrypt(0, []byte(context), nil, encoded) + if err != nil { + return "", err + } + + encPath = paths.Join(encPath, Base62Encode([]byte(ciphertext))) + context = paths.Join(context, p) + } + + return encPath, nil +} + +func Base62Encode(buf []byte) string { + encoder := &big.Int{} + + encoder.SetBytes(buf) + return encoder.Text(62) +} + +func Base62Decode(input string) []byte { + decoder := &big.Int{} + + decoder.SetString(input, 62) + return decoder.Bytes() +} diff --git a/helper/keysutil/encrypted_key_storage_test.go b/helper/keysutil/encrypted_key_storage_test.go new file mode 100644 index 000000000000..f700ff589592 --- /dev/null +++ b/helper/keysutil/encrypted_key_storage_test.go @@ -0,0 +1,358 @@ +package keysutil + +import ( + "context" + "fmt" + "reflect" + "sync" + "testing" + + "github.com/hashicorp/vault/logical" +) + +var compilerOpt []string + +func TestBase58(t *testing.T) { + tCases := []struct { + in string + out string + }{ + { + "", + "0", + }, + { + "foo", + "sapp", + }, + { + "5d5746d044b9a9429249966c9e3fee178ca679b91487b11d4b73c9865202104c", + "cozMP2pOYdDiNGeFQ2afKAOGIzO0HVpJ8OPFXuVPNbHasFyenK9CzIIPuOG7EFWOCy4YWvKGZa671N4kRSoaxZ", + }, + { + "5ba33e16d742f3c785f6e7e8bb6f5fe82346ffa1c47aa8e95da4ddd5a55bb334", + "cotpEJPnhuTRofLi4lDe5iKw2fkSGc6TpUYeuWoBp8eLYJBWLRUVDZI414OjOCWXKZ0AI8gqNMoxd4eLOklwYk", + }, + { + " ", + "w", + }, + { + "-", + "J", + }, + { + "0", + "M", + }, + { + "1", + "N", + }, + { + "-1", + "30B", + }, + { + "11", + "3h7", + }, + { + "abc", + "qMin", + }, + { + "1234598760", + "1a0AFzKIPnihTq", + }, + { + "abcdefghijklmnopqrstuvwxyz", + "hUBXsgd3F2swSlEgbVi2p0Ncr6kzVeJTLaW", + }, + } + + for _, c := range tCases { + e := Base62Encode([]byte(c.in)) + d := string(Base62Decode(e)) + + if d != c.in { + t.Fatalf("decoded value didn't match input %#v %#v", c.in, d) + } + + if e != c.out { + t.Fatalf("encoded value didn't match expected %#v, %#v", e, c.out) + } + } + + d := Base62Decode("!0000/") + if len(d) != 0 { + t.Fatalf("Decode of invalid string should be empty, got %#v", d) + } +} + +func TestEncrytedKeysStorage_BadPolicy(t *testing.T) { + s := &logical.InmemStorage{} + policy := &Policy{ + Name: "metadata", + Type: KeyType_AES256_GCM96, + Derived: false, + KDF: Kdf_hkdf_sha256, + ConvergentEncryption: true, + ConvergentVersion: 2, + VersionTemplate: EncryptedKeyPolicyVersionTpl, + versionPrefixCache: &sync.Map{}, + } + + _, err := NewEncryptedKeyStorage(EncryptedKeyStorageConfig{ + Storage: s, + Policy: policy, + Prefix: "prefix", + }) + if err != ErrPolicyDerivedKeys { + t.Fatalf("Unexpected Error: %s", err) + } + + policy = &Policy{ + Name: "metadata", + Type: KeyType_AES256_GCM96, + Derived: true, + KDF: Kdf_hkdf_sha256, + ConvergentEncryption: false, + ConvergentVersion: 2, + VersionTemplate: EncryptedKeyPolicyVersionTpl, + versionPrefixCache: &sync.Map{}, + } + + _, err = NewEncryptedKeyStorage(EncryptedKeyStorageConfig{ + Storage: s, + Policy: policy, + Prefix: "prefix", + }) + if err != ErrPolicyConvergentEncryption { + t.Fatalf("Unexpected Error: %s", err) + } + + policy = &Policy{ + Name: "metadata", + Type: KeyType_AES256_GCM96, + Derived: true, + KDF: Kdf_hkdf_sha256, + ConvergentEncryption: true, + ConvergentVersion: 1, + VersionTemplate: EncryptedKeyPolicyVersionTpl, + versionPrefixCache: &sync.Map{}, + } + + _, err = NewEncryptedKeyStorage(EncryptedKeyStorageConfig{ + Storage: s, + Policy: policy, + Prefix: "prefix", + }) + if err != ErrPolicyConvergentVersion { + t.Fatalf("Unexpected Error: %s", err) + } + + policy = &Policy{ + Name: "metadata", + Type: KeyType_AES256_GCM96, + Derived: true, + KDF: Kdf_hkdf_sha256, + ConvergentEncryption: true, + ConvergentVersion: 2, + VersionTemplate: EncryptedKeyPolicyVersionTpl, + versionPrefixCache: &sync.Map{}, + } + + _, err = NewEncryptedKeyStorage(EncryptedKeyStorageConfig{ + Storage: nil, + Policy: policy, + Prefix: "prefix", + }) + if err != ErrNilStorage { + t.Fatalf("Unexpected Error: %s", err) + } +} + +func TestEncrytedKeysStorage_CRUD(t *testing.T) { + s := &logical.InmemStorage{} + policy := &Policy{ + Name: "metadata", + Type: KeyType_AES256_GCM96, + Derived: true, + KDF: Kdf_hkdf_sha256, + ConvergentEncryption: true, + ConvergentVersion: 2, + VersionTemplate: EncryptedKeyPolicyVersionTpl, + versionPrefixCache: &sync.Map{}, + } + + ctx := context.Background() + + err := policy.Rotate(ctx, s) + if err != nil { + t.Fatal(err) + } + + es, err := NewEncryptedKeyStorage(EncryptedKeyStorageConfig{ + Storage: s, + Policy: policy, + Prefix: "prefix", + }) + if err != nil { + t.Fatal(err) + } + + err = es.Put(ctx, &logical.StorageEntry{ + Key: "test/foo", + Value: []byte("test"), + }) + if err != nil { + t.Fatal(err) + } + + err = es.Put(ctx, &logical.StorageEntry{ + Key: "test/foo1/test", + Value: []byte("test"), + }) + if err != nil { + t.Fatal(err) + } + + keys, err := es.List(ctx, "test/") + if err != nil { + t.Fatal(err) + } + + // Test prefixed with "/" + keys, err = es.List(ctx, "/test/") + if err != nil { + t.Fatal(err) + } + + if len(keys) != 2 || keys[0] != "foo1/" || keys[1] != "foo" { + t.Fatalf("bad keys: %#v", keys) + } + + // Test the cached value is correct + keys, err = es.List(ctx, "test/") + if err != nil { + t.Fatal(err) + } + + if len(keys) != 2 || keys[0] != "foo1/" || keys[1] != "foo" { + t.Fatalf("bad keys: %#v", keys) + } + + data, err := es.Get(ctx, "test/foo") + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(data.Value, []byte("test")) { + t.Fatalf("bad data: %#v", data) + } + + err = es.Delete(ctx, "test/foo") + if err != nil { + t.Fatal(err) + } + + data, err = es.Get(ctx, "test/foo") + if err != nil { + t.Fatal(err) + } + if data != nil { + t.Fatal("data should be nil") + } + +} + +func BenchmarkEncrytedKeyStorage_List(b *testing.B) { + s := &logical.InmemStorage{} + policy := &Policy{ + Name: "metadata", + Type: KeyType_AES256_GCM96, + Derived: true, + KDF: Kdf_hkdf_sha256, + ConvergentEncryption: true, + ConvergentVersion: 2, + VersionTemplate: EncryptedKeyPolicyVersionTpl, + versionPrefixCache: &sync.Map{}, + } + + ctx := context.Background() + + err := policy.Rotate(ctx, s) + if err != nil { + b.Fatal(err) + } + + es, err := NewEncryptedKeyStorage(EncryptedKeyStorageConfig{ + Storage: s, + Policy: policy, + Prefix: "prefix", + }) + if err != nil { + b.Fatal(err) + } + + for i := 0; i < 10000; i++ { + err = es.Put(ctx, &logical.StorageEntry{ + Key: fmt.Sprintf("test/%d", i), + Value: []byte("test"), + }) + if err != nil { + b.Fatal(err) + } + } + b.ResetTimer() + + for i := 0; i < b.N; i++ { + keys, err := es.List(ctx, "test/") + if err != nil { + b.Fatal(err) + } + compilerOpt = keys + } +} + +func BenchmarkEncrytedKeyStorage_Put(b *testing.B) { + s := &logical.InmemStorage{} + policy := &Policy{ + Name: "metadata", + Type: KeyType_AES256_GCM96, + Derived: true, + KDF: Kdf_hkdf_sha256, + ConvergentEncryption: true, + ConvergentVersion: 2, + VersionTemplate: EncryptedKeyPolicyVersionTpl, + versionPrefixCache: &sync.Map{}, + } + + ctx := context.Background() + + err := policy.Rotate(ctx, s) + if err != nil { + b.Fatal(err) + } + + es, err := NewEncryptedKeyStorage(EncryptedKeyStorageConfig{ + Storage: s, + Policy: policy, + Prefix: "prefix", + }) + if err != nil { + b.Fatal(err) + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + err = es.Put(ctx, &logical.StorageEntry{ + Key: fmt.Sprintf("test/%d", i), + Value: []byte("test"), + }) + if err != nil { + b.Fatal(err) + } + } +} diff --git a/helper/keysutil/lock_manager.go b/helper/keysutil/lock_manager.go index f1c505a7b49c..778c28cbcbb0 100644 --- a/helper/keysutil/lock_manager.go +++ b/helper/keysutil/lock_manager.go @@ -239,6 +239,9 @@ func (lm *LockManager) RestorePolicy(ctx context.Context, storage logical.Storag name = keyData.Policy.Name + // set the policy version cache + keyData.Policy.versionPrefixCache = &sync.Map{} + lockType := exclusive lock := lm.policyLock(name, lockType) defer lm.UnlockPolicy(lock, lockType) @@ -387,6 +390,7 @@ func (lm *LockManager) getPolicyCommon(ctx context.Context, req PolicyRequest, l Derived: req.Derived, Exportable: req.Exportable, AllowPlaintextBackup: req.AllowPlaintextBackup, + versionPrefixCache: &sync.Map{}, } if req.Derived { p.KDF = Kdf_hkdf_sha256 @@ -496,21 +500,5 @@ func (lm *LockManager) DeletePolicy(ctx context.Context, storage logical.Storage } func (lm *LockManager) getStoredPolicy(ctx context.Context, storage logical.Storage, name string) (*Policy, error) { - // Check if the policy already exists - raw, err := storage.Get(ctx, "policy/"+name) - if err != nil { - return nil, err - } - if raw == nil { - return nil, nil - } - - // Decode the policy - var policy Policy - err = jsonutil.DecodeJSON(raw.Value, &policy) - if err != nil { - return nil, err - } - - return &policy, nil + return LoadPolicy(ctx, storage, "policy/"+name) } diff --git a/helper/keysutil/policy.go b/helper/keysutil/policy.go index e5770904a072..759d083773da 100644 --- a/helper/keysutil/policy.go +++ b/helper/keysutil/policy.go @@ -20,8 +20,10 @@ import ( "fmt" "io" "math/big" + "path" "strconv" "strings" + "sync" "time" "golang.org/x/crypto/chacha20poly1305" @@ -52,7 +54,14 @@ const ( KeyType_ChaCha20_Poly1305 ) -const ErrTooOld = "ciphertext or signature version is disallowed by policy (too old)" +const ( + // ErrTooOld is returned whtn the ciphertext or signatures's key version is + // too old. + ErrTooOld = "ciphertext or signature version is disallowed by policy (too old)" + + // DefaultVersionTemplate is used when no version template is provided. + DefaultVersionTemplate = "vault:v{{version}}:" +) type RestoreInfo struct { Time time.Time `json:"time"` @@ -196,6 +205,85 @@ func (kem deprecatedKeyEntryMap) UnmarshalJSON(data []byte) error { // keyEntryMap is used to allow JSON marshal/unmarshal type keyEntryMap map[string]KeyEntry +// PolicyConfig is used to create a new policy +type PolicyConfig struct { + // The name of the policy + Name string `json:"name"` + + // The type of key + Type KeyType + + // Derived keys MUST provide a context and the master underlying key is + // never used. If convergent encryption is true, the context will be used + // as the nonce as well. + Derived bool + KDF int + ConvergentEncryption bool + + // Whether the key is exportable + Exportable bool + + // Whether the key is allowed to be deleted + DeletionAllowed bool + + // AllowPlaintextBackup allows taking backup of the policy in plaintext + AllowPlaintextBackup bool + + // VersionTemplate is used to prefix the ciphertext with information about + // the key version. It must inclide {{version}} and a delimiter between the + // version prefix and the ciphertext. + VersionTemplate string + + // StoragePrefix is used to add a prefix when storing and retrieving the + // policy object. + StoragePrefix string +} + +// NewPolicy takes a policy config and returns a Policy with those settings. +func NewPolicy(config PolicyConfig) *Policy { + var convergentVersion int + if config.ConvergentEncryption { + convergentVersion = 2 + } + + return &Policy{ + Name: config.Name, + Type: config.Type, + Derived: config.Derived, + KDF: config.KDF, + ConvergentEncryption: config.ConvergentEncryption, + ConvergentVersion: convergentVersion, + Exportable: config.Exportable, + DeletionAllowed: config.DeletionAllowed, + AllowPlaintextBackup: config.AllowPlaintextBackup, + VersionTemplate: config.VersionTemplate, + StoragePrefix: config.StoragePrefix, + versionPrefixCache: &sync.Map{}, + } +} + +// LoadPolicy will load a policy from the provided storage path and set the +// necessary un-exported variables. It is particularly useful when accessing a +// policy without the lock manager. +func LoadPolicy(ctx context.Context, s logical.Storage, path string) (*Policy, error) { + raw, err := s.Get(ctx, path) + if err != nil { + return nil, err + } + if raw == nil { + return nil, nil + } + + var policy Policy + err = jsonutil.DecodeJSON(raw.Value, &policy) + if err != nil { + return nil, err + } + + policy.versionPrefixCache = &sync.Map{} + return &policy, nil +} + // Policy is the struct used to store metadata type Policy struct { Name string `json:"name"` @@ -244,6 +332,19 @@ type Policy struct { // AllowPlaintextBackup allows taking backup of the policy in plaintext AllowPlaintextBackup bool `json:"allow_plaintext_backup"` + + // VersionTemplate is used to prefix the ciphertext with information about + // the key version. It must inclide {{version}} and a delimiter between the + // version prefix and the ciphertext. + VersionTemplate string `json:"version_template"` + + // StoragePrefix is used to add a prefix when storing and retrieving the + // policy object. + StoragePrefix string `json:"storage_prefix"` + + // versionPrefixCache stores caches of verison prefix strings and the split + // version template. + versionPrefixCache *sync.Map } // ArchivedKeys stores old keys. This is used to keep the key loading time sane @@ -255,7 +356,7 @@ type archivedKeys struct { func (p *Policy) LoadArchive(ctx context.Context, storage logical.Storage) (*archivedKeys, error) { archive := &archivedKeys{} - raw, err := storage.Get(ctx, "archive/"+p.Name) + raw, err := storage.Get(ctx, path.Join(p.StoragePrefix, "archive", p.Name)) if err != nil { return nil, err } @@ -280,7 +381,7 @@ func (p *Policy) storeArchive(ctx context.Context, storage logical.Storage, arch // Write the policy into storage err = storage.Put(ctx, &logical.StorageEntry{ - Key: "archive/" + p.Name, + Key: path.Join(p.StoragePrefix, "archive", p.Name), Value: buf, }) if err != nil { @@ -405,7 +506,7 @@ func (p *Policy) Persist(ctx context.Context, storage logical.Storage) (retErr e // Write the policy into storage err = storage.Put(ctx, &logical.StorageEntry{ - Key: "policy/" + p.Name, + Key: path.Join(p.StoragePrefix, "policy", p.Name), Value: buf, }) if err != nil { @@ -703,7 +804,7 @@ func (p *Policy) Encrypt(ver int, context, nonce []byte, value string) (string, encoded := base64.StdEncoding.EncodeToString(ciphertext) // Prepend some information - encoded = "vault:v" + strconv.Itoa(ver) + ":" + encoded + encoded = p.getVersionPrefix(ver) + encoded return encoded, nil } @@ -713,8 +814,13 @@ func (p *Policy) Decrypt(context, nonce []byte, value string) (string, error) { return "", errutil.UserError{Err: fmt.Sprintf("message decryption not supported for key type %v", p.Type)} } + tplParts, err := p.getTemplateParts() + if err != nil { + return "", err + } + // Verify the prefix - if !strings.HasPrefix(value, "vault:v") { + if !strings.HasPrefix(value, tplParts[0]) { return "", errutil.UserError{Err: "invalid ciphertext: no prefix"} } @@ -722,7 +828,7 @@ func (p *Policy) Decrypt(context, nonce []byte, value string) (string, error) { return "", errutil.UserError{Err: "invalid convergent nonce supplied"} } - splitVerCiphertext := strings.SplitN(strings.TrimPrefix(value, "vault:v"), ":", 2) + splitVerCiphertext := strings.SplitN(strings.TrimPrefix(value, tplParts[0]), tplParts[1], 2) if len(splitVerCiphertext) != 2 { return "", errutil.UserError{Err: "invalid ciphertext: wrong number of fields"} } @@ -836,7 +942,7 @@ func (p *Policy) HMACKey(version int) ([]byte, error) { return p.Keys[strconv.Itoa(version)].HMACKey, nil } -func (p *Policy) Sign(ver int, context, input []byte, algorithm string) (*SigningResult, error) { +func (p *Policy) Sign(ver int, context, input []byte, hashAlgorithm, sigAlgorithm string) (*SigningResult, error) { if !p.Type.SigningSupported() { return nil, fmt.Errorf("message signing not supported for key type %v", p.Type) } @@ -905,7 +1011,7 @@ func (p *Policy) Sign(ver int, context, input []byte, algorithm string) (*Signin key := p.Keys[strconv.Itoa(ver)].RSAKey var algo crypto.Hash - switch algorithm { + switch hashAlgorithm { case "sha2-224": algo = crypto.SHA224 case "sha2-256": @@ -915,12 +1021,26 @@ func (p *Policy) Sign(ver int, context, input []byte, algorithm string) (*Signin case "sha2-512": algo = crypto.SHA512 default: - return nil, errutil.InternalError{Err: fmt.Sprintf("unsupported algorithm %s", algorithm)} + return nil, errutil.InternalError{Err: fmt.Sprintf("unsupported hash algorithm %s", hashAlgorithm)} } - sig, err = rsa.SignPSS(rand.Reader, key, algo, input, nil) - if err != nil { - return nil, err + if sigAlgorithm == "" { + sigAlgorithm = "pss" + } + + switch sigAlgorithm { + case "pss": + sig, err = rsa.SignPSS(rand.Reader, key, algo, input, nil) + if err != nil { + return nil, err + } + case "pkcs1v15": + sig, err = rsa.SignPKCS1v15(rand.Reader, key, algo, input) + if err != nil { + return nil, err + } + default: + return nil, errutil.InternalError{Err: fmt.Sprintf("unsupported rsa signature algorithm %s", sigAlgorithm)} } default: @@ -929,26 +1049,30 @@ func (p *Policy) Sign(ver int, context, input []byte, algorithm string) (*Signin // Convert to base64 encoded := base64.StdEncoding.EncodeToString(sig) - res := &SigningResult{ - Signature: "vault:v" + strconv.Itoa(ver) + ":" + encoded, + Signature: p.getVersionPrefix(ver) + encoded, PublicKey: pubKey, } return res, nil } -func (p *Policy) VerifySignature(context, input []byte, sig, algorithm string) (bool, error) { +func (p *Policy) VerifySignature(context, input []byte, sig, hashAlgorithm string, sigAlgorithm string) (bool, error) { if !p.Type.SigningSupported() { return false, errutil.UserError{Err: fmt.Sprintf("message verification not supported for key type %v", p.Type)} } + tplParts, err := p.getTemplateParts() + if err != nil { + return false, err + } + // Verify the prefix - if !strings.HasPrefix(sig, "vault:v") { + if !strings.HasPrefix(sig, tplParts[0]) { return false, errutil.UserError{Err: "invalid signature: no prefix"} } - splitVerSig := strings.SplitN(strings.TrimPrefix(sig, "vault:v"), ":", 2) + splitVerSig := strings.SplitN(strings.TrimPrefix(sig, tplParts[0]), tplParts[1], 2) if len(splitVerSig) != 2 { return false, errutil.UserError{Err: "invalid signature: wrong number of fields"} } @@ -1011,7 +1135,7 @@ func (p *Policy) VerifySignature(context, input []byte, sig, algorithm string) ( key := p.Keys[strconv.Itoa(ver)].RSAKey var algo crypto.Hash - switch algorithm { + switch hashAlgorithm { case "sha2-224": algo = crypto.SHA224 case "sha2-256": @@ -1021,10 +1145,21 @@ func (p *Policy) VerifySignature(context, input []byte, sig, algorithm string) ( case "sha2-512": algo = crypto.SHA512 default: - return false, errutil.InternalError{Err: fmt.Sprintf("unsupported algorithm %s", algorithm)} + return false, errutil.InternalError{Err: fmt.Sprintf("unsupported hash algorithm %s", hashAlgorithm)} + } + + if sigAlgorithm == "" { + sigAlgorithm = "pss" } - err = rsa.VerifyPSS(&key.PublicKey, algo, input, sigBytes, nil) + switch sigAlgorithm { + case "pss": + err = rsa.VerifyPSS(&key.PublicKey, algo, input, sigBytes, nil) + case "pkcs1v15": + err = rsa.VerifyPKCS1v15(&key.PublicKey, algo, input, sigBytes) + default: + return false, errutil.InternalError{Err: fmt.Sprintf("unsupported rsa signature algorithm %s", sigAlgorithm)} + } return err == nil, nil @@ -1195,3 +1330,40 @@ func (p *Policy) Backup(ctx context.Context, storage logical.Storage) (out strin return base64.StdEncoding.EncodeToString(encodedBackup), nil } + +func (p *Policy) getTemplateParts() ([]string, error) { + partsRaw, ok := p.versionPrefixCache.Load("template-parts") + if ok { + return partsRaw.([]string), nil + } + + template := p.VersionTemplate + if template == "" { + template = DefaultVersionTemplate + } + + tplParts := strings.Split(template, "{{version}}") + if len(tplParts) != 2 { + return nil, errutil.InternalError{Err: "error parsing version template"} + } + + p.versionPrefixCache.Store("template-parts", tplParts) + return tplParts, nil +} + +func (p *Policy) getVersionPrefix(ver int) string { + prefixRaw, ok := p.versionPrefixCache.Load(ver) + if ok { + return prefixRaw.(string) + } + + template := p.VersionTemplate + if template == "" { + template = DefaultVersionTemplate + } + + prefix := strings.Replace(template, "{{version}}", strconv.Itoa(ver), -1) + p.versionPrefixCache.Store(ver, prefix) + + return prefix +} diff --git a/helper/keysutil/policy_test.go b/helper/keysutil/policy_test.go index d1e745fea17f..3f7f20917abe 100644 --- a/helper/keysutil/policy_test.go +++ b/helper/keysutil/policy_test.go @@ -471,6 +471,7 @@ func Test_BadUpgrade(t *testing.T) { k.CreationTime = o.CreationTime k.HMACKey = o.HMACKey p.Keys["1"] = k + p.versionPrefixCache = nil if !reflect.DeepEqual(orig, p) { t.Fatalf("not equal:\n%#v\n%#v", orig, p) diff --git a/helper/salt/salt.go b/helper/salt/salt.go index 24fd208bd547..450d9c6e7360 100644 --- a/helper/salt/salt.go +++ b/helper/salt/salt.go @@ -51,7 +51,7 @@ type Config struct { } // NewSalt creates a new salt based on the configuration -func NewSalt(view logical.Storage, config *Config) (*Salt, error) { +func NewSalt(ctx context.Context, view logical.Storage, config *Config) (*Salt, error) { // Setup the configuration if config == nil { config = &Config{} @@ -76,7 +76,7 @@ func NewSalt(view logical.Storage, config *Config) (*Salt, error) { var raw *logical.StorageEntry var err error if view != nil { - raw, err = view.Get(context.Background(), config.Location) + raw, err = view.Get(ctx, config.Location) if err != nil { return nil, fmt.Errorf("failed to read salt: %v", err) } @@ -99,7 +99,7 @@ func NewSalt(view logical.Storage, config *Config) (*Salt, error) { Key: config.Location, Value: []byte(s.salt), } - if err := view.Put(context.Background(), raw); err != nil { + if err := view.Put(ctx, raw); err != nil { return nil, fmt.Errorf("failed to persist salt: %v", err) } } diff --git a/helper/salt/salt_test.go b/helper/salt/salt_test.go index 2470dc8e2094..25359c08d1e7 100644 --- a/helper/salt/salt_test.go +++ b/helper/salt/salt_test.go @@ -14,7 +14,7 @@ func TestSalt(t *testing.T) { inm := &logical.InmemStorage{} conf := &Config{} - salt, err := NewSalt(inm, conf) + salt, err := NewSalt(context.Background(), inm, conf) if err != nil { t.Fatalf("err: %v", err) } @@ -33,7 +33,7 @@ func TestSalt(t *testing.T) { } // Create a new salt, should restore - salt2, err := NewSalt(inm, conf) + salt2, err := NewSalt(context.Background(), inm, conf) if err != nil { t.Fatalf("err: %v", err) } diff --git a/http/sys_auth_test.go b/http/sys_auth_test.go index 58e70963a772..a806450f36a4 100644 --- a/http/sys_auth_test.go +++ b/http/sys_auth_test.go @@ -31,6 +31,7 @@ func TestSysAuth(t *testing.T) { "config": map[string]interface{}{ "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), + "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -42,6 +43,7 @@ func TestSysAuth(t *testing.T) { "config": map[string]interface{}{ "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), + "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -93,6 +95,7 @@ func TestSysEnableAuth(t *testing.T) { "config": map[string]interface{}{ "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), + "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -103,6 +106,7 @@ func TestSysEnableAuth(t *testing.T) { "config": map[string]interface{}{ "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), + "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -114,6 +118,7 @@ func TestSysEnableAuth(t *testing.T) { "config": map[string]interface{}{ "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), + "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -124,6 +129,7 @@ func TestSysEnableAuth(t *testing.T) { "config": map[string]interface{}{ "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), + "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -176,6 +182,7 @@ func TestSysDisableAuth(t *testing.T) { "config": map[string]interface{}{ "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), + "plugin_name": "", }, "description": "token based credentials", "type": "token", @@ -187,6 +194,7 @@ func TestSysDisableAuth(t *testing.T) { "config": map[string]interface{}{ "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), + "plugin_name": "", }, "description": "token based credentials", "type": "token", diff --git a/logical/framework/path_map.go b/logical/framework/path_map.go index 6e369e24a3f5..83aa0bafaa1f 100644 --- a/logical/framework/path_map.go +++ b/logical/framework/path_map.go @@ -22,7 +22,7 @@ type PathMap struct { Schema map[string]*FieldSchema CaseSensitive bool Salt *saltpkg.Salt - SaltFunc func() (*saltpkg.Salt, error) + SaltFunc func(context.Context) (*saltpkg.Salt, error) once sync.Once } @@ -58,7 +58,7 @@ func (p *PathMap) pathStruct(ctx context.Context, s logical.Storage, k string) ( salt := p.Salt var err error if p.SaltFunc != nil { - salt, err = p.SaltFunc() + salt, err = p.SaltFunc(ctx) if err != nil { return nil, err } diff --git a/logical/framework/path_map_test.go b/logical/framework/path_map_test.go index b1cce0923f2c..97ab774aee3d 100644 --- a/logical/framework/path_map_test.go +++ b/logical/framework/path_map_test.go @@ -143,7 +143,7 @@ func TestPathMap_routes(t *testing.T) { func TestPathMap_Salted(t *testing.T) { storage := new(logical.InmemStorage) - salt, err := saltpkg.NewSalt(storage, &saltpkg.Config{ + salt, err := saltpkg.NewSalt(context.Background(), storage, &saltpkg.Config{ HashFunc: saltpkg.SHA1Hash, }) if err != nil { @@ -335,14 +335,14 @@ func testSalting(t *testing.T, ctx context.Context, storage logical.Storage, sal func TestPathMap_SaltFunc(t *testing.T) { storage := new(logical.InmemStorage) - salt, err := saltpkg.NewSalt(storage, &saltpkg.Config{ + salt, err := saltpkg.NewSalt(context.Background(), storage, &saltpkg.Config{ HashFunc: saltpkg.SHA1Hash, }) if err != nil { t.Fatalf("err: %v", err) } - saltFunc := func() (*saltpkg.Salt, error) { + saltFunc := func(context.Context) (*saltpkg.Salt, error) { return salt, nil } diff --git a/logical/plugin/backend.go b/logical/plugin/backend.go index b79798fb5da2..713f5d3c398a 100644 --- a/logical/plugin/backend.go +++ b/logical/plugin/backend.go @@ -3,6 +3,7 @@ package plugin import ( "context" "net/rpc" + "sync/atomic" "google.golang.org/grpc" @@ -42,10 +43,17 @@ func (b BackendPlugin) GRPCServer(broker *plugin.GRPCBroker, s *grpc.Server) err } func (p *BackendPlugin) GRPCClient(ctx context.Context, broker *plugin.GRPCBroker, c *grpc.ClientConn) (interface{}, error) { - return &backendGRPCPluginClient{ + ret := &backendGRPCPluginClient{ client: pb.NewBackendClient(c), clientConn: c, broker: broker, + cleanupCh: make(chan struct{}), doneCtx: ctx, - }, nil + } + + // Create the value and set the type + ret.server = new(atomic.Value) + ret.server.Store((*grpc.Server)(nil)) + + return ret, nil } diff --git a/logical/plugin/grpc_backend_client.go b/logical/plugin/grpc_backend_client.go index 7df90837d1c3..c980f795be2f 100644 --- a/logical/plugin/grpc_backend_client.go +++ b/logical/plugin/grpc_backend_client.go @@ -3,6 +3,7 @@ package plugin import ( "context" "errors" + "sync/atomic" "google.golang.org/grpc" @@ -28,8 +29,13 @@ type backendGRPCPluginClient struct { system logical.SystemView logger log.Logger + // This is used to signal to the Cleanup function that it can proceed + // because we have a defined server + cleanupCh chan struct{} + // server is the grpc server used for serving storage and sysview requests. - server *grpc.Server + server *atomic.Value + // clientConn is the underlying grpc connection to the server, we store it // so it can be cleaned up. clientConn *grpc.ClientConn @@ -139,8 +145,16 @@ func (b *backendGRPCPluginClient) Cleanup(ctx context.Context) { defer cancel() b.client.Cleanup(ctx, &pb.Empty{}) - if b.server != nil { - b.server.GracefulStop() + + // This will block until Setup has run the function to create a new server + // in b.server. If we stop here before it has a chance to actually start + // listening, when it starts listening it will immediatley error out and + // exit, which is fine. Overall this ensures that we do not miss stopping + // the server if it ends up being created after Cleanup is called. + <-b.cleanupCh + server := b.server.Load() + if server != nil { + server.(*grpc.Server).GracefulStop() } b.clientConn.Close() } @@ -184,7 +198,8 @@ func (b *backendGRPCPluginClient) Setup(ctx context.Context, config *logical.Bac s := grpc.NewServer(opts...) pb.RegisterSystemViewServer(s, sysView) pb.RegisterStorageServer(s, storage) - b.server = s + b.server.Store(s) + close(b.cleanupCh) return s } brokerID := b.broker.NextId() diff --git a/vault/audit_broker.go b/vault/audit_broker.go index 3584f8a6f625..d39856d8d6dd 100644 --- a/vault/audit_broker.go +++ b/vault/audit_broker.go @@ -60,7 +60,7 @@ func (a *AuditBroker) IsRegistered(name string) bool { } // GetHash returns a hash using the salt of the given backend -func (a *AuditBroker) GetHash(name string, input string) (string, error) { +func (a *AuditBroker) GetHash(ctx context.Context, name string, input string) (string, error) { a.RLock() defer a.RUnlock() be, ok := a.backends[name] @@ -68,7 +68,7 @@ func (a *AuditBroker) GetHash(name string, input string) (string, error) { return "", fmt.Errorf("unknown audit backend %s", name) } - return be.backend.GetHash(input) + return be.backend.GetHash(ctx, input) } // LogRequest is used to ensure all the audit backends have an opportunity to @@ -110,7 +110,7 @@ func (a *AuditBroker) LogRequest(ctx context.Context, in *audit.LogInput, header anyLogged := false for name, be := range a.backends { in.Request.Headers = nil - transHeaders, thErr := headersConfig.ApplyConfig(headers, be.backend.GetHash) + transHeaders, thErr := headersConfig.ApplyConfig(ctx, headers, be.backend.GetHash) if thErr != nil { a.logger.Error("audit: backend failed to include headers", "backend", name, "error", thErr) continue @@ -166,7 +166,7 @@ func (a *AuditBroker) LogResponse(ctx context.Context, in *audit.LogInput, heade anyLogged := false for name, be := range a.backends { in.Request.Headers = nil - transHeaders, thErr := headersConfig.ApplyConfig(headers, be.backend.GetHash) + transHeaders, thErr := headersConfig.ApplyConfig(ctx, headers, be.backend.GetHash) if thErr != nil { a.logger.Error("audit: backend failed to include headers", "backend", name, "error", thErr) continue diff --git a/vault/audit_test.go b/vault/audit_test.go index 8ea249ce8b4e..820766b894d7 100644 --- a/vault/audit_test.go +++ b/vault/audit_test.go @@ -66,7 +66,7 @@ func (n *NoopAudit) LogResponse(ctx context.Context, in *audit.LogInput) error { return n.RespErr } -func (n *NoopAudit) Salt() (*salt.Salt, error) { +func (n *NoopAudit) Salt(ctx context.Context) (*salt.Salt, error) { n.saltMutex.RLock() if n.salt != nil { defer n.saltMutex.RUnlock() @@ -78,7 +78,7 @@ func (n *NoopAudit) Salt() (*salt.Salt, error) { if n.salt != nil { return n.salt, nil } - salt, err := salt.NewSalt(n.Config.SaltView, n.Config.SaltConfig) + salt, err := salt.NewSalt(ctx, n.Config.SaltView, n.Config.SaltConfig) if err != nil { return nil, err } @@ -86,8 +86,8 @@ func (n *NoopAudit) Salt() (*salt.Salt, error) { return salt, nil } -func (n *NoopAudit) GetHash(data string) (string, error) { - salt, err := n.Salt() +func (n *NoopAudit) GetHash(ctx context.Context, data string) (string, error) { + salt, err := n.Salt(ctx) if err != nil { return "", err } diff --git a/vault/audited_headers.go b/vault/audited_headers.go index 46e85b0e7462..ce5c2b1b3d14 100644 --- a/vault/audited_headers.go +++ b/vault/audited_headers.go @@ -89,7 +89,7 @@ func (a *AuditedHeadersConfig) remove(ctx context.Context, header string) error // ApplyConfig returns a map of approved headers and their values, either // hmac'ed or plaintext -func (a *AuditedHeadersConfig) ApplyConfig(headers map[string][]string, hashFunc func(string) (string, error)) (result map[string][]string, retErr error) { +func (a *AuditedHeadersConfig) ApplyConfig(ctx context.Context, headers map[string][]string, hashFunc func(context.Context, string) (string, error)) (result map[string][]string, retErr error) { // Grab a read lock a.RLock() defer a.RUnlock() @@ -111,7 +111,7 @@ func (a *AuditedHeadersConfig) ApplyConfig(headers map[string][]string, hashFunc // Optionally hmac the values if settings.HMAC { for i, el := range hVals { - hVal, err := hashFunc(el) + hVal, err := hashFunc(ctx, el) if err != nil { return nil, err } diff --git a/vault/audited_headers_test.go b/vault/audited_headers_test.go index a673ba0bc88a..473b0bb7c8da 100644 --- a/vault/audited_headers_test.go +++ b/vault/audited_headers_test.go @@ -167,9 +167,9 @@ func TestAuditedHeadersConfig_ApplyConfig(t *testing.T) { "Content-Type": []string{"json"}, } - hashFunc := func(s string) (string, error) { return "hashed", nil } + hashFunc := func(ctx context.Context, s string) (string, error) { return "hashed", nil } - result, err := conf.ApplyConfig(reqHeaders, hashFunc) + result, err := conf.ApplyConfig(context.Background(), reqHeaders, hashFunc) if err != nil { t.Fatal(err) } @@ -213,16 +213,16 @@ func BenchmarkAuditedHeaderConfig_ApplyConfig(b *testing.B) { "Content-Type": []string{"json"}, } - salter, err := salt.NewSalt(nil, nil) + salter, err := salt.NewSalt(context.Background(), nil, nil) if err != nil { b.Fatal(err) } - hashFunc := func(s string) (string, error) { return salter.GetIdentifiedHMAC(s), nil } + hashFunc := func(ctx context.Context, s string) (string, error) { return salter.GetIdentifiedHMAC(s), nil } // Reset the timer since we did a lot above b.ResetTimer() for i := 0; i < b.N; i++ { - conf.ApplyConfig(reqHeaders, hashFunc) + conf.ApplyConfig(context.Background(), reqHeaders, hashFunc) } } diff --git a/vault/core.go b/vault/core.go index ef43d4dc65e3..3e97c0de1768 100644 --- a/vault/core.go +++ b/vault/core.go @@ -190,10 +190,12 @@ type Core struct { stateLock sync.RWMutex sealed bool - standby bool - standbyDoneCh chan struct{} - standbyStopCh chan struct{} - manualStepDownCh chan struct{} + standby bool + standbyDoneCh chan struct{} + standbyStopCh chan struct{} + manualStepDownCh chan struct{} + keepHALockOnStepDown uint32 + heldHALock physical.Lock // unlockInfo has the keys provided to Unseal until the threshold number of parts is available, as well as the operation nonce unlockInfo *unlockInformation @@ -626,6 +628,7 @@ func NewCore(conf *CoreConfig) (*Core, error) { // problem. It is only used to gracefully quit in the case of HA so that failover // happens as quickly as possible. func (c *Core) Shutdown() error { + c.logger.Trace("core: shutdown called") c.stateLock.RLock() // Tell any requests that know about this to stop if c.activeContextCancelFunc != nil { @@ -633,15 +636,13 @@ func (c *Core) Shutdown() error { } c.stateLock.RUnlock() + c.logger.Trace("core: shutdown initiating internal seal") // Seal the Vault, causes a leader stepdown - retChan := make(chan error) - go func() { - c.stateLock.Lock() - defer c.stateLock.Unlock() - retChan <- c.sealInternal() - }() + c.stateLock.Lock() + defer c.stateLock.Unlock() - return <-retChan + c.logger.Trace("core: shutdown running internal seal") + return c.sealInternal(false) } // CORSConfig returns the current CORS configuration @@ -1237,9 +1238,9 @@ func (c *Core) unsealInternal(ctx context.Context, masterKey []byte) (bool, erro } else { // Go to standby mode, wait until we are active to unseal c.standbyDoneCh = make(chan struct{}) - c.standbyStopCh = make(chan struct{}) c.manualStepDownCh = make(chan struct{}) - go c.runStandby(c.standbyDoneCh, c.standbyStopCh, c.manualStepDownCh) + c.standbyStopCh = make(chan struct{}) + go c.runStandby(c.standbyDoneCh, c.manualStepDownCh, c.standbyStopCh) } // Success! @@ -1406,19 +1407,15 @@ func (c *Core) sealInitCommon(ctx context.Context, req *logical.Request) (retErr c.stateLock.RUnlock() //Seal the Vault - retChan := make(chan error) - go func() { - c.stateLock.Lock() - defer c.stateLock.Unlock() - retChan <- c.sealInternal() - }() + c.stateLock.Lock() + defer c.stateLock.Unlock() + sealErr := c.sealInternal(false) - funcErr := <-retChan - if funcErr != nil { - retErr = multierror.Append(retErr, funcErr) + if sealErr != nil { + retErr = multierror.Append(retErr, sealErr) } - return retErr + return } // StepDown is used to step down from leadership @@ -1515,7 +1512,7 @@ func (c *Core) StepDown(req *logical.Request) (retErr error) { // sealInternal is an internal method used to seal the vault. It does not do // any authorization checking. The stateLock must be held prior to calling. -func (c *Core) sealInternal() error { +func (c *Core) sealInternal(keepLock bool) error { if c.sealed { return nil } @@ -1539,13 +1536,21 @@ func (c *Core) sealInternal() error { return fmt.Errorf("internal error") } } else { - // Signal the standby goroutine to shutdown, wait for completion + if keepLock { + atomic.StoreUint32(&c.keepHALockOnStepDown, 1) + } + // If we are trying to acquire the lock, force it to return with nil so + // runStandby will exit + // If we are active, signal the standby goroutine to shut down and wait + // for completion. We have the state lock here so nothing else should + // be toggling standby status. close(c.standbyStopCh) + c.logger.Trace("core: finished triggering standbyStopCh for runStandby") - // Release the lock while we wait to avoid deadlocking - c.stateLock.Unlock() + // Wait for runStandby to stop <-c.standbyDoneCh - c.stateLock.Lock() + atomic.StoreUint32(&c.keepHALockOnStepDown, 0) + c.logger.Trace("core: runStandby done") } c.logger.Debug("core: sealing barrier") @@ -1744,7 +1749,7 @@ func stopReplicationImpl(c *Core) error { // runStandby is a long running routine that is used when an HA backend // is enabled. It waits until we are leader and switches this Vault to // active. -func (c *Core) runStandby(doneCh, stopCh, manualStepDownCh chan struct{}) { +func (c *Core) runStandby(doneCh, manualStepDownCh, stopCh chan struct{}) { defer close(doneCh) defer close(manualStepDownCh) c.logger.Info("core: entering standby mode") @@ -1758,18 +1763,29 @@ func (c *Core) runStandby(doneCh, stopCh, manualStepDownCh chan struct{}) { checkLeaderStop := make(chan struct{}) go c.periodicLeaderRefresh(checkLeaderDone, checkLeaderStop) defer func() { + c.logger.Trace("core: closed periodic key rotation checker stop channel") close(keyRotateStop) <-keyRotateDone close(checkLeaderStop) + c.logger.Trace("core: closed periodic leader refresh stop channel") <-checkLeaderDone + c.logger.Trace("core: periodic leader refresh returned") }() + var manualStepDown bool for { // Check for a shutdown select { case <-stopCh: + c.logger.Trace("core: stop channel triggered in runStandby") return default: + // If we've just down, we could instantly grab the lock again. Give + // the other nodes a chance. + if manualStepDown { + time.Sleep(manualStepDownSleepPeriod) + manualStepDown = false + } } // Create a lock @@ -1799,7 +1815,42 @@ func (c *Core) runStandby(doneCh, stopCh, manualStepDownCh chan struct{}) { // Grab the lock as we need it for cluster setup, which needs to happen // before advertising; - c.stateLock.Lock() + + lockGrabbedCh := make(chan struct{}) + go func() { + // Grab the lock + c.stateLock.Lock() + // If stopCh has been closed, which only happens while the + // stateLock is held, we have actually terminated, so we just + // instantly give up the lock, otherwise we notify that it's ready + // for consumption + select { + case <-stopCh: + c.stateLock.Unlock() + default: + close(lockGrabbedCh) + } + }() + + select { + case <-stopCh: + lock.Unlock() + metrics.MeasureSince([]string{"core", "leadership_setup_failed"}, activeTime) + return + case <-lockGrabbedCh: + // We now have the lock and can use it + } + + if c.sealed { + c.logger.Warn("core: grabbed HA lock but already sealed, exiting") + lock.Unlock() + c.stateLock.Unlock() + metrics.MeasureSince([]string{"core", "leadership_setup_failed"}, activeTime) + return + } + + // Store the lock so that we can manually clear it later if needed + c.heldHALock = lock // We haven't run postUnseal yet so we have nothing meaningful to use here ctx := context.Background() @@ -1818,10 +1869,11 @@ func (c *Core) runStandby(doneCh, stopCh, manualStepDownCh chan struct{}) { // statelock and have this shut us down; sealInternal has a // workflow where it watches for the stopCh to close so we want // to return from here - go c.Shutdown() c.logger.Error("core: error performing key upgrades", "error", err) - c.stateLock.Unlock() + go c.Shutdown() + c.heldHALock = nil lock.Unlock() + c.stateLock.Unlock() metrics.MeasureSince([]string{"core", "leadership_setup_failed"}, activeTime) return } @@ -1834,18 +1886,20 @@ func (c *Core) runStandby(doneCh, stopCh, manualStepDownCh chan struct{}) { c.localClusterPrivateKey.Store((*ecdsa.PrivateKey)(nil)) if err := c.setupCluster(ctx); err != nil { + c.heldHALock = nil + lock.Unlock() c.stateLock.Unlock() c.logger.Error("core: cluster setup failed", "error", err) - lock.Unlock() metrics.MeasureSince([]string{"core", "leadership_setup_failed"}, activeTime) continue } // Advertise as leader if err := c.advertiseLeader(ctx, uuid, leaderLostCh); err != nil { + c.heldHALock = nil + lock.Unlock() c.stateLock.Unlock() c.logger.Error("core: leader advertisement setup failed", "error", err) - lock.Unlock() metrics.MeasureSince([]string{"core", "leadership_setup_failed"}, activeTime) continue } @@ -1855,6 +1909,7 @@ func (c *Core) runStandby(doneCh, stopCh, manualStepDownCh chan struct{}) { if err == nil { c.standby = false } + c.stateLock.Unlock() // Handle a failure to unseal @@ -1866,12 +1921,18 @@ func (c *Core) runStandby(doneCh, stopCh, manualStepDownCh chan struct{}) { } // Monitor a loss of leadership - var manualStepDown bool + releaseHALock := true + grabStateLock := true select { case <-leaderLostCh: c.logger.Warn("core: leadership lost, stopping active operation") case <-stopCh: - c.logger.Warn("core: stopping active operation") + // This case comes from sealInternal; we will already be having the + // state lock held so we do toggle grabStateLock to false + if atomic.LoadUint32(&c.keepHALockOnStepDown) == 1 { + releaseHALock = false + } + grabStateLock = false case <-manualStepDownCh: c.logger.Warn("core: stepping down from active operation to standby") manualStepDown = true @@ -1879,35 +1940,33 @@ func (c *Core) runStandby(doneCh, stopCh, manualStepDownCh chan struct{}) { metrics.MeasureSince([]string{"core", "leadership_lost"}, activeTime) - // Clear ourself as leader - if err := c.clearLeader(uuid); err != nil { - c.logger.Error("core: clearing leader advertisement failed", "error", err) - } - // Tell any requests that know about this to stop if c.activeContextCancelFunc != nil { c.activeContextCancelFunc() } // Attempt the pre-seal process - c.stateLock.Lock() + if grabStateLock { + c.stateLock.Lock() + } c.standby = true preSealErr := c.preSeal() - c.stateLock.Unlock() + if grabStateLock { + c.stateLock.Unlock() + } - // Give up leadership - lock.Unlock() + if releaseHALock { + if err := c.clearLeader(uuid); err != nil { + c.logger.Error("core: clearing leader advertisement failed", "error", err) + } + c.heldHALock.Unlock() + c.heldHALock = nil + } // Check for a failure to prepare to seal if preSealErr != nil { c.logger.Error("core: pre-seal teardown failed", "error", err) } - - // If we've merely stepped down, we could instantly grab the lock - // again. Give the other nodes a chance. - if manualStepDown { - time.Sleep(manualStepDownSleepPeriod) - } } } @@ -1917,10 +1976,23 @@ func (c *Core) runStandby(doneCh, stopCh, manualStepDownCh chan struct{}) { // the result. func (c *Core) periodicLeaderRefresh(doneCh, stopCh chan struct{}) { defer close(doneCh) + var opCount int32 for { select { case <-time.After(leaderCheckInterval): - c.Leader() + count := atomic.AddInt32(&opCount, 1) + if count > 1 { + atomic.AddInt32(&opCount, -1) + continue + } + // We do this in a goroutine because otherwise if this refresh is + // called while we're shutting down the call to Leader() can + // deadlock, which then means stopCh can never been seen and we can + // block shutdown + go func() { + defer atomic.AddInt32(&opCount, -1) + c.Leader() + }() case <-stopCh: return } @@ -1930,30 +2002,40 @@ func (c *Core) periodicLeaderRefresh(doneCh, stopCh chan struct{}) { // periodicCheckKeyUpgrade is used to watch for key rotation events as a standby func (c *Core) periodicCheckKeyUpgrade(ctx context.Context, doneCh, stopCh chan struct{}) { defer close(doneCh) + var opCount int32 for { select { case <-time.After(keyRotateCheckInterval): - // Only check if we are a standby - c.stateLock.RLock() - standby := c.standby - c.stateLock.RUnlock() - if !standby { + count := atomic.AddInt32(&opCount, 1) + if count > 1 { + atomic.AddInt32(&opCount, -1) continue } - // Check for a poison pill. If we can read it, it means we have stale - // keys (e.g. from replication being activated) and we need to seal to - // be unsealed again. - entry, _ := c.barrier.Get(ctx, poisonPillPath) - if entry != nil && len(entry.Value) > 0 { - c.logger.Warn("core: encryption keys have changed out from underneath us (possibly due to replication enabling), must be unsealed again") - go c.Shutdown() - continue - } + go func() { + defer atomic.AddInt32(&opCount, -1) + // Only check if we are a standby + c.stateLock.RLock() + standby := c.standby + c.stateLock.RUnlock() + if !standby { + return + } - if err := c.checkKeyUpgrades(ctx); err != nil { - c.logger.Error("core: key rotation periodic upgrade check failed", "error", err) - } + // Check for a poison pill. If we can read it, it means we have stale + // keys (e.g. from replication being activated) and we need to seal to + // be unsealed again. + entry, _ := c.barrier.Get(ctx, poisonPillPath) + if entry != nil && len(entry.Value) > 0 { + c.logger.Warn("core: encryption keys have changed out from underneath us (possibly due to replication enabling), must be unsealed again") + go c.Shutdown() + return + } + + if err := c.checkKeyUpgrades(ctx); err != nil { + c.logger.Error("core: key rotation periodic upgrade check failed", "error", err) + } + }() case <-stopCh: return } diff --git a/vault/expiration.go b/vault/expiration.go index 006ff001df0e..2175784b2f9a 100644 --- a/vault/expiration.go +++ b/vault/expiration.go @@ -219,7 +219,7 @@ func (m *ExpirationManager) Tidy() error { isValid, ok = tokenCache[le.ClientToken] if !ok { - saltedID, err := m.tokenStore.SaltID(le.ClientToken) + saltedID, err := m.tokenStore.SaltID(m.quitContext, le.ClientToken) if err != nil { tidyErrors = multierror.Append(tidyErrors, fmt.Errorf("failed to lookup salt id: %v", err)) return @@ -563,7 +563,7 @@ func (m *ExpirationManager) RevokeByToken(te *TokenEntry) error { } if te.Path != "" { - saltedID, err := m.tokenStore.SaltID(te.ID) + saltedID, err := m.tokenStore.SaltID(m.quitContext, te.ID) if err != nil { return err } @@ -715,7 +715,7 @@ func (m *ExpirationManager) RenewToken(req *logical.Request, source string, toke defer metrics.MeasureSince([]string{"expire", "renew-token"}, time.Now()) // Compute the Lease ID - saltedID, err := m.tokenStore.SaltID(token) + saltedID, err := m.tokenStore.SaltID(m.quitContext, token) if err != nil { return nil, err } @@ -768,14 +768,14 @@ func (m *ExpirationManager) RenewToken(req *logical.Request, source string, toke // framework.LeaseExtend call against the request. Also, cap period value to // the sys/mount max value. if resp.Auth.Period > sysView.MaxLeaseTTL() { - retResp.AddWarning(fmt.Sprintf("Period of %d seconds is greater than current mount/system default of %d seconds, value will be truncated.", resp.Auth.TTL, sysView.MaxLeaseTTL())) + retResp.AddWarning(fmt.Sprintf("Period of %d seconds is greater than current mount/system default of %d seconds, value will be truncated.", int64(resp.Auth.TTL.Seconds()), int64(sysView.MaxLeaseTTL().Seconds()))) resp.Auth.Period = sysView.MaxLeaseTTL() } resp.Auth.TTL = resp.Auth.Period case resp.Auth.TTL > time.Duration(0): // Cap TTL value to the sys/mount max value if resp.Auth.TTL > sysView.MaxLeaseTTL() { - retResp.AddWarning(fmt.Sprintf("TTL of %d seconds is greater than current mount/system default of %d seconds, value will be truncated.", resp.Auth.TTL, sysView.MaxLeaseTTL())) + retResp.AddWarning(fmt.Sprintf("TTL of %d seconds is greater than current mount/system default of %d seconds, value will be truncated.", int64(resp.Auth.TTL.Seconds()), int64(sysView.MaxLeaseTTL().Seconds()))) resp.Auth.TTL = sysView.MaxLeaseTTL() } } @@ -891,7 +891,7 @@ func (m *ExpirationManager) RegisterAuth(source string, auth *logical.Auth) erro return fmt.Errorf("expiration: %s", consts.ErrPathContainsParentReferences) } - saltedID, err := m.tokenStore.SaltID(auth.ClientToken) + saltedID, err := m.tokenStore.SaltID(m.quitContext, auth.ClientToken) if err != nil { return err } @@ -928,7 +928,7 @@ func (m *ExpirationManager) FetchLeaseTimesByToken(source, token string) (*lease defer metrics.MeasureSince([]string{"expire", "fetch-lease-times-by-token"}, time.Now()) // Compute the Lease ID - saltedID, err := m.tokenStore.SaltID(token) + saltedID, err := m.tokenStore.SaltID(m.quitContext, token) if err != nil { return nil, err } @@ -1180,12 +1180,12 @@ func (m *ExpirationManager) deleteEntry(leaseID string) error { // createIndexByToken creates a secondary index from the token to a lease entry func (m *ExpirationManager) createIndexByToken(token, leaseID string) error { - saltedID, err := m.tokenStore.SaltID(token) + saltedID, err := m.tokenStore.SaltID(m.quitContext, token) if err != nil { return err } - leaseSaltedID, err := m.tokenStore.SaltID(leaseID) + leaseSaltedID, err := m.tokenStore.SaltID(m.quitContext, leaseID) if err != nil { return err } @@ -1202,12 +1202,12 @@ func (m *ExpirationManager) createIndexByToken(token, leaseID string) error { // indexByToken looks up the secondary index from the token to a lease entry func (m *ExpirationManager) indexByToken(token, leaseID string) (*logical.StorageEntry, error) { - saltedID, err := m.tokenStore.SaltID(token) + saltedID, err := m.tokenStore.SaltID(m.quitContext, token) if err != nil { return nil, err } - leaseSaltedID, err := m.tokenStore.SaltID(leaseID) + leaseSaltedID, err := m.tokenStore.SaltID(m.quitContext, leaseID) if err != nil { return nil, err } @@ -1222,12 +1222,12 @@ func (m *ExpirationManager) indexByToken(token, leaseID string) (*logical.Storag // removeIndexByToken removes the secondary index from the token to a lease entry func (m *ExpirationManager) removeIndexByToken(token, leaseID string) error { - saltedID, err := m.tokenStore.SaltID(token) + saltedID, err := m.tokenStore.SaltID(m.quitContext, token) if err != nil { return err } - leaseSaltedID, err := m.tokenStore.SaltID(leaseID) + leaseSaltedID, err := m.tokenStore.SaltID(m.quitContext, leaseID) if err != nil { return err } @@ -1241,7 +1241,7 @@ func (m *ExpirationManager) removeIndexByToken(token, leaseID string) error { // lookupByToken is used to lookup all the leaseID's via the func (m *ExpirationManager) lookupByToken(token string) ([]string, error) { - saltedID, err := m.tokenStore.SaltID(token) + saltedID, err := m.tokenStore.SaltID(m.quitContext, token) if err != nil { return nil, err } diff --git a/vault/logical_system.go b/vault/logical_system.go index 508747415622..c632c8805bc5 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -1443,15 +1443,22 @@ func (b *SystemBackend) handleMountTable(ctx context.Context, req *logical.Reque "type": entry.Type, "description": entry.Description, "accessor": entry.Accessor, - "config": map[string]interface{}{ - "default_lease_ttl": int64(entry.Config.DefaultLeaseTTL.Seconds()), - "max_lease_ttl": int64(entry.Config.MaxLeaseTTL.Seconds()), - "force_no_cache": entry.Config.ForceNoCache, - "plugin_name": entry.Config.PluginName, - }, - "local": entry.Local, - "seal_wrap": entry.SealWrap, + "local": entry.Local, + "seal_wrap": entry.SealWrap, + } + entryConfig := map[string]interface{}{ + "default_lease_ttl": int64(entry.Config.DefaultLeaseTTL.Seconds()), + "max_lease_ttl": int64(entry.Config.MaxLeaseTTL.Seconds()), + "force_no_cache": entry.Config.ForceNoCache, + "plugin_name": entry.Config.PluginName, + } + if rawVal, ok := entry.synthesizedConfigCache.Load("audit_non_hmac_request_keys"); ok { + entryConfig["audit_non_hmac_request_keys"] = rawVal.([]string) + } + if rawVal, ok := entry.synthesizedConfigCache.Load("audit_non_hmac_response_keys"); ok { + entryConfig["audit_non_hmac_response_keys"] = rawVal.([]string) } + info["config"] = entryConfig resp.Data[entry.Path] = info } @@ -1553,6 +1560,14 @@ func (b *SystemBackend) handleMount(ctx context.Context, req *logical.Request, d config.ForceNoCache = true } + if len(apiConfig.AuditNonHMACRequestKeys) > 0 { + config.AuditNonHMACRequestKeys = apiConfig.AuditNonHMACRequestKeys + } + + if len(apiConfig.AuditNonHMACResponseKeys) > 0 { + config.AuditNonHMACResponseKeys = apiConfig.AuditNonHMACResponseKeys + } + // Create the mount entry me := &MountEntry{ Table: mountTableType, @@ -2028,13 +2043,21 @@ func (b *SystemBackend) handleAuthTable(ctx context.Context, req *logical.Reques "type": entry.Type, "description": entry.Description, "accessor": entry.Accessor, - "config": map[string]interface{}{ - "default_lease_ttl": int64(entry.Config.DefaultLeaseTTL.Seconds()), - "max_lease_ttl": int64(entry.Config.MaxLeaseTTL.Seconds()), - }, - "local": entry.Local, - "seal_wrap": entry.SealWrap, + "local": entry.Local, + "seal_wrap": entry.SealWrap, + } + entryConfig := map[string]interface{}{ + "default_lease_ttl": int64(entry.Config.DefaultLeaseTTL.Seconds()), + "max_lease_ttl": int64(entry.Config.MaxLeaseTTL.Seconds()), + "plugin_name": entry.Config.PluginName, + } + if rawVal, ok := entry.synthesizedConfigCache.Load("audit_non_hmac_request_keys"); ok { + entryConfig["audit_non_hmac_request_keys"] = rawVal.([]string) + } + if rawVal, ok := entry.synthesizedConfigCache.Load("audit_non_hmac_response_keys"); ok { + entryConfig["audit_non_hmac_response_keys"] = rawVal.([]string) } + info["config"] = entryConfig resp.Data[entry.Path] = info } return resp, nil @@ -2129,6 +2152,14 @@ func (b *SystemBackend) handleEnableAuth(ctx context.Context, req *logical.Reque path = sanitizeMountPath(path) + if len(apiConfig.AuditNonHMACRequestKeys) > 0 { + config.AuditNonHMACRequestKeys = apiConfig.AuditNonHMACRequestKeys + } + + if len(apiConfig.AuditNonHMACResponseKeys) > 0 { + config.AuditNonHMACResponseKeys = apiConfig.AuditNonHMACResponseKeys + } + // Create the mount entry me := &MountEntry{ Table: credentialTableType, @@ -2387,7 +2418,7 @@ func (b *SystemBackend) handleAuditHash(ctx context.Context, req *logical.Reques path = sanitizeMountPath(path) - hash, err := b.Core.auditBroker.GetHash(path, input) + hash, err := b.Core.auditBroker.GetHash(ctx, path, input) if err != nil { return logical.ErrorResponse(err.Error()), nil } diff --git a/vault/logical_system_integ_test.go b/vault/logical_system_integ_test.go index 403d4892b526..f6993b8c3813 100644 --- a/vault/logical_system_integ_test.go +++ b/vault/logical_system_integ_test.go @@ -174,11 +174,12 @@ func testPlugin_CatalogRemoved(t *testing.T, btype logical.BackendType, testMoun if sealed { t.Fatal("should not be sealed") } - // Wait for active so post-unseal takes place - // If it fails, it means unseal process failed - vault.TestWaitActive(t, core.Core) } + // Wait for active so post-unseal takes place + // If it fails, it means unseal process failed + vault.TestWaitActive(t, core.Core) + if testMount { // Mount the plugin at the same path after plugin is re-added to the catalog // and expect an error due to existing path. @@ -286,11 +287,12 @@ func testPlugin_continueOnError(t *testing.T, btype logical.BackendType, mismatc if sealed { t.Fatal("should not be sealed") } - // Wait for active so post-unseal takes place - // If it fails, it means unseal process failed - vault.TestWaitActive(t, core.Core) } + // Wait for active so post-unseal takes place + // If it fails, it means unseal process failed + vault.TestWaitActive(t, core.Core) + // Re-add the plugin to the catalog switch btype { case logical.TypeLogical: @@ -394,10 +396,11 @@ func TestSystemBackend_Plugin_SealUnseal(t *testing.T) { if sealed { t.Fatal("should not be sealed") } - // Wait for active so post-unseal takes place - // If it fails, it means unseal process failed - vault.TestWaitActive(t, core.Core) } + + // Wait for active so post-unseal takes place + // If it fails, it means unseal process failed + vault.TestWaitActive(t, cluster.Cores[0].Core) } func TestSystemBackend_Plugin_reload(t *testing.T) { diff --git a/vault/logical_system_test.go b/vault/logical_system_test.go index a5bfd37043cd..5ee285bfe4d5 100644 --- a/vault/logical_system_test.go +++ b/vault/logical_system_test.go @@ -1388,6 +1388,7 @@ func TestSystemBackend_authTable(t *testing.T) { "config": map[string]interface{}{ "default_lease_ttl": int64(0), "max_lease_ttl": int64(0), + "plugin_name": "", }, "local": false, "seal_wrap": false, @@ -1438,6 +1439,7 @@ func TestSystemBackend_enableAuth(t *testing.T) { "config": map[string]interface{}{ "default_lease_ttl": int64(2100), "max_lease_ttl": int64(2700), + "plugin_name": "", }, "local": true, "seal_wrap": true, @@ -1449,6 +1451,7 @@ func TestSystemBackend_enableAuth(t *testing.T) { "config": map[string]interface{}{ "default_lease_ttl": int64(0), "max_lease_ttl": int64(0), + "plugin_name": "", }, "local": false, "seal_wrap": false, diff --git a/vault/router.go b/vault/router.go index d02204f73d69..c3dae498ace6 100644 --- a/vault/router.go +++ b/vault/router.go @@ -19,7 +19,7 @@ type Router struct { root *radix.Tree mountUUIDCache *radix.Tree mountAccessorCache *radix.Tree - tokenStoreSaltFunc func() (*salt.Salt, error) + tokenStoreSaltFunc func(context.Context) (*salt.Salt, error) // storagePrefix maps the prefix used for storage (ala the BarrierView) // to the backend. This is used to map a key back into the backend that owns it. // For example, logical/uuid1/foobar -> secrets/ (kv backend) + foobar @@ -447,7 +447,7 @@ func (r *Router) routeCommon(ctx context.Context, req *logical.Request, existenc case strings.HasPrefix(originalPath, "cubbyhole/"): // In order for the token store to revoke later, we need to have the same // salted ID, so we double-salt what's going to the cubbyhole backend - salt, err := r.tokenStoreSaltFunc() + salt, err := r.tokenStoreSaltFunc(ctx) if err != nil { return nil, false, false, err } diff --git a/vault/testing.go b/vault/testing.go index b91b581f4add..2bbd91950bc8 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -603,8 +603,8 @@ type noopAudit struct { saltMutex sync.RWMutex } -func (n *noopAudit) GetHash(data string) (string, error) { - salt, err := n.Salt() +func (n *noopAudit) GetHash(ctx context.Context, data string) (string, error) { + salt, err := n.Salt(ctx) if err != nil { return "", err } @@ -629,7 +629,7 @@ func (n *noopAudit) Invalidate(_ context.Context) { n.salt = nil } -func (n *noopAudit) Salt() (*salt.Salt, error) { +func (n *noopAudit) Salt(ctx context.Context) (*salt.Salt, error) { n.saltMutex.RLock() if n.salt != nil { defer n.saltMutex.RUnlock() @@ -641,7 +641,7 @@ func (n *noopAudit) Salt() (*salt.Salt, error) { if n.salt != nil { return n.salt, nil } - salt, err := salt.NewSalt(n.Config.SaltView, n.Config.SaltConfig) + salt, err := salt.NewSalt(ctx, n.Config.SaltView, n.Config.SaltConfig) if err != nil { return nil, err } diff --git a/vault/token_store.go b/vault/token_store.go index c65b9a3cabe3..0f6e00d28bf3 100644 --- a/vault/token_store.go +++ b/vault/token_store.go @@ -491,7 +491,7 @@ func (ts *TokenStore) Invalidate(ctx context.Context, key string) { } } -func (ts *TokenStore) Salt() (*salt.Salt, error) { +func (ts *TokenStore) Salt(ctx context.Context) (*salt.Salt, error) { ts.saltLock.RLock() if ts.salt != nil { defer ts.saltLock.RUnlock() @@ -503,7 +503,7 @@ func (ts *TokenStore) Salt() (*salt.Salt, error) { if ts.salt != nil { return ts.salt, nil } - salt, err := salt.NewSalt(ts.view, &salt.Config{ + salt, err := salt.NewSalt(ctx, ts.view, &salt.Config{ HashFunc: salt.SHA1Hash, Location: salt.DefaultLocation, }) @@ -667,8 +667,8 @@ func (ts *TokenStore) SetExpirationManager(exp *ExpirationManager) { } // SaltID is used to apply a salt and hash to an ID to make sure its not reversible -func (ts *TokenStore) SaltID(id string) (string, error) { - s, err := ts.Salt() +func (ts *TokenStore) SaltID(ctx context.Context, id string) (string, error) { + s, err := ts.Salt(ctx) if err != nil { return "", err } @@ -731,7 +731,7 @@ func (ts *TokenStore) createAccessor(ctx context.Context, entry *TokenEntry) err entry.Accessor = accessorUUID // Create index entry, mapping the accessor to the token ID - saltID, err := ts.SaltID(entry.Accessor) + saltID, err := ts.SaltID(ctx, entry.Accessor) if err != nil { return err } @@ -766,7 +766,7 @@ func (ts *TokenStore) create(ctx context.Context, entry *TokenEntry) error { entry.ID = entryUUID } - saltedId, err := ts.SaltID(entry.ID) + saltedId, err := ts.SaltID(ctx, entry.ID) if err != nil { return err } @@ -795,7 +795,7 @@ func (ts *TokenStore) store(ctx context.Context, entry *TokenEntry) error { // storeCommon handles the actual storage of an entry, possibly generating // secondary indexes func (ts *TokenStore) storeCommon(ctx context.Context, entry *TokenEntry, writeSecondary bool) error { - saltedId, err := ts.SaltID(entry.ID) + saltedId, err := ts.SaltID(ctx, entry.ID) if err != nil { return err } @@ -822,7 +822,7 @@ func (ts *TokenStore) storeCommon(ctx context.Context, entry *TokenEntry, writeS } // Create the index entry - parentSaltedID, err := ts.SaltID(entry.Parent) + parentSaltedID, err := ts.SaltID(ctx, entry.Parent) if err != nil { return err } @@ -875,7 +875,7 @@ func (ts *TokenStore) UseToken(ctx context.Context, te *TokenEntry) (*TokenEntry defer lock.Unlock() // Call lookupSalted instead of Lookup to avoid deadlocking since Lookup grabs a read lock - saltedID, err := ts.SaltID(te.ID) + saltedID, err := ts.SaltID(ctx, te.ID) if err != nil { return nil, err } @@ -932,7 +932,7 @@ func (ts *TokenStore) Lookup(ctx context.Context, id string) (*TokenEntry, error lock.RLock() defer lock.RUnlock() - saltedID, err := ts.SaltID(id) + saltedID, err := ts.SaltID(ctx, id) if err != nil { return nil, err } @@ -951,7 +951,7 @@ func (ts *TokenStore) lookupTainted(ctx context.Context, id string) (*TokenEntry lock.RLock() defer lock.RUnlock() - saltedID, err := ts.SaltID(id) + saltedID, err := ts.SaltID(ctx, id) if err != nil { return nil, err } @@ -1051,7 +1051,7 @@ func (ts *TokenStore) Revoke(ctx context.Context, id string) error { return fmt.Errorf("cannot revoke blank token") } - saltedID, err := ts.SaltID(id) + saltedID, err := ts.SaltID(ctx, id) if err != nil { return err } @@ -1145,7 +1145,7 @@ func (ts *TokenStore) revokeSalted(ctx context.Context, saltedId string) (ret er // Clear the secondary index if any if entry.Parent != "" { - parentSaltedID, err := ts.SaltID(entry.Parent) + parentSaltedID, err := ts.SaltID(ctx, entry.Parent) if err != nil { return err } @@ -1158,7 +1158,7 @@ func (ts *TokenStore) revokeSalted(ctx context.Context, saltedId string) (ret er // Clear the accessor index if any if entry.Accessor != "" { - accessorSaltedID, err := ts.SaltID(entry.Accessor) + accessorSaltedID, err := ts.SaltID(ctx, entry.Accessor) if err != nil { return err } @@ -1188,7 +1188,7 @@ func (ts *TokenStore) RevokeTree(ctx context.Context, id string) error { } // Get the salted ID - saltedId, err := ts.SaltID(id) + saltedId, err := ts.SaltID(ctx, id) if err != nil { return err } @@ -1251,7 +1251,7 @@ func (ts *TokenStore) handleCreateAgainstRole(ctx context.Context, req *logical. } func (ts *TokenStore) lookupByAccessor(ctx context.Context, accessor string, tainted bool) (accessorEntry, error) { - saltedID, err := ts.SaltID(accessor) + saltedID, err := ts.SaltID(ctx, accessor) if err != nil { return accessorEntry{}, err } @@ -1272,7 +1272,7 @@ func (ts *TokenStore) lookupBySaltedAccessor(ctx context.Context, saltedAccessor err = jsonutil.DecodeJSON(entry.Value, &aEntry) // If we hit an error, assume it's a pre-struct straight token ID if err != nil { - saltedID, err := ts.SaltID(string(entry.Value)) + saltedID, err := ts.SaltID(ctx, string(entry.Value)) if err != nil { return accessorEntry{}, err } @@ -1395,7 +1395,7 @@ func (ts *TokenStore) handleTidy(ctx context.Context, req *logical.Request, data // Look up tainted variants so we only find entries that truly don't // exist - saltedId, err := ts.SaltID(accessorEntry.TokenID) + saltedId, err := ts.SaltID(ctx, accessorEntry.TokenID) if err != nil { tidyErrors = multierror.Append(tidyErrors, fmt.Errorf("failed to read salt id: %v", err)) lock.RUnlock() @@ -1893,6 +1893,12 @@ func (ts *TokenStore) handleCreateCommon(ctx context.Context, req *logical.Reque sysView := ts.System() if periodToUse > 0 { + // Cap period value to the sys/mount max value; this matches behavior + // in expiration manager for renewals + if periodToUse > sysView.MaxLeaseTTL() { + resp.AddWarning(fmt.Sprintf("Period of %d seconds is greater than current mount/system default of %d seconds, value will be truncated.", int64(periodToUse.Seconds()), int64(sysView.MaxLeaseTTL().Seconds()))) + periodToUse = sysView.MaxLeaseTTL() + } te.TTL = periodToUse } else { // Set the default lease if not provided, root tokens are exempt @@ -2086,7 +2092,7 @@ func (ts *TokenStore) handleLookup(ctx context.Context, req *logical.Request, da defer lock.RUnlock() // Lookup the token - saltedId, err := ts.SaltID(id) + saltedId, err := ts.SaltID(ctx, id) if err != nil { return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest } diff --git a/vault/token_store_test.go b/vault/token_store_test.go index fa865c7e2199..c2396a2fd57f 100644 --- a/vault/token_store_test.go +++ b/vault/token_store_test.go @@ -58,7 +58,7 @@ func TestTokenStore_TokenEntryUpgrade(t *testing.T) { t.Fatal(err) } - saltedId, err := ts.SaltID(entry.ID) + saltedId, err := ts.SaltID(context.Background(), entry.ID) if err != nil { t.Fatal(err) } @@ -304,7 +304,7 @@ func TestTokenStore_HandleRequest_ListAccessors(t *testing.T) { } // Revoke root to make the number of accessors match - salted, err := ts.SaltID(root) + salted, err := ts.SaltID(context.Background(), root) if err != nil { t.Fatal(err) } @@ -339,7 +339,7 @@ func TestTokenStore_HandleRequest_ListAccessors(t *testing.T) { if aEntry.TokenID == "" || aEntry.AccessorID == "" { t.Fatalf("error, accessor entry looked up is empty, but no error thrown") } - salted, err := ts.SaltID(accessor) + salted, err := ts.SaltID(context.Background(), accessor) if err != nil { t.Fatal(err) } @@ -522,7 +522,7 @@ func TestTokenStore_CreateLookup_ExpirationInRestoreMode(t *testing.T) { } // Replace the lease with a lease with an expire time in the past - saltedID, err := ts.SaltID(ent.ID) + saltedID, err := ts.SaltID(context.Background(), ent.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -3254,7 +3254,7 @@ func TestTokenStore_RevokeUseCountToken(t *testing.T) { } tut := resp.Auth.ClientToken - saltTut, err := ts.SaltID(tut) + saltTut, err := ts.SaltID(context.Background(), tut) if err != nil { t.Fatal(err) } @@ -3452,7 +3452,7 @@ func TestTokenStore_HandleTidyCase1(t *testing.T) { // cubbyhole and by not deleting its secondary index, its accessor and // associated leases. - saltedTut, err := ts.SaltID(tut) + saltedTut, err := ts.SaltID(context.Background(), tut) if err != nil { t.Fatal(err) } @@ -3594,7 +3594,7 @@ func TestTokenStore_TidyLeaseRevocation(t *testing.T) { } // Now, delete the token entry. The leases should still exist. - saltedTut, err := ts.SaltID(tut) + saltedTut, err := ts.SaltID(context.Background(), tut) if err != nil { t.Fatal(err) } diff --git a/vendor/github.com/golang/go/LICENSE b/vendor/github.com/golang/go/LICENSE new file mode 100644 index 000000000000..6a66aea5eafe --- /dev/null +++ b/vendor/github.com/golang/go/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2009 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/golang/go/PATENTS b/vendor/github.com/golang/go/PATENTS new file mode 100644 index 000000000000..733099041f84 --- /dev/null +++ b/vendor/github.com/golang/go/PATENTS @@ -0,0 +1,22 @@ +Additional IP Rights Grant (Patents) + +"This implementation" means the copyrightable works distributed by +Google as part of the Go project. + +Google hereby grants to You a perpetual, worldwide, non-exclusive, +no-charge, royalty-free, irrevocable (except as stated in this section) +patent license to make, have made, use, offer to sell, sell, import, +transfer and otherwise run, modify and propagate the contents of this +implementation of Go, where such license applies only to those patent +claims, both currently owned or controlled by Google and acquired in +the future, licensable by Google that are necessarily infringed by this +implementation of Go. This grant does not include claims that would be +infringed only as a consequence of further modification of this +implementation. If you or your agent or exclusive licensee institute or +order or agree to the institution of patent litigation against any +entity (including a cross-claim or counterclaim in a lawsuit) alleging +that this implementation of Go or any code incorporated within this +implementation of Go constitutes direct or contributory patent +infringement, or inducement of patent infringement, then any patent +rights granted to you under this License for this implementation of Go +shall terminate as of the date such litigation is filed. diff --git a/vendor/github.com/golang/go/src/math/big/accuracy_string.go b/vendor/github.com/golang/go/src/math/big/accuracy_string.go new file mode 100644 index 000000000000..24ef7f107700 --- /dev/null +++ b/vendor/github.com/golang/go/src/math/big/accuracy_string.go @@ -0,0 +1,17 @@ +// generated by stringer -type=Accuracy; DO NOT EDIT + +package big + +import "fmt" + +const _Accuracy_name = "BelowExactAbove" + +var _Accuracy_index = [...]uint8{0, 5, 10, 15} + +func (i Accuracy) String() string { + i -= -1 + if i < 0 || i+1 >= Accuracy(len(_Accuracy_index)) { + return fmt.Sprintf("Accuracy(%d)", i+-1) + } + return _Accuracy_name[_Accuracy_index[i]:_Accuracy_index[i+1]] +} diff --git a/vendor/github.com/golang/go/src/math/big/arith.go b/vendor/github.com/golang/go/src/math/big/arith.go new file mode 100644 index 000000000000..ad352403a7c5 --- /dev/null +++ b/vendor/github.com/golang/go/src/math/big/arith.go @@ -0,0 +1,260 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file provides Go implementations of elementary multi-precision +// arithmetic operations on word vectors. Needed for platforms without +// assembly implementations of these routines. + +package big + +import "math/bits" + +// A Word represents a single digit of a multi-precision unsigned integer. +type Word uint + +const ( + _S = _W / 8 // word size in bytes + + _W = bits.UintSize // word size in bits + _B = 1 << _W // digit base + _M = _B - 1 // digit mask + + _W2 = _W / 2 // half word size in bits + _B2 = 1 << _W2 // half digit base + _M2 = _B2 - 1 // half digit mask +) + +// ---------------------------------------------------------------------------- +// Elementary operations on words +// +// These operations are used by the vector operations below. + +// z1<<_W + z0 = x+y+c, with c == 0 or 1 +func addWW_g(x, y, c Word) (z1, z0 Word) { + yc := y + c + z0 = x + yc + if z0 < x || yc < y { + z1 = 1 + } + return +} + +// z1<<_W + z0 = x-y-c, with c == 0 or 1 +func subWW_g(x, y, c Word) (z1, z0 Word) { + yc := y + c + z0 = x - yc + if z0 > x || yc < y { + z1 = 1 + } + return +} + +// z1<<_W + z0 = x*y +// Adapted from Warren, Hacker's Delight, p. 132. +func mulWW_g(x, y Word) (z1, z0 Word) { + x0 := x & _M2 + x1 := x >> _W2 + y0 := y & _M2 + y1 := y >> _W2 + w0 := x0 * y0 + t := x1*y0 + w0>>_W2 + w1 := t & _M2 + w2 := t >> _W2 + w1 += x0 * y1 + z1 = x1*y1 + w2 + w1>>_W2 + z0 = x * y + return +} + +// z1<<_W + z0 = x*y + c +func mulAddWWW_g(x, y, c Word) (z1, z0 Word) { + z1, zz0 := mulWW_g(x, y) + if z0 = zz0 + c; z0 < zz0 { + z1++ + } + return +} + +// nlz returns the number of leading zeros in x. +// Wraps bits.LeadingZeros call for convenience. +func nlz(x Word) uint { + return uint(bits.LeadingZeros(uint(x))) +} + +// q = (u1<<_W + u0 - r)/y +// Adapted from Warren, Hacker's Delight, p. 152. +func divWW_g(u1, u0, v Word) (q, r Word) { + if u1 >= v { + return 1<<_W - 1, 1<<_W - 1 + } + + s := nlz(v) + v <<= s + + vn1 := v >> _W2 + vn0 := v & _M2 + un32 := u1<>(_W-s) + un10 := u0 << s + un1 := un10 >> _W2 + un0 := un10 & _M2 + q1 := un32 / vn1 + rhat := un32 - q1*vn1 + + for q1 >= _B2 || q1*vn0 > _B2*rhat+un1 { + q1-- + rhat += vn1 + if rhat >= _B2 { + break + } + } + + un21 := un32*_B2 + un1 - q1*v + q0 := un21 / vn1 + rhat = un21 - q0*vn1 + + for q0 >= _B2 || q0*vn0 > _B2*rhat+un0 { + q0-- + rhat += vn1 + if rhat >= _B2 { + break + } + } + + return q1*_B2 + q0, (un21*_B2 + un0 - q0*v) >> s +} + +// Keep for performance debugging. +// Using addWW_g is likely slower. +const use_addWW_g = false + +// The resulting carry c is either 0 or 1. +func addVV_g(z, x, y []Word) (c Word) { + if use_addWW_g { + for i := range z { + c, z[i] = addWW_g(x[i], y[i], c) + } + return + } + + for i, xi := range x[:len(z)] { + yi := y[i] + zi := xi + yi + c + z[i] = zi + // see "Hacker's Delight", section 2-12 (overflow detection) + c = (xi&yi | (xi|yi)&^zi) >> (_W - 1) + } + return +} + +// The resulting carry c is either 0 or 1. +func subVV_g(z, x, y []Word) (c Word) { + if use_addWW_g { + for i := range z { + c, z[i] = subWW_g(x[i], y[i], c) + } + return + } + + for i, xi := range x[:len(z)] { + yi := y[i] + zi := xi - yi - c + z[i] = zi + // see "Hacker's Delight", section 2-12 (overflow detection) + c = (yi&^xi | (yi|^xi)&zi) >> (_W - 1) + } + return +} + +// The resulting carry c is either 0 or 1. +func addVW_g(z, x []Word, y Word) (c Word) { + if use_addWW_g { + c = y + for i := range z { + c, z[i] = addWW_g(x[i], c, 0) + } + return + } + + c = y + for i, xi := range x[:len(z)] { + zi := xi + c + z[i] = zi + c = xi &^ zi >> (_W - 1) + } + return +} + +func subVW_g(z, x []Word, y Word) (c Word) { + if use_addWW_g { + c = y + for i := range z { + c, z[i] = subWW_g(x[i], c, 0) + } + return + } + + c = y + for i, xi := range x[:len(z)] { + zi := xi - c + z[i] = zi + c = (zi &^ xi) >> (_W - 1) + } + return +} + +func shlVU_g(z, x []Word, s uint) (c Word) { + if n := len(z); n > 0 { + ŝ := _W - s + w1 := x[n-1] + c = w1 >> ŝ + for i := n - 1; i > 0; i-- { + w := w1 + w1 = x[i-1] + z[i] = w<>ŝ + } + z[0] = w1 << s + } + return +} + +func shrVU_g(z, x []Word, s uint) (c Word) { + if n := len(z); n > 0 { + ŝ := _W - s + w1 := x[0] + c = w1 << ŝ + for i := 0; i < n-1; i++ { + w := w1 + w1 = x[i+1] + z[i] = w>>s | w1<<ŝ + } + z[n-1] = w1 >> s + } + return +} + +func mulAddVWW_g(z, x []Word, y, r Word) (c Word) { + c = r + for i := range z { + c, z[i] = mulAddWWW_g(x[i], y, c) + } + return +} + +// TODO(gri) Remove use of addWW_g here and then we can remove addWW_g and subWW_g. +func addMulVVW_g(z, x []Word, y Word) (c Word) { + for i := range z { + z1, z0 := mulAddWWW_g(x[i], y, z[i]) + c, z[i] = addWW_g(z0, c, 0) + c += z1 + } + return +} + +func divWVW_g(z []Word, xn Word, x []Word, y Word) (r Word) { + r = xn + for i := len(z) - 1; i >= 0; i-- { + z[i], r = divWW_g(r, x[i], y) + } + return +} diff --git a/vendor/github.com/golang/go/src/math/big/arith_386.s b/vendor/github.com/golang/go/src/math/big/arith_386.s new file mode 100644 index 000000000000..6c080f074a3c --- /dev/null +++ b/vendor/github.com/golang/go/src/math/big/arith_386.s @@ -0,0 +1,271 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !math_big_pure_go + +#include "textflag.h" + +// This file provides fast assembly versions for the elementary +// arithmetic operations on vectors implemented in arith.go. + +// func mulWW(x, y Word) (z1, z0 Word) +TEXT ·mulWW(SB),NOSPLIT,$0 + MOVL x+0(FP), AX + MULL y+4(FP) + MOVL DX, z1+8(FP) + MOVL AX, z0+12(FP) + RET + + +// func divWW(x1, x0, y Word) (q, r Word) +TEXT ·divWW(SB),NOSPLIT,$0 + MOVL x1+0(FP), DX + MOVL x0+4(FP), AX + DIVL y+8(FP) + MOVL AX, q+12(FP) + MOVL DX, r+16(FP) + RET + + +// func addVV(z, x, y []Word) (c Word) +TEXT ·addVV(SB),NOSPLIT,$0 + MOVL z+0(FP), DI + MOVL x+12(FP), SI + MOVL y+24(FP), CX + MOVL z_len+4(FP), BP + MOVL $0, BX // i = 0 + MOVL $0, DX // c = 0 + JMP E1 + +L1: MOVL (SI)(BX*4), AX + ADDL DX, DX // restore CF + ADCL (CX)(BX*4), AX + SBBL DX, DX // save CF + MOVL AX, (DI)(BX*4) + ADDL $1, BX // i++ + +E1: CMPL BX, BP // i < n + JL L1 + + NEGL DX + MOVL DX, c+36(FP) + RET + + +// func subVV(z, x, y []Word) (c Word) +// (same as addVV except for SBBL instead of ADCL and label names) +TEXT ·subVV(SB),NOSPLIT,$0 + MOVL z+0(FP), DI + MOVL x+12(FP), SI + MOVL y+24(FP), CX + MOVL z_len+4(FP), BP + MOVL $0, BX // i = 0 + MOVL $0, DX // c = 0 + JMP E2 + +L2: MOVL (SI)(BX*4), AX + ADDL DX, DX // restore CF + SBBL (CX)(BX*4), AX + SBBL DX, DX // save CF + MOVL AX, (DI)(BX*4) + ADDL $1, BX // i++ + +E2: CMPL BX, BP // i < n + JL L2 + + NEGL DX + MOVL DX, c+36(FP) + RET + + +// func addVW(z, x []Word, y Word) (c Word) +TEXT ·addVW(SB),NOSPLIT,$0 + MOVL z+0(FP), DI + MOVL x+12(FP), SI + MOVL y+24(FP), AX // c = y + MOVL z_len+4(FP), BP + MOVL $0, BX // i = 0 + JMP E3 + +L3: ADDL (SI)(BX*4), AX + MOVL AX, (DI)(BX*4) + SBBL AX, AX // save CF + NEGL AX + ADDL $1, BX // i++ + +E3: CMPL BX, BP // i < n + JL L3 + + MOVL AX, c+28(FP) + RET + + +// func subVW(z, x []Word, y Word) (c Word) +TEXT ·subVW(SB),NOSPLIT,$0 + MOVL z+0(FP), DI + MOVL x+12(FP), SI + MOVL y+24(FP), AX // c = y + MOVL z_len+4(FP), BP + MOVL $0, BX // i = 0 + JMP E4 + +L4: MOVL (SI)(BX*4), DX + SUBL AX, DX + MOVL DX, (DI)(BX*4) + SBBL AX, AX // save CF + NEGL AX + ADDL $1, BX // i++ + +E4: CMPL BX, BP // i < n + JL L4 + + MOVL AX, c+28(FP) + RET + + +// func shlVU(z, x []Word, s uint) (c Word) +TEXT ·shlVU(SB),NOSPLIT,$0 + MOVL z_len+4(FP), BX // i = z + SUBL $1, BX // i-- + JL X8b // i < 0 (n <= 0) + + // n > 0 + MOVL z+0(FP), DI + MOVL x+12(FP), SI + MOVL s+24(FP), CX + MOVL (SI)(BX*4), AX // w1 = x[n-1] + MOVL $0, DX + SHLL CX, DX:AX // w1>>ŝ + MOVL DX, c+28(FP) + + CMPL BX, $0 + JLE X8a // i <= 0 + + // i > 0 +L8: MOVL AX, DX // w = w1 + MOVL -4(SI)(BX*4), AX // w1 = x[i-1] + SHLL CX, DX:AX // w<>ŝ + MOVL DX, (DI)(BX*4) // z[i] = w<>ŝ + SUBL $1, BX // i-- + JG L8 // i > 0 + + // i <= 0 +X8a: SHLL CX, AX // w1< 0 + MOVL z+0(FP), DI + MOVL x+12(FP), SI + MOVL s+24(FP), CX + MOVL (SI), AX // w1 = x[0] + MOVL $0, DX + SHRL CX, DX:AX // w1<<ŝ + MOVL DX, c+28(FP) + + MOVL $0, BX // i = 0 + JMP E9 + + // i < n-1 +L9: MOVL AX, DX // w = w1 + MOVL 4(SI)(BX*4), AX // w1 = x[i+1] + SHRL CX, DX:AX // w>>s | w1<<ŝ + MOVL DX, (DI)(BX*4) // z[i] = w>>s | w1<<ŝ + ADDL $1, BX // i++ + +E9: CMPL BX, BP + JL L9 // i < n-1 + + // i >= n-1 +X9a: SHRL CX, AX // w1>>s + MOVL AX, (DI)(BP*4) // z[n-1] = w1>>s + RET + +X9b: MOVL $0, c+28(FP) + RET + + +// func mulAddVWW(z, x []Word, y, r Word) (c Word) +TEXT ·mulAddVWW(SB),NOSPLIT,$0 + MOVL z+0(FP), DI + MOVL x+12(FP), SI + MOVL y+24(FP), BP + MOVL r+28(FP), CX // c = r + MOVL z_len+4(FP), BX + LEAL (DI)(BX*4), DI + LEAL (SI)(BX*4), SI + NEGL BX // i = -n + JMP E5 + +L5: MOVL (SI)(BX*4), AX + MULL BP + ADDL CX, AX + ADCL $0, DX + MOVL AX, (DI)(BX*4) + MOVL DX, CX + ADDL $1, BX // i++ + +E5: CMPL BX, $0 // i < 0 + JL L5 + + MOVL CX, c+32(FP) + RET + + +// func addMulVVW(z, x []Word, y Word) (c Word) +TEXT ·addMulVVW(SB),NOSPLIT,$0 + MOVL z+0(FP), DI + MOVL x+12(FP), SI + MOVL y+24(FP), BP + MOVL z_len+4(FP), BX + LEAL (DI)(BX*4), DI + LEAL (SI)(BX*4), SI + NEGL BX // i = -n + MOVL $0, CX // c = 0 + JMP E6 + +L6: MOVL (SI)(BX*4), AX + MULL BP + ADDL CX, AX + ADCL $0, DX + ADDL AX, (DI)(BX*4) + ADCL $0, DX + MOVL DX, CX + ADDL $1, BX // i++ + +E6: CMPL BX, $0 // i < 0 + JL L6 + + MOVL CX, c+28(FP) + RET + + +// func divWVW(z* Word, xn Word, x []Word, y Word) (r Word) +TEXT ·divWVW(SB),NOSPLIT,$0 + MOVL z+0(FP), DI + MOVL xn+12(FP), DX // r = xn + MOVL x+16(FP), SI + MOVL y+28(FP), CX + MOVL z_len+4(FP), BX // i = z + JMP E7 + +L7: MOVL (SI)(BX*4), AX + DIVL CX + MOVL AX, (DI)(BX*4) + +E7: SUBL $1, BX // i-- + JGE L7 // i >= 0 + + MOVL DX, r+32(FP) + RET diff --git a/vendor/github.com/golang/go/src/math/big/arith_amd64.s b/vendor/github.com/golang/go/src/math/big/arith_amd64.s new file mode 100644 index 000000000000..9a2405ee1c24 --- /dev/null +++ b/vendor/github.com/golang/go/src/math/big/arith_amd64.s @@ -0,0 +1,450 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !math_big_pure_go + +#include "textflag.h" + +// This file provides fast assembly versions for the elementary +// arithmetic operations on vectors implemented in arith.go. + +// func mulWW(x, y Word) (z1, z0 Word) +TEXT ·mulWW(SB),NOSPLIT,$0 + MOVQ x+0(FP), AX + MULQ y+8(FP) + MOVQ DX, z1+16(FP) + MOVQ AX, z0+24(FP) + RET + + +// func divWW(x1, x0, y Word) (q, r Word) +TEXT ·divWW(SB),NOSPLIT,$0 + MOVQ x1+0(FP), DX + MOVQ x0+8(FP), AX + DIVQ y+16(FP) + MOVQ AX, q+24(FP) + MOVQ DX, r+32(FP) + RET + +// The carry bit is saved with SBBQ Rx, Rx: if the carry was set, Rx is -1, otherwise it is 0. +// It is restored with ADDQ Rx, Rx: if Rx was -1 the carry is set, otherwise it is cleared. +// This is faster than using rotate instructions. + +// func addVV(z, x, y []Word) (c Word) +TEXT ·addVV(SB),NOSPLIT,$0 + MOVQ z_len+8(FP), DI + MOVQ x+24(FP), R8 + MOVQ y+48(FP), R9 + MOVQ z+0(FP), R10 + + MOVQ $0, CX // c = 0 + MOVQ $0, SI // i = 0 + + // s/JL/JMP/ below to disable the unrolled loop + SUBQ $4, DI // n -= 4 + JL V1 // if n < 0 goto V1 + +U1: // n >= 0 + // regular loop body unrolled 4x + ADDQ CX, CX // restore CF + MOVQ 0(R8)(SI*8), R11 + MOVQ 8(R8)(SI*8), R12 + MOVQ 16(R8)(SI*8), R13 + MOVQ 24(R8)(SI*8), R14 + ADCQ 0(R9)(SI*8), R11 + ADCQ 8(R9)(SI*8), R12 + ADCQ 16(R9)(SI*8), R13 + ADCQ 24(R9)(SI*8), R14 + MOVQ R11, 0(R10)(SI*8) + MOVQ R12, 8(R10)(SI*8) + MOVQ R13, 16(R10)(SI*8) + MOVQ R14, 24(R10)(SI*8) + SBBQ CX, CX // save CF + + ADDQ $4, SI // i += 4 + SUBQ $4, DI // n -= 4 + JGE U1 // if n >= 0 goto U1 + +V1: ADDQ $4, DI // n += 4 + JLE E1 // if n <= 0 goto E1 + +L1: // n > 0 + ADDQ CX, CX // restore CF + MOVQ 0(R8)(SI*8), R11 + ADCQ 0(R9)(SI*8), R11 + MOVQ R11, 0(R10)(SI*8) + SBBQ CX, CX // save CF + + ADDQ $1, SI // i++ + SUBQ $1, DI // n-- + JG L1 // if n > 0 goto L1 + +E1: NEGQ CX + MOVQ CX, c+72(FP) // return c + RET + + +// func subVV(z, x, y []Word) (c Word) +// (same as addVV except for SBBQ instead of ADCQ and label names) +TEXT ·subVV(SB),NOSPLIT,$0 + MOVQ z_len+8(FP), DI + MOVQ x+24(FP), R8 + MOVQ y+48(FP), R9 + MOVQ z+0(FP), R10 + + MOVQ $0, CX // c = 0 + MOVQ $0, SI // i = 0 + + // s/JL/JMP/ below to disable the unrolled loop + SUBQ $4, DI // n -= 4 + JL V2 // if n < 0 goto V2 + +U2: // n >= 0 + // regular loop body unrolled 4x + ADDQ CX, CX // restore CF + MOVQ 0(R8)(SI*8), R11 + MOVQ 8(R8)(SI*8), R12 + MOVQ 16(R8)(SI*8), R13 + MOVQ 24(R8)(SI*8), R14 + SBBQ 0(R9)(SI*8), R11 + SBBQ 8(R9)(SI*8), R12 + SBBQ 16(R9)(SI*8), R13 + SBBQ 24(R9)(SI*8), R14 + MOVQ R11, 0(R10)(SI*8) + MOVQ R12, 8(R10)(SI*8) + MOVQ R13, 16(R10)(SI*8) + MOVQ R14, 24(R10)(SI*8) + SBBQ CX, CX // save CF + + ADDQ $4, SI // i += 4 + SUBQ $4, DI // n -= 4 + JGE U2 // if n >= 0 goto U2 + +V2: ADDQ $4, DI // n += 4 + JLE E2 // if n <= 0 goto E2 + +L2: // n > 0 + ADDQ CX, CX // restore CF + MOVQ 0(R8)(SI*8), R11 + SBBQ 0(R9)(SI*8), R11 + MOVQ R11, 0(R10)(SI*8) + SBBQ CX, CX // save CF + + ADDQ $1, SI // i++ + SUBQ $1, DI // n-- + JG L2 // if n > 0 goto L2 + +E2: NEGQ CX + MOVQ CX, c+72(FP) // return c + RET + + +// func addVW(z, x []Word, y Word) (c Word) +TEXT ·addVW(SB),NOSPLIT,$0 + MOVQ z_len+8(FP), DI + MOVQ x+24(FP), R8 + MOVQ y+48(FP), CX // c = y + MOVQ z+0(FP), R10 + + MOVQ $0, SI // i = 0 + + // s/JL/JMP/ below to disable the unrolled loop + SUBQ $4, DI // n -= 4 + JL V3 // if n < 4 goto V3 + +U3: // n >= 0 + // regular loop body unrolled 4x + MOVQ 0(R8)(SI*8), R11 + MOVQ 8(R8)(SI*8), R12 + MOVQ 16(R8)(SI*8), R13 + MOVQ 24(R8)(SI*8), R14 + ADDQ CX, R11 + ADCQ $0, R12 + ADCQ $0, R13 + ADCQ $0, R14 + SBBQ CX, CX // save CF + NEGQ CX + MOVQ R11, 0(R10)(SI*8) + MOVQ R12, 8(R10)(SI*8) + MOVQ R13, 16(R10)(SI*8) + MOVQ R14, 24(R10)(SI*8) + + ADDQ $4, SI // i += 4 + SUBQ $4, DI // n -= 4 + JGE U3 // if n >= 0 goto U3 + +V3: ADDQ $4, DI // n += 4 + JLE E3 // if n <= 0 goto E3 + +L3: // n > 0 + ADDQ 0(R8)(SI*8), CX + MOVQ CX, 0(R10)(SI*8) + SBBQ CX, CX // save CF + NEGQ CX + + ADDQ $1, SI // i++ + SUBQ $1, DI // n-- + JG L3 // if n > 0 goto L3 + +E3: MOVQ CX, c+56(FP) // return c + RET + + +// func subVW(z, x []Word, y Word) (c Word) +// (same as addVW except for SUBQ/SBBQ instead of ADDQ/ADCQ and label names) +TEXT ·subVW(SB),NOSPLIT,$0 + MOVQ z_len+8(FP), DI + MOVQ x+24(FP), R8 + MOVQ y+48(FP), CX // c = y + MOVQ z+0(FP), R10 + + MOVQ $0, SI // i = 0 + + // s/JL/JMP/ below to disable the unrolled loop + SUBQ $4, DI // n -= 4 + JL V4 // if n < 4 goto V4 + +U4: // n >= 0 + // regular loop body unrolled 4x + MOVQ 0(R8)(SI*8), R11 + MOVQ 8(R8)(SI*8), R12 + MOVQ 16(R8)(SI*8), R13 + MOVQ 24(R8)(SI*8), R14 + SUBQ CX, R11 + SBBQ $0, R12 + SBBQ $0, R13 + SBBQ $0, R14 + SBBQ CX, CX // save CF + NEGQ CX + MOVQ R11, 0(R10)(SI*8) + MOVQ R12, 8(R10)(SI*8) + MOVQ R13, 16(R10)(SI*8) + MOVQ R14, 24(R10)(SI*8) + + ADDQ $4, SI // i += 4 + SUBQ $4, DI // n -= 4 + JGE U4 // if n >= 0 goto U4 + +V4: ADDQ $4, DI // n += 4 + JLE E4 // if n <= 0 goto E4 + +L4: // n > 0 + MOVQ 0(R8)(SI*8), R11 + SUBQ CX, R11 + MOVQ R11, 0(R10)(SI*8) + SBBQ CX, CX // save CF + NEGQ CX + + ADDQ $1, SI // i++ + SUBQ $1, DI // n-- + JG L4 // if n > 0 goto L4 + +E4: MOVQ CX, c+56(FP) // return c + RET + + +// func shlVU(z, x []Word, s uint) (c Word) +TEXT ·shlVU(SB),NOSPLIT,$0 + MOVQ z_len+8(FP), BX // i = z + SUBQ $1, BX // i-- + JL X8b // i < 0 (n <= 0) + + // n > 0 + MOVQ z+0(FP), R10 + MOVQ x+24(FP), R8 + MOVQ s+48(FP), CX + MOVQ (R8)(BX*8), AX // w1 = x[n-1] + MOVQ $0, DX + SHLQ CX, DX:AX // w1>>ŝ + MOVQ DX, c+56(FP) + + CMPQ BX, $0 + JLE X8a // i <= 0 + + // i > 0 +L8: MOVQ AX, DX // w = w1 + MOVQ -8(R8)(BX*8), AX // w1 = x[i-1] + SHLQ CX, DX:AX // w<>ŝ + MOVQ DX, (R10)(BX*8) // z[i] = w<>ŝ + SUBQ $1, BX // i-- + JG L8 // i > 0 + + // i <= 0 +X8a: SHLQ CX, AX // w1< 0 + MOVQ z+0(FP), R10 + MOVQ x+24(FP), R8 + MOVQ s+48(FP), CX + MOVQ (R8), AX // w1 = x[0] + MOVQ $0, DX + SHRQ CX, DX:AX // w1<<ŝ + MOVQ DX, c+56(FP) + + MOVQ $0, BX // i = 0 + JMP E9 + + // i < n-1 +L9: MOVQ AX, DX // w = w1 + MOVQ 8(R8)(BX*8), AX // w1 = x[i+1] + SHRQ CX, DX:AX // w>>s | w1<<ŝ + MOVQ DX, (R10)(BX*8) // z[i] = w>>s | w1<<ŝ + ADDQ $1, BX // i++ + +E9: CMPQ BX, R11 + JL L9 // i < n-1 + + // i >= n-1 +X9a: SHRQ CX, AX // w1>>s + MOVQ AX, (R10)(R11*8) // z[n-1] = w1>>s + RET + +X9b: MOVQ $0, c+56(FP) + RET + + +// func mulAddVWW(z, x []Word, y, r Word) (c Word) +TEXT ·mulAddVWW(SB),NOSPLIT,$0 + MOVQ z+0(FP), R10 + MOVQ x+24(FP), R8 + MOVQ y+48(FP), R9 + MOVQ r+56(FP), CX // c = r + MOVQ z_len+8(FP), R11 + MOVQ $0, BX // i = 0 + + CMPQ R11, $4 + JL E5 + +U5: // i+4 <= n + // regular loop body unrolled 4x + MOVQ (0*8)(R8)(BX*8), AX + MULQ R9 + ADDQ CX, AX + ADCQ $0, DX + MOVQ AX, (0*8)(R10)(BX*8) + MOVQ DX, CX + MOVQ (1*8)(R8)(BX*8), AX + MULQ R9 + ADDQ CX, AX + ADCQ $0, DX + MOVQ AX, (1*8)(R10)(BX*8) + MOVQ DX, CX + MOVQ (2*8)(R8)(BX*8), AX + MULQ R9 + ADDQ CX, AX + ADCQ $0, DX + MOVQ AX, (2*8)(R10)(BX*8) + MOVQ DX, CX + MOVQ (3*8)(R8)(BX*8), AX + MULQ R9 + ADDQ CX, AX + ADCQ $0, DX + MOVQ AX, (3*8)(R10)(BX*8) + MOVQ DX, CX + ADDQ $4, BX // i += 4 + + LEAQ 4(BX), DX + CMPQ DX, R11 + JLE U5 + JMP E5 + +L5: MOVQ (R8)(BX*8), AX + MULQ R9 + ADDQ CX, AX + ADCQ $0, DX + MOVQ AX, (R10)(BX*8) + MOVQ DX, CX + ADDQ $1, BX // i++ + +E5: CMPQ BX, R11 // i < n + JL L5 + + MOVQ CX, c+64(FP) + RET + + +// func addMulVVW(z, x []Word, y Word) (c Word) +TEXT ·addMulVVW(SB),NOSPLIT,$0 + MOVQ z+0(FP), R10 + MOVQ x+24(FP), R8 + MOVQ y+48(FP), R9 + MOVQ z_len+8(FP), R11 + MOVQ $0, BX // i = 0 + MOVQ $0, CX // c = 0 + MOVQ R11, R12 + ANDQ $-2, R12 + CMPQ R11, $2 + JAE A6 + JMP E6 + +A6: + MOVQ (R8)(BX*8), AX + MULQ R9 + ADDQ (R10)(BX*8), AX + ADCQ $0, DX + ADDQ CX, AX + ADCQ $0, DX + MOVQ DX, CX + MOVQ AX, (R10)(BX*8) + + MOVQ (8)(R8)(BX*8), AX + MULQ R9 + ADDQ (8)(R10)(BX*8), AX + ADCQ $0, DX + ADDQ CX, AX + ADCQ $0, DX + MOVQ DX, CX + MOVQ AX, (8)(R10)(BX*8) + + ADDQ $2, BX + CMPQ BX, R12 + JL A6 + JMP E6 + +L6: MOVQ (R8)(BX*8), AX + MULQ R9 + ADDQ CX, AX + ADCQ $0, DX + ADDQ AX, (R10)(BX*8) + ADCQ $0, DX + MOVQ DX, CX + ADDQ $1, BX // i++ + +E6: CMPQ BX, R11 // i < n + JL L6 + + MOVQ CX, c+56(FP) + RET + + +// func divWVW(z []Word, xn Word, x []Word, y Word) (r Word) +TEXT ·divWVW(SB),NOSPLIT,$0 + MOVQ z+0(FP), R10 + MOVQ xn+24(FP), DX // r = xn + MOVQ x+32(FP), R8 + MOVQ y+56(FP), R9 + MOVQ z_len+8(FP), BX // i = z + JMP E7 + +L7: MOVQ (R8)(BX*8), AX + DIVQ R9 + MOVQ AX, (R10)(BX*8) + +E7: SUBQ $1, BX // i-- + JGE L7 // i >= 0 + + MOVQ DX, r+64(FP) + RET diff --git a/vendor/github.com/golang/go/src/math/big/arith_amd64p32.s b/vendor/github.com/golang/go/src/math/big/arith_amd64p32.s new file mode 100644 index 000000000000..0a672386ccd7 --- /dev/null +++ b/vendor/github.com/golang/go/src/math/big/arith_amd64p32.s @@ -0,0 +1,40 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !math_big_pure_go + +#include "textflag.h" + +TEXT ·mulWW(SB),NOSPLIT,$0 + JMP ·mulWW_g(SB) + +TEXT ·divWW(SB),NOSPLIT,$0 + JMP ·divWW_g(SB) + +TEXT ·addVV(SB),NOSPLIT,$0 + JMP ·addVV_g(SB) + +TEXT ·subVV(SB),NOSPLIT,$0 + JMP ·subVV_g(SB) + +TEXT ·addVW(SB),NOSPLIT,$0 + JMP ·addVW_g(SB) + +TEXT ·subVW(SB),NOSPLIT,$0 + JMP ·subVW_g(SB) + +TEXT ·shlVU(SB),NOSPLIT,$0 + JMP ·shlVU_g(SB) + +TEXT ·shrVU(SB),NOSPLIT,$0 + JMP ·shrVU_g(SB) + +TEXT ·mulAddVWW(SB),NOSPLIT,$0 + JMP ·mulAddVWW_g(SB) + +TEXT ·addMulVVW(SB),NOSPLIT,$0 + JMP ·addMulVVW_g(SB) + +TEXT ·divWVW(SB),NOSPLIT,$0 + JMP ·divWVW_g(SB) diff --git a/vendor/github.com/golang/go/src/math/big/arith_arm.s b/vendor/github.com/golang/go/src/math/big/arith_arm.s new file mode 100644 index 000000000000..ba65fd2b1fa5 --- /dev/null +++ b/vendor/github.com/golang/go/src/math/big/arith_arm.s @@ -0,0 +1,294 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !math_big_pure_go + +#include "textflag.h" + +// This file provides fast assembly versions for the elementary +// arithmetic operations on vectors implemented in arith.go. + +// func addVV(z, x, y []Word) (c Word) +TEXT ·addVV(SB),NOSPLIT,$0 + ADD.S $0, R0 // clear carry flag + MOVW z+0(FP), R1 + MOVW z_len+4(FP), R4 + MOVW x+12(FP), R2 + MOVW y+24(FP), R3 + ADD R4<<2, R1, R4 + B E1 +L1: + MOVW.P 4(R2), R5 + MOVW.P 4(R3), R6 + ADC.S R6, R5 + MOVW.P R5, 4(R1) +E1: + TEQ R1, R4 + BNE L1 + + MOVW $0, R0 + MOVW.CS $1, R0 + MOVW R0, c+36(FP) + RET + + +// func subVV(z, x, y []Word) (c Word) +// (same as addVV except for SBC instead of ADC and label names) +TEXT ·subVV(SB),NOSPLIT,$0 + SUB.S $0, R0 // clear borrow flag + MOVW z+0(FP), R1 + MOVW z_len+4(FP), R4 + MOVW x+12(FP), R2 + MOVW y+24(FP), R3 + ADD R4<<2, R1, R4 + B E2 +L2: + MOVW.P 4(R2), R5 + MOVW.P 4(R3), R6 + SBC.S R6, R5 + MOVW.P R5, 4(R1) +E2: + TEQ R1, R4 + BNE L2 + + MOVW $0, R0 + MOVW.CC $1, R0 + MOVW R0, c+36(FP) + RET + + +// func addVW(z, x []Word, y Word) (c Word) +TEXT ·addVW(SB),NOSPLIT,$0 + MOVW z+0(FP), R1 + MOVW z_len+4(FP), R4 + MOVW x+12(FP), R2 + MOVW y+24(FP), R3 + ADD R4<<2, R1, R4 + TEQ R1, R4 + BNE L3a + MOVW R3, c+28(FP) + RET +L3a: + MOVW.P 4(R2), R5 + ADD.S R3, R5 + MOVW.P R5, 4(R1) + B E3 +L3: + MOVW.P 4(R2), R5 + ADC.S $0, R5 + MOVW.P R5, 4(R1) +E3: + TEQ R1, R4 + BNE L3 + + MOVW $0, R0 + MOVW.CS $1, R0 + MOVW R0, c+28(FP) + RET + + +// func subVW(z, x []Word, y Word) (c Word) +TEXT ·subVW(SB),NOSPLIT,$0 + MOVW z+0(FP), R1 + MOVW z_len+4(FP), R4 + MOVW x+12(FP), R2 + MOVW y+24(FP), R3 + ADD R4<<2, R1, R4 + TEQ R1, R4 + BNE L4a + MOVW R3, c+28(FP) + RET +L4a: + MOVW.P 4(R2), R5 + SUB.S R3, R5 + MOVW.P R5, 4(R1) + B E4 +L4: + MOVW.P 4(R2), R5 + SBC.S $0, R5 + MOVW.P R5, 4(R1) +E4: + TEQ R1, R4 + BNE L4 + + MOVW $0, R0 + MOVW.CC $1, R0 + MOVW R0, c+28(FP) + RET + + +// func shlVU(z, x []Word, s uint) (c Word) +TEXT ·shlVU(SB),NOSPLIT,$0 + MOVW z_len+4(FP), R5 + TEQ $0, R5 + BEQ X7 + + MOVW z+0(FP), R1 + MOVW x+12(FP), R2 + ADD R5<<2, R2, R2 + ADD R5<<2, R1, R5 + MOVW s+24(FP), R3 + TEQ $0, R3 // shift 0 is special + BEQ Y7 + ADD $4, R1 // stop one word early + MOVW $32, R4 + SUB R3, R4 + MOVW $0, R7 + + MOVW.W -4(R2), R6 + MOVW R6<>R4, R6 + MOVW R6, c+28(FP) + B E7 + +L7: + MOVW.W -4(R2), R6 + ORR R6>>R4, R7 + MOVW.W R7, -4(R5) + MOVW R6<>R3, R7 + MOVW R6<>R3, R7 +E6: + TEQ R1, R5 + BNE L6 + + MOVW R7, 0(R1) + RET + +Y6: // copy loop, because shift 0 == shift 32 + MOVW.P 4(R2), R6 + MOVW.P R6, 4(R1) + TEQ R1, R5 + BNE Y6 + +X6: + MOVW $0, R1 + MOVW R1, c+28(FP) + RET + + +// func mulAddVWW(z, x []Word, y, r Word) (c Word) +TEXT ·mulAddVWW(SB),NOSPLIT,$0 + MOVW $0, R0 + MOVW z+0(FP), R1 + MOVW z_len+4(FP), R5 + MOVW x+12(FP), R2 + MOVW y+24(FP), R3 + MOVW r+28(FP), R4 + ADD R5<<2, R1, R5 + B E8 + + // word loop +L8: + MOVW.P 4(R2), R6 + MULLU R6, R3, (R7, R6) + ADD.S R4, R6 + ADC R0, R7 + MOVW.P R6, 4(R1) + MOVW R7, R4 +E8: + TEQ R1, R5 + BNE L8 + + MOVW R4, c+32(FP) + RET + + +// func addMulVVW(z, x []Word, y Word) (c Word) +TEXT ·addMulVVW(SB),NOSPLIT,$0 + MOVW $0, R0 + MOVW z+0(FP), R1 + MOVW z_len+4(FP), R5 + MOVW x+12(FP), R2 + MOVW y+24(FP), R3 + ADD R5<<2, R1, R5 + MOVW $0, R4 + B E9 + + // word loop +L9: + MOVW.P 4(R2), R6 + MULLU R6, R3, (R7, R6) + ADD.S R4, R6 + ADC R0, R7 + MOVW 0(R1), R4 + ADD.S R4, R6 + ADC R0, R7 + MOVW.P R6, 4(R1) + MOVW R7, R4 +E9: + TEQ R1, R5 + BNE L9 + + MOVW R4, c+28(FP) + RET + + +// func divWVW(z* Word, xn Word, x []Word, y Word) (r Word) +TEXT ·divWVW(SB),NOSPLIT,$0 + // ARM has no multiword division, so use portable code. + B ·divWVW_g(SB) + + +// func divWW(x1, x0, y Word) (q, r Word) +TEXT ·divWW(SB),NOSPLIT,$0 + // ARM has no multiword division, so use portable code. + B ·divWW_g(SB) + + +// func mulWW(x, y Word) (z1, z0 Word) +TEXT ·mulWW(SB),NOSPLIT,$0 + MOVW x+0(FP), R1 + MOVW y+4(FP), R2 + MULLU R1, R2, (R4, R3) + MOVW R4, z1+8(FP) + MOVW R3, z0+12(FP) + RET diff --git a/vendor/github.com/golang/go/src/math/big/arith_arm64.s b/vendor/github.com/golang/go/src/math/big/arith_arm64.s new file mode 100644 index 000000000000..397b4630a846 --- /dev/null +++ b/vendor/github.com/golang/go/src/math/big/arith_arm64.s @@ -0,0 +1,167 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !math_big_pure_go + +#include "textflag.h" + +// This file provides fast assembly versions for the elementary +// arithmetic operations on vectors implemented in arith.go. + +// TODO: Consider re-implementing using Advanced SIMD +// once the assembler supports those instructions. + +// func mulWW(x, y Word) (z1, z0 Word) +TEXT ·mulWW(SB),NOSPLIT,$0 + MOVD x+0(FP), R0 + MOVD y+8(FP), R1 + MUL R0, R1, R2 + UMULH R0, R1, R3 + MOVD R3, z1+16(FP) + MOVD R2, z0+24(FP) + RET + + +// func divWW(x1, x0, y Word) (q, r Word) +TEXT ·divWW(SB),NOSPLIT,$0 + B ·divWW_g(SB) // ARM64 has no multiword division + + +// func addVV(z, x, y []Word) (c Word) +TEXT ·addVV(SB),NOSPLIT,$0 + MOVD z+0(FP), R3 + MOVD z_len+8(FP), R0 + MOVD x+24(FP), R1 + MOVD y+48(FP), R2 + ADDS $0, R0 // clear carry flag +loop: + CBZ R0, done // careful not to touch the carry flag + MOVD.P 8(R1), R4 + MOVD.P 8(R2), R5 + ADCS R4, R5 + MOVD.P R5, 8(R3) + SUB $1, R0 + B loop +done: + CSET HS, R0 // extract carry flag + MOVD R0, c+72(FP) + RET + + +// func subVV(z, x, y []Word) (c Word) +TEXT ·subVV(SB),NOSPLIT,$0 + MOVD z+0(FP), R3 + MOVD z_len+8(FP), R0 + MOVD x+24(FP), R1 + MOVD y+48(FP), R2 + CMP R0, R0 // set carry flag +loop: + CBZ R0, done // careful not to touch the carry flag + MOVD.P 8(R1), R4 + MOVD.P 8(R2), R5 + SBCS R5, R4 + MOVD.P R4, 8(R3) + SUB $1, R0 + B loop +done: + CSET LO, R0 // extract carry flag + MOVD R0, c+72(FP) + RET + + +// func addVW(z, x []Word, y Word) (c Word) +TEXT ·addVW(SB),NOSPLIT,$0 + MOVD z+0(FP), R3 + MOVD z_len+8(FP), R0 + MOVD x+24(FP), R1 + MOVD y+48(FP), R2 + CBZ R0, return_y + MOVD.P 8(R1), R4 + ADDS R2, R4 + MOVD.P R4, 8(R3) + SUB $1, R0 +loop: + CBZ R0, done // careful not to touch the carry flag + MOVD.P 8(R1), R4 + ADCS $0, R4 + MOVD.P R4, 8(R3) + SUB $1, R0 + B loop +done: + CSET HS, R0 // extract carry flag + MOVD R0, c+56(FP) + RET +return_y: // z is empty; copy y to c + MOVD R2, c+56(FP) + RET + + +// func subVW(z, x []Word, y Word) (c Word) +TEXT ·subVW(SB),NOSPLIT,$0 + MOVD z+0(FP), R3 + MOVD z_len+8(FP), R0 + MOVD x+24(FP), R1 + MOVD y+48(FP), R2 + CBZ R0, rety + MOVD.P 8(R1), R4 + SUBS R2, R4 + MOVD.P R4, 8(R3) + SUB $1, R0 +loop: + CBZ R0, done // careful not to touch the carry flag + MOVD.P 8(R1), R4 + SBCS $0, R4 + MOVD.P R4, 8(R3) + SUB $1, R0 + B loop +done: + CSET LO, R0 // extract carry flag + MOVD R0, c+56(FP) + RET +rety: // z is empty; copy y to c + MOVD R2, c+56(FP) + RET + + +// func shlVU(z, x []Word, s uint) (c Word) +TEXT ·shlVU(SB),NOSPLIT,$0 + B ·shlVU_g(SB) + + +// func shrVU(z, x []Word, s uint) (c Word) +TEXT ·shrVU(SB),NOSPLIT,$0 + B ·shrVU_g(SB) + + +// func mulAddVWW(z, x []Word, y, r Word) (c Word) +TEXT ·mulAddVWW(SB),NOSPLIT,$0 + MOVD z+0(FP), R1 + MOVD z_len+8(FP), R0 + MOVD x+24(FP), R2 + MOVD y+48(FP), R3 + MOVD r+56(FP), R4 +loop: + CBZ R0, done + MOVD.P 8(R2), R5 + UMULH R5, R3, R7 + MUL R5, R3, R6 + ADDS R4, R6 + ADC $0, R7 + MOVD.P R6, 8(R1) + MOVD R7, R4 + SUB $1, R0 + B loop +done: + MOVD R4, c+64(FP) + RET + + +// func addMulVVW(z, x []Word, y Word) (c Word) +TEXT ·addMulVVW(SB),NOSPLIT,$0 + B ·addMulVVW_g(SB) + + +// func divWVW(z []Word, xn Word, x []Word, y Word) (r Word) +TEXT ·divWVW(SB),NOSPLIT,$0 + B ·divWVW_g(SB) diff --git a/vendor/github.com/golang/go/src/math/big/arith_decl.go b/vendor/github.com/golang/go/src/math/big/arith_decl.go new file mode 100644 index 000000000000..41e592334c37 --- /dev/null +++ b/vendor/github.com/golang/go/src/math/big/arith_decl.go @@ -0,0 +1,20 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !math_big_pure_go + +package big + +// implemented in arith_$GOARCH.s +func mulWW(x, y Word) (z1, z0 Word) +func divWW(x1, x0, y Word) (q, r Word) +func addVV(z, x, y []Word) (c Word) +func subVV(z, x, y []Word) (c Word) +func addVW(z, x []Word, y Word) (c Word) +func subVW(z, x []Word, y Word) (c Word) +func shlVU(z, x []Word, s uint) (c Word) +func shrVU(z, x []Word, s uint) (c Word) +func mulAddVWW(z, x []Word, y, r Word) (c Word) +func addMulVVW(z, x []Word, y Word) (c Word) +func divWVW(z []Word, xn Word, x []Word, y Word) (r Word) diff --git a/vendor/github.com/golang/go/src/math/big/arith_decl_pure.go b/vendor/github.com/golang/go/src/math/big/arith_decl_pure.go new file mode 100644 index 000000000000..4ae49c123d9f --- /dev/null +++ b/vendor/github.com/golang/go/src/math/big/arith_decl_pure.go @@ -0,0 +1,51 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build math_big_pure_go + +package big + +func mulWW(x, y Word) (z1, z0 Word) { + return mulWW_g(x, y) +} + +func divWW(x1, x0, y Word) (q, r Word) { + return divWW_g(x1, x0, y) +} + +func addVV(z, x, y []Word) (c Word) { + return addVV_g(z, x, y) +} + +func subVV(z, x, y []Word) (c Word) { + return subVV_g(z, x, y) +} + +func addVW(z, x []Word, y Word) (c Word) { + return addVW_g(z, x, y) +} + +func subVW(z, x []Word, y Word) (c Word) { + return subVW_g(z, x, y) +} + +func shlVU(z, x []Word, s uint) (c Word) { + return shlVU_g(z, x, s) +} + +func shrVU(z, x []Word, s uint) (c Word) { + return shrVU_g(z, x, s) +} + +func mulAddVWW(z, x []Word, y, r Word) (c Word) { + return mulAddVWW_g(z, x, y, r) +} + +func addMulVVW(z, x []Word, y Word) (c Word) { + return addMulVVW_g(z, x, y) +} + +func divWVW(z []Word, xn Word, x []Word, y Word) (r Word) { + return divWVW_g(z, xn, x, y) +} diff --git a/vendor/github.com/golang/go/src/math/big/arith_decl_s390x.go b/vendor/github.com/golang/go/src/math/big/arith_decl_s390x.go new file mode 100644 index 000000000000..0f11481f6d2e --- /dev/null +++ b/vendor/github.com/golang/go/src/math/big/arith_decl_s390x.go @@ -0,0 +1,23 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !math_big_pure_go + +package big + +func addVV_check(z, x, y []Word) (c Word) +func addVV_vec(z, x, y []Word) (c Word) +func addVV_novec(z, x, y []Word) (c Word) +func subVV_check(z, x, y []Word) (c Word) +func subVV_vec(z, x, y []Word) (c Word) +func subVV_novec(z, x, y []Word) (c Word) +func addVW_check(z, x []Word, y Word) (c Word) +func addVW_vec(z, x []Word, y Word) (c Word) +func addVW_novec(z, x []Word, y Word) (c Word) +func subVW_check(z, x []Word, y Word) (c Word) +func subVW_vec(z, x []Word, y Word) (c Word) +func subVW_novec(z, x []Word, y Word) (c Word) +func hasVectorFacility() bool + +var hasVX = hasVectorFacility() diff --git a/vendor/github.com/golang/go/src/math/big/arith_mips64x.s b/vendor/github.com/golang/go/src/math/big/arith_mips64x.s new file mode 100644 index 000000000000..983510ee3d42 --- /dev/null +++ b/vendor/github.com/golang/go/src/math/big/arith_mips64x.s @@ -0,0 +1,43 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !math_big_pure_go,mips64 !math_big_pure_go,mips64le + +#include "textflag.h" + +// This file provides fast assembly versions for the elementary +// arithmetic operations on vectors implemented in arith.go. + +TEXT ·mulWW(SB),NOSPLIT,$0 + JMP ·mulWW_g(SB) + +TEXT ·divWW(SB),NOSPLIT,$0 + JMP ·divWW_g(SB) + +TEXT ·addVV(SB),NOSPLIT,$0 + JMP ·addVV_g(SB) + +TEXT ·subVV(SB),NOSPLIT,$0 + JMP ·subVV_g(SB) + +TEXT ·addVW(SB),NOSPLIT,$0 + JMP ·addVW_g(SB) + +TEXT ·subVW(SB),NOSPLIT,$0 + JMP ·subVW_g(SB) + +TEXT ·shlVU(SB),NOSPLIT,$0 + JMP ·shlVU_g(SB) + +TEXT ·shrVU(SB),NOSPLIT,$0 + JMP ·shrVU_g(SB) + +TEXT ·mulAddVWW(SB),NOSPLIT,$0 + JMP ·mulAddVWW_g(SB) + +TEXT ·addMulVVW(SB),NOSPLIT,$0 + JMP ·addMulVVW_g(SB) + +TEXT ·divWVW(SB),NOSPLIT,$0 + JMP ·divWVW_g(SB) diff --git a/vendor/github.com/golang/go/src/math/big/arith_mipsx.s b/vendor/github.com/golang/go/src/math/big/arith_mipsx.s new file mode 100644 index 000000000000..54cafbd9c0c8 --- /dev/null +++ b/vendor/github.com/golang/go/src/math/big/arith_mipsx.s @@ -0,0 +1,43 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !math_big_pure_go,mips !math_big_pure_go,mipsle + +#include "textflag.h" + +// This file provides fast assembly versions for the elementary +// arithmetic operations on vectors implemented in arith.go. + +TEXT ·mulWW(SB),NOSPLIT,$0 + JMP ·mulWW_g(SB) + +TEXT ·divWW(SB),NOSPLIT,$0 + JMP ·divWW_g(SB) + +TEXT ·addVV(SB),NOSPLIT,$0 + JMP ·addVV_g(SB) + +TEXT ·subVV(SB),NOSPLIT,$0 + JMP ·subVV_g(SB) + +TEXT ·addVW(SB),NOSPLIT,$0 + JMP ·addVW_g(SB) + +TEXT ·subVW(SB),NOSPLIT,$0 + JMP ·subVW_g(SB) + +TEXT ·shlVU(SB),NOSPLIT,$0 + JMP ·shlVU_g(SB) + +TEXT ·shrVU(SB),NOSPLIT,$0 + JMP ·shrVU_g(SB) + +TEXT ·mulAddVWW(SB),NOSPLIT,$0 + JMP ·mulAddVWW_g(SB) + +TEXT ·addMulVVW(SB),NOSPLIT,$0 + JMP ·addMulVVW_g(SB) + +TEXT ·divWVW(SB),NOSPLIT,$0 + JMP ·divWVW_g(SB) diff --git a/vendor/github.com/golang/go/src/math/big/arith_ppc64x.s b/vendor/github.com/golang/go/src/math/big/arith_ppc64x.s new file mode 100644 index 000000000000..74db48933f8a --- /dev/null +++ b/vendor/github.com/golang/go/src/math/big/arith_ppc64x.s @@ -0,0 +1,197 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !math_big_pure_go,ppc64 !math_big_pure_go,ppc64le + +#include "textflag.h" + +// This file provides fast assembly versions for the elementary +// arithmetic operations on vectors implemented in arith.go. + +// func mulWW(x, y Word) (z1, z0 Word) +TEXT ·mulWW(SB), NOSPLIT, $0 + MOVD x+0(FP), R4 + MOVD y+8(FP), R5 + MULHDU R4, R5, R6 + MULLD R4, R5, R7 + MOVD R6, z1+16(FP) + MOVD R7, z0+24(FP) + RET + +// func addVV(z, y, y []Word) (c Word) +// z[i] = x[i] + y[i] for all i, carrying +TEXT ·addVV(SB), NOSPLIT, $0 + MOVD z_len+8(FP), R7 + MOVD x+24(FP), R8 + MOVD y+48(FP), R9 + MOVD z+0(FP), R10 + + MOVD R0, R4 + MOVD R0, R6 // R6 will be the address index + ADDC R4, R4 // clear CA + MOVD R7, CTR + + CMP R0, R7 + BEQ done + +loop: + MOVD (R8)(R6), R11 // x[i] + MOVD (R9)(R6), R12 // y[i] + ADDE R12, R11, R15 // x[i] + y[i] + CA + MOVD R15, (R10)(R6) // z[i] + + ADD $8, R6 + BC 16, 0, loop // bdnz + +done: + ADDZE R4 + MOVD R4, c+72(FP) + RET + +// func subVV(z, x, y []Word) (c Word) +// z[i] = x[i] - y[i] for all i, carrying +TEXT ·subVV(SB), NOSPLIT, $0 + MOVD z_len+8(FP), R7 + MOVD x+24(FP), R8 + MOVD y+48(FP), R9 + MOVD z+0(FP), R10 + + MOVD R0, R4 // c = 0 + MOVD R0, R6 + SUBC R0, R0 // clear CA + MOVD R7, CTR + + CMP R0, R7 + BEQ sublend + +// amd64 saves and restores CF, but I believe they only have to do that because all of +// their math operations clobber it - we should just be able to recover it at the end. +subloop: + MOVD (R8)(R6), R11 // x[i] + MOVD (R9)(R6), R12 // y[i] + + SUBE R12, R11, R15 + MOVD R15, (R10)(R6) + + ADD $8, R6 + BC 16, 0, subloop // bdnz + +sublend: + + ADDZE R4 + XOR $1, R4 + MOVD R4, c+72(FP) + RET + +TEXT ·addVW(SB), NOSPLIT, $0 + BR ·addVW_g(SB) + +TEXT ·subVW(SB), NOSPLIT, $0 + BR ·subVW_g(SB) + +TEXT ·shlVU(SB), NOSPLIT, $0 + BR ·shlVU_g(SB) + +TEXT ·shrVU(SB), NOSPLIT, $0 + BR ·shrVU_g(SB) + +// func mulAddVWW(z, x []Word, y, r Word) (c Word) +TEXT ·mulAddVWW(SB), NOSPLIT, $0 + MOVD z+0(FP), R10 // R10 = z[] + MOVD x+24(FP), R8 // R8 = x[] + MOVD y+48(FP), R9 // R9 = y + MOVD r+56(FP), R4 // R4 = r = c + MOVD z_len+8(FP), R11 // R11 = z_len + + MOVD R0, R3 // R3 will be the index register + CMP R0, R11 + MOVD R11, CTR // Initialize loop counter + BEQ done + +loop: + MOVD (R8)(R3), R20 // x[i] + MULLD R9, R20, R6 // R6 = z0 = Low-order(x[i]*y) + MULHDU R9, R20, R7 // R7 = z1 = High-order(x[i]*y) + ADDC R4, R6 // Compute sum for z1 and z0 + ADDZE R7 + MOVD R6, (R10)(R3) // z[i] + MOVD R7, R4 // c + ADD $8, R3 + BC 16, 0, loop // bdnz + +done: + MOVD R4, c+64(FP) + RET + +// func addMulVVW(z, x []Word, y Word) (c Word) +TEXT ·addMulVVW(SB), NOSPLIT, $0 + MOVD z+0(FP), R10 // R10 = z[] + MOVD x+24(FP), R8 // R8 = x[] + MOVD y+48(FP), R9 // R9 = y + MOVD z_len+8(FP), R22 // R22 = z_len + + MOVD R0, R3 // R3 will be the index register + CMP R0, R22 + MOVD R0, R4 // R4 = c = 0 + MOVD R22, CTR // Initialize loop counter + BEQ done + +loop: + MOVD (R8)(R3), R20 // Load x[i] + MOVD (R10)(R3), R21 // Load z[i] + MULLD R9, R20, R6 // R6 = Low-order(x[i]*y) + MULHDU R9, R20, R7 // R7 = High-order(x[i]*y) + ADDC R21, R6 // R6 = z0 + ADDZE R7 // R7 = z1 + ADDC R4, R6 // R6 = z0 + c + 0 + ADDZE R7, R4 // c += z1 + MOVD R6, (R10)(R3) // Store z[i] + ADD $8, R3 + BC 16, 0, loop // bdnz + +done: + MOVD R4, c+56(FP) + RET + +// func divWW(x1, x0, y Word) (q, r Word) +TEXT ·divWW(SB), NOSPLIT, $0 + MOVD x1+0(FP), R4 + MOVD x0+8(FP), R5 + MOVD y+16(FP), R6 + + CMPU R4, R6 + BGE divbigger + + // from the programmer's note in ch. 3 of the ISA manual, p.74 + DIVDEU R6, R4, R3 + DIVDU R6, R5, R7 + MULLD R6, R3, R8 + MULLD R6, R7, R20 + SUB R20, R5, R10 + ADD R7, R3, R3 + SUB R8, R10, R4 + CMPU R4, R10 + BLT adjust + CMPU R4, R6 + BLT end + +adjust: + MOVD $1, R21 + ADD R21, R3, R3 + SUB R6, R4, R4 + +end: + MOVD R3, q+24(FP) + MOVD R4, r+32(FP) + + RET + +divbigger: + MOVD $-1, R7 + MOVD R7, q+24(FP) + MOVD R7, r+32(FP) + RET + +TEXT ·divWVW(SB), NOSPLIT, $0 + BR ·divWVW_g(SB) diff --git a/vendor/github.com/golang/go/src/math/big/arith_s390x.s b/vendor/github.com/golang/go/src/math/big/arith_s390x.s new file mode 100644 index 000000000000..4520d161d779 --- /dev/null +++ b/vendor/github.com/golang/go/src/math/big/arith_s390x.s @@ -0,0 +1,1239 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !math_big_pure_go,s390x + +#include "textflag.h" + +// This file provides fast assembly versions for the elementary +// arithmetic operations on vectors implemented in arith.go. + +TEXT ·hasVectorFacility(SB),NOSPLIT,$24-1 + MOVD $x-24(SP), R1 + XC $24, 0(R1), 0(R1) // clear the storage + MOVD $2, R0 // R0 is the number of double words stored -1 + WORD $0xB2B01000 // STFLE 0(R1) + XOR R0, R0 // reset the value of R0 + MOVBZ z-8(SP), R1 + AND $0x40, R1 + BEQ novector +vectorinstalled: + // check if the vector instruction has been enabled + VLEIB $0, $0xF, V16 + VLGVB $0, V16, R1 + CMPBNE R1, $0xF, novector + MOVB $1, ret+0(FP) // have vx + RET +novector: + MOVB $0, ret+0(FP) // no vx + RET + +TEXT ·mulWW(SB),NOSPLIT,$0 + MOVD x+0(FP), R3 + MOVD y+8(FP), R4 + MULHDU R3, R4 + MOVD R10, z1+16(FP) + MOVD R11, z0+24(FP) + RET + +// func divWW(x1, x0, y Word) (q, r Word) +TEXT ·divWW(SB),NOSPLIT,$0 + MOVD x1+0(FP), R10 + MOVD x0+8(FP), R11 + MOVD y+16(FP), R5 + WORD $0xb98700a5 // dlgr r10,r5 + MOVD R11, q+24(FP) + MOVD R10, r+32(FP) + RET + +// DI = R3, CX = R4, SI = r10, r8 = r8, r9=r9, r10 = r2 , r11 = r5, r12 = r6, r13 = r7, r14 = r1 (R0 set to 0) + use R11 +// func addVV(z, x, y []Word) (c Word) + + +TEXT ·addVV(SB),NOSPLIT,$0 + MOVD addvectorfacility+0x00(SB),R1 + BR (R1) + +TEXT ·addVV_check(SB),NOSPLIT, $0 + MOVB ·hasVX(SB), R1 + CMPBEQ R1, $1, vectorimpl // vectorfacility = 1, vector supported + MOVD $addvectorfacility+0x00(SB), R1 + MOVD $·addVV_novec(SB), R2 + MOVD R2, 0(R1) + //MOVD $·addVV_novec(SB), 0(R1) + BR ·addVV_novec(SB) +vectorimpl: + MOVD $addvectorfacility+0x00(SB), R1 + MOVD $·addVV_vec(SB), R2 + MOVD R2, 0(R1) + //MOVD $·addVV_vec(SB), 0(R1) + BR ·addVV_vec(SB) + +GLOBL addvectorfacility+0x00(SB), NOPTR, $8 +DATA addvectorfacility+0x00(SB)/8, $·addVV_check(SB) + +TEXT ·addVV_vec(SB),NOSPLIT,$0 + MOVD z_len+8(FP), R3 + MOVD x+24(FP), R8 + MOVD y+48(FP), R9 + MOVD z+0(FP), R2 + + MOVD $0, R4 // c = 0 + MOVD $0, R0 // make sure it's zero + MOVD $0, R10 // i = 0 + + + // s/JL/JMP/ below to disable the unrolled loop + SUB $4, R3 + BLT v1 + SUB $12, R3 // n -= 16 + BLT A1 // if n < 0 goto A1 + + MOVD R8, R5 + MOVD R9, R6 + MOVD R2, R7 + // n >= 0 + // regular loop body unrolled 16x + VZERO V0 // c = 0 +UU1: VLM 0(R5), V1, V4 // 64-bytes into V1..V8 + ADD $64, R5 + VPDI $0x4,V1,V1,V1 // flip the doublewords to big-endian order + VPDI $0x4,V2,V2,V2 // flip the doublewords to big-endian order + + + VLM 0(R6), V9, V12 // 64-bytes into V9..V16 + ADD $64, R6 + VPDI $0x4,V9,V9,V9 // flip the doublewords to big-endian order + VPDI $0x4,V10,V10,V10 // flip the doublewords to big-endian order + + VACCCQ V1, V9, V0, V25 + VACQ V1, V9, V0, V17 + VACCCQ V2, V10, V25, V26 + VACQ V2, V10, V25, V18 + + + VLM 0(R5), V5, V6 // 32-bytes into V1..V8 + VLM 0(R6), V13, V14 // 32-bytes into V9..V16 + ADD $32, R5 + ADD $32, R6 + + VPDI $0x4,V3,V3,V3 // flip the doublewords to big-endian order + VPDI $0x4,V4,V4,V4 // flip the doublewords to big-endian order + VPDI $0x4,V11,V11,V11 // flip the doublewords to big-endian order + VPDI $0x4,V12,V12,V12 // flip the doublewords to big-endian order + + VACCCQ V3, V11, V26, V27 + VACQ V3, V11, V26, V19 + VACCCQ V4, V12, V27, V28 + VACQ V4, V12, V27, V20 + + VLM 0(R5), V7, V8 // 32-bytes into V1..V8 + VLM 0(R6), V15, V16 // 32-bytes into V9..V16 + ADD $32, R5 + ADD $32, R6 + + VPDI $0x4,V5,V5,V5 // flip the doublewords to big-endian order + VPDI $0x4,V6,V6,V6 // flip the doublewords to big-endian order + VPDI $0x4,V13,V13,V13 // flip the doublewords to big-endian order + VPDI $0x4,V14,V14,V14 // flip the doublewords to big-endian order + + VACCCQ V5, V13, V28, V29 + VACQ V5, V13, V28, V21 + VACCCQ V6, V14, V29, V30 + VACQ V6, V14, V29, V22 + + VPDI $0x4,V7,V7,V7 // flip the doublewords to big-endian order + VPDI $0x4,V8,V8,V8 // flip the doublewords to big-endian order + VPDI $0x4,V15,V15,V15 // flip the doublewords to big-endian order + VPDI $0x4,V16,V16,V16 // flip the doublewords to big-endian order + + VACCCQ V7, V15, V30, V31 + VACQ V7, V15, V30, V23 + VACCCQ V8, V16, V31, V0 //V0 has carry-over + VACQ V8, V16, V31, V24 + + VPDI $0x4,V17,V17,V17 // flip the doublewords to big-endian order + VPDI $0x4,V18,V18,V18 // flip the doublewords to big-endian order + VPDI $0x4,V19,V19,V19 // flip the doublewords to big-endian order + VPDI $0x4,V20,V20,V20 // flip the doublewords to big-endian order + VPDI $0x4,V21,V21,V21 // flip the doublewords to big-endian order + VPDI $0x4,V22,V22,V22 // flip the doublewords to big-endian order + VPDI $0x4,V23,V23,V23 // flip the doublewords to big-endian order + VPDI $0x4,V24,V24,V24 // flip the doublewords to big-endian order + VSTM V17, V24, 0(R7) // 128-bytes into z + ADD $128, R7 + ADD $128, R10 // i += 16 + SUB $16, R3 // n -= 16 + BGE UU1 // if n >= 0 goto U1 + VLGVG $1, V0, R4 // put cf into R4 + NEG R4, R4 // save cf + +A1: ADD $12, R3 // n += 16 + + + // s/JL/JMP/ below to disable the unrolled loop + BLT v1 // if n < 0 goto v1 + +U1: // n >= 0 + // regular loop body unrolled 4x + MOVD 0(R8)(R10*1), R5 + MOVD 8(R8)(R10*1), R6 + MOVD 16(R8)(R10*1), R7 + MOVD 24(R8)(R10*1), R1 + ADDC R4, R4 // restore CF + MOVD 0(R9)(R10*1), R11 + ADDE R11, R5 + MOVD 8(R9)(R10*1), R11 + ADDE R11, R6 + MOVD 16(R9)(R10*1), R11 + ADDE R11, R7 + MOVD 24(R9)(R10*1), R11 + ADDE R11, R1 + MOVD R0, R4 + ADDE R4, R4 // save CF + NEG R4, R4 + MOVD R5, 0(R2)(R10*1) + MOVD R6, 8(R2)(R10*1) + MOVD R7, 16(R2)(R10*1) + MOVD R1, 24(R2)(R10*1) + + + ADD $32, R10 // i += 4 + SUB $4, R3 // n -= 4 + BGE U1 // if n >= 0 goto U1 + +v1: ADD $4, R3 // n += 4 + BLE E1 // if n <= 0 goto E1 + +L1: // n > 0 + ADDC R4, R4 // restore CF + MOVD 0(R8)(R10*1), R5 + MOVD 0(R9)(R10*1), R11 + ADDE R11, R5 + MOVD R5, 0(R2)(R10*1) + MOVD R0, R4 + ADDE R4, R4 // save CF + NEG R4, R4 + + ADD $8, R10 // i++ + SUB $1, R3 // n-- + BGT L1 // if n > 0 goto L1 + +E1: NEG R4, R4 + MOVD R4, c+72(FP) // return c + RET + +TEXT ·addVV_novec(SB),NOSPLIT,$0 +novec: + MOVD z_len+8(FP), R3 + MOVD x+24(FP), R8 + MOVD y+48(FP), R9 + MOVD z+0(FP), R2 + + MOVD $0, R4 // c = 0 + MOVD $0, R0 // make sure it's zero + MOVD $0, R10 // i = 0 + + // s/JL/JMP/ below to disable the unrolled loop + SUB $4, R3 // n -= 4 + BLT v1n // if n < 0 goto v1n +U1n: // n >= 0 + // regular loop body unrolled 4x + MOVD 0(R8)(R10*1), R5 + MOVD 8(R8)(R10*1), R6 + MOVD 16(R8)(R10*1), R7 + MOVD 24(R8)(R10*1), R1 + ADDC R4, R4 // restore CF + MOVD 0(R9)(R10*1), R11 + ADDE R11, R5 + MOVD 8(R9)(R10*1), R11 + ADDE R11, R6 + MOVD 16(R9)(R10*1), R11 + ADDE R11, R7 + MOVD 24(R9)(R10*1), R11 + ADDE R11, R1 + MOVD R0, R4 + ADDE R4, R4 // save CF + NEG R4, R4 + MOVD R5, 0(R2)(R10*1) + MOVD R6, 8(R2)(R10*1) + MOVD R7, 16(R2)(R10*1) + MOVD R1, 24(R2)(R10*1) + + + ADD $32, R10 // i += 4 + SUB $4, R3 // n -= 4 + BGE U1n // if n >= 0 goto U1n + +v1n: ADD $4, R3 // n += 4 + BLE E1n // if n <= 0 goto E1n + +L1n: // n > 0 + ADDC R4, R4 // restore CF + MOVD 0(R8)(R10*1), R5 + MOVD 0(R9)(R10*1), R11 + ADDE R11, R5 + MOVD R5, 0(R2)(R10*1) + MOVD R0, R4 + ADDE R4, R4 // save CF + NEG R4, R4 + + ADD $8, R10 // i++ + SUB $1, R3 // n-- + BGT L1n // if n > 0 goto L1n + +E1n: NEG R4, R4 + MOVD R4, c+72(FP) // return c + RET + + +TEXT ·subVV(SB),NOSPLIT,$0 + MOVD subvectorfacility+0x00(SB),R1 + BR (R1) + +TEXT ·subVV_check(SB),NOSPLIT,$0 + MOVB ·hasVX(SB), R1 + CMPBEQ R1, $1, vectorimpl // vectorfacility = 1, vector supported + MOVD $subvectorfacility+0x00(SB), R1 + MOVD $·subVV_novec(SB), R2 + MOVD R2, 0(R1) + //MOVD $·subVV_novec(SB), 0(R1) + BR ·subVV_novec(SB) +vectorimpl: + MOVD $subvectorfacility+0x00(SB), R1 + MOVD $·subVV_vec(SB), R2 + MOVD R2, 0(R1) + //MOVD $·subVV_vec(SB), 0(R1) + BR ·subVV_vec(SB) + +GLOBL subvectorfacility+0x00(SB), NOPTR, $8 +DATA subvectorfacility+0x00(SB)/8, $·subVV_check(SB) + +// DI = R3, CX = R4, SI = r10, r8 = r8, r9=r9, r10 = r2 , r11 = r5, r12 = r6, r13 = r7, r14 = r1 (R0 set to 0) + use R11 +// func subVV(z, x, y []Word) (c Word) +// (same as addVV except for SUBC/SUBE instead of ADDC/ADDE and label names) +TEXT ·subVV_vec(SB),NOSPLIT,$0 + MOVD z_len+8(FP), R3 + MOVD x+24(FP), R8 + MOVD y+48(FP), R9 + MOVD z+0(FP), R2 + MOVD $0, R4 // c = 0 + MOVD $0, R0 // make sure it's zero + MOVD $0, R10 // i = 0 + + // s/JL/JMP/ below to disable the unrolled loop + SUB $4, R3 // n -= 4 + BLT v1 // if n < 0 goto v1 + SUB $12, R3 // n -= 16 + BLT A1 // if n < 0 goto A1 + + MOVD R8, R5 + MOVD R9, R6 + MOVD R2, R7 + + // n >= 0 + // regular loop body unrolled 16x + VZERO V0 // cf = 0 + MOVD $1, R4 // for 390 subtraction cf starts as 1 (no borrow) + VLVGG $1, R4, V0 //put carry into V0 + +UU1: VLM 0(R5), V1, V4 // 64-bytes into V1..V8 + ADD $64, R5 + VPDI $0x4,V1,V1,V1 // flip the doublewords to big-endian order + VPDI $0x4,V2,V2,V2 // flip the doublewords to big-endian order + + + VLM 0(R6), V9, V12 // 64-bytes into V9..V16 + ADD $64, R6 + VPDI $0x4,V9,V9,V9 // flip the doublewords to big-endian order + VPDI $0x4,V10,V10,V10 // flip the doublewords to big-endian order + + VSBCBIQ V1, V9, V0, V25 + VSBIQ V1, V9, V0, V17 + VSBCBIQ V2, V10, V25, V26 + VSBIQ V2, V10, V25, V18 + + + VLM 0(R5), V5, V6 // 32-bytes into V1..V8 + VLM 0(R6), V13, V14 // 32-bytes into V9..V16 + ADD $32, R5 + ADD $32, R6 + + VPDI $0x4,V3,V3,V3 // flip the doublewords to big-endian order + VPDI $0x4,V4,V4,V4 // flip the doublewords to big-endian order + VPDI $0x4,V11,V11,V11 // flip the doublewords to big-endian order + VPDI $0x4,V12,V12,V12 // flip the doublewords to big-endian order + + VSBCBIQ V3, V11, V26, V27 + VSBIQ V3, V11, V26, V19 + VSBCBIQ V4, V12, V27, V28 + VSBIQ V4, V12, V27, V20 + + VLM 0(R5), V7, V8 // 32-bytes into V1..V8 + VLM 0(R6), V15, V16 // 32-bytes into V9..V16 + ADD $32, R5 + ADD $32, R6 + + VPDI $0x4,V5,V5,V5 // flip the doublewords to big-endian order + VPDI $0x4,V6,V6,V6 // flip the doublewords to big-endian order + VPDI $0x4,V13,V13,V13 // flip the doublewords to big-endian order + VPDI $0x4,V14,V14,V14 // flip the doublewords to big-endian order + + VSBCBIQ V5, V13, V28, V29 + VSBIQ V5, V13, V28, V21 + VSBCBIQ V6, V14, V29, V30 + VSBIQ V6, V14, V29, V22 + + VPDI $0x4,V7,V7,V7 // flip the doublewords to big-endian order + VPDI $0x4,V8,V8,V8 // flip the doublewords to big-endian order + VPDI $0x4,V15,V15,V15 // flip the doublewords to big-endian order + VPDI $0x4,V16,V16,V16 // flip the doublewords to big-endian order + + VSBCBIQ V7, V15, V30, V31 + VSBIQ V7, V15, V30, V23 + VSBCBIQ V8, V16, V31, V0 //V0 has carry-over + VSBIQ V8, V16, V31, V24 + + VPDI $0x4,V17,V17,V17 // flip the doublewords to big-endian order + VPDI $0x4,V18,V18,V18 // flip the doublewords to big-endian order + VPDI $0x4,V19,V19,V19 // flip the doublewords to big-endian order + VPDI $0x4,V20,V20,V20 // flip the doublewords to big-endian order + VPDI $0x4,V21,V21,V21 // flip the doublewords to big-endian order + VPDI $0x4,V22,V22,V22 // flip the doublewords to big-endian order + VPDI $0x4,V23,V23,V23 // flip the doublewords to big-endian order + VPDI $0x4,V24,V24,V24 // flip the doublewords to big-endian order + VSTM V17, V24, 0(R7) // 128-bytes into z + ADD $128, R7 + ADD $128, R10 // i += 16 + SUB $16, R3 // n -= 16 + BGE UU1 // if n >= 0 goto U1 + VLGVG $1, V0, R4 // put cf into R4 + SUB $1, R4 // save cf + +A1: ADD $12, R3 // n += 16 + BLT v1 // if n < 0 goto v1 + +U1: // n >= 0 + // regular loop body unrolled 4x + MOVD 0(R8)(R10*1), R5 + MOVD 8(R8)(R10*1), R6 + MOVD 16(R8)(R10*1), R7 + MOVD 24(R8)(R10*1), R1 + MOVD R0, R11 + SUBC R4, R11 // restore CF + MOVD 0(R9)(R10*1), R11 + SUBE R11, R5 + MOVD 8(R9)(R10*1), R11 + SUBE R11, R6 + MOVD 16(R9)(R10*1), R11 + SUBE R11, R7 + MOVD 24(R9)(R10*1), R11 + SUBE R11, R1 + MOVD R0, R4 + SUBE R4, R4 // save CF + MOVD R5, 0(R2)(R10*1) + MOVD R6, 8(R2)(R10*1) + MOVD R7, 16(R2)(R10*1) + MOVD R1, 24(R2)(R10*1) + + ADD $32, R10 // i += 4 + SUB $4, R3 // n -= 4 + BGE U1 // if n >= 0 goto U1n + +v1: ADD $4, R3 // n += 4 + BLE E1 // if n <= 0 goto E1 + +L1: // n > 0 + MOVD R0, R11 + SUBC R4, R11 // restore CF + MOVD 0(R8)(R10*1), R5 + MOVD 0(R9)(R10*1), R11 + SUBE R11, R5 + MOVD R5, 0(R2)(R10*1) + MOVD R0, R4 + SUBE R4, R4 // save CF + + ADD $8, R10 // i++ + SUB $1, R3 // n-- + BGT L1 // if n > 0 goto L1n + +E1: NEG R4, R4 + MOVD R4, c+72(FP) // return c + RET + + +// DI = R3, CX = R4, SI = r10, r8 = r8, r9=r9, r10 = r2 , r11 = r5, r12 = r6, r13 = r7, r14 = r1 (R0 set to 0) + use R11 +// func subVV(z, x, y []Word) (c Word) +// (same as addVV except for SUBC/SUBE instead of ADDC/ADDE and label names) +TEXT ·subVV_novec(SB),NOSPLIT,$0 + MOVD z_len+8(FP), R3 + MOVD x+24(FP), R8 + MOVD y+48(FP), R9 + MOVD z+0(FP), R2 + + MOVD $0, R4 // c = 0 + MOVD $0, R0 // make sure it's zero + MOVD $0, R10 // i = 0 + + // s/JL/JMP/ below to disable the unrolled loop + SUB $4, R3 // n -= 4 + BLT v1 // if n < 0 goto v1 + +U1: // n >= 0 + // regular loop body unrolled 4x + MOVD 0(R8)(R10*1), R5 + MOVD 8(R8)(R10*1), R6 + MOVD 16(R8)(R10*1), R7 + MOVD 24(R8)(R10*1), R1 + MOVD R0, R11 + SUBC R4, R11 // restore CF + MOVD 0(R9)(R10*1), R11 + SUBE R11, R5 + MOVD 8(R9)(R10*1), R11 + SUBE R11, R6 + MOVD 16(R9)(R10*1), R11 + SUBE R11, R7 + MOVD 24(R9)(R10*1), R11 + SUBE R11, R1 + MOVD R0, R4 + SUBE R4, R4 // save CF + MOVD R5, 0(R2)(R10*1) + MOVD R6, 8(R2)(R10*1) + MOVD R7, 16(R2)(R10*1) + MOVD R1, 24(R2)(R10*1) + + + ADD $32, R10 // i += 4 + SUB $4, R3 // n -= 4 + BGE U1 // if n >= 0 goto U1 + +v1: ADD $4, R3 // n += 4 + BLE E1 // if n <= 0 goto E1 + +L1: // n > 0 + MOVD R0, R11 + SUBC R4, R11 // restore CF + MOVD 0(R8)(R10*1), R5 + MOVD 0(R9)(R10*1), R11 + SUBE R11, R5 + MOVD R5, 0(R2)(R10*1) + MOVD R0, R4 + SUBE R4, R4 // save CF + + ADD $8, R10 // i++ + SUB $1, R3 // n-- + BGT L1 // if n > 0 goto L1 + +E1: NEG R4, R4 + MOVD R4, c+72(FP) // return c + RET + +TEXT ·addVW(SB),NOSPLIT,$0 + MOVD addwvectorfacility+0x00(SB),R1 + BR (R1) + +TEXT ·addVW_check(SB),NOSPLIT,$0 + MOVB ·hasVX(SB), R1 + CMPBEQ R1, $1, vectorimpl // vectorfacility = 1, vector supported + MOVD $addwvectorfacility+0x00(SB), R1 + MOVD $·addVW_novec(SB), R2 + MOVD R2, 0(R1) + //MOVD $·addVW_novec(SB), 0(R1) + BR ·addVW_novec(SB) +vectorimpl: + MOVD $addwvectorfacility+0x00(SB), R1 + MOVD $·addVW_vec(SB), R2 + MOVD R2, 0(R1) + //MOVD $·addVW_vec(SB), 0(R1) + BR ·addVW_vec(SB) + +GLOBL addwvectorfacility+0x00(SB), NOPTR, $8 +DATA addwvectorfacility+0x00(SB)/8, $·addVW_check(SB) + + +// func addVW_vec(z, x []Word, y Word) (c Word) +TEXT ·addVW_vec(SB),NOSPLIT,$0 + MOVD z_len+8(FP), R3 + MOVD x+24(FP), R8 + MOVD y+48(FP), R4 // c = y + MOVD z+0(FP), R2 + + MOVD $0, R0 // make sure it's zero + MOVD $0, R10 // i = 0 + MOVD R8, R5 + MOVD R2, R7 + + // s/JL/JMP/ below to disable the unrolled loop + SUB $4, R3 // n -= 4 + BLT v10 // if n < 0 goto v10 + SUB $12, R3 + BLT A10 + + // n >= 0 + // regular loop body unrolled 16x + + VZERO V0 // prepare V0 to be final carry register + VZERO V9 // to ensure upper half is zero + VLVGG $1, R4, V9 +UU1: VLM 0(R5), V1, V4 // 64-bytes into V1..V4 + ADD $64, R5 + VPDI $0x4,V1,V1,V1 // flip the doublewords to big-endian order + VPDI $0x4,V2,V2,V2 // flip the doublewords to big-endian order + + + VACCCQ V1, V9, V0, V25 + VACQ V1, V9, V0, V17 + VZERO V9 + VACCCQ V2, V9, V25, V26 + VACQ V2, V9, V25, V18 + + + VLM 0(R5), V5, V6 // 32-bytes into V5..V6 + ADD $32, R5 + + VPDI $0x4,V3,V3,V3 // flip the doublewords to big-endian order + VPDI $0x4,V4,V4,V4 // flip the doublewords to big-endian order + + VACCCQ V3, V9, V26, V27 + VACQ V3, V9, V26, V19 + VACCCQ V4, V9, V27, V28 + VACQ V4, V9, V27, V20 + + VLM 0(R5), V7, V8 // 32-bytes into V7..V8 + ADD $32, R5 + + VPDI $0x4,V5,V5,V5 // flip the doublewords to big-endian order + VPDI $0x4,V6,V6,V6 // flip the doublewords to big-endian order + + VACCCQ V5, V9, V28, V29 + VACQ V5, V9, V28, V21 + VACCCQ V6, V9, V29, V30 + VACQ V6, V9, V29, V22 + + VPDI $0x4,V7,V7,V7 // flip the doublewords to big-endian order + VPDI $0x4,V8,V8,V8 // flip the doublewords to big-endian order + + VACCCQ V7, V9, V30, V31 + VACQ V7, V9, V30, V23 + VACCCQ V8, V9, V31, V0 //V0 has carry-over + VACQ V8, V9, V31, V24 + + VPDI $0x4,V17,V17,V17 // flip the doublewords to big-endian order + VPDI $0x4,V18,V18,V18 // flip the doublewords to big-endian order + VPDI $0x4,V19,V19,V19 // flip the doublewords to big-endian order + VPDI $0x4,V20,V20,V20 // flip the doublewords to big-endian order + VPDI $0x4,V21,V21,V21 // flip the doublewords to big-endian order + VPDI $0x4,V22,V22,V22 // flip the doublewords to big-endian order + VPDI $0x4,V23,V23,V23 // flip the doublewords to big-endian order + VPDI $0x4,V24,V24,V24 // flip the doublewords to big-endian order + VSTM V17, V24, 0(R7) // 128-bytes into z + ADD $128, R7 + ADD $128, R10 // i += 16 + SUB $16, R3 // n -= 16 + BGE UU1 // if n >= 0 goto U1 + VLGVG $1, V0, R4 // put cf into R4 in case we branch to v10 + +A10: ADD $12, R3 // n += 16 + + + // s/JL/JMP/ below to disable the unrolled loop + + BLT v10 // if n < 0 goto v10 + + +U4: // n >= 0 + // regular loop body unrolled 4x + MOVD 0(R8)(R10*1), R5 + MOVD 8(R8)(R10*1), R6 + MOVD 16(R8)(R10*1), R7 + MOVD 24(R8)(R10*1), R1 + ADDC R4, R5 + ADDE R0, R6 + ADDE R0, R7 + ADDE R0, R1 + ADDE R0, R0 + MOVD R0, R4 // save CF + SUB R0, R0 + MOVD R5, 0(R2)(R10*1) + MOVD R6, 8(R2)(R10*1) + MOVD R7, 16(R2)(R10*1) + MOVD R1, 24(R2)(R10*1) + + ADD $32, R10 // i += 4 -> i +=32 + SUB $4, R3 // n -= 4 + BGE U4 // if n >= 0 goto U4 + +v10: ADD $4, R3 // n += 4 + BLE E10 // if n <= 0 goto E4 + + +L4: // n > 0 + MOVD 0(R8)(R10*1), R5 + ADDC R4, R5 + ADDE R0, R0 + MOVD R0, R4 // save CF + SUB R0, R0 + MOVD R5, 0(R2)(R10*1) + + ADD $8, R10 // i++ + SUB $1, R3 // n-- + BGT L4 // if n > 0 goto L4 + +E10: MOVD R4, c+56(FP) // return c + + RET + + +TEXT ·addVW_novec(SB),NOSPLIT,$0 +//DI = R3, CX = R4, SI = r10, r8 = r8, r10 = r2 , r11 = r5, r12 = r6, r13 = r7, r14 = r1 (R0 set to 0) + MOVD z_len+8(FP), R3 + MOVD x+24(FP), R8 + MOVD y+48(FP), R4 // c = y + MOVD z+0(FP), R2 + MOVD $0, R0 // make sure it's 0 + MOVD $0, R10 // i = 0 + + // s/JL/JMP/ below to disable the unrolled loop + SUB $4, R3 // n -= 4 + BLT v4 // if n < 4 goto v4 + +U4: // n >= 0 + // regular loop body unrolled 4x + MOVD 0(R8)(R10*1), R5 + MOVD 8(R8)(R10*1), R6 + MOVD 16(R8)(R10*1), R7 + MOVD 24(R8)(R10*1), R1 + ADDC R4, R5 + ADDE R0, R6 + ADDE R0, R7 + ADDE R0, R1 + ADDE R0, R0 + MOVD R0, R4 // save CF + SUB R0, R0 + MOVD R5, 0(R2)(R10*1) + MOVD R6, 8(R2)(R10*1) + MOVD R7, 16(R2)(R10*1) + MOVD R1, 24(R2)(R10*1) + + ADD $32, R10 // i += 4 -> i +=32 + SUB $4, R3 // n -= 4 + BGE U4 // if n >= 0 goto U4 + +v4: ADD $4, R3 // n += 4 + BLE E4 // if n <= 0 goto E4 + +L4: // n > 0 + MOVD 0(R8)(R10*1), R5 + ADDC R4, R5 + ADDE R0, R0 + MOVD R0, R4 // save CF + SUB R0, R0 + MOVD R5, 0(R2)(R10*1) + + ADD $8, R10 // i++ + SUB $1, R3 // n-- + BGT L4 // if n > 0 goto L4 + +E4: MOVD R4, c+56(FP) // return c + + RET + +TEXT ·subVW(SB),NOSPLIT,$0 + MOVD subwvectorfacility+0x00(SB),R1 + BR (R1) + +TEXT ·subVW_check(SB),NOSPLIT,$0 + MOVB ·hasVX(SB), R1 + CMPBEQ R1, $1, vectorimpl // vectorfacility = 1, vector supported + MOVD $subwvectorfacility+0x00(SB), R1 + MOVD $·subVW_novec(SB), R2 + MOVD R2, 0(R1) + //MOVD $·subVW_novec(SB), 0(R1) + BR ·subVW_novec(SB) +vectorimpl: + MOVD $subwvectorfacility+0x00(SB), R1 + MOVD $·subVW_vec(SB), R2 + MOVD R2, 0(R1) + //MOVD $·subVW_vec(SB), 0(R1) + BR ·subVW_vec(SB) + +GLOBL subwvectorfacility+0x00(SB), NOPTR, $8 +DATA subwvectorfacility+0x00(SB)/8, $·subVW_check(SB) + +// func subVW(z, x []Word, y Word) (c Word) +TEXT ·subVW_vec(SB),NOSPLIT,$0 + MOVD z_len+8(FP), R3 + MOVD x+24(FP), R8 + MOVD y+48(FP), R4 // c = y + MOVD z+0(FP), R2 + + MOVD $0, R0 // make sure it's zero + MOVD $0, R10 // i = 0 + MOVD R8, R5 + MOVD R2, R7 + + // s/JL/JMP/ below to disable the unrolled loop + SUB $4, R3 // n -= 4 + BLT v11 // if n < 0 goto v11 + SUB $12, R3 + BLT A11 + + VZERO V0 + MOVD $1, R6 // prepare V0 to be final carry register + VLVGG $1, R6, V0 // borrow is initially "no borrow" + VZERO V9 // to ensure upper half is zero + VLVGG $1, R4, V9 + + // n >= 0 + // regular loop body unrolled 16x + + +UU1: VLM 0(R5), V1, V4 // 64-bytes into V1..V4 + ADD $64, R5 + VPDI $0x4,V1,V1,V1 // flip the doublewords to big-endian order + VPDI $0x4,V2,V2,V2 // flip the doublewords to big-endian order + + + VSBCBIQ V1, V9, V0, V25 + VSBIQ V1, V9, V0, V17 + VZERO V9 + VSBCBIQ V2, V9, V25, V26 + VSBIQ V2, V9, V25, V18 + + VLM 0(R5), V5, V6 // 32-bytes into V5..V6 + ADD $32, R5 + + VPDI $0x4,V3,V3,V3 // flip the doublewords to big-endian order + VPDI $0x4,V4,V4,V4 // flip the doublewords to big-endian order + + + VSBCBIQ V3, V9, V26, V27 + VSBIQ V3, V9, V26, V19 + VSBCBIQ V4, V9, V27, V28 + VSBIQ V4, V9, V27, V20 + + VLM 0(R5), V7, V8 // 32-bytes into V7..V8 + ADD $32, R5 + + VPDI $0x4,V5,V5,V5 // flip the doublewords to big-endian order + VPDI $0x4,V6,V6,V6 // flip the doublewords to big-endian order + + VSBCBIQ V5, V9, V28, V29 + VSBIQ V5, V9, V28, V21 + VSBCBIQ V6, V9, V29, V30 + VSBIQ V6, V9, V29, V22 + + VPDI $0x4,V7,V7,V7 // flip the doublewords to big-endian order + VPDI $0x4,V8,V8,V8 // flip the doublewords to big-endian order + + VSBCBIQ V7, V9, V30, V31 + VSBIQ V7, V9, V30, V23 + VSBCBIQ V8, V9, V31, V0 // V0 has carry-over + VSBIQ V8, V9, V31, V24 + + VPDI $0x4,V17,V17,V17 // flip the doublewords to big-endian order + VPDI $0x4,V18,V18,V18 // flip the doublewords to big-endian order + VPDI $0x4,V19,V19,V19 // flip the doublewords to big-endian order + VPDI $0x4,V20,V20,V20 // flip the doublewords to big-endian order + VPDI $0x4,V21,V21,V21 // flip the doublewords to big-endian order + VPDI $0x4,V22,V22,V22 // flip the doublewords to big-endian order + VPDI $0x4,V23,V23,V23 // flip the doublewords to big-endian order + VPDI $0x4,V24,V24,V24 // flip the doublewords to big-endian order + VSTM V17, V24, 0(R7) // 128-bytes into z + ADD $128, R7 + ADD $128, R10 // i += 16 + SUB $16, R3 // n -= 16 + BGE UU1 // if n >= 0 goto U1 + VLGVG $1, V0, R4 // put cf into R4 in case we branch to v10 + SUB $1, R4 // save cf + NEG R4, R4 +A11: ADD $12, R3 // n += 16 + + BLT v11 // if n < 0 goto v11 + + // n >= 0 + // regular loop body unrolled 4x + +U4: // n >= 0 + // regular loop body unrolled 4x + MOVD 0(R8)(R10*1), R5 + MOVD 8(R8)(R10*1), R6 + MOVD 16(R8)(R10*1), R7 + MOVD 24(R8)(R10*1), R1 + SUBC R4, R5 //SLGR -> SUBC + SUBE R0, R6 //SLBGR -> SUBE + SUBE R0, R7 + SUBE R0, R1 + SUBE R4, R4 // save CF + NEG R4, R4 + MOVD R5, 0(R2)(R10*1) + MOVD R6, 8(R2)(R10*1) + MOVD R7, 16(R2)(R10*1) + MOVD R1, 24(R2)(R10*1) + + ADD $32, R10 // i += 4 -> i +=32 + SUB $4, R3 // n -= 4 + BGE U4 // if n >= 0 goto U4 + +v11: ADD $4, R3 // n += 4 + BLE E11 // if n <= 0 goto E4 + +L4: // n > 0 + + MOVD 0(R8)(R10*1), R5 + SUBC R4, R5 + SUBE R4, R4 // save CF + NEG R4, R4 + MOVD R5, 0(R2)(R10*1) + + ADD $8, R10 // i++ + SUB $1, R3 // n-- + BGT L4 // if n > 0 goto L4 + +E11: MOVD R4, c+56(FP) // return c + + RET + +//DI = R3, CX = R4, SI = r10, r8 = r8, r10 = r2 , r11 = r5, r12 = r6, r13 = r7, r14 = r1 (R0 set to 0) +// func subVW(z, x []Word, y Word) (c Word) +// (same as addVW except for SUBC/SUBE instead of ADDC/ADDE and label names) +TEXT ·subVW_novec(SB),NOSPLIT,$0 + MOVD z_len+8(FP), R3 + MOVD x+24(FP), R8 + MOVD y+48(FP), R4 // c = y + MOVD z+0(FP), R2 + MOVD $0, R0 // make sure it's 0 + MOVD $0, R10 // i = 0 + + // s/JL/JMP/ below to disable the unrolled loop + SUB $4, R3 // n -= 4 + BLT v4 // if n < 4 goto v4 + +U4: // n >= 0 + // regular loop body unrolled 4x + MOVD 0(R8)(R10*1), R5 + MOVD 8(R8)(R10*1), R6 + MOVD 16(R8)(R10*1), R7 + MOVD 24(R8)(R10*1), R1 + SUBC R4, R5 //SLGR -> SUBC + SUBE R0, R6 //SLBGR -> SUBE + SUBE R0, R7 + SUBE R0, R1 + SUBE R4, R4 // save CF + NEG R4, R4 + MOVD R5, 0(R2)(R10*1) + MOVD R6, 8(R2)(R10*1) + MOVD R7, 16(R2)(R10*1) + MOVD R1, 24(R2)(R10*1) + + ADD $32, R10 // i += 4 -> i +=32 + SUB $4, R3 // n -= 4 + BGE U4 // if n >= 0 goto U4 + +v4: ADD $4, R3 // n += 4 + BLE E4 // if n <= 0 goto E4 + +L4: // n > 0 + MOVD 0(R8)(R10*1), R5 + SUBC R4, R5 + SUBE R4, R4 // save CF + NEG R4, R4 + MOVD R5, 0(R2)(R10*1) + + ADD $8, R10 // i++ + SUB $1, R3 // n-- + BGT L4 // if n > 0 goto L4 + +E4: MOVD R4, c+56(FP) // return c + + RET + +// func shlVU(z, x []Word, s uint) (c Word) +TEXT ·shlVU(SB),NOSPLIT,$0 + MOVD z_len+8(FP), R5 + MOVD $0, R0 + SUB $1, R5 // n-- + BLT X8b // n < 0 (n <= 0) + + // n > 0 + MOVD s+48(FP), R4 + CMPBEQ R0, R4, Z80 //handle 0 case beq + MOVD $64, R6 + CMPBEQ R6, R4, Z864 //handle 64 case beq + MOVD z+0(FP), R2 + MOVD x+24(FP), R8 + SLD $3, R5 // n = n*8 + SUB R4, R6, R7 + MOVD (R8)(R5*1), R10 // w1 = x[i-1] + SRD R7, R10, R3 + MOVD R3, c+56(FP) + + MOVD $0, R1 // i = 0 + BR E8 + + // i < n-1 +L8: MOVD R10, R3 // w = w1 + MOVD -8(R8)(R5*1), R10 // w1 = x[i+1] + + SLD R4, R3 // w<>ŝ + SRD R7, R10, R6 + OR R6, R3 + MOVD R3, (R2)(R5*1) // z[i] = w<>ŝ + SUB $8, R5 // i-- + +E8: CMPBGT R5, R0, L8 // i < n-1 + + // i >= n-1 +X8a: SLD R4, R10 // w1<= n-1 + MOVD R10, (R2)(R5*1) + RET + +Z864: MOVD z+0(FP), R2 + MOVD x+24(FP), R8 + SLD $3, R5 // n = n*8 + MOVD (R8)(R5*1), R3 // w1 = x[n-1] + MOVD R3, c+56(FP) // z[i] = x[n-1] + + BR E864 + + // i < n-1 +L864: MOVD -8(R8)(R5*1), R3 + + MOVD R3, (R2)(R5*1) // z[i] = x[n-1] + SUB $8, R5 // i-- + +E864: CMPBGT R5, R0, L864 // i < n-1 + + MOVD R0, (R2) // z[n-1] = 0 + RET + + +// CX = R4, r8 = r8, r10 = r2 , r11 = r5, DX = r3, AX = r10 , BX = R1 , 64-count = r7 (R0 set to 0) temp = R6 +// func shrVU(z, x []Word, s uint) (c Word) +TEXT ·shrVU(SB),NOSPLIT,$0 + MOVD z_len+8(FP), R5 + MOVD $0, R0 + SUB $1, R5 // n-- + BLT X9b // n < 0 (n <= 0) + + // n > 0 + MOVD s+48(FP), R4 + CMPBEQ R0, R4, ZB0 //handle 0 case beq + MOVD $64, R6 + CMPBEQ R6, R4, ZB64 //handle 64 case beq + MOVD z+0(FP), R2 + MOVD x+24(FP), R8 + SLD $3, R5 // n = n*8 + SUB R4, R6, R7 + MOVD (R8), R10 // w1 = x[0] + SLD R7, R10, R3 + MOVD R3, c+56(FP) + + MOVD $0, R1 // i = 0 + BR E9 + + // i < n-1 +L9: MOVD R10, R3 // w = w1 + MOVD 8(R8)(R1*1), R10 // w1 = x[i+1] + + SRD R4, R3 // w>>s | w1<>s | w1<= n-1 +X9a: SRD R4, R10 // w1>>s + MOVD R10, (R2)(R5*1) // z[n-1] = w1>>s + RET + +X9b: MOVD R0, c+56(FP) + RET + +ZB0: MOVD z+0(FP), R2 + MOVD x+24(FP), R8 + SLD $3, R5 // n = n*8 + + MOVD (R8), R10 // w1 = x[0] + MOVD $0, R3 // R10 << 64 + MOVD R3, c+56(FP) + + MOVD $0, R1 // i = 0 + BR E9Z + + // i < n-1 +L9Z: MOVD R10, R3 // w = w1 + MOVD 8(R8)(R1*1), R10 // w1 = x[i+1] + + MOVD R3, (R2)(R1*1) // z[i] = w>>s | w1<= n-1 + MOVD R10, (R2)(R5*1) // z[n-1] = w1>>s + RET + +ZB64: MOVD z+0(FP), R2 + MOVD x+24(FP), R8 + SLD $3, R5 // n = n*8 + MOVD (R8), R3 // w1 = x[0] + MOVD R3, c+56(FP) + + MOVD $0, R1 // i = 0 + BR E964 + + // i < n-1 +L964: MOVD 8(R8)(R1*1), R3 // w1 = x[i+1] + + MOVD R3, (R2)(R1*1) // z[i] = w>>s | w1<= n-1 + MOVD $0, R10 // w1>>s + MOVD R10, (R2)(R5*1) // z[n-1] = w1>>s + RET + +// CX = R4, r8 = r8, r9=r9, r10 = r2 , r11 = r5, DX = r3, AX = r6 , BX = R1 , (R0 set to 0) + use R11 + use R7 for i +// func mulAddVWW(z, x []Word, y, r Word) (c Word) +TEXT ·mulAddVWW(SB),NOSPLIT,$0 + MOVD z+0(FP), R2 + MOVD x+24(FP), R8 + MOVD y+48(FP), R9 + MOVD r+56(FP), R4 // c = r + MOVD z_len+8(FP), R5 + MOVD $0, R1 // i = 0 + MOVD $0, R7 // i*8 = 0 + MOVD $0, R0 // make sure it's zero + BR E5 + +L5: MOVD (R8)(R1*1), R6 + MULHDU R9, R6 + ADDC R4, R11 //add to low order bits + ADDE R0, R6 + MOVD R11, (R2)(R1*1) + MOVD R6, R4 + ADD $8, R1 // i*8 + 8 + ADD $1, R7 // i++ + +E5: CMPBLT R7, R5, L5 // i < n + + MOVD R4, c+64(FP) + RET + +// func addMulVVW(z, x []Word, y Word) (c Word) +// CX = R4, r8 = r8, r9=r9, r10 = r2 , r11 = r5, AX = r11, DX = R6, r12=r12, BX = R1 , (R0 set to 0) + use R11 + use R7 for i +TEXT ·addMulVVW(SB),NOSPLIT,$0 + MOVD z+0(FP), R2 + MOVD x+24(FP), R8 + MOVD y+48(FP), R9 + MOVD z_len+8(FP), R5 + + MOVD $0, R1 // i*8 = 0 + MOVD $0, R7 // i = 0 + MOVD $0, R0 // make sure it's zero + MOVD $0, R4 // c = 0 + + MOVD R5, R12 + AND $-2, R12 + CMPBGE R5, $2, A6 + BR E6 + +A6: MOVD (R8)(R1*1), R6 + MULHDU R9, R6 + MOVD (R2)(R1*1), R10 + ADDC R10, R11 //add to low order bits + ADDE R0, R6 + ADDC R4, R11 + ADDE R0, R6 + MOVD R6, R4 + MOVD R11, (R2)(R1*1) + + MOVD (8)(R8)(R1*1), R6 + MULHDU R9, R6 + MOVD (8)(R2)(R1*1), R10 + ADDC R10, R11 //add to low order bits + ADDE R0, R6 + ADDC R4, R11 + ADDE R0, R6 + MOVD R6, R4 + MOVD R11, (8)(R2)(R1*1) + + ADD $16, R1 // i*8 + 8 + ADD $2, R7 // i++ + + CMPBLT R7, R12, A6 + BR E6 + +L6: MOVD (R8)(R1*1), R6 + MULHDU R9, R6 + MOVD (R2)(R1*1), R10 + ADDC R10, R11 //add to low order bits + ADDE R0, R6 + ADDC R4, R11 + ADDE R0, R6 + MOVD R6, R4 + MOVD R11, (R2)(R1*1) + + ADD $8, R1 // i*8 + 8 + ADD $1, R7 // i++ + +E6: CMPBLT R7, R5, L6 // i < n + + MOVD R4, c+56(FP) + RET + +// func divWVW(z []Word, xn Word, x []Word, y Word) (r Word) +// CX = R4, r8 = r8, r9=r9, r10 = r2 , r11 = r5, AX = r11, DX = R6, r12=r12, BX = R1(*8) , (R0 set to 0) + use R11 + use R7 for i +TEXT ·divWVW(SB),NOSPLIT,$0 + MOVD z+0(FP), R2 + MOVD xn+24(FP), R10 // r = xn + MOVD x+32(FP), R8 + MOVD y+56(FP), R9 + MOVD z_len+8(FP), R7 // i = z + SLD $3, R7, R1 // i*8 + MOVD $0, R0 // make sure it's zero + BR E7 + +L7: MOVD (R8)(R1*1), R11 + WORD $0xB98700A9 //DLGR R10,R9 + MOVD R11, (R2)(R1*1) + +E7: SUB $1, R7 // i-- + SUB $8, R1 + BGE L7 // i >= 0 + + MOVD R10, r+64(FP) + RET diff --git a/vendor/github.com/golang/go/src/math/big/decimal.go b/vendor/github.com/golang/go/src/math/big/decimal.go new file mode 100644 index 000000000000..ae9ffb5db6ab --- /dev/null +++ b/vendor/github.com/golang/go/src/math/big/decimal.go @@ -0,0 +1,267 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements multi-precision decimal numbers. +// The implementation is for float to decimal conversion only; +// not general purpose use. +// The only operations are precise conversion from binary to +// decimal and rounding. +// +// The key observation and some code (shr) is borrowed from +// strconv/decimal.go: conversion of binary fractional values can be done +// precisely in multi-precision decimal because 2 divides 10 (required for +// >> of mantissa); but conversion of decimal floating-point values cannot +// be done precisely in binary representation. +// +// In contrast to strconv/decimal.go, only right shift is implemented in +// decimal format - left shift can be done precisely in binary format. + +package big + +// A decimal represents an unsigned floating-point number in decimal representation. +// The value of a non-zero decimal d is d.mant * 10**d.exp with 0.1 <= d.mant < 1, +// with the most-significant mantissa digit at index 0. For the zero decimal, the +// mantissa length and exponent are 0. +// The zero value for decimal represents a ready-to-use 0.0. +type decimal struct { + mant []byte // mantissa ASCII digits, big-endian + exp int // exponent +} + +// at returns the i'th mantissa digit, starting with the most significant digit at 0. +func (d *decimal) at(i int) byte { + if 0 <= i && i < len(d.mant) { + return d.mant[i] + } + return '0' +} + +// Maximum shift amount that can be done in one pass without overflow. +// A Word has _W bits and (1<= 0), or m >> -shift (for shift < 0). +func (x *decimal) init(m nat, shift int) { + // special case 0 + if len(m) == 0 { + x.mant = x.mant[:0] + x.exp = 0 + return + } + + // Optimization: If we need to shift right, first remove any trailing + // zero bits from m to reduce shift amount that needs to be done in + // decimal format (since that is likely slower). + if shift < 0 { + ntz := m.trailingZeroBits() + s := uint(-shift) + if s >= ntz { + s = ntz // shift at most ntz bits + } + m = nat(nil).shr(m, s) + shift += int(s) + } + + // Do any shift left in binary representation. + if shift > 0 { + m = nat(nil).shl(m, uint(shift)) + shift = 0 + } + + // Convert mantissa into decimal representation. + s := m.utoa(10) + n := len(s) + x.exp = n + // Trim trailing zeros; instead the exponent is tracking + // the decimal point independent of the number of digits. + for n > 0 && s[n-1] == '0' { + n-- + } + x.mant = append(x.mant[:0], s[:n]...) + + // Do any (remaining) shift right in decimal representation. + if shift < 0 { + for shift < -maxShift { + shr(x, maxShift) + shift += maxShift + } + shr(x, uint(-shift)) + } +} + +// shr implements x >> s, for s <= maxShift. +func shr(x *decimal, s uint) { + // Division by 1<>s == 0 && r < len(x.mant) { + ch := Word(x.mant[r]) + r++ + n = n*10 + ch - '0' + } + if n == 0 { + // x == 0; shouldn't get here, but handle anyway + x.mant = x.mant[:0] + return + } + for n>>s == 0 { + r++ + n *= 10 + } + x.exp += 1 - r + + // read a digit, write a digit + w := 0 // write index + mask := Word(1)<> s + n &= mask // n -= d << s + x.mant[w] = byte(d + '0') + w++ + n = n*10 + ch - '0' + } + + // write extra digits that still fit + for n > 0 && w < len(x.mant) { + d := n >> s + n &= mask + x.mant[w] = byte(d + '0') + w++ + n = n * 10 + } + x.mant = x.mant[:w] // the number may be shorter (e.g. 1024 >> 10) + + // append additional digits that didn't fit + for n > 0 { + d := n >> s + n &= mask + x.mant = append(x.mant, byte(d+'0')) + n = n * 10 + } + + trim(x) +} + +func (x *decimal) String() string { + if len(x.mant) == 0 { + return "0" + } + + var buf []byte + switch { + case x.exp <= 0: + // 0.00ddd + buf = append(buf, "0."...) + buf = appendZeros(buf, -x.exp) + buf = append(buf, x.mant...) + + case /* 0 < */ x.exp < len(x.mant): + // dd.ddd + buf = append(buf, x.mant[:x.exp]...) + buf = append(buf, '.') + buf = append(buf, x.mant[x.exp:]...) + + default: // len(x.mant) <= x.exp + // ddd00 + buf = append(buf, x.mant...) + buf = appendZeros(buf, x.exp-len(x.mant)) + } + + return string(buf) +} + +// appendZeros appends n 0 digits to buf and returns buf. +func appendZeros(buf []byte, n int) []byte { + for ; n > 0; n-- { + buf = append(buf, '0') + } + return buf +} + +// shouldRoundUp reports if x should be rounded up +// if shortened to n digits. n must be a valid index +// for x.mant. +func shouldRoundUp(x *decimal, n int) bool { + if x.mant[n] == '5' && n+1 == len(x.mant) { + // exactly halfway - round to even + return n > 0 && (x.mant[n-1]-'0')&1 != 0 + } + // not halfway - digit tells all (x.mant has no trailing zeros) + return x.mant[n] >= '5' +} + +// round sets x to (at most) n mantissa digits by rounding it +// to the nearest even value with n (or fever) mantissa digits. +// If n < 0, x remains unchanged. +func (x *decimal) round(n int) { + if n < 0 || n >= len(x.mant) { + return // nothing to do + } + + if shouldRoundUp(x, n) { + x.roundUp(n) + } else { + x.roundDown(n) + } +} + +func (x *decimal) roundUp(n int) { + if n < 0 || n >= len(x.mant) { + return // nothing to do + } + // 0 <= n < len(x.mant) + + // find first digit < '9' + for n > 0 && x.mant[n-1] >= '9' { + n-- + } + + if n == 0 { + // all digits are '9's => round up to '1' and update exponent + x.mant[0] = '1' // ok since len(x.mant) > n + x.mant = x.mant[:1] + x.exp++ + return + } + + // n > 0 && x.mant[n-1] < '9' + x.mant[n-1]++ + x.mant = x.mant[:n] + // x already trimmed +} + +func (x *decimal) roundDown(n int) { + if n < 0 || n >= len(x.mant) { + return // nothing to do + } + x.mant = x.mant[:n] + trim(x) +} + +// trim cuts off any trailing zeros from x's mantissa; +// they are meaningless for the value of x. +func trim(x *decimal) { + i := len(x.mant) + for i > 0 && x.mant[i-1] == '0' { + i-- + } + x.mant = x.mant[:i] + if i == 0 { + x.exp = 0 + } +} diff --git a/vendor/github.com/golang/go/src/math/big/doc.go b/vendor/github.com/golang/go/src/math/big/doc.go new file mode 100644 index 000000000000..65ed019b741d --- /dev/null +++ b/vendor/github.com/golang/go/src/math/big/doc.go @@ -0,0 +1,99 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* +Package big implements arbitrary-precision arithmetic (big numbers). +The following numeric types are supported: + + Int signed integers + Rat rational numbers + Float floating-point numbers + +The zero value for an Int, Rat, or Float correspond to 0. Thus, new +values can be declared in the usual ways and denote 0 without further +initialization: + + var x Int // &x is an *Int of value 0 + var r = &Rat{} // r is a *Rat of value 0 + y := new(Float) // y is a *Float of value 0 + +Alternatively, new values can be allocated and initialized with factory +functions of the form: + + func NewT(v V) *T + +For instance, NewInt(x) returns an *Int set to the value of the int64 +argument x, NewRat(a, b) returns a *Rat set to the fraction a/b where +a and b are int64 values, and NewFloat(f) returns a *Float initialized +to the float64 argument f. More flexibility is provided with explicit +setters, for instance: + + var z1 Int + z1.SetUint64(123) // z1 := 123 + z2 := new(Rat).SetFloat64(1.25) // z2 := 5/4 + z3 := new(Float).SetInt(z1) // z3 := 123.0 + +Setters, numeric operations and predicates are represented as methods of +the form: + + func (z *T) SetV(v V) *T // z = v + func (z *T) Unary(x *T) *T // z = unary x + func (z *T) Binary(x, y *T) *T // z = x binary y + func (x *T) Pred() P // p = pred(x) + +with T one of Int, Rat, or Float. For unary and binary operations, the +result is the receiver (usually named z in that case; see below); if it +is one of the operands x or y it may be safely overwritten (and its memory +reused). + +Arithmetic expressions are typically written as a sequence of individual +method calls, with each call corresponding to an operation. The receiver +denotes the result and the method arguments are the operation's operands. +For instance, given three *Int values a, b and c, the invocation + + c.Add(a, b) + +computes the sum a + b and stores the result in c, overwriting whatever +value was held in c before. Unless specified otherwise, operations permit +aliasing of parameters, so it is perfectly ok to write + + sum.Add(sum, x) + +to accumulate values x in a sum. + +(By always passing in a result value via the receiver, memory use can be +much better controlled. Instead of having to allocate new memory for each +result, an operation can reuse the space allocated for the result value, +and overwrite that value with the new result in the process.) + +Notational convention: Incoming method parameters (including the receiver) +are named consistently in the API to clarify their use. Incoming operands +are usually named x, y, a, b, and so on, but never z. A parameter specifying +the result is named z (typically the receiver). + +For instance, the arguments for (*Int).Add are named x and y, and because +the receiver specifies the result destination, it is called z: + + func (z *Int) Add(x, y *Int) *Int + +Methods of this form typically return the incoming receiver as well, to +enable simple call chaining. + +Methods which don't require a result value to be passed in (for instance, +Int.Sign), simply return the result. In this case, the receiver is typically +the first operand, named x: + + func (x *Int) Sign() int + +Various methods support conversions between strings and corresponding +numeric values, and vice versa: *Int, *Rat, and *Float values implement +the Stringer interface for a (default) string representation of the value, +but also provide SetString methods to initialize a value from a string in +a variety of supported formats (see the respective SetString documentation). + +Finally, *Int, *Rat, and *Float satisfy the fmt package's Scanner interface +for scanning and (except for *Rat) the Formatter interface for formatted +printing. +*/ +package big diff --git a/vendor/github.com/golang/go/src/math/big/float.go b/vendor/github.com/golang/go/src/math/big/float.go new file mode 100644 index 000000000000..c042854ebaaf --- /dev/null +++ b/vendor/github.com/golang/go/src/math/big/float.go @@ -0,0 +1,1717 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements multi-precision floating-point numbers. +// Like in the GNU MPFR library (http://www.mpfr.org/), operands +// can be of mixed precision. Unlike MPFR, the rounding mode is +// not specified with each operation, but with each operand. The +// rounding mode of the result operand determines the rounding +// mode of an operation. This is a from-scratch implementation. + +package big + +import ( + "fmt" + "math" + "math/bits" +) + +const debugFloat = false // enable for debugging + +// A nonzero finite Float represents a multi-precision floating point number +// +// sign × mantissa × 2**exponent +// +// with 0.5 <= mantissa < 1.0, and MinExp <= exponent <= MaxExp. +// A Float may also be zero (+0, -0) or infinite (+Inf, -Inf). +// All Floats are ordered, and the ordering of two Floats x and y +// is defined by x.Cmp(y). +// +// Each Float value also has a precision, rounding mode, and accuracy. +// The precision is the maximum number of mantissa bits available to +// represent the value. The rounding mode specifies how a result should +// be rounded to fit into the mantissa bits, and accuracy describes the +// rounding error with respect to the exact result. +// +// Unless specified otherwise, all operations (including setters) that +// specify a *Float variable for the result (usually via the receiver +// with the exception of MantExp), round the numeric result according +// to the precision and rounding mode of the result variable. +// +// If the provided result precision is 0 (see below), it is set to the +// precision of the argument with the largest precision value before any +// rounding takes place, and the rounding mode remains unchanged. Thus, +// uninitialized Floats provided as result arguments will have their +// precision set to a reasonable value determined by the operands and +// their mode is the zero value for RoundingMode (ToNearestEven). +// +// By setting the desired precision to 24 or 53 and using matching rounding +// mode (typically ToNearestEven), Float operations produce the same results +// as the corresponding float32 or float64 IEEE-754 arithmetic for operands +// that correspond to normal (i.e., not denormal) float32 or float64 numbers. +// Exponent underflow and overflow lead to a 0 or an Infinity for different +// values than IEEE-754 because Float exponents have a much larger range. +// +// The zero (uninitialized) value for a Float is ready to use and represents +// the number +0.0 exactly, with precision 0 and rounding mode ToNearestEven. +// +type Float struct { + prec uint32 + mode RoundingMode + acc Accuracy + form form + neg bool + mant nat + exp int32 +} + +// An ErrNaN panic is raised by a Float operation that would lead to +// a NaN under IEEE-754 rules. An ErrNaN implements the error interface. +type ErrNaN struct { + msg string +} + +func (err ErrNaN) Error() string { + return err.msg +} + +// NewFloat allocates and returns a new Float set to x, +// with precision 53 and rounding mode ToNearestEven. +// NewFloat panics with ErrNaN if x is a NaN. +func NewFloat(x float64) *Float { + if math.IsNaN(x) { + panic(ErrNaN{"NewFloat(NaN)"}) + } + return new(Float).SetFloat64(x) +} + +// Exponent and precision limits. +const ( + MaxExp = math.MaxInt32 // largest supported exponent + MinExp = math.MinInt32 // smallest supported exponent + MaxPrec = math.MaxUint32 // largest (theoretically) supported precision; likely memory-limited +) + +// Internal representation: The mantissa bits x.mant of a nonzero finite +// Float x are stored in a nat slice long enough to hold up to x.prec bits; +// the slice may (but doesn't have to) be shorter if the mantissa contains +// trailing 0 bits. x.mant is normalized if the msb of x.mant == 1 (i.e., +// the msb is shifted all the way "to the left"). Thus, if the mantissa has +// trailing 0 bits or x.prec is not a multiple of the Word size _W, +// x.mant[0] has trailing zero bits. The msb of the mantissa corresponds +// to the value 0.5; the exponent x.exp shifts the binary point as needed. +// +// A zero or non-finite Float x ignores x.mant and x.exp. +// +// x form neg mant exp +// ---------------------------------------------------------- +// ±0 zero sign - - +// 0 < |x| < +Inf finite sign mantissa exponent +// ±Inf inf sign - - + +// A form value describes the internal representation. +type form byte + +// The form value order is relevant - do not change! +const ( + zero form = iota + finite + inf +) + +// RoundingMode determines how a Float value is rounded to the +// desired precision. Rounding may change the Float value; the +// rounding error is described by the Float's Accuracy. +type RoundingMode byte + +// These constants define supported rounding modes. +const ( + ToNearestEven RoundingMode = iota // == IEEE 754-2008 roundTiesToEven + ToNearestAway // == IEEE 754-2008 roundTiesToAway + ToZero // == IEEE 754-2008 roundTowardZero + AwayFromZero // no IEEE 754-2008 equivalent + ToNegativeInf // == IEEE 754-2008 roundTowardNegative + ToPositiveInf // == IEEE 754-2008 roundTowardPositive +) + +//go:generate stringer -type=RoundingMode + +// Accuracy describes the rounding error produced by the most recent +// operation that generated a Float value, relative to the exact value. +type Accuracy int8 + +// Constants describing the Accuracy of a Float. +const ( + Below Accuracy = -1 + Exact Accuracy = 0 + Above Accuracy = +1 +) + +//go:generate stringer -type=Accuracy + +// SetPrec sets z's precision to prec and returns the (possibly) rounded +// value of z. Rounding occurs according to z's rounding mode if the mantissa +// cannot be represented in prec bits without loss of precision. +// SetPrec(0) maps all finite values to ±0; infinite values remain unchanged. +// If prec > MaxPrec, it is set to MaxPrec. +func (z *Float) SetPrec(prec uint) *Float { + z.acc = Exact // optimistically assume no rounding is needed + + // special case + if prec == 0 { + z.prec = 0 + if z.form == finite { + // truncate z to 0 + z.acc = makeAcc(z.neg) + z.form = zero + } + return z + } + + // general case + if prec > MaxPrec { + prec = MaxPrec + } + old := z.prec + z.prec = uint32(prec) + if z.prec < old { + z.round(0) + } + return z +} + +func makeAcc(above bool) Accuracy { + if above { + return Above + } + return Below +} + +// SetMode sets z's rounding mode to mode and returns an exact z. +// z remains unchanged otherwise. +// z.SetMode(z.Mode()) is a cheap way to set z's accuracy to Exact. +func (z *Float) SetMode(mode RoundingMode) *Float { + z.mode = mode + z.acc = Exact + return z +} + +// Prec returns the mantissa precision of x in bits. +// The result may be 0 for |x| == 0 and |x| == Inf. +func (x *Float) Prec() uint { + return uint(x.prec) +} + +// MinPrec returns the minimum precision required to represent x exactly +// (i.e., the smallest prec before x.SetPrec(prec) would start rounding x). +// The result is 0 for |x| == 0 and |x| == Inf. +func (x *Float) MinPrec() uint { + if x.form != finite { + return 0 + } + return uint(len(x.mant))*_W - x.mant.trailingZeroBits() +} + +// Mode returns the rounding mode of x. +func (x *Float) Mode() RoundingMode { + return x.mode +} + +// Acc returns the accuracy of x produced by the most recent operation. +func (x *Float) Acc() Accuracy { + return x.acc +} + +// Sign returns: +// +// -1 if x < 0 +// 0 if x is ±0 +// +1 if x > 0 +// +func (x *Float) Sign() int { + if debugFloat { + x.validate() + } + if x.form == zero { + return 0 + } + if x.neg { + return -1 + } + return 1 +} + +// MantExp breaks x into its mantissa and exponent components +// and returns the exponent. If a non-nil mant argument is +// provided its value is set to the mantissa of x, with the +// same precision and rounding mode as x. The components +// satisfy x == mant × 2**exp, with 0.5 <= |mant| < 1.0. +// Calling MantExp with a nil argument is an efficient way to +// get the exponent of the receiver. +// +// Special cases are: +// +// ( ±0).MantExp(mant) = 0, with mant set to ±0 +// (±Inf).MantExp(mant) = 0, with mant set to ±Inf +// +// x and mant may be the same in which case x is set to its +// mantissa value. +func (x *Float) MantExp(mant *Float) (exp int) { + if debugFloat { + x.validate() + } + if x.form == finite { + exp = int(x.exp) + } + if mant != nil { + mant.Copy(x) + if mant.form == finite { + mant.exp = 0 + } + } + return +} + +func (z *Float) setExpAndRound(exp int64, sbit uint) { + if exp < MinExp { + // underflow + z.acc = makeAcc(z.neg) + z.form = zero + return + } + + if exp > MaxExp { + // overflow + z.acc = makeAcc(!z.neg) + z.form = inf + return + } + + z.form = finite + z.exp = int32(exp) + z.round(sbit) +} + +// SetMantExp sets z to mant × 2**exp and and returns z. +// The result z has the same precision and rounding mode +// as mant. SetMantExp is an inverse of MantExp but does +// not require 0.5 <= |mant| < 1.0. Specifically: +// +// mant := new(Float) +// new(Float).SetMantExp(mant, x.MantExp(mant)).Cmp(x) == 0 +// +// Special cases are: +// +// z.SetMantExp( ±0, exp) = ±0 +// z.SetMantExp(±Inf, exp) = ±Inf +// +// z and mant may be the same in which case z's exponent +// is set to exp. +func (z *Float) SetMantExp(mant *Float, exp int) *Float { + if debugFloat { + z.validate() + mant.validate() + } + z.Copy(mant) + if z.form != finite { + return z + } + z.setExpAndRound(int64(z.exp)+int64(exp), 0) + return z +} + +// Signbit returns true if x is negative or negative zero. +func (x *Float) Signbit() bool { + return x.neg +} + +// IsInf reports whether x is +Inf or -Inf. +func (x *Float) IsInf() bool { + return x.form == inf +} + +// IsInt reports whether x is an integer. +// ±Inf values are not integers. +func (x *Float) IsInt() bool { + if debugFloat { + x.validate() + } + // special cases + if x.form != finite { + return x.form == zero + } + // x.form == finite + if x.exp <= 0 { + return false + } + // x.exp > 0 + return x.prec <= uint32(x.exp) || x.MinPrec() <= uint(x.exp) // not enough bits for fractional mantissa +} + +// debugging support +func (x *Float) validate() { + if !debugFloat { + // avoid performance bugs + panic("validate called but debugFloat is not set") + } + if x.form != finite { + return + } + m := len(x.mant) + if m == 0 { + panic("nonzero finite number with empty mantissa") + } + const msb = 1 << (_W - 1) + if x.mant[m-1]&msb == 0 { + panic(fmt.Sprintf("msb not set in last word %#x of %s", x.mant[m-1], x.Text('p', 0))) + } + if x.prec == 0 { + panic("zero precision finite number") + } +} + +// round rounds z according to z.mode to z.prec bits and sets z.acc accordingly. +// sbit must be 0 or 1 and summarizes any "sticky bit" information one might +// have before calling round. z's mantissa must be normalized (with the msb set) +// or empty. +// +// CAUTION: The rounding modes ToNegativeInf, ToPositiveInf are affected by the +// sign of z. For correct rounding, the sign of z must be set correctly before +// calling round. +func (z *Float) round(sbit uint) { + if debugFloat { + z.validate() + } + + z.acc = Exact + if z.form != finite { + // ±0 or ±Inf => nothing left to do + return + } + // z.form == finite && len(z.mant) > 0 + // m > 0 implies z.prec > 0 (checked by validate) + + m := uint32(len(z.mant)) // present mantissa length in words + bits := m * _W // present mantissa bits; bits > 0 + if bits <= z.prec { + // mantissa fits => nothing to do + return + } + // bits > z.prec + + // Rounding is based on two bits: the rounding bit (rbit) and the + // sticky bit (sbit). The rbit is the bit immediately before the + // z.prec leading mantissa bits (the "0.5"). The sbit is set if any + // of the bits before the rbit are set (the "0.25", "0.125", etc.): + // + // rbit sbit => "fractional part" + // + // 0 0 == 0 + // 0 1 > 0 , < 0.5 + // 1 0 == 0.5 + // 1 1 > 0.5, < 1.0 + + // bits > z.prec: mantissa too large => round + r := uint(bits - z.prec - 1) // rounding bit position; r >= 0 + rbit := z.mant.bit(r) & 1 // rounding bit; be safe and ensure it's a single bit + // The sticky bit is only needed for rounding ToNearestEven + // or when the rounding bit is zero. Avoid computation otherwise. + if sbit == 0 && (rbit == 0 || z.mode == ToNearestEven) { + sbit = z.mant.sticky(r) + } + sbit &= 1 // be safe and ensure it's a single bit + + // cut off extra words + n := (z.prec + (_W - 1)) / _W // mantissa length in words for desired precision + if m > n { + copy(z.mant, z.mant[m-n:]) // move n last words to front + z.mant = z.mant[:n] + } + + // determine number of trailing zero bits (ntz) and compute lsb mask of mantissa's least-significant word + ntz := n*_W - z.prec // 0 <= ntz < _W + lsb := Word(1) << ntz + + // round if result is inexact + if rbit|sbit != 0 { + // Make rounding decision: The result mantissa is truncated ("rounded down") + // by default. Decide if we need to increment, or "round up", the (unsigned) + // mantissa. + inc := false + switch z.mode { + case ToNegativeInf: + inc = z.neg + case ToZero: + // nothing to do + case ToNearestEven: + inc = rbit != 0 && (sbit != 0 || z.mant[0]&lsb != 0) + case ToNearestAway: + inc = rbit != 0 + case AwayFromZero: + inc = true + case ToPositiveInf: + inc = !z.neg + default: + panic("unreachable") + } + + // A positive result (!z.neg) is Above the exact result if we increment, + // and it's Below if we truncate (Exact results require no rounding). + // For a negative result (z.neg) it is exactly the opposite. + z.acc = makeAcc(inc != z.neg) + + if inc { + // add 1 to mantissa + if addVW(z.mant, z.mant, lsb) != 0 { + // mantissa overflow => adjust exponent + if z.exp >= MaxExp { + // exponent overflow + z.form = inf + return + } + z.exp++ + // adjust mantissa: divide by 2 to compensate for exponent adjustment + shrVU(z.mant, z.mant, 1) + // set msb == carry == 1 from the mantissa overflow above + const msb = 1 << (_W - 1) + z.mant[n-1] |= msb + } + } + } + + // zero out trailing bits in least-significant word + z.mant[0] &^= lsb - 1 + + if debugFloat { + z.validate() + } +} + +func (z *Float) setBits64(neg bool, x uint64) *Float { + if z.prec == 0 { + z.prec = 64 + } + z.acc = Exact + z.neg = neg + if x == 0 { + z.form = zero + return z + } + // x != 0 + z.form = finite + s := bits.LeadingZeros64(x) + z.mant = z.mant.setUint64(x << uint(s)) + z.exp = int32(64 - s) // always fits + if z.prec < 64 { + z.round(0) + } + return z +} + +// SetUint64 sets z to the (possibly rounded) value of x and returns z. +// If z's precision is 0, it is changed to 64 (and rounding will have +// no effect). +func (z *Float) SetUint64(x uint64) *Float { + return z.setBits64(false, x) +} + +// SetInt64 sets z to the (possibly rounded) value of x and returns z. +// If z's precision is 0, it is changed to 64 (and rounding will have +// no effect). +func (z *Float) SetInt64(x int64) *Float { + u := x + if u < 0 { + u = -u + } + // We cannot simply call z.SetUint64(uint64(u)) and change + // the sign afterwards because the sign affects rounding. + return z.setBits64(x < 0, uint64(u)) +} + +// SetFloat64 sets z to the (possibly rounded) value of x and returns z. +// If z's precision is 0, it is changed to 53 (and rounding will have +// no effect). SetFloat64 panics with ErrNaN if x is a NaN. +func (z *Float) SetFloat64(x float64) *Float { + if z.prec == 0 { + z.prec = 53 + } + if math.IsNaN(x) { + panic(ErrNaN{"Float.SetFloat64(NaN)"}) + } + z.acc = Exact + z.neg = math.Signbit(x) // handle -0, -Inf correctly + if x == 0 { + z.form = zero + return z + } + if math.IsInf(x, 0) { + z.form = inf + return z + } + // normalized x != 0 + z.form = finite + fmant, exp := math.Frexp(x) // get normalized mantissa + z.mant = z.mant.setUint64(1<<63 | math.Float64bits(fmant)<<11) + z.exp = int32(exp) // always fits + if z.prec < 53 { + z.round(0) + } + return z +} + +// fnorm normalizes mantissa m by shifting it to the left +// such that the msb of the most-significant word (msw) is 1. +// It returns the shift amount. It assumes that len(m) != 0. +func fnorm(m nat) int64 { + if debugFloat && (len(m) == 0 || m[len(m)-1] == 0) { + panic("msw of mantissa is 0") + } + s := nlz(m[len(m)-1]) + if s > 0 { + c := shlVU(m, m, s) + if debugFloat && c != 0 { + panic("nlz or shlVU incorrect") + } + } + return int64(s) +} + +// SetInt sets z to the (possibly rounded) value of x and returns z. +// If z's precision is 0, it is changed to the larger of x.BitLen() +// or 64 (and rounding will have no effect). +func (z *Float) SetInt(x *Int) *Float { + // TODO(gri) can be more efficient if z.prec > 0 + // but small compared to the size of x, or if there + // are many trailing 0's. + bits := uint32(x.BitLen()) + if z.prec == 0 { + z.prec = umax32(bits, 64) + } + z.acc = Exact + z.neg = x.neg + if len(x.abs) == 0 { + z.form = zero + return z + } + // x != 0 + z.mant = z.mant.set(x.abs) + fnorm(z.mant) + z.setExpAndRound(int64(bits), 0) + return z +} + +// SetRat sets z to the (possibly rounded) value of x and returns z. +// If z's precision is 0, it is changed to the largest of a.BitLen(), +// b.BitLen(), or 64; with x = a/b. +func (z *Float) SetRat(x *Rat) *Float { + if x.IsInt() { + return z.SetInt(x.Num()) + } + var a, b Float + a.SetInt(x.Num()) + b.SetInt(x.Denom()) + if z.prec == 0 { + z.prec = umax32(a.prec, b.prec) + } + return z.Quo(&a, &b) +} + +// SetInf sets z to the infinite Float -Inf if signbit is +// set, or +Inf if signbit is not set, and returns z. The +// precision of z is unchanged and the result is always +// Exact. +func (z *Float) SetInf(signbit bool) *Float { + z.acc = Exact + z.form = inf + z.neg = signbit + return z +} + +// Set sets z to the (possibly rounded) value of x and returns z. +// If z's precision is 0, it is changed to the precision of x +// before setting z (and rounding will have no effect). +// Rounding is performed according to z's precision and rounding +// mode; and z's accuracy reports the result error relative to the +// exact (not rounded) result. +func (z *Float) Set(x *Float) *Float { + if debugFloat { + x.validate() + } + z.acc = Exact + if z != x { + z.form = x.form + z.neg = x.neg + if x.form == finite { + z.exp = x.exp + z.mant = z.mant.set(x.mant) + } + if z.prec == 0 { + z.prec = x.prec + } else if z.prec < x.prec { + z.round(0) + } + } + return z +} + +// Copy sets z to x, with the same precision, rounding mode, and +// accuracy as x, and returns z. x is not changed even if z and +// x are the same. +func (z *Float) Copy(x *Float) *Float { + if debugFloat { + x.validate() + } + if z != x { + z.prec = x.prec + z.mode = x.mode + z.acc = x.acc + z.form = x.form + z.neg = x.neg + if z.form == finite { + z.mant = z.mant.set(x.mant) + z.exp = x.exp + } + } + return z +} + +// msb32 returns the 32 most significant bits of x. +func msb32(x nat) uint32 { + i := len(x) - 1 + if i < 0 { + return 0 + } + if debugFloat && x[i]&(1<<(_W-1)) == 0 { + panic("x not normalized") + } + switch _W { + case 32: + return uint32(x[i]) + case 64: + return uint32(x[i] >> 32) + } + panic("unreachable") +} + +// msb64 returns the 64 most significant bits of x. +func msb64(x nat) uint64 { + i := len(x) - 1 + if i < 0 { + return 0 + } + if debugFloat && x[i]&(1<<(_W-1)) == 0 { + panic("x not normalized") + } + switch _W { + case 32: + v := uint64(x[i]) << 32 + if i > 0 { + v |= uint64(x[i-1]) + } + return v + case 64: + return uint64(x[i]) + } + panic("unreachable") +} + +// Uint64 returns the unsigned integer resulting from truncating x +// towards zero. If 0 <= x <= math.MaxUint64, the result is Exact +// if x is an integer and Below otherwise. +// The result is (0, Above) for x < 0, and (math.MaxUint64, Below) +// for x > math.MaxUint64. +func (x *Float) Uint64() (uint64, Accuracy) { + if debugFloat { + x.validate() + } + + switch x.form { + case finite: + if x.neg { + return 0, Above + } + // 0 < x < +Inf + if x.exp <= 0 { + // 0 < x < 1 + return 0, Below + } + // 1 <= x < Inf + if x.exp <= 64 { + // u = trunc(x) fits into a uint64 + u := msb64(x.mant) >> (64 - uint32(x.exp)) + if x.MinPrec() <= 64 { + return u, Exact + } + return u, Below // x truncated + } + // x too large + return math.MaxUint64, Below + + case zero: + return 0, Exact + + case inf: + if x.neg { + return 0, Above + } + return math.MaxUint64, Below + } + + panic("unreachable") +} + +// Int64 returns the integer resulting from truncating x towards zero. +// If math.MinInt64 <= x <= math.MaxInt64, the result is Exact if x is +// an integer, and Above (x < 0) or Below (x > 0) otherwise. +// The result is (math.MinInt64, Above) for x < math.MinInt64, +// and (math.MaxInt64, Below) for x > math.MaxInt64. +func (x *Float) Int64() (int64, Accuracy) { + if debugFloat { + x.validate() + } + + switch x.form { + case finite: + // 0 < |x| < +Inf + acc := makeAcc(x.neg) + if x.exp <= 0 { + // 0 < |x| < 1 + return 0, acc + } + // x.exp > 0 + + // 1 <= |x| < +Inf + if x.exp <= 63 { + // i = trunc(x) fits into an int64 (excluding math.MinInt64) + i := int64(msb64(x.mant) >> (64 - uint32(x.exp))) + if x.neg { + i = -i + } + if x.MinPrec() <= uint(x.exp) { + return i, Exact + } + return i, acc // x truncated + } + if x.neg { + // check for special case x == math.MinInt64 (i.e., x == -(0.5 << 64)) + if x.exp == 64 && x.MinPrec() == 1 { + acc = Exact + } + return math.MinInt64, acc + } + // x too large + return math.MaxInt64, Below + + case zero: + return 0, Exact + + case inf: + if x.neg { + return math.MinInt64, Above + } + return math.MaxInt64, Below + } + + panic("unreachable") +} + +// Float32 returns the float32 value nearest to x. If x is too small to be +// represented by a float32 (|x| < math.SmallestNonzeroFloat32), the result +// is (0, Below) or (-0, Above), respectively, depending on the sign of x. +// If x is too large to be represented by a float32 (|x| > math.MaxFloat32), +// the result is (+Inf, Above) or (-Inf, Below), depending on the sign of x. +func (x *Float) Float32() (float32, Accuracy) { + if debugFloat { + x.validate() + } + + switch x.form { + case finite: + // 0 < |x| < +Inf + + const ( + fbits = 32 // float size + mbits = 23 // mantissa size (excluding implicit msb) + ebits = fbits - mbits - 1 // 8 exponent size + bias = 1<<(ebits-1) - 1 // 127 exponent bias + dmin = 1 - bias - mbits // -149 smallest unbiased exponent (denormal) + emin = 1 - bias // -126 smallest unbiased exponent (normal) + emax = bias // 127 largest unbiased exponent (normal) + ) + + // Float mantissa m is 0.5 <= m < 1.0; compute exponent e for float32 mantissa. + e := x.exp - 1 // exponent for normal mantissa m with 1.0 <= m < 2.0 + + // Compute precision p for float32 mantissa. + // If the exponent is too small, we have a denormal number before + // rounding and fewer than p mantissa bits of precision available + // (the exponent remains fixed but the mantissa gets shifted right). + p := mbits + 1 // precision of normal float + if e < emin { + // recompute precision + p = mbits + 1 - emin + int(e) + // If p == 0, the mantissa of x is shifted so much to the right + // that its msb falls immediately to the right of the float32 + // mantissa space. In other words, if the smallest denormal is + // considered "1.0", for p == 0, the mantissa value m is >= 0.5. + // If m > 0.5, it is rounded up to 1.0; i.e., the smallest denormal. + // If m == 0.5, it is rounded down to even, i.e., 0.0. + // If p < 0, the mantissa value m is <= "0.25" which is never rounded up. + if p < 0 /* m <= 0.25 */ || p == 0 && x.mant.sticky(uint(len(x.mant))*_W-1) == 0 /* m == 0.5 */ { + // underflow to ±0 + if x.neg { + var z float32 + return -z, Above + } + return 0.0, Below + } + // otherwise, round up + // We handle p == 0 explicitly because it's easy and because + // Float.round doesn't support rounding to 0 bits of precision. + if p == 0 { + if x.neg { + return -math.SmallestNonzeroFloat32, Below + } + return math.SmallestNonzeroFloat32, Above + } + } + // p > 0 + + // round + var r Float + r.prec = uint32(p) + r.Set(x) + e = r.exp - 1 + + // Rounding may have caused r to overflow to ±Inf + // (rounding never causes underflows to 0). + // If the exponent is too large, also overflow to ±Inf. + if r.form == inf || e > emax { + // overflow + if x.neg { + return float32(math.Inf(-1)), Below + } + return float32(math.Inf(+1)), Above + } + // e <= emax + + // Determine sign, biased exponent, and mantissa. + var sign, bexp, mant uint32 + if x.neg { + sign = 1 << (fbits - 1) + } + + // Rounding may have caused a denormal number to + // become normal. Check again. + if e < emin { + // denormal number: recompute precision + // Since rounding may have at best increased precision + // and we have eliminated p <= 0 early, we know p > 0. + // bexp == 0 for denormals + p = mbits + 1 - emin + int(e) + mant = msb32(r.mant) >> uint(fbits-p) + } else { + // normal number: emin <= e <= emax + bexp = uint32(e+bias) << mbits + mant = msb32(r.mant) >> ebits & (1< math.MaxFloat64), +// the result is (+Inf, Above) or (-Inf, Below), depending on the sign of x. +func (x *Float) Float64() (float64, Accuracy) { + if debugFloat { + x.validate() + } + + switch x.form { + case finite: + // 0 < |x| < +Inf + + const ( + fbits = 64 // float size + mbits = 52 // mantissa size (excluding implicit msb) + ebits = fbits - mbits - 1 // 11 exponent size + bias = 1<<(ebits-1) - 1 // 1023 exponent bias + dmin = 1 - bias - mbits // -1074 smallest unbiased exponent (denormal) + emin = 1 - bias // -1022 smallest unbiased exponent (normal) + emax = bias // 1023 largest unbiased exponent (normal) + ) + + // Float mantissa m is 0.5 <= m < 1.0; compute exponent e for float64 mantissa. + e := x.exp - 1 // exponent for normal mantissa m with 1.0 <= m < 2.0 + + // Compute precision p for float64 mantissa. + // If the exponent is too small, we have a denormal number before + // rounding and fewer than p mantissa bits of precision available + // (the exponent remains fixed but the mantissa gets shifted right). + p := mbits + 1 // precision of normal float + if e < emin { + // recompute precision + p = mbits + 1 - emin + int(e) + // If p == 0, the mantissa of x is shifted so much to the right + // that its msb falls immediately to the right of the float64 + // mantissa space. In other words, if the smallest denormal is + // considered "1.0", for p == 0, the mantissa value m is >= 0.5. + // If m > 0.5, it is rounded up to 1.0; i.e., the smallest denormal. + // If m == 0.5, it is rounded down to even, i.e., 0.0. + // If p < 0, the mantissa value m is <= "0.25" which is never rounded up. + if p < 0 /* m <= 0.25 */ || p == 0 && x.mant.sticky(uint(len(x.mant))*_W-1) == 0 /* m == 0.5 */ { + // underflow to ±0 + if x.neg { + var z float64 + return -z, Above + } + return 0.0, Below + } + // otherwise, round up + // We handle p == 0 explicitly because it's easy and because + // Float.round doesn't support rounding to 0 bits of precision. + if p == 0 { + if x.neg { + return -math.SmallestNonzeroFloat64, Below + } + return math.SmallestNonzeroFloat64, Above + } + } + // p > 0 + + // round + var r Float + r.prec = uint32(p) + r.Set(x) + e = r.exp - 1 + + // Rounding may have caused r to overflow to ±Inf + // (rounding never causes underflows to 0). + // If the exponent is too large, also overflow to ±Inf. + if r.form == inf || e > emax { + // overflow + if x.neg { + return math.Inf(-1), Below + } + return math.Inf(+1), Above + } + // e <= emax + + // Determine sign, biased exponent, and mantissa. + var sign, bexp, mant uint64 + if x.neg { + sign = 1 << (fbits - 1) + } + + // Rounding may have caused a denormal number to + // become normal. Check again. + if e < emin { + // denormal number: recompute precision + // Since rounding may have at best increased precision + // and we have eliminated p <= 0 early, we know p > 0. + // bexp == 0 for denormals + p = mbits + 1 - emin + int(e) + mant = msb64(r.mant) >> uint(fbits-p) + } else { + // normal number: emin <= e <= emax + bexp = uint64(e+bias) << mbits + mant = msb64(r.mant) >> ebits & (1< 0, and Above for x < 0. +// If a non-nil *Int argument z is provided, Int stores +// the result in z instead of allocating a new Int. +func (x *Float) Int(z *Int) (*Int, Accuracy) { + if debugFloat { + x.validate() + } + + if z == nil && x.form <= finite { + z = new(Int) + } + + switch x.form { + case finite: + // 0 < |x| < +Inf + acc := makeAcc(x.neg) + if x.exp <= 0 { + // 0 < |x| < 1 + return z.SetInt64(0), acc + } + // x.exp > 0 + + // 1 <= |x| < +Inf + // determine minimum required precision for x + allBits := uint(len(x.mant)) * _W + exp := uint(x.exp) + if x.MinPrec() <= exp { + acc = Exact + } + // shift mantissa as needed + if z == nil { + z = new(Int) + } + z.neg = x.neg + switch { + case exp > allBits: + z.abs = z.abs.shl(x.mant, exp-allBits) + default: + z.abs = z.abs.set(x.mant) + case exp < allBits: + z.abs = z.abs.shr(x.mant, allBits-exp) + } + return z, acc + + case zero: + return z.SetInt64(0), Exact + + case inf: + return nil, makeAcc(x.neg) + } + + panic("unreachable") +} + +// Rat returns the rational number corresponding to x; +// or nil if x is an infinity. +// The result is Exact if x is not an Inf. +// If a non-nil *Rat argument z is provided, Rat stores +// the result in z instead of allocating a new Rat. +func (x *Float) Rat(z *Rat) (*Rat, Accuracy) { + if debugFloat { + x.validate() + } + + if z == nil && x.form <= finite { + z = new(Rat) + } + + switch x.form { + case finite: + // 0 < |x| < +Inf + allBits := int32(len(x.mant)) * _W + // build up numerator and denominator + z.a.neg = x.neg + switch { + case x.exp > allBits: + z.a.abs = z.a.abs.shl(x.mant, uint(x.exp-allBits)) + z.b.abs = z.b.abs[:0] // == 1 (see Rat) + // z already in normal form + default: + z.a.abs = z.a.abs.set(x.mant) + z.b.abs = z.b.abs[:0] // == 1 (see Rat) + // z already in normal form + case x.exp < allBits: + z.a.abs = z.a.abs.set(x.mant) + t := z.b.abs.setUint64(1) + z.b.abs = t.shl(t, uint(allBits-x.exp)) + z.norm() + } + return z, Exact + + case zero: + return z.SetInt64(0), Exact + + case inf: + return nil, makeAcc(x.neg) + } + + panic("unreachable") +} + +// Abs sets z to the (possibly rounded) value |x| (the absolute value of x) +// and returns z. +func (z *Float) Abs(x *Float) *Float { + z.Set(x) + z.neg = false + return z +} + +// Neg sets z to the (possibly rounded) value of x with its sign negated, +// and returns z. +func (z *Float) Neg(x *Float) *Float { + z.Set(x) + z.neg = !z.neg + return z +} + +func validateBinaryOperands(x, y *Float) { + if !debugFloat { + // avoid performance bugs + panic("validateBinaryOperands called but debugFloat is not set") + } + if len(x.mant) == 0 { + panic("empty mantissa for x") + } + if len(y.mant) == 0 { + panic("empty mantissa for y") + } +} + +// z = x + y, ignoring signs of x and y for the addition +// but using the sign of z for rounding the result. +// x and y must have a non-empty mantissa and valid exponent. +func (z *Float) uadd(x, y *Float) { + // Note: This implementation requires 2 shifts most of the + // time. It is also inefficient if exponents or precisions + // differ by wide margins. The following article describes + // an efficient (but much more complicated) implementation + // compatible with the internal representation used here: + // + // Vincent Lefèvre: "The Generic Multiple-Precision Floating- + // Point Addition With Exact Rounding (as in the MPFR Library)" + // http://www.vinc17.net/research/papers/rnc6.pdf + + if debugFloat { + validateBinaryOperands(x, y) + } + + // compute exponents ex, ey for mantissa with "binary point" + // on the right (mantissa.0) - use int64 to avoid overflow + ex := int64(x.exp) - int64(len(x.mant))*_W + ey := int64(y.exp) - int64(len(y.mant))*_W + + al := alias(z.mant, x.mant) || alias(z.mant, y.mant) + + // TODO(gri) having a combined add-and-shift primitive + // could make this code significantly faster + switch { + case ex < ey: + if al { + t := nat(nil).shl(y.mant, uint(ey-ex)) + z.mant = z.mant.add(x.mant, t) + } else { + z.mant = z.mant.shl(y.mant, uint(ey-ex)) + z.mant = z.mant.add(x.mant, z.mant) + } + default: + // ex == ey, no shift needed + z.mant = z.mant.add(x.mant, y.mant) + case ex > ey: + if al { + t := nat(nil).shl(x.mant, uint(ex-ey)) + z.mant = z.mant.add(t, y.mant) + } else { + z.mant = z.mant.shl(x.mant, uint(ex-ey)) + z.mant = z.mant.add(z.mant, y.mant) + } + ex = ey + } + // len(z.mant) > 0 + + z.setExpAndRound(ex+int64(len(z.mant))*_W-fnorm(z.mant), 0) +} + +// z = x - y for |x| > |y|, ignoring signs of x and y for the subtraction +// but using the sign of z for rounding the result. +// x and y must have a non-empty mantissa and valid exponent. +func (z *Float) usub(x, y *Float) { + // This code is symmetric to uadd. + // We have not factored the common code out because + // eventually uadd (and usub) should be optimized + // by special-casing, and the code will diverge. + + if debugFloat { + validateBinaryOperands(x, y) + } + + ex := int64(x.exp) - int64(len(x.mant))*_W + ey := int64(y.exp) - int64(len(y.mant))*_W + + al := alias(z.mant, x.mant) || alias(z.mant, y.mant) + + switch { + case ex < ey: + if al { + t := nat(nil).shl(y.mant, uint(ey-ex)) + z.mant = t.sub(x.mant, t) + } else { + z.mant = z.mant.shl(y.mant, uint(ey-ex)) + z.mant = z.mant.sub(x.mant, z.mant) + } + default: + // ex == ey, no shift needed + z.mant = z.mant.sub(x.mant, y.mant) + case ex > ey: + if al { + t := nat(nil).shl(x.mant, uint(ex-ey)) + z.mant = t.sub(t, y.mant) + } else { + z.mant = z.mant.shl(x.mant, uint(ex-ey)) + z.mant = z.mant.sub(z.mant, y.mant) + } + ex = ey + } + + // operands may have canceled each other out + if len(z.mant) == 0 { + z.acc = Exact + z.form = zero + z.neg = false + return + } + // len(z.mant) > 0 + + z.setExpAndRound(ex+int64(len(z.mant))*_W-fnorm(z.mant), 0) +} + +// z = x * y, ignoring signs of x and y for the multiplication +// but using the sign of z for rounding the result. +// x and y must have a non-empty mantissa and valid exponent. +func (z *Float) umul(x, y *Float) { + if debugFloat { + validateBinaryOperands(x, y) + } + + // Note: This is doing too much work if the precision + // of z is less than the sum of the precisions of x + // and y which is often the case (e.g., if all floats + // have the same precision). + // TODO(gri) Optimize this for the common case. + + e := int64(x.exp) + int64(y.exp) + if x == y { + z.mant = z.mant.sqr(x.mant) + } else { + z.mant = z.mant.mul(x.mant, y.mant) + } + z.setExpAndRound(e-fnorm(z.mant), 0) +} + +// z = x / y, ignoring signs of x and y for the division +// but using the sign of z for rounding the result. +// x and y must have a non-empty mantissa and valid exponent. +func (z *Float) uquo(x, y *Float) { + if debugFloat { + validateBinaryOperands(x, y) + } + + // mantissa length in words for desired result precision + 1 + // (at least one extra bit so we get the rounding bit after + // the division) + n := int(z.prec/_W) + 1 + + // compute adjusted x.mant such that we get enough result precision + xadj := x.mant + if d := n - len(x.mant) + len(y.mant); d > 0 { + // d extra words needed => add d "0 digits" to x + xadj = make(nat, len(x.mant)+d) + copy(xadj[d:], x.mant) + } + // TODO(gri): If we have too many digits (d < 0), we should be able + // to shorten x for faster division. But we must be extra careful + // with rounding in that case. + + // Compute d before division since there may be aliasing of x.mant + // (via xadj) or y.mant with z.mant. + d := len(xadj) - len(y.mant) + + // divide + var r nat + z.mant, r = z.mant.div(nil, xadj, y.mant) + e := int64(x.exp) - int64(y.exp) - int64(d-len(z.mant))*_W + + // The result is long enough to include (at least) the rounding bit. + // If there's a non-zero remainder, the corresponding fractional part + // (if it were computed), would have a non-zero sticky bit (if it were + // zero, it couldn't have a non-zero remainder). + var sbit uint + if len(r) > 0 { + sbit = 1 + } + + z.setExpAndRound(e-fnorm(z.mant), sbit) +} + +// ucmp returns -1, 0, or +1, depending on whether +// |x| < |y|, |x| == |y|, or |x| > |y|. +// x and y must have a non-empty mantissa and valid exponent. +func (x *Float) ucmp(y *Float) int { + if debugFloat { + validateBinaryOperands(x, y) + } + + switch { + case x.exp < y.exp: + return -1 + case x.exp > y.exp: + return +1 + } + // x.exp == y.exp + + // compare mantissas + i := len(x.mant) + j := len(y.mant) + for i > 0 || j > 0 { + var xm, ym Word + if i > 0 { + i-- + xm = x.mant[i] + } + if j > 0 { + j-- + ym = y.mant[j] + } + switch { + case xm < ym: + return -1 + case xm > ym: + return +1 + } + } + + return 0 +} + +// Handling of sign bit as defined by IEEE 754-2008, section 6.3: +// +// When neither the inputs nor result are NaN, the sign of a product or +// quotient is the exclusive OR of the operands’ signs; the sign of a sum, +// or of a difference x−y regarded as a sum x+(−y), differs from at most +// one of the addends’ signs; and the sign of the result of conversions, +// the quantize operation, the roundToIntegral operations, and the +// roundToIntegralExact (see 5.3.1) is the sign of the first or only operand. +// These rules shall apply even when operands or results are zero or infinite. +// +// When the sum of two operands with opposite signs (or the difference of +// two operands with like signs) is exactly zero, the sign of that sum (or +// difference) shall be +0 in all rounding-direction attributes except +// roundTowardNegative; under that attribute, the sign of an exact zero +// sum (or difference) shall be −0. However, x+x = x−(−x) retains the same +// sign as x even when x is zero. +// +// See also: https://play.golang.org/p/RtH3UCt5IH + +// Add sets z to the rounded sum x+y and returns z. If z's precision is 0, +// it is changed to the larger of x's or y's precision before the operation. +// Rounding is performed according to z's precision and rounding mode; and +// z's accuracy reports the result error relative to the exact (not rounded) +// result. Add panics with ErrNaN if x and y are infinities with opposite +// signs. The value of z is undefined in that case. +// +// BUG(gri) When rounding ToNegativeInf, the sign of Float values rounded to 0 is incorrect. +func (z *Float) Add(x, y *Float) *Float { + if debugFloat { + x.validate() + y.validate() + } + + if z.prec == 0 { + z.prec = umax32(x.prec, y.prec) + } + + if x.form == finite && y.form == finite { + // x + y (common case) + + // Below we set z.neg = x.neg, and when z aliases y this will + // change the y operand's sign. This is fine, because if an + // operand aliases the receiver it'll be overwritten, but we still + // want the original x.neg and y.neg values when we evaluate + // x.neg != y.neg, so we need to save y.neg before setting z.neg. + yneg := y.neg + + z.neg = x.neg + if x.neg == yneg { + // x + y == x + y + // (-x) + (-y) == -(x + y) + z.uadd(x, y) + } else { + // x + (-y) == x - y == -(y - x) + // (-x) + y == y - x == -(x - y) + if x.ucmp(y) > 0 { + z.usub(x, y) + } else { + z.neg = !z.neg + z.usub(y, x) + } + } + return z + } + + if x.form == inf && y.form == inf && x.neg != y.neg { + // +Inf + -Inf + // -Inf + +Inf + // value of z is undefined but make sure it's valid + z.acc = Exact + z.form = zero + z.neg = false + panic(ErrNaN{"addition of infinities with opposite signs"}) + } + + if x.form == zero && y.form == zero { + // ±0 + ±0 + z.acc = Exact + z.form = zero + z.neg = x.neg && y.neg // -0 + -0 == -0 + return z + } + + if x.form == inf || y.form == zero { + // ±Inf + y + // x + ±0 + return z.Set(x) + } + + // ±0 + y + // x + ±Inf + return z.Set(y) +} + +// Sub sets z to the rounded difference x-y and returns z. +// Precision, rounding, and accuracy reporting are as for Add. +// Sub panics with ErrNaN if x and y are infinities with equal +// signs. The value of z is undefined in that case. +func (z *Float) Sub(x, y *Float) *Float { + if debugFloat { + x.validate() + y.validate() + } + + if z.prec == 0 { + z.prec = umax32(x.prec, y.prec) + } + + if x.form == finite && y.form == finite { + // x - y (common case) + yneg := y.neg + z.neg = x.neg + if x.neg != yneg { + // x - (-y) == x + y + // (-x) - y == -(x + y) + z.uadd(x, y) + } else { + // x - y == x - y == -(y - x) + // (-x) - (-y) == y - x == -(x - y) + if x.ucmp(y) > 0 { + z.usub(x, y) + } else { + z.neg = !z.neg + z.usub(y, x) + } + } + return z + } + + if x.form == inf && y.form == inf && x.neg == y.neg { + // +Inf - +Inf + // -Inf - -Inf + // value of z is undefined but make sure it's valid + z.acc = Exact + z.form = zero + z.neg = false + panic(ErrNaN{"subtraction of infinities with equal signs"}) + } + + if x.form == zero && y.form == zero { + // ±0 - ±0 + z.acc = Exact + z.form = zero + z.neg = x.neg && !y.neg // -0 - +0 == -0 + return z + } + + if x.form == inf || y.form == zero { + // ±Inf - y + // x - ±0 + return z.Set(x) + } + + // ±0 - y + // x - ±Inf + return z.Neg(y) +} + +// Mul sets z to the rounded product x*y and returns z. +// Precision, rounding, and accuracy reporting are as for Add. +// Mul panics with ErrNaN if one operand is zero and the other +// operand an infinity. The value of z is undefined in that case. +func (z *Float) Mul(x, y *Float) *Float { + if debugFloat { + x.validate() + y.validate() + } + + if z.prec == 0 { + z.prec = umax32(x.prec, y.prec) + } + + z.neg = x.neg != y.neg + + if x.form == finite && y.form == finite { + // x * y (common case) + z.umul(x, y) + return z + } + + z.acc = Exact + if x.form == zero && y.form == inf || x.form == inf && y.form == zero { + // ±0 * ±Inf + // ±Inf * ±0 + // value of z is undefined but make sure it's valid + z.form = zero + z.neg = false + panic(ErrNaN{"multiplication of zero with infinity"}) + } + + if x.form == inf || y.form == inf { + // ±Inf * y + // x * ±Inf + z.form = inf + return z + } + + // ±0 * y + // x * ±0 + z.form = zero + return z +} + +// Quo sets z to the rounded quotient x/y and returns z. +// Precision, rounding, and accuracy reporting are as for Add. +// Quo panics with ErrNaN if both operands are zero or infinities. +// The value of z is undefined in that case. +func (z *Float) Quo(x, y *Float) *Float { + if debugFloat { + x.validate() + y.validate() + } + + if z.prec == 0 { + z.prec = umax32(x.prec, y.prec) + } + + z.neg = x.neg != y.neg + + if x.form == finite && y.form == finite { + // x / y (common case) + z.uquo(x, y) + return z + } + + z.acc = Exact + if x.form == zero && y.form == zero || x.form == inf && y.form == inf { + // ±0 / ±0 + // ±Inf / ±Inf + // value of z is undefined but make sure it's valid + z.form = zero + z.neg = false + panic(ErrNaN{"division of zero by zero or infinity by infinity"}) + } + + if x.form == zero || y.form == inf { + // ±0 / y + // x / ±Inf + z.form = zero + return z + } + + // x / ±0 + // ±Inf / y + z.form = inf + return z +} + +// Cmp compares x and y and returns: +// +// -1 if x < y +// 0 if x == y (incl. -0 == 0, -Inf == -Inf, and +Inf == +Inf) +// +1 if x > y +// +func (x *Float) Cmp(y *Float) int { + if debugFloat { + x.validate() + y.validate() + } + + mx := x.ord() + my := y.ord() + switch { + case mx < my: + return -1 + case mx > my: + return +1 + } + // mx == my + + // only if |mx| == 1 we have to compare the mantissae + switch mx { + case -1: + return y.ucmp(x) + case +1: + return x.ucmp(y) + } + + return 0 +} + +// ord classifies x and returns: +// +// -2 if -Inf == x +// -1 if -Inf < x < 0 +// 0 if x == 0 (signed or unsigned) +// +1 if 0 < x < +Inf +// +2 if x == +Inf +// +func (x *Float) ord() int { + var m int + switch x.form { + case finite: + m = 1 + case zero: + return 0 + case inf: + m = 2 + } + if x.neg { + m = -m + } + return m +} + +func umax32(x, y uint32) uint32 { + if x > y { + return x + } + return y +} diff --git a/vendor/github.com/golang/go/src/math/big/floatconv.go b/vendor/github.com/golang/go/src/math/big/floatconv.go new file mode 100644 index 000000000000..95d1bf84e243 --- /dev/null +++ b/vendor/github.com/golang/go/src/math/big/floatconv.go @@ -0,0 +1,293 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements string-to-Float conversion functions. + +package big + +import ( + "fmt" + "io" + "strings" +) + +var floatZero Float + +// SetString sets z to the value of s and returns z and a boolean indicating +// success. s must be a floating-point number of the same format as accepted +// by Parse, with base argument 0. The entire string (not just a prefix) must +// be valid for success. If the operation failed, the value of z is undefined +// but the returned value is nil. +func (z *Float) SetString(s string) (*Float, bool) { + if f, _, err := z.Parse(s, 0); err == nil { + return f, true + } + return nil, false +} + +// scan is like Parse but reads the longest possible prefix representing a valid +// floating point number from an io.ByteScanner rather than a string. It serves +// as the implementation of Parse. It does not recognize ±Inf and does not expect +// EOF at the end. +func (z *Float) scan(r io.ByteScanner, base int) (f *Float, b int, err error) { + prec := z.prec + if prec == 0 { + prec = 64 + } + + // A reasonable value in case of an error. + z.form = zero + + // sign + z.neg, err = scanSign(r) + if err != nil { + return + } + + // mantissa + var fcount int // fractional digit count; valid if <= 0 + z.mant, b, fcount, err = z.mant.scan(r, base, true) + if err != nil { + return + } + + // exponent + var exp int64 + var ebase int + exp, ebase, err = scanExponent(r, true) + if err != nil { + return + } + + // special-case 0 + if len(z.mant) == 0 { + z.prec = prec + z.acc = Exact + z.form = zero + f = z + return + } + // len(z.mant) > 0 + + // The mantissa may have a decimal point (fcount <= 0) and there + // may be a nonzero exponent exp. The decimal point amounts to a + // division by b**(-fcount). An exponent means multiplication by + // ebase**exp. Finally, mantissa normalization (shift left) requires + // a correcting multiplication by 2**(-shiftcount). Multiplications + // are commutative, so we can apply them in any order as long as there + // is no loss of precision. We only have powers of 2 and 10, and + // we split powers of 10 into the product of the same powers of + // 2 and 5. This reduces the size of the multiplication factor + // needed for base-10 exponents. + + // normalize mantissa and determine initial exponent contributions + exp2 := int64(len(z.mant))*_W - fnorm(z.mant) + exp5 := int64(0) + + // determine binary or decimal exponent contribution of decimal point + if fcount < 0 { + // The mantissa has a "decimal" point ddd.dddd; and + // -fcount is the number of digits to the right of '.'. + // Adjust relevant exponent accordingly. + d := int64(fcount) + switch b { + case 10: + exp5 = d + fallthrough // 10**e == 5**e * 2**e + case 2: + exp2 += d + case 16: + exp2 += d * 4 // hexadecimal digits are 4 bits each + default: + panic("unexpected mantissa base") + } + // fcount consumed - not needed anymore + } + + // take actual exponent into account + switch ebase { + case 10: + exp5 += exp + fallthrough + case 2: + exp2 += exp + default: + panic("unexpected exponent base") + } + // exp consumed - not needed anymore + + // apply 2**exp2 + if MinExp <= exp2 && exp2 <= MaxExp { + z.prec = prec + z.form = finite + z.exp = int32(exp2) + f = z + } else { + err = fmt.Errorf("exponent overflow") + return + } + + if exp5 == 0 { + // no decimal exponent contribution + z.round(0) + return + } + // exp5 != 0 + + // apply 5**exp5 + p := new(Float).SetPrec(z.Prec() + 64) // use more bits for p -- TODO(gri) what is the right number? + if exp5 < 0 { + z.Quo(z, p.pow5(uint64(-exp5))) + } else { + z.Mul(z, p.pow5(uint64(exp5))) + } + + return +} + +// These powers of 5 fit into a uint64. +// +// for p, q := uint64(0), uint64(1); p < q; p, q = q, q*5 { +// fmt.Println(q) +// } +// +var pow5tab = [...]uint64{ + 1, + 5, + 25, + 125, + 625, + 3125, + 15625, + 78125, + 390625, + 1953125, + 9765625, + 48828125, + 244140625, + 1220703125, + 6103515625, + 30517578125, + 152587890625, + 762939453125, + 3814697265625, + 19073486328125, + 95367431640625, + 476837158203125, + 2384185791015625, + 11920928955078125, + 59604644775390625, + 298023223876953125, + 1490116119384765625, + 7450580596923828125, +} + +// pow5 sets z to 5**n and returns z. +// n must not be negative. +func (z *Float) pow5(n uint64) *Float { + const m = uint64(len(pow5tab) - 1) + if n <= m { + return z.SetUint64(pow5tab[n]) + } + // n > m + + z.SetUint64(pow5tab[m]) + n -= m + + // use more bits for f than for z + // TODO(gri) what is the right number? + f := new(Float).SetPrec(z.Prec() + 64).SetUint64(5) + + for n > 0 { + if n&1 != 0 { + z.Mul(z, f) + } + f.Mul(f, f) + n >>= 1 + } + + return z +} + +// Parse parses s which must contain a text representation of a floating- +// point number with a mantissa in the given conversion base (the exponent +// is always a decimal number), or a string representing an infinite value. +// +// It sets z to the (possibly rounded) value of the corresponding floating- +// point value, and returns z, the actual base b, and an error err, if any. +// The entire string (not just a prefix) must be consumed for success. +// If z's precision is 0, it is changed to 64 before rounding takes effect. +// The number must be of the form: +// +// number = [ sign ] [ prefix ] mantissa [ exponent ] | infinity . +// sign = "+" | "-" . +// prefix = "0" ( "x" | "X" | "b" | "B" ) . +// mantissa = digits | digits "." [ digits ] | "." digits . +// exponent = ( "E" | "e" | "p" ) [ sign ] digits . +// digits = digit { digit } . +// digit = "0" ... "9" | "a" ... "z" | "A" ... "Z" . +// infinity = [ sign ] ( "inf" | "Inf" ) . +// +// The base argument must be 0, 2, 10, or 16. Providing an invalid base +// argument will lead to a run-time panic. +// +// For base 0, the number prefix determines the actual base: A prefix of +// "0x" or "0X" selects base 16, and a "0b" or "0B" prefix selects +// base 2; otherwise, the actual base is 10 and no prefix is accepted. +// The octal prefix "0" is not supported (a leading "0" is simply +// considered a "0"). +// +// A "p" exponent indicates a binary (rather then decimal) exponent; +// for instance "0x1.fffffffffffffp1023" (using base 0) represents the +// maximum float64 value. For hexadecimal mantissae, the exponent must +// be binary, if present (an "e" or "E" exponent indicator cannot be +// distinguished from a mantissa digit). +// +// The returned *Float f is nil and the value of z is valid but not +// defined if an error is reported. +// +func (z *Float) Parse(s string, base int) (f *Float, b int, err error) { + // scan doesn't handle ±Inf + if len(s) == 3 && (s == "Inf" || s == "inf") { + f = z.SetInf(false) + return + } + if len(s) == 4 && (s[0] == '+' || s[0] == '-') && (s[1:] == "Inf" || s[1:] == "inf") { + f = z.SetInf(s[0] == '-') + return + } + + r := strings.NewReader(s) + if f, b, err = z.scan(r, base); err != nil { + return + } + + // entire string must have been consumed + if ch, err2 := r.ReadByte(); err2 == nil { + err = fmt.Errorf("expected end of string, found %q", ch) + } else if err2 != io.EOF { + err = err2 + } + + return +} + +// ParseFloat is like f.Parse(s, base) with f set to the given precision +// and rounding mode. +func ParseFloat(s string, base int, prec uint, mode RoundingMode) (f *Float, b int, err error) { + return new(Float).SetPrec(prec).SetMode(mode).Parse(s, base) +} + +var _ fmt.Scanner = &floatZero // *Float must implement fmt.Scanner + +// Scan is a support routine for fmt.Scanner; it sets z to the value of +// the scanned number. It accepts formats whose verbs are supported by +// fmt.Scan for floating point values, which are: +// 'b' (binary), 'e', 'E', 'f', 'F', 'g' and 'G'. +// Scan doesn't handle ±Inf. +func (z *Float) Scan(s fmt.ScanState, ch rune) error { + s.SkipSpace() + _, _, err := z.scan(byteReader{s}, 0) + return err +} diff --git a/vendor/github.com/golang/go/src/math/big/floatmarsh.go b/vendor/github.com/golang/go/src/math/big/floatmarsh.go new file mode 100644 index 000000000000..d1c1dab06917 --- /dev/null +++ b/vendor/github.com/golang/go/src/math/big/floatmarsh.go @@ -0,0 +1,120 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements encoding/decoding of Floats. + +package big + +import ( + "encoding/binary" + "fmt" +) + +// Gob codec version. Permits backward-compatible changes to the encoding. +const floatGobVersion byte = 1 + +// GobEncode implements the gob.GobEncoder interface. +// The Float value and all its attributes (precision, +// rounding mode, accuracy) are marshaled. +func (x *Float) GobEncode() ([]byte, error) { + if x == nil { + return nil, nil + } + + // determine max. space (bytes) required for encoding + sz := 1 + 1 + 4 // version + mode|acc|form|neg (3+2+2+1bit) + prec + n := 0 // number of mantissa words + if x.form == finite { + // add space for mantissa and exponent + n = int((x.prec + (_W - 1)) / _W) // required mantissa length in words for given precision + // actual mantissa slice could be shorter (trailing 0's) or longer (unused bits): + // - if shorter, only encode the words present + // - if longer, cut off unused words when encoding in bytes + // (in practice, this should never happen since rounding + // takes care of it, but be safe and do it always) + if len(x.mant) < n { + n = len(x.mant) + } + // len(x.mant) >= n + sz += 4 + n*_S // exp + mant + } + buf := make([]byte, sz) + + buf[0] = floatGobVersion + b := byte(x.mode&7)<<5 | byte((x.acc+1)&3)<<3 | byte(x.form&3)<<1 + if x.neg { + b |= 1 + } + buf[1] = b + binary.BigEndian.PutUint32(buf[2:], x.prec) + + if x.form == finite { + binary.BigEndian.PutUint32(buf[6:], uint32(x.exp)) + x.mant[len(x.mant)-n:].bytes(buf[10:]) // cut off unused trailing words + } + + return buf, nil +} + +// GobDecode implements the gob.GobDecoder interface. +// The result is rounded per the precision and rounding mode of +// z unless z's precision is 0, in which case z is set exactly +// to the decoded value. +func (z *Float) GobDecode(buf []byte) error { + if len(buf) == 0 { + // Other side sent a nil or default value. + *z = Float{} + return nil + } + + if buf[0] != floatGobVersion { + return fmt.Errorf("Float.GobDecode: encoding version %d not supported", buf[0]) + } + + oldPrec := z.prec + oldMode := z.mode + + b := buf[1] + z.mode = RoundingMode((b >> 5) & 7) + z.acc = Accuracy((b>>3)&3) - 1 + z.form = form((b >> 1) & 3) + z.neg = b&1 != 0 + z.prec = binary.BigEndian.Uint32(buf[2:]) + + if z.form == finite { + z.exp = int32(binary.BigEndian.Uint32(buf[6:])) + z.mant = z.mant.setBytes(buf[10:]) + } + + if oldPrec != 0 { + z.mode = oldMode + z.SetPrec(uint(oldPrec)) + } + + return nil +} + +// MarshalText implements the encoding.TextMarshaler interface. +// Only the Float value is marshaled (in full precision), other +// attributes such as precision or accuracy are ignored. +func (x *Float) MarshalText() (text []byte, err error) { + if x == nil { + return []byte(""), nil + } + var buf []byte + return x.Append(buf, 'g', -1), nil +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface. +// The result is rounded per the precision and rounding mode of z. +// If z's precision is 0, it is changed to 64 before rounding takes +// effect. +func (z *Float) UnmarshalText(text []byte) error { + // TODO(gri): get rid of the []byte/string conversion + _, _, err := z.Parse(string(text), 0) + if err != nil { + err = fmt.Errorf("math/big: cannot unmarshal %q into a *big.Float (%v)", text, err) + } + return err +} diff --git a/vendor/github.com/golang/go/src/math/big/ftoa.go b/vendor/github.com/golang/go/src/math/big/ftoa.go new file mode 100644 index 000000000000..d2a85886c72d --- /dev/null +++ b/vendor/github.com/golang/go/src/math/big/ftoa.go @@ -0,0 +1,461 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements Float-to-string conversion functions. +// It is closely following the corresponding implementation +// in strconv/ftoa.go, but modified and simplified for Float. + +package big + +import ( + "bytes" + "fmt" + "strconv" +) + +// Text converts the floating-point number x to a string according +// to the given format and precision prec. The format is one of: +// +// 'e' -d.dddde±dd, decimal exponent, at least two (possibly 0) exponent digits +// 'E' -d.ddddE±dd, decimal exponent, at least two (possibly 0) exponent digits +// 'f' -ddddd.dddd, no exponent +// 'g' like 'e' for large exponents, like 'f' otherwise +// 'G' like 'E' for large exponents, like 'f' otherwise +// 'b' -ddddddp±dd, binary exponent +// 'p' -0x.dddp±dd, binary exponent, hexadecimal mantissa +// +// For the binary exponent formats, the mantissa is printed in normalized form: +// +// 'b' decimal integer mantissa using x.Prec() bits, or -0 +// 'p' hexadecimal fraction with 0.5 <= 0.mantissa < 1.0, or -0 +// +// If format is a different character, Text returns a "%" followed by the +// unrecognized format character. +// +// The precision prec controls the number of digits (excluding the exponent) +// printed by the 'e', 'E', 'f', 'g', and 'G' formats. For 'e', 'E', and 'f' +// it is the number of digits after the decimal point. For 'g' and 'G' it is +// the total number of digits. A negative precision selects the smallest +// number of decimal digits necessary to identify the value x uniquely using +// x.Prec() mantissa bits. +// The prec value is ignored for the 'b' or 'p' format. +func (x *Float) Text(format byte, prec int) string { + cap := 10 // TODO(gri) determine a good/better value here + if prec > 0 { + cap += prec + } + return string(x.Append(make([]byte, 0, cap), format, prec)) +} + +// String formats x like x.Text('g', 10). +// (String must be called explicitly, Float.Format does not support %s verb.) +func (x *Float) String() string { + return x.Text('g', 10) +} + +// Append appends to buf the string form of the floating-point number x, +// as generated by x.Text, and returns the extended buffer. +func (x *Float) Append(buf []byte, fmt byte, prec int) []byte { + // sign + if x.neg { + buf = append(buf, '-') + } + + // Inf + if x.form == inf { + if !x.neg { + buf = append(buf, '+') + } + return append(buf, "Inf"...) + } + + // pick off easy formats + switch fmt { + case 'b': + return x.fmtB(buf) + case 'p': + return x.fmtP(buf) + } + + // Algorithm: + // 1) convert Float to multiprecision decimal + // 2) round to desired precision + // 3) read digits out and format + + // 1) convert Float to multiprecision decimal + var d decimal // == 0.0 + if x.form == finite { + // x != 0 + d.init(x.mant, int(x.exp)-x.mant.bitLen()) + } + + // 2) round to desired precision + shortest := false + if prec < 0 { + shortest = true + roundShortest(&d, x) + // Precision for shortest representation mode. + switch fmt { + case 'e', 'E': + prec = len(d.mant) - 1 + case 'f': + prec = max(len(d.mant)-d.exp, 0) + case 'g', 'G': + prec = len(d.mant) + } + } else { + // round appropriately + switch fmt { + case 'e', 'E': + // one digit before and number of digits after decimal point + d.round(1 + prec) + case 'f': + // number of digits before and after decimal point + d.round(d.exp + prec) + case 'g', 'G': + if prec == 0 { + prec = 1 + } + d.round(prec) + } + } + + // 3) read digits out and format + switch fmt { + case 'e', 'E': + return fmtE(buf, fmt, prec, d) + case 'f': + return fmtF(buf, prec, d) + case 'g', 'G': + // trim trailing fractional zeros in %e format + eprec := prec + if eprec > len(d.mant) && len(d.mant) >= d.exp { + eprec = len(d.mant) + } + // %e is used if the exponent from the conversion + // is less than -4 or greater than or equal to the precision. + // If precision was the shortest possible, use eprec = 6 for + // this decision. + if shortest { + eprec = 6 + } + exp := d.exp - 1 + if exp < -4 || exp >= eprec { + if prec > len(d.mant) { + prec = len(d.mant) + } + return fmtE(buf, fmt+'e'-'g', prec-1, d) + } + if prec > d.exp { + prec = len(d.mant) + } + return fmtF(buf, max(prec-d.exp, 0), d) + } + + // unknown format + if x.neg { + buf = buf[:len(buf)-1] // sign was added prematurely - remove it again + } + return append(buf, '%', fmt) +} + +func roundShortest(d *decimal, x *Float) { + // if the mantissa is zero, the number is zero - stop now + if len(d.mant) == 0 { + return + } + + // Approach: All numbers in the interval [x - 1/2ulp, x + 1/2ulp] + // (possibly exclusive) round to x for the given precision of x. + // Compute the lower and upper bound in decimal form and find the + // shortest decimal number d such that lower <= d <= upper. + + // TODO(gri) strconv/ftoa.do describes a shortcut in some cases. + // See if we can use it (in adjusted form) here as well. + + // 1) Compute normalized mantissa mant and exponent exp for x such + // that the lsb of mant corresponds to 1/2 ulp for the precision of + // x (i.e., for mant we want x.prec + 1 bits). + mant := nat(nil).set(x.mant) + exp := int(x.exp) - mant.bitLen() + s := mant.bitLen() - int(x.prec+1) + switch { + case s < 0: + mant = mant.shl(mant, uint(-s)) + case s > 0: + mant = mant.shr(mant, uint(+s)) + } + exp += s + // x = mant * 2**exp with lsb(mant) == 1/2 ulp of x.prec + + // 2) Compute lower bound by subtracting 1/2 ulp. + var lower decimal + var tmp nat + lower.init(tmp.sub(mant, natOne), exp) + + // 3) Compute upper bound by adding 1/2 ulp. + var upper decimal + upper.init(tmp.add(mant, natOne), exp) + + // The upper and lower bounds are possible outputs only if + // the original mantissa is even, so that ToNearestEven rounding + // would round to the original mantissa and not the neighbors. + inclusive := mant[0]&2 == 0 // test bit 1 since original mantissa was shifted by 1 + + // Now we can figure out the minimum number of digits required. + // Walk along until d has distinguished itself from upper and lower. + for i, m := range d.mant { + l := lower.at(i) + u := upper.at(i) + + // Okay to round down (truncate) if lower has a different digit + // or if lower is inclusive and is exactly the result of rounding + // down (i.e., and we have reached the final digit of lower). + okdown := l != m || inclusive && i+1 == len(lower.mant) + + // Okay to round up if upper has a different digit and either upper + // is inclusive or upper is bigger than the result of rounding up. + okup := m != u && (inclusive || m+1 < u || i+1 < len(upper.mant)) + + // If it's okay to do either, then round to the nearest one. + // If it's okay to do only one, do it. + switch { + case okdown && okup: + d.round(i + 1) + return + case okdown: + d.roundDown(i + 1) + return + case okup: + d.roundUp(i + 1) + return + } + } +} + +// %e: d.ddddde±dd +func fmtE(buf []byte, fmt byte, prec int, d decimal) []byte { + // first digit + ch := byte('0') + if len(d.mant) > 0 { + ch = d.mant[0] + } + buf = append(buf, ch) + + // .moredigits + if prec > 0 { + buf = append(buf, '.') + i := 1 + m := min(len(d.mant), prec+1) + if i < m { + buf = append(buf, d.mant[i:m]...) + i = m + } + for ; i <= prec; i++ { + buf = append(buf, '0') + } + } + + // e± + buf = append(buf, fmt) + var exp int64 + if len(d.mant) > 0 { + exp = int64(d.exp) - 1 // -1 because first digit was printed before '.' + } + if exp < 0 { + ch = '-' + exp = -exp + } else { + ch = '+' + } + buf = append(buf, ch) + + // dd...d + if exp < 10 { + buf = append(buf, '0') // at least 2 exponent digits + } + return strconv.AppendInt(buf, exp, 10) +} + +// %f: ddddddd.ddddd +func fmtF(buf []byte, prec int, d decimal) []byte { + // integer, padded with zeros as needed + if d.exp > 0 { + m := min(len(d.mant), d.exp) + buf = append(buf, d.mant[:m]...) + for ; m < d.exp; m++ { + buf = append(buf, '0') + } + } else { + buf = append(buf, '0') + } + + // fraction + if prec > 0 { + buf = append(buf, '.') + for i := 0; i < prec; i++ { + buf = append(buf, d.at(d.exp+i)) + } + } + + return buf +} + +// fmtB appends the string of x in the format mantissa "p" exponent +// with a decimal mantissa and a binary exponent, or 0" if x is zero, +// and returns the extended buffer. +// The mantissa is normalized such that is uses x.Prec() bits in binary +// representation. +// The sign of x is ignored, and x must not be an Inf. +func (x *Float) fmtB(buf []byte) []byte { + if x.form == zero { + return append(buf, '0') + } + + if debugFloat && x.form != finite { + panic("non-finite float") + } + // x != 0 + + // adjust mantissa to use exactly x.prec bits + m := x.mant + switch w := uint32(len(x.mant)) * _W; { + case w < x.prec: + m = nat(nil).shl(m, uint(x.prec-w)) + case w > x.prec: + m = nat(nil).shr(m, uint(w-x.prec)) + } + + buf = append(buf, m.utoa(10)...) + buf = append(buf, 'p') + e := int64(x.exp) - int64(x.prec) + if e >= 0 { + buf = append(buf, '+') + } + return strconv.AppendInt(buf, e, 10) +} + +// fmtP appends the string of x in the format "0x." mantissa "p" exponent +// with a hexadecimal mantissa and a binary exponent, or "0" if x is zero, +// and returns the extended buffer. +// The mantissa is normalized such that 0.5 <= 0.mantissa < 1.0. +// The sign of x is ignored, and x must not be an Inf. +func (x *Float) fmtP(buf []byte) []byte { + if x.form == zero { + return append(buf, '0') + } + + if debugFloat && x.form != finite { + panic("non-finite float") + } + // x != 0 + + // remove trailing 0 words early + // (no need to convert to hex 0's and trim later) + m := x.mant + i := 0 + for i < len(m) && m[i] == 0 { + i++ + } + m = m[i:] + + buf = append(buf, "0x."...) + buf = append(buf, bytes.TrimRight(m.utoa(16), "0")...) + buf = append(buf, 'p') + if x.exp >= 0 { + buf = append(buf, '+') + } + return strconv.AppendInt(buf, int64(x.exp), 10) +} + +func min(x, y int) int { + if x < y { + return x + } + return y +} + +var _ fmt.Formatter = &floatZero // *Float must implement fmt.Formatter + +// Format implements fmt.Formatter. It accepts all the regular +// formats for floating-point numbers ('b', 'e', 'E', 'f', 'F', +// 'g', 'G') as well as 'p' and 'v'. See (*Float).Text for the +// interpretation of 'p'. The 'v' format is handled like 'g'. +// Format also supports specification of the minimum precision +// in digits, the output field width, as well as the format flags +// '+' and ' ' for sign control, '0' for space or zero padding, +// and '-' for left or right justification. See the fmt package +// for details. +func (x *Float) Format(s fmt.State, format rune) { + prec, hasPrec := s.Precision() + if !hasPrec { + prec = 6 // default precision for 'e', 'f' + } + + switch format { + case 'e', 'E', 'f', 'b', 'p': + // nothing to do + case 'F': + // (*Float).Text doesn't support 'F'; handle like 'f' + format = 'f' + case 'v': + // handle like 'g' + format = 'g' + fallthrough + case 'g', 'G': + if !hasPrec { + prec = -1 // default precision for 'g', 'G' + } + default: + fmt.Fprintf(s, "%%!%c(*big.Float=%s)", format, x.String()) + return + } + var buf []byte + buf = x.Append(buf, byte(format), prec) + if len(buf) == 0 { + buf = []byte("?") // should never happen, but don't crash + } + // len(buf) > 0 + + var sign string + switch { + case buf[0] == '-': + sign = "-" + buf = buf[1:] + case buf[0] == '+': + // +Inf + sign = "+" + if s.Flag(' ') { + sign = " " + } + buf = buf[1:] + case s.Flag('+'): + sign = "+" + case s.Flag(' '): + sign = " " + } + + var padding int + if width, hasWidth := s.Width(); hasWidth && width > len(sign)+len(buf) { + padding = width - len(sign) - len(buf) + } + + switch { + case s.Flag('0') && !x.IsInf(): + // 0-padding on left + writeMultiple(s, sign, 1) + writeMultiple(s, "0", padding) + s.Write(buf) + case s.Flag('-'): + // padding on right + writeMultiple(s, sign, 1) + s.Write(buf) + writeMultiple(s, " ", padding) + default: + // padding on left + writeMultiple(s, " ", padding) + writeMultiple(s, sign, 1) + s.Write(buf) + } +} diff --git a/vendor/github.com/golang/go/src/math/big/int.go b/vendor/github.com/golang/go/src/math/big/int.go new file mode 100644 index 000000000000..0eda9cd4e123 --- /dev/null +++ b/vendor/github.com/golang/go/src/math/big/int.go @@ -0,0 +1,1033 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements signed multi-precision integers. + +package big + +import ( + "fmt" + "io" + "math/rand" + "strings" +) + +// An Int represents a signed multi-precision integer. +// The zero value for an Int represents the value 0. +type Int struct { + neg bool // sign + abs nat // absolute value of the integer +} + +var intOne = &Int{false, natOne} + +// Sign returns: +// +// -1 if x < 0 +// 0 if x == 0 +// +1 if x > 0 +// +func (x *Int) Sign() int { + if len(x.abs) == 0 { + return 0 + } + if x.neg { + return -1 + } + return 1 +} + +// SetInt64 sets z to x and returns z. +func (z *Int) SetInt64(x int64) *Int { + neg := false + if x < 0 { + neg = true + x = -x + } + z.abs = z.abs.setUint64(uint64(x)) + z.neg = neg + return z +} + +// SetUint64 sets z to x and returns z. +func (z *Int) SetUint64(x uint64) *Int { + z.abs = z.abs.setUint64(x) + z.neg = false + return z +} + +// NewInt allocates and returns a new Int set to x. +func NewInt(x int64) *Int { + return new(Int).SetInt64(x) +} + +// Set sets z to x and returns z. +func (z *Int) Set(x *Int) *Int { + if z != x { + z.abs = z.abs.set(x.abs) + z.neg = x.neg + } + return z +} + +// Bits provides raw (unchecked but fast) access to x by returning its +// absolute value as a little-endian Word slice. The result and x share +// the same underlying array. +// Bits is intended to support implementation of missing low-level Int +// functionality outside this package; it should be avoided otherwise. +func (x *Int) Bits() []Word { + return x.abs +} + +// SetBits provides raw (unchecked but fast) access to z by setting its +// value to abs, interpreted as a little-endian Word slice, and returning +// z. The result and abs share the same underlying array. +// SetBits is intended to support implementation of missing low-level Int +// functionality outside this package; it should be avoided otherwise. +func (z *Int) SetBits(abs []Word) *Int { + z.abs = nat(abs).norm() + z.neg = false + return z +} + +// Abs sets z to |x| (the absolute value of x) and returns z. +func (z *Int) Abs(x *Int) *Int { + z.Set(x) + z.neg = false + return z +} + +// Neg sets z to -x and returns z. +func (z *Int) Neg(x *Int) *Int { + z.Set(x) + z.neg = len(z.abs) > 0 && !z.neg // 0 has no sign + return z +} + +// Add sets z to the sum x+y and returns z. +func (z *Int) Add(x, y *Int) *Int { + neg := x.neg + if x.neg == y.neg { + // x + y == x + y + // (-x) + (-y) == -(x + y) + z.abs = z.abs.add(x.abs, y.abs) + } else { + // x + (-y) == x - y == -(y - x) + // (-x) + y == y - x == -(x - y) + if x.abs.cmp(y.abs) >= 0 { + z.abs = z.abs.sub(x.abs, y.abs) + } else { + neg = !neg + z.abs = z.abs.sub(y.abs, x.abs) + } + } + z.neg = len(z.abs) > 0 && neg // 0 has no sign + return z +} + +// Sub sets z to the difference x-y and returns z. +func (z *Int) Sub(x, y *Int) *Int { + neg := x.neg + if x.neg != y.neg { + // x - (-y) == x + y + // (-x) - y == -(x + y) + z.abs = z.abs.add(x.abs, y.abs) + } else { + // x - y == x - y == -(y - x) + // (-x) - (-y) == y - x == -(x - y) + if x.abs.cmp(y.abs) >= 0 { + z.abs = z.abs.sub(x.abs, y.abs) + } else { + neg = !neg + z.abs = z.abs.sub(y.abs, x.abs) + } + } + z.neg = len(z.abs) > 0 && neg // 0 has no sign + return z +} + +// Mul sets z to the product x*y and returns z. +func (z *Int) Mul(x, y *Int) *Int { + // x * y == x * y + // x * (-y) == -(x * y) + // (-x) * y == -(x * y) + // (-x) * (-y) == x * y + if x == y { + z.abs = z.abs.sqr(x.abs) + z.neg = false + return z + } + z.abs = z.abs.mul(x.abs, y.abs) + z.neg = len(z.abs) > 0 && x.neg != y.neg // 0 has no sign + return z +} + +// MulRange sets z to the product of all integers +// in the range [a, b] inclusively and returns z. +// If a > b (empty range), the result is 1. +func (z *Int) MulRange(a, b int64) *Int { + switch { + case a > b: + return z.SetInt64(1) // empty range + case a <= 0 && b >= 0: + return z.SetInt64(0) // range includes 0 + } + // a <= b && (b < 0 || a > 0) + + neg := false + if a < 0 { + neg = (b-a)&1 == 0 + a, b = -b, -a + } + + z.abs = z.abs.mulRange(uint64(a), uint64(b)) + z.neg = neg + return z +} + +// Binomial sets z to the binomial coefficient of (n, k) and returns z. +func (z *Int) Binomial(n, k int64) *Int { + // reduce the number of multiplications by reducing k + if n/2 < k && k <= n { + k = n - k // Binomial(n, k) == Binomial(n, n-k) + } + var a, b Int + a.MulRange(n-k+1, n) + b.MulRange(1, k) + return z.Quo(&a, &b) +} + +// Quo sets z to the quotient x/y for y != 0 and returns z. +// If y == 0, a division-by-zero run-time panic occurs. +// Quo implements truncated division (like Go); see QuoRem for more details. +func (z *Int) Quo(x, y *Int) *Int { + z.abs, _ = z.abs.div(nil, x.abs, y.abs) + z.neg = len(z.abs) > 0 && x.neg != y.neg // 0 has no sign + return z +} + +// Rem sets z to the remainder x%y for y != 0 and returns z. +// If y == 0, a division-by-zero run-time panic occurs. +// Rem implements truncated modulus (like Go); see QuoRem for more details. +func (z *Int) Rem(x, y *Int) *Int { + _, z.abs = nat(nil).div(z.abs, x.abs, y.abs) + z.neg = len(z.abs) > 0 && x.neg // 0 has no sign + return z +} + +// QuoRem sets z to the quotient x/y and r to the remainder x%y +// and returns the pair (z, r) for y != 0. +// If y == 0, a division-by-zero run-time panic occurs. +// +// QuoRem implements T-division and modulus (like Go): +// +// q = x/y with the result truncated to zero +// r = x - y*q +// +// (See Daan Leijen, ``Division and Modulus for Computer Scientists''.) +// See DivMod for Euclidean division and modulus (unlike Go). +// +func (z *Int) QuoRem(x, y, r *Int) (*Int, *Int) { + z.abs, r.abs = z.abs.div(r.abs, x.abs, y.abs) + z.neg, r.neg = len(z.abs) > 0 && x.neg != y.neg, len(r.abs) > 0 && x.neg // 0 has no sign + return z, r +} + +// Div sets z to the quotient x/y for y != 0 and returns z. +// If y == 0, a division-by-zero run-time panic occurs. +// Div implements Euclidean division (unlike Go); see DivMod for more details. +func (z *Int) Div(x, y *Int) *Int { + y_neg := y.neg // z may be an alias for y + var r Int + z.QuoRem(x, y, &r) + if r.neg { + if y_neg { + z.Add(z, intOne) + } else { + z.Sub(z, intOne) + } + } + return z +} + +// Mod sets z to the modulus x%y for y != 0 and returns z. +// If y == 0, a division-by-zero run-time panic occurs. +// Mod implements Euclidean modulus (unlike Go); see DivMod for more details. +func (z *Int) Mod(x, y *Int) *Int { + y0 := y // save y + if z == y || alias(z.abs, y.abs) { + y0 = new(Int).Set(y) + } + var q Int + q.QuoRem(x, y, z) + if z.neg { + if y0.neg { + z.Sub(z, y0) + } else { + z.Add(z, y0) + } + } + return z +} + +// DivMod sets z to the quotient x div y and m to the modulus x mod y +// and returns the pair (z, m) for y != 0. +// If y == 0, a division-by-zero run-time panic occurs. +// +// DivMod implements Euclidean division and modulus (unlike Go): +// +// q = x div y such that +// m = x - y*q with 0 <= m < |y| +// +// (See Raymond T. Boute, ``The Euclidean definition of the functions +// div and mod''. ACM Transactions on Programming Languages and +// Systems (TOPLAS), 14(2):127-144, New York, NY, USA, 4/1992. +// ACM press.) +// See QuoRem for T-division and modulus (like Go). +// +func (z *Int) DivMod(x, y, m *Int) (*Int, *Int) { + y0 := y // save y + if z == y || alias(z.abs, y.abs) { + y0 = new(Int).Set(y) + } + z.QuoRem(x, y, m) + if m.neg { + if y0.neg { + z.Add(z, intOne) + m.Sub(m, y0) + } else { + z.Sub(z, intOne) + m.Add(m, y0) + } + } + return z, m +} + +// Cmp compares x and y and returns: +// +// -1 if x < y +// 0 if x == y +// +1 if x > y +// +func (x *Int) Cmp(y *Int) (r int) { + // x cmp y == x cmp y + // x cmp (-y) == x + // (-x) cmp y == y + // (-x) cmp (-y) == -(x cmp y) + switch { + case x.neg == y.neg: + r = x.abs.cmp(y.abs) + if x.neg { + r = -r + } + case x.neg: + r = -1 + default: + r = 1 + } + return +} + +// CmpAbs compares the absolute values of x and y and returns: +// +// -1 if |x| < |y| +// 0 if |x| == |y| +// +1 if |x| > |y| +// +func (x *Int) CmpAbs(y *Int) int { + return x.abs.cmp(y.abs) +} + +// low32 returns the least significant 32 bits of x. +func low32(x nat) uint32 { + if len(x) == 0 { + return 0 + } + return uint32(x[0]) +} + +// low64 returns the least significant 64 bits of x. +func low64(x nat) uint64 { + if len(x) == 0 { + return 0 + } + v := uint64(x[0]) + if _W == 32 && len(x) > 1 { + return uint64(x[1])<<32 | v + } + return v +} + +// Int64 returns the int64 representation of x. +// If x cannot be represented in an int64, the result is undefined. +func (x *Int) Int64() int64 { + v := int64(low64(x.abs)) + if x.neg { + v = -v + } + return v +} + +// Uint64 returns the uint64 representation of x. +// If x cannot be represented in a uint64, the result is undefined. +func (x *Int) Uint64() uint64 { + return low64(x.abs) +} + +// IsInt64 reports whether x can be represented as an int64. +func (x *Int) IsInt64() bool { + if len(x.abs) <= 64/_W { + w := int64(low64(x.abs)) + return w >= 0 || x.neg && w == -w + } + return false +} + +// IsUint64 reports whether x can be represented as a uint64. +func (x *Int) IsUint64() bool { + return !x.neg && len(x.abs) <= 64/_W +} + +// SetString sets z to the value of s, interpreted in the given base, +// and returns z and a boolean indicating success. The entire string +// (not just a prefix) must be valid for success. If SetString fails, +// the value of z is undefined but the returned value is nil. +// +// The base argument must be 0 or a value between 2 and MaxBase. If the base +// is 0, the string prefix determines the actual conversion base. A prefix of +// ``0x'' or ``0X'' selects base 16; the ``0'' prefix selects base 8, and a +// ``0b'' or ``0B'' prefix selects base 2. Otherwise the selected base is 10. +// +// For bases <= 36, lower and upper case letters are considered the same: +// The letters 'a' to 'z' and 'A' to 'Z' represent digit values 10 to 35. +// For bases > 36, the upper case letters 'A' to 'Z' represent the digit +// values 36 to 61. +// +func (z *Int) SetString(s string, base int) (*Int, bool) { + return z.setFromScanner(strings.NewReader(s), base) +} + +// setFromScanner implements SetString given an io.BytesScanner. +// For documentation see comments of SetString. +func (z *Int) setFromScanner(r io.ByteScanner, base int) (*Int, bool) { + if _, _, err := z.scan(r, base); err != nil { + return nil, false + } + // entire content must have been consumed + if _, err := r.ReadByte(); err != io.EOF { + return nil, false + } + return z, true // err == io.EOF => scan consumed all content of r +} + +// SetBytes interprets buf as the bytes of a big-endian unsigned +// integer, sets z to that value, and returns z. +func (z *Int) SetBytes(buf []byte) *Int { + z.abs = z.abs.setBytes(buf) + z.neg = false + return z +} + +// Bytes returns the absolute value of x as a big-endian byte slice. +func (x *Int) Bytes() []byte { + buf := make([]byte, len(x.abs)*_S) + return buf[x.abs.bytes(buf):] +} + +// BitLen returns the length of the absolute value of x in bits. +// The bit length of 0 is 0. +func (x *Int) BitLen() int { + return x.abs.bitLen() +} + +// Exp sets z = x**y mod |m| (i.e. the sign of m is ignored), and returns z. +// If y <= 0, the result is 1 mod |m|; if m == nil or m == 0, z = x**y. +// +// Modular exponentation of inputs of a particular size is not a +// cryptographically constant-time operation. +func (z *Int) Exp(x, y, m *Int) *Int { + // See Knuth, volume 2, section 4.6.3. + var yWords nat + if !y.neg { + yWords = y.abs + } + // y >= 0 + + var mWords nat + if m != nil { + mWords = m.abs // m.abs may be nil for m == 0 + } + + z.abs = z.abs.expNN(x.abs, yWords, mWords) + z.neg = len(z.abs) > 0 && x.neg && len(yWords) > 0 && yWords[0]&1 == 1 // 0 has no sign + if z.neg && len(mWords) > 0 { + // make modulus result positive + z.abs = z.abs.sub(mWords, z.abs) // z == x**y mod |m| && 0 <= z < |m| + z.neg = false + } + + return z +} + +// GCD sets z to the greatest common divisor of a and b, which both must +// be > 0, and returns z. +// If x or y are not nil, GCD sets their value such that z = a*x + b*y. +// If either a or b is <= 0, GCD sets z = x = y = 0. +func (z *Int) GCD(x, y, a, b *Int) *Int { + if a.Sign() <= 0 || b.Sign() <= 0 { + z.SetInt64(0) + if x != nil { + x.SetInt64(0) + } + if y != nil { + y.SetInt64(0) + } + return z + } + if x == nil && y == nil { + return z.lehmerGCD(a, b) + } + + A := new(Int).Set(a) + B := new(Int).Set(b) + + X := new(Int) + lastX := new(Int).SetInt64(1) + + q := new(Int) + temp := new(Int) + + r := new(Int) + for len(B.abs) > 0 { + q, r = q.QuoRem(A, B, r) + + A, B, r = B, r, A + + temp.Set(X) + X.Mul(X, q) + X.Sub(lastX, X) + lastX.Set(temp) + } + + if x != nil { + *x = *lastX + } + + if y != nil { + // y = (z - a*x)/b + y.Mul(a, lastX) + y.Sub(A, y) + y.Div(y, b) + } + + *z = *A + return z +} + +// lehmerGCD sets z to the greatest common divisor of a and b, +// which both must be > 0, and returns z. +// See Knuth, The Art of Computer Programming, Vol. 2, Section 4.5.2, Algorithm L. +// This implementation uses the improved condition by Collins requiring only one +// quotient and avoiding the possibility of single Word overflow. +// See Jebelean, "Improving the multiprecision Euclidean algorithm", +// Design and Implementation of Symbolic Computation Systems, pp 45-58. +func (z *Int) lehmerGCD(a, b *Int) *Int { + // ensure a >= b + if a.abs.cmp(b.abs) < 0 { + a, b = b, a + } + + // don't destroy incoming values of a and b + B := new(Int).Set(b) // must be set first in case b is an alias of z + A := z.Set(a) + + // temp variables for multiprecision update + t := new(Int) + r := new(Int) + s := new(Int) + w := new(Int) + + // loop invariant A >= B + for len(B.abs) > 1 { + // initialize the digits + var a1, a2, u0, u1, u2, v0, v1, v2 Word + + m := len(B.abs) // m >= 2 + n := len(A.abs) // n >= m >= 2 + + // extract the top Word of bits from A and B + h := nlz(A.abs[n-1]) + a1 = (A.abs[n-1] << h) | (A.abs[n-2] >> (_W - h)) + // B may have implicit zero words in the high bits if the lengths differ + switch { + case n == m: + a2 = (B.abs[n-1] << h) | (B.abs[n-2] >> (_W - h)) + case n == m+1: + a2 = (B.abs[n-2] >> (_W - h)) + default: + a2 = 0 + } + + // Since we are calculating with full words to avoid overflow, + // we use 'even' to track the sign of the cosequences. + // For even iterations: u0, v1 >= 0 && u1, v0 <= 0 + // For odd iterations: u0, v1 <= 0 && u1, v0 >= 0 + // The first iteration starts with k=1 (odd). + even := false + // variables to track the cosequences + u0, u1, u2 = 0, 1, 0 + v0, v1, v2 = 0, 0, 1 + + // Calculate the quotient and cosequences using Collins' stopping condition. + // Note that overflow of a Word is not possible when computing the remainder + // sequence and cosequences since the cosequence size is bounded by the input size. + // See section 4.2 of Jebelean for details. + for a2 >= v2 && a1-a2 >= v1+v2 { + q := a1 / a2 + a1, a2 = a2, a1-q*a2 + u0, u1, u2 = u1, u2, u1+q*u2 + v0, v1, v2 = v1, v2, v1+q*v2 + even = !even + } + + // multiprecision step + if v0 != 0 { + // simulate the effect of the single precision steps using the cosequences + // A = u0*A + v0*B + // B = u1*A + v1*B + + t.abs = t.abs.setWord(u0) + s.abs = s.abs.setWord(v0) + t.neg = !even + s.neg = even + + t.Mul(A, t) + s.Mul(B, s) + + r.abs = r.abs.setWord(u1) + w.abs = w.abs.setWord(v1) + r.neg = even + w.neg = !even + + r.Mul(A, r) + w.Mul(B, w) + + A.Add(t, s) + B.Add(r, w) + + } else { + // single-digit calculations failed to simluate any quotients + // do a standard Euclidean step + t.Rem(A, B) + A, B, t = B, t, A + } + } + + if len(B.abs) > 0 { + // standard Euclidean algorithm base case for B a single Word + if len(A.abs) > 1 { + // A is longer than a single Word + t.Rem(A, B) + A, B, t = B, t, A + } + if len(B.abs) > 0 { + // A and B are both a single Word + a1, a2 := A.abs[0], B.abs[0] + for a2 != 0 { + a1, a2 = a2, a1%a2 + } + A.abs[0] = a1 + } + } + *z = *A + return z +} + +// Rand sets z to a pseudo-random number in [0, n) and returns z. +// +// As this uses the math/rand package, it must not be used for +// security-sensitive work. Use crypto/rand.Int instead. +func (z *Int) Rand(rnd *rand.Rand, n *Int) *Int { + z.neg = false + if n.neg || len(n.abs) == 0 { + z.abs = nil + return z + } + z.abs = z.abs.random(rnd, n.abs, n.abs.bitLen()) + return z +} + +// ModInverse sets z to the multiplicative inverse of g in the ring ℤ/nℤ +// and returns z. If g and n are not relatively prime, the result is undefined. +func (z *Int) ModInverse(g, n *Int) *Int { + if g.neg { + // GCD expects parameters a and b to be > 0. + var g2 Int + g = g2.Mod(g, n) + } + var d Int + d.GCD(z, nil, g, n) + // x and y are such that g*x + n*y = d. Since g and n are + // relatively prime, d = 1. Taking that modulo n results in + // g*x = 1, therefore x is the inverse element. + if z.neg { + z.Add(z, n) + } + return z +} + +// Jacobi returns the Jacobi symbol (x/y), either +1, -1, or 0. +// The y argument must be an odd integer. +func Jacobi(x, y *Int) int { + if len(y.abs) == 0 || y.abs[0]&1 == 0 { + panic(fmt.Sprintf("big: invalid 2nd argument to Int.Jacobi: need odd integer but got %s", y)) + } + + // We use the formulation described in chapter 2, section 2.4, + // "The Yacas Book of Algorithms": + // http://yacas.sourceforge.net/Algo.book.pdf + + var a, b, c Int + a.Set(x) + b.Set(y) + j := 1 + + if b.neg { + if a.neg { + j = -1 + } + b.neg = false + } + + for { + if b.Cmp(intOne) == 0 { + return j + } + if len(a.abs) == 0 { + return 0 + } + a.Mod(&a, &b) + if len(a.abs) == 0 { + return 0 + } + // a > 0 + + // handle factors of 2 in 'a' + s := a.abs.trailingZeroBits() + if s&1 != 0 { + bmod8 := b.abs[0] & 7 + if bmod8 == 3 || bmod8 == 5 { + j = -j + } + } + c.Rsh(&a, s) // a = 2^s*c + + // swap numerator and denominator + if b.abs[0]&3 == 3 && c.abs[0]&3 == 3 { + j = -j + } + a.Set(&b) + b.Set(&c) + } +} + +// modSqrt3Mod4 uses the identity +// (a^((p+1)/4))^2 mod p +// == u^(p+1) mod p +// == u^2 mod p +// to calculate the square root of any quadratic residue mod p quickly for 3 +// mod 4 primes. +func (z *Int) modSqrt3Mod4Prime(x, p *Int) *Int { + e := new(Int).Add(p, intOne) // e = p + 1 + e.Rsh(e, 2) // e = (p + 1) / 4 + z.Exp(x, e, p) // z = x^e mod p + return z +} + +// modSqrtTonelliShanks uses the Tonelli-Shanks algorithm to find the square +// root of a quadratic residue modulo any prime. +func (z *Int) modSqrtTonelliShanks(x, p *Int) *Int { + // Break p-1 into s*2^e such that s is odd. + var s Int + s.Sub(p, intOne) + e := s.abs.trailingZeroBits() + s.Rsh(&s, e) + + // find some non-square n + var n Int + n.SetInt64(2) + for Jacobi(&n, p) != -1 { + n.Add(&n, intOne) + } + + // Core of the Tonelli-Shanks algorithm. Follows the description in + // section 6 of "Square roots from 1; 24, 51, 10 to Dan Shanks" by Ezra + // Brown: + // https://www.maa.org/sites/default/files/pdf/upload_library/22/Polya/07468342.di020786.02p0470a.pdf + var y, b, g, t Int + y.Add(&s, intOne) + y.Rsh(&y, 1) + y.Exp(x, &y, p) // y = x^((s+1)/2) + b.Exp(x, &s, p) // b = x^s + g.Exp(&n, &s, p) // g = n^s + r := e + for { + // find the least m such that ord_p(b) = 2^m + var m uint + t.Set(&b) + for t.Cmp(intOne) != 0 { + t.Mul(&t, &t).Mod(&t, p) + m++ + } + + if m == 0 { + return z.Set(&y) + } + + t.SetInt64(0).SetBit(&t, int(r-m-1), 1).Exp(&g, &t, p) + // t = g^(2^(r-m-1)) mod p + g.Mul(&t, &t).Mod(&g, p) // g = g^(2^(r-m)) mod p + y.Mul(&y, &t).Mod(&y, p) + b.Mul(&b, &g).Mod(&b, p) + r = m + } +} + +// ModSqrt sets z to a square root of x mod p if such a square root exists, and +// returns z. The modulus p must be an odd prime. If x is not a square mod p, +// ModSqrt leaves z unchanged and returns nil. This function panics if p is +// not an odd integer. +func (z *Int) ModSqrt(x, p *Int) *Int { + switch Jacobi(x, p) { + case -1: + return nil // x is not a square mod p + case 0: + return z.SetInt64(0) // sqrt(0) mod p = 0 + case 1: + break + } + if x.neg || x.Cmp(p) >= 0 { // ensure 0 <= x < p + x = new(Int).Mod(x, p) + } + + // Check whether p is 3 mod 4, and if so, use the faster algorithm. + if len(p.abs) > 0 && p.abs[0]%4 == 3 { + return z.modSqrt3Mod4Prime(x, p) + } + // Otherwise, use Tonelli-Shanks. + return z.modSqrtTonelliShanks(x, p) +} + +// Lsh sets z = x << n and returns z. +func (z *Int) Lsh(x *Int, n uint) *Int { + z.abs = z.abs.shl(x.abs, n) + z.neg = x.neg + return z +} + +// Rsh sets z = x >> n and returns z. +func (z *Int) Rsh(x *Int, n uint) *Int { + if x.neg { + // (-x) >> s == ^(x-1) >> s == ^((x-1) >> s) == -(((x-1) >> s) + 1) + t := z.abs.sub(x.abs, natOne) // no underflow because |x| > 0 + t = t.shr(t, n) + z.abs = t.add(t, natOne) + z.neg = true // z cannot be zero if x is negative + return z + } + + z.abs = z.abs.shr(x.abs, n) + z.neg = false + return z +} + +// Bit returns the value of the i'th bit of x. That is, it +// returns (x>>i)&1. The bit index i must be >= 0. +func (x *Int) Bit(i int) uint { + if i == 0 { + // optimization for common case: odd/even test of x + if len(x.abs) > 0 { + return uint(x.abs[0] & 1) // bit 0 is same for -x + } + return 0 + } + if i < 0 { + panic("negative bit index") + } + if x.neg { + t := nat(nil).sub(x.abs, natOne) + return t.bit(uint(i)) ^ 1 + } + + return x.abs.bit(uint(i)) +} + +// SetBit sets z to x, with x's i'th bit set to b (0 or 1). +// That is, if b is 1 SetBit sets z = x | (1 << i); +// if b is 0 SetBit sets z = x &^ (1 << i). If b is not 0 or 1, +// SetBit will panic. +func (z *Int) SetBit(x *Int, i int, b uint) *Int { + if i < 0 { + panic("negative bit index") + } + if x.neg { + t := z.abs.sub(x.abs, natOne) + t = t.setBit(t, uint(i), b^1) + z.abs = t.add(t, natOne) + z.neg = len(z.abs) > 0 + return z + } + z.abs = z.abs.setBit(x.abs, uint(i), b) + z.neg = false + return z +} + +// And sets z = x & y and returns z. +func (z *Int) And(x, y *Int) *Int { + if x.neg == y.neg { + if x.neg { + // (-x) & (-y) == ^(x-1) & ^(y-1) == ^((x-1) | (y-1)) == -(((x-1) | (y-1)) + 1) + x1 := nat(nil).sub(x.abs, natOne) + y1 := nat(nil).sub(y.abs, natOne) + z.abs = z.abs.add(z.abs.or(x1, y1), natOne) + z.neg = true // z cannot be zero if x and y are negative + return z + } + + // x & y == x & y + z.abs = z.abs.and(x.abs, y.abs) + z.neg = false + return z + } + + // x.neg != y.neg + if x.neg { + x, y = y, x // & is symmetric + } + + // x & (-y) == x & ^(y-1) == x &^ (y-1) + y1 := nat(nil).sub(y.abs, natOne) + z.abs = z.abs.andNot(x.abs, y1) + z.neg = false + return z +} + +// AndNot sets z = x &^ y and returns z. +func (z *Int) AndNot(x, y *Int) *Int { + if x.neg == y.neg { + if x.neg { + // (-x) &^ (-y) == ^(x-1) &^ ^(y-1) == ^(x-1) & (y-1) == (y-1) &^ (x-1) + x1 := nat(nil).sub(x.abs, natOne) + y1 := nat(nil).sub(y.abs, natOne) + z.abs = z.abs.andNot(y1, x1) + z.neg = false + return z + } + + // x &^ y == x &^ y + z.abs = z.abs.andNot(x.abs, y.abs) + z.neg = false + return z + } + + if x.neg { + // (-x) &^ y == ^(x-1) &^ y == ^(x-1) & ^y == ^((x-1) | y) == -(((x-1) | y) + 1) + x1 := nat(nil).sub(x.abs, natOne) + z.abs = z.abs.add(z.abs.or(x1, y.abs), natOne) + z.neg = true // z cannot be zero if x is negative and y is positive + return z + } + + // x &^ (-y) == x &^ ^(y-1) == x & (y-1) + y1 := nat(nil).sub(y.abs, natOne) + z.abs = z.abs.and(x.abs, y1) + z.neg = false + return z +} + +// Or sets z = x | y and returns z. +func (z *Int) Or(x, y *Int) *Int { + if x.neg == y.neg { + if x.neg { + // (-x) | (-y) == ^(x-1) | ^(y-1) == ^((x-1) & (y-1)) == -(((x-1) & (y-1)) + 1) + x1 := nat(nil).sub(x.abs, natOne) + y1 := nat(nil).sub(y.abs, natOne) + z.abs = z.abs.add(z.abs.and(x1, y1), natOne) + z.neg = true // z cannot be zero if x and y are negative + return z + } + + // x | y == x | y + z.abs = z.abs.or(x.abs, y.abs) + z.neg = false + return z + } + + // x.neg != y.neg + if x.neg { + x, y = y, x // | is symmetric + } + + // x | (-y) == x | ^(y-1) == ^((y-1) &^ x) == -(^((y-1) &^ x) + 1) + y1 := nat(nil).sub(y.abs, natOne) + z.abs = z.abs.add(z.abs.andNot(y1, x.abs), natOne) + z.neg = true // z cannot be zero if one of x or y is negative + return z +} + +// Xor sets z = x ^ y and returns z. +func (z *Int) Xor(x, y *Int) *Int { + if x.neg == y.neg { + if x.neg { + // (-x) ^ (-y) == ^(x-1) ^ ^(y-1) == (x-1) ^ (y-1) + x1 := nat(nil).sub(x.abs, natOne) + y1 := nat(nil).sub(y.abs, natOne) + z.abs = z.abs.xor(x1, y1) + z.neg = false + return z + } + + // x ^ y == x ^ y + z.abs = z.abs.xor(x.abs, y.abs) + z.neg = false + return z + } + + // x.neg != y.neg + if x.neg { + x, y = y, x // ^ is symmetric + } + + // x ^ (-y) == x ^ ^(y-1) == ^(x ^ (y-1)) == -((x ^ (y-1)) + 1) + y1 := nat(nil).sub(y.abs, natOne) + z.abs = z.abs.add(z.abs.xor(x.abs, y1), natOne) + z.neg = true // z cannot be zero if only one of x or y is negative + return z +} + +// Not sets z = ^x and returns z. +func (z *Int) Not(x *Int) *Int { + if x.neg { + // ^(-x) == ^(^(x-1)) == x-1 + z.abs = z.abs.sub(x.abs, natOne) + z.neg = false + return z + } + + // ^x == -x-1 == -(x+1) + z.abs = z.abs.add(x.abs, natOne) + z.neg = true // z cannot be zero if x is positive + return z +} + +// Sqrt sets z to ⌊√x⌋, the largest integer such that z² ≤ x, and returns z. +// It panics if x is negative. +func (z *Int) Sqrt(x *Int) *Int { + if x.neg { + panic("square root of negative number") + } + z.neg = false + z.abs = z.abs.sqrt(x.abs) + return z +} diff --git a/vendor/github.com/golang/go/src/math/big/intconv.go b/vendor/github.com/golang/go/src/math/big/intconv.go new file mode 100644 index 000000000000..6cca827c8e34 --- /dev/null +++ b/vendor/github.com/golang/go/src/math/big/intconv.go @@ -0,0 +1,247 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements int-to-string conversion functions. + +package big + +import ( + "errors" + "fmt" + "io" +) + +// Text returns the string representation of x in the given base. +// Base must be between 2 and 62, inclusive. The result uses the +// lower-case letters 'a' to 'z' for digit values 10 to 35, and +// the upper-case letters 'A' to 'Z' for digit values 36 to 61. +// No prefix (such as "0x") is added to the string. +func (x *Int) Text(base int) string { + if x == nil { + return "" + } + return string(x.abs.itoa(x.neg, base)) +} + +// Append appends the string representation of x, as generated by +// x.Text(base), to buf and returns the extended buffer. +func (x *Int) Append(buf []byte, base int) []byte { + if x == nil { + return append(buf, ""...) + } + return append(buf, x.abs.itoa(x.neg, base)...) +} + +func (x *Int) String() string { + return x.Text(10) +} + +// write count copies of text to s +func writeMultiple(s fmt.State, text string, count int) { + if len(text) > 0 { + b := []byte(text) + for ; count > 0; count-- { + s.Write(b) + } + } +} + +var _ fmt.Formatter = intOne // *Int must implement fmt.Formatter + +// Format implements fmt.Formatter. It accepts the formats +// 'b' (binary), 'o' (octal), 'd' (decimal), 'x' (lowercase +// hexadecimal), and 'X' (uppercase hexadecimal). +// Also supported are the full suite of package fmt's format +// flags for integral types, including '+' and ' ' for sign +// control, '#' for leading zero in octal and for hexadecimal, +// a leading "0x" or "0X" for "%#x" and "%#X" respectively, +// specification of minimum digits precision, output field +// width, space or zero padding, and '-' for left or right +// justification. +// +func (x *Int) Format(s fmt.State, ch rune) { + // determine base + var base int + switch ch { + case 'b': + base = 2 + case 'o': + base = 8 + case 'd', 's', 'v': + base = 10 + case 'x', 'X': + base = 16 + default: + // unknown format + fmt.Fprintf(s, "%%!%c(big.Int=%s)", ch, x.String()) + return + } + + if x == nil { + fmt.Fprint(s, "") + return + } + + // determine sign character + sign := "" + switch { + case x.neg: + sign = "-" + case s.Flag('+'): // supersedes ' ' when both specified + sign = "+" + case s.Flag(' '): + sign = " " + } + + // determine prefix characters for indicating output base + prefix := "" + if s.Flag('#') { + switch ch { + case 'o': // octal + prefix = "0" + case 'x': // hexadecimal + prefix = "0x" + case 'X': + prefix = "0X" + } + } + + digits := x.abs.utoa(base) + if ch == 'X' { + // faster than bytes.ToUpper + for i, d := range digits { + if 'a' <= d && d <= 'z' { + digits[i] = 'A' + (d - 'a') + } + } + } + + // number of characters for the three classes of number padding + var left int // space characters to left of digits for right justification ("%8d") + var zeros int // zero characters (actually cs[0]) as left-most digits ("%.8d") + var right int // space characters to right of digits for left justification ("%-8d") + + // determine number padding from precision: the least number of digits to output + precision, precisionSet := s.Precision() + if precisionSet { + switch { + case len(digits) < precision: + zeros = precision - len(digits) // count of zero padding + case len(digits) == 1 && digits[0] == '0' && precision == 0: + return // print nothing if zero value (x == 0) and zero precision ("." or ".0") + } + } + + // determine field pad from width: the least number of characters to output + length := len(sign) + len(prefix) + zeros + len(digits) + if width, widthSet := s.Width(); widthSet && length < width { // pad as specified + switch d := width - length; { + case s.Flag('-'): + // pad on the right with spaces; supersedes '0' when both specified + right = d + case s.Flag('0') && !precisionSet: + // pad with zeros unless precision also specified + zeros = d + default: + // pad on the left with spaces + left = d + } + } + + // print number as [left pad][sign][prefix][zero pad][digits][right pad] + writeMultiple(s, " ", left) + writeMultiple(s, sign, 1) + writeMultiple(s, prefix, 1) + writeMultiple(s, "0", zeros) + s.Write(digits) + writeMultiple(s, " ", right) +} + +// scan sets z to the integer value corresponding to the longest possible prefix +// read from r representing a signed integer number in a given conversion base. +// It returns z, the actual conversion base used, and an error, if any. In the +// error case, the value of z is undefined but the returned value is nil. The +// syntax follows the syntax of integer literals in Go. +// +// The base argument must be 0 or a value from 2 through MaxBase. If the base +// is 0, the string prefix determines the actual conversion base. A prefix of +// ``0x'' or ``0X'' selects base 16; the ``0'' prefix selects base 8, and a +// ``0b'' or ``0B'' prefix selects base 2. Otherwise the selected base is 10. +// +func (z *Int) scan(r io.ByteScanner, base int) (*Int, int, error) { + // determine sign + neg, err := scanSign(r) + if err != nil { + return nil, 0, err + } + + // determine mantissa + z.abs, base, _, err = z.abs.scan(r, base, false) + if err != nil { + return nil, base, err + } + z.neg = len(z.abs) > 0 && neg // 0 has no sign + + return z, base, nil +} + +func scanSign(r io.ByteScanner) (neg bool, err error) { + var ch byte + if ch, err = r.ReadByte(); err != nil { + return false, err + } + switch ch { + case '-': + neg = true + case '+': + // nothing to do + default: + r.UnreadByte() + } + return +} + +// byteReader is a local wrapper around fmt.ScanState; +// it implements the ByteReader interface. +type byteReader struct { + fmt.ScanState +} + +func (r byteReader) ReadByte() (byte, error) { + ch, size, err := r.ReadRune() + if size != 1 && err == nil { + err = fmt.Errorf("invalid rune %#U", ch) + } + return byte(ch), err +} + +func (r byteReader) UnreadByte() error { + return r.UnreadRune() +} + +var _ fmt.Scanner = intOne // *Int must implement fmt.Scanner + +// Scan is a support routine for fmt.Scanner; it sets z to the value of +// the scanned number. It accepts the formats 'b' (binary), 'o' (octal), +// 'd' (decimal), 'x' (lowercase hexadecimal), and 'X' (uppercase hexadecimal). +func (z *Int) Scan(s fmt.ScanState, ch rune) error { + s.SkipSpace() // skip leading space characters + base := 0 + switch ch { + case 'b': + base = 2 + case 'o': + base = 8 + case 'd': + base = 10 + case 'x', 'X': + base = 16 + case 's', 'v': + // let scan determine the base + default: + return errors.New("Int.Scan: invalid verb") + } + _, _, err := z.scan(byteReader{s}, base) + return err +} diff --git a/vendor/github.com/golang/go/src/math/big/intmarsh.go b/vendor/github.com/golang/go/src/math/big/intmarsh.go new file mode 100644 index 000000000000..c1422e271072 --- /dev/null +++ b/vendor/github.com/golang/go/src/math/big/intmarsh.go @@ -0,0 +1,80 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements encoding/decoding of Ints. + +package big + +import ( + "bytes" + "fmt" +) + +// Gob codec version. Permits backward-compatible changes to the encoding. +const intGobVersion byte = 1 + +// GobEncode implements the gob.GobEncoder interface. +func (x *Int) GobEncode() ([]byte, error) { + if x == nil { + return nil, nil + } + buf := make([]byte, 1+len(x.abs)*_S) // extra byte for version and sign bit + i := x.abs.bytes(buf) - 1 // i >= 0 + b := intGobVersion << 1 // make space for sign bit + if x.neg { + b |= 1 + } + buf[i] = b + return buf[i:], nil +} + +// GobDecode implements the gob.GobDecoder interface. +func (z *Int) GobDecode(buf []byte) error { + if len(buf) == 0 { + // Other side sent a nil or default value. + *z = Int{} + return nil + } + b := buf[0] + if b>>1 != intGobVersion { + return fmt.Errorf("Int.GobDecode: encoding version %d not supported", b>>1) + } + z.neg = b&1 != 0 + z.abs = z.abs.setBytes(buf[1:]) + return nil +} + +// MarshalText implements the encoding.TextMarshaler interface. +func (x *Int) MarshalText() (text []byte, err error) { + if x == nil { + return []byte(""), nil + } + return x.abs.itoa(x.neg, 10), nil +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface. +func (z *Int) UnmarshalText(text []byte) error { + if _, ok := z.setFromScanner(bytes.NewReader(text), 0); !ok { + return fmt.Errorf("math/big: cannot unmarshal %q into a *big.Int", text) + } + return nil +} + +// The JSON marshalers are only here for API backward compatibility +// (programs that explicitly look for these two methods). JSON works +// fine with the TextMarshaler only. + +// MarshalJSON implements the json.Marshaler interface. +func (x *Int) MarshalJSON() ([]byte, error) { + return x.MarshalText() +} + +// UnmarshalJSON implements the json.Unmarshaler interface. +func (z *Int) UnmarshalJSON(text []byte) error { + // Ignore null, like in the main JSON package. + if string(text) == "null" { + return nil + } + return z.UnmarshalText(text) +} diff --git a/vendor/github.com/golang/go/src/math/big/nat.go b/vendor/github.com/golang/go/src/math/big/nat.go new file mode 100644 index 000000000000..3bb818f5f259 --- /dev/null +++ b/vendor/github.com/golang/go/src/math/big/nat.go @@ -0,0 +1,1267 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements unsigned multi-precision integers (natural +// numbers). They are the building blocks for the implementation +// of signed integers, rationals, and floating-point numbers. + +package big + +import ( + "math/bits" + "math/rand" + "sync" +) + +// An unsigned integer x of the form +// +// x = x[n-1]*_B^(n-1) + x[n-2]*_B^(n-2) + ... + x[1]*_B + x[0] +// +// with 0 <= x[i] < _B and 0 <= i < n is stored in a slice of length n, +// with the digits x[i] as the slice elements. +// +// A number is normalized if the slice contains no leading 0 digits. +// During arithmetic operations, denormalized values may occur but are +// always normalized before returning the final result. The normalized +// representation of 0 is the empty or nil slice (length = 0). +// +type nat []Word + +var ( + natOne = nat{1} + natTwo = nat{2} + natTen = nat{10} +) + +func (z nat) clear() { + for i := range z { + z[i] = 0 + } +} + +func (z nat) norm() nat { + i := len(z) + for i > 0 && z[i-1] == 0 { + i-- + } + return z[0:i] +} + +func (z nat) make(n int) nat { + if n <= cap(z) { + return z[:n] // reuse z + } + // Choosing a good value for e has significant performance impact + // because it increases the chance that a value can be reused. + const e = 4 // extra capacity + return make(nat, n, n+e) +} + +func (z nat) setWord(x Word) nat { + if x == 0 { + return z[:0] + } + z = z.make(1) + z[0] = x + return z +} + +func (z nat) setUint64(x uint64) nat { + // single-word value + if w := Word(x); uint64(w) == x { + return z.setWord(w) + } + // 2-word value + z = z.make(2) + z[1] = Word(x >> 32) + z[0] = Word(x) + return z +} + +func (z nat) set(x nat) nat { + z = z.make(len(x)) + copy(z, x) + return z +} + +func (z nat) add(x, y nat) nat { + m := len(x) + n := len(y) + + switch { + case m < n: + return z.add(y, x) + case m == 0: + // n == 0 because m >= n; result is 0 + return z[:0] + case n == 0: + // result is x + return z.set(x) + } + // m > 0 + + z = z.make(m + 1) + c := addVV(z[0:n], x, y) + if m > n { + c = addVW(z[n:m], x[n:], c) + } + z[m] = c + + return z.norm() +} + +func (z nat) sub(x, y nat) nat { + m := len(x) + n := len(y) + + switch { + case m < n: + panic("underflow") + case m == 0: + // n == 0 because m >= n; result is 0 + return z[:0] + case n == 0: + // result is x + return z.set(x) + } + // m > 0 + + z = z.make(m) + c := subVV(z[0:n], x, y) + if m > n { + c = subVW(z[n:], x[n:], c) + } + if c != 0 { + panic("underflow") + } + + return z.norm() +} + +func (x nat) cmp(y nat) (r int) { + m := len(x) + n := len(y) + if m != n || m == 0 { + switch { + case m < n: + r = -1 + case m > n: + r = 1 + } + return + } + + i := m - 1 + for i > 0 && x[i] == y[i] { + i-- + } + + switch { + case x[i] < y[i]: + r = -1 + case x[i] > y[i]: + r = 1 + } + return +} + +func (z nat) mulAddWW(x nat, y, r Word) nat { + m := len(x) + if m == 0 || y == 0 { + return z.setWord(r) // result is r + } + // m > 0 + + z = z.make(m + 1) + z[m] = mulAddVWW(z[0:m], x, y, r) + + return z.norm() +} + +// basicMul multiplies x and y and leaves the result in z. +// The (non-normalized) result is placed in z[0 : len(x) + len(y)]. +func basicMul(z, x, y nat) { + z[0 : len(x)+len(y)].clear() // initialize z + for i, d := range y { + if d != 0 { + z[len(x)+i] = addMulVVW(z[i:i+len(x)], x, d) + } + } +} + +// montgomery computes z mod m = x*y*2**(-n*_W) mod m, +// assuming k = -1/m mod 2**_W. +// z is used for storing the result which is returned; +// z must not alias x, y or m. +// See Gueron, "Efficient Software Implementations of Modular Exponentiation". +// https://eprint.iacr.org/2011/239.pdf +// In the terminology of that paper, this is an "Almost Montgomery Multiplication": +// x and y are required to satisfy 0 <= z < 2**(n*_W) and then the result +// z is guaranteed to satisfy 0 <= z < 2**(n*_W), but it may not be < m. +func (z nat) montgomery(x, y, m nat, k Word, n int) nat { + // This code assumes x, y, m are all the same length, n. + // (required by addMulVVW and the for loop). + // It also assumes that x, y are already reduced mod m, + // or else the result will not be properly reduced. + if len(x) != n || len(y) != n || len(m) != n { + panic("math/big: mismatched montgomery number lengths") + } + z = z.make(n) + z.clear() + var c Word + for i := 0; i < n; i++ { + d := y[i] + c2 := addMulVVW(z, x, d) + t := z[0] * k + c3 := addMulVVW(z, m, t) + copy(z, z[1:]) + cx := c + c2 + cy := cx + c3 + z[n-1] = cy + if cx < c2 || cy < c3 { + c = 1 + } else { + c = 0 + } + } + if c != 0 { + subVV(z, z, m) + } + return z +} + +// Fast version of z[0:n+n>>1].add(z[0:n+n>>1], x[0:n]) w/o bounds checks. +// Factored out for readability - do not use outside karatsuba. +func karatsubaAdd(z, x nat, n int) { + if c := addVV(z[0:n], z, x); c != 0 { + addVW(z[n:n+n>>1], z[n:], c) + } +} + +// Like karatsubaAdd, but does subtract. +func karatsubaSub(z, x nat, n int) { + if c := subVV(z[0:n], z, x); c != 0 { + subVW(z[n:n+n>>1], z[n:], c) + } +} + +// Operands that are shorter than karatsubaThreshold are multiplied using +// "grade school" multiplication; for longer operands the Karatsuba algorithm +// is used. +var karatsubaThreshold = 40 // computed by calibrate_test.go + +// karatsuba multiplies x and y and leaves the result in z. +// Both x and y must have the same length n and n must be a +// power of 2. The result vector z must have len(z) >= 6*n. +// The (non-normalized) result is placed in z[0 : 2*n]. +func karatsuba(z, x, y nat) { + n := len(y) + + // Switch to basic multiplication if numbers are odd or small. + // (n is always even if karatsubaThreshold is even, but be + // conservative) + if n&1 != 0 || n < karatsubaThreshold || n < 2 { + basicMul(z, x, y) + return + } + // n&1 == 0 && n >= karatsubaThreshold && n >= 2 + + // Karatsuba multiplication is based on the observation that + // for two numbers x and y with: + // + // x = x1*b + x0 + // y = y1*b + y0 + // + // the product x*y can be obtained with 3 products z2, z1, z0 + // instead of 4: + // + // x*y = x1*y1*b*b + (x1*y0 + x0*y1)*b + x0*y0 + // = z2*b*b + z1*b + z0 + // + // with: + // + // xd = x1 - x0 + // yd = y0 - y1 + // + // z1 = xd*yd + z2 + z0 + // = (x1-x0)*(y0 - y1) + z2 + z0 + // = x1*y0 - x1*y1 - x0*y0 + x0*y1 + z2 + z0 + // = x1*y0 - z2 - z0 + x0*y1 + z2 + z0 + // = x1*y0 + x0*y1 + + // split x, y into "digits" + n2 := n >> 1 // n2 >= 1 + x1, x0 := x[n2:], x[0:n2] // x = x1*b + y0 + y1, y0 := y[n2:], y[0:n2] // y = y1*b + y0 + + // z is used for the result and temporary storage: + // + // 6*n 5*n 4*n 3*n 2*n 1*n 0*n + // z = [z2 copy|z0 copy| xd*yd | yd:xd | x1*y1 | x0*y0 ] + // + // For each recursive call of karatsuba, an unused slice of + // z is passed in that has (at least) half the length of the + // caller's z. + + // compute z0 and z2 with the result "in place" in z + karatsuba(z, x0, y0) // z0 = x0*y0 + karatsuba(z[n:], x1, y1) // z2 = x1*y1 + + // compute xd (or the negative value if underflow occurs) + s := 1 // sign of product xd*yd + xd := z[2*n : 2*n+n2] + if subVV(xd, x1, x0) != 0 { // x1-x0 + s = -s + subVV(xd, x0, x1) // x0-x1 + } + + // compute yd (or the negative value if underflow occurs) + yd := z[2*n+n2 : 3*n] + if subVV(yd, y0, y1) != 0 { // y0-y1 + s = -s + subVV(yd, y1, y0) // y1-y0 + } + + // p = (x1-x0)*(y0-y1) == x1*y0 - x1*y1 - x0*y0 + x0*y1 for s > 0 + // p = (x0-x1)*(y0-y1) == x0*y0 - x0*y1 - x1*y0 + x1*y1 for s < 0 + p := z[n*3:] + karatsuba(p, xd, yd) + + // save original z2:z0 + // (ok to use upper half of z since we're done recursing) + r := z[n*4:] + copy(r, z[:n*2]) + + // add up all partial products + // + // 2*n n 0 + // z = [ z2 | z0 ] + // + [ z0 ] + // + [ z2 ] + // + [ p ] + // + karatsubaAdd(z[n2:], r, n) + karatsubaAdd(z[n2:], r[n:], n) + if s > 0 { + karatsubaAdd(z[n2:], p, n) + } else { + karatsubaSub(z[n2:], p, n) + } +} + +// alias reports whether x and y share the same base array. +func alias(x, y nat) bool { + return cap(x) > 0 && cap(y) > 0 && &x[0:cap(x)][cap(x)-1] == &y[0:cap(y)][cap(y)-1] +} + +// addAt implements z += x<<(_W*i); z must be long enough. +// (we don't use nat.add because we need z to stay the same +// slice, and we don't need to normalize z after each addition) +func addAt(z, x nat, i int) { + if n := len(x); n > 0 { + if c := addVV(z[i:i+n], z[i:], x); c != 0 { + j := i + n + if j < len(z) { + addVW(z[j:], z[j:], c) + } + } + } +} + +func max(x, y int) int { + if x > y { + return x + } + return y +} + +// karatsubaLen computes an approximation to the maximum k <= n such that +// k = p<= 0. Thus, the +// result is the largest number that can be divided repeatedly by 2 before +// becoming about the value of karatsubaThreshold. +func karatsubaLen(n int) int { + i := uint(0) + for n > karatsubaThreshold { + n >>= 1 + i++ + } + return n << i +} + +func (z nat) mul(x, y nat) nat { + m := len(x) + n := len(y) + + switch { + case m < n: + return z.mul(y, x) + case m == 0 || n == 0: + return z[:0] + case n == 1: + return z.mulAddWW(x, y[0], 0) + } + // m >= n > 1 + + // determine if z can be reused + if alias(z, x) || alias(z, y) { + z = nil // z is an alias for x or y - cannot reuse + } + + // use basic multiplication if the numbers are small + if n < karatsubaThreshold { + z = z.make(m + n) + basicMul(z, x, y) + return z.norm() + } + // m >= n && n >= karatsubaThreshold && n >= 2 + + // determine Karatsuba length k such that + // + // x = xh*b + x0 (0 <= x0 < b) + // y = yh*b + y0 (0 <= y0 < b) + // b = 1<<(_W*k) ("base" of digits xi, yi) + // + k := karatsubaLen(n) + // k <= n + + // multiply x0 and y0 via Karatsuba + x0 := x[0:k] // x0 is not normalized + y0 := y[0:k] // y0 is not normalized + z = z.make(max(6*k, m+n)) // enough space for karatsuba of x0*y0 and full result of x*y + karatsuba(z, x0, y0) + z = z[0 : m+n] // z has final length but may be incomplete + z[2*k:].clear() // upper portion of z is garbage (and 2*k <= m+n since k <= n <= m) + + // If xh != 0 or yh != 0, add the missing terms to z. For + // + // xh = xi*b^i + ... + x2*b^2 + x1*b (0 <= xi < b) + // yh = y1*b (0 <= y1 < b) + // + // the missing terms are + // + // x0*y1*b and xi*y0*b^i, xi*y1*b^(i+1) for i > 0 + // + // since all the yi for i > 1 are 0 by choice of k: If any of them + // were > 0, then yh >= b^2 and thus y >= b^2. Then k' = k*2 would + // be a larger valid threshold contradicting the assumption about k. + // + if k < n || m != n { + var t nat + + // add x0*y1*b + x0 := x0.norm() + y1 := y[k:] // y1 is normalized because y is + t = t.mul(x0, y1) // update t so we don't lose t's underlying array + addAt(z, t, k) + + // add xi*y0< k { + xi = xi[:k] + } + xi = xi.norm() + t = t.mul(xi, y0) + addAt(z, t, i) + t = t.mul(xi, y1) + addAt(z, t, i+k) + } + } + + return z.norm() +} + +// basicSqr sets z = x*x and is asymptotically faster than basicMul +// by about a factor of 2, but slower for small arguments due to overhead. +// Requirements: len(x) > 0, len(z) >= 2*len(x) +// The (non-normalized) result is placed in z[0 : 2 * len(x)]. +func basicSqr(z, x nat) { + n := len(x) + t := make(nat, 2*n) // temporary variable to hold the products + z[1], z[0] = mulWW(x[0], x[0]) // the initial square + for i := 1; i < n; i++ { + d := x[i] + // z collects the squares x[i] * x[i] + z[2*i+1], z[2*i] = mulWW(d, d) + // t collects the products x[i] * x[j] where j < i + t[2*i] = addMulVVW(t[i:2*i], x[0:i], d) + } + t[2*n-1] = shlVU(t[1:2*n-1], t[1:2*n-1], 1) // double the j < i products + addVV(z, z, t) // combine the result +} + +// Operands that are shorter than basicSqrThreshold are squared using +// "grade school" multiplication; for operands longer than karatsubaSqrThreshold +// the Karatsuba algorithm is used. +var basicSqrThreshold = 20 // computed by calibrate_test.go +var karatsubaSqrThreshold = 400 // computed by calibrate_test.go + +// z = x*x +func (z nat) sqr(x nat) nat { + n := len(x) + switch { + case n == 0: + return z[:0] + case n == 1: + d := x[0] + z = z.make(2) + z[1], z[0] = mulWW(d, d) + return z.norm() + } + + if alias(z, x) { + z = nil // z is an alias for x - cannot reuse + } + z = z.make(2 * n) + + if n < basicSqrThreshold { + basicMul(z, x, x) + return z.norm() + } + if n < karatsubaSqrThreshold { + basicSqr(z, x) + return z.norm() + } + + return z.mul(x, x) +} + +// mulRange computes the product of all the unsigned integers in the +// range [a, b] inclusively. If a > b (empty range), the result is 1. +func (z nat) mulRange(a, b uint64) nat { + switch { + case a == 0: + // cut long ranges short (optimization) + return z.setUint64(0) + case a > b: + return z.setUint64(1) + case a == b: + return z.setUint64(a) + case a+1 == b: + return z.mul(nat(nil).setUint64(a), nat(nil).setUint64(b)) + } + m := (a + b) / 2 + return z.mul(nat(nil).mulRange(a, m), nat(nil).mulRange(m+1, b)) +} + +// q = (x-r)/y, with 0 <= r < y +func (z nat) divW(x nat, y Word) (q nat, r Word) { + m := len(x) + switch { + case y == 0: + panic("division by zero") + case y == 1: + q = z.set(x) // result is x + return + case m == 0: + q = z[:0] // result is 0 + return + } + // m > 0 + z = z.make(m) + r = divWVW(z, 0, x, y) + q = z.norm() + return +} + +func (z nat) div(z2, u, v nat) (q, r nat) { + if len(v) == 0 { + panic("division by zero") + } + + if u.cmp(v) < 0 { + q = z[:0] + r = z2.set(u) + return + } + + if len(v) == 1 { + var r2 Word + q, r2 = z.divW(u, v[0]) + r = z2.setWord(r2) + return + } + + q, r = z.divLarge(z2, u, v) + return +} + +// getNat returns a *nat of len n. The contents may not be zero. +// The pool holds *nat to avoid allocation when converting to interface{}. +func getNat(n int) *nat { + var z *nat + if v := natPool.Get(); v != nil { + z = v.(*nat) + } + if z == nil { + z = new(nat) + } + *z = z.make(n) + return z +} + +func putNat(x *nat) { + natPool.Put(x) +} + +var natPool sync.Pool + +// q = (uIn-r)/v, with 0 <= r < y +// Uses z as storage for q, and u as storage for r if possible. +// See Knuth, Volume 2, section 4.3.1, Algorithm D. +// Preconditions: +// len(v) >= 2 +// len(uIn) >= len(v) +func (z nat) divLarge(u, uIn, v nat) (q, r nat) { + n := len(v) + m := len(uIn) - n + + // determine if z can be reused + // TODO(gri) should find a better solution - this if statement + // is very costly (see e.g. time pidigits -s -n 10000) + if alias(z, u) || alias(z, uIn) || alias(z, v) { + z = nil // z is an alias for u or uIn or v - cannot reuse + } + q = z.make(m + 1) + + qhatvp := getNat(n + 1) + qhatv := *qhatvp + if alias(u, uIn) || alias(u, v) { + u = nil // u is an alias for uIn or v - cannot reuse + } + u = u.make(len(uIn) + 1) + u.clear() // TODO(gri) no need to clear if we allocated a new u + + // D1. + var v1p *nat + shift := nlz(v[n-1]) + if shift > 0 { + // do not modify v, it may be used by another goroutine simultaneously + v1p = getNat(n) + v1 := *v1p + shlVU(v1, v, shift) + v = v1 + } + u[len(uIn)] = shlVU(u[0:len(uIn)], uIn, shift) + + // D2. + vn1 := v[n-1] + for j := m; j >= 0; j-- { + // D3. + qhat := Word(_M) + if ujn := u[j+n]; ujn != vn1 { + var rhat Word + qhat, rhat = divWW(ujn, u[j+n-1], vn1) + + // x1 | x2 = q̂v_{n-2} + vn2 := v[n-2] + x1, x2 := mulWW(qhat, vn2) + // test if q̂v_{n-2} > br̂ + u_{j+n-2} + ujn2 := u[j+n-2] + for greaterThan(x1, x2, rhat, ujn2) { + qhat-- + prevRhat := rhat + rhat += vn1 + // v[n-1] >= 0, so this tests for overflow. + if rhat < prevRhat { + break + } + x1, x2 = mulWW(qhat, vn2) + } + } + + // D4. + qhatv[n] = mulAddVWW(qhatv[0:n], v, qhat, 0) + + c := subVV(u[j:j+len(qhatv)], u[j:], qhatv) + if c != 0 { + c := addVV(u[j:j+n], u[j:], v) + u[j+n] += c + qhat-- + } + + q[j] = qhat + } + if v1p != nil { + putNat(v1p) + } + putNat(qhatvp) + + q = q.norm() + shrVU(u, u, shift) + r = u.norm() + + return q, r +} + +// Length of x in bits. x must be normalized. +func (x nat) bitLen() int { + if i := len(x) - 1; i >= 0 { + return i*_W + bits.Len(uint(x[i])) + } + return 0 +} + +// trailingZeroBits returns the number of consecutive least significant zero +// bits of x. +func (x nat) trailingZeroBits() uint { + if len(x) == 0 { + return 0 + } + var i uint + for x[i] == 0 { + i++ + } + // x[i] != 0 + return i*_W + uint(bits.TrailingZeros(uint(x[i]))) +} + +// z = x << s +func (z nat) shl(x nat, s uint) nat { + m := len(x) + if m == 0 { + return z[:0] + } + // m > 0 + + n := m + int(s/_W) + z = z.make(n + 1) + z[n] = shlVU(z[n-m:n], x, s%_W) + z[0 : n-m].clear() + + return z.norm() +} + +// z = x >> s +func (z nat) shr(x nat, s uint) nat { + m := len(x) + n := m - int(s/_W) + if n <= 0 { + return z[:0] + } + // n > 0 + + z = z.make(n) + shrVU(z, x[m-n:], s%_W) + + return z.norm() +} + +func (z nat) setBit(x nat, i uint, b uint) nat { + j := int(i / _W) + m := Word(1) << (i % _W) + n := len(x) + switch b { + case 0: + z = z.make(n) + copy(z, x) + if j >= n { + // no need to grow + return z + } + z[j] &^= m + return z.norm() + case 1: + if j >= n { + z = z.make(j + 1) + z[n:].clear() + } else { + z = z.make(n) + } + copy(z, x) + z[j] |= m + // no need to normalize + return z + } + panic("set bit is not 0 or 1") +} + +// bit returns the value of the i'th bit, with lsb == bit 0. +func (x nat) bit(i uint) uint { + j := i / _W + if j >= uint(len(x)) { + return 0 + } + // 0 <= j < len(x) + return uint(x[j] >> (i % _W) & 1) +} + +// sticky returns 1 if there's a 1 bit within the +// i least significant bits, otherwise it returns 0. +func (x nat) sticky(i uint) uint { + j := i / _W + if j >= uint(len(x)) { + if len(x) == 0 { + return 0 + } + return 1 + } + // 0 <= j < len(x) + for _, x := range x[:j] { + if x != 0 { + return 1 + } + } + if x[j]<<(_W-i%_W) != 0 { + return 1 + } + return 0 +} + +func (z nat) and(x, y nat) nat { + m := len(x) + n := len(y) + if m > n { + m = n + } + // m <= n + + z = z.make(m) + for i := 0; i < m; i++ { + z[i] = x[i] & y[i] + } + + return z.norm() +} + +func (z nat) andNot(x, y nat) nat { + m := len(x) + n := len(y) + if n > m { + n = m + } + // m >= n + + z = z.make(m) + for i := 0; i < n; i++ { + z[i] = x[i] &^ y[i] + } + copy(z[n:m], x[n:m]) + + return z.norm() +} + +func (z nat) or(x, y nat) nat { + m := len(x) + n := len(y) + s := x + if m < n { + n, m = m, n + s = y + } + // m >= n + + z = z.make(m) + for i := 0; i < n; i++ { + z[i] = x[i] | y[i] + } + copy(z[n:m], s[n:m]) + + return z.norm() +} + +func (z nat) xor(x, y nat) nat { + m := len(x) + n := len(y) + s := x + if m < n { + n, m = m, n + s = y + } + // m >= n + + z = z.make(m) + for i := 0; i < n; i++ { + z[i] = x[i] ^ y[i] + } + copy(z[n:m], s[n:m]) + + return z.norm() +} + +// greaterThan reports whether (x1<<_W + x2) > (y1<<_W + y2) +func greaterThan(x1, x2, y1, y2 Word) bool { + return x1 > y1 || x1 == y1 && x2 > y2 +} + +// modW returns x % d. +func (x nat) modW(d Word) (r Word) { + // TODO(agl): we don't actually need to store the q value. + var q nat + q = q.make(len(x)) + return divWVW(q, 0, x, d) +} + +// random creates a random integer in [0..limit), using the space in z if +// possible. n is the bit length of limit. +func (z nat) random(rand *rand.Rand, limit nat, n int) nat { + if alias(z, limit) { + z = nil // z is an alias for limit - cannot reuse + } + z = z.make(len(limit)) + + bitLengthOfMSW := uint(n % _W) + if bitLengthOfMSW == 0 { + bitLengthOfMSW = _W + } + mask := Word((1 << bitLengthOfMSW) - 1) + + for { + switch _W { + case 32: + for i := range z { + z[i] = Word(rand.Uint32()) + } + case 64: + for i := range z { + z[i] = Word(rand.Uint32()) | Word(rand.Uint32())<<32 + } + default: + panic("unknown word size") + } + z[len(limit)-1] &= mask + if z.cmp(limit) < 0 { + break + } + } + + return z.norm() +} + +// If m != 0 (i.e., len(m) != 0), expNN sets z to x**y mod m; +// otherwise it sets z to x**y. The result is the value of z. +func (z nat) expNN(x, y, m nat) nat { + if alias(z, x) || alias(z, y) { + // We cannot allow in-place modification of x or y. + z = nil + } + + // x**y mod 1 == 0 + if len(m) == 1 && m[0] == 1 { + return z.setWord(0) + } + // m == 0 || m > 1 + + // x**0 == 1 + if len(y) == 0 { + return z.setWord(1) + } + // y > 0 + + // x**1 mod m == x mod m + if len(y) == 1 && y[0] == 1 && len(m) != 0 { + _, z = z.div(z, x, m) + return z + } + // y > 1 + + if len(m) != 0 { + // We likely end up being as long as the modulus. + z = z.make(len(m)) + } + z = z.set(x) + + // If the base is non-trivial and the exponent is large, we use + // 4-bit, windowed exponentiation. This involves precomputing 14 values + // (x^2...x^15) but then reduces the number of multiply-reduces by a + // third. Even for a 32-bit exponent, this reduces the number of + // operations. Uses Montgomery method for odd moduli. + if x.cmp(natOne) > 0 && len(y) > 1 && len(m) > 0 { + if m[0]&1 == 1 { + return z.expNNMontgomery(x, y, m) + } + return z.expNNWindowed(x, y, m) + } + + v := y[len(y)-1] // v > 0 because y is normalized and y > 0 + shift := nlz(v) + 1 + v <<= shift + var q nat + + const mask = 1 << (_W - 1) + + // We walk through the bits of the exponent one by one. Each time we + // see a bit, we square, thus doubling the power. If the bit is a one, + // we also multiply by x, thus adding one to the power. + + w := _W - int(shift) + // zz and r are used to avoid allocating in mul and div as + // otherwise the arguments would alias. + var zz, r nat + for j := 0; j < w; j++ { + zz = zz.sqr(z) + zz, z = z, zz + + if v&mask != 0 { + zz = zz.mul(z, x) + zz, z = z, zz + } + + if len(m) != 0 { + zz, r = zz.div(r, z, m) + zz, r, q, z = q, z, zz, r + } + + v <<= 1 + } + + for i := len(y) - 2; i >= 0; i-- { + v = y[i] + + for j := 0; j < _W; j++ { + zz = zz.sqr(z) + zz, z = z, zz + + if v&mask != 0 { + zz = zz.mul(z, x) + zz, z = z, zz + } + + if len(m) != 0 { + zz, r = zz.div(r, z, m) + zz, r, q, z = q, z, zz, r + } + + v <<= 1 + } + } + + return z.norm() +} + +// expNNWindowed calculates x**y mod m using a fixed, 4-bit window. +func (z nat) expNNWindowed(x, y, m nat) nat { + // zz and r are used to avoid allocating in mul and div as otherwise + // the arguments would alias. + var zz, r nat + + const n = 4 + // powers[i] contains x^i. + var powers [1 << n]nat + powers[0] = natOne + powers[1] = x + for i := 2; i < 1<= 0; i-- { + yi := y[i] + for j := 0; j < _W; j += n { + if i != len(y)-1 || j != 0 { + // Unrolled loop for significant performance + // gain. Use go test -bench=".*" in crypto/rsa + // to check performance before making changes. + zz = zz.sqr(z) + zz, z = z, zz + zz, r = zz.div(r, z, m) + z, r = r, z + + zz = zz.sqr(z) + zz, z = z, zz + zz, r = zz.div(r, z, m) + z, r = r, z + + zz = zz.sqr(z) + zz, z = z, zz + zz, r = zz.div(r, z, m) + z, r = r, z + + zz = zz.sqr(z) + zz, z = z, zz + zz, r = zz.div(r, z, m) + z, r = r, z + } + + zz = zz.mul(z, powers[yi>>(_W-n)]) + zz, z = z, zz + zz, r = zz.div(r, z, m) + z, r = r, z + + yi <<= n + } + } + + return z.norm() +} + +// expNNMontgomery calculates x**y mod m using a fixed, 4-bit window. +// Uses Montgomery representation. +func (z nat) expNNMontgomery(x, y, m nat) nat { + numWords := len(m) + + // We want the lengths of x and m to be equal. + // It is OK if x >= m as long as len(x) == len(m). + if len(x) > numWords { + _, x = nat(nil).div(nil, x, m) + // Note: now len(x) <= numWords, not guaranteed ==. + } + if len(x) < numWords { + rr := make(nat, numWords) + copy(rr, x) + x = rr + } + + // Ideally the precomputations would be performed outside, and reused + // k0 = -m**-1 mod 2**_W. Algorithm from: Dumas, J.G. "On Newton–Raphson + // Iteration for Multiplicative Inverses Modulo Prime Powers". + k0 := 2 - m[0] + t := m[0] - 1 + for i := 1; i < _W; i <<= 1 { + t *= t + k0 *= (t + 1) + } + k0 = -k0 + + // RR = 2**(2*_W*len(m)) mod m + RR := nat(nil).setWord(1) + zz := nat(nil).shl(RR, uint(2*numWords*_W)) + _, RR = RR.div(RR, zz, m) + if len(RR) < numWords { + zz = zz.make(numWords) + copy(zz, RR) + RR = zz + } + // one = 1, with equal length to that of m + one := make(nat, numWords) + one[0] = 1 + + const n = 4 + // powers[i] contains x^i + var powers [1 << n]nat + powers[0] = powers[0].montgomery(one, RR, m, k0, numWords) + powers[1] = powers[1].montgomery(x, RR, m, k0, numWords) + for i := 2; i < 1<= 0; i-- { + yi := y[i] + for j := 0; j < _W; j += n { + if i != len(y)-1 || j != 0 { + zz = zz.montgomery(z, z, m, k0, numWords) + z = z.montgomery(zz, zz, m, k0, numWords) + zz = zz.montgomery(z, z, m, k0, numWords) + z = z.montgomery(zz, zz, m, k0, numWords) + } + zz = zz.montgomery(z, powers[yi>>(_W-n)], m, k0, numWords) + z, zz = zz, z + yi <<= n + } + } + // convert to regular number + zz = zz.montgomery(z, one, m, k0, numWords) + + // One last reduction, just in case. + // See golang.org/issue/13907. + if zz.cmp(m) >= 0 { + // Common case is m has high bit set; in that case, + // since zz is the same length as m, there can be just + // one multiple of m to remove. Just subtract. + // We think that the subtract should be sufficient in general, + // so do that unconditionally, but double-check, + // in case our beliefs are wrong. + // The div is not expected to be reached. + zz = zz.sub(zz, m) + if zz.cmp(m) >= 0 { + _, zz = nat(nil).div(nil, zz, m) + } + } + + return zz.norm() +} + +// bytes writes the value of z into buf using big-endian encoding. +// len(buf) must be >= len(z)*_S. The value of z is encoded in the +// slice buf[i:]. The number i of unused bytes at the beginning of +// buf is returned as result. +func (z nat) bytes(buf []byte) (i int) { + i = len(buf) + for _, d := range z { + for j := 0; j < _S; j++ { + i-- + buf[i] = byte(d) + d >>= 8 + } + } + + for i < len(buf) && buf[i] == 0 { + i++ + } + + return +} + +// setBytes interprets buf as the bytes of a big-endian unsigned +// integer, sets z to that value, and returns z. +func (z nat) setBytes(buf []byte) nat { + z = z.make((len(buf) + _S - 1) / _S) + + k := 0 + s := uint(0) + var d Word + for i := len(buf); i > 0; i-- { + d |= Word(buf[i-1]) << s + if s += 8; s == _S*8 { + z[k] = d + k++ + s = 0 + d = 0 + } + } + if k < len(z) { + z[k] = d + } + + return z.norm() +} + +// sqrt sets z = ⌊√x⌋ +func (z nat) sqrt(x nat) nat { + if x.cmp(natOne) <= 0 { + return z.set(x) + } + if alias(z, x) { + z = nil + } + + // Start with value known to be too large and repeat "z = ⌊(z + ⌊x/z⌋)/2⌋" until it stops getting smaller. + // See Brent and Zimmermann, Modern Computer Arithmetic, Algorithm 1.13 (SqrtInt). + // https://members.loria.fr/PZimmermann/mca/pub226.html + // If x is one less than a perfect square, the sequence oscillates between the correct z and z+1; + // otherwise it converges to the correct z and stays there. + var z1, z2 nat + z1 = z + z1 = z1.setUint64(1) + z1 = z1.shl(z1, uint(x.bitLen()/2+1)) // must be ≥ √x + for n := 0; ; n++ { + z2, _ = z2.div(nil, x, z1) + z2 = z2.add(z2, z1) + z2 = z2.shr(z2, 1) + if z2.cmp(z1) >= 0 { + // z1 is answer. + // Figure out whether z1 or z2 is currently aliased to z by looking at loop count. + if n&1 == 0 { + return z1 + } + return z.set(z1) + } + z1, z2 = z2, z1 + } +} diff --git a/vendor/github.com/golang/go/src/math/big/natconv.go b/vendor/github.com/golang/go/src/math/big/natconv.go new file mode 100644 index 000000000000..21ccbd6cfafc --- /dev/null +++ b/vendor/github.com/golang/go/src/math/big/natconv.go @@ -0,0 +1,503 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements nat-to-string conversion functions. + +package big + +import ( + "errors" + "fmt" + "io" + "math" + "math/bits" + "sync" +) + +const digits = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + +// Note: MaxBase = len(digits), but it must remain an untyped rune constant +// for API compatibility. + +// MaxBase is the largest number base accepted for string conversions. +const MaxBase = 10 + ('z' - 'a' + 1) + ('Z' - 'A' + 1) +const maxBaseSmall = 10 + ('z' - 'a' + 1) + +// maxPow returns (b**n, n) such that b**n is the largest power b**n <= _M. +// For instance maxPow(10) == (1e19, 19) for 19 decimal digits in a 64bit Word. +// In other words, at most n digits in base b fit into a Word. +// TODO(gri) replace this with a table, generated at build time. +func maxPow(b Word) (p Word, n int) { + p, n = b, 1 // assuming b <= _M + for max := _M / b; p <= max; { + // p == b**n && p <= max + p *= b + n++ + } + // p == b**n && p <= _M + return +} + +// pow returns x**n for n > 0, and 1 otherwise. +func pow(x Word, n int) (p Word) { + // n == sum of bi * 2**i, for 0 <= i < imax, and bi is 0 or 1 + // thus x**n == product of x**(2**i) for all i where bi == 1 + // (Russian Peasant Method for exponentiation) + p = 1 + for n > 0 { + if n&1 != 0 { + p *= x + } + x *= x + n >>= 1 + } + return +} + +// scan scans the number corresponding to the longest possible prefix +// from r representing an unsigned number in a given conversion base. +// It returns the corresponding natural number res, the actual base b, +// a digit count, and a read or syntax error err, if any. +// +// number = [ prefix ] mantissa . +// prefix = "0" [ "x" | "X" | "b" | "B" ] . +// mantissa = digits | digits "." [ digits ] | "." digits . +// digits = digit { digit } . +// digit = "0" ... "9" | "a" ... "z" | "A" ... "Z" . +// +// Unless fracOk is set, the base argument must be 0 or a value between +// 2 and MaxBase. If fracOk is set, the base argument must be one of +// 0, 2, 10, or 16. Providing an invalid base argument leads to a run- +// time panic. +// +// For base 0, the number prefix determines the actual base: A prefix of +// ``0x'' or ``0X'' selects base 16; if fracOk is not set, the ``0'' prefix +// selects base 8, and a ``0b'' or ``0B'' prefix selects base 2. Otherwise +// the selected base is 10 and no prefix is accepted. +// +// If fracOk is set, an octal prefix is ignored (a leading ``0'' simply +// stands for a zero digit), and a period followed by a fractional part +// is permitted. The result value is computed as if there were no period +// present; and the count value is used to determine the fractional part. +// +// For bases <= 36, lower and upper case letters are considered the same: +// The letters 'a' to 'z' and 'A' to 'Z' represent digit values 10 to 35. +// For bases > 36, the upper case letters 'A' to 'Z' represent the digit +// values 36 to 61. +// +// A result digit count > 0 corresponds to the number of (non-prefix) digits +// parsed. A digit count <= 0 indicates the presence of a period (if fracOk +// is set, only), and -count is the number of fractional digits found. +// In this case, the actual value of the scanned number is res * b**count. +// +func (z nat) scan(r io.ByteScanner, base int, fracOk bool) (res nat, b, count int, err error) { + // reject illegal bases + baseOk := base == 0 || + !fracOk && 2 <= base && base <= MaxBase || + fracOk && (base == 2 || base == 10 || base == 16) + if !baseOk { + panic(fmt.Sprintf("illegal number base %d", base)) + } + + // one char look-ahead + ch, err := r.ReadByte() + if err != nil { + return + } + + // determine actual base + b = base + if base == 0 { + // actual base is 10 unless there's a base prefix + b = 10 + if ch == '0' { + count = 1 + switch ch, err = r.ReadByte(); err { + case nil: + // possibly one of 0x, 0X, 0b, 0B + if !fracOk { + b = 8 + } + switch ch { + case 'x', 'X': + b = 16 + case 'b', 'B': + b = 2 + } + switch b { + case 16, 2: + count = 0 // prefix is not counted + if ch, err = r.ReadByte(); err != nil { + // io.EOF is also an error in this case + return + } + case 8: + count = 0 // prefix is not counted + } + case io.EOF: + // input is "0" + res = z[:0] + err = nil + return + default: + // read error + return + } + } + } + + // convert string + // Algorithm: Collect digits in groups of at most n digits in di + // and then use mulAddWW for every such group to add them to the + // result. + z = z[:0] + b1 := Word(b) + bn, n := maxPow(b1) // at most n digits in base b1 fit into Word + di := Word(0) // 0 <= di < b1**i < bn + i := 0 // 0 <= i < n + dp := -1 // position of decimal point + for { + if fracOk && ch == '.' { + fracOk = false + dp = count + // advance + if ch, err = r.ReadByte(); err != nil { + if err == io.EOF { + err = nil + break + } + return + } + } + + // convert rune into digit value d1 + var d1 Word + switch { + case '0' <= ch && ch <= '9': + d1 = Word(ch - '0') + case 'a' <= ch && ch <= 'z': + d1 = Word(ch - 'a' + 10) + case 'A' <= ch && ch <= 'Z': + if b <= maxBaseSmall { + d1 = Word(ch - 'A' + 10) + } else { + d1 = Word(ch - 'A' + maxBaseSmall) + } + default: + d1 = MaxBase + 1 + } + if d1 >= b1 { + r.UnreadByte() // ch does not belong to number anymore + break + } + count++ + + // collect d1 in di + di = di*b1 + d1 + i++ + + // if di is "full", add it to the result + if i == n { + z = z.mulAddWW(z, bn, di) + di = 0 + i = 0 + } + + // advance + if ch, err = r.ReadByte(); err != nil { + if err == io.EOF { + err = nil + break + } + return + } + } + + if count == 0 { + // no digits found + switch { + case base == 0 && b == 8: + // there was only the octal prefix 0 (possibly followed by digits > 7); + // count as one digit and return base 10, not 8 + count = 1 + b = 10 + case base != 0 || b != 8: + // there was neither a mantissa digit nor the octal prefix 0 + err = errors.New("syntax error scanning number") + } + return + } + // count > 0 + + // add remaining digits to result + if i > 0 { + z = z.mulAddWW(z, pow(b1, i), di) + } + res = z.norm() + + // adjust for fraction, if any + if dp >= 0 { + // 0 <= dp <= count > 0 + count = dp - count + } + + return +} + +// utoa converts x to an ASCII representation in the given base; +// base must be between 2 and MaxBase, inclusive. +func (x nat) utoa(base int) []byte { + return x.itoa(false, base) +} + +// itoa is like utoa but it prepends a '-' if neg && x != 0. +func (x nat) itoa(neg bool, base int) []byte { + if base < 2 || base > MaxBase { + panic("invalid base") + } + + // x == 0 + if len(x) == 0 { + return []byte("0") + } + // len(x) > 0 + + // allocate buffer for conversion + i := int(float64(x.bitLen())/math.Log2(float64(base))) + 1 // off by 1 at most + if neg { + i++ + } + s := make([]byte, i) + + // convert power of two and non power of two bases separately + if b := Word(base); b == b&-b { + // shift is base b digit size in bits + shift := uint(bits.TrailingZeros(uint(b))) // shift > 0 because b >= 2 + mask := Word(1<= shift { + i-- + s[i] = digits[w&mask] + w >>= shift + nbits -= shift + } + + // convert any partial leading digit and advance to next word + if nbits == 0 { + // no partial digit remaining, just advance + w = x[k] + nbits = _W + } else { + // partial digit in current word w (== x[k-1]) and next word x[k] + w |= x[k] << nbits + i-- + s[i] = digits[w&mask] + + // advance + w = x[k] >> (shift - nbits) + nbits = _W - (shift - nbits) + } + } + + // convert digits of most-significant word w (omit leading zeros) + for w != 0 { + i-- + s[i] = digits[w&mask] + w >>= shift + } + + } else { + bb, ndigits := maxPow(b) + + // construct table of successive squares of bb*leafSize to use in subdivisions + // result (table != nil) <=> (len(x) > leafSize > 0) + table := divisors(len(x), b, ndigits, bb) + + // preserve x, create local copy for use by convertWords + q := nat(nil).set(x) + + // convert q to string s in base b + q.convertWords(s, b, ndigits, bb, table) + + // strip leading zeros + // (x != 0; thus s must contain at least one non-zero digit + // and the loop will terminate) + i = 0 + for s[i] == '0' { + i++ + } + } + + if neg { + i-- + s[i] = '-' + } + + return s[i:] +} + +// Convert words of q to base b digits in s. If q is large, it is recursively "split in half" +// by nat/nat division using tabulated divisors. Otherwise, it is converted iteratively using +// repeated nat/Word division. +// +// The iterative method processes n Words by n divW() calls, each of which visits every Word in the +// incrementally shortened q for a total of n + (n-1) + (n-2) ... + 2 + 1, or n(n+1)/2 divW()'s. +// Recursive conversion divides q by its approximate square root, yielding two parts, each half +// the size of q. Using the iterative method on both halves means 2 * (n/2)(n/2 + 1)/2 divW()'s +// plus the expensive long div(). Asymptotically, the ratio is favorable at 1/2 the divW()'s, and +// is made better by splitting the subblocks recursively. Best is to split blocks until one more +// split would take longer (because of the nat/nat div()) than the twice as many divW()'s of the +// iterative approach. This threshold is represented by leafSize. Benchmarking of leafSize in the +// range 2..64 shows that values of 8 and 16 work well, with a 4x speedup at medium lengths and +// ~30x for 20000 digits. Use nat_test.go's BenchmarkLeafSize tests to optimize leafSize for +// specific hardware. +// +func (q nat) convertWords(s []byte, b Word, ndigits int, bb Word, table []divisor) { + // split larger blocks recursively + if table != nil { + // len(q) > leafSize > 0 + var r nat + index := len(table) - 1 + for len(q) > leafSize { + // find divisor close to sqrt(q) if possible, but in any case < q + maxLength := q.bitLen() // ~= log2 q, or at of least largest possible q of this bit length + minLength := maxLength >> 1 // ~= log2 sqrt(q) + for index > 0 && table[index-1].nbits > minLength { + index-- // desired + } + if table[index].nbits >= maxLength && table[index].bbb.cmp(q) >= 0 { + index-- + if index < 0 { + panic("internal inconsistency") + } + } + + // split q into the two digit number (q'*bbb + r) to form independent subblocks + q, r = q.div(r, q, table[index].bbb) + + // convert subblocks and collect results in s[:h] and s[h:] + h := len(s) - table[index].ndigits + r.convertWords(s[h:], b, ndigits, bb, table[0:index]) + s = s[:h] // == q.convertWords(s, b, ndigits, bb, table[0:index+1]) + } + } + + // having split any large blocks now process the remaining (small) block iteratively + i := len(s) + var r Word + if b == 10 { + // hard-coding for 10 here speeds this up by 1.25x (allows for / and % by constants) + for len(q) > 0 { + // extract least significant, base bb "digit" + q, r = q.divW(q, bb) + for j := 0; j < ndigits && i > 0; j++ { + i-- + // avoid % computation since r%10 == r - int(r/10)*10; + // this appears to be faster for BenchmarkString10000Base10 + // and smaller strings (but a bit slower for larger ones) + t := r / 10 + s[i] = '0' + byte(r-t*10) + r = t + } + } + } else { + for len(q) > 0 { + // extract least significant, base bb "digit" + q, r = q.divW(q, bb) + for j := 0; j < ndigits && i > 0; j++ { + i-- + s[i] = digits[r%b] + r /= b + } + } + } + + // prepend high-order zeros + for i > 0 { // while need more leading zeros + i-- + s[i] = '0' + } +} + +// Split blocks greater than leafSize Words (or set to 0 to disable recursive conversion) +// Benchmark and configure leafSize using: go test -bench="Leaf" +// 8 and 16 effective on 3.0 GHz Xeon "Clovertown" CPU (128 byte cache lines) +// 8 and 16 effective on 2.66 GHz Core 2 Duo "Penryn" CPU +var leafSize int = 8 // number of Word-size binary values treat as a monolithic block + +type divisor struct { + bbb nat // divisor + nbits int // bit length of divisor (discounting leading zeros) ~= log2(bbb) + ndigits int // digit length of divisor in terms of output base digits +} + +var cacheBase10 struct { + sync.Mutex + table [64]divisor // cached divisors for base 10 +} + +// expWW computes x**y +func (z nat) expWW(x, y Word) nat { + return z.expNN(nat(nil).setWord(x), nat(nil).setWord(y), nil) +} + +// construct table of powers of bb*leafSize to use in subdivisions +func divisors(m int, b Word, ndigits int, bb Word) []divisor { + // only compute table when recursive conversion is enabled and x is large + if leafSize == 0 || m <= leafSize { + return nil + } + + // determine k where (bb**leafSize)**(2**k) >= sqrt(x) + k := 1 + for words := leafSize; words < m>>1 && k < len(cacheBase10.table); words <<= 1 { + k++ + } + + // reuse and extend existing table of divisors or create new table as appropriate + var table []divisor // for b == 10, table overlaps with cacheBase10.table + if b == 10 { + cacheBase10.Lock() + table = cacheBase10.table[0:k] // reuse old table for this conversion + } else { + table = make([]divisor, k) // create new table for this conversion + } + + // extend table + if table[k-1].ndigits == 0 { + // add new entries as needed + var larger nat + for i := 0; i < k; i++ { + if table[i].ndigits == 0 { + if i == 0 { + table[0].bbb = nat(nil).expWW(bb, Word(leafSize)) + table[0].ndigits = ndigits * leafSize + } else { + table[i].bbb = nat(nil).sqr(table[i-1].bbb) + table[i].ndigits = 2 * table[i-1].ndigits + } + + // optimization: exploit aggregated extra bits in macro blocks + larger = nat(nil).set(table[i].bbb) + for mulAddVWW(larger, larger, b, 0) == 0 { + table[i].bbb = table[i].bbb.set(larger) + table[i].ndigits++ + } + + table[i].nbits = table[i].bbb.bitLen() + } + } + } + + if b == 10 { + cacheBase10.Unlock() + } + + return table +} diff --git a/vendor/github.com/golang/go/src/math/big/prime.go b/vendor/github.com/golang/go/src/math/big/prime.go new file mode 100644 index 000000000000..848affbf5bf4 --- /dev/null +++ b/vendor/github.com/golang/go/src/math/big/prime.go @@ -0,0 +1,320 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package big + +import "math/rand" + +// ProbablyPrime reports whether x is probably prime, +// applying the Miller-Rabin test with n pseudorandomly chosen bases +// as well as a Baillie-PSW test. +// +// If x is prime, ProbablyPrime returns true. +// If x is chosen randomly and not prime, ProbablyPrime probably returns false. +// The probability of returning true for a randomly chosen non-prime is at most ¼ⁿ. +// +// ProbablyPrime is 100% accurate for inputs less than 2⁶⁴. +// See Menezes et al., Handbook of Applied Cryptography, 1997, pp. 145-149, +// and FIPS 186-4 Appendix F for further discussion of the error probabilities. +// +// ProbablyPrime is not suitable for judging primes that an adversary may +// have crafted to fool the test. +// +// As of Go 1.8, ProbablyPrime(0) is allowed and applies only a Baillie-PSW test. +// Before Go 1.8, ProbablyPrime applied only the Miller-Rabin tests, and ProbablyPrime(0) panicked. +func (x *Int) ProbablyPrime(n int) bool { + // Note regarding the doc comment above: + // It would be more precise to say that the Baillie-PSW test uses the + // extra strong Lucas test as its Lucas test, but since no one knows + // how to tell any of the Lucas tests apart inside a Baillie-PSW test + // (they all work equally well empirically), that detail need not be + // documented or implicitly guaranteed. + // The comment does avoid saying "the" Baillie-PSW test + // because of this general ambiguity. + + if n < 0 { + panic("negative n for ProbablyPrime") + } + if x.neg || len(x.abs) == 0 { + return false + } + + // primeBitMask records the primes < 64. + const primeBitMask uint64 = 1<<2 | 1<<3 | 1<<5 | 1<<7 | + 1<<11 | 1<<13 | 1<<17 | 1<<19 | 1<<23 | 1<<29 | 1<<31 | + 1<<37 | 1<<41 | 1<<43 | 1<<47 | 1<<53 | 1<<59 | 1<<61 + + w := x.abs[0] + if len(x.abs) == 1 && w < 64 { + return primeBitMask&(1< 10000 { + // This is widely believed to be impossible. + // If we get a report, we'll want the exact number n. + panic("math/big: internal error: cannot find (D/n) = -1 for " + intN.String()) + } + d[0] = p*p - 4 + j := Jacobi(intD, intN) + if j == -1 { + break + } + if j == 0 { + // d = p²-4 = (p-2)(p+2). + // If (d/n) == 0 then d shares a prime factor with n. + // Since the loop proceeds in increasing p and starts with p-2==1, + // the shared prime factor must be p+2. + // If p+2 == n, then n is prime; otherwise p+2 is a proper factor of n. + return len(n) == 1 && n[0] == p+2 + } + if p == 40 { + // We'll never find (d/n) = -1 if n is a square. + // If n is a non-square we expect to find a d in just a few attempts on average. + // After 40 attempts, take a moment to check if n is indeed a square. + t1 = t1.sqrt(n) + t1 = t1.sqr(t1) + if t1.cmp(n) == 0 { + return false + } + } + } + + // Grantham definition of "extra strong Lucas pseudoprime", after Thm 2.3 on p. 876 + // (D, P, Q above have become Δ, b, 1): + // + // Let U_n = U_n(b, 1), V_n = V_n(b, 1), and Δ = b²-4. + // An extra strong Lucas pseudoprime to base b is a composite n = 2^r s + Jacobi(Δ, n), + // where s is odd and gcd(n, 2*Δ) = 1, such that either (i) U_s ≡ 0 mod n and V_s ≡ ±2 mod n, + // or (ii) V_{2^t s} ≡ 0 mod n for some 0 ≤ t < r-1. + // + // We know gcd(n, Δ) = 1 or else we'd have found Jacobi(d, n) == 0 above. + // We know gcd(n, 2) = 1 because n is odd. + // + // Arrange s = (n - Jacobi(Δ, n)) / 2^r = (n+1) / 2^r. + s := nat(nil).add(n, natOne) + r := int(s.trailingZeroBits()) + s = s.shr(s, uint(r)) + nm2 := nat(nil).sub(n, natTwo) // n-2 + + // We apply the "almost extra strong" test, which checks the above conditions + // except for U_s ≡ 0 mod n, which allows us to avoid computing any U_k values. + // Jacobsen points out that maybe we should just do the full extra strong test: + // "It is also possible to recover U_n using Crandall and Pomerance equation 3.13: + // U_n = D^-1 (2V_{n+1} - PV_n) allowing us to run the full extra-strong test + // at the cost of a single modular inversion. This computation is easy and fast in GMP, + // so we can get the full extra-strong test at essentially the same performance as the + // almost extra strong test." + + // Compute Lucas sequence V_s(b, 1), where: + // + // V(0) = 2 + // V(1) = P + // V(k) = P V(k-1) - Q V(k-2). + // + // (Remember that due to method C above, P = b, Q = 1.) + // + // In general V(k) = α^k + β^k, where α and β are roots of x² - Px + Q. + // Crandall and Pomerance (p.147) observe that for 0 ≤ j ≤ k, + // + // V(j+k) = V(j)V(k) - V(k-j). + // + // So in particular, to quickly double the subscript: + // + // V(2k) = V(k)² - 2 + // V(2k+1) = V(k) V(k+1) - P + // + // We can therefore start with k=0 and build up to k=s in log₂(s) steps. + natP := nat(nil).setWord(p) + vk := nat(nil).setWord(2) + vk1 := nat(nil).setWord(p) + t2 := nat(nil) // temp + for i := int(s.bitLen()); i >= 0; i-- { + if s.bit(uint(i)) != 0 { + // k' = 2k+1 + // V(k') = V(2k+1) = V(k) V(k+1) - P. + t1 = t1.mul(vk, vk1) + t1 = t1.add(t1, n) + t1 = t1.sub(t1, natP) + t2, vk = t2.div(vk, t1, n) + // V(k'+1) = V(2k+2) = V(k+1)² - 2. + t1 = t1.sqr(vk1) + t1 = t1.add(t1, nm2) + t2, vk1 = t2.div(vk1, t1, n) + } else { + // k' = 2k + // V(k'+1) = V(2k+1) = V(k) V(k+1) - P. + t1 = t1.mul(vk, vk1) + t1 = t1.add(t1, n) + t1 = t1.sub(t1, natP) + t2, vk1 = t2.div(vk1, t1, n) + // V(k') = V(2k) = V(k)² - 2 + t1 = t1.sqr(vk) + t1 = t1.add(t1, nm2) + t2, vk = t2.div(vk, t1, n) + } + } + + // Now k=s, so vk = V(s). Check V(s) ≡ ±2 (mod n). + if vk.cmp(natTwo) == 0 || vk.cmp(nm2) == 0 { + // Check U(s) ≡ 0. + // As suggested by Jacobsen, apply Crandall and Pomerance equation 3.13: + // + // U(k) = D⁻¹ (2 V(k+1) - P V(k)) + // + // Since we are checking for U(k) == 0 it suffices to check 2 V(k+1) == P V(k) mod n, + // or P V(k) - 2 V(k+1) == 0 mod n. + t1 := t1.mul(vk, natP) + t2 := t2.shl(vk1, 1) + if t1.cmp(t2) < 0 { + t1, t2 = t2, t1 + } + t1 = t1.sub(t1, t2) + t3 := vk1 // steal vk1, no longer needed below + vk1 = nil + _ = vk1 + t2, t3 = t2.div(t3, t1, n) + if len(t3) == 0 { + return true + } + } + + // Check V(2^t s) ≡ 0 mod n for some 0 ≤ t < r-1. + for t := 0; t < r-1; t++ { + if len(vk) == 0 { // vk == 0 + return true + } + // Optimization: V(k) = 2 is a fixed point for V(k') = V(k)² - 2, + // so if V(k) = 2, we can stop: we will never find a future V(k) == 0. + if len(vk) == 1 && vk[0] == 2 { // vk == 2 + return false + } + // k' = 2k + // V(k') = V(2k) = V(k)² - 2 + t1 = t1.sqr(vk) + t1 = t1.sub(t1, natTwo) + t2, vk = t2.div(vk, t1, n) + } + return false +} diff --git a/vendor/github.com/golang/go/src/math/big/rat.go b/vendor/github.com/golang/go/src/math/big/rat.go new file mode 100644 index 000000000000..b33fc696adf5 --- /dev/null +++ b/vendor/github.com/golang/go/src/math/big/rat.go @@ -0,0 +1,517 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements multi-precision rational numbers. + +package big + +import ( + "fmt" + "math" +) + +// A Rat represents a quotient a/b of arbitrary precision. +// The zero value for a Rat represents the value 0. +type Rat struct { + // To make zero values for Rat work w/o initialization, + // a zero value of b (len(b) == 0) acts like b == 1. + // a.neg determines the sign of the Rat, b.neg is ignored. + a, b Int +} + +// NewRat creates a new Rat with numerator a and denominator b. +func NewRat(a, b int64) *Rat { + return new(Rat).SetFrac64(a, b) +} + +// SetFloat64 sets z to exactly f and returns z. +// If f is not finite, SetFloat returns nil. +func (z *Rat) SetFloat64(f float64) *Rat { + const expMask = 1<<11 - 1 + bits := math.Float64bits(f) + mantissa := bits & (1<<52 - 1) + exp := int((bits >> 52) & expMask) + switch exp { + case expMask: // non-finite + return nil + case 0: // denormal + exp -= 1022 + default: // normal + mantissa |= 1 << 52 + exp -= 1023 + } + + shift := 52 - exp + + // Optimization (?): partially pre-normalise. + for mantissa&1 == 0 && shift > 0 { + mantissa >>= 1 + shift-- + } + + z.a.SetUint64(mantissa) + z.a.neg = f < 0 + z.b.Set(intOne) + if shift > 0 { + z.b.Lsh(&z.b, uint(shift)) + } else { + z.a.Lsh(&z.a, uint(-shift)) + } + return z.norm() +} + +// quotToFloat32 returns the non-negative float32 value +// nearest to the quotient a/b, using round-to-even in +// halfway cases. It does not mutate its arguments. +// Preconditions: b is non-zero; a and b have no common factors. +func quotToFloat32(a, b nat) (f float32, exact bool) { + const ( + // float size in bits + Fsize = 32 + + // mantissa + Msize = 23 + Msize1 = Msize + 1 // incl. implicit 1 + Msize2 = Msize1 + 1 + + // exponent + Esize = Fsize - Msize1 + Ebias = 1<<(Esize-1) - 1 + Emin = 1 - Ebias + Emax = Ebias + ) + + // TODO(adonovan): specialize common degenerate cases: 1.0, integers. + alen := a.bitLen() + if alen == 0 { + return 0, true + } + blen := b.bitLen() + if blen == 0 { + panic("division by zero") + } + + // 1. Left-shift A or B such that quotient A/B is in [1<= B). + // This is 2 or 3 more than the float32 mantissa field width of Msize: + // - the optional extra bit is shifted away in step 3 below. + // - the high-order 1 is omitted in "normal" representation; + // - the low-order 1 will be used during rounding then discarded. + exp := alen - blen + var a2, b2 nat + a2 = a2.set(a) + b2 = b2.set(b) + if shift := Msize2 - exp; shift > 0 { + a2 = a2.shl(a2, uint(shift)) + } else if shift < 0 { + b2 = b2.shl(b2, uint(-shift)) + } + + // 2. Compute quotient and remainder (q, r). NB: due to the + // extra shift, the low-order bit of q is logically the + // high-order bit of r. + var q nat + q, r := q.div(a2, a2, b2) // (recycle a2) + mantissa := low32(q) + haveRem := len(r) > 0 // mantissa&1 && !haveRem => remainder is exactly half + + // 3. If quotient didn't fit in Msize2 bits, redo division by b2<<1 + // (in effect---we accomplish this incrementally). + if mantissa>>Msize2 == 1 { + if mantissa&1 == 1 { + haveRem = true + } + mantissa >>= 1 + exp++ + } + if mantissa>>Msize1 != 1 { + panic(fmt.Sprintf("expected exactly %d bits of result", Msize2)) + } + + // 4. Rounding. + if Emin-Msize <= exp && exp <= Emin { + // Denormal case; lose 'shift' bits of precision. + shift := uint(Emin - (exp - 1)) // [1..Esize1) + lostbits := mantissa & (1<>= shift + exp = 2 - Ebias // == exp + shift + } + // Round q using round-half-to-even. + exact = !haveRem + if mantissa&1 != 0 { + exact = false + if haveRem || mantissa&2 != 0 { + if mantissa++; mantissa >= 1< 100...0, so shift is safe + mantissa >>= 1 + exp++ + } + } + } + mantissa >>= 1 // discard rounding bit. Mantissa now scaled by 1<= B). + // This is 2 or 3 more than the float64 mantissa field width of Msize: + // - the optional extra bit is shifted away in step 3 below. + // - the high-order 1 is omitted in "normal" representation; + // - the low-order 1 will be used during rounding then discarded. + exp := alen - blen + var a2, b2 nat + a2 = a2.set(a) + b2 = b2.set(b) + if shift := Msize2 - exp; shift > 0 { + a2 = a2.shl(a2, uint(shift)) + } else if shift < 0 { + b2 = b2.shl(b2, uint(-shift)) + } + + // 2. Compute quotient and remainder (q, r). NB: due to the + // extra shift, the low-order bit of q is logically the + // high-order bit of r. + var q nat + q, r := q.div(a2, a2, b2) // (recycle a2) + mantissa := low64(q) + haveRem := len(r) > 0 // mantissa&1 && !haveRem => remainder is exactly half + + // 3. If quotient didn't fit in Msize2 bits, redo division by b2<<1 + // (in effect---we accomplish this incrementally). + if mantissa>>Msize2 == 1 { + if mantissa&1 == 1 { + haveRem = true + } + mantissa >>= 1 + exp++ + } + if mantissa>>Msize1 != 1 { + panic(fmt.Sprintf("expected exactly %d bits of result", Msize2)) + } + + // 4. Rounding. + if Emin-Msize <= exp && exp <= Emin { + // Denormal case; lose 'shift' bits of precision. + shift := uint(Emin - (exp - 1)) // [1..Esize1) + lostbits := mantissa & (1<>= shift + exp = 2 - Ebias // == exp + shift + } + // Round q using round-half-to-even. + exact = !haveRem + if mantissa&1 != 0 { + exact = false + if haveRem || mantissa&2 != 0 { + if mantissa++; mantissa >= 1< 100...0, so shift is safe + mantissa >>= 1 + exp++ + } + } + } + mantissa >>= 1 // discard rounding bit. Mantissa now scaled by 1< 0 && !z.a.neg // 0 has no sign + return z +} + +// Inv sets z to 1/x and returns z. +func (z *Rat) Inv(x *Rat) *Rat { + if len(x.a.abs) == 0 { + panic("division by zero") + } + z.Set(x) + a := z.b.abs + if len(a) == 0 { + a = a.set(natOne) // materialize numerator + } + b := z.a.abs + if b.cmp(natOne) == 0 { + b = b[:0] // normalize denominator + } + z.a.abs, z.b.abs = a, b // sign doesn't change + return z +} + +// Sign returns: +// +// -1 if x < 0 +// 0 if x == 0 +// +1 if x > 0 +// +func (x *Rat) Sign() int { + return x.a.Sign() +} + +// IsInt reports whether the denominator of x is 1. +func (x *Rat) IsInt() bool { + return len(x.b.abs) == 0 || x.b.abs.cmp(natOne) == 0 +} + +// Num returns the numerator of x; it may be <= 0. +// The result is a reference to x's numerator; it +// may change if a new value is assigned to x, and vice versa. +// The sign of the numerator corresponds to the sign of x. +func (x *Rat) Num() *Int { + return &x.a +} + +// Denom returns the denominator of x; it is always > 0. +// The result is a reference to x's denominator; it +// may change if a new value is assigned to x, and vice versa. +func (x *Rat) Denom() *Int { + x.b.neg = false // the result is always >= 0 + if len(x.b.abs) == 0 { + x.b.abs = x.b.abs.set(natOne) // materialize denominator + } + return &x.b +} + +func (z *Rat) norm() *Rat { + switch { + case len(z.a.abs) == 0: + // z == 0 - normalize sign and denominator + z.a.neg = false + z.b.abs = z.b.abs[:0] + case len(z.b.abs) == 0: + // z is normalized int - nothing to do + case z.b.abs.cmp(natOne) == 0: + // z is int - normalize denominator + z.b.abs = z.b.abs[:0] + default: + neg := z.a.neg + z.a.neg = false + z.b.neg = false + if f := NewInt(0).lehmerGCD(&z.a, &z.b); f.Cmp(intOne) != 0 { + z.a.abs, _ = z.a.abs.div(nil, z.a.abs, f.abs) + z.b.abs, _ = z.b.abs.div(nil, z.b.abs, f.abs) + if z.b.abs.cmp(natOne) == 0 { + // z is int - normalize denominator + z.b.abs = z.b.abs[:0] + } + } + z.a.neg = neg + } + return z +} + +// mulDenom sets z to the denominator product x*y (by taking into +// account that 0 values for x or y must be interpreted as 1) and +// returns z. +func mulDenom(z, x, y nat) nat { + switch { + case len(x) == 0: + return z.set(y) + case len(y) == 0: + return z.set(x) + } + return z.mul(x, y) +} + +// scaleDenom computes x*f. +// If f == 0 (zero value of denominator), the result is (a copy of) x. +func scaleDenom(x *Int, f nat) *Int { + var z Int + if len(f) == 0 { + return z.Set(x) + } + z.abs = z.abs.mul(x.abs, f) + z.neg = x.neg + return &z +} + +// Cmp compares x and y and returns: +// +// -1 if x < y +// 0 if x == y +// +1 if x > y +// +func (x *Rat) Cmp(y *Rat) int { + return scaleDenom(&x.a, y.b.abs).Cmp(scaleDenom(&y.a, x.b.abs)) +} + +// Add sets z to the sum x+y and returns z. +func (z *Rat) Add(x, y *Rat) *Rat { + a1 := scaleDenom(&x.a, y.b.abs) + a2 := scaleDenom(&y.a, x.b.abs) + z.a.Add(a1, a2) + z.b.abs = mulDenom(z.b.abs, x.b.abs, y.b.abs) + return z.norm() +} + +// Sub sets z to the difference x-y and returns z. +func (z *Rat) Sub(x, y *Rat) *Rat { + a1 := scaleDenom(&x.a, y.b.abs) + a2 := scaleDenom(&y.a, x.b.abs) + z.a.Sub(a1, a2) + z.b.abs = mulDenom(z.b.abs, x.b.abs, y.b.abs) + return z.norm() +} + +// Mul sets z to the product x*y and returns z. +func (z *Rat) Mul(x, y *Rat) *Rat { + if x == y { + // a squared Rat is positive and can't be reduced + z.a.neg = false + z.a.abs = z.a.abs.sqr(x.a.abs) + z.b.abs = z.b.abs.sqr(x.b.abs) + return z + } + z.a.Mul(&x.a, &y.a) + z.b.abs = mulDenom(z.b.abs, x.b.abs, y.b.abs) + return z.norm() +} + +// Quo sets z to the quotient x/y and returns z. +// If y == 0, a division-by-zero run-time panic occurs. +func (z *Rat) Quo(x, y *Rat) *Rat { + if len(y.a.abs) == 0 { + panic("division by zero") + } + a := scaleDenom(&x.a, y.b.abs) + b := scaleDenom(&y.a, x.b.abs) + z.a.abs = a.abs + z.b.abs = b.abs + z.a.neg = a.neg != b.neg + return z.norm() +} diff --git a/vendor/github.com/golang/go/src/math/big/ratconv.go b/vendor/github.com/golang/go/src/math/big/ratconv.go new file mode 100644 index 000000000000..157d8d006d4f --- /dev/null +++ b/vendor/github.com/golang/go/src/math/big/ratconv.go @@ -0,0 +1,283 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements rat-to-string conversion functions. + +package big + +import ( + "errors" + "fmt" + "io" + "strconv" + "strings" +) + +func ratTok(ch rune) bool { + return strings.ContainsRune("+-/0123456789.eE", ch) +} + +var ratZero Rat +var _ fmt.Scanner = &ratZero // *Rat must implement fmt.Scanner + +// Scan is a support routine for fmt.Scanner. It accepts the formats +// 'e', 'E', 'f', 'F', 'g', 'G', and 'v'. All formats are equivalent. +func (z *Rat) Scan(s fmt.ScanState, ch rune) error { + tok, err := s.Token(true, ratTok) + if err != nil { + return err + } + if !strings.ContainsRune("efgEFGv", ch) { + return errors.New("Rat.Scan: invalid verb") + } + if _, ok := z.SetString(string(tok)); !ok { + return errors.New("Rat.Scan: invalid syntax") + } + return nil +} + +// SetString sets z to the value of s and returns z and a boolean indicating +// success. s can be given as a fraction "a/b" or as a floating-point number +// optionally followed by an exponent. The entire string (not just a prefix) +// must be valid for success. If the operation failed, the value of z is +// undefined but the returned value is nil. +func (z *Rat) SetString(s string) (*Rat, bool) { + if len(s) == 0 { + return nil, false + } + // len(s) > 0 + + // parse fraction a/b, if any + if sep := strings.Index(s, "/"); sep >= 0 { + if _, ok := z.a.SetString(s[:sep], 0); !ok { + return nil, false + } + r := strings.NewReader(s[sep+1:]) + var err error + if z.b.abs, _, _, err = z.b.abs.scan(r, 0, false); err != nil { + return nil, false + } + // entire string must have been consumed + if _, err = r.ReadByte(); err != io.EOF { + return nil, false + } + if len(z.b.abs) == 0 { + return nil, false + } + return z.norm(), true + } + + // parse floating-point number + r := strings.NewReader(s) + + // sign + neg, err := scanSign(r) + if err != nil { + return nil, false + } + + // mantissa + var ecorr int + z.a.abs, _, ecorr, err = z.a.abs.scan(r, 10, true) + if err != nil { + return nil, false + } + + // exponent + var exp int64 + exp, _, err = scanExponent(r, false) + if err != nil { + return nil, false + } + + // there should be no unread characters left + if _, err = r.ReadByte(); err != io.EOF { + return nil, false + } + + // special-case 0 (see also issue #16176) + if len(z.a.abs) == 0 { + return z, true + } + // len(z.a.abs) > 0 + + // correct exponent + if ecorr < 0 { + exp += int64(ecorr) + } + + // compute exponent power + expabs := exp + if expabs < 0 { + expabs = -expabs + } + powTen := nat(nil).expNN(natTen, nat(nil).setWord(Word(expabs)), nil) + + // complete fraction + if exp < 0 { + z.b.abs = powTen + z.norm() + } else { + z.a.abs = z.a.abs.mul(z.a.abs, powTen) + z.b.abs = z.b.abs[:0] + } + + z.a.neg = neg && len(z.a.abs) > 0 // 0 has no sign + + return z, true +} + +// scanExponent scans the longest possible prefix of r representing a decimal +// ('e', 'E') or binary ('p') exponent, if any. It returns the exponent, the +// exponent base (10 or 2), or a read or syntax error, if any. +// +// exponent = ( "E" | "e" | "p" ) [ sign ] digits . +// sign = "+" | "-" . +// digits = digit { digit } . +// digit = "0" ... "9" . +// +// A binary exponent is only permitted if binExpOk is set. +func scanExponent(r io.ByteScanner, binExpOk bool) (exp int64, base int, err error) { + base = 10 + + var ch byte + if ch, err = r.ReadByte(); err != nil { + if err == io.EOF { + err = nil // no exponent; same as e0 + } + return + } + + switch ch { + case 'e', 'E': + // ok + case 'p': + if binExpOk { + base = 2 + break // ok + } + fallthrough // binary exponent not permitted + default: + r.UnreadByte() + return // no exponent; same as e0 + } + + var neg bool + if neg, err = scanSign(r); err != nil { + return + } + + var digits []byte + if neg { + digits = append(digits, '-') + } + + // no need to use nat.scan for exponent digits + // since we only care about int64 values - the + // from-scratch scan is easy enough and faster + for i := 0; ; i++ { + if ch, err = r.ReadByte(); err != nil { + if err != io.EOF || i == 0 { + return + } + err = nil + break // i > 0 + } + if ch < '0' || '9' < ch { + if i == 0 { + r.UnreadByte() + err = fmt.Errorf("invalid exponent (missing digits)") + return + } + break // i > 0 + } + digits = append(digits, ch) + } + // i > 0 => we have at least one digit + + exp, err = strconv.ParseInt(string(digits), 10, 64) + return +} + +// String returns a string representation of x in the form "a/b" (even if b == 1). +func (x *Rat) String() string { + return string(x.marshal()) +} + +// marshal implements String returning a slice of bytes +func (x *Rat) marshal() []byte { + var buf []byte + buf = x.a.Append(buf, 10) + buf = append(buf, '/') + if len(x.b.abs) != 0 { + buf = x.b.Append(buf, 10) + } else { + buf = append(buf, '1') + } + return buf +} + +// RatString returns a string representation of x in the form "a/b" if b != 1, +// and in the form "a" if b == 1. +func (x *Rat) RatString() string { + if x.IsInt() { + return x.a.String() + } + return x.String() +} + +// FloatString returns a string representation of x in decimal form with prec +// digits of precision after the decimal point. The last digit is rounded to +// nearest, with halves rounded away from zero. +func (x *Rat) FloatString(prec int) string { + var buf []byte + + if x.IsInt() { + buf = x.a.Append(buf, 10) + if prec > 0 { + buf = append(buf, '.') + for i := prec; i > 0; i-- { + buf = append(buf, '0') + } + } + return string(buf) + } + // x.b.abs != 0 + + q, r := nat(nil).div(nat(nil), x.a.abs, x.b.abs) + + p := natOne + if prec > 0 { + p = nat(nil).expNN(natTen, nat(nil).setUint64(uint64(prec)), nil) + } + + r = r.mul(r, p) + r, r2 := r.div(nat(nil), r, x.b.abs) + + // see if we need to round up + r2 = r2.add(r2, r2) + if x.b.abs.cmp(r2) <= 0 { + r = r.add(r, natOne) + if r.cmp(p) >= 0 { + q = nat(nil).add(q, natOne) + r = nat(nil).sub(r, p) + } + } + + if x.a.neg { + buf = append(buf, '-') + } + buf = append(buf, q.utoa(10)...) // itoa ignores sign if q == 0 + + if prec > 0 { + buf = append(buf, '.') + rs := r.utoa(10) + for i := prec - len(rs); i > 0; i-- { + buf = append(buf, '0') + } + buf = append(buf, rs...) + } + + return string(buf) +} diff --git a/vendor/github.com/golang/go/src/math/big/ratmarsh.go b/vendor/github.com/golang/go/src/math/big/ratmarsh.go new file mode 100644 index 000000000000..fbc7b6002d95 --- /dev/null +++ b/vendor/github.com/golang/go/src/math/big/ratmarsh.go @@ -0,0 +1,75 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements encoding/decoding of Rats. + +package big + +import ( + "encoding/binary" + "errors" + "fmt" +) + +// Gob codec version. Permits backward-compatible changes to the encoding. +const ratGobVersion byte = 1 + +// GobEncode implements the gob.GobEncoder interface. +func (x *Rat) GobEncode() ([]byte, error) { + if x == nil { + return nil, nil + } + buf := make([]byte, 1+4+(len(x.a.abs)+len(x.b.abs))*_S) // extra bytes for version and sign bit (1), and numerator length (4) + i := x.b.abs.bytes(buf) + j := x.a.abs.bytes(buf[:i]) + n := i - j + if int(uint32(n)) != n { + // this should never happen + return nil, errors.New("Rat.GobEncode: numerator too large") + } + binary.BigEndian.PutUint32(buf[j-4:j], uint32(n)) + j -= 1 + 4 + b := ratGobVersion << 1 // make space for sign bit + if x.a.neg { + b |= 1 + } + buf[j] = b + return buf[j:], nil +} + +// GobDecode implements the gob.GobDecoder interface. +func (z *Rat) GobDecode(buf []byte) error { + if len(buf) == 0 { + // Other side sent a nil or default value. + *z = Rat{} + return nil + } + b := buf[0] + if b>>1 != ratGobVersion { + return fmt.Errorf("Rat.GobDecode: encoding version %d not supported", b>>1) + } + const j = 1 + 4 + i := j + binary.BigEndian.Uint32(buf[j-4:j]) + z.a.neg = b&1 != 0 + z.a.abs = z.a.abs.setBytes(buf[j:i]) + z.b.abs = z.b.abs.setBytes(buf[i:]) + return nil +} + +// MarshalText implements the encoding.TextMarshaler interface. +func (x *Rat) MarshalText() (text []byte, err error) { + if x.IsInt() { + return x.a.MarshalText() + } + return x.marshal(), nil +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface. +func (z *Rat) UnmarshalText(text []byte) error { + // TODO(gri): get rid of the []byte/string conversion + if _, ok := z.SetString(string(text)); !ok { + return fmt.Errorf("math/big: cannot unmarshal %q into a *big.Rat", text) + } + return nil +} diff --git a/vendor/github.com/golang/go/src/math/big/roundingmode_string.go b/vendor/github.com/golang/go/src/math/big/roundingmode_string.go new file mode 100644 index 000000000000..05024b806562 --- /dev/null +++ b/vendor/github.com/golang/go/src/math/big/roundingmode_string.go @@ -0,0 +1,16 @@ +// generated by stringer -type=RoundingMode; DO NOT EDIT + +package big + +import "fmt" + +const _RoundingMode_name = "ToNearestEvenToNearestAwayToZeroAwayFromZeroToNegativeInfToPositiveInf" + +var _RoundingMode_index = [...]uint8{0, 13, 26, 32, 44, 57, 70} + +func (i RoundingMode) String() string { + if i+1 >= RoundingMode(len(_RoundingMode_index)) { + return fmt.Sprintf("RoundingMode(%d)", i) + } + return _RoundingMode_name[_RoundingMode_index[i]:_RoundingMode_index[i+1]] +} diff --git a/vendor/github.com/golang/go/src/math/big/sqrt.go b/vendor/github.com/golang/go/src/math/big/sqrt.go new file mode 100644 index 000000000000..00433cfe7a70 --- /dev/null +++ b/vendor/github.com/golang/go/src/math/big/sqrt.go @@ -0,0 +1,151 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package big + +import "math" + +var ( + half = NewFloat(0.5) + two = NewFloat(2.0) + three = NewFloat(3.0) +) + +// Sqrt sets z to the rounded square root of x, and returns it. +// +// If z's precision is 0, it is changed to x's precision before the +// operation. Rounding is performed according to z's precision and +// rounding mode. +// +// The function panics if z < 0. The value of z is undefined in that +// case. +func (z *Float) Sqrt(x *Float) *Float { + if debugFloat { + x.validate() + } + + if z.prec == 0 { + z.prec = x.prec + } + + if x.Sign() == -1 { + // following IEEE754-2008 (section 7.2) + panic(ErrNaN{"square root of negative operand"}) + } + + // handle ±0 and +∞ + if x.form != finite { + z.acc = Exact + z.form = x.form + z.neg = x.neg // IEEE754-2008 requires √±0 = ±0 + return z + } + + // MantExp sets the argument's precision to the receiver's, and + // when z.prec > x.prec this will lower z.prec. Restore it after + // the MantExp call. + prec := z.prec + b := x.MantExp(z) + z.prec = prec + + // Compute √(z·2**b) as + // √( z)·2**(½b) if b is even + // √(2z)·2**(⌊½b⌋) if b > 0 is odd + // √(½z)·2**(⌈½b⌉) if b < 0 is odd + switch b % 2 { + case 0: + // nothing to do + case 1: + z.Mul(two, z) + case -1: + z.Mul(half, z) + } + // 0.25 <= z < 2.0 + + // Solving x² - z = 0 directly requires a Quo call, but it's + // faster for small precisions. + // + // Solving 1/x² - z = 0 avoids the Quo call and is much faster for + // high precisions. + // + // 128bit precision is an empirically chosen threshold. + if z.prec <= 128 { + z.sqrtDirect(z) + } else { + z.sqrtInverse(z) + } + + // re-attach halved exponent + return z.SetMantExp(z, b/2) +} + +// Compute √x (up to prec 128) by solving +// t² - x = 0 +// for t, starting with a 53 bits precision guess from math.Sqrt and +// then using at most two iterations of Newton's method. +func (z *Float) sqrtDirect(x *Float) { + // let + // f(t) = t² - x + // then + // g(t) = f(t)/f'(t) = ½(t² - x)/t + // and the next guess is given by + // t2 = t - g(t) = ½(t² + x)/t + u := new(Float) + ng := func(t *Float) *Float { + u.prec = t.prec + u.Mul(t, t) // u = t² + u.Add(u, x) // = t² + x + u.Mul(half, u) // = ½(t² + x) + return t.Quo(u, t) // = ½(t² + x)/t + } + + xf, _ := x.Float64() + sq := NewFloat(math.Sqrt(xf)) + + switch { + case z.prec > 128: + panic("sqrtDirect: only for z.prec <= 128") + case z.prec > 64: + sq.prec *= 2 + sq = ng(sq) + fallthrough + default: + sq.prec *= 2 + sq = ng(sq) + } + + z.Set(sq) +} + +// Compute √x (to z.prec precision) by solving +// 1/t² - x = 0 +// for t (using Newton's method), and then inverting. +func (z *Float) sqrtInverse(x *Float) { + // let + // f(t) = 1/t² - x + // then + // g(t) = f(t)/f'(t) = -½t(1 - xt²) + // and the next guess is given by + // t2 = t - g(t) = ½t(3 - xt²) + u := new(Float) + ng := func(t *Float) *Float { + u.prec = t.prec + u.Mul(t, t) // u = t² + u.Mul(x, u) // = xt² + u.Sub(three, u) // = 3 - xt² + u.Mul(t, u) // = t(3 - xt²) + return t.Mul(half, u) // = ½t(3 - xt²) + } + + xf, _ := x.Float64() + sqi := NewFloat(1 / math.Sqrt(xf)) + for prec := z.prec + 32; sqi.prec < prec; { + sqi.prec *= 2 + sqi = ng(sqi) + } + // sqi = 1/√x + + // x/√x = √x + z.Mul(x, sqi) +} diff --git a/vendor/github.com/posener/complete/cmd/install/utils.go b/vendor/github.com/posener/complete/cmd/install/utils.go index 8bcf4e1517d2..bb709bc6cd98 100644 --- a/vendor/github.com/posener/complete/cmd/install/utils.go +++ b/vendor/github.com/posener/complete/cmd/install/utils.go @@ -6,6 +6,7 @@ import ( "io" "io/ioutil" "os" + "path/filepath" ) func lineInFile(name string, lookFor string) bool { @@ -37,11 +38,19 @@ func lineInFile(name string, lookFor string) bool { } func createFile(name string, content string) error { + // make sure file directory exists + if err := os.MkdirAll(filepath.Dir(name), 0775); err != nil { + return err + } + + // create the file f, err := os.Create(name) if err != nil { return err } defer f.Close() + + // write file content _, err = f.WriteString(fmt.Sprintf("%s\n", content)) return err } diff --git a/vendor/vendor.json b/vendor/vendor.json index 84f64d244af1..3b4a9df53dd2 100644 --- a/vendor/vendor.json +++ b/vendor/vendor.json @@ -924,6 +924,12 @@ "revision": "23def4e6c14b4da8ac2ed8007337bc5eb5007998", "revisionTime": "2016-01-25T20:49:56Z" }, + { + "checksumSHA1": "y1GG7DFXzx8ABfTWlVdgMVvlYU0=", + "path": "github.com/golang/go/src/math/big", + "revision": "bf86aec25972f3a100c3aa58a6abcbcc35bdea49", + "revisionTime": "2018-02-16T04:20:11Z" + }, { "checksumSHA1": "WX1+2gktHcBmE9MGwFSGs7oqexU=", "path": "github.com/golang/protobuf/proto", @@ -1539,26 +1545,26 @@ { "checksumSHA1": "eKclqCehbe7JsvlemLF7TfjMWf0=", "path": "github.com/posener/complete", - "revision": "cdc49b71388c2ab059f57997ef2575c9e8b4f146", - "revisionTime": "2018-01-19T09:07:45Z" + "revision": "98eb9847f27ba2008d380a32c98be474dea55bdf", + "revisionTime": "2018-03-09T06:24:32Z" }, { "checksumSHA1": "NB7uVS0/BJDmNu68vPAlbrq4TME=", "path": "github.com/posener/complete/cmd", - "revision": "cdc49b71388c2ab059f57997ef2575c9e8b4f146", - "revisionTime": "2018-01-19T09:07:45Z" + "revision": "98eb9847f27ba2008d380a32c98be474dea55bdf", + "revisionTime": "2018-03-09T06:24:32Z" }, { - "checksumSHA1": "/HKxX422GpzWV56uW87cwXEWYV8=", + "checksumSHA1": "llSE1833yASSLHfDuN7lKx48020=", "path": "github.com/posener/complete/cmd/install", - "revision": "cdc49b71388c2ab059f57997ef2575c9e8b4f146", - "revisionTime": "2018-01-19T09:07:45Z" + "revision": "98eb9847f27ba2008d380a32c98be474dea55bdf", + "revisionTime": "2018-03-09T06:24:32Z" }, { "checksumSHA1": "DMo94FwJAm9ZCYCiYdJU2+bh4no=", "path": "github.com/posener/complete/match", - "revision": "cdc49b71388c2ab059f57997ef2575c9e8b4f146", - "revisionTime": "2018-01-19T09:07:45Z" + "revision": "98eb9847f27ba2008d380a32c98be474dea55bdf", + "revisionTime": "2018-03-09T06:24:32Z" }, { "checksumSHA1": "vCogt04lbcE8fUgvRCOaZQUo+Pk=", diff --git a/website/Gemfile b/website/Gemfile index bb10a536ebbb..d21d35e45145 100644 --- a/website/Gemfile +++ b/website/Gemfile @@ -1,3 +1,3 @@ source "https://rubygems.org" -gem "middleman-hashicorp", "0.3.29" +gem "middleman-hashicorp", "0.3.30" diff --git a/website/Gemfile.lock b/website/Gemfile.lock index 843cf36e18f1..e567bf2edeb2 100644 --- a/website/Gemfile.lock +++ b/website/Gemfile.lock @@ -6,7 +6,7 @@ GEM minitest (~> 5.1) thread_safe (~> 0.3, >= 0.3.4) tzinfo (~> 1.1) - autoprefixer-rails (7.1.5) + autoprefixer-rails (8.1.0) execjs bootstrap-sass (3.3.7) autoprefixer-rails (>= 5.2.1) @@ -18,7 +18,7 @@ GEM rack (>= 1.0.0) rack-test (>= 0.5.4) xpath (~> 2.0) - chunky_png (1.3.8) + chunky_png (1.3.10) coffee-script (2.4.1) coffee-script-source execjs @@ -41,7 +41,7 @@ GEM erubis (2.7.0) eventmachine (1.2.5) execjs (2.7.0) - ffi (1.9.18) + ffi (1.9.23) haml (5.0.4) temple (>= 0.8.0) tilt @@ -51,7 +51,7 @@ GEM http_parser.rb (0.6.0) i18n (0.7.0) json (2.1.0) - kramdown (1.15.0) + kramdown (1.16.2) listen (3.0.8) rb-fsevent (~> 0.9, >= 0.9.4) rb-inotify (~> 0.9, >= 0.9.7) @@ -78,7 +78,7 @@ GEM rack (>= 1.4.5, < 2.0) thor (>= 0.15.2, < 2.0) tilt (~> 1.4.1, < 2.0) - middleman-hashicorp (0.3.29) + middleman-hashicorp (0.3.30) bootstrap-sass (~> 3.3) builder (~> 3.2) middleman (~> 3.4) @@ -102,22 +102,22 @@ GEM mime-types-data (~> 3.2015) mime-types-data (3.2016.0521) mini_portile2 (2.3.0) - minitest (5.10.3) - multi_json (1.12.2) - nokogiri (1.8.1) + minitest (5.11.3) + multi_json (1.13.1) + nokogiri (1.8.2) mini_portile2 (~> 2.3.0) - padrino-helpers (0.12.8.1) + padrino-helpers (0.12.9) i18n (~> 0.6, >= 0.6.7) - padrino-support (= 0.12.8.1) - tilt (~> 1.4.1) - padrino-support (0.12.8.1) + padrino-support (= 0.12.9) + tilt (>= 1.4.1, < 3) + padrino-support (0.12.9) activesupport (>= 3.1) - rack (1.6.8) + rack (1.6.9) rack-livereload (0.3.16) rack - rack-test (0.7.0) + rack-test (0.8.3) rack (>= 1.0, < 3) - rb-fsevent (0.10.2) + rb-fsevent (0.10.3) rb-inotify (0.9.10) ffi (>= 0.5.0, < 2) redcarpet (3.4.0) @@ -137,10 +137,10 @@ GEM thor (0.20.0) thread_safe (0.3.6) tilt (1.4.1) - turbolinks (5.0.1) - turbolinks-source (~> 5) - turbolinks-source (5.0.3) - tzinfo (1.2.3) + turbolinks (5.1.0) + turbolinks-source (~> 5.1) + turbolinks-source (5.1.0) + tzinfo (1.2.5) thread_safe (~> 0.1) uber (0.0.15) uglifier (2.7.2) @@ -153,7 +153,7 @@ PLATFORMS ruby DEPENDENCIES - middleman-hashicorp (= 0.3.29) + middleman-hashicorp (= 0.3.30) BUNDLED WITH 1.16.1 diff --git a/website/packer.json b/website/packer.json index cbfb5c565ad3..aeaa162e0985 100644 --- a/website/packer.json +++ b/website/packer.json @@ -8,7 +8,7 @@ "builders": [ { "type": "docker", - "image": "hashicorp/middleman-hashicorp:0.3.29", + "image": "hashicorp/middleman-hashicorp:0.3.32", "discard": "true", "volumes": { "{{ pwd }}": "/website" diff --git a/website/source/api/auth/approle/index.html.md b/website/source/api/auth/approle/index.html.md index b06451b49183..ee7f8ef41f80 100644 --- a/website/source/api/auth/approle/index.html.md +++ b/website/source/api/auth/approle/index.html.md @@ -69,8 +69,9 @@ enabled while creating or updating a role. - `role_name` `(string: )` - Name of the AppRole. - `bind_secret_id` `(bool: true)` - Require `secret_id` to be presented when logging in using this AppRole. -- `bound_cidr_list` `(array: [])` - Comma-separated list of CIDR blocks; if set, - specifies blocks of IP addresses which can perform the login operation. +- `bound_cidr_list` `(array: [])` - Comma-separated string or list of CIDR + blocks; if set, specifies blocks of IP addresses which can perform the login + operation. - `policies` `(array: [])` - Comma-separated list of policies set on tokens issued via this AppRole. - `secret_id_num_uses` `(integer: 0)` - Number of times any particular SecretID @@ -155,7 +156,7 @@ $ curl \ ], "period": 0, "bind_secret_id": true, - "bound_cidr_list": "" + "bound_cidr_list": [] }, "lease_duration": 0, "renewable": false, @@ -285,10 +286,10 @@ itself, and also to delete the SecretID from the AppRole. a JSON-formatted string containing the metadata in key-value pairs. This metadata will be set on tokens issued with this SecretID, and is logged in audit logs _in plaintext_. -- `cidr_list` `(string: "")` - Comma separated list of CIDR blocks enforcing - secret IDs to be used from specific set of IP addresses. If 'bound_cidr_list' - is set on the role, then the list of CIDR blocks listed here should be a - subset of the CIDR blocks listed on the role. +- `cidr_list` `(array: [])` - Comma separated string or list of CIDR blocks + enforcing secret IDs to be used from specific set of IP addresses. If + `bound_cidr_list` is set on the role, then the list of CIDR blocks listed + here should be a subset of the CIDR blocks listed on the role. ### Sample Payload @@ -510,10 +511,10 @@ Assigns a "custom" SecretID against an existing AppRole. This is used in the a JSON-formatted string containing the metadata in key-value pairs. This metadata will be set on tokens issued with this SecretID, and is logged in audit logs _in plaintext_. -- `cidr_list` `(string: "")` - Comma separated list of CIDR blocks enforcing - secret IDs to be used from ppecific set of IP addresses. If 'bound_cidr_list' - is set on the role, then the list of CIDR blocks listed here should be a - subset of the CIDR blocks listed on the role. +- `cidr_list` `(array: [])` - Comma separated string or list of CIDR blocks + enforcing secret IDs to be used from ppecific set of IP addresses. If + `bound_cidr_list` is set on the role, then the list of CIDR blocks listed + here should be a subset of the CIDR blocks listed on the role. ### Sample Payload diff --git a/website/source/api/auth/aws/index.html.md b/website/source/api/auth/aws/index.html.md index 986fbb1d4748..667c390df18a 100644 --- a/website/source/api/auth/aws/index.html.md +++ b/website/source/api/auth/aws/index.html.md @@ -588,6 +588,10 @@ list in order to satisfy that constraint. This constraint is checked by the ec2 auth method as well as the iam auth method only when inferring an ec2 instance. This is a comma-separated string or a JSON array. +- `bound_ec2_instance_id` `(list: [])` - If set, defines a constraint on the + EC2 instances to have one of these instance IDs. This constraint is checked by + the ec2 auth method as well as the iam auth method only when inferring an ec2 + instance. This is a comma-separated string or a JSON array. - `role_tag` `(string: "")` - If set, enables the role tags for this role. The value set for this field should be the 'key' of the tag on the EC2 instance. The 'value' of the tag should be generated using `role//tag` endpoint. @@ -681,6 +685,7 @@ list in order to satisfy that constraint. ```json { "bound_ami_id": ["ami-fce36987"], + "bound_ec2_instance_id": ["i-12345678901234567"], "role_tag": "", "policies": [ "default", diff --git a/website/source/api/secret/transit/index.html.md b/website/source/api/secret/transit/index.html.md index b1bdc9ba3336..b6f11624967d 100644 --- a/website/source/api/secret/transit/index.html.md +++ b/website/source/api/secret/transit/index.html.md @@ -54,8 +54,10 @@ values set here cannot be changed after key creation. (symmetric, supports derivation and convergent encryption) - `chacha20-poly1305` – ChaCha20-Poly1305 AEAD (symmetric, supports derivation and convergent encryption) + - `ed25519` – ED25519 (asymmetric, supports derivation). When using + derivation, a sign operation with the same context will derive the same + key and signature; this is a signing analogue to `convergent_encryption`. - `ecdsa-p256` – ECDSA using the P-256 elliptic curve (asymmetric) - - `ed25519` – ED25519 (asymmetric, supports derivation) - `rsa-2048` - RSA with bit size of 2048 (asymmetric) - `rsa-4096` - RSA with bit size of 4096 (asymmetric) @@ -774,7 +776,7 @@ supports signing. | Method | Path | Produces | | :------- | :--------------------------- | :--------------------- | -| `POST` | `/transit/sign/:name(/:algorithm)` | `200 application/json` | +| `POST` | `/transit/sign/:name(/:hash_algorithm)` | `200 application/json` | ### Parameters @@ -785,7 +787,7 @@ supports signing. signing. If not set, uses the latest version. Must be greater than or equal to the key's `min_encryption_version`, if set. -- `algorithm` `(string: "sha2-256")` – Specifies the hash algorithm to use for +- `hash_algorithm` `(string: "sha2-256")` – Specifies the hash algorithm to use for supporting key types (notably, not including `ed25519` which specifies its own hash algorithm). This can also be specified as part of the URL. Currently-supported algorithms are: @@ -803,7 +805,13 @@ supports signing. - `prehashed` `(bool: false)` - Set to `true` when the input is already hashed. If the key type is `rsa-2048` or `rsa-4096`, then the algorithm used - to hash the input should be indicated by the `algorithm` parameter. + to hash the input should be indicated by the `hash_algorithm` parameter. + +- `signature_algorithm` `(string: "pss")` – When using a RSA key, specifies the RSA + signature algorithm to use for signing. Supported signature types are: + + - `pss` + - `pkcs1v15` ### Sample Payload @@ -841,14 +849,14 @@ data. | Method | Path | Produces | | :------- | :--------------------------- | :--------------------- | -| `POST` | `/transit/verify/:name(/:algorithm)` | `200 application/json` | +| `POST` | `/transit/verify/:name(/:hash_algorithm)` | `200 application/json` | ### Parameters - `name` `(string: )` – Specifies the name of the encryption key that was used to generate the signature or HMAC. -- `algorithm` `(string: "sha2-256")` – Specifies the hash algorithm to use. This +- `hash_algorithm` `(string: "sha2-256")` – Specifies the hash algorithm to use. This can also be specified as part of the URL. Currently-supported algorithms are: - `sha2-224` @@ -872,7 +880,14 @@ data. - `prehashed` `(bool: false)` - Set to `true` when the input is already hashed. If the key type is `rsa-2048` or `rsa-4096`, then the algorithm used - to hash the input should be indicated by the `algorithm` parameter. + to hash the input should be indicated by the `hash_algorithm` parameter. + +- `signature_algorithm` `(string: "pss")` – When using a RSA key, specifies the RSA + signature algorithm to use for signature verification. Supported signature types + are: + + - `pss` + - `pkcs1v15` ### Sample Payload diff --git a/website/source/api/system/auth.html.md b/website/source/api/system/auth.html.md index 4b72e9aa4cf2..2f85c191a780 100644 --- a/website/source/api/system/auth.html.md +++ b/website/source/api/system/auth.html.md @@ -77,7 +77,20 @@ For example, enable the "foo" auth method will make it accessible at - `config` `(map: nil)` – Specifies configuration options for this auth method. These are the possible values: - - `plugin_name` + - `default_lease_ttl` `(string: "")` - The default lease duration, specified + as a string duration like "5s" or "30m". + + - `max_lease_ttl` `(string: "")` - The maximum lease duration, specified as a + string duration like "5s" or "30m". + + - `plugin_name` `(string: "")` - The name of the plugin in the plugin catalog + to use. + + - `audit_non_hmac_request_keys` `(array: [])` - Comma-separated list of keys + that will not be HMAC'd by audit devices in the request data object. + + - `audit_non_hmac_response_keys` `(array: [])` - Comma-separated list of keys + that will not be HMAC'd by audit devices in the response data object. The plugin_name can be provided in the config map or as a top-level option, with the former taking precedence. diff --git a/website/source/api/system/mounts.html.md b/website/source/api/system/mounts.html.md index bd52bdea9be0..d2a69900c062 100644 --- a/website/source/api/system/mounts.html.md +++ b/website/source/api/system/mounts.html.md @@ -80,23 +80,29 @@ This endpoint enables a new secrets engine at the given path. - `config` `(map: nil)` – Specifies configuration options for this mount. This is an object with four possible values: - - `default_lease_ttl` `(string: "")` - the default lease duration, specified - as a go string duration like "5s" or "30m". + - `default_lease_ttl` `(string: "")` - The default lease duration, specified + as a string duration like "5s" or "30m". - - `max_lease_ttl` `(string: "")` - the maximum lease duration, specified as - a go string duration like "5s" or "30m". + - `max_lease_ttl` `(string: "")` - The maximum lease duration, specified as a + string duration like "5s" or "30m". - - `force_no_cache` `(bool: false)` - disable caching. + - `force_no_cache` `(bool: false)` - Disable caching. - - `plugin_name` `(string: "")` - the name of the plugin in the plugin - catalog to use. + - `plugin_name` `(string: "")` - The name of the plugin in the plugin catalog + to use. + + - `audit_non_hmac_request_keys` `(array: [])` - Comma-separated list of keys + that will not be HMAC'd by audit devices in the request data object. + + - `audit_non_hmac_response_keys` `(array: [])` - Comma-separated list of keys + that will not be HMAC'd by audit devices in the response data object. These control the default and maximum lease time-to-live, force disabling backend caching, and option plugin name for plugin backends respectively. The first three options override the global defaults if set on a specific mount. The plugin_name can be provided in the config map or as a top-level option, with the former taking precedence. - + When used with supported seals (`pkcs11`, `awskms`, etc.), `seal_wrap` causes key material for supporting mounts to be wrapped by the seal's encryption capability. This is currently only supported for `transit` and diff --git a/website/source/assets/images/vault-approle-workflow2.png b/website/source/assets/images/vault-approle-workflow2.png index 66f85e4141ea..5d8ece655a11 100644 Binary files a/website/source/assets/images/vault-approle-workflow2.png and b/website/source/assets/images/vault-approle-workflow2.png differ diff --git a/website/source/docs/auth/aws.html.md b/website/source/docs/auth/aws.html.md index a7cca908260b..6152d1a846ce 100644 --- a/website/source/docs/auth/aws.html.md +++ b/website/source/docs/auth/aws.html.md @@ -304,7 +304,7 @@ method. "Effect": "Allow", "Action": ["sts:AssumeRole"], "Resource": [ - "arn:aws:iam::role/" + "arn:aws:iam:::role/" ] } ] diff --git a/website/source/docs/concepts/policies.html.md b/website/source/docs/concepts/policies.html.md index eaffadb2e126..182af4a0401a 100644 --- a/website/source/docs/concepts/policies.html.md +++ b/website/source/docs/concepts/policies.html.md @@ -269,7 +269,7 @@ options are: ``` * If any keys are specified, all non-specified parameters will be denied - unless there the parameter `"*"` is set to an empty array, which will + unless the parameter `"*"` is set to an empty array, which will allow all other parameters to be modified. Parameters with specific values will still be restricted to those values. @@ -338,15 +338,18 @@ Parameter values also support prefix/suffix globbing. Globbing is enabled by prepending or appending or prepending a splat (`*`) to the value: ```ruby -# Allow any parameter as long as the value starts with "foo-*". +# Only allow a parameter named "bar" with a value starting with "foo-*". path "secret/foo" { capabilities = ["create"] allowed_parameters = { - "*" = ["foo-*"] + "bar" = ["foo-*"] } } ``` +Note: the only value that can be used with the `*` parameter is `[]`. + + ### Required Response Wrapping TTLs These parameters can be used to set minimums/maximums on TTLs set by clients @@ -368,9 +371,9 @@ wrapping mandatory for a particular path. wrapped response. If both are specified, the minimum value must be less than the maximum. In -addition, if paths are merged from different stanzas, the lowest value specified -for each is the value that will result, in line with the idea of keeping token -lifetimes as short as possible. +addition, if paths are merged from different stanzas, the lowest value +specified for each is the value that will result, in line with the idea of +keeping token lifetimes as short as possible. ## Builtin Policies @@ -379,10 +382,17 @@ the two builtin policies. ### Default Policy -The `default` policy is a builtin Vault policy that cannot be modified or -removed. By default, it is attached to all tokens, but may be explicitly -detached at creation time. The policy contains basic functionality such as the -ability for the token to lookup data about itself and to use its cubbyhole data. +The `default` policy is a builtin Vault policy that cannot be removed. By +default, it is attached to all tokens, but may be explicitly excluded at token +creation time by supporting authentication methods. + +The policy contains basic functionality such as the ability for the token to +look up data about itself and to use its cubbyhole data. However, Vault is not +proscriptive about its contents. It can be modified to suit your needs; Vault +will never overwrite your modifications. If you want to stay up-to-date with +the latest upstream version of the `default` policy, simply read the contents +of the policy from an up-to-date `dev` server, and write those contents into +your Vault's `default` policy. To view all permissions granted by the default policy on your Vault installation, run: diff --git a/website/source/docs/configuration/storage/dynamodb.html.md b/website/source/docs/configuration/storage/dynamodb.html.md index 61ca9e687e87..0f68749e7234 100644 --- a/website/source/docs/configuration/storage/dynamodb.html.md +++ b/website/source/docs/configuration/storage/dynamodb.html.md @@ -19,8 +19,8 @@ The DynamoDB storage backend is used to persist Vault's data in - **Community Supported** – the DynamoDB storage backend is supported by the community. While it has undergone review by HashiCorp employees, they may not - be as knowledgeable about the technology. If you encounter problems with them, - you may be referred to the original author. + be as knowledgeable about the technology. If you encounter problems with this + storage backend, you could be referred to the original author for support. ```hcl storage "dynamodb" { diff --git a/website/source/docs/install/index.html.md b/website/source/docs/install/index.html.md index 36eb9307252e..af3137d29117 100644 --- a/website/source/docs/install/index.html.md +++ b/website/source/docs/install/index.html.md @@ -41,7 +41,7 @@ as a copy of [`git`](https://www.git-scm.com/) in your `PATH`. 1. Clone the Vault repository from GitHub into your `GOPATH`: ```shell - $ mkdir -p $GOPATH/src/github.com/hashicorp && cd $! + $ mkdir -p $GOPATH/src/github.com/hashicorp && cd $_ $ git clone https://github.com/hashicorp/vault.git $ cd vault ``` diff --git a/website/source/docs/secrets/aws/index.html.md b/website/source/docs/secrets/aws/index.html.md index db07f2a766a1..fb2948ef73da 100644 --- a/website/source/docs/secrets/aws/index.html.md +++ b/website/source/docs/secrets/aws/index.html.md @@ -46,7 +46,9 @@ the IAM credentials: on IAM credentials. Since Vault uses the official AWS SDK, it will use the specified credentials. You can also specify the credentials via the standard AWS environment credentials, shared file credentials, or IAM role/ECS task - credentials. + credentials. (Note that you can't authorize vault with IAM role credentials if you plan + on using STS Federation Tokens, since the temporary security credentials + associated with the role are not authorized to use GetFederationToken.) ~> **Notice:** Even though the path above is `aws/config/root`, do not use your AWS root account credentials. Instead generate a dedicated user or diff --git a/website/source/guides/identity/authentication.html.md b/website/source/guides/identity/authentication.html.md index e86d84e1079a..dbcbe114bf49 100644 --- a/website/source/guides/identity/authentication.html.md +++ b/website/source/guides/identity/authentication.html.md @@ -10,34 +10,35 @@ description: |- # Authentication Before a client can interact with Vault, it must authenticate against an [**auth -backend**](/docs/auth/index.html) to acquire a token. This token has policies attached so +method**](/docs/auth/index.html) to acquire a token. This token has policies attached so that the behavior of the client can be governed. Since tokens are the core method for authentication within Vault, there is a -**token** auth backend (often refer as **_token store_**). This is a special -auth backend responsible for creating and storing tokens. +**token** auth method (often referred to as **_token store_**). This is a special +auth method responsible for creating and storing tokens. -### Auth Backends +### Auth Methods -Auth backends perform authentication to verify the user or machine-supplied -information. Some of the supported auth backends are targeted towards users +Auth methods perform authentication to verify the user or machine-supplied +information. Some of the supported auth methods are targeted towards users while others are targeted toward machines or apps. For example, -[**LDAP**](/docs/auth/ldap.html) auth backend enables user authentication using +[**LDAP**](/docs/auth/ldap.html) auth method enables user authentication using an existing LDAP server while [**AppRole**](/docs/auth/approle.html) auth -backend is recommended for machines or apps. +method is recommended for machines or apps. The [Getting Started](/intro/getting-started/authentication.html) guide walks you -through how to enable the GitHub auth backend for user authentication. +through how to enable the GitHub auth method for user authentication. This introductory guide focuses on generating tokens for machines or apps by -enabling the [**AppRole**](/docs/auth/approle.html) auth backend. +enabling the [**AppRole**](/docs/auth/approle.html) auth method. ## Reference Material -- [Getting Started](/intro/getting-started/authentication.html) -- [Auth Backends](/docs/auth/index.html) -- [GitHub Auth APIs](/api/auth/github/index.html) +- [AppRole Auth Methods](/docs/auth/approle.html) +- [AppRole Auth Method (API)](/api/auth/approle/index.html) +- [Authenticating Applications with HashiCorp Vault AppRole](https://www.hashicorp.com/blog/authenticating-applications-with-vault-approle) + ## Estimated Time to Complete @@ -48,7 +49,7 @@ enabling the [**AppRole**](/docs/auth/approle.html) auth backend. The end-to-end scenario described in this guide involves two personas: -- **`admin`** with privileged permissions to configure an auth backend +- **`admin`** with privileged permissions to configure an auth method - **`app`** is the consumer of secrets stored in Vault @@ -58,12 +59,12 @@ Think of a scenario where a DevOps team wants to configure Jenkins to read secrets from Vault so that it can inject the secrets to an app's environment variables (e.g. `MYSQL_DB_HOST`) at deployment time. -Instead of hardcoding secrets in each build script as a plaintext, Jenkins +Instead of hardcoding secrets in each build script as plain text, Jenkins retrieves secrets from Vault. As a user, you can authenticate with Vault using your LDAP credentials, and -Vault generates a token. This token has policies granting you to perform -appropriate operations. +Vault generates a token. This token has policies granting you permission to perform +the appropriate operations. How can a Jenkins server programmatically request a token so that it can read secrets from Vault? @@ -71,7 +72,7 @@ secrets from Vault? ## Solution -Enable **AppRole** auth backend so that the Jenkins server can obtain a Vault +Enable **AppRole** auth method so that the Jenkins server can obtain a Vault token with appropriate policies attached. Since each AppRole has attached policies, you can write fine-grained policies limiting which app can access which path. @@ -87,21 +88,21 @@ unsealed](/intro/getting-started/deploy.html). ### Policy requirements --> **NOTE:** For the purpose of this guide, you can use **`root`** token to work +-> **NOTE:** For the purpose of this guide, you can use the **`root`** token to work with Vault. However, it is recommended that root tokens are only used for just enough initial setup or in emergencies. As a best practice, use tokens with -appropriate set of policies based on your role in the organization. +an appropriate set of policies based on your role in the organization. To perform all tasks demonstrated in this guide, your policy must include the following permissions: ```shell -# Mount the AppRole auth backend +# Mount the AppRole auth method path "sys/auth/approle" { capabilities = [ "create", "read", "update", "delete", "sudo" ] } -# Configure the AppRole auth backend +# Configure the AppRole auth method path "sys/auth/approle/*" { capabilities = [ "create", "read", "update", "delete" ] } @@ -133,7 +134,7 @@ to allow machines or apps to acquire a token to interact with Vault. It uses **Role ID** and **Secret ID** for login. The basic workflow is: -![AppRole auth backend workflow](/assets/images/vault-approle-workflow.png) +![AppRole auth method workflow](/assets/images/vault-approle-workflow.png) > For the purpose of introducing the basics of AppRole, this guide walks you > through a very simple scenario involving only two personas (admin and app). @@ -142,24 +143,24 @@ The basic workflow is: In this guide, you are going to perform the following steps: -1. [Enable AppRole auth backend](#step1) -2. [Create a role with policy attached](#step2) -3. [Get Role ID and Secret ID](#step3) -4. [Login with Role ID & Secret ID](#step4) -5. [Read secrets using the AppRole token](#step5) +1. [Enable AppRole auth method](#step1) +1. [Create a role with policy attached](#step2) +1. [Get Role ID and Secret ID](#step3) +1. [Login with Role ID & Secret ID](#step4) +1. [Read secrets using the AppRole token](#step5) Step 1 through 3 need to be performed by an `admin` user. Step 4 and 5 describe the commands that an `app` runs to get a token and read secrets from Vault. -### Step 1: Enable AppRole auth backend +### Step 1: Enable AppRole auth method (**Persona:** admin) -Like many other auth backends, AppRole must be enabled before it can be used. +Like many other auth methods, AppRole must be enabled before it can be used. #### CLI command -Enable `approle` auth backend by executing the following command: +Enable `approle` auth method by executing the following command: ```shell $ vault auth enable approle @@ -167,7 +168,7 @@ $ vault auth enable approle #### API call using cURL -Enable `approle` auth backend by mounting its endpoint at `/sys/auth/approle`: +Enable `approle` auth method by mounting its endpoint at `/sys/auth/approle`: ```shell $ curl --header "X-Vault-Token: " \ @@ -177,7 +178,7 @@ $ curl --header "X-Vault-Token: " \ ``` Where `` is your valid token, and `` holds [configuration -parameters](/api/system/auth.html#mount-auth-backend) of the backend. +parameters](/api/system/auth.html#enable-auth-method) of the method. **Example:** @@ -189,15 +190,15 @@ $ curl --header "X-Vault-Token: ..." \ https://vault.rocks/v1/sys/auth/approle ``` -The above example passes the **type** (`approle`) in the request payload which +The above example passes the **type** (`approle`) in the request payload at the `sys/auth/approle` endpoint. ### Step 2: Create a role with policy attached (**Persona:** admin) -When you enabled AppRole auth backend, it gets mounted at the +When you enabled the AppRole auth method, it gets mounted at the **`/auth/approle`** path. In this example, you are going to create a role for -**`app`** persona (`jenkins` in our scenario). +the **`app`** persona (`jenkins` in our scenario). The scenario in this guide requires the `app` to have the following policy (`jenkins-pol.hcl`): @@ -216,7 +217,7 @@ path "secret/mysql/*" { #### CLI command -Before creating a role, create `jenkins` policy: +Before creating a role, create a `jenkins` policy: ```shell $ vault policy write jenkins jenkins-pol.hcl @@ -287,8 +288,8 @@ Now, you are ready to create a role. **Example:** -The following example creates a role named `jenkins` with `jenkins` policy -attached. (NOTE: This example creates a role operates in [**pull** +The following example creates a role named `jenkins` with a `jenkins` policy +attached. (NOTE: This example creates a role which operates in [**pull** mode](/docs/auth/approle.html).) ```shell @@ -303,7 +304,7 @@ $ curl --header "X-Vault-Token: ..." --request POST \ > `secret_id_num_uses` or `secret_id_ttl` parameter values. Similarly, you can > specify `token_num_uses` and `token_ttl`. You may never want the app token to > expire. In such a case, specify the `period` so that the token generated by -> this AppRole is a periodic token. To learn more about periodic token, refer to +> this AppRole is a periodic token. To learn more about periodic tokens, refer to > the [Tokens and Leases](/guides/identity/lease.html#step4) guide. @@ -418,7 +419,7 @@ $ curl --header "X-Vault-Token:..." \ You can pass [parameters](/api/auth/approle/index.html#generate-new-secret-id) in the request -payload, or invoke the API with empty payload. +payload, or invoke the API with an empty payload. **Example:** @@ -448,7 +449,7 @@ securely. #### CLI command -To login, use `auth/approle/login` endpoint by passing the role ID and secret ID. +To login, use the `auth/approle/login` endpoint by passing the role ID and secret ID. **Example:** @@ -471,7 +472,7 @@ Now you have a **client token** with `default` and `jenkins` policies attached. #### API call using cURL -To login, use `auth/approle/login` endpoint by passing the role ID and secret ID +To login, use the `auth/approle/login` endpoint by passing the role ID and secret ID in the request payload. **Example:** @@ -543,11 +544,11 @@ $ vault read secret/mysql/webapp No value found at secret/mysql/webapp ``` -Since there is no value in the `secret/mysql/webapp`, it returns "no value +Since there is no value at `secret/mysql/webapp`, it returns a "no value found" message. **Optional:** Using the `admin` user's token, you can store some secrets in the -`secret/mysql/webapp` backend. +`secret/mysql/webapp` path. ```shell $ vault write secret/dev/config/mongodb @mysqldb.txt @@ -561,7 +562,7 @@ $ cat mysqldb.txt } ``` -Now, try to read secrets from `secret/mysql/webapp` using `client_token` again. +Now, try to read secrets from `secret/mysql/webapp` using the `client_token` again. This time, it should return the values you just created. @@ -581,10 +582,10 @@ $ curl --header "X-Vault-Token: 3e7dd0ac-8b3e-8f88-bb37-a2890455ca6e" \ } ``` -Since there is no value in the `secret/mysql/webapp`, it returns an empty array. +Since there is no value at `secret/mysql/webapp`, it returns an empty array. **Optional:** Using the **`admin`** user's token, create some secrets in the -`secret/mysql/webapp` backend. +`secret/mysql/webapp` path. ```shell $ curl --header "X-Vault-Token: ..." --request POST --data @mysqldb.txt \ @@ -598,7 +599,7 @@ $ cat mysqldb.text } ``` -Now, try to read secrets from `secret/mysql/webapp` using `client_token` again. +Now, try to read secrets from `secret/mysql/webapp` using the `client_token` again. This time, it should return the values you just created. @@ -607,7 +608,7 @@ This time, it should return the values you just created. The Role ID is equivalent to a username, and Secret ID is the corresponding password. The app needs both to log in with Vault. Naturally, the next question -becomes how to deliver those values to the expecting client. +becomes how to deliver those values to the expected client. A common solution involves **three personas** instead of two: `admin`, `app`, and `trusted entity`. The `trusted entity` delivers the Role ID and Secret ID to the @@ -617,10 +618,10 @@ For example, Terraform as a trusted entity can deliver the Role ID onto the virtual machine. When the app runs on the virtual machine, the Role ID already exists on the virtual machine. -![AppRole auth backend workflow](/assets/images/vault-approle-workflow2.png) +![AppRole auth method workflow](/assets/images/vault-approle-workflow2.png) Secret ID is like a password. To keep the Secret ID confidential, use -[**response wrapping**](/docs/concepts/response-wrapping.html) so that the only +[**response wrapping**](/docs/concepts/response-wrapping.html) so that only the expected client can unwrap the Secret ID. In [Step 3](#step3), you executed the following command to retrieve the Secret @@ -644,7 +645,7 @@ wrapping_token_creation_time: 2018-01-08 21:29:38.826611 -0800 PST wrapping_token_creation_path: auth/approle/role/jenkins/secret-id ``` -Send this `wrapping_token` to the client so that the response can be unwrap and +Send this `wrapping_token` to the client so that the response can be unwrapped and obtain the Secret ID. ```shell @@ -656,7 +657,7 @@ secret_id 575f23e4-01ad-25f7-2661-9c9bdbb1cf81 secret_id_accessor 7d8a40b7-a6fd-a634-579b-b7d673ff86fb ``` -NOTE: To retrieve the Secret ID alone, you can use `jq` as follow: +NOTE: To retrieve the Secret ID alone, you can use `jq` as follows: ```shell $ VAULT_TOKEN=2577044d-cf86-a065-e28f-e2a14ea6eaf7 vault unwrap -format=json | jq -r ".data.secret_id" @@ -667,5 +668,8 @@ b07d7a47-1d0d-741d-20b4-ae0de7c6d964 ## Next steps -To learn more about response wrapping, go to [Cubbyhole Response +Watch the video recording of the [Delivering Secret Zero: Vault AppRole with Terraform and Chef](Docs.google.com/document/d/1CCbAQ-ZEeSpxQqvo40hM0Y5oe3tBCxM1qYrk7_Wyvww/edit?ts=5a865277#heading=h.i52qa2wbjcsk) +webinar which talks about the usage of AppRole with Terraform and Chef as its trusted entities. + +To learn more about response wrapping, go to the [Cubbyhole Response Wrapping](/guides/secret-mgmt/cubbyhole.html) guide. diff --git a/website/source/guides/operations/replication.html.md b/website/source/guides/operations/replication.html.md index d6e9debd33c0..fac32785cf49 100644 --- a/website/source/guides/operations/replication.html.md +++ b/website/source/guides/operations/replication.html.md @@ -8,6 +8,8 @@ description: |- # Replication Setup & Guidance +~> **Enterprise Only:** Vault replication feature is a part of _Vault Enterprise_. + If you're unfamiliar with Vault Replication concepts, please first look at the [general information page](/docs/vault-enterprise/replication/index.html). More details can be found in the diff --git a/website/source/guides/secret-mgmt/dynamic-secrets.html.md b/website/source/guides/secret-mgmt/dynamic-secrets.html.md index bd6057ace8cf..2337526de1c7 100644 --- a/website/source/guides/secret-mgmt/dynamic-secrets.html.md +++ b/website/source/guides/secret-mgmt/dynamic-secrets.html.md @@ -1,9 +1,9 @@ --- layout: "guides" page_title: "Secret as a Service - Guides" -sidebar_current: "guides-secret-mgmt-dataynamic-secrets" +sidebar_current: "guides-secret-mgmt-dynamic-secrets" description: |- - Vault can dynamically generate secrets on--dataemand for some systems. + Vault can dynamically generate secrets on-demand for some systems. --- # Secret as a Service: Dynamic Secrets @@ -330,7 +330,7 @@ $ vault token create -policy="apps" Key Value --- ----- token e4bdf7dc-cbbf-1bb1-c06c-6a4f9a826cf2 -token_accessor 54700b7e--data828-a6c4-6141-96e71e002bd7 +token_accessor 54700b7e-d828-a6c4-6141-96e71e002bd7 token_duration 768h0m0s token_renewable true token_policies [apps default] @@ -354,7 +354,7 @@ $ vault read database/creds/readonly Key Value --- ----- -lease_id database/creds/readonly/4b5c6e82--dataf88-0dec-c0cb-f07eee8f0329 +lease_id database/creds/readonly/4b5c6e82-df88-0dec-c0cb-f07eee8f0329 lease_duration 1h0m0s lease_renewable true password A1a-4urzp0wu92r5s1q0 @@ -496,7 +496,7 @@ user name exists. ## Next steps -This guide discussed how to generate credentials on--dataemand so that the access +This guide discussed how to generate credentials on-demand so that the access credentials no longer need to be written to disk. Next, learn about the [Tokens and Leases](/guides/identity/lease.html) so that you can control the lifecycle of those credentials. diff --git a/website/source/index.html.erb b/website/source/index.html.erb index 26b0bf091ea2..0d54cb76d4a0 100644 --- a/website/source/index.html.erb +++ b/website/source/index.html.erb @@ -23,7 +23,7 @@ description: |-

- Vault secures, stores, and tightly + HashiCorp Vault secures, stores, and tightly controls access to tokens, passwords, certificates, API keys, and other secrets in modern computing. Vault handles leasing, key revocation, key rolling, and auditing. Through a unified diff --git a/website/source/intro/getting-started/authentication.html.md b/website/source/intro/getting-started/authentication.html.md index 55f90ab57402..d8a2632f63f4 100644 --- a/website/source/intro/getting-started/authentication.html.md +++ b/website/source/intro/getting-started/authentication.html.md @@ -86,7 +86,7 @@ is only used for revoking _leases_. For revoking _tokens_, use To authenticate with a token: ```text -$ vault login d08e2bd5-ffb0-440d-6486-b8f650ec8c0c +$ vault login a402d075-6d59-6129-1ac7-3718796d4346 Success! You are now authenticated. The token information displayed below is already stored in the token helper. You do NOT need to run "vault login" again. Future Vault requests will automatically use this token. diff --git a/website/source/layouts/docs.erb b/website/source/layouts/docs.erb index 1a299eb7b3ed..9c383b6ff4eb 100644 --- a/website/source/layouts/docs.erb +++ b/website/source/layouts/docs.erb @@ -308,6 +308,9 @@ > ssh + > + status + > token

- > - status - > unwrap