diff --git a/pkg/client/client.go b/pkg/client/client.go index e3a4cf5..86e979d 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -8,6 +8,7 @@ import ( "github.com/vesoft-inc/nebula-importer/v4/pkg/errors" "github.com/cenkalti/backoff/v4" + nebula "github.com/vesoft-inc/nebula-go/v3" ) type ( @@ -136,3 +137,37 @@ func (c *defaultClient) Close() error { } return nil } + +func NewSessionPool(opts ...Option) (*nebula.SessionPool, error) { + ops := newOptions(opts...) + var ( + hostAddresses []nebula.HostAddress + pool *nebula.SessionPool + ) + + for _, h := range ops.addresses { + hostPort := strings.Split(h, ":") + if len(hostPort) != 2 { + return nil, errors.ErrInvalidAddress + } + if hostPort[0] == "" { + return nil, errors.ErrInvalidAddress + } + port, err := strconv.Atoi(hostPort[1]) + if err != nil { + err = errors.ErrInvalidAddress + } + hostAddresses = append(hostAddresses, nebula.HostAddress{Host: hostPort[0], Port: port}) + } + conf, err := nebula.NewSessionPoolConf(ops.user, ops.password, hostAddresses, + "sf300_2", nebula.WithMaxSize(3000)) + if err != nil { + return nil, err + } + pool, err = nebula.NewSessionPool(*conf, nebula.DefaultLogger{}) + if err != nil { + return nil, err + } + return pool, nil + +} diff --git a/pkg/config/base/client.go b/pkg/config/base/client.go index 612d45d..31f7fbe 100644 --- a/pkg/config/base/client.go +++ b/pkg/config/base/client.go @@ -9,6 +9,7 @@ import ( "github.com/vesoft-inc/nebula-importer/v4/pkg/client" "github.com/vesoft-inc/nebula-importer/v4/pkg/errors" + "github.com/vesoft-inc/nebula-importer/v4/pkg/manager" "github.com/vesoft-inc/nebula-importer/v4/pkg/utils" ) @@ -85,6 +86,11 @@ func (c *Client) BuildClientPool(opts ...client.Option) (client.Pool, error) { } options = append(options, opts...) pool := newClientPool(options...) + sessionPool, err := client.NewSessionPool(options...) + if err != nil { + return nil, err + } + manager.DefaultSessionPool = sessionPool return pool, nil } diff --git a/pkg/config/v3/config.go b/pkg/config/v3/config.go index 15b4a62..c053513 100644 --- a/pkg/config/v3/config.go +++ b/pkg/config/v3/config.go @@ -78,6 +78,7 @@ func (c *Config) Build() error { if err != nil { return err } + mgr, err = c.Manager.BuildManager(l, pool, c.Sources, manager.WithGetClientOptions(client.WithClientInitFunc(nil)), // clean the USE SPACE in 3.x ) diff --git a/pkg/errors/errors.go b/pkg/errors/errors.go index c0b5481..8c69858 100644 --- a/pkg/errors/errors.go +++ b/pkg/errors/errors.go @@ -27,4 +27,6 @@ var ( ErrUnsupportedFunction = stderrors.New("unsupported function") ErrFilterSyntax = stderrors.New("filter syntax") ErrUnsupportedMode = stderrors.New("unsupported mode") + ErrNoDynamicParam = stderrors.New("no dynamic param") + ErrFetchFailed = stderrors.New("fetch failed") ) diff --git a/pkg/manager/manager.go b/pkg/manager/manager.go index e657cff..8718762 100644 --- a/pkg/manager/manager.go +++ b/pkg/manager/manager.go @@ -8,6 +8,8 @@ import ( "sync/atomic" "time" + "github.com/panjf2000/ants/v2" + nebula "github.com/vesoft-inc/nebula-go/v3" "github.com/vesoft-inc/nebula-importer/v4/pkg/client" "github.com/vesoft-inc/nebula-importer/v4/pkg/errors" "github.com/vesoft-inc/nebula-importer/v4/pkg/importer" @@ -16,8 +18,6 @@ import ( "github.com/vesoft-inc/nebula-importer/v4/pkg/source" "github.com/vesoft-inc/nebula-importer/v4/pkg/spec" "github.com/vesoft-inc/nebula-importer/v4/pkg/stats" - - "github.com/panjf2000/ants/v2" ) const ( @@ -158,6 +158,7 @@ func WithLogger(l logger.Logger) Option { } func (m *defaultManager) Import(s source.Source, brr reader.BatchRecordReader, importers ...importer.Importer) error { + if len(importers) == 0 { return nil } @@ -268,6 +269,9 @@ func (m *defaultManager) Stop() (err error) { m.importerWaitGroup.Wait() m.logStats() + if DefaultSessionPool != nil { + DefaultSessionPool.Close() + } return m.After() } @@ -437,3 +441,6 @@ func (m *defaultManager) logError(err error, msg string, fields ...logger.Field) fields = append(fields, logger.MapToFields(e.Fields())...) m.logger.SkipCaller(1).WithError(e.Cause()).Error(msg, fields...) } + +// tmp var for test +var DefaultSessionPool *nebula.SessionPool diff --git a/pkg/spec/base/dynamicParam.go b/pkg/spec/base/dynamicParam.go new file mode 100644 index 0000000..9cd001c --- /dev/null +++ b/pkg/spec/base/dynamicParam.go @@ -0,0 +1,8 @@ +package specbase + +type DynamicParam struct { + Address string `yaml:"address,omitempty"` + User string `yaml:"user,omitempty"` + Password string `yaml:"password,omitempty"` + Space string `yaml:"space,omitempty"` +} diff --git a/pkg/spec/base/mode.go b/pkg/spec/base/mode.go index 9b018d7..c70fcca 100644 --- a/pkg/spec/base/mode.go +++ b/pkg/spec/base/mode.go @@ -3,10 +3,11 @@ package specbase import "strings" const ( - DefaultMode = InsertMode - InsertMode Mode = "INSERT" - UpdateMode Mode = "UPDATE" - DeleteMode Mode = "DELETE" + DefaultMode = InsertMode + InsertMode Mode = "INSERT" + UpdateMode Mode = "UPDATE" + DeleteMode Mode = "DELETE" + BatchUpdateMode Mode = "BATCH_UPDATE" ) type Mode string @@ -19,5 +20,5 @@ func (m Mode) Convert() Mode { } func (m Mode) IsSupport() bool { - return m == InsertMode || m == UpdateMode || m == DeleteMode + return m == InsertMode || m == UpdateMode || m == DeleteMode || m == BatchUpdateMode } diff --git a/pkg/spec/v3/node.go b/pkg/spec/v3/node.go index 2c3929e..24d3350 100644 --- a/pkg/spec/v3/node.go +++ b/pkg/spec/v3/node.go @@ -2,10 +2,13 @@ package specv3 import ( "fmt" + "sort" "strings" + nebula "github.com/vesoft-inc/nebula-go/v3" "github.com/vesoft-inc/nebula-importer/v4/pkg/bytebufferpool" "github.com/vesoft-inc/nebula-importer/v4/pkg/errors" + "github.com/vesoft-inc/nebula-importer/v4/pkg/manager" specbase "github.com/vesoft-inc/nebula-importer/v4/pkg/spec/base" "github.com/vesoft-inc/nebula-importer/v4/pkg/utils" ) @@ -21,13 +24,16 @@ type ( Filter *specbase.Filter `yaml:"filter,omitempty"` - Mode specbase.Mode `yaml:"mode,omitempty"` + Mode specbase.Mode `yaml:"mode,omitempty"` + DynamicParam *specbase.DynamicParam `yaml:"dynamicParam,omitempty"` - fnStatement func(records ...Record) (string, int, error) + fnStatement func(records ...Record) (string, int, error) + dynamicFnStatement func(pool *nebula.SessionPool, records ...Record) (string, int, error) // "INSERT VERTEX name(prop_name, ..., prop_name) VALUES " // "UPDATE VERTEX ON name " // "DELETE TAG name FROM " statementPrefix string + // session for batch update } Nodes []*Node @@ -109,6 +115,11 @@ func (n *Node) Complete() { case specbase.DeleteMode: n.fnStatement = n.deleteStatement n.statementPrefix = fmt.Sprintf("DELETE TAG %s FROM ", utils.ConvertIdentifier(n.Name)) + case specbase.BatchUpdateMode: + //batch update, would fetch the node first. + //and then update the node with the props + //statementPrefix should be modified after fetch the node + n.fnStatement = n.updateBatchStatement } } @@ -279,3 +290,166 @@ func (ns Nodes) Validate() error { } return nil } + +func (n *Node) updateBatchStatement(records ...Record) (statement string, nRecord int, err error) { + if n.DynamicParam == nil { + return "", 0, errors.ErrNoDynamicParam + } + buff := bytebufferpool.Get() + defer bytebufferpool.Put(buff) + var ( + idValues []string + cols []string + needUpdateRecords []Record + ) + + for _, record := range records { + idValue, err := n.ID.Value(record) + if err != nil { + return "", 0, n.importError(err) + } + idValues = append(idValues, idValue) + propsSetValueList, err := n.Props.ValueList(record) + if err != nil { + return "", 0, err + } + needUpdateRecords = append(needUpdateRecords, propsSetValueList) + } + for _, prop := range n.Props { + cols = append(cols, prop.Name) + } + + updatedCols, updatedRecords, err := n.genDynamicUpdateRecord(manager.DefaultSessionPool, idValues, cols, needUpdateRecords) + if err != nil { + return "", 0, err + } + + // batch insert + // INSERT VERTEX %s(%s) VALUES + prefix := fmt.Sprintf("INSERT VERTEX %s(%s) VALUES ", utils.ConvertIdentifier(n.Name), strings.Join(updatedCols, ", ")) + buff.SetString(prefix) + + for index, record := range updatedRecords { + idValue := idValues[index] + + if nRecord > 0 { + _, _ = buff.WriteString(", ") + } + + // id:(prop_value1, prop_value2, ...) + _, _ = buff.WriteString(idValue) + _, _ = buff.WriteString(":(") + _, _ = buff.WriteStringSlice(record, ", ") + _, _ = buff.WriteString(")") + + nRecord++ + } + return buff.String(), nRecord, nil +} + +// genDynamicUpdateRecord generate the update record for batch update +// return column values and records +func (n *Node) genDynamicUpdateRecord(pool *nebula.SessionPool, idValues []string, cols []string, records []Record) ([]string, []Record, error) { + stat := fmt.Sprintf("FETCH PROP ON %s %s YIELD VERTEX as v;", utils.ConvertIdentifier(n.Name), strings.Join(idValues, ",")) + var ( + rs *nebula.ResultSet + err error + updatedCols []string + updatedRecords []Record + ) + for i := 0; i < 3; i++ { + rs, err = pool.Execute(stat) + if err != nil { + continue + } + if !rs.IsSucceed() { + continue + } + } + if err != nil { + return nil, nil, err + } + if !rs.IsSucceed() { + return nil, nil, fmt.Errorf(rs.GetErrorMsg()) + } + fetchData, err := n.getNebulaFetchData(rs) + for _, property := range fetchData { + updatedCols = n.getDynamicUpdateCols(cols, property) + break + } + for index, id := range idValues { + originalData, ok := fetchData[id] + if !ok { + return nil, nil, fmt.Errorf("cannot find id, id: %s", id) + } + r := n.getUpdateRocord(originalData, updatedCols, records[index]) + updatedRecords = append(updatedRecords, r) + } + return updatedCols, updatedRecords, nil +} + +// append the need update column to the end of the cols +func (n *Node) getDynamicUpdateCols(updateCols []string, properties map[string]*nebula.ValueWrapper) []string { + needUpdate := make(map[string]struct{}) + for _, c := range updateCols { + needUpdate[c] = struct{}{} + } + var cols []string + for k, _ := range properties { + if _, ok := needUpdate[k]; !ok { + cols = append(cols, k) + } + } + sort.Slice(cols, func(i, j int) bool { + return cols[i] < cols[j] + }) + cols = append(cols, updateCols...) + return cols +} + +func (n *Node) getNebulaFetchData(rs *nebula.ResultSet) (map[string]map[string]*nebula.ValueWrapper, error) { + m := make(map[string]map[string]*nebula.ValueWrapper) + for i := 0; i < rs.GetRowSize(); i++ { + row, err := rs.GetRowValuesByIndex(i) + if err != nil { + return nil, err + } + cell, err := row.GetValueByIndex(0) + if err != nil { + return nil, err + } + node, err := cell.AsNode() + + if err != nil { + return nil, err + } + property, err := node.Properties(n.Name) + if err != nil { + return nil, err + } + m[node.GetID().String()] = property + } + return m, nil +} + +func (n *Node) getUpdateRocord(original map[string]*nebula.ValueWrapper, Columns []string, update Record) Record { + r := make(Record, 0, len(Columns)) + var vStr string + for _, c := range Columns { + value := original[c] + + switch value.GetType() { + // TODO should handle other type + case "datetime": + vStr = fmt.Sprintf("datetime(\"%s\")", value.String()) + default: + vStr = value.String() + } + r = append(r, vStr) + } + // update + for i := 0; i < len(update); i++ { + r[len(Columns)-len(update)+i] = update[i] + } + return r +}