diff --git a/aipcli/dial.go b/aipcli/dial.go index e9a3cc7..3098e0d 100644 --- a/aipcli/dial.go +++ b/aipcli/dial.go @@ -48,37 +48,7 @@ func dial(cmd *cobra.Command) (*grpc.ClientConn, error) { ) } if isForceTrace(cmd) { - traceID, err := generateTraceID() - if err != nil { - return nil, fmt.Errorf("failed to generate trace ID: %w", err) - } - spanID, err := generateSpanID() - if err != nil { - return nil, fmt.Errorf("failed to generate span ID: %w", err) - } - opts = append( - opts, - grpc.WithUnaryInterceptor(func( - ctx context.Context, - method string, - req interface{}, - reply interface{}, - cc *grpc.ClientConn, - invoker grpc.UnaryInvoker, - opts ...grpc.CallOption, - ) error { - // See https://cloud.google.com/trace/docs/setup#force-trace - const header = "x-cloud-trace-context" - value := fmt.Sprintf("%s/%d;o=1", traceID, spanID) - cmd.PrintErrln(">> trace ID:", traceID) - if IsVerbose(cmd) { - cmd.PrintErrln(">> span ID:", spanID) - cmd.PrintErrln(">> trace header:", header, "=", value) - } - ctx = metadata.AppendToOutgoingContext(ctx, header, value) - return invoker(ctx, method, req, reply, cc, opts...) - }), - ) + opts = append(opts, withForceTrace(cmd)) } systemCertPool, err := x509.SystemCertPool() if err != nil { @@ -101,6 +71,9 @@ func dialInsecure(cmd *cobra.Command) (*grpc.ClientConn, error) { if token != "" { opts = append(opts, grpc.WithPerRPCCredentials(insecureTokenCredentials(token))) } + if isForceTrace(cmd) { + opts = append(opts, withForceTrace(cmd)) + } return grpc.NewClient(withDefaultPort(address, 443), opts...) } @@ -130,18 +103,39 @@ func withDefaultPort(target string, port int) string { return target } -func generateTraceID() (string, error) { +func withForceTrace(cmd *cobra.Command) grpc.DialOption { + traceID := generateTraceID() + spanID := generateSpanID() + return grpc.WithUnaryInterceptor(func( + ctx context.Context, + method string, + req interface{}, + reply interface{}, + cc *grpc.ClientConn, + invoker grpc.UnaryInvoker, + opts ...grpc.CallOption, + ) error { + // See https://cloud.google.com/trace/docs/setup#force-trace + const header = "x-cloud-trace-context" + value := fmt.Sprintf("%s/%d;o=1", traceID, spanID) + cmd.PrintErrln(">> trace ID:", traceID) + if IsVerbose(cmd) { + cmd.PrintErrln(">> span ID:", spanID) + cmd.PrintErrln(">> trace header:", header, "=", value) + } + ctx = metadata.AppendToOutgoingContext(ctx, header, value) + return invoker(ctx, method, req, reply, cc, opts...) + }) +} + +func generateTraceID() string { var id [16]byte - if _, err := rand.Read(id[:]); err != nil { - return "", err - } - return hex.EncodeToString(id[:]), nil + _, _ = rand.Read(id[:]) // panics on error + return hex.EncodeToString(id[:]) } -func generateSpanID() (uint64, error) { +func generateSpanID() uint64 { var id [8]byte - if _, err := rand.Read(id[:]); err != nil { - return 0, err - } - return binary.LittleEndian.Uint64(id[:]), nil + _, _ = rand.Read(id[:]) // panics on error + return binary.LittleEndian.Uint64(id[:]) }