Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split main package into subpackages #71

Merged
merged 1 commit into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions clock/clock.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package clock

import (
"context"
"time"
)

const WALLCLOCK_PRECISION = 1 * time.Second

func AfterWallClock(d time.Duration) <-chan time.Time {
ch := make(chan time.Time, 1)
deadline := time.Now().Add(d).Truncate(0)
after_ch := time.After(d)
ticker := time.NewTicker(WALLCLOCK_PRECISION)
go func() {
var t time.Time
defer ticker.Stop()
for {
select {
case t = <-after_ch:
ch <- t
return
case t = <-ticker.C:
if t.After(deadline) {
ch <- t
return
}
}
}
}()
return ch
}

func RunTicker(ctx context.Context, interval, retryInterval time.Duration, cb func(context.Context) error) {
go func() {
var err error
for {
nextInterval := interval
if err != nil {
nextInterval = retryInterval
}
select {
case <-ctx.Done():
return
case <-AfterWallClock(nextInterval):
err = cb(ctx)
}
}
}()
}
2 changes: 1 addition & 1 deletion fixed.go → dialer/fixed.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package main
package dialer

import (
"context"
Expand Down
2 changes: 1 addition & 1 deletion resolver.go → dialer/resolver.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package main
package dialer

import (
"context"
Expand Down
59 changes: 38 additions & 21 deletions upstream.go → dialer/upstream.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
package main
package dialer

import (
"bufio"
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"errors"
"fmt"
Expand Down Expand Up @@ -47,11 +48,11 @@ CV4Ks2dH/hzg1cEo70qLRDEmBDeNiXQ2Lu+lIg+DdEmSx/cQwgwp+7e9un/jX9Wf
`
)

var UpstreamBlockedError = errors.New("blocked by upstream")

var missingLinkDER, _ = pem.Decode([]byte(MISSING_CHAIN_CERT))
var missingLink, _ = x509.ParseCertificate(missingLinkDER.Bytes)

type stringCb = func() (string, error)

type Dialer interface {
Dial(network, address string) (net.Conn, error)
}
Expand All @@ -62,15 +63,15 @@ type ContextDialer interface {
}

type ProxyDialer struct {
address string
tlsServerName string
auth AuthProvider
address stringCb
tlsServerName stringCb
auth stringCb
next ContextDialer
intermediateWorkaround bool
caPool *x509.CertPool
}

func NewProxyDialer(address, tlsServerName string, auth AuthProvider, intermediateWorkaround bool, caPool *x509.CertPool, nextDialer ContextDialer) *ProxyDialer {
func NewProxyDialer(address, tlsServerName, auth stringCb, intermediateWorkaround bool, caPool *x509.CertPool, nextDialer ContextDialer) *ProxyDialer {
return &ProxyDialer{
address: address,
tlsServerName: tlsServerName,
Expand All @@ -85,7 +86,7 @@ func ProxyDialerFromURL(u *url.URL, next ContextDialer) (*ProxyDialer, error) {
host := u.Hostname()
port := u.Port()
tlsServerName := ""
var auth AuthProvider = nil
var auth stringCb = nil

switch strings.ToLower(u.Scheme) {
case "http":
Expand All @@ -106,12 +107,9 @@ func ProxyDialerFromURL(u *url.URL, next ContextDialer) (*ProxyDialer, error) {
if u.User != nil {
username := u.User.Username()
password, _ := u.User.Password()
authHeader := basic_auth_header(username, password)
auth = func() string {
return authHeader
}
auth = WrapStringToCb(BasicAuthHeader(username, password))
}
return NewProxyDialer(address, tlsServerName, auth, false, nil, next), nil
return NewProxyDialer(WrapStringToCb(address), WrapStringToCb(tlsServerName), auth, false, nil, next), nil
}

func (d *ProxyDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
Expand All @@ -121,12 +119,20 @@ func (d *ProxyDialer) DialContext(ctx context.Context, network, address string)
return nil, errors.New("bad network specified for DialContext: only tcp is supported")
}

conn, err := d.next.DialContext(ctx, "tcp", d.address)
uAddress, err := d.address()
if err != nil {
return nil, err
}
conn, err := d.next.DialContext(ctx, "tcp", uAddress)
if err != nil {
return nil, err
}

if d.tlsServerName != "" {
uTLSServerName, err := d.tlsServerName()
if err != nil {
return nil, err
}
if uTLSServerName != "" {
// Custom cert verification logic:
// DO NOT send SNI extension of TLS ClientHello
// DO peer certificate verification against specified servername
Expand All @@ -135,7 +141,7 @@ func (d *ProxyDialer) DialContext(ctx context.Context, network, address string)
InsecureSkipVerify: true,
VerifyConnection: func(cs tls.ConnectionState) error {
opts := x509.VerifyOptions{
DNSName: d.tlsServerName,
DNSName: uTLSServerName,
Intermediates: x509.NewCertPool(),
Roots: d.caPool,
}
Expand Down Expand Up @@ -169,7 +175,11 @@ func (d *ProxyDialer) DialContext(ctx context.Context, network, address string)
}

if d.auth != nil {
req.Header.Set(PROXY_AUTHORIZATION_HEADER, d.auth())
auth, err := d.auth()
if err != nil {
return nil, err
}
req.Header.Set(PROXY_AUTHORIZATION_HEADER, auth)
}

rawreq, err := httputil.DumpRequest(req, false)
Expand All @@ -188,10 +198,6 @@ func (d *ProxyDialer) DialContext(ctx context.Context, network, address string)
}

if proxyResp.StatusCode != http.StatusOK {
if proxyResp.StatusCode == http.StatusForbidden &&
proxyResp.Header.Get("X-Hola-Error") == "Forbidden Host" {
return nil, UpstreamBlockedError
}
return nil, errors.New(fmt.Sprintf("bad response from upstream proxy server: %s", proxyResp.Status))
}

Expand Down Expand Up @@ -228,3 +234,14 @@ func readResponse(r io.Reader, req *http.Request) (*http.Response, error) {
}
return http.ReadResponse(bufio.NewReader(buf), req)
}

func BasicAuthHeader(login, password string) string {
return "Basic " + base64.StdEncoding.EncodeToString(
[]byte(login+":"+password))
}

func WrapStringToCb(s string) func() (string, error) {
return func() (string, error) {
return s, nil
}
}
149 changes: 142 additions & 7 deletions handler.go → handler/handler.go
Original file line number Diff line number Diff line change
@@ -1,23 +1,33 @@
package main
package handler

import (
"bufio"
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
"strings"
"sync"
"time"
)

const BAD_REQ_MSG = "Bad Request\n"
"github.com/Snawoot/opera-proxy/dialer"
clog "github.com/Snawoot/opera-proxy/log"
)

type AuthProvider func() string
const (
COPY_BUF = 128 * 1024
BAD_REQ_MSG = "Bad Request\n"
)

type ProxyHandler struct {
logger *CondLogger
dialer ContextDialer
logger *clog.CondLogger
dialer dialer.ContextDialer
httptransport http.RoundTripper
}

func NewProxyHandler(dialer ContextDialer, logger *CondLogger) *ProxyHandler {
func NewProxyHandler(dialer dialer.ContextDialer, logger *clog.CondLogger) *ProxyHandler {
httptransport := &http.Transport{
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
Expand Down Expand Up @@ -104,3 +114,128 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) {
s.HandleRequest(wr, req)
}
}

func proxy(ctx context.Context, left, right net.Conn) {
wg := sync.WaitGroup{}
cpy := func(dst, src net.Conn) {
defer wg.Done()
io.Copy(dst, src)
dst.Close()
}
wg.Add(2)
go cpy(left, right)
go cpy(right, left)
groupdone := make(chan struct{})
go func() {
wg.Wait()
groupdone <- struct{}{}
}()
select {
case <-ctx.Done():
left.Close()
right.Close()
case <-groupdone:
return
}
<-groupdone
return
}

func proxyh2(ctx context.Context, leftreader io.ReadCloser, leftwriter io.Writer, right net.Conn) {
wg := sync.WaitGroup{}
ltr := func(dst net.Conn, src io.Reader) {
defer wg.Done()
io.Copy(dst, src)
dst.Close()
}
rtl := func(dst io.Writer, src io.Reader) {
defer wg.Done()
copyBody(dst, src)
}
wg.Add(2)
go ltr(right, leftreader)
go rtl(leftwriter, right)
groupdone := make(chan struct{}, 1)
go func() {
wg.Wait()
groupdone <- struct{}{}
}()
select {
case <-ctx.Done():
leftreader.Close()
right.Close()
case <-groupdone:
return
}
<-groupdone
return
}

// Hop-by-hop headers. These are removed when sent to the backend.
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
var hopHeaders = []string{
"Connection",
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Connection",
"Te", // canonicalized version of "TE"
"Trailers",
"Transfer-Encoding",
"Upgrade",
}

func copyHeader(dst, src http.Header) {
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
}
}

func delHopHeaders(header http.Header) {
for _, h := range hopHeaders {
header.Del(h)
}
}

func hijack(hijackable interface{}) (net.Conn, *bufio.ReadWriter, error) {
hj, ok := hijackable.(http.Hijacker)
if !ok {
return nil, nil, errors.New("Connection doesn't support hijacking")
}
conn, rw, err := hj.Hijack()
if err != nil {
return nil, nil, err
}
var emptytime time.Time
err = conn.SetDeadline(emptytime)
if err != nil {
conn.Close()
return nil, nil, err
}
return conn, rw, nil
}

func flush(flusher interface{}) bool {
f, ok := flusher.(http.Flusher)
if !ok {
return false
}
f.Flush()
return true
}

func copyBody(wr io.Writer, body io.Reader) {
buf := make([]byte, COPY_BUF)
for {
bread, read_err := body.Read(buf)
var write_err error
if bread > 0 {
_, write_err = wr.Write(buf[:bread])
flush(wr)
}
if read_err != nil || write_err != nil {
break
}
}
}
2 changes: 1 addition & 1 deletion condlog.go → log/condlog.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package main
package log

import (
"fmt"
Expand Down
2 changes: 1 addition & 1 deletion logwriter.go → log/logwriter.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package main
package log

import (
"errors"
Expand Down
Loading