Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small codebase refactors / cleanups #75

Merged
merged 6 commits into from
Jan 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions cmd/rtrdump/rtrdump.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@ package main
import (
"crypto/tls"
"encoding/json"
"errors"
"flag"
"fmt"
"io"
"net"
"os"
"runtime"
"strings"

rtr "github.com/bgp/stayrtr/lib"
Expand Down Expand Up @@ -125,8 +123,6 @@ func (c *Client) ClientDisconnected(cs *rtr.ClientSession) {
}

func main() {
runtime.GOMAXPROCS(runtime.NumCPU())

flag.Parse()
if flag.NArg() > 0 {
fmt.Printf("%s: illegal positional argument(s) provided (\"%s\") - did you mean to provide a flag?\n", os.Args[0], strings.Join(flag.Args(), " "))
Expand Down Expand Up @@ -167,7 +163,7 @@ func main() {
serverKeyHash := ssh.FingerprintSHA256(key)
if *ValidateSSH {
if serverKeyHash != fmt.Sprintf("SHA256:%v", *SSHServerKey) {
return errors.New(fmt.Sprintf("Server key hash %v is different than expected key hash SHA256:%v", serverKeyHash, *SSHServerKey))
return fmt.Errorf("server key hash %v is different than expected key hash SHA256:%v", serverKeyHash, *SSHServerKey)
}
}
log.Infof("Connected to server %v via ssh. Fingerprint: %v", remote.String(), serverKeyHash)
Expand Down Expand Up @@ -212,10 +208,10 @@ func main() {
var f io.Writer
if *OutFile != "" {
ff, err := os.Create(*OutFile)
defer ff.Close()
if err != nil {
log.Fatal(err)
}
defer ff.Close()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

f = ff
} else {
f = os.Stdout
Expand Down
15 changes: 4 additions & 11 deletions cmd/rtrmon/rtrmon.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
"net/http"
"net/url"
"os"
"runtime"
"sort"
"strconv"
"strings"
Expand Down Expand Up @@ -263,15 +262,13 @@ func (c *Client) Start(id int, ch chan int) {
}

connType := pathUrl.Scheme
rtrAddr := fmt.Sprintf("%s", pathUrl.Host)
rtrAddr := pathUrl.Host

bypass := true
for {

if !bypass {
select {
case <-time.After(c.RefreshInterval):
}
<-time.After(c.RefreshInterval)
}
bypass = false

Expand Down Expand Up @@ -325,10 +322,8 @@ func (c *Client) Start(id int, ch chan int) {
log.Fatal(err)
}

select {
case <-c.qrtr:
log.Infof("%d: Quitting RTR session", id)
}
<-c.qrtr
log.Infof("%d: Quitting RTR session", id)
} else {
log.Infof("%d: Fetching %s", c.id, c.Path)
data, statusCode, _, err := c.FetchConfig.FetchFile(c.Path)
Expand Down Expand Up @@ -873,8 +868,6 @@ func (c *Comparator) Start() error {
}

func main() {
runtime.GOMAXPROCS(runtime.NumCPU())

flag.Parse()
if flag.NArg() > 0 {
fmt.Printf("%s: illegal positional argument(s) provided (\"%s\") - did you mean to provide a flag?\n", os.Args[0], strings.Join(flag.Args(), " "))
Expand Down
10 changes: 5 additions & 5 deletions cmd/stayrtr/stayrtr.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ func (s *state) updateFromNewState() error {
sessid := s.server.GetSessionId()

vrpsjson := s.lastdata.Data
if (vrpsjson == nil) {
if vrpsjson == nil {
return nil
}

Expand Down Expand Up @@ -566,7 +566,7 @@ func run() error {
}

if *Bind == "" && *BindTLS == "" && *BindSSH == "" {
log.Fatalf("Specify at least a bind address")
log.Fatalf("Specify at least a bind address using -bind , -tls.bind , or -ssh.bind")
}

_, err := s.updateFile(*CacheBin)
Expand Down Expand Up @@ -639,7 +639,7 @@ func run() error {
}
private, err := ssh.ParsePrivateKey(sshkey)
if err != nil {
log.Fatal("Failed to parse private key: ", err)
log.Fatal("Failed to parse SSH private key: ", err)
}

sshConfig := ssh.ServerConfig{}
Expand All @@ -654,7 +654,7 @@ func run() error {
log.Infof("Connected (ssh-password): %v/%v", conn.User(), conn.RemoteAddr())
if conn.User() != *SSHAuthUser || !bytes.Equal(suppliedPassword, []byte(password)) {
log.Warnf("Wrong user or password for %v/%v. Disconnecting.", conn.User(), conn.RemoteAddr())
return nil, errors.New("Wrong user or password")
return nil, errors.New("wrong user or password")
}

return &ssh.Permissions{
Expand Down Expand Up @@ -693,7 +693,7 @@ func run() error {
}
if !noKeys {
log.Warnf("No key for %v/%v %v %v. Disconnecting.", conn.User(), conn.RemoteAddr(), key.Type(), keyBase64)
return nil, errors.New("Key not found")
return nil, errors.New("provided ssh key not found")
}
} else {
log.Infof("Connected (ssh-key): %v/%v with key %v %v", conn.User(), conn.RemoteAddr(), key.Type(), keyBase64)
Expand Down
3 changes: 1 addition & 2 deletions lib/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package rtrlib

import (
"crypto/tls"
"errors"
"fmt"
"io"
"net"
Expand Down Expand Up @@ -220,6 +219,6 @@ func (c *ClientSession) Start(addr string, connType int, configTLS *tls.Config,
case TYPE_SSH:
return c.StartSSH(addr, configSSH)
default:
return errors.New(fmt.Sprintf("Unknown type %v", connType))
return fmt.Errorf("unknown ClientSession type %v", connType)
}
}
2 changes: 1 addition & 1 deletion lib/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,7 @@ func (r VRP) String() string {
}

func (r1 VRP) Equals(r2 VRP) bool {
return r1.MaxLen == r2.MaxLen && r1.ASN == r2.ASN && bytes.Equal(r1.Prefix.IP, r2.Prefix.IP) && bytes.Equal(r1.Prefix.Mask, r2.Prefix.Mask)
return r1.MaxLen == r2.MaxLen && r1.ASN == r2.ASN && r1.Prefix.IP.Equal(r2.Prefix.IP) && bytes.Equal(r1.Prefix.Mask, r2.Prefix.Mask)
}

func (r1 VRP) Copy() VRP {
Expand Down
32 changes: 16 additions & 16 deletions lib/structs.go
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ func DecodeBytes(b []byte) (PDU, error) {

func Decode(rdr io.Reader) (PDU, error) {
if rdr == nil {
return nil, errors.New("Reader for decoding is nil")
return nil, errors.New("reader for decoding is nil")
}
var pver uint8
var pduType uint8
Expand All @@ -536,10 +536,10 @@ func Decode(rdr io.Reader) (PDU, error) {
}

if length < 8 {
return nil, fmt.Errorf("Wrong length: %d < 8", length)
return nil, fmt.Errorf("wrong length: %d < 8", length)
}
if length > messageMaxSize {
return nil, fmt.Errorf("Wrong length: %d > %d", length, messageMaxSize)
return nil, fmt.Errorf("wrong length: %d > %d", length, messageMaxSize)
}
toread := make([]byte, length-8)
err = binary.Read(rdr, binary.BigEndian, toread)
Expand All @@ -550,7 +550,7 @@ func Decode(rdr io.Reader) (PDU, error) {
switch pduType {
case PDU_ID_SERIAL_NOTIFY:
if len(toread) != 4 {
return nil, fmt.Errorf("Wrong length for Serial Notify PDU: %d != 4", len(toread))
return nil, fmt.Errorf("wrong length for Serial Notify PDU: %d != 4", len(toread))
}
serial := binary.BigEndian.Uint32(toread)
return &PDUSerialNotify{
Expand All @@ -560,7 +560,7 @@ func Decode(rdr io.Reader) (PDU, error) {
}, nil
case PDU_ID_SERIAL_QUERY:
if len(toread) != 4 {
return nil, fmt.Errorf("Wrong length for Serial Query PDU: %d != 4", len(toread))
return nil, fmt.Errorf("wrong length for Serial Query PDU: %d != 4", len(toread))
}
serial := binary.BigEndian.Uint32(toread)
return &PDUSerialQuery{
Expand All @@ -570,22 +570,22 @@ func Decode(rdr io.Reader) (PDU, error) {
}, nil
case PDU_ID_RESET_QUERY:
if len(toread) != 0 {
return nil, fmt.Errorf("Wrong length for Reset Query PDU: %d != 0", len(toread))
return nil, fmt.Errorf("wrong length for Reset Query PDU: %d != 0", len(toread))
}
return &PDUResetQuery{
Version: pver,
}, nil
case PDU_ID_CACHE_RESPONSE:
if len(toread) != 0 {
return nil, fmt.Errorf("Wrong length for Cache Response PDU: %d != 0", len(toread))
return nil, fmt.Errorf("wrong length for Cache Response PDU: %d != 0", len(toread))
}
return &PDUCacheResponse{
Version: pver,
SessionId: sessionId,
}, nil
case PDU_ID_IPV4_PREFIX:
if len(toread) != 12 {
return nil, fmt.Errorf("Wrong length for IPv4 Prefix PDU: %d != 12", len(toread))
return nil, fmt.Errorf("wrong length for IPv4 Prefix PDU: %d != 12", len(toread))
}
prefixLen := int(toread[1])
ip := toread[4:8]
Expand All @@ -603,7 +603,7 @@ func Decode(rdr io.Reader) (PDU, error) {
}, nil
case PDU_ID_IPV6_PREFIX:
if len(toread) != 24 {
return nil, fmt.Errorf("Wrong length for IPv6 Prefix PDU: %d != 24", len(toread))
return nil, fmt.Errorf("wrong length for IPv6 Prefix PDU: %d != 24", len(toread))
}
prefixLen := int(toread[1])
ip := toread[4:20]
Expand All @@ -621,7 +621,7 @@ func Decode(rdr io.Reader) (PDU, error) {
}, nil
case PDU_ID_END_OF_DATA:
if len(toread) != 4 && len(toread) != 16 {
return nil, fmt.Errorf("Wrong length for End of Data PDU: %d != 4 or != 16", len(toread))
return nil, fmt.Errorf("wrong length for End of Data PDU: %d != 4 or != 16", len(toread))
}

var serial uint32
Expand All @@ -647,14 +647,14 @@ func Decode(rdr io.Reader) (PDU, error) {
}, nil
case PDU_ID_CACHE_RESET:
if len(toread) != 0 {
return nil, fmt.Errorf("Wrong length for Cache Reset PDU: %d != 0", len(toread))
return nil, fmt.Errorf("wrong length for Cache Reset PDU: %d != 0", len(toread))
}
return &PDUCacheReset{
Version: pver,
}, nil
case PDU_ID_ROUTER_KEY:
if len(toread) != 28 {
return nil, fmt.Errorf("Wrong length for Router Key PDU: %d < 8", len(toread))
return nil, fmt.Errorf("wrong length for Router Key PDU: %d < 8", len(toread))
}
asn := binary.BigEndian.Uint32(toread[20:24])
spki := binary.BigEndian.Uint32(toread[24:28])
Expand All @@ -668,18 +668,18 @@ func Decode(rdr io.Reader) (PDU, error) {
}, nil
case PDU_ID_ERROR_REPORT:
if len(toread) < 8 {
return nil, fmt.Errorf("Wrong length for Error Report PDU: %d < 8", len(toread))
return nil, fmt.Errorf("wrong length for Error Report PDU: %d < 8", len(toread))
}
lenPdu := binary.BigEndian.Uint32(toread[0:4])
if len(toread) < int(lenPdu)+8 {
return nil, fmt.Errorf("Wrong length for Error Report PDU: %d < %d", len(toread), lenPdu+4)
return nil, fmt.Errorf("wrong length for Error Report PDU: %d < %d", len(toread), lenPdu+4)
}
errPdu := toread[4 : lenPdu+4]
lenErrText := binary.BigEndian.Uint32(toread[lenPdu+4 : lenPdu+8])
// int casting for each value is needed here to prevent an uint32 overflow that could result in
// upper bound being lower than lower bound causing a crash
if len(toread) < int(lenPdu)+8+int(lenErrText) {
return nil, fmt.Errorf("Wrong length for Error Report PDU: %d < %d", len(toread), lenPdu+8+lenErrText)
return nil, fmt.Errorf("wrong length for Error Report PDU: %d < %d", len(toread), lenPdu+8+lenErrText)
}
errMsg := string(toread[lenPdu+8 : lenPdu+8+lenErrText])
return &PDUErrorReport{
Expand All @@ -689,6 +689,6 @@ func Decode(rdr io.Reader) (PDU, error) {
ErrorMsg: errMsg,
}, nil
default:
return nil, errors.New("Could not decode packet")
return nil, errors.New("could not decode packet")
}
}
11 changes: 5 additions & 6 deletions prefixfile/prefixfile.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package prefixfile

import (
"errors"
"fmt"
"net"
"strconv"
Expand All @@ -17,8 +16,8 @@ type VRPJson struct {
}

type MetaData struct {
Counts int `json:"vrps"`
Buildtime string `json:"buildtime,omitempty"`
Counts int `json:"vrps"`
Buildtime string `json:"buildtime,omitempty"`
}

type VRPList struct {
Expand All @@ -32,7 +31,7 @@ func (vrp *VRPJson) GetASN2() (uint32, error) {
asnStr := strings.TrimLeft(asnc, "aAsS")
asnInt, err := strconv.ParseUint(asnStr, 10, 32)
if err != nil {
return 0, errors.New(fmt.Sprintf("Could not decode ASN string: %v", vrp.ASN))
return 0, fmt.Errorf("could not decode ASN string: %v", vrp.ASN)
}
asn := uint32(asnInt)
return asn, nil
Expand All @@ -43,7 +42,7 @@ func (vrp *VRPJson) GetASN2() (uint32, error) {
case int:
return uint32(asnc), nil
default:
return 0, errors.New(fmt.Sprintf("Could not decode ASN: %v", vrp.ASN))
return 0, fmt.Errorf("could not decode ASN: %v", vrp.ASN)
}
}

Expand All @@ -55,7 +54,7 @@ func (vrp *VRPJson) GetASN() uint32 {
func (vrp *VRPJson) GetPrefix2() (*net.IPNet, error) {
_, prefix, err := net.ParseCIDR(vrp.Prefix)
if err != nil {
return nil, errors.New(fmt.Sprintf("Could not decode prefix: %v", vrp.Prefix))
return nil, fmt.Errorf("could not decode prefix: %v", vrp.Prefix)
}
return prefix, nil
}
Expand Down