Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(csharp/src/Drivers/Apache): extend capability of GetInfo for Spark driver #1863

50 changes: 50 additions & 0 deletions csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;
using Apache.Arrow.Ipc;
Expand All @@ -35,6 +36,8 @@ public abstract class HiveServer2Connection : AdbcConnection
internal TTransport? transport;
internal TCLIService.Client? client;
internal TSessionHandle? sessionHandle;
private string? _vendorVersion;
birschick-bq marked this conversation as resolved.
Show resolved Hide resolved
private string? _vendorName;

internal HiveServer2Connection(IReadOnlyDictionary<string, string> properties)
{
Expand All @@ -46,6 +49,30 @@ internal TCLIService.Client Client
get { return this.client ?? throw new InvalidOperationException("connection not open"); }
}

protected string? VendorVersion
{
get
{
if (_vendorVersion == null && TryGetInfoType(TGetInfoType.CLI_DBMS_VER, out string? value))
{
_vendorVersion = value;
}
return _vendorVersion;
}
}

protected string? VendorName
{
get
{
if (_vendorName == null && TryGetInfoType(TGetInfoType.CLI_DBMS_NAME, out string? value))
birschick-bq marked this conversation as resolved.
Show resolved Hide resolved
{
_vendorName = value;
}
return _vendorName;
}
}

internal async Task OpenAsync()
{
TProtocol protocol = await CreateProtocolAsync();
Expand Down Expand Up @@ -103,6 +130,29 @@ protected Schema GetSchema()
return SchemaParser.GetArrowSchema(response.Schema);
}

private bool TryGetInfoType(TGetInfoType infoType, out string? value)
{
TGetInfoReq req = new()
{
SessionHandle = this.sessionHandle ?? throw new InvalidOperationException("session not created"),
InfoType = infoType,
};

TGetInfoResp getInfoResp = Client.GetInfo(req).Result;
if (getInfoResp.Status.StatusCode == TStatusCode.ERROR_STATUS)
{
Trace.TraceWarning("{0}, Error Code={1}, SQLState={2}",
getInfoResp.Status.ErrorMessage,
getInfoResp.Status.ErrorCode,
getInfoResp.Status.SqlState);
value = null;
return false;
}

value = getInfoResp.InfoValue.StringValue;
return true;
}

sealed class GetObjectsReader : IArrowArrayStream
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed unused code sealed class GetObjectsReader : IArrowArrayStream

{
HiveServer2Connection? connection;
Expand Down
47 changes: 38 additions & 9 deletions csharp/src/Drivers/Apache/Spark/SparkConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,17 @@ public class SparkConnection : HiveServer2Connection
AdbcInfoCode.DriverName,
AdbcInfoCode.DriverVersion,
AdbcInfoCode.DriverArrowVersion,
AdbcInfoCode.VendorName
AdbcInfoCode.VendorName,
AdbcInfoCode.VendorSql,
AdbcInfoCode.VendorVersion,
};

const string InfoDriverName = "ADBC Spark Driver";
// TODO: Make this dynamically return current version
const string InfoDriverVersion = "1.0.0";
const string InfoVendorName = "Spark";
const string InfoVendorName = "Spark SQL";
const string InfoDriverArrowVersion = "1.0.0";
const bool InfoVendorSql = true;
const int DecimalPrecisionDefault = 10;
const int DecimalScaleDefault = 0;

