diff --git a/js/common/util.go b/js/common/util.go index 6077ffa2749..22a15174772 100644 --- a/js/common/util.go +++ b/js/common/util.go @@ -65,3 +65,17 @@ func ToBytes(data interface{}) ([]byte, error) { return nil, fmt.Errorf("invalid type %T, expected string, []byte or ArrayBuffer", data) } } + +// ToString tries to return a string from compatible types. +func ToString(data interface{}) (string, error) { + switch dt := data.(type) { + case []byte: + return string(dt), nil + case string: + return dt, nil + case goja.ArrayBuffer: + return string(dt.Bytes()), nil + default: + return "", fmt.Errorf("invalid type %T, expected string, []byte or ArrayBuffer", data) + } +} diff --git a/js/common/util_test.go b/js/common/util_test.go index c0bcba14ec3..7cc1e15ff59 100644 --- a/js/common/util_test.go +++ b/js/common/util_test.go @@ -45,6 +45,7 @@ func TestThrow(t *testing.T) { } func TestToBytes(t *testing.T) { + t.Parallel() rt := goja.New() b := []byte("hello") testCases := []struct { @@ -70,3 +71,31 @@ func TestToBytes(t *testing.T) { }) } } + +func TestToString(t *testing.T) { + t.Parallel() + rt := goja.New() + s := "hello" + testCases := []struct { + in interface{} + expOut, expErr string + }{ + {s, s, ""}, + {"hello", s, ""}, + {rt.NewArrayBuffer([]byte(s)), s, ""}, + {struct{}{}, "", "invalid type struct {}, expected string, []byte or ArrayBuffer"}, + } + + for _, tc := range testCases { //nolint: paralleltest // false positive: https://github.com/kunwardeep/paralleltest/issues/8 + tc := tc + t.Run(fmt.Sprintf("%T", tc.in), func(t *testing.T) { + t.Parallel() + out, err := ToString(tc.in) + if tc.expErr != "" { + assert.EqualError(t, err, tc.expErr) + return + } + assert.Equal(t, tc.expOut, out) + }) + } +} diff --git a/js/initcontext.go b/js/initcontext.go index eeb286694d0..4bfcbe3e32a 100644 --- a/js/initcontext.go +++ b/js/initcontext.go @@ -207,8 +207,9 @@ func (i *InitContext) compileImport(src, filename string) (*goja.Program, error) return pgm, err } -// Open implements open() in the init context and will read and return the contents of a file. -// If the second argument is "b" it returns the data as a binary array, otherwise as a string. +// Open implements open() in the init context and will read and return the +// contents of a file. If the second argument is "b" it returns an ArrayBuffer +// instance, otherwise a string representation. func (i *InitContext) Open(ctx context.Context, filename string, args ...string) (goja.Value, error) { if lib.GetState(ctx) != nil { return nil, errors.New(openCantBeUsedOutsideInitContextMsg) @@ -242,7 +243,8 @@ func (i *InitContext) Open(ctx context.Context, filename string, args ...string) } if len(args) > 0 && args[0] == "b" { - return i.runtime.ToValue(data), nil + ab := i.runtime.NewArrayBuffer(data) + return i.runtime.ToValue(&ab), nil } return i.runtime.ToValue(string(data)), nil } diff --git a/js/initcontext_test.go b/js/initcontext_test.go index 99bb5cbaf5a..4bb436ce8f2 100644 --- a/js/initcontext_test.go +++ b/js/initcontext_test.go @@ -239,24 +239,22 @@ func TestInitContextRequire(t *testing.T) { }) } -func createAndReadFile(t *testing.T, file string, content []byte, expectedLength int, binary bool) (*BundleInstance, error) { +func createAndReadFile(t *testing.T, file string, content []byte, expectedLength int, binary string) (*BundleInstance, error) { + t.Helper() fs := afero.NewMemMapFs() assert.NoError(t, fs.MkdirAll("/path/to", 0o755)) assert.NoError(t, afero.WriteFile(fs, "/path/to/"+file, content, 0o644)) - binaryArg := "" - if binary { - binaryArg = ",\"b\"" - } - data := fmt.Sprintf(` - export let data = open("/path/to/%s"%s); + let binArg = "%s"; + export let data = open("/path/to/%s", binArg); var expectedLength = %d; - if (data.length != expectedLength) { - throw new Error("Length not equal, expected: " + expectedLength + ", actual: " + data.length); + var len = binArg === "b" ? "byteLength" : "length"; + if (data[len] != expectedLength) { + throw new Error("Length not equal, expected: " + expectedLength + ", actual: " + data[len]); } export default function() {} - `, file, binaryArg, expectedLength) + `, binary, file, expectedLength) b, err := getSimpleBundle(t, "/path/to/script.js", data, fs) if !assert.NoError(t, err) { @@ -282,8 +280,9 @@ func TestInitContextOpen(t *testing.T) { //{[]byte{00, 36, 32, 127}, "utf-16", 2}, // $€ } for _, tc := range testCases { + tc := tc t.Run(tc.file, func(t *testing.T) { - bi, err := createAndReadFile(t, tc.file, tc.content, tc.length, false) + bi, err := createAndReadFile(t, tc.file, tc.content, tc.length, "") if !assert.NoError(t, err) { return } @@ -292,12 +291,12 @@ func TestInitContextOpen(t *testing.T) { } t.Run("Binary", func(t *testing.T) { - bi, err := createAndReadFile(t, "/path/to/file.bin", []byte("hi!\x0f\xff\x01"), 6, true) + bi, err := createAndReadFile(t, "/path/to/file.bin", []byte("hi!\x0f\xff\x01"), 6, "b") if !assert.NoError(t, err) { return } - bytes := []byte{104, 105, 33, 15, 255, 1} - assert.Equal(t, bytes, bi.Runtime.Get("data").Export()) + buf := bi.Runtime.NewArrayBuffer([]byte{104, 105, 33, 15, 255, 1}) + assert.Equal(t, buf, bi.Runtime.Get("data").Export()) }) testdata := map[string]string{ @@ -306,8 +305,9 @@ func TestInitContextOpen(t *testing.T) { } for name, loadPath := range testdata { + loadPath := loadPath t.Run(name, func(t *testing.T) { - _, err := createAndReadFile(t, loadPath, []byte("content"), 7, false) + _, err := createAndReadFile(t, loadPath, []byte("content"), 7, "") if !assert.NoError(t, err) { return } @@ -409,7 +409,7 @@ func TestRequestWithBinaryFile(t *testing.T) { v, err := bi.exports[consts.DefaultFn](goja.Undefined()) assert.NoError(t, err) - assert.NotNil(t, v) + require.NotNil(t, v) assert.Equal(t, true, v.Export()) <-ch diff --git a/js/modules/k6/crypto/crypto.go b/js/modules/k6/crypto/crypto.go index 94910f2cf9e..8fb43592958 100644 --- a/js/modules/k6/crypto/crypto.go +++ b/js/modules/k6/crypto/crypto.go @@ -36,6 +36,8 @@ import ( "golang.org/x/crypto/md4" "golang.org/x/crypto/ripemd160" + "github.com/dop251/goja" + "github.com/loadimpact/k6/js/common" ) @@ -52,16 +54,18 @@ func New() *Crypto { } // RandomBytes returns random data of the given size. -func (*Crypto) RandomBytes(ctx context.Context, size int) []byte { +func (*Crypto) RandomBytes(ctx context.Context, size int) *goja.ArrayBuffer { + rt := common.GetRuntime(ctx) if size < 1 { - common.Throw(common.GetRuntime(ctx), errors.New("invalid size")) + common.Throw(rt, errors.New("invalid size")) } bytes := make([]byte, size) _, err := rand.Read(bytes) if err != nil { - common.Throw(common.GetRuntime(ctx), err) + common.Throw(rt, err) } - return bytes + ab := rt.NewArrayBuffer(bytes) + return &ab } // Md4 returns the MD4 hash of input in the given encoding. @@ -171,6 +175,7 @@ func (hasher *Hasher) Update(input interface{}) { // Digest returns the hash value in the given encoding. func (hasher *Hasher) Digest(outputEncoding string) interface{} { sum := hasher.hash.Sum(nil) + rt := common.GetRuntime(hasher.ctx) switch outputEncoding { case "base64": @@ -186,11 +191,12 @@ func (hasher *Hasher) Digest(outputEncoding string) interface{} { return hex.EncodeToString(sum) case "binary": - return sum + ab := rt.NewArrayBuffer(sum) + return &ab default: err := errors.New("Invalid output encoding: " + outputEncoding) - common.Throw(common.GetRuntime(hasher.ctx), err) + common.Throw(rt, err) } return "" diff --git a/js/modules/k6/crypto/crypto_test.go b/js/modules/k6/crypto/crypto_test.go index a9f9be45c32..b52dd212a82 100644 --- a/js/modules/k6/crypto/crypto_test.go +++ b/js/modules/k6/crypto/crypto_test.go @@ -54,9 +54,9 @@ func TestCryptoAlgorithms(t *testing.T) { t.Run("RandomBytesSuccess", func(t *testing.T) { _, err := rt.RunString(` - var bytes = crypto.randomBytes(5); - if (bytes.length !== 5) { - throw new Error("Incorrect size: " + bytes.length); + var buf = crypto.randomBytes(5); + if (buf.byteLength !== 5) { + throw new Error("Incorrect size: " + buf.byteLength); }`) assert.NoError(t, err) @@ -303,7 +303,7 @@ func TestOutputEncoding(t *testing.T) { return true; } - var resultBinary = hasher.digest("binary"); + var resultBinary = new Uint8Array(hasher.digest("binary")); if (!arraysEqual(resultBinary, correctBinary)) { throw new Error("Binary encoding mismatch: " + JSON.stringify(resultBinary)); } diff --git a/js/modules/k6/encoding/encoding.go b/js/modules/k6/encoding/encoding.go index 8c2893dc233..8c00fcbfa3e 100644 --- a/js/modules/k6/encoding/encoding.go +++ b/js/modules/k6/encoding/encoding.go @@ -55,8 +55,9 @@ func (e *Encoding) B64encode(ctx context.Context, input interface{}, encoding st } // B64decode returns the decoded data of the base64 encoded input string using -// the given encoding. -func (e *Encoding) B64decode(ctx context.Context, input string, encoding string) string { +// the given encoding. If format is "s" it returns the data as a string, +// otherwise as an ArrayBuffer. +func (e *Encoding) B64decode(ctx context.Context, input, encoding, format string) interface{} { var output []byte var err error @@ -73,9 +74,18 @@ func (e *Encoding) B64decode(ctx context.Context, input string, encoding string) output, err = base64.StdEncoding.DecodeString(input) } + rt := common.GetRuntime(ctx) //nolint: ifshort if err != nil { - common.Throw(common.GetRuntime(ctx), err) + common.Throw(rt, err) + } + + var out interface{} + if format == "s" { + out = string(output) + } else { + ab := rt.NewArrayBuffer(output) + out = &ab } - return string(output) + return out } diff --git a/js/modules/k6/encoding/encoding_test.go b/js/modules/k6/encoding/encoding_test.go index 5d0d1bdd88c..ba12470d6a6 100644 --- a/js/modules/k6/encoding/encoding_test.go +++ b/js/modules/k6/encoding/encoding_test.go @@ -31,6 +31,7 @@ import ( ) func TestEncodingAlgorithms(t *testing.T) { + t.Parallel() if testing.Short() { return } @@ -54,9 +55,12 @@ func TestEncodingAlgorithms(t *testing.T) { t.Run("DefaultDec", func(t *testing.T) { _, err := rt.RunString(` var correct = "hello world"; - var decoded = encoding.b64decode("aGVsbG8gd29ybGQ="); - if (decoded !== correct) { - throw new Error("Decoding mismatch: " + decoded); + var decBin = encoding.b64decode("aGVsbG8gd29ybGQ="); + + var decText = String.fromCharCode.apply(null, new Uint8Array(decBin)); + decText = decodeURIComponent(escape(decText)); + if (decText !== correct) { + throw new Error("Decoding mismatch: " + decText); }`) assert.NoError(t, err) }) @@ -70,6 +74,17 @@ func TestEncodingAlgorithms(t *testing.T) { }`) assert.NoError(t, err) }) + t.Run("DefaultArrayBufferDec", func(t *testing.T) { //nolint: paralleltest // weird that it triggers here, and these tests can't be parallel + _, err := rt.RunString(` + var exp = "hello"; + var decBin = encoding.b64decode("aGVsbG8="); + var decText = String.fromCharCode.apply(null, new Uint8Array(decBin)); + decText = decodeURIComponent(escape(decText)); + if (decText !== exp) { + throw new Error("Decoding mismatch: " + decText); + }`) + assert.NoError(t, err) + }) t.Run("DefaultUnicodeEnc", func(t *testing.T) { _, err := rt.RunString(` var correct = "44GT44KT44Gr44Gh44Gv5LiW55WM"; @@ -82,9 +97,11 @@ func TestEncodingAlgorithms(t *testing.T) { t.Run("DefaultUnicodeDec", func(t *testing.T) { _, err := rt.RunString(` var correct = "こんにちは世界"; - var decoded = encoding.b64decode("44GT44KT44Gr44Gh44Gv5LiW55WM"); - if (decoded !== correct) { - throw new Error("Decoding mismatch: " + decoded); + var decBin = encoding.b64decode("44GT44KT44Gr44Gh44Gv5LiW55WM"); + var decText = String.fromCharCode.apply(null, new Uint8Array(decBin)); + decText = decodeURIComponent(escape(decText)); + if (decText !== correct) { + throw new Error("Decoding mismatch: " + decText); }`) assert.NoError(t, err) }) @@ -100,7 +117,7 @@ func TestEncodingAlgorithms(t *testing.T) { t.Run("StdDec", func(t *testing.T) { _, err := rt.RunString(` var correct = "hello world"; - var decoded = encoding.b64decode("aGVsbG8gd29ybGQ=", "std"); + var decoded = encoding.b64decode("aGVsbG8gd29ybGQ=", "std", "s"); if (decoded !== correct) { throw new Error("Decoding mismatch: " + decoded); }`) @@ -118,7 +135,7 @@ func TestEncodingAlgorithms(t *testing.T) { t.Run("RawStdDec", func(t *testing.T) { _, err := rt.RunString(` var correct = "hello world"; - var decoded = encoding.b64decode("aGVsbG8gd29ybGQ", "rawstd"); + var decoded = encoding.b64decode("aGVsbG8gd29ybGQ", "rawstd", "s"); if (decoded !== correct) { throw new Error("Decoding mismatch: " + decoded); }`) @@ -136,7 +153,7 @@ func TestEncodingAlgorithms(t *testing.T) { t.Run("URLDec", func(t *testing.T) { _, err := rt.RunString(` var correct = "小飼弾.."; - var decoded = encoding.b64decode("5bCP6aO85by-Li4=", "url"); + var decoded = encoding.b64decode("5bCP6aO85by-Li4=", "url", "s"); if (decoded !== correct) { throw new Error("Decoding mismatch: " + decoded); }`) @@ -154,7 +171,7 @@ func TestEncodingAlgorithms(t *testing.T) { t.Run("RawURLDec", func(t *testing.T) { _, err := rt.RunString(` var correct = "小飼弾.."; - var decoded = encoding.b64decode("5bCP6aO85by-Li4", "rawurl"); + var decoded = encoding.b64decode("5bCP6aO85by-Li4", "rawurl", "s"); if (decoded !== correct) { throw new Error("Decoding mismatch: " + decoded); }`) diff --git a/js/modules/k6/http/request.go b/js/modules/k6/http/request.go index 22f18ab9dac..dc252bc1f5b 100644 --- a/js/modules/k6/http/request.go +++ b/js/modules/k6/http/request.go @@ -113,6 +113,7 @@ func (h *HTTP) Request(ctx context.Context, method string, url goja.Value, args if err != nil { return nil, err } + processResponse(ctx, resp, req.ResponseType) return h.responseFromHttpext(resp), nil } @@ -450,6 +451,7 @@ func (h *HTTP) Batch(ctx context.Context, reqsV goja.Value) (goja.Value, error) errs := httpext.MakeBatchRequests( ctx, batchReqs, reqCount, int(state.Options.Batch.Int64), int(state.Options.BatchPerHost.Int64), + processResponse, ) for i := 0; i < reqCount; i++ { diff --git a/js/modules/k6/http/request_test.go b/js/modules/k6/http/request_test.go index d38cc10f876..c547fb85600 100644 --- a/js/modules/k6/http/request_test.go +++ b/js/modules/k6/http/request_test.go @@ -1429,8 +1429,19 @@ func TestRequestArrayBufferBody(t *testing.T) { var res = http.post("HTTPBIN_URL/post-arraybuffer", arr.buffer, { responseType: 'binary' }); if (res.status != 200) { throw new Error("wrong status: " + res.status) } - if (res.body != "%[2]s") { throw new Error( - "incorrect data: expected '%[2]s', received '" + res.body + "'") } + + var resTyped = new Uint8Array(res.body); + var exp = new %[1]s([%[2]s]); + if (exp.length !== resTyped.length) { + throw new Error( + "incorrect data length: expected " + exp.length + ", received " + resTypedLength) + } + for (var i = 0; i < exp.length; i++) { + if (exp[i] !== resTyped[i]) { + throw new Error( + "incorrect data at index " + i + ": expected " + exp[i] + ", received " + resTyped[i]) + } + } `, tc.arr, tc.expected))) assert.NoError(t, err) }) @@ -1687,31 +1698,51 @@ func TestResponseTypes(t *testing.T) { } // Check binary transmission of the text response as well - var respTextInBin = http.get("HTTPBIN_URL/get-text", { responseType: "binary" }).body; + var respBin = http.get("HTTPBIN_URL/get-text", { responseType: "binary" }); - // Hack to convert a utf-8 array to a JS string - var strConv = ""; - function pad(n) { return n.length < 2 ? "0" + n : n; } - for( var i = 0; i < respTextInBin.length; i++ ) { - strConv += ( "%" + pad(respTextInBin[i].toString(16))); - } - strConv = decodeURIComponent(strConv); + // Convert a UTF-8 ArrayBuffer to a JS string + var respBinText = String.fromCharCode.apply(null, new Uint8Array(respBin.body)); + var strConv = decodeURIComponent(escape(respBinText)); if (strConv !== expText) { throw new Error("converted response body should be '" + expText + "' but was '" + strConv + "'"); } - http.post("HTTPBIN_URL/compare-text", respTextInBin); + http.post("HTTPBIN_URL/compare-text", respBin.body); // Check binary response + var respBin = http.get("HTTPBIN_URL/get-bin", { responseType: "binary" }); + var respBinTyped = new Uint8Array(respBin.body); + if (expBinLength !== respBinTyped.length) { + throw new Error("response body length should be '" + expBinLength + + "' but was '" + respBinTyped.length + "'"); + } + for(var i = 0; i < respBinTyped.length; i++) { + if (respBinTyped[i] !== i%256) { + throw new Error("expected value " + (i%256) + " to be at position " + + i + " but it was " + respBinTyped[i]); + } + } + http.post("HTTPBIN_URL/compare-bin", respBin.body); + + // Check ArrayBuffer response var respBin = http.get("HTTPBIN_URL/get-bin", { responseType: "binary" }).body; - if (respBin.length !== expBinLength) { - throw new Error("response body length should be '" + expBinLength + "' but was '" + respBin.length + "'"); + if (respBin.byteLength !== expBinLength) { + throw new Error("response body length should be '" + expBinLength + "' but was '" + respBin.byteLength + "'"); + } + + // Check ArrayBuffer responses with http.batch() + var responses = http.batch([ + ["GET", "HTTPBIN_URL/get-bin", null, { responseType: "binary" }], + ["GET", "HTTPBIN_URL/get-bin", null, { responseType: "binary" }], + ]); + if (responses.length != 2) { + throw new Error("expected 2 responses, received " + responses.length); } - for( var i = 0; i < respBin.length; i++ ) { - if ( respBin[i] !== i%256 ) { - throw new Error("expected value " + (i%256) + " to be at position " + i + " but it was " + respBin[i]); + for (var i = 0; i < responses.length; i++) { + if (responses[i].body.byteLength !== expBinLength) { + throw new Error("response body length should be '" + + expBinLength + "' but was '" + responses[i].body.byteLength + "'"); } } - http.post("HTTPBIN_URL/compare-bin", respBin); `)) assert.NoError(t, err) diff --git a/js/modules/k6/http/response.go b/js/modules/k6/http/response.go index 34506d7225c..c618893fcae 100644 --- a/js/modules/k6/http/response.go +++ b/js/modules/k6/http/response.go @@ -21,12 +21,15 @@ package http import ( + "context" + "encoding/json" "errors" "fmt" "net/url" "strings" "github.com/dop251/goja" + "github.com/tidwall/gjson" "github.com/loadimpact/k6/js/common" "github.com/loadimpact/k6/js/modules/k6/html" @@ -37,34 +40,41 @@ import ( type Response struct { *httpext.Response `js:"-"` h *HTTP + + cachedJSON interface{} + validatedJSON bool } -func (h *HTTP) responseFromHttpext(resp *httpext.Response) *Response { - return &Response{Response: resp, h: h} +type jsonError struct { + line int + character int + err error } -// JSON parses the body of a response as json and returns it to the goja VM -func (res *Response) JSON(selector ...string) goja.Value { - v, err := res.Response.JSON(selector...) - if err != nil { - common.Throw(common.GetRuntime(res.GetCtx()), err) - } - if v == nil { - return goja.Undefined() +func (j jsonError) Error() string { + errMessage := "cannot parse json due to an error at line" + return fmt.Sprintf("%s %d, character %d , error: %v", errMessage, j.line, j.character, j.err) +} + +// processResponse stores the body as an ArrayBuffer if indicated by +// respType. This is done here instead of in httpext.readResponseBody to avoid +// a reverse dependency on js/common or goja. +func processResponse(ctx context.Context, resp *httpext.Response, respType httpext.ResponseType) { + if respType == httpext.ResponseTypeBinary { + rt := common.GetRuntime(ctx) + resp.Body = rt.NewArrayBuffer(resp.Body.([]byte)) } - return common.GetRuntime(res.GetCtx()).ToValue(v) +} + +func (h *HTTP) responseFromHttpext(resp *httpext.Response) *Response { + return &Response{Response: resp, h: h, cachedJSON: nil, validatedJSON: false} } // HTML returns the body as an html.Selection func (res *Response) HTML(selector ...string) html.Selection { - var body string - switch b := res.Body.(type) { - case []byte: - body = string(b) - case string: - body = b - default: - common.Throw(common.GetRuntime(res.GetCtx()), errors.New("invalid response type")) + body, err := common.ToString(res.Body) + if err != nil { + common.Throw(common.GetRuntime(res.GetCtx()), err) } sel, err := html.HTML{}.ParseHTML(res.GetCtx(), body) @@ -78,6 +88,70 @@ func (res *Response) HTML(selector ...string) html.Selection { return sel } +// JSON parses the body of a response as JSON and returns it to the goja VM. +func (res *Response) JSON(selector ...string) goja.Value { + rt := common.GetRuntime(res.GetCtx()) + hasSelector := len(selector) > 0 + if res.cachedJSON == nil || hasSelector { //nolint:nestif + var v interface{} + + body, err := common.ToBytes(res.Body) + if err != nil { + common.Throw(rt, err) + } + + if hasSelector { + if !res.validatedJSON { + if !gjson.ValidBytes(body) { + return goja.Undefined() + } + res.validatedJSON = true + } + + result := gjson.GetBytes(body, selector[0]) + + if !result.Exists() { + return goja.Undefined() + } + return rt.ToValue(result.Value()) + } + + if err := json.Unmarshal(body, &v); err != nil { + var syntaxError *json.SyntaxError + if errors.As(err, &syntaxError) { + err = checkErrorInJSON(body, int(syntaxError.Offset), err) + } + common.Throw(rt, err) + } + res.validatedJSON = true + res.cachedJSON = v + } + + return rt.ToValue(res.cachedJSON) +} + +func checkErrorInJSON(input []byte, offset int, err error) error { + lf := '\n' + str := string(input) + + // Humans tend to count from 1. + line := 1 + character := 0 + + for i, b := range str { + if b == lf { + line++ + character = 0 + } + character++ + if i == offset { + break + } + } + + return jsonError{line: line, character: character, err: err} +} + // SubmitForm parses the body as an html looking for a from and then submitting it // TODO: document the actual arguments that can be provided func (res *Response) SubmitForm(args ...goja.Value) (*Response, error) { diff --git a/js/modules/k6/ws/ws.go b/js/modules/k6/ws/ws.go index e19813897bc..b089c99ae60 100644 --- a/js/modules/k6/ws/ws.go +++ b/js/modules/k6/ws/ws.go @@ -70,6 +70,11 @@ type WSHTTPResponse struct { Error string `json:"error"` } +type message struct { + mtype int // message type consts as defined in gorilla/websocket/conn.go + data []byte +} + const writeWait = 10 * time.Second func New() *WS { @@ -240,7 +245,7 @@ func (*WS) Connect(ctx context.Context, url string, args ...goja.Value) (*WSHTTP conn.SetPingHandler(func(msg string) error { pingChan <- msg; return nil }) conn.SetPongHandler(func(pingID string) error { pongChan <- pingID; return nil }) - readDataChan := make(chan []byte) + readDataChan := make(chan *message) readCloseChan := make(chan int) readErrChan := make(chan error) @@ -280,14 +285,20 @@ func (*WS) Connect(ctx context.Context, url string, args ...goja.Value) (*WSHTTP socket.trackPong(pingID) socket.handleEvent("pong") - case readData := <-readDataChan: + case msg := <-readDataChan: stats.PushIfNotDone(ctx, socket.samplesOutput, stats.Sample{ Metric: metrics.WSMessagesReceived, Time: time.Now(), Tags: socket.sampleTags, Value: 1, }) - socket.handleEvent("message", rt.ToValue(string(readData))) + + if msg.mtype == websocket.BinaryMessage { + ab := rt.NewArrayBuffer(msg.data) + socket.handleEvent("binaryMessage", rt.ToValue(&ab)) + } else { + socket.handleEvent("message", rt.ToValue(string(msg.data))) + } case readErr := <-readErrChan: socket.handleEvent("error", rt.ToValue(readErr)) @@ -329,14 +340,41 @@ func (s *Socket) handleEvent(event string, args ...goja.Value) { } } +// Send writes the given string message to the connection. func (s *Socket) Send(message string) { - // NOTE: No binary message support for the time being since goja doesn't - // support typed arrays. - rt := common.GetRuntime(s.ctx) + if err := s.conn.WriteMessage(websocket.TextMessage, []byte(message)); err != nil { + s.handleEvent("error", common.GetRuntime(s.ctx).ToValue(err)) + } - writeData := []byte(message) - if err := s.conn.WriteMessage(websocket.TextMessage, writeData); err != nil { - s.handleEvent("error", rt.ToValue(err)) + stats.PushIfNotDone(s.ctx, s.samplesOutput, stats.Sample{ + Metric: metrics.WSMessagesSent, + Time: time.Now(), + Tags: s.sampleTags, + Value: 1, + }) +} + +// SendBinary writes the given ArrayBuffer message to the connection. +func (s *Socket) SendBinary(message goja.Value) { + if message == nil { + common.Throw(common.GetRuntime(s.ctx), errors.New("missing argument, expected ArrayBuffer")) + } + + msg := message.Export() + if ab, ok := msg.(goja.ArrayBuffer); ok { + if err := s.conn.WriteMessage(websocket.BinaryMessage, ab.Bytes()); err != nil { + s.handleEvent("error", common.GetRuntime(s.ctx).ToValue(err)) + } + } else { + rt := common.GetRuntime(s.ctx) + var jsType string + switch { + case goja.IsNull(message), goja.IsUndefined(message): + jsType = message.String() + default: + jsType = message.ToObject(rt).ClassName() + } + common.Throw(rt, fmt.Errorf("expected ArrayBuffer as argument, received: %s", jsType)) } stats.PushIfNotDone(s.ctx, s.samplesOutput, stats.Sample{ @@ -486,9 +524,9 @@ func (s *Socket) closeConnection(code int) error { } // Wraps conn.ReadMessage in a channel -func (s *Socket) readPump(readChan chan []byte, errorChan chan error, closeChan chan int) { +func (s *Socket) readPump(readChan chan *message, errorChan chan error, closeChan chan int) { //nolint: cyclop for { - _, message, err := s.conn.ReadMessage() + messageType, data, err := s.conn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError( err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { @@ -511,7 +549,7 @@ func (s *Socket) readPump(readChan chan []byte, errorChan chan error, closeChan } select { - case readChan <- message: + case readChan <- &message{messageType, data}: case <-s.done: return } diff --git a/js/modules/k6/ws/ws_test.go b/js/modules/k6/ws/ws_test.go index 90d2c3ae5cd..30437ff8807 100644 --- a/js/modules/k6/ws/ws_test.go +++ b/js/modules/k6/ws/ws_test.go @@ -347,6 +347,102 @@ func TestSession(t *testing.T) { } } +func TestSocketSendBinary(t *testing.T) { //nolint: tparallel + t.Parallel() + tb := httpmultibin.NewHTTPMultiBin(t) + t.Cleanup(tb.Cleanup) + sr := tb.Replacer.Replace + + root, err := lib.NewGroup("", nil) + assert.NoError(t, err) + + rt := goja.New() + rt.SetFieldNameMapper(common.FieldNameMapper{}) + samples := make(chan stats.SampleContainer, 1000) + state := &lib.State{ //nolint: exhaustivestruct + Group: root, + Dialer: tb.Dialer, + Options: lib.Options{ //nolint: exhaustivestruct + SystemTags: stats.NewSystemTagSet( + stats.TagURL, + stats.TagProto, + stats.TagStatus, + stats.TagSubproto, + ), + }, + Samples: samples, + TLSConfig: tb.TLSClientConfig, + } + + ctx := context.Background() + ctx = lib.WithState(ctx, state) + ctx = common.WithRuntime(ctx, rt) + + err = rt.Set("ws", common.Bind(rt, New(), &ctx)) + assert.NoError(t, err) + + t.Run("ok", func(t *testing.T) { + _, err = rt.RunString(sr(` + var gotMsg = false; + var res = ws.connect('WSBIN_URL/ws-echo', function(socket){ + var data = new Uint8Array([104, 101, 108, 108, 111]); // 'hello' + + socket.on('open', function() { + socket.sendBinary(data.buffer); + }) + socket.on('binaryMessage', function(msg) { + gotMsg = true; + let decText = String.fromCharCode.apply(null, new Uint8Array(msg)); + decText = decodeURIComponent(escape(decText)); + if (decText !== 'hello') { + throw new Error('received unexpected binary message: ' + decText); + } + socket.close() + }); + }); + if (!gotMsg) { + throw new Error("the 'binaryMessage' handler wasn't called") + } + `)) + assert.NoError(t, err) + }) + + errTestCases := []struct { + in, expErrType string + }{ + {"", ""}, + {"undefined", "undefined"}, + {"null", "null"}, + {"true", "Boolean"}, + {"1", "Number"}, + {"3.14", "Number"}, + {"'str'", "String"}, + {"[1, 2, 3]", "Array"}, + {"new Uint8Array([1, 2, 3])", "Object"}, + {"Symbol('a')", "Symbol"}, + {"function() {}", "Function"}, + } + + for _, tc := range errTestCases { //nolint: paralleltest + tc := tc + t.Run(fmt.Sprintf("err_%s", tc.expErrType), func(t *testing.T) { + _, err = rt.RunString(fmt.Sprintf(sr(` + var res = ws.connect('WSBIN_URL/ws-echo', function(socket){ + socket.on('open', function() { + socket.sendBinary(%s); + }) + }); + `), tc.in)) + require.Error(t, err) + if tc.in == "" { + assert.Contains(t, err.Error(), "missing argument, expected ArrayBuffer") + } else { + assert.Contains(t, err.Error(), fmt.Sprintf("expected ArrayBuffer as argument, received: %s", tc.expErrType)) + } + }) + } +} + func TestErrors(t *testing.T) { t.Parallel() tb := httpmultibin.NewHTTPMultiBin(t) @@ -607,7 +703,7 @@ func TestReadPump(t *testing.T) { _ = conn.Close() }() - msgChan := make(chan []byte) + msgChan := make(chan *message) errChan := make(chan error) closeChan := make(chan int) s := &Socket{conn: conn} diff --git a/lib/netext/httpext/batch.go b/lib/netext/httpext/batch.go index c61197367cd..e7256e5d5a5 100644 --- a/lib/netext/httpext/batch.go +++ b/lib/netext/httpext/batch.go @@ -42,10 +42,13 @@ type BatchParsedHTTPRequest struct { // pre-initialized. In addition, each processed request would emit either a nil // value, or an error, via the returned errors channel. The goroutines exit when // the requests channel is closed. +// The processResponse callback can be used to modify the response, e.g. +// to replace the body. func MakeBatchRequests( ctx context.Context, requests []BatchParsedHTTPRequest, reqCount, globalLimit, perHostLimit int, + processResponse func(context.Context, *Response, ResponseType), ) <-chan error { workers := globalLimit if reqCount < workers { @@ -62,6 +65,7 @@ func MakeBatchRequests( resp, err := MakeRequest(ctx, req.ParsedHTTPRequest) if resp != nil { + processResponse(ctx, resp, req.ParsedHTTPRequest.ResponseType) *req.Response = *resp } result <- err diff --git a/lib/netext/httpext/compression.go b/lib/netext/httpext/compression.go index d52f07c8a73..1e4cc9e632f 100644 --- a/lib/netext/httpext/compression.go +++ b/lib/netext/httpext/compression.go @@ -212,6 +212,8 @@ func readResponseBody( case ResponseTypeBinary: // Copy the data to a new slice before we return the buffer to the pool, // because buf.Bytes() points to the underlying buffer byte slice. + // The ArrayBuffer wrapping will be done in the js/modules/k6/http + // package to avoid a reverse dependency, since it depends on goja. binData := make([]byte, buf.Len()) copy(binData, buf.Bytes()) result = binData diff --git a/lib/netext/httpext/response.go b/lib/netext/httpext/response.go index 4a2df1f877b..04e558de334 100644 --- a/lib/netext/httpext/response.go +++ b/lib/netext/httpext/response.go @@ -23,11 +23,6 @@ package httpext import ( "context" "crypto/tls" - "encoding/json" - "errors" - "fmt" - - "github.com/tidwall/gjson" "github.com/loadimpact/k6/lib/netext" ) @@ -57,17 +52,6 @@ const ( ResponseTypeNone ) -type jsonError struct { - line int - character int - err error -} - -func (j jsonError) Error() string { - errMessage := "cannot parse json due to an error at line" - return fmt.Sprintf("%s %d, character %d , error: %v", errMessage, j.line, j.character, j.err) -} - // ResponseTimings is a struct to put all timings for a given HTTP response/request type ResponseTimings struct { Duration float64 `json:"duration"` @@ -108,9 +92,6 @@ type Response struct { Error string `json:"error"` ErrorCode int `json:"error_code"` Request Request `json:"request"` - - cachedJSON interface{} - validatedJSON bool } func (res *Response) setTLSInfo(tlsState *tls.ConnectionState) { @@ -124,68 +105,3 @@ func (res *Response) setTLSInfo(tlsState *tls.ConnectionState) { func (res *Response) GetCtx() context.Context { return res.ctx } - -// JSON parses the body of a response as json and returns it to the goja VM -func (res *Response) JSON(selector ...string) (interface{}, error) { - hasSelector := len(selector) > 0 - if res.cachedJSON == nil || hasSelector { - var v interface{} - var body []byte - switch b := res.Body.(type) { - case []byte: - body = b - case string: - body = []byte(b) - default: - return nil, errors.New("invalid response type") - } - - if hasSelector { - if !res.validatedJSON { - if !gjson.ValidBytes(body) { - return nil, nil - } - res.validatedJSON = true - } - - result := gjson.GetBytes(body, selector[0]) - - if !result.Exists() { - return nil, nil - } - return result.Value(), nil - } - - if err := json.Unmarshal(body, &v); err != nil { - if syntaxError, ok := err.(*json.SyntaxError); ok { - err = checkErrorInJSON(body, int(syntaxError.Offset), err) - } - return nil, err - } - res.validatedJSON = true - res.cachedJSON = v - } - return res.cachedJSON, nil -} - -func checkErrorInJSON(input []byte, offset int, err error) error { - lf := '\n' - str := string(input) - - // Humans tend to count from 1. - line := 1 - character := 0 - - for i, b := range str { - if b == lf { - line++ - character = 0 - } - character++ - if i == offset { - break - } - } - - return jsonError{line: line, character: character, err: err} -}