diff --git a/.gitignore b/.gitignore index 019529ac..5fcab651 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ bin logs .coverage.* .vscode +.idea \ No newline at end of file diff --git a/cc/server.go b/cc/server.go index 7a0b42a3..96729747 100644 --- a/cc/server.go +++ b/cc/server.go @@ -80,7 +80,13 @@ type ListNamespaceResp struct { Data []string `json:"data"` } -// return names of all namespace +// @Summary 返回所有namespace名称 +// @Description 获取集群名称, 返回所有namespace名称, 未传入为默认集群 +// @Produce json +// @Param cluster header string false "cluster name" +// @Success 200 {object} ListNamespaceResp +// @Security BasicAuth +// @Router /api/cc/namespace/list [get] func (s *Server) listNamespace(c *gin.Context) { var err error r := &ListNamespaceResp{RetHeader: &RetHeader{RetCode: -1, RetMessage: ""}} @@ -109,6 +115,15 @@ type QueryNamespaceResp struct { Data []*models.Namespace `json:"data"` } +// @Summary 返回namespace配置详情, 已废弃 +// @Description 获取集群名称, 返回多个指定namespace配置详情, 未传入为默认集群, 已废弃 +// @Accept json +// @Produce json +// @Param cluster header string false "cluster name" +// @Param names body json true "{"names":["namespace_1","namespace_2"]}" +// @Success 200 {object} QueryNamespaceResp +// @Security BasicAuth +// @Router /api/cc/namespace [get] func (s *Server) queryNamespace(c *gin.Context) { var err error var req QueryReq @@ -136,6 +151,14 @@ func (s *Server) queryNamespace(c *gin.Context) { return } +// @Summary 返回namespace配置详情 +// @Description 获取集群名称, 返回指定namespace配置详情, 未传入为默认集群 +// @Produce json +// @Param cluster header string false "cluster name" +// @Param name path string true "namespace Name" +// @Success 200 {object} QueryNamespaceResp +// @Security BasicAuth +// @Router /api/cc/namespace/detail/{name} [get] func (s *Server) detailNamespace(c *gin.Context) { var err error var names []string @@ -164,6 +187,15 @@ func (s *Server) detailNamespace(c *gin.Context) { return } +// @Summary 创建修改namespace配置 +// @Description 获取集群名称, 根据json body创建或修改namespace配置, 未传入为默认集群 +// @Accept json +// @Produce json +// @Param cluster header string false "cluster name" +// @Param name body json true "namespace" +// @Success 200 {object} RetHeader +// @Security BasicAuth +// @Router /api/cc/namespace/modify [put] func (s *Server) modifyNamespace(c *gin.Context) { var err error var namespace models.Namespace @@ -190,6 +222,14 @@ func (s *Server) modifyNamespace(c *gin.Context) { return } +// @Summary 删除namespace配置 +// @Description 获取集群名称, 根据namespace name删除namespace, 未传入为默认集群 +// @Produce json +// @Param cluster header string false "cluster name" +// @Param name path string true "namespace name" +// @Success 200 {object} RetHeader +// @Security BasicAuth +// @Router /api/cc/namespace/delete/{name} [put] func (s *Server) delNamespace(c *gin.Context) { var err error h := &RetHeader{RetCode: -1, RetMessage: ""} @@ -219,6 +259,14 @@ type sqlFingerprintResp struct { SlowSQLs map[string]string `json:"slow_sqls"` } +// @Summary 获取namespce慢SQL、错误SQL +// @Description 获取集群名称, 根据namespace name获取该namespce慢SQL、错误SQL, 未传入为默认集群 +// @Produce json +// @Param cluster header string false "cluster name" +// @Param name path string true "namespace name" +// @Success 200 {object} sqlFingerprintResp +// @Security BasicAuth +// @Router /api/cc/namespace/sqlfingerprint/{name} [get] func (s *Server) sqlFingerprint(c *gin.Context) { var err error r := &sqlFingerprintResp{RetHeader: &RetHeader{RetCode: -1, RetMessage: ""}} @@ -246,6 +294,13 @@ type proxyConfigFingerprintResp struct { Data map[string]string `json:"data"` // key: ip:port value: md5 of config } +// @Summary 获取集群管理地址 +// @Description 获取集群名称, 返回集群管理地址, 未传入为默认集群 +// @Produce json +// @Param cluster header string false "cluster name" +// @Success 200 {object} proxyConfigFingerprintResp +// @Security BasicAuth +// @Router /api/cc/proxy/config/fingerprint [get] func (s *Server) proxyConfigFingerprint(c *gin.Context) { var err error r := &proxyConfigFingerprintResp{RetHeader: &RetHeader{RetCode: -1, RetMessage: ""}} diff --git a/core/version.go b/core/version.go index a0c95279..d302c4ac 100644 --- a/core/version.go +++ b/core/version.go @@ -28,6 +28,8 @@ var ( buildHost = "unknown" buildStatus = "unknown" buildTime = "unknown" + buildBranch = "unknown" + buildGitDirty = "0" ) // BuildInfo describes version information about the binary build. @@ -39,6 +41,8 @@ type BuildInfo struct { GolangVersion string `json:"golang_version"` BuildStatus string `json:"status"` BuildTime string `json:"time"` + BuildBranch string `json:"build_branch"` + BuildGitDirty string `json:"build_git_dirty"` } var ( @@ -54,13 +58,15 @@ var ( // user@host--- // ``` func (b BuildInfo) String() string { - return fmt.Sprintf("%v@%v-%v-%v-%v-%v", + return fmt.Sprintf("%v@%v-%v-%v-%v-%v-%v-%v", b.User, b.Host, b.Version, b.GitRevision, b.BuildStatus, - b.BuildTime) + b.BuildTime, + b.BuildBranch, + b.BuildGitDirty) } // LongForm returns a multi-line version information @@ -81,6 +87,8 @@ User: %v@%v GolangVersion: %v BuildStatus: %v BuildTime: %v +BuildBranch: %v +BuildGitDirty: %v `, b.Version, b.GitRevision, @@ -88,7 +96,9 @@ BuildTime: %v b.Host, b.GolangVersion, b.BuildStatus, - b.BuildTime) + b.BuildTime, + b.BuildBranch, + b.BuildGitDirty) } func init() { @@ -100,5 +110,7 @@ func init() { GolangVersion: runtime.Version(), BuildStatus: buildStatus, BuildTime: buildTime, + BuildBranch: buildBranch, + BuildGitDirty: buildGitDirty, } } diff --git a/etc/gaea.ini b/etc/gaea.ini index 20ede192..e30319a1 100644 --- a/etc/gaea.ini +++ b/etc/gaea.ini @@ -44,3 +44,6 @@ stats_interval=10 ;encrypt key encrypt_key=1234abcd5678efg* + +;server_version +server_version=5.6.20-gaea \ No newline at end of file diff --git a/gen_version.sh b/gen_version.sh index 277833a3..06d873e7 100755 --- a/gen_version.sh +++ b/gen_version.sh @@ -15,6 +15,10 @@ else tree_status="Modified" fi +# Check for git branch and git dirty +BRANCH=$(git rev-parse --abbrev-ref HEAD) +GIT_DIRTY=$(git diff --no-ext-diff 2> /dev/null | wc -l) + # XXX This needs to be updated to accomodate tags added after building, rather than prior to builds RELEASE_TAG=$(git describe --match '[0-9]*\.[0-9]*\.[0-9]*' --exact-match --tags 2> /dev/null || echo "") @@ -26,10 +30,12 @@ elif [[ -n ${MY_VERSION} ]]; then VERSION="${MY_VERSION}" fi -# used by pkg/version +# used by core/version echo buildVersion "${VERSION}" echo buildGitRevision "${BUILD_GIT_REVISION}" echo buildUser "$(whoami)" echo buildHost "$(hostname -f)" echo buildStatus "${tree_status}" echo buildTime "$(date +%Y-%m-%d--%T)" +echo buildBranch "${BRANCH}" +echo buildGitDirty "${GIT_DIRTY}" \ No newline at end of file diff --git a/models/proxy.go b/models/proxy.go index faac2380..395c4227 100644 --- a/models/proxy.go +++ b/models/proxy.go @@ -61,6 +61,8 @@ type Proxy struct { StatsInterval int `ini:"stats_interval"` // set stats interval of connect pool EncryptKey string `ini:"encrypt_key"` + + ServerVersion string `ini:"server_version"` } // ParseProxyConfigFromFile parser proxy config from file diff --git a/proxy/server/admin.go b/proxy/server/admin.go index 6c6bebf3..1071aa3f 100644 --- a/proxy/server/admin.go +++ b/proxy/server/admin.go @@ -161,6 +161,10 @@ func (s *AdminServer) registerURL() { }) } +// @Summary 获取proxy prometheus指标信息 +// @Description 获取gaea proxy prometheus指标信息 +// @Security BasicAuth +// @Router /api/metric/metrics [get] func (s *AdminServer) registerMetric() { metricGroup := s.engine.Group("/api/metric", gin.BasicAuth(gin.Accounts{s.adminUser: s.adminPassword})) for path, handler := range s.proxy.manager.GetStatisticManager().GetHandlers() { @@ -254,10 +258,22 @@ func (s *AdminServer) unregisterProxy() error { return nil } +// @Summary 获取proxy admin接口状态 +// @Description 获取proxy admin接口状态 +// @Success 200 {string} string "OK" +// @Security BasicAuth +// @Router /api/proxy/ping [get] func (s *AdminServer) ping(c *gin.Context) { c.JSON(http.StatusOK, "OK") } +// @Summary prepare namespace配置 +// @Description 通过管理接口, 二阶段提交, prepare namespace配置 +// @Produce json +// @Param name path string true "namespace name" +// @Success 200 {string} string "OK" +// @Security BasicAuth +// @Router /api/proxy/config/prepare/{name} [put] func (s *AdminServer) prepareConfig(c *gin.Context) { name := strings.TrimSpace(c.Param("name")) if name == "" { @@ -275,6 +291,13 @@ func (s *AdminServer) prepareConfig(c *gin.Context) { c.JSON(http.StatusOK, "OK") } +// @Summary commit namespace配置 +// @Description 通过管理接口, 二阶段提交, commit namespace配置, 使etcd配置生效 +// @Produce json +// @Param name path string true "namespace name" +// @Success 200 {string} string "OK" +// @Security BasicAuth +// @Router /api/proxy/config/commit/{name} [put] func (s *AdminServer) commitConfig(c *gin.Context) { name := strings.TrimSpace(c.Param("name")) if name == "" { @@ -289,6 +312,13 @@ func (s *AdminServer) commitConfig(c *gin.Context) { c.JSON(http.StatusOK, "OK") } +// @Summary 删除namespace配置 +// @Description 通过管理接口删除指定namespace配置 +// @Produce json +// @Param name path string true "namespace name" +// @Success 200 {string} string "OK" +// @Security BasicAuth +// @Router /api/proxy/config/delete/{name} [put] func (s *AdminServer) deleteNamespace(c *gin.Context) { name := strings.TrimSpace(c.Param("name")) if name == "" { @@ -304,11 +334,23 @@ func (s *AdminServer) deleteNamespace(c *gin.Context) { c.JSON(http.StatusOK, "OK") } +// @Summary 返回配置指纹 +// @Description 返回配置指纹, 指纹随配置变化而变化 +// @Produce json +// @Success 200 {string} string "Config Fingerprint" +// @Security BasicAuth +// @Router /api/proxy/config/fingerprint [get] func (s *AdminServer) configFingerprint(c *gin.Context) { c.JSON(http.StatusOK, s.proxy.manager.ConfigFingerprint()) } -// getNamespaceSessionSQLFingerprint return namespace sql fingerprint information +// @Summary 获取Porxy 慢SQL、错误SQL信息 +// @Description 通过管理接口获取Porxy 慢SQL、错误SQL信息 +// @Produce json +// @Param namespace path string true "namespace name" +// @Success 200 {object} SQLFingerprint +// @Security BasicAuth +// @Router /api/proxy/stats/sessionsqlfingerprint/{namespace} [get] func (s *AdminServer) getNamespaceSessionSQLFingerprint(c *gin.Context) { ns := strings.TrimSpace(c.Param("namespace")) namespace := s.proxy.manager.GetNamespace(ns) @@ -324,6 +366,13 @@ func (s *AdminServer) getNamespaceSessionSQLFingerprint(c *gin.Context) { c.JSON(http.StatusOK, ret) } +// @Summary 获取后端节点慢SQL、错误SQL信息 +// @Description 通过管理接口获取后端节点慢SQL、错误SQL信息 +// @Produce json +// @Param namespace path string true "namespace name" +// @Success 200 {object} SQLFingerprint +// @Security BasicAuth +// @Router /api/proxy/stats/backendsqlfingerprint/{namespace} [get] func (s *AdminServer) getNamespaceBackendSQLFingerprint(c *gin.Context) { ns := strings.TrimSpace(c.Param("namespace")) namespace := s.proxy.manager.GetNamespace(ns) @@ -339,6 +388,13 @@ func (s *AdminServer) getNamespaceBackendSQLFingerprint(c *gin.Context) { c.JSON(http.StatusOK, ret) } +// @Summary 清空Porxy节点慢SQL、错误SQL信息 +// @Description 通过管理接口清空Porxy慢SQL、错误SQL信息 +// @Produce json +// @Param namespace path string true "namespace name" +// @Success 200 {object} SQLFingerprint +// @Security BasicAuth +// @Router /api/proxy/stats/sessionsqlfingerprint/{namespace} [delete] func (s *AdminServer) clearNamespaceSessionSQLFingerprint(c *gin.Context) { ns := strings.TrimSpace(c.Param("namespace")) namespace := s.proxy.manager.GetNamespace(ns) @@ -353,6 +409,13 @@ func (s *AdminServer) clearNamespaceSessionSQLFingerprint(c *gin.Context) { c.JSON(http.StatusOK, "OK") } +// @Summary 清空后端节点慢SQL、错误SQL信息 +// @Description 通过管理接口清空后端节点慢SQL、错误SQL信息 +// @Produce json +// @Param namespace path string true "namespace name" +// @Success 200 {object} SQLFingerprint +// @Security BasicAuth +// @Router /api/proxy/stats/backendsqlfingerprint/{namespace} [delete] func (s *AdminServer) clearNamespaceBackendSQLFingerprint(c *gin.Context) { ns := strings.TrimSpace(c.Param("namespace")) namespace := s.proxy.manager.GetNamespace(ns) diff --git a/proxy/server/client_conn.go b/proxy/server/client_conn.go index b61591d0..9b448b67 100644 --- a/proxy/server/client_conn.go +++ b/proxy/server/client_conn.go @@ -16,9 +16,9 @@ package server import ( "fmt" - "github.com/XiaoMi/Gaea/log" "github.com/XiaoMi/Gaea/mysql" + "strings" ) // ClientConn session client connection @@ -29,6 +29,8 @@ type ClientConn struct { manager *Manager + capability uint32 + namespace string // TODO: remove it when refactor is done } @@ -51,10 +53,24 @@ func NewClientConn(c *mysql.Conn, manager *Manager) *ClientConn { } } -func (cc *ClientConn) writeInitialHandshakeV10() error { +func (cc *ClientConn) CompactVersion(sv string ) string { + version := strings.Trim(sv, " ") + if version != "" { + v := strings.Split(sv, ".") + if len(v) < 3 { + return mysql.ServerVersion + } + return version + } else { + return mysql.ServerVersion + } +} + +func (cc *ClientConn) writeInitialHandshakeV10(sv string) error { + ServerVersion:= cc.CompactVersion(sv) length := 1 + // protocol version - mysql.LenNullString(mysql.ServerVersion) + + mysql.LenNullString(ServerVersion) + 4 + // connection ID 8 + // first part of salt data 1 + // filler byte @@ -75,7 +91,7 @@ func (cc *ClientConn) writeInitialHandshakeV10() error { // Copy server version. // server version data with terminate character 0x00, type: string[NUL]. - pos = mysql.WriteNullString(data, pos, mysql.ServerVersion) + pos = mysql.WriteNullString(data, pos, ServerVersion) // Add connectionID in. // connection id type: 4 bytes. @@ -149,6 +165,7 @@ func (cc *ClientConn) readHandshakeResponse() (HandshakeResponseInfo, error) { return info, fmt.Errorf("readHandshakeResponse: only support protocol 4.1") } + cc.capability = capability // Max packet size. Don't do anything with this now. _, pos, ok = mysql.ReadUint32(data, pos) if !ok { @@ -330,7 +347,7 @@ func (cc *ClientConn) writeColumnDefinition(field *mysql.Field) error { 1 + // decimals 2 // filler if field.DefaultValue != nil { - length += 8 + len(field.DefaultValue) + length += mysql.LenEncIntSize(uint64(len(field.DefaultValue))) + len(field.DefaultValue) } data := cc.StartEphemeralPacket(length) @@ -366,7 +383,7 @@ func (cc *ClientConn) writeColumnDefinition(field *mysql.Field) error { pos = mysql.WriteUint16(data, pos, uint16(0x0000)) if field.DefaultValue != nil { - pos = mysql.WriteUint64(data, pos, field.DefaultValueLength) + pos = mysql.WriteLenEncInt(data, pos, field.DefaultValueLength) copy(data[pos:], field.DefaultValue) pos += len(field.DefaultValue) } diff --git a/proxy/server/server.go b/proxy/server/server.go index dbb55fcf..d1abbd88 100644 --- a/proxy/server/server.go +++ b/proxy/server/server.go @@ -43,6 +43,7 @@ type Server struct { adminServer *AdminServer manager *Manager EncryptKey string + ServerVersion string } // NewServer create new server @@ -55,6 +56,8 @@ func NewServer(cfg *models.Proxy, manager *Manager) (*Server, error) { s.manager = manager + s.ServerVersion = cfg.ServerVersion + // if error occurs, recycle the resources during creation. defer func() { if e := recover(); e != nil { @@ -135,7 +138,13 @@ func (s *Server) onConn(c net.Conn) { // added into time wheel s.tw.Add(s.sessionTimeout, cc, cc.Close) - + log.Notice("Connected conn_id=%d, %s@%s (%s) namespace:%s capability: %d", + cc.c.ConnectionID, + cc.executor.user, + cc.executor.clientAddr, + cc.executor.db, + cc.executor.namespace, + cc.c.capability) cc.Run() } diff --git a/proxy/server/session.go b/proxy/server/session.go index 7375a78b..d8fe325a 100644 --- a/proxy/server/session.go +++ b/proxy/server/session.go @@ -97,7 +97,7 @@ func (cc *Session) IsAllowConnect() bool { // step3: server send ok/err packets to client func (cc *Session) Handshake() error { // First build and send the server handshake packet. - if err := cc.c.writeInitialHandshakeV10(); err != nil { + if err := cc.c.writeInitialHandshakeV10(cc.proxy.ServerVersion); err != nil { clientHost, _, innerErr := net.SplitHostPort(cc.c.RemoteAddr().String()) if innerErr != nil { log.Warn("[server] Session parse host error: %v", innerErr) diff --git a/util/resource_pool.go b/util/resource_pool.go index 3ef97274..c8b899ba 100644 --- a/util/resource_pool.go +++ b/util/resource_pool.go @@ -22,6 +22,7 @@ import ( "context" "errors" "fmt" + "sync" "time" "github.com/XiaoMi/Gaea/util/sync2" @@ -54,14 +55,21 @@ type ResourcePool struct { capacity sync2.AtomicInt64 idleTimeout sync2.AtomicDuration idleTimer *timer.Timer + capTimer *timer.Timer // stats - available sync2.AtomicInt64 - active sync2.AtomicInt64 - inUse sync2.AtomicInt64 - waitCount sync2.AtomicInt64 - waitTime sync2.AtomicDuration - idleClosed sync2.AtomicInt64 + available sync2.AtomicInt64 + active sync2.AtomicInt64 + inUse sync2.AtomicInt64 + waitCount sync2.AtomicInt64 + waitTime sync2.AtomicDuration + idleClosed sync2.AtomicInt64 + baseCapacity sync2.AtomicInt64 + maxCapacity sync2.AtomicInt64 + lock *sync.Mutex + scaleOutTime int64 + scaleInTodo chan int8 + Dynamic bool } type resourceWrapper struct { @@ -86,11 +94,16 @@ func NewResourcePool(factory Factory, capacity, maxCap int, idleTimeout time.Dur panic(errors.New("invalid/out of range capacity")) } rp := &ResourcePool{ - resources: make(chan resourceWrapper, maxCap), - factory: factory, - available: sync2.NewAtomicInt64(int64(capacity)), - capacity: sync2.NewAtomicInt64(int64(capacity)), - idleTimeout: sync2.NewAtomicDuration(idleTimeout), + resources: make(chan resourceWrapper, maxCap), + factory: factory, + available: sync2.NewAtomicInt64(int64(capacity)), + capacity: sync2.NewAtomicInt64(int64(capacity)), + idleTimeout: sync2.NewAtomicDuration(idleTimeout), + baseCapacity: sync2.NewAtomicInt64(int64(capacity)), + maxCapacity: sync2.NewAtomicInt64(int64(maxCap)), + lock: &sync.Mutex{}, + scaleInTodo: make(chan int8, 1), + Dynamic: true, // 动态扩展连接池 } for i := 0; i < capacity; i++ { rp.resources <- resourceWrapper{} @@ -100,6 +113,8 @@ func NewResourcePool(factory Factory, capacity, maxCap int, idleTimeout time.Dur rp.idleTimer = timer.NewTimer(idleTimeout / 10) rp.idleTimer.Start(rp.closeIdleResources) } + rp.capTimer = timer.NewTimer(5 * time.Second) + rp.capTimer.Start(rp.scaleInResources) return rp } @@ -111,7 +126,14 @@ func (rp *ResourcePool) Close() { if rp.idleTimer != nil { rp.idleTimer.Stop() } - _ = rp.SetCapacity(0) + if rp.capTimer != nil { + rp.capTimer.Stop() + } + _ = rp.ScaleCapacity(0) +} + +func (rp *ResourcePool) SetDynamic(value bool) { + rp.Dynamic = value } // IsClosed returns true if the resource pool is closed. @@ -169,6 +191,9 @@ func (rp *ResourcePool) get(ctx context.Context, wait bool) (resource Resource, select { case wrapper, ok = <-rp.resources: default: + if rp.Dynamic { + rp.scaleOutResources() + } if !wait { return nil, nil } @@ -178,7 +203,10 @@ func (rp *ResourcePool) get(ctx context.Context, wait bool) (resource Resource, case <-ctx.Done(): return nil, ErrTimeout } - rp.recordWait(startTime) + endTime := time.Now() + if int64(startTime.UnixNano()/100000) != int64(endTime.UnixNano()/100000) { + rp.recordWait(startTime) + } } if !ok { return nil, ErrClosed @@ -218,14 +246,23 @@ func (rp *ResourcePool) Put(resource Resource) { rp.available.Add(1) } +func (rp *ResourcePool) SetCapacity(capacity int) error { + oldcap := rp.baseCapacity.Get() + rp.baseCapacity.CompareAndSwap(oldcap, int64(capacity)) + if int(oldcap) < capacity { + rp.ScaleCapacity(capacity) + } + return nil +} + // SetCapacity changes the capacity of the pool. // You can use it to shrink or expand, but not beyond // the max capacity. If the change requires the pool // to be shrunk, SetCapacity waits till the necessary // number of resources are returned to the pool. // A SetCapacity of 0 is equivalent to closing the ResourcePool. -func (rp *ResourcePool) SetCapacity(capacity int) error { - if capacity < 0 || capacity > cap(rp.resources) { +func (rp *ResourcePool) ScaleCapacity(capacity int) error { + if capacity < 0 || capacity > int(rp.maxCapacity.Get()) { return fmt.Errorf("capacity %d is out of range", capacity) } @@ -266,6 +303,33 @@ func (rp *ResourcePool) SetCapacity(capacity int) error { return nil } +// 扩容 +func (rp *ResourcePool) scaleOutResources() { + rp.lock.Lock() + defer rp.lock.Unlock() + if rp.capacity.Get() < rp.maxCapacity.Get() { + rp.ScaleCapacity(int(rp.capacity.Get()) + 1) + rp.scaleOutTime = time.Now().Unix() + } +} + +// 缩容 +func (rp *ResourcePool) scaleInResources() { + rp.lock.Lock() + defer rp.lock.Unlock() + if rp.capacity.Get() > rp.baseCapacity.Get() && time.Now().Unix()-rp.scaleOutTime > 60 { + select { + case rp.scaleInTodo <- 0: + go func() { + rp.ScaleCapacity(int(rp.capacity.Get()) - 1) + <-rp.scaleInTodo + }() + default: + return + } + } +} + func (rp *ResourcePool) recordWait(start time.Time) { rp.waitCount.Add(1) rp.waitTime.Add(time.Now().Sub(start)) diff --git a/util/resource_pool_test.go b/util/resource_pool_test.go index f7051dd4..1df612f6 100644 --- a/util/resource_pool_test.go +++ b/util/resource_pool_test.go @@ -58,7 +58,8 @@ func TestOpen(t *testing.T) { lastID.Set(0) count.Set(0) p := NewResourcePool(PoolFactory, 6, 6, time.Second) - p.SetCapacity(5) + p.SetDynamic(false) + p.ScaleCapacity(5) var resources [10]Resource // Test Get @@ -143,8 +144,8 @@ func TestOpen(t *testing.T) { t.Errorf("Expecting 6, received %d", lastID.Get()) } - // SetCapacity - p.SetCapacity(3) + // ScaleCapacity + p.ScaleCapacity(3) if count.Get() != 3 { t.Errorf("Expecting 3, received %d", count.Get()) } @@ -157,7 +158,7 @@ func TestOpen(t *testing.T) { if p.Available() != 3 { t.Errorf("Expecting 3, received %d", p.Available()) } - p.SetCapacity(6) + p.ScaleCapacity(6) if p.Capacity() != 6 { t.Errorf("Expecting 6, received %d", p.Capacity()) } @@ -194,11 +195,111 @@ func TestOpen(t *testing.T) { } } +func TestOpenDynamic(t *testing.T) { + ctx := context.Background() + lastID.Set(0) + count.Set(0) + p := NewResourcePool(PoolFactory, 6, 10, time.Second) + p.ScaleCapacity(5) + p.SetDynamic(true) + var resources [10]Resource + + // Test Get + for i := 0; i < 7; i++ { + r, err := p.Get(ctx) + resources[i] = r + if err != nil { + t.Errorf("Unexpected error %v", err) + } + if i < 5 { + if p.Available() != int64(5-i-1) { + t.Errorf("expecting %d, received %d", 5-i-1, p.Available()) + } + } else { + if p.Available() != 0 { + t.Errorf("expecting %d, received %d", 0, p.Available()) + } + } + + if p.WaitCount() != 0 { + t.Errorf("expecting 0, received %d", p.WaitCount()) + } + if p.WaitTime() != 0 { + t.Errorf("expecting 0, received %d", p.WaitTime()) + } + if lastID.Get() != int64(i+1) { + t.Errorf("Expecting %d, received %d", i+1, lastID.Get()) + } + if count.Get() != int64(i+1) { + t.Errorf("Expecting %d, received %d", i+1, count.Get()) + } + } + + // Test that Get waits + ch := make(chan bool) + go func() { + for i := 0; i < 7; i++ { + r, err := p.Get(ctx) + if err != nil { + t.Errorf("Get failed: %v", err) + } + resources[i] = r + } + for i := 0; i < 7; i++ { + p.Put(resources[i]) + } + ch <- true + }() + for i := 0; i < 7; i++ { + // Sleep to ensure the goroutine waits + time.Sleep(10 * time.Millisecond) + p.Put(resources[i]) + } + <-ch + if p.WaitCount() != 4 { + t.Errorf("Expecting 4, received %d", p.WaitCount()) + } + if p.WaitTime() == 0 { + t.Errorf("Expecting non-zero") + } + if lastID.Get() != 10 { + t.Errorf("Expecting 10, received %d", lastID.Get()) + } + + // Test Close resource + r, err := p.Get(ctx) + if err != nil { + t.Errorf("Unexpected error %v", err) + } + r.Close() + p.Put(nil) + if count.Get() != 9 { + t.Errorf("Expecting 9, received %d", count.Get()) + } + for i := 0; i < 5; i++ { + r, err := p.Get(ctx) + if err != nil { + t.Errorf("Get failed: %v", err) + } + resources[i] = r + } + for i := 0; i < 5; i++ { + p.Put(resources[i]) + } + if count.Get() != 9 { + t.Errorf("Expecting 9, received %d", count.Get()) + } + if lastID.Get() != 10 { + t.Errorf("Expecting 10, received %d", lastID.Get()) + } +} + func TestShrinking(t *testing.T) { ctx := context.Background() lastID.Set(0) count.Set(0) p := NewResourcePool(PoolFactory, 5, 5, time.Second) + p.SetDynamic(false) var resources [10]Resource // Leave one empty slot in the pool for i := 0; i < 4; i++ { @@ -210,7 +311,7 @@ func TestShrinking(t *testing.T) { } done := make(chan bool) go func() { - p.SetCapacity(3) + p.ScaleCapacity(3) done <- true }() expected := `{"Capacity": 3, "Available": 0, "Active": 4, "InUse": 4, "MaxCapacity": 5, "WaitCount": 0, "WaitTime": 0, "IdleTimeout": 1000000000, "IdleClosed": 0}` @@ -224,7 +325,7 @@ func TestShrinking(t *testing.T) { } } // There are already 2 resources available in the pool. - // So, returning one should be enough for SetCapacity to complete. + // So, returning one should be enough for ScaleCapacity to complete. p.Put(resources[3]) <-done // Return the rest of the resources @@ -240,7 +341,7 @@ func TestShrinking(t *testing.T) { t.Errorf("Expecting 3, received %d", count.Get()) } - // Ensure no deadlock if SetCapacity is called after we start + // Ensure no deadlock if ScaleCapacity is called after we start // waiting for a resource var err error for i := 0; i < 3; i++ { @@ -261,7 +362,7 @@ func TestShrinking(t *testing.T) { // This will also wait go func() { - p.SetCapacity(2) + p.ScaleCapacity(2) done <- true }() time.Sleep(10 * time.Millisecond) @@ -285,8 +386,8 @@ func TestShrinking(t *testing.T) { t.Errorf("Expecting 2, received %d", count.Get()) } - // Test race condition of SetCapacity with itself - p.SetCapacity(3) + // Test race condition of ScaleCapacity with itself + p.ScaleCapacity(3) for i := 0; i < 3; i++ { resources[i], err = p.Get(ctx) if err != nil { @@ -305,9 +406,9 @@ func TestShrinking(t *testing.T) { time.Sleep(10 * time.Millisecond) // This will wait till we Put - go p.SetCapacity(2) + go p.ScaleCapacity(2) time.Sleep(10 * time.Millisecond) - go p.SetCapacity(4) + go p.ScaleCapacity(4) time.Sleep(10 * time.Millisecond) // This should not hang @@ -316,11 +417,11 @@ func TestShrinking(t *testing.T) { } <-done - err = p.SetCapacity(-1) + err = p.ScaleCapacity(-1) if err == nil { t.Errorf("Expecting error") } - err = p.SetCapacity(255555) + err = p.ScaleCapacity(255555) if err == nil { t.Errorf("Expecting error") } @@ -338,6 +439,7 @@ func TestClosing(t *testing.T) { lastID.Set(0) count.Set(0) p := NewResourcePool(PoolFactory, 5, 5, time.Second) + p.SetDynamic(false) var resources [10]Resource for i := 0; i < 5; i++ { r, err := p.Get(ctx) @@ -368,8 +470,8 @@ func TestClosing(t *testing.T) { // Wait for Close to return <-ch - // SetCapacity must be ignored after Close - err := p.SetCapacity(1) + // ScaleCapacity must be ignored after Close + err := p.ScaleCapacity(1) if err == nil { t.Errorf("expecting error") } @@ -392,6 +494,7 @@ func TestIdleTimeout(t *testing.T) { lastID.Set(0) count.Set(0) p := NewResourcePool(PoolFactory, 1, 1, 10*time.Millisecond) + p.SetDynamic(false) defer p.Close() r, err := p.Get(ctx) @@ -497,6 +600,7 @@ func TestCreateFail(t *testing.T) { lastID.Set(0) count.Set(0) p := NewResourcePool(FailFactory, 5, 5, time.Second) + p.SetDynamic(false) defer p.Close() if _, err := p.Get(ctx); err.Error() != "Failed" { t.Errorf("Expecting Failed, received %v", err) @@ -513,6 +617,7 @@ func TestSlowCreateFail(t *testing.T) { lastID.Set(0) count.Set(0) p := NewResourcePool(SlowFailFactory, 2, 2, time.Second) + p.SetDynamic(false) defer p.Close() ch := make(chan bool) // The third Get should not wait indefinitely @@ -535,6 +640,7 @@ func TestTimeout(t *testing.T) { lastID.Set(0) count.Set(0) p := NewResourcePool(PoolFactory, 1, 1, time.Second) + p.SetDynamic(false) defer p.Close() r, err := p.Get(ctx) if err != nil { @@ -554,6 +660,7 @@ func TestExpired(t *testing.T) { lastID.Set(0) count.Set(0) p := NewResourcePool(PoolFactory, 1, 1, time.Second) + p.SetDynamic(false) defer p.Close() ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-1*time.Second)) r, err := p.Get(ctx)