Skip to content

Commit

Permalink
[misc] permit using setObject/getObject on float[]/Float[] values
Browse files Browse the repository at this point in the history
  • Loading branch information
rusher committed Oct 17, 2024
1 parent 5463572 commit ae8e740
Show file tree
Hide file tree
Showing 5 changed files with 383 additions and 1 deletion.
140 changes: 140 additions & 0 deletions src/main/java/org/mariadb/jdbc/plugin/codec/FloatArrayCodec.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
// SPDX-License-Identifier: LGPL-2.1-or-later
// Copyright (c) 2012-2014 Monty Program Ab
// Copyright (c) 2015-2024 MariaDB Corporation Ab
package org.mariadb.jdbc.plugin.codec;

import java.io.IOException;
import java.lang.reflect.Array;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.sql.SQLDataException;
import java.util.Calendar;
import java.util.EnumSet;
import org.mariadb.jdbc.client.*;
import org.mariadb.jdbc.client.socket.Writer;
import org.mariadb.jdbc.client.util.MutableInt;
import org.mariadb.jdbc.plugin.Codec;
import org.mariadb.jdbc.util.constants.ServerStatus;

/** Float codec */
public class FloatArrayCodec implements Codec<float[]> {

/** default instance */
public static final FloatArrayCodec INSTANCE = new FloatArrayCodec();

private static Class floatArrayClass = Array.newInstance(float.class, 0).getClass();

private static final EnumSet<DataType> COMPATIBLE_TYPES =
EnumSet.of(
DataType.BLOB,
DataType.TINYBLOB,
DataType.MEDIUMBLOB,
DataType.LONGBLOB,
DataType.VARSTRING,
DataType.VARCHAR,
DataType.STRING);

public String className() {
return float[].class.getName();
}

public boolean canDecode(ColumnDecoder column, Class<?> type) {
return COMPATIBLE_TYPES.contains(column.getType())
&& ((!type.isPrimitive() && type == floatArrayClass && type.isArray()));
}

public boolean canEncode(Object value) {
return value instanceof float[];
}

@Override
public float[] decodeText(
final ReadableByteBuf buf,
final MutableInt length,
final ColumnDecoder column,
final Calendar cal,
final Context context)
throws SQLDataException {

return toFloatArray(getBytes(buf, length, column));
}

@Override
public float[] decodeBinary(
final ReadableByteBuf buf,
final MutableInt length,
final ColumnDecoder column,
final Calendar cal,
final Context context)
throws SQLDataException {

return toFloatArray(getBytes(buf, length, column));
}

static final int BYTES_IN_FLOAT = Float.SIZE / Byte.SIZE;

public static byte[] toByteArray(float[] floatArray) {
ByteBuffer buffer = ByteBuffer.allocate(floatArray.length * BYTES_IN_FLOAT);
buffer.order(ByteOrder.LITTLE_ENDIAN);
buffer.asFloatBuffer().put(floatArray);
return buffer.array();
}

public static float[] toFloatArray(byte[] byteArray) {
float[] result = new float[byteArray.length / BYTES_IN_FLOAT];
ByteBuffer.wrap(byteArray)
.order(ByteOrder.LITTLE_ENDIAN)
.asFloatBuffer()
.get(result, 0, result.length);
return result;
}

private byte[] getBytes(ReadableByteBuf buf, MutableInt length, ColumnDecoder column)
throws SQLDataException {
switch (column.getType()) {
case BLOB:
case TINYBLOB:
case MEDIUMBLOB:
case LONGBLOB:
case STRING:
case VARSTRING:
case VARCHAR:
case GEOMETRY:
byte[] arr = new byte[length.get()];
buf.readBytes(arr);
return arr;

default:
buf.skip(length.get());
throw new SQLDataException(
String.format("Data type %s cannot be decoded as byte[]", column.getType()));
}
}

@Override
public void encodeText(Writer encoder, Context context, Object value, Calendar cal, Long maxLen)
throws IOException {
byte[] encoded = toByteArray((float[]) value);
encoder.writeBytes(ByteArrayCodec.BINARY_PREFIX);
encoder.writeBytesEscaped(
encoded,
encoded.length,
(context.getServerStatus() & ServerStatus.NO_BACKSLASH_ESCAPES) != 0);
encoder.writeByte('\'');
}

@Override
public void encodeBinary(
final Writer encoder,
final Context context,
final Object value,
final Calendar cal,
final Long maxLength)
throws IOException {
encoder.writeBytes(toByteArray((float[]) value));
}

public int getBinaryEncodeType() {
return DataType.BLOB.get();
}
}
160 changes: 160 additions & 0 deletions src/main/java/org/mariadb/jdbc/plugin/codec/FloatObjectArrayCodec.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
// SPDX-License-Identifier: LGPL-2.1-or-later
// Copyright (c) 2012-2014 Monty Program Ab
// Copyright (c) 2015-2024 MariaDB Corporation Ab
package org.mariadb.jdbc.plugin.codec;

