Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(csharp/src/Drivers/Apache): extend capability of GetInfo for Spark driver #1863

91 changes: 31 additions & 60 deletions csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;
using Apache.Arrow.Ipc;
Expand All @@ -35,17 +36,29 @@ public abstract class HiveServer2Connection : AdbcConnection
internal TTransport? transport;
internal TCLIService.Client? client;
internal TSessionHandle? sessionHandle;
private readonly Lazy<string> _vendorVersion;
private readonly Lazy<string> _vendorName;

internal HiveServer2Connection(IReadOnlyDictionary<string, string> properties)
{
this.properties = properties;
// Note: "LazyThreadSafetyMode.PublicationOnly" is thread-safe initialization where
// the first successful thread sets the value. If an exception is thrown, initialization
// will retry until it successfully returns a value without an exception.
// https://learn.microsoft.com/en-us/dotnet/framework/performance/lazy-initialization#exceptions-in-lazy-objects
_vendorVersion = new Lazy<string>(() => GetInfoTypeStringValue(TGetInfoType.CLI_DBMS_VER), LazyThreadSafetyMode.PublicationOnly);
_vendorName = new Lazy<string>(() => GetInfoTypeStringValue(TGetInfoType.CLI_DBMS_NAME), LazyThreadSafetyMode.PublicationOnly);
}

internal TCLIService.Client Client
{
get { return this.client ?? throw new InvalidOperationException("connection not open"); }
}

protected string VendorVersion => _vendorVersion.Value;

protected string VendorName => _vendorName.Value;

internal async Task OpenAsync()
{
TProtocol protocol = await CreateProtocolAsync();
Expand Down Expand Up @@ -81,6 +94,24 @@ protected void PollForResponse()
} while (statusResponse.OperationState == TOperationState.PENDING_STATE || statusResponse.OperationState == TOperationState.RUNNING_STATE);
}

private string GetInfoTypeStringValue(TGetInfoType infoType)
{
TGetInfoReq req = new()
{
SessionHandle = this.sessionHandle ?? throw new InvalidOperationException("session not created"),
InfoType = infoType,
};

TGetInfoResp getInfoResp = Client.GetInfo(req).Result;
if (getInfoResp.Status.StatusCode == TStatusCode.ERROR_STATUS)
{
throw new HiveServer2Exception(getInfoResp.Status.ErrorMessage)
.SetNativeError(getInfoResp.Status.ErrorCode)
.SetSqlState(getInfoResp.Status.SqlState);
}

return getInfoResp.InfoValue.StringValue;
}

public override void Dispose()
{
Expand All @@ -102,65 +133,5 @@ protected Schema GetSchema()
TGetResultSetMetadataResp response = this.Client.GetResultSetMetadata(request).Result;
return SchemaParser.GetArrowSchema(response.Schema);
}

sealed class GetObjectsReader : IArrowArrayStream
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed unused code sealed class GetObjectsReader : IArrowArrayStream

{
HiveServer2Connection? connection;
Schema schema;
List<TSparkArrowBatch>? batches;
int index;
IArrowReader? reader;

public GetObjectsReader(HiveServer2Connection connection, Schema schema)
{
this.connection = connection;
this.schema = schema;
}

public Schema Schema { get { return schema; } }

public async ValueTask<RecordBatch?> ReadNextRecordBatchAsync(CancellationToken cancellationToken = default)
{
while (true)
{
if (this.reader != null)
{
RecordBatch? next = await this.reader.ReadNextRecordBatchAsync(cancellationToken);
if (next != null)
{
return next;
}
this.reader = null;
}

if (this.batches != null && this.index < this.batches.Count)
{
this.reader = new ArrowStreamReader(new ChunkStream(this.schema, this.batches[this.index++].Batch));
continue;
}

this.batches = null;
this.index = 0;

if (this.connection == null)
{
return null;
}

TFetchResultsReq request = new TFetchResultsReq(this.connection.operationHandle, TFetchOrientation.FETCH_NEXT, 50000);
TFetchResultsResp response = await this.connection.Client.FetchResults(request, cancellationToken);
this.batches = response.Results.ArrowBatches;

if (!response.HasMoreRows)
{
this.connection = null;
}
}
}

public void Dispose()
{
}
}
}
}
61 changes: 50 additions & 11 deletions csharp/src/Drivers/Apache/Spark/SparkConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
using System.Diagnostics;
using System.Net.Http;
using System.Net.Http.Headers;
using System.Reflection;
using System.Text;
using System.Text.RegularExpressions;
using System.Threading;
Expand All @@ -43,16 +44,20 @@ public class SparkConnection : HiveServer2Connection
AdbcInfoCode.DriverName,
AdbcInfoCode.DriverVersion,
AdbcInfoCode.DriverArrowVersion,
AdbcInfoCode.VendorName
AdbcInfoCode.VendorName,
AdbcInfoCode.VendorSql,
AdbcInfoCode.VendorVersion,
};