Expand Down Expand Up @@ -137,6 +141,7 @@ public override AdbcStatement CreateStatement()
public override IArrowArrayStream GetInfo(IReadOnlyList<AdbcInfoCode> codes)
{
const int strValTypeID = 0;
const int boolValTypeId = 1;

UnionType infoUnionType = new UnionType(
new Field[]
Expand Down Expand Up @@ -178,8 +183,11 @@ public override IArrowArrayStream GetInfo(IReadOnlyList<AdbcInfoCode> codes)
ArrowBuffer.Builder<byte> typeBuilder = new ArrowBuffer.Builder<byte>();
ArrowBuffer.Builder<int> offsetBuilder = new ArrowBuffer.Builder<int>();
StringArray.Builder stringInfoBuilder = new StringArray.Builder();
BooleanArray.Builder booleanInfoBuilder = new BooleanArray.Builder();

int nullCount = 0;
int arrayLength = codes.Count;
int offset = 0;

foreach (AdbcInfoCode code in codes)
{
Expand All @@ -188,32 +196,53 @@ public override IArrowArrayStream GetInfo(IReadOnlyList<AdbcInfoCode> codes)
case AdbcInfoCode.DriverName:
infoNameBuilder.Append((UInt32)code);
typeBuilder.Append(strValTypeID);
offsetBuilder.Append(stringInfoBuilder.Length);
offsetBuilder.Append(offset++);
stringInfoBuilder.Append(InfoDriverName);
booleanInfoBuilder.AppendNull();
break;
case AdbcInfoCode.DriverVersion:
infoNameBuilder.Append((UInt32)code);
typeBuilder.Append(strValTypeID);
offsetBuilder.Append(stringInfoBuilder.Length);
offsetBuilder.Append(offset++);
stringInfoBuilder.Append(InfoDriverVersion);
booleanInfoBuilder.AppendNull();
break;
case AdbcInfoCode.DriverArrowVersion:
infoNameBuilder.Append((UInt32)code);
typeBuilder.Append(strValTypeID);
offsetBuilder.Append(stringInfoBuilder.Length);
offsetBuilder.Append(offset++);
stringInfoBuilder.Append(InfoDriverArrowVersion);
booleanInfoBuilder.AppendNull();
break;
case AdbcInfoCode.VendorName:
infoNameBuilder.Append((UInt32)code);
typeBuilder.Append(strValTypeID);
offsetBuilder.Append(stringInfoBuilder.Length);
stringInfoBuilder.Append(InfoVendorName);
offsetBuilder.Append(offset++);
string vendorName = VendorName ?? InfoVendorName;
stringInfoBuilder.Append(vendorName);
booleanInfoBuilder.AppendNull();
break;
case AdbcInfoCode.VendorVersion:
infoNameBuilder.Append((UInt32)code);
typeBuilder.Append(strValTypeID);
offsetBuilder.Append(offset++);
string? vendorVersion = VendorVersion;
stringInfoBuilder.Append(vendorVersion);
booleanInfoBuilder.AppendNull();
break;
case AdbcInfoCode.VendorSql:
infoNameBuilder.Append((UInt32)code);
typeBuilder.Append(boolValTypeId);
offsetBuilder.Append(offset++);
stringInfoBuilder.AppendNull();
booleanInfoBuilder.Append(InfoVendorSql);
break;
default:
infoNameBuilder.Append((UInt32)code);
typeBuilder.Append(strValTypeID);
offsetBuilder.Append(stringInfoBuilder.Length);
offsetBuilder.Append(offset++);
stringInfoBuilder.AppendNull();
booleanInfoBuilder.AppendNull();
nullCount++;
break;
}
Expand All @@ -231,7 +260,7 @@ public override IArrowArrayStream GetInfo(IReadOnlyList<AdbcInfoCode> codes)
IArrowArray[] childrenArrays = new IArrowArray[]
{
stringInfoBuilder.Build(),
new BooleanArray.Builder().Build(),
booleanInfoBuilder.Build(),
new Int64Array.Builder().Build(),
new Int32Array.Builder().Build(),
new ListArray.Builder(StringType.Default).Build(),
Expand Down
77 changes: 73 additions & 4 deletions csharp/test/Drivers/Apache/Spark/DriverTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,30 @@ public async Task CanGetInfo()
{
AdbcConnection adbcConnection = NewConnection();

using IArrowArrayStream stream = adbcConnection.GetInfo(new List<AdbcInfoCode>() { AdbcInfoCode.DriverName, AdbcInfoCode.DriverVersion, AdbcInfoCode.VendorName });
// Test the supported info codes
List<AdbcInfoCode> handledCodes = new List<AdbcInfoCode>()
{
AdbcInfoCode.DriverName,
AdbcInfoCode.DriverVersion,
AdbcInfoCode.VendorName,
AdbcInfoCode.DriverArrowVersion,
AdbcInfoCode.VendorVersion,
AdbcInfoCode.VendorSql
};
using IArrowArrayStream stream = adbcConnection.GetInfo(handledCodes);

RecordBatch recordBatch = await stream.ReadNextRecordBatchAsync();
UInt32Array infoNameArray = (UInt32Array)recordBatch.Column("info_name");

List<string> expectedValues = new List<string>() { "DriverName", "DriverVersion", "VendorName" };
List<string> expectedValues = new List<string>()
{
"DriverName",
"DriverVersion",
"VendorName",
"DriverArrowVersion",
"VendorVersion",
"VendorSql"
};

for (int i = 0; i < infoNameArray.Length; i++)
{
Expand All @@ -98,8 +116,59 @@ public async Task CanGetInfo()

Assert.Contains(value.ToString(), expectedValues);

StringArray stringArray = (StringArray)valueArray.Fields[0];
Console.WriteLine($"{value}={stringArray.GetString(i)}");
switch (value)
{
case AdbcInfoCode.VendorSql:
// TODO: How does external developer know the second field is the boolean field?
BooleanArray booleanArray = (BooleanArray)valueArray.Fields[1];
bool? boolValue = booleanArray.GetValue(i);
OutputHelper?.WriteLine($"{value}={boolValue}");
Assert.True(boolValue);
break;
default:
StringArray stringArray = (StringArray)valueArray.Fields[0];
string stringValue = stringArray.GetString(i);
OutputHelper?.WriteLine($"{value}={stringValue}");
Assert.NotNull(stringValue);
break;
}
}

// Test the unhandled info codes.
List<AdbcInfoCode> unhandledCodes = new List<AdbcInfoCode>()
{
AdbcInfoCode.VendorArrowVersion,
AdbcInfoCode.VendorSubstrait,
AdbcInfoCode.VendorSubstraitMaxVersion
};
using IArrowArrayStream stream2 = adbcConnection.GetInfo(unhandledCodes);

recordBatch = await stream2.ReadNextRecordBatchAsync();
infoNameArray = (UInt32Array)recordBatch.Column("info_name");

List<string> unexpectedValues = new List<string>()
{
"VendorArrowVersion",
"VendorSubstrait",
"VendorSubstraitMaxVersion"
};
for (int i = 0; i < infoNameArray.Length; i++)
{
AdbcInfoCode? value = (AdbcInfoCode?)infoNameArray.GetValue(i);
DenseUnionArray valueArray = (DenseUnionArray)recordBatch.Column("info_value");

Assert.Contains(value.ToString(), unexpectedValues);
switch (value)
{
case AdbcInfoCode.VendorSql:
BooleanArray booleanArray = (BooleanArray)valueArray.Fields[1];
Assert.Null(booleanArray.GetValue(i));
break;
default:
StringArray stringArray = (StringArray)valueArray.Fields[0];
Assert.Null(stringArray.GetString(i));
break;
}
}
}

Expand Down
Loading