diff --git a/consumer/option.go b/consumer/option.go index f2cb9e8a..089d529a 100644 --- a/consumer/option.go +++ b/consumer/option.go @@ -18,9 +18,10 @@ limitations under the License. package consumer import ( - "github.com/apache/rocketmq-client-go/v2/hooks" + "strings" "time" + "github.com/apache/rocketmq-client-go/v2/hooks" "github.com/apache/rocketmq-client-go/v2/internal" "github.com/apache/rocketmq-client-go/v2/primitive" ) @@ -327,7 +328,28 @@ func WithNameServer(nameServers primitive.NamesrvAddr) Option { // WithNameServerDomain set NameServer domain func WithNameServerDomain(nameServerUrl string) Option { return func(opts *consumerOptions) { - opts.Resolver = primitive.NewHttpResolver("DEFAULT", nameServerUrl) + h := primitive.NewHttpResolver("DEFAULT", nameServerUrl) + if opts.UnitName != "" { + h.SetUnitName(opts.UnitName) + } + opts.Resolver = h + } +} + +// WithUnitMode set the unit mode +func WithUnitMode(unitMode bool) Option { + return func(opts *consumerOptions) { + opts.UnitMode = unitMode + } +} + +// WithUnitName set the name of specified unit +func WithUnitName(unitName string) Option { + return func(opts *consumerOptions) { + opts.UnitName = strings.TrimSpace(unitName) + if ns, ok := opts.Resolver.(*primitive.HttpResolver); ok { + ns.SetUnitName(opts.UnitName) + } } } diff --git a/consumer/option_test.go b/consumer/option_test.go new file mode 100644 index 00000000..0e50bb11 --- /dev/null +++ b/consumer/option_test.go @@ -0,0 +1,71 @@ +package consumer + +import ( + "fmt" + "reflect" + "strings" + "testing" +) + +func getFieldString(obj interface{}, field string) string { + v := reflect.Indirect(reflect.ValueOf(obj)) + return v.FieldByNameFunc(func(n string) bool { + return n == field + }).String() +} + +func TestWithUnitMode(t *testing.T) { + opt := defaultPullConsumerOptions() + WithUnitMode(true)(&opt) + if !opt.UnitMode { + t.Errorf("consumer option WithUnitMode. want:true, got=%v", opt.UnitMode) + } +} + +func TestWithUnitName(t *testing.T) { + opt := defaultPullConsumerOptions() + unitName := "unsh" + WithUnitName(unitName)(&opt) + if opt.UnitName != unitName { + t.Errorf("consumer option WithUnitName. want:%s, got=%s", unitName, opt.UnitName) + } +} + +func TestWithNameServerDomain(t *testing.T) { + opt := defaultPullConsumerOptions() + nameServerAddr := "http://127.0.0.1:8080/nameserver/addr" + WithNameServerDomain(nameServerAddr)(&opt) + domainStr := getFieldString(opt.Resolver, "domain") + if domainStr != nameServerAddr { + t.Errorf("consumer option WithUnitName. want:%s, got=%s", nameServerAddr, domainStr) + } +} + +func TestWithNameServerDomainAndUnitName(t *testing.T) { + nameServerAddr := "http://127.0.0.1:8080/nameserver/addr" + unitName := "unsh" + suffix := fmt.Sprintf("-%s?nofix=1", unitName) + + // test with two different orders + t.Run("WithNameServerDomain & WithUnitName", func(t *testing.T) { + opt := defaultPullConsumerOptions() + WithNameServerDomain(nameServerAddr)(&opt) + WithUnitName(unitName)(&opt) + + domainStr := getFieldString(opt.Resolver, "domain") + if !strings.Contains(domainStr, nameServerAddr) || !strings.Contains(domainStr, suffix) { + t.Errorf("consumer option should contains %s and %s", nameServerAddr, suffix) + } + }) + + t.Run("WithUnitName & WithNameServerDomain", func(t *testing.T) { + opt := defaultPullConsumerOptions() + WithNameServerDomain(nameServerAddr)(&opt) + WithUnitName(unitName)(&opt) + + domainStr := getFieldString(opt.Resolver, "domain") + if !strings.Contains(domainStr, nameServerAddr) || !strings.Contains(domainStr, suffix) { + t.Errorf("consumer option should contains %s and %s", nameServerAddr, suffix) + } + }) +} diff --git a/primitive/nsresolver.go b/primitive/nsresolver.go index 4e5917c8..ac39df09 100644 --- a/primitive/nsresolver.go +++ b/primitive/nsresolver.go @@ -6,7 +6,7 @@ The ASF licenses this file to You under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, @@ -111,6 +111,16 @@ func NewHttpResolver(instance string, domain ...string) *HttpResolver { return h } +func (h *HttpResolver) SetUnitName(unitName string) { + if unitName == "" { + return + } + if strings.Contains(h.domain, "?nofix=1") { + return + } + h.domain = fmt.Sprintf("%s-%s?nofix=1", h.domain, unitName) +} + func (h *HttpResolver) Resolve() []string { addrs := h.get() if len(addrs) > 0 { @@ -152,14 +162,14 @@ func (h *HttpResolver) get() []string { return nil } - bodyStr := string(body) + bodyStr := strings.TrimSpace(string(body)) if bodyStr == "" { return nil } - h.saveSnapshot(body) + _ = h.saveSnapshot([]byte(bodyStr)) - return strings.Split(string(body), ";") + return strings.Split(bodyStr, ";") } func (h *HttpResolver) saveSnapshot(body []byte) error { diff --git a/primitive/nsresolver_test.go b/primitive/nsresolver_test.go index d42d2c6b..9cee5581 100644 --- a/primitive/nsresolver_test.go +++ b/primitive/nsresolver_test.go @@ -6,7 +6,7 @@ The ASF licenses this file to You under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, @@ -18,7 +18,6 @@ package primitive import ( "fmt" - "github.com/apache/rocketmq-client-go/v2/rlog" "io/ioutil" "net" "net/http" @@ -26,6 +25,8 @@ import ( "strings" "testing" + "github.com/apache/rocketmq-client-go/v2/rlog" + . "github.com/smartystreets/goconvey/convey" ) @@ -81,6 +82,43 @@ func TestHttpResolverWithGet(t *testing.T) { }) } +func TestHttpResolverWithGetUnitName(t *testing.T) { + Convey("Test UpdateNameServerAddress Save Local Snapshot", t, func() { + srvs := []string{ + "192.168.100.1", + "192.168.100.2", + "192.168.100.3", + "192.168.100.4", + "192.168.100.5", + } + http.HandleFunc("/nameserver/addrs3-unsh", func(w http.ResponseWriter, r *http.Request) { + if r.URL.Query().Get("nofix") == "1" { + fmt.Fprintf(w, strings.Join(srvs, ";")) + } + fmt.Fprintf(w, "") + }) + server := &http.Server{Addr: ":0", Handler: nil} + listener, _ := net.Listen("tcp", ":0") + go server.Serve(listener) + + port := listener.Addr().(*net.TCPAddr).Port + nameServerDommain := fmt.Sprintf("http://127.0.0.1:%d/nameserver/addrs3", port) + rlog.Info("Temporary Nameserver", map[string]interface{}{ + "domain": nameServerDommain, + }) + + resolver := NewHttpResolver("DEFAULT", nameServerDommain) + resolver.SetUnitName("unsh") + resolver.Resolve() + + // check snapshot saved + filePath := resolver.getSnapshotFilePath("DEFAULT") + body := strings.Join(srvs, ";") + bs, _ := ioutil.ReadFile(filePath) + So(string(bs), ShouldEqual, body) + }) +} + func TestHttpResolverWithSnapshotFile(t *testing.T) { Convey("Test UpdateNameServerAddress Use Local Snapshot", t, func() { srvs := []string{ diff --git a/producer/option.go b/producer/option.go index 9fd8374d..082000fe 100644 --- a/producer/option.go +++ b/producer/option.go @@ -18,6 +18,7 @@ limitations under the License. package producer import ( + "strings" "time" "github.com/apache/rocketmq-client-go/v2/internal" @@ -144,7 +145,28 @@ func WithNameServer(nameServers primitive.NamesrvAddr) Option { // WithNameServerDomain set NameServer domain func WithNameServerDomain(nameServerUrl string) Option { return func(opts *producerOptions) { - opts.Resolver = primitive.NewHttpResolver("DEFAULT", nameServerUrl) + h := primitive.NewHttpResolver("DEFAULT", nameServerUrl) + if opts.UnitName != "" { + h.SetUnitName(opts.UnitName) + } + opts.Resolver = h + } +} + +// WithUnitMode set the unit mode +func WithUnitMode(unitMode bool) Option { + return func(opts *producerOptions) { + opts.UnitMode = unitMode + } +} + +// WithUnitName set the name of specified unit +func WithUnitName(unitName string) Option { + return func(opts *producerOptions) { + opts.UnitName = strings.TrimSpace(unitName) + if ns, ok := opts.Resolver.(*primitive.HttpResolver); ok { + ns.SetUnitName(opts.UnitName) + } } } diff --git a/producer/option_test.go b/producer/option_test.go new file mode 100644 index 00000000..5231d108 --- /dev/null +++ b/producer/option_test.go @@ -0,0 +1,71 @@ +package producer + +import ( + "fmt" + "reflect" + "strings" + "testing" +) + +func getFieldString(obj interface{}, field string) string { + v := reflect.Indirect(reflect.ValueOf(obj)) + return v.FieldByNameFunc(func(n string) bool { + return n == field + }).String() +} + +func TestWithUnitMode(t *testing.T) { + opt := defaultProducerOptions() + WithUnitMode(true)(&opt) + if !opt.UnitMode { + t.Errorf("consumer option WithUnitMode. want:true, got=%v", opt.UnitMode) + } +} + +func TestWithUnitName(t *testing.T) { + opt := defaultProducerOptions() + unitName := "unsh" + WithUnitName(unitName)(&opt) + if opt.UnitName != unitName { + t.Errorf("consumer option WithUnitName. want:%s, got=%s", unitName, opt.UnitName) + } +} + +func TestWithNameServerDomain(t *testing.T) { + opt := defaultProducerOptions() + nameServerAddr := "http://127.0.0.1:8080/nameserver/addr" + WithNameServerDomain(nameServerAddr)(&opt) + domainStr := getFieldString(opt.Resolver, "domain") + if domainStr != nameServerAddr { + t.Errorf("consumer option WithUnitName. want:%s, got=%s", nameServerAddr, domainStr) + } +} + +func TestWithNameServerDomainAndUnitName(t *testing.T) { + nameServerAddr := "http://127.0.0.1:8080/nameserver/addr" + unitName := "unsh" + suffix := fmt.Sprintf("-%s?nofix=1", unitName) + + // test with two different orders + t.Run("WithNameServerDomain & WithUnitName", func(t *testing.T) { + opt := defaultProducerOptions() + WithNameServerDomain(nameServerAddr)(&opt) + WithUnitName(unitName)(&opt) + + domainStr := getFieldString(opt.Resolver, "domain") + if !strings.Contains(domainStr, nameServerAddr) || !strings.Contains(domainStr, suffix) { + t.Errorf("consumer option should contains %s and %s", nameServerAddr, suffix) + } + }) + + t.Run("WithUnitName & WithNameServerDomain", func(t *testing.T) { + opt := defaultProducerOptions() + WithNameServerDomain(nameServerAddr)(&opt) + WithUnitName(unitName)(&opt) + + domainStr := getFieldString(opt.Resolver, "domain") + if !strings.Contains(domainStr, nameServerAddr) || !strings.Contains(domainStr, suffix) { + t.Errorf("consumer option should contains %s and %s", nameServerAddr, suffix) + } + }) +}