Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

planner,infoschema,executor: Add tiflash fine grained shuffle support for hash join and aggregation #40121

Merged
merged 24 commits into from
Jan 6, 2023
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 4 additions & 124 deletions executor/memtable_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ import (
"github.com/pingcap/errors"
"github.com/pingcap/failpoint"
"github.com/pingcap/kvproto/pkg/diagnosticspb"
"github.com/pingcap/log"
"github.com/pingcap/sysutil"
"github.com/pingcap/tidb/config"
"github.com/pingcap/tidb/infoschema"
Expand All @@ -47,7 +46,6 @@ import (
"github.com/pingcap/tidb/util/execdetails"
"github.com/pingcap/tidb/util/pdapi"
"github.com/pingcap/tidb/util/set"
"go.uber.org/zap"
"golang.org/x/exp/slices"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
Expand Down Expand Up @@ -177,7 +175,7 @@ func fetchClusterConfig(sctx sessionctx.Context, nodeTypes, nodeAddrs set.String
if err != nil {
return nil, err
}
serversInfo = filterClusterServerInfo(serversInfo, nodeTypes, nodeAddrs)
serversInfo = infoschema.FilterClusterServerInfo(serversInfo, nodeTypes, nodeAddrs)
//nolint: prealloc
var finalRows [][]types.Datum
wg := sync.WaitGroup{}
Expand Down Expand Up @@ -310,108 +308,12 @@ func (e *clusterServerInfoRetriever) retrieve(ctx context.Context, sctx sessionc
return nil, nil
}
e.retrieved = true

serversInfo, err := infoschema.GetClusterServerInfo(sctx)
if err != nil {
return nil, err
}
serversInfo = filterClusterServerInfo(serversInfo, e.extractor.NodeTypes, e.extractor.Instances)

type result struct {
idx int
rows [][]types.Datum
err error
}
wg := sync.WaitGroup{}
ch := make(chan result, len(serversInfo))
infoTp := e.serverInfoType
finalRows := make([][]types.Datum, 0, len(serversInfo)*10)
for i, srv := range serversInfo {
address := srv.Address
remote := address
if srv.ServerType == "tidb" {
remote = srv.StatusAddr
}
wg.Add(1)
go func(index int, remote, address, serverTP string) {
util.WithRecovery(func() {
defer wg.Done()
items, err := getServerInfoByGRPC(ctx, remote, infoTp)
if err != nil {
ch <- result{idx: index, err: err}
return
}
partRows := serverInfoItemToRows(items, serverTP, address)
ch <- result{idx: index, rows: partRows}
}, nil)
}(i, remote, address, srv.ServerType)
}
wg.Wait()
close(ch)
// Keep the original order to make the result more stable
var results []result //nolint: prealloc
for result := range ch {
if result.err != nil {
sctx.GetSessionVars().StmtCtx.AppendWarning(result.err)
continue
}
results = append(results, result)
}
slices.SortFunc(results, func(i, j result) bool { return i.idx < j.idx })
for _, result := range results {
finalRows = append(finalRows, result.rows...)
}
return finalRows, nil
}

func serverInfoItemToRows(items []*diagnosticspb.ServerInfoItem, tp, addr string) [][]types.Datum {
rows := make([][]types.Datum, 0, len(items))
for _, v := range items {
for _, item := range v.Pairs {
row := types.MakeDatums(
tp,
addr,
v.Tp,
v.Name,
item.Key,
item.Value,
)
rows = append(rows, row)
}
}
return rows
}

func getServerInfoByGRPC(ctx context.Context, address string, tp diagnosticspb.ServerInfoType) ([]*diagnosticspb.ServerInfoItem, error) {
opt := grpc.WithInsecure()
security := config.GetGlobalConfig().Security
if len(security.ClusterSSLCA) != 0 {
clusterSecurity := security.ClusterSecurity()
tlsConfig, err := clusterSecurity.ToTLSConfig()
if err != nil {
return nil, errors.Trace(err)
}
opt = grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))
}
conn, err := grpc.Dial(address, opt)
if err != nil {
return nil, err
}
defer func() {
err := conn.Close()
if err != nil {
log.Error("close grpc connection error", zap.Error(err))
}
}()

cli := diagnosticspb.NewDiagnosticsClient(conn)
ctx, cancel := context.WithTimeout(ctx, time.Second*10)
defer cancel()
r, err := cli.ServerInfo(ctx, &diagnosticspb.ServerInfoRequest{Tp: tp})
if err != nil {
return nil, err
}
return r.Items, nil
serversInfo = infoschema.FilterClusterServerInfo(serversInfo, e.extractor.NodeTypes, e.extractor.Instances)
return infoschema.FetchClusterServerInfoWithoutPrivilegeCheck(ctx, sctx, serversInfo, e.serverInfoType, true)
}

