diff --git a/java/src/main/java/ai/rapids/cudf/Scalar.java b/java/src/main/java/ai/rapids/cudf/Scalar.java index 62dd9bda13b..7794b57c3f9 100644 --- a/java/src/main/java/ai/rapids/cudf/Scalar.java +++ b/java/src/main/java/ai/rapids/cudf/Scalar.java @@ -329,10 +329,19 @@ public static Scalar timestampFromLong(DType type, Long value) { } public static Scalar fromString(String value) { + return fromUTF8String(value == null ? null : value.getBytes(StandardCharsets.UTF_8)); + } + + /** + * Creates a String scalar from an array of UTF8 bytes. + * @param value the array of UTF8 bytes + * @return a String scalar + */ + public static Scalar fromUTF8String(byte[] value) { if (value == null) { return fromNull(DType.STRING); } - return new Scalar(DType.STRING, makeStringScalar(value.getBytes(StandardCharsets.UTF_8), true)); + return new Scalar(DType.STRING, makeStringScalar(value, true)); } /** diff --git a/java/src/test/java/ai/rapids/cudf/ScalarTest.java b/java/src/test/java/ai/rapids/cudf/ScalarTest.java index b09850bc3d9..a1078f2546b 100644 --- a/java/src/test/java/ai/rapids/cudf/ScalarTest.java +++ b/java/src/test/java/ai/rapids/cudf/ScalarTest.java @@ -27,6 +27,7 @@ import org.junit.jupiter.api.Test; import java.math.BigDecimal; +import java.nio.charset.StandardCharsets; import java.util.Arrays; import static ai.rapids.cudf.TableTest.assertColumnsAreEqual; @@ -244,6 +245,22 @@ public void testString() { } } + @Test + public void testUTF8String() { + try (Scalar s = Scalar.fromUTF8String("TEST".getBytes(StandardCharsets.UTF_8))) { + assertEquals(DType.STRING, s.getType()); + assertTrue(s.isValid()); + assertEquals("TEST", s.getJavaString()); + assertArrayEquals(new byte[]{'T', 'E', 'S', 'T'}, s.getUTF8()); + } + try (Scalar s = Scalar.fromUTF8String("".getBytes(StandardCharsets.UTF_8))) { + assertEquals(DType.STRING, s.getType()); + assertTrue(s.isValid()); + assertEquals("", s.getJavaString()); + assertArrayEquals(new byte[]{}, s.getUTF8()); + } + } + @Test public void testList() { // list of int