diff --git a/cpp/src/io/fst/lookup_tables.cuh b/cpp/src/io/fst/lookup_tables.cuh index 37c99453361..42036b79751 100644 --- a/cpp/src/io/fst/lookup_tables.cuh +++ b/cpp/src/io/fst/lookup_tables.cuh @@ -753,7 +753,7 @@ class TranslationOp { RelativeOffsetT const relative_offset, SymbolT const read_symbol) const { - return translation_op(*this, state_id, match_id, relative_offset, read_symbol); + return translation_op(state_id, match_id, relative_offset, read_symbol); } template @@ -761,7 +761,7 @@ class TranslationOp { SymbolIndexT const match_id, SymbolT const read_symbol) const { - return translation_op(*this, state_id, match_id, read_symbol); + return translation_op(state_id, match_id, read_symbol); } }; diff --git a/cpp/src/io/json/nested_json_gpu.cu b/cpp/src/io/json/nested_json_gpu.cu index c9107357239..3702d94fd2b 100644 --- a/cpp/src/io/json/nested_json_gpu.cu +++ b/cpp/src/io/json/nested_json_gpu.cu @@ -91,6 +91,98 @@ void check_input_size(std::size_t input_size) namespace cudf::io::json { +// FST to help fixing the stack context of characters that follow the first record on each JSON line +namespace fix_stack_of_excess_chars { + +// Type used to represent the target state in the transition table +using StateT = char; + +// Type used to represent a symbol group id +using SymbolGroupT = uint8_t; + +/** + * @brief Definition of the DFA's states. + */ +enum class dfa_states : StateT { + // Before the first record on the JSON line + BEFORE, + // Within the first record on the JSON line + WITHIN, + // Excess data that follows the first record on the JSON line + EXCESS, + // Total number of states + NUM_STATES +}; + +/** + * @brief Definition of the symbol groups + */ +enum class dfa_symbol_group_id : SymbolGroupT { + ROOT, ///< Symbol for root stack context + DELIMITER, ///< Line delimiter symbol group + OTHER, ///< Symbol group that implicitly matches all other tokens + NUM_SYMBOL_GROUPS ///< Total number of symbol groups +}; + +constexpr auto TT_NUM_STATES = static_cast(dfa_states::NUM_STATES); +constexpr auto NUM_SYMBOL_GROUPS = static_cast(dfa_symbol_group_id::NUM_SYMBOL_GROUPS); + +/** + * @brief Function object to map (input_symbol,stack_context) tuples to a symbol group. + */ +struct SymbolPairToSymbolGroupId { + CUDF_HOST_DEVICE SymbolGroupT operator()(thrust::tuple symbol) const + { + auto const input_symbol = thrust::get<0>(symbol); + auto const stack_symbol = thrust::get<1>(symbol); + return static_cast( + input_symbol == '\n' + ? dfa_symbol_group_id::DELIMITER + : (stack_symbol == '_' ? dfa_symbol_group_id::ROOT : dfa_symbol_group_id::OTHER)); + } +}; + +/** + * @brief Translation function object that fixes the stack context of excess data that follows after + * the first JSON record on each line. + */ +struct TransduceInputOp { + template + constexpr CUDF_HOST_DEVICE StackSymbolT operator()(StateT const state_id, + SymbolGroupT const match_id, + RelativeOffsetT const relative_offset, + SymbolT const read_symbol) const + { + if (state_id == static_cast(dfa_states::EXCESS)) { return '_'; } + return thrust::get<1>(read_symbol); + } + + template + constexpr CUDF_HOST_DEVICE int32_t operator()(StateT const state_id, + SymbolGroupT const match_id, + SymbolT const read_symbol) const + { + constexpr int32_t single_output_item = 1; + return single_output_item; + } +}; + +// Aliases for readability of the transition table +constexpr auto TT_BEFORE = dfa_states::BEFORE; +constexpr auto TT_INSIDE = dfa_states::WITHIN; +constexpr auto TT_EXCESS = dfa_states::EXCESS; + +// Transition table +std::array, TT_NUM_STATES> constexpr transition_table{ + {/* IN_STATE ROOT NEWLINE OTHER */ + /* TT_BEFORE */ {{TT_BEFORE, TT_BEFORE, TT_INSIDE}}, + /* TT_INSIDE */ {{TT_EXCESS, TT_BEFORE, TT_INSIDE}}, + /* TT_EXCESS */ {{TT_EXCESS, TT_BEFORE, TT_EXCESS}}}}; + +// The DFA's starting state +constexpr auto start_state = static_cast(dfa_states::BEFORE); +} // namespace fix_stack_of_excess_chars + // FST to prune tokens of invalid lines for recovering JSON lines format namespace token_filter { @@ -146,9 +238,8 @@ struct UnwrapTokenFromSymbolOp { * invalid lines. */ struct TransduceToken { - template - constexpr CUDF_HOST_DEVICE SymbolT operator()(TransducerTableT const&, - StateT const state_id, + template + constexpr CUDF_HOST_DEVICE SymbolT operator()(StateT const state_id, SymbolGroupT const match_id, RelativeOffsetT const relative_offset, SymbolT const read_symbol) const @@ -165,9 +256,8 @@ struct TransduceToken { } } - template - constexpr CUDF_HOST_DEVICE int32_t operator()(TransducerTableT const&, - StateT const state_id, + template + constexpr CUDF_HOST_DEVICE int32_t operator()(StateT const state_id, SymbolGroupT const match_id, SymbolT const read_symbol) const { @@ -643,6 +733,11 @@ auto get_transition_table(json_format_cfg_t format) // PD_ANL describes the target state after a new line after encountering error state auto const PD_ANL = (format == json_format_cfg_t::JSON_LINES_RECOVER) ? PD_BOV : PD_ERR; + // Target state after having parsed the first JSON value on a JSON line + // Spark has the special need to ignore everything that comes after the first JSON object + // on a JSON line instead of marking those as invalid + auto const PD_AFS = (format == json_format_cfg_t::JSON_LINES_RECOVER) ? PD_PVL : PD_ERR; + // First row: empty stack ("root" level of the JSON) // Second row: '[' on top of stack (we're parsing a list value) // Third row: '{' on top of stack (we're parsing a struct value) @@ -668,7 +763,7 @@ auto get_transition_table(json_format_cfg_t format) PD_STR, PD_STR, PD_STR, PD_STR, PD_STR, PD_STR, PD_STR, PD_STR, PD_STR, PD_BOV, PD_STR, PD_STR, PD_STR, PD_STR, PD_STR, PD_STR, PD_STR, PD_STR, PD_STR, PD_STR, PD_BOV, PD_STR}; pda_tt[static_cast(pda_state_t::PD_PVL)] = { - PD_ERR, PD_ERR, PD_ERR, PD_ERR, PD_ERR, PD_ERR, PD_ERR, PD_ERR, PD_PVL, PD_BOV, PD_ERR, + PD_AFS, PD_AFS, PD_AFS, PD_AFS, PD_AFS, PD_AFS, PD_AFS, PD_AFS, PD_PVL, PD_BOV, PD_AFS, PD_ERR, PD_ERR, PD_ERR, PD_PVL, PD_ERR, PD_ERR, PD_BOV, PD_ERR, PD_PVL, PD_BOV, PD_ERR, PD_ERR, PD_ERR, PD_PVL, PD_ERR, PD_ERR, PD_ERR, PD_BFN, PD_ERR, PD_PVL, PD_BOV, PD_ERR}; pda_tt[static_cast(pda_state_t::PD_BFN)] = { @@ -733,6 +828,18 @@ auto get_translation_table(bool recover_from_error) return regular_tokens; }; + /** + * @brief Helper function that returns `recovering_tokens` if `recover_from_error` is true and + * returns `regular_tokens` otherwise. This is used to ignore excess characters after the first + * value in the case of JSON lines that recover from invalid lines, as Spark ignores any excess + * characters that follow the first record on a JSON line. + */ + auto alt_tokens = [recover_from_error](std::vector regular_tokens, + std::vector recovering_tokens) { + if (recover_from_error) { return recovering_tokens; } + return regular_tokens; + }; + std::array, NUM_PDA_SGIDS>, PD_NUM_STATES> pda_tlt; pda_tlt[static_cast(pda_state_t::PD_BOV)] = {{ /*ROOT*/ {StructBegin}, // OPENING_BRACE @@ -920,18 +1027,18 @@ auto get_translation_table(bool recover_from_error) {}}}; // OTHER pda_tlt[static_cast(pda_state_t::PD_PVL)] = { - { /*ROOT*/ - {ErrorBegin}, // OPENING_BRACE - {ErrorBegin}, // OPENING_BRACKET - {ErrorBegin}, // CLOSING_BRACE - {ErrorBegin}, // CLOSING_BRACKET - {ErrorBegin}, // QUOTE - {ErrorBegin}, // ESCAPE - {ErrorBegin}, // COMMA - {ErrorBegin}, // COLON - {}, // WHITE_SPACE - nl_tokens({}, {}), // LINE_BREAK - {ErrorBegin}, // OTHER + { /*ROOT*/ + {alt_tokens({ErrorBegin}, {})}, // OPENING_BRACE + {alt_tokens({ErrorBegin}, {})}, // OPENING_BRACKET + {alt_tokens({ErrorBegin}, {})}, // CLOSING_BRACE + {alt_tokens({ErrorBegin}, {})}, // CLOSING_BRACKET + {alt_tokens({ErrorBegin}, {})}, // QUOTE + {alt_tokens({ErrorBegin}, {})}, // ESCAPE + {alt_tokens({ErrorBegin}, {})}, // COMMA + {alt_tokens({ErrorBegin}, {})}, // COLON + {}, // WHITE_SPACE + nl_tokens({}, {}), // LINE_BREAK + {alt_tokens({ErrorBegin}, {})}, // OTHER /*LIST*/ {ErrorBegin}, // OPENING_BRACE {ErrorBegin}, // OPENING_BRACKET @@ -1446,6 +1553,26 @@ std::pair, rmm::device_uvector> ge // character. auto zip_in = thrust::make_zip_iterator(json_in.data(), stack_symbols.data()); + // Spark, as the main stakeholder in the `recover_from_error` option, has the specific need to + // ignore any characters that follow the first value on each JSON line. This is an FST that + // fixes the stack context for those excess characters. That is, that all those excess characters + // will be interpreted in the root stack context + if (recover_from_error) { + auto fix_stack_of_excess_chars = fst::detail::make_fst( + fst::detail::make_symbol_group_lookup_op( + fix_stack_of_excess_chars::SymbolPairToSymbolGroupId{}), + fst::detail::make_transition_table(fix_stack_of_excess_chars::transition_table), + fst::detail::make_translation_functor(fix_stack_of_excess_chars::TransduceInputOp{}), + stream); + fix_stack_of_excess_chars.Transduce(zip_in, + static_cast(json_in.size()), + stack_symbols.data(), + thrust::make_discard_iterator(), + thrust::make_discard_iterator(), + fix_stack_of_excess_chars::start_state, + stream); + } + constexpr auto max_translation_table_size = tokenizer_pda::NUM_PDA_SGIDS * static_cast(tokenizer_pda::pda_state_t::PD_NUM_STATES); diff --git a/cpp/tests/io/json_test.cpp b/cpp/tests/io/json_test.cpp index 2ddb0b76544..0149a467c32 100644 --- a/cpp/tests/io/json_test.cpp +++ b/cpp/tests/io/json_test.cpp @@ -1957,11 +1957,11 @@ TEST_F(JsonReaderTest, JSONLinesRecovering) // 2 -> (invalid) R"({"b":{"a":[321})" "\n" - // 3 -> c: [1] (valid) + // 3 -> c: 1.2 (valid) R"({"c":1.2})" "\n" "\n" - // 4 -> a: 123 (valid) + // 4 -> a: 4 (valid) R"({"a":4})" "\n" // 5 -> (invalid) @@ -2020,4 +2020,71 @@ TEST_F(JsonReaderTest, JSONLinesRecovering) c_validity.cbegin()}); } +TEST_F(JsonReaderTest, JSONLinesRecoveringIgnoreExcessChars) +{ + /** + * @brief Spark has the specific need to ignore extra characters that come after the first record + * on a JSON line + */ + std::string data = + // 0 -> a: -2 (valid) + R"({"a":-2}{})" + "\n" + // 1 -> (invalid) + R"({"b":{}should_be_invalid})" + "\n" + // 2 -> b (valid) + R"({"b":{"a":3} })" + "\n" + // 3 -> c: (valid) + R"({"c":1.2 } )" + "\n" + "\n" + // 4 -> (valid) + R"({"a":4} 123)" + "\n" + // 5 -> (valid) + R"({"a":5}//Comment after record)" + "\n" + // 6 -> (valid) + R"({"a":6} //Comment after whitespace)" + "\n" + // 7 -> (invalid) + R"({"a":5 //Invalid Comment within record})"; + + auto filepath = temp_env->get_temp_dir() + "RecoveringLinesExcessChars.json"; + { + std::ofstream outfile(filepath, std::ofstream::out); + outfile << data; + } + + cudf::io::json_reader_options in_options = + cudf::io::json_reader_options::builder(cudf::io::source_info{filepath}) + .lines(true) + .recovery_mode(cudf::io::json_recovery_mode_t::RECOVER_WITH_NULL); + + cudf::io::table_with_metadata result = cudf::io::read_json(in_options); + + EXPECT_EQ(result.tbl->num_columns(), 3); + EXPECT_EQ(result.tbl->num_rows(), 8); + EXPECT_EQ(result.tbl->get_column(0).type().id(), cudf::type_id::INT64); + EXPECT_EQ(result.tbl->get_column(1).type().id(), cudf::type_id::STRUCT); + EXPECT_EQ(result.tbl->get_column(2).type().id(), cudf::type_id::FLOAT64); + + std::vector a_validity{true, false, false, false, true, true, true, false}; + std::vector b_validity{false, false, true, false, false, false, false, false}; + std::vector c_validity{false, false, false, true, false, false, false, false}; + + // Child column b->a + auto b_a_col = int64_wrapper({0, 0, 3, 0, 0, 0, 0, 0}); + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(result.tbl->get_column(0), + int64_wrapper{{-2, 0, 0, 0, 4, 5, 6, 0}, a_validity.cbegin()}); + CUDF_TEST_EXPECT_COLUMNS_EQUAL( + result.tbl->get_column(1), cudf::test::structs_column_wrapper({b_a_col}, b_validity.cbegin())); + CUDF_TEST_EXPECT_COLUMNS_EQUAL( + result.tbl->get_column(2), + float64_wrapper{{0.0, 0.0, 0.0, 1.2, 0.0, 0.0, 0.0, 0.0}, c_validity.cbegin()}); +} + CUDF_TEST_PROGRAM_MAIN() diff --git a/cpp/tests/io/nested_json_test.cpp b/cpp/tests/io/nested_json_test.cpp index 3cb7e1f287a..5f79d5b862b 100644 --- a/cpp/tests/io/nested_json_test.cpp +++ b/cpp/tests/io/nested_json_test.cpp @@ -543,7 +543,7 @@ TEST_F(JsonTest, RecoveringTokenStream) { // Test input. Inline comments used to indicate character indexes // 012345678 <= line 0 - std::string const input = R"({"a":-2},)" + std::string const input = R"({"a":2 {})" // 9 "\n" // 01234 <= line 1