diff --git a/client.go b/client.go index 61772b6..e58f820 100644 --- a/client.go +++ b/client.go @@ -40,21 +40,22 @@ type ResponseValidator func(c *Client, resp *http.Response) error // Client handles an incoming server stream type Client struct { - Retry time.Time - ReconnectStrategy backoff.BackOff - disconnectcb ConnCallback + Connected bool connectedcb ConnCallback - subscribed map[chan *Event]chan struct{} - Headers map[string]string - ReconnectNotify backoff.Notify - ResponseValidator ResponseValidator Connection *http.Client - URL string + disconnectcb ConnCallback + EncodingBase64 bool + Headers map[string]string LastEventID atomic.Value // []byte maxBufferSize int mu sync.Mutex - EncodingBase64 bool - Connected bool + ReconnectNotify backoff.Notify + ReconnectStrategy backoff.BackOff + Request *http.Request + ResponseValidator ResponseValidator + Retry time.Time + subscribed map[chan *Event]chan struct{} + URL string } // NewClient creates a new client @@ -74,6 +75,22 @@ func NewClient(url string, opts ...func(c *Client)) *Client { return c } +func NewClientFromReq(req *http.Request, opts ...func(c *Client)) *Client { + c := &Client{ + Request: req, + Connection: &http.Client{}, + Headers: make(map[string]string), + subscribed: make(map[chan *Event]chan struct{}), + maxBufferSize: 1 << 16, + } + + for _, opt := range opts { + opt(c) + } + + return c +} + // Subscribe to a data stream func (c *Client) Subscribe(stream string, handler func(msg *Event)) error { return c.SubscribeWithContext(context.Background(), stream, handler) @@ -289,34 +306,38 @@ func (c *Client) OnConnect(fn ConnCallback) { } func (c *Client) request(ctx context.Context, stream string) (*http.Response, error) { - req, err := http.NewRequest("GET", c.URL, nil) - if err != nil { - return nil, err + if c.Request != nil { + } else { + req, err := http.NewRequest("GET", c.URL, nil) + if err != nil { + return nil, err + } + c.Request = req } - req = req.WithContext(ctx) + c.Request = c.Request.WithContext(ctx) // Setup request, specify stream to connect to if stream != "" { - query := req.URL.Query() + query := c.Request.URL.Query() query.Add("stream", stream) - req.URL.RawQuery = query.Encode() + c.Request.URL.RawQuery = query.Encode() } - req.Header.Set("Cache-Control", "no-cache") - req.Header.Set("Accept", "text/event-stream") - req.Header.Set("Connection", "keep-alive") + c.Request.Header.Set("Cache-Control", "no-cache") + c.Request.Header.Set("Accept", "text/event-stream") + c.Request.Header.Set("Connection", "keep-alive") lastID, exists := c.LastEventID.Load().([]byte) if exists && lastID != nil { - req.Header.Set("Last-Event-ID", string(lastID)) + c.Request.Header.Set("Last-Event-ID", string(lastID)) } // Add user specified headers for k, v := range c.Headers { - req.Header.Set(k, v) + c.Request.Header.Set(k, v) } - return c.Connection.Do(req) + return c.Connection.Do(c.Request) } func (c *Client) processEvent(msg []byte) (event *Event, err error) {