Skip to content

Commit

Permalink
Precompile regular expression patterns, optimize some string operatio…
Browse files Browse the repository at this point in the history
…ns, fix the read name separator.
  • Loading branch information
cmnbroad committed Dec 10, 2024
1 parent 923668d commit ed7bd6d
Show file tree
Hide file tree
Showing 6 changed files with 227 additions and 216 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ public static ByteBuffer allocateOutputBuffer(final int inSize) {
}

// returns a new LITTLE_ENDIAN ByteBuffer of size = bufferSize
//TODO: rename this to allocateLittleEndianByteBuffer
public static ByteBuffer allocateByteBuffer(final int bufferSize){
return ByteBuffer.allocate(bufferSize).order(ByteOrder.LITTLE_ENDIAN);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,16 @@
import java.util.StringJoiner;

public class NameTokenisationDecode {


// TODO: lift these values to a common location since they're used by the encode, decoder and tests
// for now, since we're returning a String of all the names (instead of a list, which is more efficient) because,
// use a single byte to separate the names; this particular byte is chosen because the calling code in the CRAM
// reader for read names already assumes it will be handed a block of '\0' separated names
public final static byte NAME_SEPARATOR = 0;
public final static CharSequence LOCAL_NAME_SEPARATOR_CHARSEQUENCE = new String(new byte[] {NAME_SEPARATOR});

// the input must be a ByteBuffer containing the read names, separated by the NAME_SEPARATOR byte, WITHOUT
// a terminating separator
public String uncompress(final ByteBuffer inBuffer) {
//TODO: make this stop sentinel into a shared static constant
// Actually, this doesn't need to be exposed as an arg on this, so move it into the uncompress method
return uncompress(inBuffer, "\0");
}

public String uncompress(
final ByteBuffer inBuffer,
final String separator) {
inBuffer.order(ByteOrder.LITTLE_ENDIAN);
final int uncompressedLength = inBuffer.getInt() & 0xFFFFFFFF; //unused but we have to consume it
final int numNames = inBuffer.getInt() & 0xFFFFFFFF;
Expand All @@ -27,21 +26,18 @@ public String uncompress(
final TokenStreams tokenStreams = new TokenStreams(inBuffer, useArith, numNames);

// track tokens we've already seen for subsequent lookup/reference (indexed as (nameIndex, tokenPosition))
//TODO: for performance reasons, it would probably be wise to separate the string tokens from the int tokens
// so we don't have to repeatedly interconvert them when fetching from this list
final List<List<String>> previousTokens = new ArrayList<>(numNames);
for (int i = 0; i < numNames; i++) {
previousTokens.add(new ArrayList<>());
}

final StringJoiner decodedNamesJoiner = new StringJoiner(separator);
final StringJoiner decodedNamesJoiner = new StringJoiner(LOCAL_NAME_SEPARATOR_CHARSEQUENCE);
for (int i = 0; i < numNames; i++) {
decodedNamesJoiner.add(decodeSingleName(tokenStreams, previousTokens, i));
}
final String uncompressedNames = decodedNamesJoiner.toString();
if (uncompressedLength == uncompressedNames.length() + separator.length()) {
return uncompressedNames + separator;
}
//TODO: this line is never executed in interop tests
return uncompressedNames;
return decodedNamesJoiner.toString();
}

private String decodeSingleName(
Expand All @@ -52,8 +48,7 @@ private String decodeSingleName(
// The information about whether a name is a duplicate or not
// is obtained from the list of tokens at tokenStreams[0,0]
final byte nameType = tokenStreams.getTokenStream(0, TokenStreams.TOKEN_TYPE).get();
//TODO: set the byte order to little endian where these are created, not here...
final ByteBuffer distBuffer = tokenStreams.getTokenStream(0, nameType).order(ByteOrder.LITTLE_ENDIAN);
final ByteBuffer distBuffer = tokenStreams.getTokenStream(0, nameType);
final int dist = distBuffer.getInt() & 0xFFFFFFFF;
final int prevNameIndex = currentNameIndex - dist;

Expand Down Expand Up @@ -82,15 +77,15 @@ private String decodeSingleName(
case TokenStreams.TOKEN_DIGITS0:
final String digits0Token = getDigitsToken(tokenStreams, tokenPos, TokenStreams.TOKEN_DIGITS0);
final int lenDigits0Token = tokenStreams.getTokenStream(tokenPos, TokenStreams.TOKEN_DZLEN).get() & 0xFF;
currentToken = leftPadNumber(digits0Token, lenDigits0Token);
currentToken = leftPadWith0(digits0Token, lenDigits0Token);
break;
case TokenStreams.TOKEN_DELTA:
currentToken = getDeltaToken(tokenStreams, tokenPos, tokensList, prevNameIndex, TokenStreams.TOKEN_DELTA);
break;
case TokenStreams.TOKEN_DELTA0:
final String delta0Token = getDeltaToken(tokenStreams, tokenPos, tokensList, prevNameIndex, TokenStreams.TOKEN_DELTA0);
final int lenDelta0Token = tokensList.get(prevNameIndex).get(tokenPos-1).length();
currentToken = leftPadNumber(delta0Token, lenDelta0Token);
currentToken = leftPadWith0(delta0Token, lenDelta0Token);
break;
case TokenStreams.TOKEN_MATCH:
currentToken = tokensList.get(prevNameIndex).get(tokenPos-1);
Expand All @@ -112,7 +107,7 @@ private String decodeSingleName(
"Invalid tokenType : %s. tokenType must be one of the valid token types",
type));
}
//TODO: this is expanding the list EVERY time, which is not efficient
//TODO: this is expanding the list many times, which is not efficient
tokensList.get(currentNameIndex).add(tokenPos - 1, currentToken);
decodedNameBuilder.append(currentToken);
tokenPos++;
Expand Down Expand Up @@ -150,7 +145,7 @@ private String getDigitsToken(
throw new CRAMException(String.format("Invalid tokenType : %s. " +
"tokenType must be either TOKEN_DIGITS or TOKEN_DIGITS0", tokenType));
}
final ByteBuffer digitsByteBuffer = tokenStreams.getTokenStream(tokenPosition, tokenType).order(ByteOrder.LITTLE_ENDIAN);
final ByteBuffer digitsByteBuffer = tokenStreams.getTokenStream(tokenPosition, tokenType);
final long digits = digitsByteBuffer.getInt() & 0xFFFFFFFFL;
return Long.toString(digits);
}
Expand All @@ -161,19 +156,23 @@ private String readString(final ByteBuffer inputBuffer) {
final StringBuilder resultStringBuilder = new StringBuilder();
byte currentByte = inputBuffer.get();
while (currentByte != 0) {
//TODO: fix this sketchy cast
resultStringBuilder.append((char) currentByte);
currentByte = inputBuffer.get();
}
return resultStringBuilder.toString();
}

private String leftPadNumber(String value, final int len) {
// return value such that it is at least len bytes long with leading zeros
//TODO: optimize this....
while (value.length() < len) {
value = "0" + value;
// return value such that it is at least len bytes long with leading zeros
private String leftPadWith0(final String value, final int len) {
if (value.length() >= len) {
return value;
} else {
final StringBuilder sb = new StringBuilder();
sb.append("0".repeat(Math.max(0, len - value.length())));
sb.append(value);
return sb.toString();
}
return value;
}

}
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
package htsjdk.samtools.cram.compression.nametokenisation;

import htsjdk.samtools.cram.CRAMException;
import htsjdk.samtools.cram.compression.CompressionUtils;
import htsjdk.samtools.cram.compression.nametokenisation.tokens.EncodeToken;
import htsjdk.samtools.cram.compression.range.RangeEncode;
import htsjdk.samtools.cram.compression.range.RangeParams;
import htsjdk.samtools.cram.compression.rans.ransnx16.RANSNx16Encode;
import htsjdk.samtools.cram.compression.rans.ransnx16.RANSNx16Params;
import htsjdk.samtools.cram.structure.CRAMEncodingStrategy;

import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.HashMap;
Expand All @@ -17,41 +18,50 @@
import java.util.regex.Pattern;

public class NameTokenisationEncode {
private final static String nameTokenizerRegex = "([a-zA-Z0-9]{1,9})|([^a-zA-Z0-9]+)";
private final static Pattern nameTokenizerPattern = Pattern.compile(nameTokenizerRegex);
private final static String READ_NAME_TOK_REGEX = "([a-zA-Z0-9]{1,9})|([^a-zA-Z0-9]+)";
private final static Pattern READ_NAME_PATTERN = Pattern.compile(READ_NAME_TOK_REGEX);

private final static String DIGITS0_REGEX = "^0+[0-9]*$";
private final static Pattern DIGITS0_PATTERN = Pattern.compile(DIGITS0_REGEX);

private final static String DIGITS_REGEX = "^[0-9]+$";
private final static Pattern DIGITS_PATTERN = Pattern.compile(DIGITS_REGEX);

private int maxToken;
private int maxLength;

//TODO: can this class use the TokenStreams class
// TODO: reset the input stream before processing
// the output is a ByteBuffer containing the read names, separated by the NAME_SEPARATOR byte, WITHOUT
// a terminating separator
public ByteBuffer compress(final ByteBuffer inBuffer, final boolean useArith) {
maxToken = 0;
maxLength = 0;
//TODO: make this an ArrayList of byte[] instead of String
final ArrayList<String> names = new ArrayList<>();
final ArrayList<String> names = new ArrayList<>(CRAMEncodingStrategy.DEFAULT_READS_PER_SLICE);

// convert buffer to array of names
//int lastPosition = inBuffer.position();
//while(inBuffer.hasRemaining()){
// extract the individual names from the input buffer
for (int lastPosition = inBuffer.position(); inBuffer.hasRemaining();) {
final byte currentByte = inBuffer.get();
//TODO: is this \n the same as the shared separator ? where is this defined ?
if ((currentByte) == '\n' || inBuffer.position()==inBuffer.limit()){
if (currentByte == NameTokenisationDecode.NAME_SEPARATOR || inBuffer.position() == inBuffer.limit()) {
final int length = inBuffer.position() - lastPosition;
final byte[] bytes = new byte[length];
inBuffer.position(lastPosition);
inBuffer.get(bytes, 0, length);
names.add(new String(bytes, StandardCharsets.UTF_8).trim());
inBuffer.get(bytes, 0, length); // consume the string + the terminator
names.add(new String(
bytes,
0,
// special case handling end of the buffer, where there is no for the lack of a trailing separator
length - (inBuffer.position() == inBuffer.limit() ? 0 : 1),
StandardCharsets.UTF_8));
lastPosition = inBuffer.position();
}
}

final int numNames = names.size();
// guess max size -> str.length*2 + 10000 (from htscodecs javascript code)
//TODO: what is this calculation ?
final ByteBuffer outBuffer = allocateOutputBuffer((inBuffer.limit()*2)+10000);
outBuffer.putInt(inBuffer.limit());
final int uncompressedInputSize = inBuffer.limit() - numNames - 1;
//TODO: guess max size -> str.length*2 + 10000 (from htscodecs javascript code)
final ByteBuffer outBuffer = CompressionUtils.allocateByteBuffer((inBuffer.limit()*2)+10000);
//TODO: what is the correct value here ? does/should this include
// the local name delimiter that we use to format the input stream (??)
outBuffer.putInt(uncompressedInputSize);
outBuffer.putInt(numNames);
outBuffer.put((byte)(useArith == true ? 1 : 0));

Expand All @@ -67,7 +77,7 @@ public ByteBuffer compress(final ByteBuffer inBuffer, final boolean useArith) {
for (int tokenPosition = 0; tokenPosition < maxToken; tokenPosition++) {
final List<ByteBuffer> tokenStream = new ArrayList(TokenStreams.TOTAL_TOKEN_TYPES);
for (int i = 0; i < TokenStreams.TOTAL_TOKEN_TYPES; i++) {
tokenStream.add(ByteBuffer.allocate(numNames* maxLength).order(ByteOrder.LITTLE_ENDIAN));
tokenStream.add(CompressionUtils.allocateByteBuffer(numNames * maxLength));
}
fillByteStreams(tokenStream, tokensList, tokenPosition, numNames);
serializeByteStreams(tokenStream, useArith, outBuffer);
Expand All @@ -89,23 +99,15 @@ private void tokeniseName(final List<List<EncodeToken>> tokensList,
final int prevNameIndex = currentNameIndex - 1;
tokensList.add(new ArrayList<>());
if (nameIndexMap.containsKey(name)) {
// TODO: Add Test to cover this code
tokensList.get(currentNameIndex).add(
// TODO: lift the common subexpressions
new EncodeToken(
String.valueOf(currentNameIndex - nameIndexMap.get(name)),
String.valueOf(currentNameIndex - nameIndexMap.get(name)),
TokenStreams.TOKEN_DUP));
final String indStr = String.valueOf(currentNameIndex - nameIndexMap.get(name));
tokensList.get(currentNameIndex).add(new EncodeToken(indStr, indStr, TokenStreams.TOKEN_DUP));
} else {
tokensList.get(currentNameIndex).add(
new EncodeToken(
String.valueOf(currentNameIndex == 0 ? 0 : 1),
String.valueOf(currentNameIndex == 0 ? 0 : 1),
TokenStreams.TOKEN_DIFF));
final String indStr = String.valueOf(currentNameIndex == 0 ? 0 : 1);
tokensList.get(currentNameIndex).add(new EncodeToken(indStr, indStr, TokenStreams.TOKEN_DIFF));
}
// Get the list of tokens `tok` for the current name
nameIndexMap.put(name, currentNameIndex);
final Matcher matcher = nameTokenizerPattern.matcher(name);
final Matcher matcher = READ_NAME_PATTERN.matcher(name);
final List<String> tok = new ArrayList<>();
while (matcher.find()) {
tok.add(matcher.group());
Expand All @@ -117,13 +119,13 @@ private void tokeniseName(final List<List<EncodeToken>> tokensList,
int tokenIndex = i + 1;
byte type = TokenStreams.TOKEN_STRING;
final String str = tok.get(i); // absolute value of the token
String val = tok.get(i); // relative value of the token (comparing to prevname's token at the same token position)
//TODO: precompile these
if (tok.get(i).matches("^0+[0-9]*$")) {
String val = str; // relative value of the token (comparing to prevname's token at the same token position)

if (DIGITS0_PATTERN.matcher(str).matches()) {
type = TokenStreams.TOKEN_DIGITS0;
} else if (tok.get(i).matches("^[0-9]+$")) {
} else if (DIGITS_PATTERN.matcher(str).matches()) {
type = TokenStreams.TOKEN_DIGITS;
} else if (tok.get(i).length() == 1) {
} else if (str.length() == 1) {
type = TokenStreams.TOKEN_CHAR;
}

Expand All @@ -144,7 +146,7 @@ private void tokeniseName(final List<List<EncodeToken>> tokensList,
type = TokenStreams.TOKEN_DELTA;
val = String.valueOf(d);
}
} else if (type==TokenStreams.TOKEN_DIGITS0 && prevToken.getActualTokenValue().length() == val.length()
} else if (type == TokenStreams.TOKEN_DIGITS0 && prevToken.getActualTokenValue().length() == val.length()
&& (prevToken.getTokenType() == TokenStreams.TOKEN_DIGITS0 || prevToken.getTokenType() == TokenStreams.TOKEN_DELTA0)) {
int d = Integer.parseInt(val) - Integer.parseInt(prevToken.getActualTokenValue());
tokenFrequencies[tokenIndex]++;
Expand Down Expand Up @@ -224,15 +226,15 @@ private void fillByteStreams(
tokenStream.get(TokenStreams.TOKEN_DELTA0).put((byte)Integer.parseInt(encodeToken.getRelativeTokenValue()));
break;

// case TokenStreams.TOKEN_NOP:
// case TokenStreams.TOKEN_MATCH:
// case TokenStreams.TOKEN_END:
// //TODO: do we need to handle these token types here? throwing causes exceptions
// //throw new CRAMException("Invalid token type: " + type);
// break;
//
// default:
// throw new CRAMException("Invalid token type: " + type);
case TokenStreams.TOKEN_NOP:
case TokenStreams.TOKEN_MATCH:
case TokenStreams.TOKEN_END:
//TODO: do we need to handle these token types here? throwing causes exceptions
//throw new CRAMException("Invalid token type: " + type);
break;

default:
throw new CRAMException("Invalid token type: " + type);
}
}
}
Expand Down Expand Up @@ -322,18 +324,4 @@ private void serializeByteStreams(
}
}
}

//TODO: consolidate this with the same method in CompressionUtils
private ByteBuffer allocateOutputBuffer(final int inSize) {

// same as the allocateOutputBuffer in RANS4x8Encode and RANSNx16Encode
// TODO: de-duplicate
final int compressedSize = (int) (1.05 * inSize + 257 * 257 * 3 + 9);
final ByteBuffer outputBuffer = ByteBuffer.allocate(compressedSize);
if (outputBuffer.remaining() < compressedSize) {
throw new RuntimeException("Failed to allocate sufficient buffer size for name tokenization encoder.");
}
outputBuffer.order(ByteOrder.LITTLE_ENDIAN);
return outputBuffer;
}
}
Loading

0 comments on commit ed7bd6d

Please sign in to comment.