Skip to content

Commit

Permalink
Add concurrency config
Browse files Browse the repository at this point in the history
Visualize progress
  • Loading branch information
Alexander Sheiko committed Sep 23, 2023
1 parent 23d9b35 commit 8954111
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 57 deletions.
1 change: 1 addition & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
type config struct {
All bool `arg:"--all,-a" help:"Scan all providers"`
AltPing bool `arg:"--alt-ping" help:"Use alternative ICMP ping method"`
Concurrent int `arg:"--concurrent,-C" placeholder:"NUM" default:"10" help:"Number of concurrent pings"`
Count int `arg:"--count,-c" placeholder:"NUM" default:"4" help:"Number of pings to send"`
HideErrors bool `arg:"--hide-errors,-e" help:"Hide errors from results"`
FilterRegion []string `arg:"--region,-r,separate" placeholder:"NAME" help:"Filter by regions, can be specified multiple times"`
Expand Down
43 changes: 35 additions & 8 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package main
import (
"fmt"
"os"
"runtime"
"sort"
"strings"
"sync"
Expand All @@ -13,30 +14,56 @@ import (
"github.com/jedib0t/go-pretty/v6/text"
)

var (
args config
wg sync.WaitGroup
)
var args config

func main() {
arg.MustParse(&args)
args.Provider = strings.ToLower(args.Provider)

semaphore := make(chan struct{}, args.Concurrent)
var wg sync.WaitGroup

for key, provider := range providers {
if args.All || key == args.Provider {
for _, region := range provider.regions {
if (len(args.FilterRegion) > 0 && !isFiltered(region.name, args.FilterRegion)) ||
(len(args.FilterLocation) > 0 && !isFiltered(region.location, args.FilterLocation)) {
for _, r := range provider.regions {
if (len(args.FilterRegion) > 0 && !isFiltered(r.name, args.FilterRegion)) ||
(len(args.FilterLocation) > 0 && !isFiltered(r.location, args.FilterLocation)) {
continue
}
if r.endpoint == "" {
code := r.name
if r.code != "" {
code = r.code
}
r.endpoint = fmt.Sprintf(provider.hostTemplate, code)
}
semaphore <- struct{}{}
wg.Add(1)
go endpointPing(provider.hostTemplate, region)
go func(region *region) {
defer wg.Done()
defer func() { <-semaphore }()
region.rtt, region.err = endpointPing(region.endpoint)
if region.rtt == 0 && region.err == nil {
region.err = fmt.Errorf("timeout")
}
if region.err != nil {
fmt.Print(".")
} else {
fmt.Print("!")
}
}(r)
time.Sleep(time.Millisecond)
}
}
}
wg.Wait()

clearLine := "\033[2K\r"
if runtime.GOOS == "windows" {
clearLine = "\r"
}
fmt.Print(clearLine)

results := []*region{}
for key, provider := range providers {
if args.All || key == args.Provider {
Expand Down
72 changes: 23 additions & 49 deletions ping.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,54 +13,28 @@ import (
"github.com/valyala/fasthttp"
)

func endpointPing(template string, region *region) {
defer wg.Done()

if region.endpoint == "" {
code := region.name
if region.code != "" {
code = region.code
}
region.endpoint = fmt.Sprintf(template, code)
}

var err error
if strings.HasPrefix(region.endpoint, "http") {
region.rtt, err = httpPing(region.endpoint, args.Count, args.Timeout)
if err != nil {
region.err = err
}
} else if strings.HasPrefix(region.endpoint, "tcp") {
address := strings.TrimPrefix(region.endpoint, "tcp://")
region.rtt, err = tcpPing(address, args.Count, args.Timeout)
if err != nil {
region.err = err
}
} else {
icmpPingFn := icmpPing
if args.AltPing {
icmpPingFn = icmpAltPing
}
region.rtt, err = icmpPingFn(region.endpoint, args.Count, args.Timeout)
if err != nil {
region.err = err
}
func endpointPing(endpoint string) (time.Duration, error) {
if strings.HasPrefix(endpoint, "http") {
return httpPing(endpoint)
} else if strings.HasPrefix(endpoint, "tcp") {
address := strings.TrimPrefix(endpoint, "tcp://")
return tcpPing(address)
}

if region.rtt == 0 && region.err == nil {
region.err = fmt.Errorf("timeout")
if args.AltPing {
return icmpAltPing(endpoint)
}
return icmpPing(endpoint)
}

func icmpPing(endpoint string, count int, timeout time.Duration) (time.Duration, error) {
func icmpPing(endpoint string) (time.Duration, error) {
pinger, err := probing.NewPinger(endpoint)
if err != nil {
return 0, err
}
defer pinger.Stop()

pinger.Count = count
pinger.Timeout = timeout
pinger.Count = args.Count
pinger.Timeout = args.Timeout

resolve_count := 0
for {
Expand All @@ -80,14 +54,14 @@ func icmpPing(endpoint string, count int, timeout time.Duration) (time.Duration,
return pinger.Statistics().AvgRtt, nil
}

func icmpAltPing(endpoint string, count int, timeout time.Duration) (time.Duration, error) {
func icmpAltPing(endpoint string) (time.Duration, error) {
pinger, err := ping.New(endpoint)
if err != nil {
return 0, err
}

pinger.SetCount(count)
pinger.SetTimeout(timeout.String())
pinger.SetCount(args.Count)
pinger.SetTimeout(args.Timeout.String())

r, err := pinger.Run()
if err != nil {
Expand All @@ -105,18 +79,18 @@ func icmpAltPing(endpoint string, count int, timeout time.Duration) (time.Durati
return time.Duration((rtt / float64(count_success)) * float64(time.Millisecond)), nil
}

func httpPing(url string, count int, timeout time.Duration) (time.Duration, error) {
func httpPing(url string) (time.Duration, error) {
var rtt time.Duration
client := fasthttp.Client{
ReadTimeout: timeout,
ReadTimeout: args.Timeout,
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
}
req := fasthttp.AcquireRequest()
req.Header.SetMethod(fasthttp.MethodGet)
defer fasthttp.ReleaseRequest(req)
for i := 0; i < count; i++ {
for i := 0; i < args.Count; i++ {
stampedURL := fmt.Sprintf("%s?%d", url, time.Now().UnixNano()/int64(time.Millisecond))
req.SetRequestURI(stampedURL)
start := time.Now()
Expand All @@ -128,19 +102,19 @@ func httpPing(url string, count int, timeout time.Duration) (time.Duration, erro
}
rtt += time.Since(start)
}
return rtt / time.Duration(count), nil
return rtt / time.Duration(args.Count), nil
}

func tcpPing(endpoint string, count int, timeout time.Duration) (time.Duration, error) {
func tcpPing(endpoint string) (time.Duration, error) {
var rtt time.Duration
for i := 0; i < count; i++ {
for i := 0; i < args.Count; i++ {
start := time.Now()
conn, err := net.DialTimeout("tcp", endpoint, timeout)
conn, err := net.DialTimeout("tcp", endpoint, args.Timeout)
if err != nil {
return 0, err
}
conn.Close()
rtt += time.Since(start)
}
return rtt / time.Duration(count), nil
return rtt / time.Duration(args.Count), nil
}

0 comments on commit 8954111

Please sign in to comment.