diff --git a/src/main/java/org/mariadb/jdbc/plugin/codec/FloatArrayCodec.java b/src/main/java/org/mariadb/jdbc/plugin/codec/FloatArrayCodec.java new file mode 100644 index 000000000..39a98b0e7 --- /dev/null +++ b/src/main/java/org/mariadb/jdbc/plugin/codec/FloatArrayCodec.java @@ -0,0 +1,140 @@ +// SPDX-License-Identifier: LGPL-2.1-or-later +// Copyright (c) 2012-2014 Monty Program Ab +// Copyright (c) 2015-2024 MariaDB Corporation Ab +package org.mariadb.jdbc.plugin.codec; + +import java.io.IOException; +import java.lang.reflect.Array; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.sql.SQLDataException; +import java.util.Calendar; +import java.util.EnumSet; +import org.mariadb.jdbc.client.*; +import org.mariadb.jdbc.client.socket.Writer; +import org.mariadb.jdbc.client.util.MutableInt; +import org.mariadb.jdbc.plugin.Codec; +import org.mariadb.jdbc.util.constants.ServerStatus; + +/** Float codec */ +public class FloatArrayCodec implements Codec { + + /** default instance */ + public static final FloatArrayCodec INSTANCE = new FloatArrayCodec(); + + private static Class floatArrayClass = Array.newInstance(float.class, 0).getClass(); + + private static final EnumSet COMPATIBLE_TYPES = + EnumSet.of( + DataType.BLOB, + DataType.TINYBLOB, + DataType.MEDIUMBLOB, + DataType.LONGBLOB, + DataType.VARSTRING, + DataType.VARCHAR, + DataType.STRING); + + public String className() { + return float[].class.getName(); + } + + public boolean canDecode(ColumnDecoder column, Class type) { + return COMPATIBLE_TYPES.contains(column.getType()) + && ((!type.isPrimitive() && type == floatArrayClass && type.isArray())); + } + + public boolean canEncode(Object value) { + return value instanceof float[]; + } + + @Override + public float[] decodeText( + final ReadableByteBuf buf, + final MutableInt length, + final ColumnDecoder column, + final Calendar cal, + final Context context) + throws SQLDataException { + + return toFloatArray(getBytes(buf, length, column)); + } + + @Override + public float[] decodeBinary( + final ReadableByteBuf buf, + final MutableInt length, + final ColumnDecoder column, + final Calendar cal, + final Context context) + throws SQLDataException { + + return toFloatArray(getBytes(buf, length, column)); + } + + static final int BYTES_IN_FLOAT = Float.SIZE / Byte.SIZE; + + public static byte[] toByteArray(float[] floatArray) { + ByteBuffer buffer = ByteBuffer.allocate(floatArray.length * BYTES_IN_FLOAT); + buffer.order(ByteOrder.LITTLE_ENDIAN); + buffer.asFloatBuffer().put(floatArray); + return buffer.array(); + } + + public static float[] toFloatArray(byte[] byteArray) { + float[] result = new float[byteArray.length / BYTES_IN_FLOAT]; + ByteBuffer.wrap(byteArray) + .order(ByteOrder.LITTLE_ENDIAN) + .asFloatBuffer() + .get(result, 0, result.length); + return result; + } + + private byte[] getBytes(ReadableByteBuf buf, MutableInt length, ColumnDecoder column) + throws SQLDataException { + switch (column.getType()) { + case BLOB: + case TINYBLOB: + case MEDIUMBLOB: + case LONGBLOB: + case STRING: + case VARSTRING: + case VARCHAR: + case GEOMETRY: + byte[] arr = new byte[length.get()]; + buf.readBytes(arr); + return arr; + + default: + buf.skip(length.get()); + throw new SQLDataException( + String.format("Data type %s cannot be decoded as byte[]", column.getType())); + } + } + + @Override + public void encodeText(Writer encoder, Context context, Object value, Calendar cal, Long maxLen) + throws IOException { + byte[] encoded = toByteArray((float[]) value); + encoder.writeBytes(ByteArrayCodec.BINARY_PREFIX); + encoder.writeBytesEscaped( + encoded, + encoded.length, + (context.getServerStatus() & ServerStatus.NO_BACKSLASH_ESCAPES) != 0); + encoder.writeByte('\''); + } + + @Override + public void encodeBinary( + final Writer encoder, + final Context context, + final Object value, + final Calendar cal, + final Long maxLength) + throws IOException { + encoder.writeBytes(toByteArray((float[]) value)); + } + + public int getBinaryEncodeType() { + return DataType.BLOB.get(); + } +} diff --git a/src/main/java/org/mariadb/jdbc/plugin/codec/FloatObjectArrayCodec.java b/src/main/java/org/mariadb/jdbc/plugin/codec/FloatObjectArrayCodec.java new file mode 100644 index 000000000..0be5fc549 --- /dev/null +++ b/src/main/java/org/mariadb/jdbc/plugin/codec/FloatObjectArrayCodec.java @@ -0,0 +1,160 @@ +// SPDX-License-Identifier: LGPL-2.1-or-later +// Copyright (c) 2012-2014 Monty Program Ab +// Copyright (c) 2015-2024 MariaDB Corporation Ab +package org.mariadb.jdbc.plugin.codec; + +import java.io.IOException; +import java.lang.reflect.Array; +import java.sql.SQLDataException; +import java.util.Calendar; +import java.util.EnumSet; +import org.mariadb.jdbc.client.ColumnDecoder; +import org.mariadb.jdbc.client.Context; +import org.mariadb.jdbc.client.DataType; +import org.mariadb.jdbc.client.ReadableByteBuf; +import org.mariadb.jdbc.client.socket.Writer; +import org.mariadb.jdbc.client.util.MutableInt; +import org.mariadb.jdbc.plugin.Codec; +import org.mariadb.jdbc.util.constants.ServerStatus; + +/** Float codec */ +public class FloatObjectArrayCodec implements Codec { + + /** default instance */ + public static final FloatObjectArrayCodec INSTANCE = new FloatObjectArrayCodec(); + + private static Class floatArrayClass = Array.newInstance(Float.class, 0).getClass(); + private static final EnumSet COMPATIBLE_TYPES = + EnumSet.of( + DataType.BLOB, + DataType.TINYBLOB, + DataType.MEDIUMBLOB, + DataType.LONGBLOB, + DataType.VARSTRING, + DataType.VARCHAR, + DataType.STRING); + + public String className() { + return float[].class.getName(); + } + + public boolean canDecode(ColumnDecoder column, Class type) { + return COMPATIBLE_TYPES.contains(column.getType()) + && ((!type.isPrimitive() && type == floatArrayClass && type.isArray())); + } + + public boolean canEncode(Object value) { + return value instanceof Float[]; + } + + @Override + public Float[] decodeText( + final ReadableByteBuf buf, + final MutableInt length, + final ColumnDecoder column, + final Calendar cal, + final Context context) + throws SQLDataException { + + return toFloatArray(getBytes(buf, length, column)); + } + + @Override + public Float[] decodeBinary( + final ReadableByteBuf buf, + final MutableInt length, + final ColumnDecoder column, + final Calendar cal, + final Context context) + throws SQLDataException { + + return toFloatArray(getBytes(buf, length, column)); + } + + static final int BYTES_IN_FLOAT = Float.SIZE / Byte.SIZE; + + public static byte[] toByteArray(Float[] floatArray) { + byte[] buf = new byte[floatArray.length * BYTES_IN_FLOAT]; + int pos = 0; + for (Float f : floatArray) { + int value = Float.floatToIntBits(f); + buf[pos] = (byte) value; + buf[pos + 1] = (byte) (value >> 8); + buf[pos + 2] = (byte) (value >> 16); + buf[pos + 3] = (byte) (value >> 24); + pos += 4; + } + return buf; + } + + public static Float[] toFloatArray(byte[] byteArray) { + int len = (int) Math.ceil(byteArray.length / 4.0); + Float[] res = new Float[len]; + int pos = 0; + int value; + while (pos < len) { + if (pos + 1 <= len) { + value = + ((byteArray[pos * 4] & 0xff) + + ((byteArray[pos * 4 + 1] & 0xff) << 8) + + ((byteArray[pos * 4 + 2] & 0xff) << 16) + + ((byteArray[pos * 4 + 3] & 0xff) << 24)); + } else { + value = (byteArray[pos * 4] & 0xff); + if (pos + 1 < byteArray.length) value += ((byteArray[pos * 4 + 1] & 0xff) << 8); + if (pos + 2 < byteArray.length) value += ((byteArray[pos * 4 + 2] & 0xff) << 16); + } + res[pos++] = Float.intBitsToFloat(value); + } + return res; + } + + private byte[] getBytes(ReadableByteBuf buf, MutableInt length, ColumnDecoder column) + throws SQLDataException { + switch (column.getType()) { + case BLOB: + case TINYBLOB: + case MEDIUMBLOB: + case LONGBLOB: + case STRING: + case VARSTRING: + case VARCHAR: + case GEOMETRY: + byte[] arr = new byte[length.get()]; + buf.readBytes(arr); + return arr; + + default: + buf.skip(length.get()); + throw new SQLDataException( + String.format("Data type %s cannot be decoded as byte[]", column.getType())); + } + } + + @Override + public void encodeText(Writer encoder, Context context, Object value, Calendar cal, Long maxLen) + throws IOException { + byte[] encoded = toByteArray((Float[]) value); + encoder.writeBytes(ByteArrayCodec.BINARY_PREFIX); + encoder.writeBytesEscaped( + encoded, + encoded.length, + (context.getServerStatus() & ServerStatus.NO_BACKSLASH_ESCAPES) != 0); + encoder.writeByte('\''); + } + + @Override + public void encodeBinary( + final Writer encoder, + final Context context, + final Object value, + final Calendar cal, + final Long maxLength) + throws IOException { + encoder.writeBytes(toByteArray((Float[]) value)); + } + + public int getBinaryEncodeType() { + return DataType.BLOB.get(); + } +} diff --git a/src/main/resources/META-INF/services/org.mariadb.jdbc.plugin.Codec b/src/main/resources/META-INF/services/org.mariadb.jdbc.plugin.Codec index 3335aa228..ccbc0bd3d 100644 --- a/src/main/resources/META-INF/services/org.mariadb.jdbc.plugin.Codec +++ b/src/main/resources/META-INF/services/org.mariadb.jdbc.plugin.Codec @@ -32,3 +32,5 @@ org.mariadb.jdbc.plugin.codec.TimeCodec org.mariadb.jdbc.plugin.codec.TimestampCodec org.mariadb.jdbc.plugin.codec.UuidCodec org.mariadb.jdbc.plugin.codec.ZonedDateTimeCodec +org.mariadb.jdbc.plugin.codec.FloatArrayCodec +org.mariadb.jdbc.plugin.codec.FloatObjectArrayCodec \ No newline at end of file diff --git a/src/test/java/org/mariadb/jdbc/integration/StatementTest.java b/src/test/java/org/mariadb/jdbc/integration/StatementTest.java index 977c93411..6ee3fce7d 100644 --- a/src/test/java/org/mariadb/jdbc/integration/StatementTest.java +++ b/src/test/java/org/mariadb/jdbc/integration/StatementTest.java @@ -1232,7 +1232,7 @@ public void ensureClassDefined() { Type it = codec.getClass().getGenericInterfaces()[0]; ParameterizedType parameterizedType = (ParameterizedType) it; Type typeParameter = parameterizedType.getActualTypeArguments()[0]; - if (!"byte[]".equals(codec.className())) + if (!"byte[]".equals(codec.className()) && !"[F".equals(codec.className())) assertEquals(((Class) typeParameter).getName(), codec.className()); } } diff --git a/src/test/java/org/mariadb/jdbc/integration/codec/FloatArrayCodecTest.java b/src/test/java/org/mariadb/jdbc/integration/codec/FloatArrayCodecTest.java new file mode 100644 index 000000000..7a3073672 --- /dev/null +++ b/src/test/java/org/mariadb/jdbc/integration/codec/FloatArrayCodecTest.java @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: LGPL-2.1-or-later +// Copyright (c) 2012-2014 Monty Program Ab +// Copyright (c) 2015-2024 MariaDB Corporation Ab +package org.mariadb.jdbc.integration.codec; + +import static org.junit.jupiter.api.Assertions.*; + +import java.sql.*; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.mariadb.jdbc.Statement; + +public class FloatArrayCodecTest extends CommonCodecTest { + @AfterAll + public static void drop() throws SQLException { + Statement stmt = sharedConn.createStatement(); + stmt.execute("DROP TABLE IF EXISTS BinaryCodec"); + stmt.execute("DROP TABLE IF EXISTS BinaryCodec2"); + } + + @BeforeAll + public static void beforeAll2() throws SQLException { + drop(); + Statement stmt = sharedConn.createStatement(); + stmt.execute( + "CREATE TABLE BinaryCodec (t1 blob) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci"); + stmt.execute("FLUSH TABLES"); + } + + @Test + public void floatArray() throws SQLException { + Statement stmt = sharedConn.createStatement(); + stmt.execute("TRUNCATE TABLE BinaryCodec"); + floatArray(sharedConn); + floatArray(sharedConnBinary); + } + + private void floatArray(Connection con) throws SQLException { + float[] val = new float[] {1, 2, 3}; + byte[] expectedConverstion = + new byte[] {0x00, 0x00, (byte) 0x80, 0x3F, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40}; + try (PreparedStatement prep = con.prepareStatement("INSERT INTO BinaryCodec(t1) VALUES (?)")) { + prep.setObject(1, val); + prep.execute(); + } + try (PreparedStatement prep = con.prepareStatement("SELECT * FROM BinaryCodec")) { + ResultSet rs = prep.executeQuery(); + assertTrue(rs.next()); + assertArrayEquals(expectedConverstion, rs.getBytes(1)); + float[] res = rs.getObject(1, float[].class); + assertArrayEquals(val, res); + } + } + + @Test + public void floatObjectArray() throws SQLException { + Statement stmt = sharedConn.createStatement(); + stmt.execute("TRUNCATE TABLE BinaryCodec"); + floatObjectArray(sharedConn); + floatObjectArray(sharedConnBinary); + } + + private void floatObjectArray(Connection con) throws SQLException { + Float[] val = new Float[] {1f, 2f, 3f}; + byte[] expectedConverstion = + new byte[] {0x00, 0x00, (byte) 0x80, 0x3F, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40}; + try (PreparedStatement prep = con.prepareStatement("INSERT INTO BinaryCodec(t1) VALUES (?)")) { + prep.setObject(1, val); + prep.execute(); + } + try (PreparedStatement prep = con.prepareStatement("SELECT * FROM BinaryCodec")) { + ResultSet rs = prep.executeQuery(); + assertTrue(rs.next()); + assertArrayEquals(expectedConverstion, rs.getBytes(1)); + Float[] res = rs.getObject(1, Float[].class); + assertArrayEquals(val, res); + } + } +}