const string ProductVersionDefault = "1.0.0";
const string InfoDriverName = "ADBC Spark Driver";
const string InfoDriverVersion = "1.0.0";
const string InfoVendorName = "Spark";
const string InfoDriverArrowVersion = "1.0.0";
const bool InfoVendorSql = true;
const int DecimalPrecisionDefault = 10;
const int DecimalScaleDefault = 0;

private readonly Lazy<string> _productVersion;

internal static TSparkGetDirectResults sparkGetDirectResults = new TSparkGetDirectResults(1000);

internal static readonly Dictionary<string, string> timestampConfig = new Dictionary<string, string>
Expand Down Expand Up @@ -83,8 +88,11 @@ private enum ColumnTypeId
internal SparkConnection(IReadOnlyDictionary<string, string> properties)
: base(properties)
{
_productVersion = new Lazy<string>(() => GetProductVersion(), LazyThreadSafetyMode.PublicationOnly);
}

protected string ProductVersion => _productVersion.Value;

protected override async ValueTask<TProtocol> CreateProtocolAsync()
{
Trace.TraceError($"create protocol with {properties.Count} properties.");
Expand Down Expand Up @@ -137,6 +145,7 @@ public override AdbcStatement CreateStatement()
public override IArrowArrayStream GetInfo(IReadOnlyList<AdbcInfoCode> codes)
{
const int strValTypeID = 0;
const int boolValTypeId = 1;

UnionType infoUnionType = new UnionType(
new Field[]
Expand Down Expand Up @@ -178,8 +187,11 @@ public override IArrowArrayStream GetInfo(IReadOnlyList<AdbcInfoCode> codes)
ArrowBuffer.Builder<byte> typeBuilder = new ArrowBuffer.Builder<byte>();
ArrowBuffer.Builder<int> offsetBuilder = new ArrowBuffer.Builder<int>();
StringArray.Builder stringInfoBuilder = new StringArray.Builder();
BooleanArray.Builder booleanInfoBuilder = new BooleanArray.Builder();

int nullCount = 0;
int arrayLength = codes.Count;
int offset = 0;

foreach (AdbcInfoCode code in codes)
{
Expand All @@ -188,32 +200,53 @@ public override IArrowArrayStream GetInfo(IReadOnlyList<AdbcInfoCode> codes)
case AdbcInfoCode.DriverName:
infoNameBuilder.Append((UInt32)code);
typeBuilder.Append(strValTypeID);
offsetBuilder.Append(stringInfoBuilder.Length);
offsetBuilder.Append(offset++);
stringInfoBuilder.Append(InfoDriverName);
booleanInfoBuilder.AppendNull();
break;
case AdbcInfoCode.DriverVersion:
infoNameBuilder.Append((UInt32)code);
typeBuilder.Append(strValTypeID);
offsetBuilder.Append(stringInfoBuilder.Length);
stringInfoBuilder.Append(InfoDriverVersion);
offsetBuilder.Append(offset++);
stringInfoBuilder.Append(ProductVersion);
booleanInfoBuilder.AppendNull();
break;
case AdbcInfoCode.DriverArrowVersion:
infoNameBuilder.Append((UInt32)code);
typeBuilder.Append(strValTypeID);
offsetBuilder.Append(stringInfoBuilder.Length);
offsetBuilder.Append(offset++);
stringInfoBuilder.Append(InfoDriverArrowVersion);
booleanInfoBuilder.AppendNull();
break;
case AdbcInfoCode.VendorName:
infoNameBuilder.Append((UInt32)code);
typeBuilder.Append(strValTypeID);
offsetBuilder.Append(stringInfoBuilder.Length);
stringInfoBuilder.Append(InfoVendorName);
offsetBuilder.Append(offset++);
string vendorName = VendorName;
stringInfoBuilder.Append(vendorName);
booleanInfoBuilder.AppendNull();
break;
case AdbcInfoCode.VendorVersion:
infoNameBuilder.Append((UInt32)code);
typeBuilder.Append(strValTypeID);
offsetBuilder.Append(offset++);
string? vendorVersion = VendorVersion;
stringInfoBuilder.Append(vendorVersion);
booleanInfoBuilder.AppendNull();
break;
case AdbcInfoCode.VendorSql:
infoNameBuilder.Append((UInt32)code);
typeBuilder.Append(boolValTypeId);
offsetBuilder.Append(offset++);
stringInfoBuilder.AppendNull();
booleanInfoBuilder.Append(InfoVendorSql);
break;
default:
infoNameBuilder.Append((UInt32)code);
typeBuilder.Append(strValTypeID);
offsetBuilder.Append(stringInfoBuilder.Length);
offsetBuilder.Append(offset++);
stringInfoBuilder.AppendNull();
booleanInfoBuilder.AppendNull();
nullCount++;
break;
}
Expand All @@ -231,7 +264,7 @@ public override IArrowArrayStream GetInfo(IReadOnlyList<AdbcInfoCode> codes)
IArrowArray[] childrenArrays = new IArrowArray[]
{
stringInfoBuilder.Build(),
new BooleanArray.Builder().Build(),
booleanInfoBuilder.Build(),
new Int64Array.Builder().Build(),
new Int32Array.Builder().Build(),
new ListArray.Builder(StringType.Default).Build(),
Expand Down Expand Up @@ -749,6 +782,12 @@ private static bool TryParse(string input, out Decimal128Type? value)
return true;
}
}

private string GetProductVersion()
{
FileVersionInfo fileVersionInfo = FileVersionInfo.GetVersionInfo(Assembly.GetExecutingAssembly().Location);
return fileVersionInfo.ProductVersion ?? ProductVersionDefault;
}
}

internal struct TableInfoPair
Expand Down
77 changes: 73 additions & 4 deletions csharp/test/Drivers/Apache/Spark/DriverTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,30 @@ public async Task CanGetInfo()
{
AdbcConnection adbcConnection = NewConnection();

using IArrowArrayStream stream = adbcConnection.GetInfo(new List<AdbcInfoCode>() { AdbcInfoCode.DriverName, AdbcInfoCode.DriverVersion, AdbcInfoCode.VendorName });
// Test the supported info codes
List<AdbcInfoCode> handledCodes = new List<AdbcInfoCode>()
{
AdbcInfoCode.DriverName,
AdbcInfoCode.DriverVersion,
AdbcInfoCode.VendorName,
AdbcInfoCode.DriverArrowVersion,
AdbcInfoCode.VendorVersion,
AdbcInfoCode.VendorSql
};
using IArrowArrayStream stream = adbcConnection.GetInfo(handledCodes);

RecordBatch recordBatch = await stream.ReadNextRecordBatchAsync();
UInt32Array infoNameArray = (UInt32Array)recordBatch.Column("info_name");

List<string> expectedValues = new List<string>() { "DriverName", "DriverVersion", "VendorName" };
List<string> expectedValues = new List<string>()
{
"DriverName",
"DriverVersion",
"VendorName",
"DriverArrowVersion",
"VendorVersion",
"VendorSql"
};

for (int i = 0; i < infoNameArray.Length; i++)
{
Expand All @@ -98,8 +116,59 @@ public async Task CanGetInfo()

Assert.Contains(value.ToString(), expectedValues);

StringArray stringArray = (StringArray)valueArray.Fields[0];
Console.WriteLine($"{value}={stringArray.GetString(i)}");
switch (value)
{
case AdbcInfoCode.VendorSql:
// TODO: How does external developer know the second field is the boolean field?
BooleanArray booleanArray = (BooleanArray)valueArray.Fields[1];
bool? boolValue = booleanArray.GetValue(i);
OutputHelper?.WriteLine($"{value}={boolValue}");
Assert.True(boolValue);
break;
default:
StringArray stringArray = (StringArray)valueArray.Fields[0];
string stringValue = stringArray.GetString(i);
OutputHelper?.WriteLine($"{value}={stringValue}");
Assert.NotNull(stringValue);
break;
}
}

// Test the unhandled info codes.
List<AdbcInfoCode> unhandledCodes = new List<AdbcInfoCode>()
{
AdbcInfoCode.VendorArrowVersion,
AdbcInfoCode.VendorSubstrait,
AdbcInfoCode.VendorSubstraitMaxVersion
};
using IArrowArrayStream stream2 = adbcConnection.GetInfo(unhandledCodes);

recordBatch = await stream2.ReadNextRecordBatchAsync();
infoNameArray = (UInt32Array)recordBatch.Column("info_name");

List<string> unexpectedValues = new List<string>()
{
"VendorArrowVersion",
"VendorSubstrait",
"VendorSubstraitMaxVersion"
};
for (int i = 0; i < infoNameArray.Length; i++)
{
AdbcInfoCode? value = (AdbcInfoCode?)infoNameArray.GetValue(i);
DenseUnionArray valueArray = (DenseUnionArray)recordBatch.Column("info_value");

Assert.Contains(value.ToString(), unexpectedValues);
switch (value)
{
case AdbcInfoCode.VendorSql:
BooleanArray booleanArray = (BooleanArray)valueArray.Fields[1];
Assert.Null(booleanArray.GetValue(i));
break;
default:
StringArray stringArray = (StringArray)valueArray.Fields[0];
Assert.Null(stringArray.GetString(i));
break;
}
}
}

Expand Down
Loading