Skip to content

Commit

Permalink
fix: correct order of verification hooks
Browse files Browse the repository at this point in the history
Fixes #265
  • Loading branch information
mefellows committed Jul 10, 2023
1 parent 628b845 commit 830d6a7
Showing 1 changed file with 50 additions and 33 deletions.
83 changes: 50 additions & 33 deletions provider/verifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"net/http"
Expand Down Expand Up @@ -101,12 +102,8 @@ func (v *Verifier) verifyProviderRaw(request VerifyRequest, writer outputWriter)
m = append(m, beforeEachMiddleware(request.BeforeEach))
}

if request.AfterEach != nil {
m = append(m, afterEachMiddleware(request.AfterEach))
}

if len(request.StateHandlers) > 0 {
m = append(m, stateHandlerMiddleware(request.StateHandlers))
m = append(m, stateHandlerMiddleware(request.StateHandlers, request.AfterEach))
}

if len(request.MessageHandlers) > 0 {
Expand Down Expand Up @@ -197,62 +194,70 @@ func beforeEachMiddleware(BeforeEach Hook) proxy.Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == providerStatesSetupPath {
state, err := getStateFromRequest(r)

log.Println("[DEBUG] executing before hook")
err := BeforeEach()
// Before each should only fire on the "setup" phase
if err == nil && state.Action == "setup" {
log.Println("[DEBUG] executing before hook")
err := BeforeEach()

if err != nil {
log.Println("[ERROR] error executing before hook:", err)
w.WriteHeader(http.StatusInternalServerError)
if err != nil {
log.Println("[ERROR] error executing before hook:", err)
w.WriteHeader(http.StatusInternalServerError)
}
}
}
next.ServeHTTP(w, r)
})
}
}

// afterEachMiddleware is invoked after any other, and is the last
// function to be called prior to returning to the test suite. It is
// therefore not invoked on __setup
func afterEachMiddleware(AfterEach Hook) proxy.Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, r)

if r.URL.Path != providerStatesSetupPath {
log.Println("[DEBUG] executing after hook")
err := AfterEach()

if err != nil {
log.Println("[ERROR] error executing after hook:", err)
w.WriteHeader(http.StatusInternalServerError)
}
}
})
}
}

// {"action":"teardown","id":"foo","state":"User foo exists"}
type stateHandlerAction struct {
Action string `json:"action"`
State string `json:"state"`
Params map[string]interface{}
}

func getStateFromRequest(r *http.Request) (stateHandlerAction, error) {
var state stateHandlerAction
buf := new(strings.Builder)
tr := io.TeeReader(r.Body, buf)
io.ReadAll(tr)

// Body is consumed above, need to put it back after ;P
r.Body = ioutil.NopCloser(strings.NewReader(buf.String()))
log.Println("[TRACE] getStateFromRequest received raw input", buf.String())

err := json.Unmarshal([]byte(buf.String()), &state)
log.Println("[TRACE] getStateFromRequest parsed input (without params)", state)

if err != nil {
log.Println("[ERROR] getStateFromRequest unable to decode incoming state change payload", err)
return stateHandlerAction{}, err
}

return state, nil
}

// stateHandlerMiddleware responds to the various states that are
// given during provider verification
//
// statehandler accepts a state object from the verifier and executes
// any state handlers associated with the provider.
// It will not execute further middleware if it is the designted "state" request
func stateHandlerMiddleware(stateHandlers models.StateHandlers) proxy.Middleware {
func stateHandlerMiddleware(stateHandlers models.StateHandlers, afterEach Hook) proxy.Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == providerStatesSetupPath {
log.Println("[INFO] executing state handler middleware")
var state stateHandlerAction
buf := new(strings.Builder)
io.Copy(buf, r.Body)
tr := io.TeeReader(r.Body, buf)
io.ReadAll(tr)

// Body is consumed above, need to put it back after ;P
r.Body = ioutil.NopCloser(strings.NewReader(buf.String()))
log.Println("[TRACE] state handler received raw input", buf.String())

err := json.Unmarshal([]byte(buf.String()), &state)
Expand Down Expand Up @@ -295,6 +300,16 @@ func stateHandlerMiddleware(stateHandlers models.StateHandlers) proxy.Middleware
return
}

if state.Action == "teardown" && afterEach != nil {
err := afterEach()

if err != nil {
log.Printf("[ERROR] after each hook for test errored: %v", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
}

// Return provider state values for generator
if res != nil {
log.Println("[TRACE] returning values from provider state (raw)", res)
Expand All @@ -309,7 +324,9 @@ func stateHandlerMiddleware(stateHandlers models.StateHandlers) proxy.Middleware
}

w.Header().Add("content-type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(resBody)
return
}
}

Expand Down

0 comments on commit 830d6a7

Please sign in to comment.