diff --git a/cmd/oras/internal/option/remote.go b/cmd/oras/internal/option/remote.go index f1cc938e4..44622d6d9 100644 --- a/cmd/oras/internal/option/remote.go +++ b/cmd/oras/internal/option/remote.go @@ -51,6 +51,8 @@ type Remote struct { resolveDialContext func(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) applyDistributionSpec bool distributionSpec distributionSpec + headerFlags []string + headers http.Header } // EnableDistributionSpecFlag set distribution specification flag as applicable. @@ -62,6 +64,7 @@ func (opts *Remote) EnableDistributionSpecFlag() { func (opts *Remote) ApplyFlags(fs *pflag.FlagSet) { opts.ApplyFlagsWithPrefix(fs, "", "") fs.BoolVarP(&opts.PasswordFromStdin, "password-stdin", "", false, "read password or identity token from stdin") + fs.StringArrayVarP(&opts.headerFlags, "header", "H", nil, "add custom headers to requests") } func applyPrefix(prefix, description string) (flagPrefix, notePrefix string) { @@ -105,6 +108,9 @@ func (opts *Remote) ApplyFlagsWithPrefix(fs *pflag.FlagSet, prefix, description // Parse tries to read password with optional cmd prompt. func (opts *Remote) Parse() error { + if err := opts.parseCustomHeaders(); err != nil { + return err + } if err := opts.readPassword(); err != nil { return err } @@ -209,7 +215,8 @@ func (opts *Remote) authClient(registry string, debug bool) (client *auth.Client TLSClientConfig: config, }, }, - Cache: auth.NewCache(), + Cache: auth.NewCache(), + Header: opts.headers, } client.SetUserAgent("oras/" + version.GetVersion()) if debug { @@ -243,6 +250,23 @@ func (opts *Remote) authClient(registry string, debug bool) (client *auth.Client return } +func (opts *Remote) parseCustomHeaders() error { + if len(opts.headerFlags) != 0 { + headers := map[string][]string{} + for _, h := range opts.headerFlags { + name, value, found := strings.Cut(h, ":") + if !found || strings.TrimSpace(name) == "" { + // In conformance to the RFC 2616 specification + // Reference: https://www.rfc-editor.org/rfc/rfc2616#section-4.2 + return fmt.Errorf("invalid header: %q", h) + } + headers[name] = append(headers[name], value) + } + opts.headers = headers + } + return nil +} + // Credential returns a credential based on the remote options. func (opts *Remote) Credential() auth.Credential { return credential.Credential(opts.Username, opts.Password) diff --git a/cmd/oras/internal/option/remote_test.go b/cmd/oras/internal/option/remote_test.go index b57ea7422..c4a1cf375 100644 --- a/cmd/oras/internal/option/remote_test.go +++ b/cmd/oras/internal/option/remote_test.go @@ -285,3 +285,104 @@ func TestRemote_parseResolve_err(t *testing.T) { }) } } + +func TestRemote_parseCustomHeaders(t *testing.T) { + tests := []struct { + name string + headerFlags []string + want nhttp.Header + wantErr bool + }{ + { + name: "no custom header is provided", + headerFlags: []string{}, + want: nil, + wantErr: false, + }, + { + name: "one name-value pair", + headerFlags: []string{"key:value"}, + want: map[string][]string{"key": {"value"}}, + wantErr: false, + }, + { + name: "multiple name-value pairs", + headerFlags: []string{"key:value", "k:v"}, + want: map[string][]string{"key": {"value"}, "k": {"v"}}, + wantErr: false, + }, + { + name: "multiple name-value pairs with commas", + headerFlags: []string{"key:value,value2,value3", "k:v,v2,v3"}, + want: map[string][]string{"key": {"value,value2,value3"}, "k": {"v,v2,v3"}}, + wantErr: false, + }, + { + name: "empty string is a valid value", + headerFlags: []string{"k:", "key:value,value2,value3"}, + want: map[string][]string{"k": {""}, "key": {"value,value2,value3"}}, + wantErr: false, + }, + { + name: "multiple colons are allowed", + headerFlags: []string{"k::::v,v2,v3", "key:value,value2,value3"}, + want: map[string][]string{"k": {":::v,v2,v3"}, "key": {"value,value2,value3"}}, + wantErr: false, + }, + { + name: "name with spaces", + headerFlags: []string{"bar :b"}, + want: map[string][]string{"bar ": {"b"}}, + wantErr: false, + }, + { + name: "value with spaces", + headerFlags: []string{"foo: a"}, + want: map[string][]string{"foo": {" a"}}, + wantErr: false, + }, + { + name: "repeated pairs", + headerFlags: []string{"key:value", "key:value"}, + want: map[string][]string{"key": {"value", "value"}}, + wantErr: false, + }, + { + name: "repeated name with different values", + headerFlags: []string{"key:value", "key:value2"}, + want: map[string][]string{"key": {"value", "value2"}}, + wantErr: false, + }, + { + name: "one valid header and one invalid header(no pair)", + headerFlags: []string{"key:value,value2,value3", "vk"}, + want: nil, + wantErr: true, + }, + { + name: "one valid header and one invalid header(empty name)", + headerFlags: []string{":v", "key:value,value2,value3"}, + want: nil, + wantErr: true, + }, + { + name: "pure-space name is invalid", + headerFlags: []string{" : foo "}, + want: nil, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + opts := &Remote{ + headerFlags: tt.headerFlags, + } + if err := opts.parseCustomHeaders(); (err != nil) != tt.wantErr { + t.Errorf("Remote.parseCustomHeaders() error = %v, wantErr %v", err, tt.wantErr) + } + if !reflect.DeepEqual(tt.want, opts.headers) { + t.Errorf("Remote.parseCustomHeaders() = %v, want %v", opts.headers, tt.want) + } + }) + } +}