import java.io.IOException;
import java.lang.reflect.Array;
import java.sql.SQLDataException;
import java.util.Calendar;
import java.util.EnumSet;
import org.mariadb.jdbc.client.ColumnDecoder;
import org.mariadb.jdbc.client.Context;
import org.mariadb.jdbc.client.DataType;
import org.mariadb.jdbc.client.ReadableByteBuf;
import org.mariadb.jdbc.client.socket.Writer;
import org.mariadb.jdbc.client.util.MutableInt;
import org.mariadb.jdbc.plugin.Codec;
import org.mariadb.jdbc.util.constants.ServerStatus;

/** Float codec */
public class FloatObjectArrayCodec implements Codec<Float[]> {

/** default instance */
public static final FloatObjectArrayCodec INSTANCE = new FloatObjectArrayCodec();

private static Class floatArrayClass = Array.newInstance(Float.class, 0).getClass();
private static final EnumSet<DataType> COMPATIBLE_TYPES =
EnumSet.of(
DataType.BLOB,
DataType.TINYBLOB,
DataType.MEDIUMBLOB,
DataType.LONGBLOB,
DataType.VARSTRING,
DataType.VARCHAR,
DataType.STRING);

public String className() {
return float[].class.getName();
}

public boolean canDecode(ColumnDecoder column, Class<?> type) {
return COMPATIBLE_TYPES.contains(column.getType())
&& ((!type.isPrimitive() && type == floatArrayClass && type.isArray()));
}

public boolean canEncode(Object value) {
return value instanceof Float[];
}

@Override
public Float[] decodeText(
final ReadableByteBuf buf,
final MutableInt length,
final ColumnDecoder column,
final Calendar cal,
final Context context)
throws SQLDataException {

return toFloatArray(getBytes(buf, length, column));
}

@Override
public Float[] decodeBinary(
final ReadableByteBuf buf,
final MutableInt length,
final ColumnDecoder column,
final Calendar cal,
final Context context)
throws SQLDataException {

return toFloatArray(getBytes(buf, length, column));
}

static final int BYTES_IN_FLOAT = Float.SIZE / Byte.SIZE;

public static byte[] toByteArray(Float[] floatArray) {
byte[] buf = new byte[floatArray.length * BYTES_IN_FLOAT];
int pos = 0;
for (Float f : floatArray) {
int value = Float.floatToIntBits(f);
buf[pos] = (byte) value;
buf[pos + 1] = (byte) (value >> 8);
buf[pos + 2] = (byte) (value >> 16);
buf[pos + 3] = (byte) (value >> 24);
pos += 4;
}
return buf;
}

public static Float[] toFloatArray(byte[] byteArray) {
int len = (int) Math.ceil(byteArray.length / 4.0);
Float[] res = new Float[len];
int pos = 0;
int value;
while (pos < len) {
if (pos + 1 <= len) {
value =
((byteArray[pos * 4] & 0xff)
+ ((byteArray[pos * 4 + 1] & 0xff) << 8)
+ ((byteArray[pos * 4 + 2] & 0xff) << 16)
+ ((byteArray[pos * 4 + 3] & 0xff) << 24));
} else {
value = (byteArray[pos * 4] & 0xff);
if (pos + 1 < byteArray.length) value += ((byteArray[pos * 4 + 1] & 0xff) << 8);
if (pos + 2 < byteArray.length) value += ((byteArray[pos * 4 + 2] & 0xff) << 16);
}
res[pos++] = Float.intBitsToFloat(value);
}
return res;
}

private byte[] getBytes(ReadableByteBuf buf, MutableInt length, ColumnDecoder column)
throws SQLDataException {
switch (column.getType()) {
case BLOB:
case TINYBLOB:
case MEDIUMBLOB:
case LONGBLOB:
case STRING:
case VARSTRING:
case VARCHAR:
case GEOMETRY:
byte[] arr = new byte[length.get()];
buf.readBytes(arr);
return arr;

default:
buf.skip(length.get());
throw new SQLDataException(
String.format("Data type %s cannot be decoded as byte[]", column.getType()));
}
}

@Override
public void encodeText(Writer encoder, Context context, Object value, Calendar cal, Long maxLen)
throws IOException {
byte[] encoded = toByteArray((Float[]) value);
encoder.writeBytes(ByteArrayCodec.BINARY_PREFIX);
encoder.writeBytesEscaped(
encoded,
encoded.length,
(context.getServerStatus() & ServerStatus.NO_BACKSLASH_ESCAPES) != 0);
encoder.writeByte('\'');
}

@Override
public void encodeBinary(
final Writer encoder,
final Context context,
final Object value,
final Calendar cal,
final Long maxLength)
throws IOException {
encoder.writeBytes(toByteArray((Float[]) value));
}

public int getBinaryEncodeType() {
return DataType.BLOB.get();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,5 @@ org.mariadb.jdbc.plugin.codec.TimeCodec
org.mariadb.jdbc.plugin.codec.TimestampCodec
org.mariadb.jdbc.plugin.codec.UuidCodec
org.mariadb.jdbc.plugin.codec.ZonedDateTimeCodec
org.mariadb.jdbc.plugin.codec.FloatArrayCodec
org.mariadb.jdbc.plugin.codec.FloatObjectArrayCodec
Original file line number Diff line number Diff line change
Expand Up @@ -1232,7 +1232,7 @@ public void ensureClassDefined() {
Type it = codec.getClass().getGenericInterfaces()[0];
ParameterizedType parameterizedType = (ParameterizedType) it;
Type typeParameter = parameterizedType.getActualTypeArguments()[0];
if (!"byte[]".equals(codec.className()))
if (!"byte[]".equals(codec.className()) && !"[F".equals(codec.className()))
assertEquals(((Class<?>) typeParameter).getName(), codec.className());
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// SPDX-License-Identifier: LGPL-2.1-or-later
// Copyright (c) 2012-2014 Monty Program Ab
// Copyright (c) 2015-2024 MariaDB Corporation Ab
package org.mariadb.jdbc.integration.codec;

import static org.junit.jupiter.api.Assertions.*;

import java.sql.*;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.mariadb.jdbc.Statement;

public class FloatArrayCodecTest extends CommonCodecTest {
@AfterAll
public static void drop() throws SQLException {
Statement stmt = sharedConn.createStatement();
stmt.execute("DROP TABLE IF EXISTS BinaryCodec");
stmt.execute("DROP TABLE IF EXISTS BinaryCodec2");
}

@BeforeAll
public static void beforeAll2() throws SQLException {
drop();
Statement stmt = sharedConn.createStatement();
stmt.execute(
"CREATE TABLE BinaryCodec (t1 blob) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci");
stmt.execute("FLUSH TABLES");
}

@Test
public void floatArray() throws SQLException {
Statement stmt = sharedConn.createStatement();
stmt.execute("TRUNCATE TABLE BinaryCodec");
floatArray(sharedConn);
floatArray(sharedConnBinary);
}

private void floatArray(Connection con) throws SQLException {
float[] val = new float[] {1, 2, 3};
byte[] expectedConverstion =
new byte[] {0x00, 0x00, (byte) 0x80, 0x3F, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40};
try (PreparedStatement prep = con.prepareStatement("INSERT INTO BinaryCodec(t1) VALUES (?)")) {
prep.setObject(1, val);
prep.execute();
}
try (PreparedStatement prep = con.prepareStatement("SELECT * FROM BinaryCodec")) {
ResultSet rs = prep.executeQuery();
assertTrue(rs.next());
assertArrayEquals(expectedConverstion, rs.getBytes(1));
float[] res = rs.getObject(1, float[].class);
assertArrayEquals(val, res);
}
}

@Test
public void floatObjectArray() throws SQLException {
Statement stmt = sharedConn.createStatement();
stmt.execute("TRUNCATE TABLE BinaryCodec");
floatObjectArray(sharedConn);
floatObjectArray(sharedConnBinary);
}

private void floatObjectArray(Connection con) throws SQLException {
Float[] val = new Float[] {1f, 2f, 3f};
byte[] expectedConverstion =
new byte[] {0x00, 0x00, (byte) 0x80, 0x3F, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40};
try (PreparedStatement prep = con.prepareStatement("INSERT INTO BinaryCodec(t1) VALUES (?)")) {
prep.setObject(1, val);
prep.execute();
}
try (PreparedStatement prep = con.prepareStatement("SELECT * FROM BinaryCodec")) {
ResultSet rs = prep.executeQuery();
assertTrue(rs.next());
assertArrayEquals(expectedConverstion, rs.getBytes(1));
Float[] res = rs.getObject(1, Float[].class);
assertArrayEquals(val, res);
}
}
}

0 comments on commit ae8e740

Please sign in to comment.