diff --git a/clock/clock.go b/clock/clock.go new file mode 100644 index 0000000..03a93a7 --- /dev/null +++ b/clock/clock.go @@ -0,0 +1,50 @@ +package clock + +import ( + "context" + "time" +) + +const WALLCLOCK_PRECISION = 1 * time.Second + +func AfterWallClock(d time.Duration) <-chan time.Time { + ch := make(chan time.Time, 1) + deadline := time.Now().Add(d).Truncate(0) + after_ch := time.After(d) + ticker := time.NewTicker(WALLCLOCK_PRECISION) + go func() { + var t time.Time + defer ticker.Stop() + for { + select { + case t = <-after_ch: + ch <- t + return + case t = <-ticker.C: + if t.After(deadline) { + ch <- t + return + } + } + } + }() + return ch +} + +func RunTicker(ctx context.Context, interval, retryInterval time.Duration, cb func(context.Context) error) { + go func() { + var err error + for { + nextInterval := interval + if err != nil { + nextInterval = retryInterval + } + select { + case <-ctx.Done(): + return + case <-AfterWallClock(nextInterval): + err = cb(ctx) + } + } + }() +} diff --git a/fixed.go b/dialer/fixed.go similarity index 97% rename from fixed.go rename to dialer/fixed.go index d5b5631..d407da4 100644 --- a/fixed.go +++ b/dialer/fixed.go @@ -1,4 +1,4 @@ -package main +package dialer import ( "context" diff --git a/resolver.go b/dialer/resolver.go similarity index 98% rename from resolver.go rename to dialer/resolver.go index 76ab612..a4c4f82 100644 --- a/resolver.go +++ b/dialer/resolver.go @@ -1,4 +1,4 @@ -package main +package dialer import ( "context" diff --git a/upstream.go b/dialer/upstream.go similarity index 82% rename from upstream.go rename to dialer/upstream.go index 3396564..362c1ee 100644 --- a/upstream.go +++ b/dialer/upstream.go @@ -1,4 +1,4 @@ -package main +package dialer import ( "bufio" @@ -6,6 +6,7 @@ import ( "context" "crypto/tls" "crypto/x509" + "encoding/base64" "encoding/pem" "errors" "fmt" @@ -47,11 +48,11 @@ CV4Ks2dH/hzg1cEo70qLRDEmBDeNiXQ2Lu+lIg+DdEmSx/cQwgwp+7e9un/jX9Wf ` ) -var UpstreamBlockedError = errors.New("blocked by upstream") - var missingLinkDER, _ = pem.Decode([]byte(MISSING_CHAIN_CERT)) var missingLink, _ = x509.ParseCertificate(missingLinkDER.Bytes) +type stringCb = func() (string, error) + type Dialer interface { Dial(network, address string) (net.Conn, error) } @@ -62,15 +63,15 @@ type ContextDialer interface { } type ProxyDialer struct { - address string - tlsServerName string - auth AuthProvider + address stringCb + tlsServerName stringCb + auth stringCb next ContextDialer intermediateWorkaround bool caPool *x509.CertPool } -func NewProxyDialer(address, tlsServerName string, auth AuthProvider, intermediateWorkaround bool, caPool *x509.CertPool, nextDialer ContextDialer) *ProxyDialer { +func NewProxyDialer(address, tlsServerName, auth stringCb, intermediateWorkaround bool, caPool *x509.CertPool, nextDialer ContextDialer) *ProxyDialer { return &ProxyDialer{ address: address, tlsServerName: tlsServerName, @@ -85,7 +86,7 @@ func ProxyDialerFromURL(u *url.URL, next ContextDialer) (*ProxyDialer, error) { host := u.Hostname() port := u.Port() tlsServerName := "" - var auth AuthProvider = nil + var auth stringCb = nil switch strings.ToLower(u.Scheme) { case "http": @@ -106,12 +107,9 @@ func ProxyDialerFromURL(u *url.URL, next ContextDialer) (*ProxyDialer, error) { if u.User != nil { username := u.User.Username() password, _ := u.User.Password() - authHeader := basic_auth_header(username, password) - auth = func() string { - return authHeader - } + auth = WrapStringToCb(BasicAuthHeader(username, password)) } - return NewProxyDialer(address, tlsServerName, auth, false, nil, next), nil + return NewProxyDialer(WrapStringToCb(address), WrapStringToCb(tlsServerName), auth, false, nil, next), nil } func (d *ProxyDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { @@ -121,12 +119,20 @@ func (d *ProxyDialer) DialContext(ctx context.Context, network, address string) return nil, errors.New("bad network specified for DialContext: only tcp is supported") } - conn, err := d.next.DialContext(ctx, "tcp", d.address) + uAddress, err := d.address() + if err != nil { + return nil, err + } + conn, err := d.next.DialContext(ctx, "tcp", uAddress) if err != nil { return nil, err } - if d.tlsServerName != "" { + uTLSServerName, err := d.tlsServerName() + if err != nil { + return nil, err + } + if uTLSServerName != "" { // Custom cert verification logic: // DO NOT send SNI extension of TLS ClientHello // DO peer certificate verification against specified servername @@ -135,7 +141,7 @@ func (d *ProxyDialer) DialContext(ctx context.Context, network, address string) InsecureSkipVerify: true, VerifyConnection: func(cs tls.ConnectionState) error { opts := x509.VerifyOptions{ - DNSName: d.tlsServerName, + DNSName: uTLSServerName, Intermediates: x509.NewCertPool(), Roots: d.caPool, } @@ -169,7 +175,11 @@ func (d *ProxyDialer) DialContext(ctx context.Context, network, address string) } if d.auth != nil { - req.Header.Set(PROXY_AUTHORIZATION_HEADER, d.auth()) + auth, err := d.auth() + if err != nil { + return nil, err + } + req.Header.Set(PROXY_AUTHORIZATION_HEADER, auth) } rawreq, err := httputil.DumpRequest(req, false) @@ -188,10 +198,6 @@ func (d *ProxyDialer) DialContext(ctx context.Context, network, address string) } if proxyResp.StatusCode != http.StatusOK { - if proxyResp.StatusCode == http.StatusForbidden && - proxyResp.Header.Get("X-Hola-Error") == "Forbidden Host" { - return nil, UpstreamBlockedError - } return nil, errors.New(fmt.Sprintf("bad response from upstream proxy server: %s", proxyResp.Status)) } @@ -228,3 +234,14 @@ func readResponse(r io.Reader, req *http.Request) (*http.Response, error) { } return http.ReadResponse(bufio.NewReader(buf), req) } + +func BasicAuthHeader(login, password string) string { + return "Basic " + base64.StdEncoding.EncodeToString( + []byte(login+":"+password)) +} + +func WrapStringToCb(s string) func() (string, error) { + return func() (string, error) { + return s, nil + } +} diff --git a/handler.go b/handler/handler.go similarity index 50% rename from handler.go rename to handler/handler.go index cb544f1..22224b2 100644 --- a/handler.go +++ b/handler/handler.go @@ -1,23 +1,33 @@ -package main +package handler import ( + "bufio" + "context" + "errors" "fmt" + "io" + "net" "net/http" "strings" + "sync" "time" -) -const BAD_REQ_MSG = "Bad Request\n" + "github.com/Snawoot/opera-proxy/dialer" + clog "github.com/Snawoot/opera-proxy/log" +) -type AuthProvider func() string +const ( + COPY_BUF = 128 * 1024 + BAD_REQ_MSG = "Bad Request\n" +) type ProxyHandler struct { - logger *CondLogger - dialer ContextDialer + logger *clog.CondLogger + dialer dialer.ContextDialer httptransport http.RoundTripper } -func NewProxyHandler(dialer ContextDialer, logger *CondLogger) *ProxyHandler { +func NewProxyHandler(dialer dialer.ContextDialer, logger *clog.CondLogger) *ProxyHandler { httptransport := &http.Transport{ MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, @@ -104,3 +114,128 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) { s.HandleRequest(wr, req) } } + +func proxy(ctx context.Context, left, right net.Conn) { + wg := sync.WaitGroup{} + cpy := func(dst, src net.Conn) { + defer wg.Done() + io.Copy(dst, src) + dst.Close() + } + wg.Add(2) + go cpy(left, right) + go cpy(right, left) + groupdone := make(chan struct{}) + go func() { + wg.Wait() + groupdone <- struct{}{} + }() + select { + case <-ctx.Done(): + left.Close() + right.Close() + case <-groupdone: + return + } + <-groupdone + return +} + +func proxyh2(ctx context.Context, leftreader io.ReadCloser, leftwriter io.Writer, right net.Conn) { + wg := sync.WaitGroup{} + ltr := func(dst net.Conn, src io.Reader) { + defer wg.Done() + io.Copy(dst, src) + dst.Close() + } + rtl := func(dst io.Writer, src io.Reader) { + defer wg.Done() + copyBody(dst, src) + } + wg.Add(2) + go ltr(right, leftreader) + go rtl(leftwriter, right) + groupdone := make(chan struct{}, 1) + go func() { + wg.Wait() + groupdone <- struct{}{} + }() + select { + case <-ctx.Done(): + leftreader.Close() + right.Close() + case <-groupdone: + return + } + <-groupdone + return +} + +// Hop-by-hop headers. These are removed when sent to the backend. +// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html +var hopHeaders = []string{ + "Connection", + "Keep-Alive", + "Proxy-Authenticate", + "Proxy-Connection", + "Te", // canonicalized version of "TE" + "Trailers", + "Transfer-Encoding", + "Upgrade", +} + +func copyHeader(dst, src http.Header) { + for k, vv := range src { + for _, v := range vv { + dst.Add(k, v) + } + } +} + +func delHopHeaders(header http.Header) { + for _, h := range hopHeaders { + header.Del(h) + } +} + +func hijack(hijackable interface{}) (net.Conn, *bufio.ReadWriter, error) { + hj, ok := hijackable.(http.Hijacker) + if !ok { + return nil, nil, errors.New("Connection doesn't support hijacking") + } + conn, rw, err := hj.Hijack() + if err != nil { + return nil, nil, err + } + var emptytime time.Time + err = conn.SetDeadline(emptytime) + if err != nil { + conn.Close() + return nil, nil, err + } + return conn, rw, nil +} + +func flush(flusher interface{}) bool { + f, ok := flusher.(http.Flusher) + if !ok { + return false + } + f.Flush() + return true +} + +func copyBody(wr io.Writer, body io.Reader) { + buf := make([]byte, COPY_BUF) + for { + bread, read_err := body.Read(buf) + var write_err error + if bread > 0 { + _, write_err = wr.Write(buf[:bread]) + flush(wr) + } + if read_err != nil || write_err != nil { + break + } + } +} diff --git a/condlog.go b/log/condlog.go similarity index 98% rename from condlog.go rename to log/condlog.go index 96a18f3..b40572a 100644 --- a/condlog.go +++ b/log/condlog.go @@ -1,4 +1,4 @@ -package main +package log import ( "fmt" diff --git a/logwriter.go b/log/logwriter.go similarity index 98% rename from logwriter.go rename to log/logwriter.go index 657c2f3..b36f727 100644 --- a/logwriter.go +++ b/log/logwriter.go @@ -1,4 +1,4 @@ -package main +package log import ( "errors" diff --git a/main.go b/main.go index 6b82a9d..9f774e8 100644 --- a/main.go +++ b/main.go @@ -22,6 +22,10 @@ import ( xproxy "golang.org/x/net/proxy" + "github.com/Snawoot/opera-proxy/clock" + "github.com/Snawoot/opera-proxy/dialer" + "github.com/Snawoot/opera-proxy/handler" + clog "github.com/Snawoot/opera-proxy/log" se "github.com/Snawoot/opera-proxy/seclient" ) @@ -151,12 +155,12 @@ func parse_args() *CLIArgs { } func proxyFromURLWrapper(u *url.URL, next xproxy.Dialer) (xproxy.Dialer, error) { - cdialer, ok := next.(ContextDialer) + cdialer, ok := next.(dialer.ContextDialer) if !ok { return nil, errors.New("only context dialers are accepted") } - return ProxyDialerFromURL(u, cdialer) + return dialer.ProxyDialerFromURL(u, cdialer) } func run() int { @@ -166,19 +170,19 @@ func run() int { return 0 } - logWriter := NewLogWriter(os.Stderr) + logWriter := clog.NewLogWriter(os.Stderr) defer logWriter.Close() - mainLogger := NewCondLogger(log.New(logWriter, "MAIN : ", + mainLogger := clog.NewCondLogger(log.New(logWriter, "MAIN : ", log.LstdFlags|log.Lshortfile), args.verbosity) - proxyLogger := NewCondLogger(log.New(logWriter, "PROXY : ", + proxyLogger := clog.NewCondLogger(log.New(logWriter, "PROXY : ", log.LstdFlags|log.Lshortfile), args.verbosity) mainLogger.Info("opera-proxy client version %s is starting...", version) - var dialer ContextDialer = &net.Dialer{ + var d dialer.ContextDialer = &net.Dialer{ Timeout: 30 * time.Second, KeepAlive: 30 * time.Second, } @@ -191,22 +195,22 @@ func run() int { mainLogger.Critical("Unable to parse base proxy URL: %v", err) return 6 } - pxDialer, err := xproxy.FromURL(proxyURL, dialer) + pxDialer, err := xproxy.FromURL(proxyURL, d) if err != nil { mainLogger.Critical("Unable to instantiate base proxy dialer: %v", err) return 7 } - dialer = pxDialer.(ContextDialer) + d = pxDialer.(dialer.ContextDialer) } - seclientDialer := dialer + seclientDialer := d if args.apiAddress != "" || len(args.bootstrapDNS.values) > 0 { var apiAddress string if args.apiAddress != "" { apiAddress = args.apiAddress mainLogger.Info("Using fixed API host IP address = %s", apiAddress) } else { - resolver, err := NewResolver(args.bootstrapDNS.values, args.timeout) + resolver, err := dialer.NewResolver(args.bootstrapDNS.values, args.timeout) if err != nil { mainLogger.Critical("Unable to instantiate DNS resolver: %v", err) return 4 @@ -234,7 +238,7 @@ func run() int { apiAddress = addrs[0].String() mainLogger.Info("Discovered address of API host = %s", apiAddress) } - seclientDialer = NewFixedDialer(apiAddress, dialer) + seclientDialer = dialer.NewFixedDialer(apiAddress, d) } // Dialing w/o SNI, receiving self-signed certificate, so skip verification. @@ -303,7 +307,7 @@ func run() int { return 13 } - runTicker(context.Background(), args.refresh, args.refreshRetry, func(ctx context.Context) error { + clock.RunTicker(context.Background(), args.refresh, args.refreshRetry, func(ctx context.Context) error { mainLogger.Info("Refreshing login...") reqCtx, cl := context.WithTimeout(ctx, args.timeout) defer cl() @@ -327,9 +331,6 @@ func run() int { }) endpoint := ips[0] - auth := func() string { - return basic_auth_header(seclient.GetProxyCredentials()) - } var caPool *x509.CertPool if args.caFile != "" { @@ -345,18 +346,26 @@ func run() int { } } - handlerDialer := NewProxyDialer(endpoint.NetAddr(), fmt.Sprintf("%s0.%s", args.country, PROXY_SUFFIX), auth, args.certChainWorkaround, caPool, dialer) + handlerDialer := dialer.NewProxyDialer( + dialer.WrapStringToCb(endpoint.NetAddr()), + dialer.WrapStringToCb(fmt.Sprintf("%s0.%s", args.country, PROXY_SUFFIX)), + func() (string, error) { + return dialer.BasicAuthHeader(seclient.GetProxyCredentials()), nil + }, + args.certChainWorkaround, + caPool, + d) mainLogger.Info("Endpoint: %s", endpoint.NetAddr()) mainLogger.Info("Starting proxy server...") - handler := NewProxyHandler(handlerDialer, proxyLogger) + h := handler.NewProxyHandler(handlerDialer, proxyLogger) mainLogger.Info("Init complete.") - err = http.ListenAndServe(args.bindAddress, handler) + err = http.ListenAndServe(args.bindAddress, h) mainLogger.Critical("Server terminated with a reason: %v", err) mainLogger.Info("Shutting down...") return 0 } -func printCountries(logger *CondLogger, timeout time.Duration, seclient *se.SEClient) int { +func printCountries(logger *clog.CondLogger, timeout time.Duration, seclient *se.SEClient) int { ctx, cl := context.WithTimeout(context.Background(), timeout) defer cl() list, err := seclient.GeoList(ctx) @@ -380,7 +389,7 @@ func printProxies(ips []se.SEIPEntry, seclient *se.SEClient) int { login, password := seclient.GetProxyCredentials() fmt.Println("Proxy login:", login) fmt.Println("Proxy password:", password) - fmt.Println("Proxy-Authorization:", basic_auth_header(login, password)) + fmt.Println("Proxy-Authorization:", dialer.BasicAuthHeader(login, password)) fmt.Println("") wr.Write([]string{"host", "ip_address", "port"}) for i, ip := range ips { diff --git a/utils.go b/utils.go deleted file mode 100644 index 4686055..0000000 --- a/utils.go +++ /dev/null @@ -1,190 +0,0 @@ -package main - -import ( - "bufio" - "context" - "encoding/base64" - "errors" - "io" - "net" - "net/http" - "sync" - "time" -) - -const ( - COPY_BUF = 128 * 1024 - WALLCLOCK_PRECISION = 1 * time.Second -) - -func basic_auth_header(login, password string) string { - return "Basic " + base64.StdEncoding.EncodeToString( - []byte(login+":"+password)) -} - -func proxy(ctx context.Context, left, right net.Conn) { - wg := sync.WaitGroup{} - cpy := func(dst, src net.Conn) { - defer wg.Done() - io.Copy(dst, src) - dst.Close() - } - wg.Add(2) - go cpy(left, right) - go cpy(right, left) - groupdone := make(chan struct{}) - go func() { - wg.Wait() - groupdone <- struct{}{} - }() - select { - case <-ctx.Done(): - left.Close() - right.Close() - case <-groupdone: - return - } - <-groupdone - return -} - -func proxyh2(ctx context.Context, leftreader io.ReadCloser, leftwriter io.Writer, right net.Conn) { - wg := sync.WaitGroup{} - ltr := func(dst net.Conn, src io.Reader) { - defer wg.Done() - io.Copy(dst, src) - dst.Close() - } - rtl := func(dst io.Writer, src io.Reader) { - defer wg.Done() - copyBody(dst, src) - } - wg.Add(2) - go ltr(right, leftreader) - go rtl(leftwriter, right) - groupdone := make(chan struct{}, 1) - go func() { - wg.Wait() - groupdone <- struct{}{} - }() - select { - case <-ctx.Done(): - leftreader.Close() - right.Close() - case <-groupdone: - return - } - <-groupdone - return -} - -// Hop-by-hop headers. These are removed when sent to the backend. -// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html -var hopHeaders = []string{ - "Connection", - "Keep-Alive", - "Proxy-Authenticate", - "Proxy-Connection", - "Te", // canonicalized version of "TE" - "Trailers", - "Transfer-Encoding", - "Upgrade", -} - -func copyHeader(dst, src http.Header) { - for k, vv := range src { - for _, v := range vv { - dst.Add(k, v) - } - } -} - -func delHopHeaders(header http.Header) { - for _, h := range hopHeaders { - header.Del(h) - } -} - -func hijack(hijackable interface{}) (net.Conn, *bufio.ReadWriter, error) { - hj, ok := hijackable.(http.Hijacker) - if !ok { - return nil, nil, errors.New("Connection doesn't support hijacking") - } - conn, rw, err := hj.Hijack() - if err != nil { - return nil, nil, err - } - var emptytime time.Time - err = conn.SetDeadline(emptytime) - if err != nil { - conn.Close() - return nil, nil, err - } - return conn, rw, nil -} - -func flush(flusher interface{}) bool { - f, ok := flusher.(http.Flusher) - if !ok { - return false - } - f.Flush() - return true -} - -func copyBody(wr io.Writer, body io.Reader) { - buf := make([]byte, COPY_BUF) - for { - bread, read_err := body.Read(buf) - var write_err error - if bread > 0 { - _, write_err = wr.Write(buf[:bread]) - flush(wr) - } - if read_err != nil || write_err != nil { - break - } - } -} - -func AfterWallClock(d time.Duration) <-chan time.Time { - ch := make(chan time.Time, 1) - deadline := time.Now().Add(d).Truncate(0) - after_ch := time.After(d) - ticker := time.NewTicker(WALLCLOCK_PRECISION) - go func() { - var t time.Time - defer ticker.Stop() - for { - select { - case t = <-after_ch: - ch <- t - return - case t = <-ticker.C: - if t.After(deadline) { - ch <- t - return - } - } - } - }() - return ch -} - -func runTicker(ctx context.Context, interval, retryInterval time.Duration, cb func(context.Context) error) { - go func() { - var err error - for { - nextInterval := interval - if err != nil { - nextInterval = retryInterval - } - select { - case <-ctx.Done(): - return - case <-AfterWallClock(nextInterval): - err = cb(ctx) - } - } - }() -}