Skip to content
This repository has been archived by the owner on May 25, 2024. It is now read-only.

Commit

Permalink
feat: finish SIMD
Browse files Browse the repository at this point in the history
[skip ci]
  • Loading branch information
Nambers committed May 7, 2024
1 parent 290944a commit 4aee593
Show file tree
Hide file tree
Showing 4 changed files with 204 additions and 33 deletions.
1 change: 0 additions & 1 deletion deps/dconv_wrapper.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#include "double-conversion.hpp"

#include <Python.h>
#include <pymath.h>

enum FLAGS {
// d2s/encoding flags
Expand Down
36 changes: 9 additions & 27 deletions src/pycJSON_decode.c
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ static parse_buffer *skip_utf8_bom(parse_buffer *const buffer) {
static bool parse_string(PyObject **item, parse_buffer *const input_buffer) {
assert(item);
const unsigned char *input_pointer = buffer_at_offset(input_buffer) + 1;
const unsigned char *input_end = buffer_at_offset(input_buffer) + 1;

unsigned char *buffer_ptr = NULL;
unsigned char parse_string_stack_buffer[STACK_BUFFER_SIZE];
Expand All @@ -140,53 +139,36 @@ static bool parse_string(PyObject **item, parse_buffer *const input_buffer) {

{
size_t skipped_bytes = 0;
while (((Py_ssize_t) (input_end - input_buffer->content) < input_buffer->length) && (*input_end != '\"')) {
if (*input_end == '\\') {
input_end++;
skipped_bytes++;
if (*input_end == 'u') {
// surrogates
if (CHECK_SURROGATES_UNICODE(input_end + 1)) {
if ((input_end - input_buffer->content) + 4 + 6 > input_buffer->length) {
PyErr_Format(PyExc_ValueError, "Failed to parse string: invalid utf8, missing surrogate pair\nposition: %d", input_end - input_buffer->content);
goto fail;
}
input_end += 6;
skipped_bytes += 6;
}
input_end += 4;
skipped_bytes += 4;
}
}
input_end++;
size_t num = 0;
if(!count_skipped(input_pointer, input_buffer->length - input_buffer->offset, &skipped_bytes, &num)) {
goto fail;
}
size_t num = input_end - buffer_at_offset(input_buffer) - 1;
size_t alloc = count_utf8(buffer_at_offset(input_buffer) + 1, num) - skipped_bytes;
// size_t num = input_end - buffer_at_offset(input_buffer) - 1;
size_t alloc = count_utf8(input_pointer, num) - skipped_bytes;

int kind = get_utf8_kind((const char *) input_pointer, num);
switch (kind) {
case 1:
if (!str2unicode_1byte(item, (const char *) input_pointer, alloc, num)) {
goto fail;
}
input_buffer->offset = (Py_ssize_t) (input_end - input_buffer->content);
break;
case 2:
if (!str2unicode_2byte(item, (const char *) input_pointer, alloc, num)) {
goto fail;
}
input_buffer->offset = (Py_ssize_t) (input_end - input_buffer->content);
break;
case 4:
if (!str2unicode_4byte(item, (const char *) input_pointer, alloc, num)) {
goto fail;
}
input_buffer->offset = (Py_ssize_t) (input_end - input_buffer->content);
break;
default:
PyErr_Format(PyExc_ValueError, "Failed to parse string: invalid utf8\nposition: %d", input_buffer->offset);
goto fail;
}
// + 1 for ending "
input_buffer->offset += num + 1;
}
PARSE_STRING_FINALIZE;

Expand Down Expand Up @@ -394,7 +376,7 @@ static bool parse_object(PyObject **item, parse_buffer *const input_buffer) {
buffer_skip_whitespace(input_buffer);

if (cannot_access_at_index(input_buffer, 0) || (buffer_at_offset(input_buffer)[0] != ':')) {
PyErr_Format(PyExc_ValueError, "Failed to parse dictionary: expected colon\nposition: %d", input_buffer->offset);
PyErr_Format(PyExc_ValueError, "Failed to parse dictionary: expect colon\nposition: %d", input_buffer->offset);
goto fail; /* invalid object */
}

Expand Down Expand Up @@ -522,7 +504,7 @@ PyObject *pycJSON_Decode(PyObject *self, PyObject *args, PyObject *kwargs) {
Py_ssize_t buffer_length;
static const char *kwlist[] = {"s", "object_hook", NULL};
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "s#|O", (char **) kwlist, &value, &buffer_length, &buffer.object_hook)) {
PyErr_Format(PyExc_TypeError, "Failed to parse JSON: invalid argument, expected str / bytes-like object");
if(!PyErr_Occurred()) PyErr_Format(PyExc_TypeError, "Failed to parse JSON: invalid argument, expected str / bytes-like object");
goto fail;
}

Expand Down
193 changes: 190 additions & 3 deletions src/str.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,200 @@
#include <stdint.h>
#define CHECK_NOT_LATIN1_2BYTES(a, b) (((a & 0b00011111) << 6 | (b & 0b00111111)) > 0xFF)

// input must be uint
// int BitCount(unsigned int u) {
// unsigned int uCount = u - ((u >> 1) & 033333333333) - ((u >> 2) & 011111111111);
// return ((uCount + (uCount >> 3)) & 030707070707) % 63;
// }

bool count_skipped(const char *buf, size_t max_len, size_t *skipped, size_t *len) {
const __m256i escape_mask = _mm256_set1_epi8('\\');
const __m256i end_mask = _mm256_set1_epi8('\"');
const __m256i u_mask = _mm256_set1_epi8('u');
// const __m256i utf8_2_mask = _mm256_set1_epi8(0b00000011);
// const __m256i utf8_3_mask = _mm256_set1_epi8(0b00000111);
// const __m256i utf8_4_mask = _mm256_set1_epi8(0b00001111);
*skipped = 0;
int i = 0;
int skip_next = 0;

for (; i + 32 < max_len; i += 32) {
__m256i batch = _mm256_loadu_si256((__m256i_u *) (buf + i));
// __mmask32 escape_result = _mm256_cmpeq_epi8_mask(batch, escape_mask);
// __mmask32 end_result = _mm256_cmpeq_epi8_mask(batch, end_mask);
// __mmask32 u_result = _mm256_cmpeq_epi8_mask(batch, u_mask);
__mmask32 escape_result = _mm256_movemask_epi8(_mm256_cmpeq_epi8(batch, escape_mask));
__mmask32 end_result = _mm256_movemask_epi8(_mm256_cmpeq_epi8(batch, end_mask));
__mmask32 u_result = _mm256_movemask_epi8(_mm256_cmpeq_epi8(batch, u_mask));
// mask out first unicode
if (skip_next > 0) {
const unsigned int mask = 0b11111111111111111111111111111111 ^ ((1 << skip_next) - 1);
escape_result = escape_result & mask;
end_result = end_result & mask;
u_result = u_result & mask;
}
if (escape_result == 0 && end_result == 0) {
skip_next = 0;
continue;
}
if (end_result != 0) {
i += skip_next;
skip_next = 0;
if (escape_result != 0) {
if (((~(end_result & (escape_result << 1))) & end_result) != 0) {
// there is real ending with some escaped sequence
if (u_result != 0) {
// some unicodes
while (buf[i] != '\"') {
if (buf[i] == '\\') {
*skipped += 1;
i++;
if (i > 31) {
// skip the escaped char in next batch
skip_next = 1;
}
if (buf[i] == 'u') {
if (i + 4 >= max_len) {
PyErr_SetString(PyExc_ValueError, "Invalide Utf8 sequence: invalid unicode escaped sequence");
return false;
}
*skipped += 4;
i++;
if (i + 4 > 31) {
skip_next = i + 4 - 32;
}
if (CHECK_SURROGATES_UNICODE(buf + i)) {
if (i + 4 + 6 >= max_len) {
PyErr_SetString(PyExc_ValueError, "Invalid Utf8 sequence: invalid unicode escaped surrogate sequence");
return false;
}
*skipped += 6;
if (i + 4 + 6 > 31) {
skip_next = i + 4 + 6 - 32;
}
i += 6;
}
i += 3;
} else {
// do nothing
}
}
i++;
}
} else {
while (buf[i] != '\"') {
if (buf[i] == '\\') {
*skipped += 1;
i++;
}
i++;
}
}
*len = i - 1 + 1;
return true;
}
// there is not ending but some escapes
// handled by following block
} else {
// real ending without any escape
while (buf[i++] != '\"') {}
*len = i - 1 - 1 + 1;
return true;
}
}
// escape_result != 0
// there are some escaped sequence without ending
if (u_result != 0) {
int j = skip_next;
skip_next = 0;
for (; j < 32; j++) {
if (buf[i + j] == '\\') {
*skipped += 1;
j++;
if (j > 31) {
skip_next = 1;
}
if (buf[i + j] == 'u') {
if (i + j + 4 >= max_len) {
PyErr_SetString(PyExc_ValueError, "invalide Utf8 Sequence: invalide unicode escaped sequence");
return false;
}
*skipped += 4;
if (j + 4 > 31) {
skip_next = j + 4 - 32;
}
j++;
if (CHECK_SURROGATES_UNICODE(buf + i + j)) {
if (i + j + 4 + 6 >= max_len) {
PyErr_SetString(PyExc_ValueError, "invalid utf8 sequence");
return false;
}
*skipped += 6;
if (j + 4 + 6 > 31) {
skip_next = j + 4 + 6 - 32;
}
j += 6;
}
j += 3;
}
}
}
} else {
skip_next = 0;
// if(escape_result & (escape_result >> 1) == 0) {
// // fast path
// *skipped += BitCount(escape_result);
// }else {
for (int j = 0; j < 32; j++) {
if ((escape_result >> j) & 0b1) {
*skipped += 1;
j++;
if (j > 31) {
skip_next = 1;
}
}
}
// }
}
}
i += skip_next;
for (; i < max_len; i++) {
if (buf[i] == '"') {
*len = i - 1 + 1;
return true;
}
if (buf[i] == '\\') {
i++;
*skipped += 1;
if (buf[i] == 'u') {
if (i + 4 >= max_len) {
PyErr_SetString(PyExc_ValueError, "Invalid Utf8 sequence: invalid unicode escaped sequence");
return false;
}
*skipped += 4;
i++;
if (CHECK_SURROGATES_UNICODE(buf + i)) {
if (i + 4 + 6 >= max_len) {
PyErr_SetString(PyExc_ValueError, "Invalid Utf8 sequence: missing surrogates 2nd byte");
return false;
}
*skipped += 6;
i += 6;
}
i += 3;
}
}
}
PyErr_SetString(PyExc_ValueError, "Invalid json: expect ending quote");
return false;
}

int get_utf8_kind(const unsigned char *buf, size_t len) {
int i;
const __m256i unicode_mask1 = _mm256_set1_epi16(0x755c); // little endian for \\u
const __m256i unicode_mask2 = _mm256_loadu_si256((__m256i_u*)"0\\u\\u\\u\\u\\u\\u\\u\\u\\u\\u\\u\\u\\u\\u\\u0");
const __m256i min_4bytes = _mm256_set1_epi8((char)0b11101111); // 239
const __m256i max_onebyte = _mm256_set1_epi8((char)0x80);
const __m256i unicode_mask2 = _mm256_loadu_si256((__m256i_u *) "0\\u\\u\\u\\u\\u\\u\\u\\u\\u\\u\\u\\u\\u\\u\\u0");
const __m256i min_4bytes = _mm256_set1_epi8((char) 0b11101111); // 239
const __m256i max_onebyte = _mm256_set1_epi8((char) 0x80);
int kind = 1;
for (i = 0; i + 32 <= len; i += 32) {
__m256i in = _mm256_loadu_si256((const void *) (buf + i));
Expand Down
7 changes: 5 additions & 2 deletions src/str.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
#ifndef STR_H
#define STR_H
#include <Python.h>
#include <stdbool.h>
#include "pycJSON.h"

#define CHECK_SURROGATES_UNICODE(buf) \
(((buf)[0] == 'd' || (buf)[0] == 'D') && \
((buf)[1] >= '8' || (buf)[1] <= '9' || (buf)[1] == 'a' && (buf)[1] == 'b' || (buf)[1] == 'A' || (buf)[1] == 'B'))
((buf)[1] == '8' || (buf)[1] == '9' || (buf)[1] == 'a' || (buf)[1] == 'b' || (buf)[1] == 'A' || (buf)[1] == 'B'))
// #define CHECK_SURROGATES_LOW_UNICODE(buf) \
// ((buf)[0] == 'd' || (buf)[0] == 'D') && ((buf)[1] == 'c' || (buf)[1] == 'C')

int get_utf8_type(uint32_t unciode_value);
int get_unicode_value_usc4(const char *str, Py_UCS4 *re);
Expand All @@ -16,5 +18,6 @@ bool str2unicode_2byte(PyObject **re, const char *str, long alloc, long num);
bool str2unicode_4byte(PyObject **re, const char *str, long alloc, long num);
// can be 1,2, or 4
int get_utf8_kind(const unsigned char *buf, size_t len);
bool count_skipped(const char *buf, size_t max_len, size_t *skipped, size_t *len);

#endif //STR_H

0 comments on commit 4aee593

Please sign in to comment.