Skip to content

Commit

Permalink
简单重构
Browse files Browse the repository at this point in the history
  • Loading branch information
Kisesy committed Jul 5, 2023
1 parent e1e343a commit b40fd7e
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 149 deletions.
64 changes: 32 additions & 32 deletions gscan.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,56 +48,56 @@ func init() {
log.SetFlags(log.LstdFlags | log.Lshortfile)
}

func initConfig(cfgfile, execFolder string) *GScanConfig {
if strings.HasPrefix(cfgfile, "./") {
cfgfile = filepath.Join(execFolder, cfgfile)
func initConfig(cfgFile, execFolder string) *GScanConfig {
if strings.HasPrefix(cfgFile, "./") {
cfgFile = filepath.Join(execFolder, cfgFile)
}

gcfg := new(GScanConfig)
if err := readJsonConfig(cfgfile, gcfg); err != nil {
config := new(GScanConfig)
if err := readJsonConfig(cfgFile, config); err != nil {
log.Panicln(err)
}

if gcfg.EnableBackup {
if strings.HasPrefix(gcfg.BackupDir, "./") {
gcfg.BackupDir = filepath.Join(execFolder, gcfg.BackupDir)
if config.EnableBackup {
if strings.HasPrefix(config.BackupDir, "./") {
config.BackupDir = filepath.Join(execFolder, config.BackupDir)
}
if _, err := os.Stat(gcfg.BackupDir); os.IsNotExist(err) {
if err := os.MkdirAll(gcfg.BackupDir, 0o755); err != nil {
if _, err := os.Stat(config.BackupDir); os.IsNotExist(err) {
if err := os.MkdirAll(config.BackupDir, 0o755); err != nil {
log.Println(err)
}
}
}

gcfg.ScanMode = strings.ToLower(gcfg.ScanMode)
if gcfg.ScanMode == "ping" {
gcfg.VerifyPing = false
config.ScanMode = strings.ToLower(config.ScanMode)
if config.ScanMode == "ping" {
config.VerifyPing = false
}

gcfg.ScanMinPingRTT *= time.Millisecond
gcfg.ScanMaxPingRTT *= time.Millisecond
config.ScanMinPingRTT *= time.Millisecond
config.ScanMaxPingRTT *= time.Millisecond

cfgs := []*ScanConfig{&gcfg.Quic, &gcfg.Tls, &gcfg.Sni, &gcfg.Ping}
for _, c := range cfgs {
if strings.HasPrefix(c.InputFile, "./") {
c.InputFile = filepath.Join(execFolder, c.InputFile)
scanConfigs := []*ScanConfig{&config.Quic, &config.Tls, &config.Sni, &config.Ping}
for _, scanConfig := range scanConfigs {
if strings.HasPrefix(scanConfig.InputFile, "./") {
scanConfig.InputFile = filepath.Join(execFolder, scanConfig.InputFile)
} else {
c.InputFile, _ = filepath.Abs(c.InputFile)
scanConfig.InputFile, _ = filepath.Abs(scanConfig.InputFile)
}
if strings.HasPrefix(c.OutputFile, "./") {
c.OutputFile = filepath.Join(execFolder, c.OutputFile)
if strings.HasPrefix(scanConfig.OutputFile, "./") {
scanConfig.OutputFile = filepath.Join(execFolder, scanConfig.OutputFile)
} else {
c.OutputFile, _ = filepath.Abs(c.OutputFile)
scanConfig.OutputFile, _ = filepath.Abs(scanConfig.OutputFile)
}
if _, err := os.Stat(c.InputFile); os.IsNotExist(err) {
os.Create(c.InputFile)
if _, err := os.Stat(scanConfig.InputFile); os.IsNotExist(err) {
os.Create(scanConfig.InputFile)
}

c.ScanMinRTT *= time.Millisecond
c.ScanMaxRTT *= time.Millisecond
c.HandshakeTimeout *= time.Millisecond
scanConfig.ScanMinRTT *= time.Millisecond
scanConfig.ScanMaxRTT *= time.Millisecond
scanConfig.HandshakeTimeout *= time.Millisecond
}
return gcfg
return config
}

func main() {
Expand Down Expand Up @@ -196,15 +196,15 @@ func main() {
if err := os.WriteFile(cfg.OutputFile, b.Bytes(), 0o644); err != nil {
log.Printf("Failed to write output file:%s for reason:%v\n", cfg.OutputFile, err)
} else {
log.Printf("All results writed to %s\n", cfg.OutputFile)
log.Printf("All results written to %s\n", cfg.OutputFile)
}
if gcfg.EnableBackup {
filename := fmt.Sprintf("%s_%s_lv%d.txt", scanMode, time.Now().Format(time.DateTime), cfg.Level)
filename := fmt.Sprintf("%s_%s_lv%d.txt", scanMode, time.Now().Format("20060102_150405"), cfg.Level)
bakfilename := filepath.Join(gcfg.BackupDir, filename)
if err := os.WriteFile(bakfilename, b.Bytes(), 0o644); err != nil {
log.Printf("Failed to write output file:%s for reason:%v\n", bakfilename, err)
} else {
log.Printf("All results writed to %s\n", bakfilename)
log.Printf("All results written to %s\n", bakfilename)
}
}
}
3 changes: 2 additions & 1 deletion ping.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@ package main

import (
"bytes"
"context"
"errors"
"net"
"os"
"time"
)

func testPing(ip string, config *ScanConfig, record *ScanRecord) bool {
func testPing(ctx context.Context, ip string, config *ScanConfig, record *ScanRecord) bool {
start := time.Now()
if err := Pinger(ip, config.ScanMaxRTT); err != nil {
return false
Expand Down
41 changes: 15 additions & 26 deletions quic.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"context"
"crypto/tls"
"io"
"io/ioutil"
"math/rand"
"net"
"net/http"
Expand All @@ -18,26 +17,19 @@ import (

var errNoSuchBucket = []byte("<?xml version='1.0' encoding='UTF-8'?><Error><Code>NoSuchBucket</Code><Message>The specified bucket does not exist.</Message></Error>")

func testQuic(ip string, config *ScanConfig, record *ScanRecord) bool {
func testQuic(ctx context.Context, ip string, config *ScanConfig, record *ScanRecord) bool {
start := time.Now()

udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
return false
}
udpConn.SetDeadline(time.Now().Add(config.ScanMaxRTT))
defer udpConn.Close()

quicCfg := &quic.Config{
HandshakeIdleTimeout: config.HandshakeTimeout,
KeepAlivePeriod: 0,
}

var serverName string
serverName := ""
if len(config.ServerName) == 0 {
serverName = randomHost()
} else {
serverName = config.ServerName[rand.Intn(len(config.ServerName))]
serverName = randomChoice(config.ServerName)
}

tlsCfg := &tls.Config{
Expand All @@ -46,28 +38,25 @@ func testQuic(ip string, config *ScanConfig, record *ScanRecord) bool {
NextProtos: []string{"h3-29", "h3", "hq", "quic"},
}

ctx := context.TODO()
udpAddr := &net.UDPAddr{IP: net.ParseIP(ip), Port: 443}
quicSessn, err := quic.DialEarly(ctx, udpConn, udpAddr, tlsCfg, quicCfg)
ctx, cancel := context.WithTimeout(ctx, config.ScanMaxRTT)
defer cancel()

quicConn, err := quic.DialAddrEarly(ctx, net.JoinHostPort(ip, "443"), tlsCfg, quicCfg)
if err != nil {
return false
}
defer quicSessn.CloseWithError(0, "")
defer quicConn.CloseWithError(0, "")

// lv1 只会验证证书是否存在
cs := quicSessn.ConnectionState().TLS
if !cs.HandshakeComplete {
return false
}
pcs := cs.PeerCertificates
if len(pcs) < 2 {
cs := quicConn.ConnectionState().TLS
if !cs.HandshakeComplete || len(cs.PeerCertificates) < 2 {
return false
}

// lv2 验证证书是否正确
if config.Level > 1 {
pkp := pcs[1].RawSubjectPublicKeyInfo
if !bytes.Equal(g2pkp, pkp) && !bytes.Equal(g3pkp, pkp) { // && !bytes.Equal(g3ecc, pkp[:]) {
pkp := cs.PeerCertificates[1].RawSubjectPublicKeyInfo
if !bytes.Equal(gpkp, pkp) {
return false
}
}
Expand All @@ -77,7 +66,7 @@ func testQuic(ip string, config *ScanConfig, record *ScanRecord) bool {
tr := &http3.RoundTripper{DisableCompression: true}
defer tr.Close()
tr.Dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
return quicSessn, err
return quicConn, err
}
// 设置超时
hclient := &http.Client{
Expand All @@ -98,12 +87,12 @@ func testQuic(ip string, config *ScanConfig, record *ScanRecord) bool {
defer resp.Body.Close()
// lv4 验证是否是 NoSuchBucket 错误
if config.Level > 3 && resp.Header.Get("Content-Type") == "application/xml; charset=UTF-8" { // 也许条件改为 || 更好
body, err := ioutil.ReadAll(resp.Body)
body, err := io.ReadAll(resp.Body)
if err != nil || bytes.Equal(body, errNoSuchBucket) {
return false
}
} else {
io.Copy(ioutil.Discard, resp.Body)
io.Copy(io.Discard, resp.Body)
}
}
}
Expand Down
89 changes: 26 additions & 63 deletions scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ type ScanRecord struct {
}

type ScanRecords struct {
recordMutex sync.RWMutex
recordMutex sync.Mutex
records []*ScanRecord
scanCounter int32
}
Expand All @@ -36,21 +36,21 @@ func (srs *ScanRecords) IncScanCounter() {
}

func (srs *ScanRecords) RecordSize() int {
srs.recordMutex.RLock()
defer srs.recordMutex.RUnlock()
srs.recordMutex.Lock()
defer srs.recordMutex.Unlock()
return len(srs.records)
}

func (srs *ScanRecords) ScanCount() int32 {
return atomic.LoadInt32(&srs.scanCounter)
}

var testIPFunc func(ip string, config *ScanConfig, record *ScanRecord) bool
var testIPFunc func(ctx context.Context, ip string, config *ScanConfig, record *ScanRecord) bool

func testip(ip string, config *ScanConfig) *ScanRecord {
func testip(ctx context.Context, ip string, config *ScanConfig) *ScanRecord {
record := new(ScanRecord)
for i := 0; i < config.ScanCountPerIP; i++ {
if !testIPFunc(ip, config, record) {
if !testIPFunc(ctx, ip, config, record) {
return nil
}
}
Expand All @@ -59,86 +59,49 @@ func testip(ip string, config *ScanConfig) *ScanRecord {
return record
}

func testip_worker(ctx context.Context, ch chan string, gcfg *GScanConfig, cfg *ScanConfig, srs *ScanRecords, wg *sync.WaitGroup) {
defer wg.Done()

timer := time.NewTimer(cfg.ScanMaxRTT + 100*time.Millisecond)
defer timer.Stop()

ctx, cancal := context.WithCancel(ctx)
defer cancal()

for ip := range ch {
func testIPWorker(ctx context.Context, ipQueue chan string, gcfg *GScanConfig, cfg *ScanConfig, srs *ScanRecords) {
for ip := range ipQueue {
srs.IncScanCounter()

if gcfg.VerifyPing {
start := time.Now()
if err := Ping(ip, gcfg.ScanMaxPingRTT); err != nil {
continue
}
if time.Since(start) < gcfg.ScanMinPingRTT {

pingErr := Ping(ip, gcfg.ScanMaxPingRTT)
if pingErr != nil || time.Since(start) < gcfg.ScanMinPingRTT {
continue
}
}

done := make(chan struct{}, 1)
go func() {
r := testip(ip, cfg)
select {
case <-ctx.Done():
return
default:
r := testip(ctx, ip, cfg)
if r != nil {
srs.AddRecord(r) // 这里放到前面,扫描时可能会多出一些记录, 但是不影响
if srs.RecordSize() >= cfg.RecordLimit {
close(done)
return
}
srs.AddRecord(r)
}
done <- struct{}{}
}()

timer.Reset(cfg.ScanMaxRTT + 100*time.Millisecond)
select {
case <-ctx.Done():
return
case <-timer.C:
log.Println(ip, "timeout")
case <-done:
}

}
}

func StartScan(gcfg *GScanConfig, cfg *ScanConfig, ipqueue chan string) *ScanRecords {
func StartScan(gcfg *GScanConfig, cfg *ScanConfig, ipQueue chan string) *ScanRecords {
var wg sync.WaitGroup
var srs ScanRecords

wg.Add(gcfg.ScanWorker)

interrupt := make(chan os.Signal, 1)
signal.Notify(interrupt, os.Interrupt)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt)
defer stop()

go func() {
<-interrupt
cancel()
}()

ch := make(chan string, 100)
wg.Add(gcfg.ScanWorker)
for i := 0; i < gcfg.ScanWorker; i++ {
go testip_worker(ctx, ch, gcfg, cfg, &srs, &wg)
}

for ip := range ipqueue {
select {
case ch <- ip:
case <-ctx.Done():
return &srs
}
if srs.RecordSize() >= cfg.RecordLimit {
break
}
go func() {
defer wg.Done()
testIPWorker(ctx, ipQueue, gcfg, cfg, &srs)
}()
}

close(ch)
wg.Wait()
return &srs
}
Loading

0 comments on commit b40fd7e

Please sign in to comment.