Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

When removeNullBytes is set, length calculations did not take into account null bytes. #17232

Merged
merged 4 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,14 @@ static long writeUtf8ByteBuffers(
written++;

if (len > 0) {
FrameWriterUtils.copyByteBufferToMemoryDisallowingNullBytes(
int lenWritten = FrameWriterUtils.copyByteBufferToMemoryDisallowingNullBytes(
utf8Datum,
memory,
position + written,
len,
removeNullBytes
);
written += len;
written += lenWritten;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,29 +212,36 @@ public static void copyByteBufferToMemoryAllowingNullBytes(

/**
* Copies {@code src} to {@code dst}, disallowing null bytes to be written to the destination. If {@code removeNullBytes}
* is true, the method will drop the null bytes, and if it is false, the method will throw an exception.
* is true, the method will drop the null bytes, and if it is false, the method will throw an exception. The written bytes
* can be less than "len" if the null bytes are dropped, and the callers must evaluate the return value to see the actual
* length of the buffer that is copied
*/
public static void copyByteBufferToMemoryDisallowingNullBytes(
public static int copyByteBufferToMemoryDisallowingNullBytes(
final ByteBuffer src,
final WritableMemory dst,
final long dstPosition,
final int len,
final boolean removeNullBytes
)
{
copyByteBufferToMemory(src, dst, dstPosition, len, false, removeNullBytes);
return copyByteBufferToMemory(src, dst, dstPosition, len, false, removeNullBytes);
}

/**
* Copies "len" bytes from {@code src.position()} to "dstPosition" in "memory". Does not update the position of src.
* Tries to copy "len" bytes from {@code src.position()} to "dstPosition" in "memory". If removeNullBytes is set to true,
* it will remove the U+0000 bytes from the src buffer, and the written bytes will be less than "len". It is imperative that the
* callers check the number of written bytes when "removeNullBytes" can be set to true, i.e. this method is invoked via
* {@link #copyByteBufferToMemoryDisallowingNullBytes}
* <p>
* Does not update the position of src.
* <p>
* Whenever "allowNullBytes" is true, "removeNullBytes" must be false. Use the methods {@link #copyByteBufferToMemoryAllowingNullBytes}
* and {@link #copyByteBufferToMemoryDisallowingNullBytes} to copy between the memory
* <p>
*
* @throws InvalidNullByteException if "allowNullBytes" and "removeNullBytes" is false and a null byte is encountered
*/
private static void copyByteBufferToMemory(
private static int copyByteBufferToMemory(
final ByteBuffer src,
final WritableMemory dst,
final long dstPosition,
Expand All @@ -251,6 +258,7 @@ private static void copyByteBufferToMemory(
}

final int srcEnd = src.position() + len;
int writtenLength = 0;

if (allowNullBytes) {
if (src.hasArray()) {
Expand All @@ -264,6 +272,8 @@ private static void copyByteBufferToMemory(
dst.putByte(q, b);
}
}
// The method does not alter the length of the memory copied if null bytes are allowed
writtenLength = len;
} else {
long q = dstPosition;
for (int p = src.position(); p < srcEnd; p++) {
Expand All @@ -282,9 +292,11 @@ private static void copyByteBufferToMemory(
} else {
dst.putByte(q, b);
q++;
writtenLength++;
}
}
}
return writtenLength;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import org.apache.datasketches.memory.WritableMemory;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.frame.write.InvalidNullByteException;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.segment.ColumnValueSelector;
Expand All @@ -40,9 +41,11 @@
import org.mockito.quality.Strictness;

import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;

public class StringFieldWriterTest extends InitializedNullHandlingTest
{
Expand All @@ -57,23 +60,45 @@ public class StringFieldWriterTest extends InitializedNullHandlingTest
@Mock
public DimensionSelector selectorUtf8;


private WritableMemory memory;
private FieldWriter fieldWriter;
private FieldWriter fieldWriterUtf8;
private FieldWriter fieldWriterRemoveNull;
private FieldWriter fieldWriterUtf8RemoveNull;

@Before
public void setUp()
{
memory = WritableMemory.allocate(1000);
fieldWriter = new StringFieldWriter(selector, false);
fieldWriterUtf8 = new StringFieldWriter(selectorUtf8, false);
fieldWriterRemoveNull = new StringFieldWriter(selector, true);
fieldWriterUtf8RemoveNull = new StringFieldWriter(selectorUtf8, true);
}

@After
public void tearDown()
{
fieldWriter.close();
fieldWriterUtf8.close();
for (FieldWriter fw : getFieldWriter(FieldWritersType.ALL)) {
try {
fw.close();
}
catch (Exception ignore) {
}
}
}


private List<FieldWriter> getFieldWriter(FieldWritersType fieldWritersType)
{
if (fieldWritersType == FieldWritersType.NULL_REPLACING) {
return Arrays.asList(fieldWriterRemoveNull, fieldWriterUtf8RemoveNull);
} else if (fieldWritersType == FieldWritersType.ALL) {
return Arrays.asList(fieldWriter, fieldWriterUtf8, fieldWriterRemoveNull, fieldWriterUtf8RemoveNull);
} else {
throw new ISE("Handler missing for type:[%s]", fieldWritersType);
}
}

@Test
Expand All @@ -100,31 +125,63 @@ public void testMultiValueString()
doTest(Arrays.asList("foo", "bar"));
}


@Test
public void testMultiValueStringContainingNulls()
{
doTest(Arrays.asList("foo", NullHandling.emptyToNullIfNeeded(""), "bar", null));
}

@Test
Copy link
Contributor

@LakshSingla LakshSingla Oct 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This name is misleading because it tests "nulls" instead of "null bytes". The string should be like "foo\u0000" in order for it to have the null byte. Does this test fail without the changes to the StringFieldWriter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just pushed up a new patch.

public void testNullByteReplacement()
{
doTest(
Arrays.asList("abc\u0000", "foo" + NullHandling.emptyToNullIfNeeded("") + "bar", "def"),
FieldWritersType.NULL_REPLACING
);
}

@Test
public void testNullByteNotReplaced()
{
mockSelectors(Arrays.asList("abc\u0000", "foo" + NullHandling.emptyToNullIfNeeded("") + "bar", "def"));
Assert.assertThrows(InvalidNullByteException.class, () -> {
doTestWithSpecificFieldWriter(fieldWriter);
});
Assert.assertThrows(InvalidNullByteException.class, () -> {
doTestWithSpecificFieldWriter(fieldWriterUtf8);
});
}

private void doTest(final List<String> values)
{
doTest(values, FieldWritersType.ALL);
}

private void doTest(final List<String> values, FieldWritersType fieldWritersType)
{
mockSelectors(values);

// Non-UTF8 test
{
final long written = writeToMemory(fieldWriter);
final Object[] valuesRead = readFromMemory(written);
Assert.assertEquals("values read (non-UTF8)", values, Arrays.asList(valuesRead));
List<FieldWriter> fieldWriters = getFieldWriter(fieldWritersType);
for (FieldWriter fw : fieldWriters) {
final Object[] valuesRead = doTestWithSpecificFieldWriter(fw);
List<String> expectedResults = new ArrayList<>(values);
if (fieldWritersType == FieldWritersType.NULL_REPLACING) {
expectedResults = expectedResults.stream()
.map(val -> StringUtils.replace(val, "\u0000", ""))
.collect(Collectors.toList());
}
Assert.assertEquals("values read", expectedResults, Arrays.asList(valuesRead));
}
}

// UTF8 test
{
final long writtenUtf8 = writeToMemory(fieldWriterUtf8);
final Object[] valuesReadUtf8 = readFromMemory(writtenUtf8);
Assert.assertEquals("values read (UTF8)", values, Arrays.asList(valuesReadUtf8));
}
private Object[] doTestWithSpecificFieldWriter(FieldWriter fieldWriter)
{
final long written = writeToMemory(fieldWriter);
return readFromMemory(written);
}


private void mockSelectors(final List<String> values)
{
final RangeIndexedInts row = new RangeIndexedInts();
Expand Down Expand Up @@ -183,9 +240,20 @@ private Object[] readFromMemory(final long written)
memory.getByteArray(MEMORY_POSITION, bytes, 0, (int) written);

final FieldReader fieldReader = FieldReaders.create("columnNameDoesntMatterHere", ColumnType.STRING_ARRAY);
final ColumnValueSelector<?> selector =
fieldReader.makeColumnValueSelector(memory, new ConstantFieldPointer(MEMORY_POSITION, -1));
final ColumnValueSelector<?> selector = fieldReader.makeColumnValueSelector(
memory,
new ConstantFieldPointer(
MEMORY_POSITION,
-1
)
);

return (Object[]) selector.getObject();
}

private enum FieldWritersType
{
NULL_REPLACING, // include null replacing writers only
ALL // include all writers
}
}
Loading