diff --git a/src/Init/Data/String/Extra.lean b/src/Init/Data/String/Extra.lean index c7dfc1a99b66..2d1615d6cf99 100644 --- a/src/Init/Data/String/Extra.lean +++ b/src/Init/Data/String/Extra.lean @@ -17,14 +17,25 @@ def toNat! (s : String) : Nat := else panic! "Nat expected" -/-- - Convert a [UTF-8](https://en.wikipedia.org/wiki/UTF-8) encoded `ByteArray` string to `String`. - The result is unspecified if `a` is not properly UTF-8 encoded. --/ -@[extern "lean_string_from_utf8_unchecked"] -opaque fromUTF8Unchecked (a : @& ByteArray) : String +/-- Returns true if the given byte array consists of valid UTF-8. -/ +@[extern "lean_string_validate_utf8"] +opaque validateUTF8 (a : @& ByteArray) : Bool + +/-- Converts a [UTF-8](https://en.wikipedia.org/wiki/UTF-8) encoded `ByteArray` string to `String`. -/ +@[extern "lean_string_from_utf8"] +opaque fromUTF8 (a : @& ByteArray) (h : validateUTF8 a) : String + +/-- Converts a [UTF-8](https://en.wikipedia.org/wiki/UTF-8) encoded `ByteArray` string to `String`, +or returns `none` if `a` is not properly UTF-8 encoded. -/ +@[inline] def fromUTF8? (a : ByteArray) : Option String := + if h : validateUTF8 a then fromUTF8 a h else none + +/-- Converts a [UTF-8](https://en.wikipedia.org/wiki/UTF-8) encoded `ByteArray` string to `String`, +or panics if `a` is not properly UTF-8 encoded. -/ +@[inline] def fromUTF8! (a : ByteArray) : String := + if h : validateUTF8 a then fromUTF8 a h else panic! "invalid UTF-8 string" -/-- Convert the given `String` to a [UTF-8](https://en.wikipedia.org/wiki/UTF-8) encoded byte array. -/ +/-- Converts the given `String` to a [UTF-8](https://en.wikipedia.org/wiki/UTF-8) encoded byte array. -/ @[extern "lean_string_to_utf8"] opaque toUTF8 (a : @& String) : ByteArray diff --git a/src/Init/System/IO.lean b/src/Init/System/IO.lean index 3353e9f8683b..3d53b77ecfbf 100644 --- a/src/Init/System/IO.lean +++ b/src/Init/System/IO.lean @@ -768,12 +768,16 @@ def ofBuffer (r : Ref Buffer) : Stream where write := fun data => r.modify fun b => -- set `exact` to `false` so that repeatedly writing to the stream does not impose quadratic run time { b with data := data.copySlice 0 b.data b.pos data.size false, pos := b.pos + data.size } - getLine := r.modifyGet fun b => - let pos := match b.data.findIdx? (start := b.pos) fun u => u == 0 || u = '\n'.toNat.toUInt8 with - -- include '\n', but not '\0' - | some pos => if b.data.get! pos == 0 then pos else pos + 1 - | none => b.data.size - (String.fromUTF8Unchecked <| b.data.extract b.pos pos, { b with pos := pos }) + getLine := do + let buf ← r.modifyGet fun b => + let pos := match b.data.findIdx? (start := b.pos) fun u => u == 0 || u = '\n'.toNat.toUInt8 with + -- include '\n', but not '\0' + | some pos => if b.data.get! pos == 0 then pos else pos + 1 + | none => b.data.size + (b.data.extract b.pos pos, { b with pos := pos }) + match String.fromUTF8? buf with + | some str => pure str + | none => throw (.userError "invalid UTF-8") putStr := fun s => r.modify fun b => let data := s.toUTF8 { b with data := data.copySlice 0 b.data b.pos data.size false, pos := b.pos + data.size } @@ -791,7 +795,7 @@ def withIsolatedStreams [Monad m] [MonadFinally m] [MonadLiftT BaseIO m] (x : m (if isolateStderr then withStderr (Stream.ofBuffer bOut) else id) <| x let bOut ← liftM (m := BaseIO) bOut.get - let out := String.fromUTF8Unchecked bOut.data + let out := String.fromUTF8! bOut.data pure (out, r) end FS diff --git a/src/Init/System/Uri.lean b/src/Init/System/Uri.lean index 64f36becb0e0..5e7cf47c3f52 100644 --- a/src/Init/System/Uri.lean +++ b/src/Init/System/Uri.lean @@ -50,7 +50,7 @@ def decodeUri (uri : String) : String := Id.run do ((decoded.push c).push h1, i + 2) else (decoded.push c, i + 1) - return String.fromUTF8Unchecked decoded + return String.fromUTF8! decoded where hexDigitToUInt8? (c : UInt8) : Option UInt8 := if zero ≤ c ∧ c ≤ nine then some (c - zero) else if lettera ≤ c ∧ c ≤ letterf then some (c - lettera + 10) diff --git a/src/Lean/Data/Json/Stream.lean b/src/Lean/Data/Json/Stream.lean index 080d4274fa47..cfd58eb52075 100644 --- a/src/Lean/Data/Json/Stream.lean +++ b/src/Lean/Data/Json/Stream.lean @@ -18,7 +18,7 @@ open IO /-- Consumes `nBytes` bytes from the stream, interprets the bytes as a utf-8 string and the string as a valid JSON object. -/ def readJson (h : FS.Stream) (nBytes : Nat) : IO Json := do let bytes ← h.read (USize.ofNat nBytes) - let s := String.fromUTF8Unchecked bytes + let some s := String.fromUTF8? bytes | throw (IO.userError "invalid UTF-8") ofExcept (Json.parse s) def writeJson (h : FS.Stream) (j : Json) : IO Unit := do diff --git a/src/lake/tests/toml/Test.lean b/src/lake/tests/toml/Test.lean index dea7b1e13428..5d025bfacab6 100644 --- a/src/lake/tests/toml/Test.lean +++ b/src/lake/tests/toml/Test.lean @@ -24,32 +24,6 @@ inductive TomlOutcome where | fail (log : MessageLog) | error (e : IO.Error) -@[inline] def Fin.allM [Monad m] (n) (f : Fin n → m Bool) : m Bool := - loop 0 -where - loop (i : Nat) : m Bool := do - if h : i < n then - if (← f ⟨i, h⟩) then loop (i+1) else pure false - else - pure true - termination_by n - i - -@[inline] def Fin.all (n) (f : Fin n → Bool) : Bool := - Id.run <| allM n f - -def bytesBEq (a b : ByteArray) : Bool := - if h_size : a.size = b.size then - Fin.all a.size fun i => a[i] = b[i]'(h_size ▸ i.isLt) - else - false - -def String.fromUTF8 (bytes : ByteArray) : String := - String.fromUTF8Unchecked bytes |>.map id - -@[inline] def String.fromUTF8? (bytes : ByteArray) : Option String := - let s := String.fromUTF8 bytes - if bytesBEq s.toUTF8 bytes then some s else none - nonrec def loadToml (tomlFile : FilePath) : BaseIO TomlOutcome := do let fileName := tomlFile.fileName.getD tomlFile.toString let input ← diff --git a/src/runtime/object.cpp b/src/runtime/object.cpp index 8d9c04b72b42..89af9b2604fa 100644 --- a/src/runtime/object.cpp +++ b/src/runtime/object.cpp @@ -1614,10 +1614,14 @@ extern "C" LEAN_EXPORT object * lean_mk_string(char const * s) { return lean_mk_string_from_bytes(s, strlen(s)); } -extern "C" LEAN_EXPORT obj_res lean_string_from_utf8_unchecked(b_obj_arg a) { +extern "C" LEAN_EXPORT obj_res lean_string_from_utf8(b_obj_arg a) { return lean_mk_string_from_bytes(reinterpret_cast(lean_sarray_cptr(a)), lean_sarray_size(a)); } +extern "C" LEAN_EXPORT uint8 lean_string_validate_utf8(b_obj_arg a) { + return validate_utf8(lean_sarray_cptr(a), lean_sarray_size(a)); +} + extern "C" LEAN_EXPORT obj_res lean_string_to_utf8(b_obj_arg s) { size_t sz = lean_string_size(s) - 1; obj_res r = lean_alloc_sarray(1, sz, sz); @@ -1741,38 +1745,38 @@ extern "C" LEAN_EXPORT obj_res lean_string_data(obj_arg s) { static bool lean_string_utf8_get_core(char const * str, usize size, usize i, uint32 & result) { unsigned c = static_cast(str[i]); - /* zero continuation (0 to 127) */ + /* zero continuation (0 to 0x7F) */ if ((c & 0x80) == 0) { result = c; return true; } - /* one continuation (128 to 2047) */ + /* one continuation (0x80 to 0x7FF) */ if ((c & 0xe0) == 0xc0 && i + 1 < size) { unsigned c1 = static_cast(str[i+1]); result = ((c & 0x1f) << 6) | (c1 & 0x3f); - if (result >= 128) { + if (result >= 0x80) { return true; } } - /* two continuations (2048 to 55295 and 57344 to 65535) */ + /* two continuations (0x800 to 0xD7FF and 0xE000 to 0xFFFF) */ if ((c & 0xf0) == 0xe0 && i + 2 < size) { unsigned c1 = static_cast(str[i+1]); unsigned c2 = static_cast(str[i+2]); result = ((c & 0x0f) << 12) | ((c1 & 0x3f) << 6) | (c2 & 0x3f); - if (result >= 2048 && (result < 55296 || result > 57343)) { + if (result >= 0x800 && (result < 0xD800 || result > 0xDFFF)) { return true; } } - /* three continuations (65536 to 1114111) */ + /* three continuations (0x10000 to 0x10FFFF) */ if ((c & 0xf8) == 0xf0 && i + 3 < size) { unsigned c1 = static_cast(str[i+1]); unsigned c2 = static_cast(str[i+2]); unsigned c3 = static_cast(str[i+3]); result = ((c & 0x07) << 18) | ((c1 & 0x3f) << 12) | ((c2 & 0x3f) << 6) | (c3 & 0x3f); - if (result >= 65536 && result <= 1114111) { + if (result >= 0x10000 && result <= 0x10FFFF) { return true; } } @@ -1810,32 +1814,32 @@ extern "C" LEAN_EXPORT uint32 lean_string_utf8_get(b_obj_arg s, b_obj_arg i0) { } extern "C" LEAN_EXPORT uint32_t lean_string_utf8_get_fast_cold(char const * str, size_t i, size_t size, unsigned char c) { - /* one continuation (128 to 2047) */ + /* one continuation (0x80 to 0x7FF) */ if ((c & 0xe0) == 0xc0 && i + 1 < size) { unsigned c1 = static_cast(str[i+1]); uint32_t result = ((c & 0x1f) << 6) | (c1 & 0x3f); - if (result >= 128) { + if (result >= 0x80) { return result; } } - /* two continuations (2048 to 55295 and 57344 to 65535) */ + /* two continuations (0x800 to 0xD7FF and 0xE000 to 0xFFFF) */ if ((c & 0xf0) == 0xe0 && i + 2 < size) { unsigned c1 = static_cast(str[i+1]); unsigned c2 = static_cast(str[i+2]); uint32_t result = ((c & 0x0f) << 12) | ((c1 & 0x3f) << 6) | (c2 & 0x3f); - if (result >= 2048 && (result < 55296 || result > 57343)) { + if (result >= 0x800 && (result < 0xD800 || result > 0xDFFF)) { return result; } } - /* three continuations (65536 to 1114111) */ + /* three continuations (0x10000 to 0x10FFFF) */ if ((c & 0xf8) == 0xf0 && i + 3 < size) { unsigned c1 = static_cast(str[i+1]); unsigned c2 = static_cast(str[i+2]); unsigned c3 = static_cast(str[i+3]); uint32_t result = ((c & 0x07) << 18) | ((c1 & 0x3f) << 12) | ((c2 & 0x3f) << 6) | (c3 & 0x3f); - if (result >= 65536 && result <= 1114111) { + if (result >= 0x10000 && result <= 0x10FFFF) { return result; } } diff --git a/src/runtime/utf8.cpp b/src/runtime/utf8.cpp index 06b114a50607..1b0c65c1de7e 100644 --- a/src/runtime/utf8.cpp +++ b/src/runtime/utf8.cpp @@ -113,7 +113,7 @@ unsigned utf8_to_unicode(uchar const * begin, uchar const * end) { auto it = begin; unsigned c = *it; ++it; - if (c < 128) + if (c < 0x80) return c; unsigned mask = (1u << 6) -1; unsigned hmask = mask; @@ -164,40 +164,40 @@ optional get_utf8_first_byte_opt(unsigned char c) { unsigned next_utf8(char const * str, size_t size, size_t & i) { unsigned c = static_cast(str[i]); - /* zero continuation (0 to 127) */ + /* zero continuation (0 to 0x7F) */ if ((c & 0x80) == 0) { i++; return c; } - /* one continuation (128 to 2047) */ + /* one continuation (0x80 to 0x7FF) */ if ((c & 0xe0) == 0xc0 && i + 1 < size) { unsigned c1 = static_cast(str[i+1]); unsigned r = ((c & 0x1f) << 6) | (c1 & 0x3f); - if (r >= 128) { + if (r >= 0x80) { i += 2; return r; } } - /* two continuations (2048 to 55295 and 57344 to 65535) */ + /* two continuations (0x800 to 0xD7FF and 0xE000 to 0xFFFF) */ if ((c & 0xf0) == 0xe0 && i + 2 < size) { unsigned c1 = static_cast(str[i+1]); unsigned c2 = static_cast(str[i+2]); unsigned r = ((c & 0x0f) << 12) | ((c1 & 0x3f) << 6) | (c2 & 0x3f); - if (r >= 2048 && (r < 55296 || r > 57343)) { + if (r >= 0x800 && (r < 0xD800 || r > 0xDFFF)) { i += 3; return r; } } - /* three continuations (65536 to 1114111) */ + /* three continuations (0x10000 to 0x10FFFF) */ if ((c & 0xf8) == 0xf0 && i + 3 < size) { unsigned c1 = static_cast(str[i+1]); unsigned c2 = static_cast(str[i+2]); unsigned c3 = static_cast(str[i+3]); unsigned r = ((c & 0x07) << 18) | ((c1 & 0x3f) << 12) | ((c2 & 0x3f) << 6) | (c3 & 0x3f); - if (r >= 65536 && r <= 1114111) { + if (r >= 0x10000 && r <= 0x10FFFF) { i += 4; return r; } @@ -220,6 +220,56 @@ void utf8_decode(std::string const & str, std::vector & out) { } } +bool validate_utf8(uint8_t const * str, size_t size) { + size_t i = 0; + while (i < size) { + unsigned c = str[i]; + if ((c & 0x80) == 0) { + /* zero continuation (0 to 0x7F) */ + i++; + } else if ((c & 0xe0) == 0xc0) { + /* one continuation (0x80 to 0x7FF) */ + if (i + 1 >= size) return false; + + unsigned c1 = str[i+1]; + if ((c1 & 0xc0) != 0x80) return false; + + unsigned r = ((c & 0x1f) << 6) | (c1 & 0x3f); + if (r < 0x80) return false; + + i += 2; + } else if ((c & 0xf0) == 0xe0) { + /* two continuations (0x800 to 0xD7FF and 0xE000 to 0xFFFF) */ + if (i + 2 >= size) return false; + + unsigned c1 = str[i+1]; + unsigned c2 = str[i+2]; + if ((c1 & 0xc0) != 0x80 || (c2 & 0xc0) != 0x80) return false; + + unsigned r = ((c & 0x0f) << 12) | ((c1 & 0x3f) << 6) | (c2 & 0x3f); + if (r < 0x800 || (r >= 0xD800 && r < 0xDFFF)) return false; + + i += 3; + } else if ((c & 0xf8) == 0xf0) { + /* three continuations (0x10000 to 0x10FFFF) */ + if (i + 3 >= size) return false; + + unsigned c1 = str[i+1]; + unsigned c2 = str[i+2]; + unsigned c3 = str[i+3]; + if ((c1 & 0xc0) != 0x80 || (c2 & 0xc0) != 0x80 || (c3 & 0xc0) != 0x80) return false; + + unsigned r = ((c & 0x07) << 18) | ((c1 & 0x3f) << 12) | ((c2 & 0x3f) << 6) | (c3 & 0x3f); + if (r < 0x10000 || r > 0x10FFFF) return false; + + i += 4; + } else { + return false; + } + } + return true; +} + #define TAG_CONT static_cast(0b10000000) #define TAG_TWO_B static_cast(0b11000000) #define TAG_THREE_B static_cast(0b11100000) diff --git a/src/runtime/utf8.h b/src/runtime/utf8.h index 3eeb6c96ff82..3ba942c96588 100644 --- a/src/runtime/utf8.h +++ b/src/runtime/utf8.h @@ -45,6 +45,9 @@ LEAN_EXPORT unsigned next_utf8(char const * str, size_t size, size_t & i); /* Decode a UTF-8 encoded string `str` into unicode scalar values */ LEAN_EXPORT void utf8_decode(std::string const & str, std::vector & out); +/* Returns true if the provided string is valid UTF-8 */ +LEAN_EXPORT bool validate_utf8(uint8_t const * str, size_t size); + /* Push a unicode scalar value into a utf-8 encoded string */ LEAN_EXPORT void push_unicode_scalar(std::string & s, unsigned code); diff --git "a/tests/lean/run/utf8\350\213\261\350\252\236.lean" "b/tests/lean/run/utf8\350\213\261\350\252\236.lean" index 1f655d71d564..a40fffe10a8d 100644 --- "a/tests/lean/run/utf8\350\213\261\350\252\236.lean" +++ "b/tests/lean/run/utf8\350\213\261\350\252\236.lean" @@ -4,10 +4,13 @@ def check_eq {α} [BEq α] [Repr α] (tag : String) (expected actual : α) : IO s!"assertion failure \"{tag}\":\n expected: {repr expected}\n actual: {repr actual}" def DecodeUTF8: IO Unit := do - let cs := String.toList "Hello, 英語!" + let str := "Hello, 英語!" + let cs := String.toList str let ns := cs.map Char.toNat IO.println cs IO.println ns check_eq "utf-8 chars" [72, 101, 108, 108, 111, 44, 32, 33521, 35486, 33] ns + check_eq "utf-8 bytes" #[72, 101, 108, 108, 111, 44, 32, 232, 139, 177, 232, 170, 158, 33] str.toUTF8.data + check_eq "string eq" (some str) (String.fromUTF8? str.toUTF8) -#eval DecodeUTF8 \ No newline at end of file +#eval DecodeUTF8