func parseFailpointServerInfo(s string) []infoschema.ServerInfo {
Expand All @@ -428,28 +330,6 @@ func parseFailpointServerInfo(s string) []infoschema.ServerInfo {
return serversInfo
}

func filterClusterServerInfo(serversInfo []infoschema.ServerInfo, nodeTypes, addresses set.StringSet) []infoschema.ServerInfo {
if len(nodeTypes) == 0 && len(addresses) == 0 {
return serversInfo
}

filterServers := make([]infoschema.ServerInfo, 0, len(serversInfo))
for _, srv := range serversInfo {
// Skip some node type which has been filtered in WHERE clause
// e.g: SELECT * FROM cluster_config WHERE type='tikv'
if len(nodeTypes) > 0 && !nodeTypes.Exist(srv.ServerType) {
continue
}
// Skip some node address which has been filtered in WHERE clause
// e.g: SELECT * FROM cluster_config WHERE address='192.16.8.12:2379'
if len(addresses) > 0 && !addresses.Exist(srv.Address) {
continue
}
filterServers = append(filterServers, srv)
}
return filterServers
}

type clusterLogRetriever struct {
isDrained bool
retrieving bool
Expand Down Expand Up @@ -515,7 +395,7 @@ func (e *clusterLogRetriever) initialize(ctx context.Context, sctx sessionctx.Co

instances := e.extractor.Instances
nodeTypes := e.extractor.NodeTypes
serversInfo = filterClusterServerInfo(serversInfo, nodeTypes, instances)
serversInfo = infoschema.FilterClusterServerInfo(serversInfo, nodeTypes, instances)

var levels = make([]diagnosticspb.LogLevel, 0, len(e.extractor.LogLevels))
for l := range e.extractor.LogLevels {
Expand Down
2 changes: 1 addition & 1 deletion executor/set_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func (s *SetConfigExec) Next(ctx context.Context, req *chunk.Chunk) error {
if s.p.Instance != "" {
nodeAddrs.Insert(s.p.Instance)
}
serversInfo = filterClusterServerInfo(serversInfo, nodeTypes, nodeAddrs)
serversInfo = infoschema.FilterClusterServerInfo(serversInfo, nodeTypes, nodeAddrs)
if s.p.Instance != "" && len(serversInfo) == 0 {
return errors.Errorf("instance %v is not found in this cluster", s.p.Instance)
}
Expand Down
4 changes: 4 additions & 0 deletions infoschema/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,12 @@ go_library(
"@com_github_ngaut_pools//:pools",
"@com_github_pingcap_errors//:errors",
"@com_github_pingcap_failpoint//:failpoint",
"@com_github_pingcap_kvproto//pkg/diagnosticspb",
"@com_github_pingcap_kvproto//pkg/metapb",
"@com_github_pingcap_log//:log",
"@com_github_tikv_client_go_v2//tikv",
"@org_golang_google_grpc//:grpc",
"@org_golang_google_grpc//credentials",
"@org_golang_x_exp//slices",
"@org_uber_go_zap//:zap",
],
Expand Down
147 changes: 147 additions & 0 deletions infoschema/tables.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,14 @@ import (
"net/http"
"strconv"
"strings"
"sync"
"time"

"github.com/pingcap/errors"
"github.com/pingcap/failpoint"
"github.com/pingcap/kvproto/pkg/diagnosticspb"
"github.com/pingcap/kvproto/pkg/metapb"
"github.com/pingcap/log"
"github.com/pingcap/tidb/config"
"github.com/pingcap/tidb/ddl/placement"
"github.com/pingcap/tidb/domain/infosync"
Expand All @@ -47,9 +51,13 @@ import (
"github.com/pingcap/tidb/util/logutil"
"github.com/pingcap/tidb/util/pdapi"
"github.com/pingcap/tidb/util/sem"
"github.com/pingcap/tidb/util/set"
"github.com/pingcap/tidb/util/stmtsummary"
"github.com/tikv/client-go/v2/tikv"
"go.uber.org/zap"
"golang.org/x/exp/slices"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)

const (
Expand Down Expand Up @@ -2226,3 +2234,142 @@ func (vt *VirtualTable) GetPhysicalID() int64 {
func (vt *VirtualTable) Type() table.Type {
return table.VirtualTable
}

// GetTiFlashServerInfo returns all TiFlash server infos
func GetTiFlashServerInfo(sctx sessionctx.Context) ([]ServerInfo, error) {
if config.GetGlobalConfig().DisaggregatedTiFlash {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DisaggregatedTiFlash is not supported because pd does not have related information?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is still under design, especially for auto scaling.

Copy link
Collaborator

@guo-shaoge guo-shaoge Jan 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When AutoScaler manage tiflash_compute nodes, PD cannot get the server info of tiflash_compute, we need add API in AutoScaler to support this.

return nil, table.ErrUnsupportedOp
}
serversInfo, err := GetStoreServerInfo(sctx)
if err != nil {
return nil, err
}
serversInfo = FilterClusterServerInfo(serversInfo, set.NewStringSet(kv.TiFlash.Name()), set.NewStringSet())
return serversInfo, nil
}

// FetchClusterServerInfoWithoutPrivilegeCheck fetches cluster server information
func FetchClusterServerInfoWithoutPrivilegeCheck(ctx context.Context, sctx sessionctx.Context, serversInfo []ServerInfo, serverInfoType diagnosticspb.ServerInfoType, recordWarningInStmtCtx bool) ([][]types.Datum, error) {
type result struct {
idx int
rows [][]types.Datum
err error
}
wg := sync.WaitGroup{}
ch := make(chan result, len(serversInfo))
infoTp := serverInfoType
finalRows := make([][]types.Datum, 0, len(serversInfo)*10)
for i, srv := range serversInfo {
address := srv.Address
remote := address
if srv.ServerType == "tidb" {
remote = srv.StatusAddr
}
wg.Add(1)
go func(index int, remote, address, serverTP string) {
util.WithRecovery(func() {
defer wg.Done()
items, err := getServerInfoByGRPC(ctx, remote, infoTp)
if err != nil {
ch <- result{idx: index, err: err}
return
}
partRows := serverInfoItemToRows(items, serverTP, address)
ch <- result{idx: index, rows: partRows}
}, nil)
}(i, remote, address, srv.ServerType)
}
wg.Wait()
close(ch)
// Keep the original order to make the result more stable
var results []result //nolint: prealloc
for result := range ch {
if result.err != nil {
if recordWarningInStmtCtx {
sctx.GetSessionVars().StmtCtx.AppendWarning(result.err)
} else {
log.Warn(result.err.Error())
}
continue
}
results = append(results, result)
}
slices.SortFunc(results, func(i, j result) bool { return i.idx < j.idx })
for _, result := range results {
finalRows = append(finalRows, result.rows...)
}
return finalRows, nil
}

func serverInfoItemToRows(items []*diagnosticspb.ServerInfoItem, tp, addr string) [][]types.Datum {
rows := make([][]types.Datum, 0, len(items))
for _, v := range items {
for _, item := range v.Pairs {
row := types.MakeDatums(
tp,
addr,
v.Tp,
v.Name,
item.Key,
item.Value,
)
rows = append(rows, row)
}
}
return rows
}

func getServerInfoByGRPC(ctx context.Context, address string, tp diagnosticspb.ServerInfoType) ([]*diagnosticspb.ServerInfoItem, error) {
opt := grpc.WithInsecure()
security := config.GetGlobalConfig().Security
if len(security.ClusterSSLCA) != 0 {
clusterSecurity := security.ClusterSecurity()
tlsConfig, err := clusterSecurity.ToTLSConfig()
if err != nil {
return nil, errors.Trace(err)
}
opt = grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))
}
conn, err := grpc.Dial(address, opt)
if err != nil {
return nil, err
}
defer func() {
err := conn.Close()
if err != nil {
log.Error("close grpc connection error", zap.Error(err))
}
}()

cli := diagnosticspb.NewDiagnosticsClient(conn)
ctx, cancel := context.WithTimeout(ctx, time.Second*10)
defer cancel()
r, err := cli.ServerInfo(ctx, &diagnosticspb.ServerInfoRequest{Tp: tp})
if err != nil {
return nil, err
}
return r.Items, nil
}

// FilterClusterServerInfo filters serversInfo by nodeTypes and addresses
func FilterClusterServerInfo(serversInfo []ServerInfo, nodeTypes, addresses set.StringSet) []ServerInfo {
if len(nodeTypes) == 0 && len(addresses) == 0 {
return serversInfo
}

filterServers := make([]ServerInfo, 0, len(serversInfo))
for _, srv := range serversInfo {
// Skip some node type which has been filtered in WHERE clause
// e.g: SELECT * FROM cluster_config WHERE type='tikv'
if len(nodeTypes) > 0 && !nodeTypes.Exist(srv.ServerType) {
continue
}
// Skip some node address which has been filtered in WHERE clause
// e.g: SELECT * FROM cluster_config WHERE address='192.16.8.12:2379'
if len(addresses) > 0 && !addresses.Exist(srv.Address) {
continue
}
filterServers = append(filterServers, srv)
}
return filterServers
}
1 change: 1 addition & 0 deletions planner/core/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ go_library(
"@com_github_pingcap_errors//:errors",
"@com_github_pingcap_failpoint//:failpoint",
"@com_github_pingcap_kvproto//pkg/coprocessor",
"@com_github_pingcap_kvproto//pkg/diagnosticspb",
"@com_github_pingcap_tipb//go-tipb",
"@com_github_tikv_client_go_v2//kv",
"@com_github_tikv_client_go_v2//tikv",
Expand Down
6 changes: 6 additions & 0 deletions planner/core/explain.go
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,9 @@ func (p *basePhysicalAgg) explainInfo(normalized bool) string {
builder.WriteString(", ")
}
}
if p.TiFlashFineGrainedShuffleStreamCount > 0 {
builder.WriteString(fmt.Sprintf(", stream_count: %d", p.TiFlashFineGrainedShuffleStreamCount))
}
return builder.String()
}

Expand Down Expand Up @@ -543,6 +546,9 @@ func (p *PhysicalHashJoin) explainInfo(normalized bool) string {
buffer.WriteString(", other cond:")
buffer.Write(sortedExplainExpressionList(p.OtherConditions))
}
if p.TiFlashFineGrainedShuffleStreamCount > 0 {
buffer.WriteString(fmt.Sprintf(", stream_count: %d", p.TiFlashFineGrainedShuffleStreamCount))
}
return buffer.String()
}

Expand Down
Loading