Skip to content

Commit

Permalink
Allow configuring endpoint for Azure FS
Browse files Browse the repository at this point in the history
  • Loading branch information
electrum committed Aug 19, 2024
1 parent 8edf5b4 commit c463301
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 37 deletions.
6 changes: 6 additions & 0 deletions docs/src/main/sphinx/object-storage/file-system-azure.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ system support:
authentication used with `NONE`. Use `ACCESS_KEY` for
[](azure-access-key-authentication) or and `OAUTH` for
[](azure-oauth-authentication).
* - `azure.endpoint`
- Hostname suffix of the Azure storage endpoint.
Defaults to `core.windows.net` for the global Azure cloud.
Use `core.usgovcloudapi.net` for the Azure US Government cloud,
`core.cloudapi.de` for the Azure Germany cloud,
or `core.chinacloudapi.cn` for the Azure China cloud.
* - `azure.read-block-size`
- [Data size](prop-type-data-size) for blocks during read operations. Defaults
to `4MB`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ public class AzureFileSystem
private final HttpClient httpClient;
private final TracingOptions tracingOptions;
private final AzureAuth azureAuth;
private final String endpoint;
private final int readBlockSizeBytes;
private final long writeBlockSizeBytes;
private final int maxWriteConcurrency;
Expand All @@ -72,6 +73,7 @@ public AzureFileSystem(
HttpClient httpClient,
TracingOptions tracingOptions,
AzureAuth azureAuth,
String endpoint,
DataSize readBlockSize,
DataSize writeBlockSize,
int maxWriteConcurrency,
Expand All @@ -80,6 +82,7 @@ public AzureFileSystem(
this.httpClient = requireNonNull(httpClient, "httpClient is null");
this.tracingOptions = requireNonNull(tracingOptions, "tracingOptions is null");
this.azureAuth = requireNonNull(azureAuth, "azureAuth is null");
this.endpoint = requireNonNull(endpoint, "endpoint is null");
this.readBlockSizeBytes = toIntExact(readBlockSize.toBytes());
this.writeBlockSizeBytes = writeBlockSize.toBytes();
checkArgument(maxWriteConcurrency >= 0, "maxWriteConcurrency is negative");
Expand Down Expand Up @@ -450,6 +453,14 @@ private boolean isHierarchicalNamespaceEnabled(AzureLocation location)
}
}

private String validatedEndpoint(AzureLocation location)
{
if (!location.endpoint().equals(endpoint)) {
throw new IllegalArgumentException("Location does not match configured Azure endpoint: " + location);
}
return location.endpoint();
}

private BlobClient createBlobClient(AzureLocation location)
{
return createBlobContainerClient(location).getBlobClient(location.path());
Expand All @@ -462,7 +473,7 @@ private BlobContainerClient createBlobContainerClient(AzureLocation location)
BlobContainerClientBuilder builder = new BlobContainerClientBuilder()
.httpClient(httpClient)
.clientOptions(new ClientOptions().setTracingOptions(tracingOptions))
.endpoint(String.format("https://%s.blob.core.windows.net", location.account()));
.endpoint("https://%s.blob.%s".formatted(location.account(), validatedEndpoint(location)));
azureAuth.setAuth(location.account(), builder);
location.container().ifPresent(builder::containerName);
return builder.buildClient();
Expand All @@ -475,7 +486,7 @@ private DataLakeFileSystemClient createFileSystemClient(AzureLocation location)
DataLakeServiceClientBuilder builder = new DataLakeServiceClientBuilder()
.httpClient(httpClient)
.clientOptions(new ClientOptions().setTracingOptions(tracingOptions))
.endpoint(String.format("https://%s.dfs.core.windows.net", location.account()));
.endpoint("https://%s.dfs.%s".formatted(location.account(), validatedEndpoint(location)));
azureAuth.setAuth(location.account(), builder);
DataLakeServiceClient client = builder.buildClient();
DataLakeFileSystemClient fileSystemClient = client.getFileSystemClient(location.container().orElseThrow());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import io.airlift.configuration.Config;
import io.airlift.units.DataSize;
import io.airlift.units.DataSize.Unit;
import jakarta.validation.constraints.NotEmpty;
import jakarta.validation.constraints.NotNull;

public class AzureFileSystemConfig
Expand All @@ -28,7 +29,7 @@ public enum AuthType
}

private AuthType authType = AuthType.DEFAULT;

private String endpoint = "core.windows.net";
private DataSize readBlockSize = DataSize.of(4, Unit.MEGABYTE);
private DataSize writeBlockSize = DataSize.of(4, Unit.MEGABYTE);
private int maxWriteConcurrency = 8;
Expand All @@ -47,6 +48,19 @@ public AzureFileSystemConfig setAuthType(AuthType authType)
return this;
}

@NotEmpty
public String getEndpoint()
{
return endpoint;
}

@Config("azure.endpoint")
public AzureFileSystemConfig setEndpoint(String endpoint)
{
this.endpoint = endpoint;
return this;
}

@NotNull
public DataSize getReadBlockSize()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ public class AzureFileSystemFactory
implements TrinoFileSystemFactory
{
private final AzureAuth auth;
private final String endpoint;
private final DataSize readBlockSize;
private final DataSize writeBlockSize;
private final int maxWriteConcurrency;
Expand All @@ -51,6 +52,7 @@ public AzureFileSystemFactory(OpenTelemetry openTelemetry, AzureAuth azureAuth,
{
this(openTelemetry,
azureAuth,
config.getEndpoint(),
config.getReadBlockSize(),
config.getWriteBlockSize(),
config.getMaxWriteConcurrency(),
Expand All @@ -60,12 +62,14 @@ public AzureFileSystemFactory(OpenTelemetry openTelemetry, AzureAuth azureAuth,
public AzureFileSystemFactory(
OpenTelemetry openTelemetry,
AzureAuth azureAuth,
String endpoint,
DataSize readBlockSize,
DataSize writeBlockSize,
int maxWriteConcurrency,
DataSize maxSingleUploadSize)
{
this.auth = requireNonNull(azureAuth, "azureAuth is null");
this.endpoint = requireNonNull(endpoint, "endpoint is null");
this.readBlockSize = requireNonNull(readBlockSize, "readBlockSize is null");
this.writeBlockSize = requireNonNull(writeBlockSize, "writeBlockSize is null");
checkArgument(maxWriteConcurrency >= 0, "maxWriteConcurrency is negative");
Expand All @@ -89,7 +93,7 @@ public void destroy()
@Override
public TrinoFileSystem create(ConnectorIdentity identity)
{
return new AzureFileSystem(httpClient, tracingOptions, auth, readBlockSize, writeBlockSize, maxWriteConcurrency, maxSingleUploadSize);
return new AzureFileSystem(httpClient, tracingOptions, auth, endpoint, readBlockSize, writeBlockSize, maxWriteConcurrency, maxSingleUploadSize);
}

public static HttpClient createAzureHttpClient(OkHttpClient okHttpClient, HttpClientOptions clientOptions)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

class AzureLocation
{
private static final String INVALID_LOCATION_MESSAGE = "Invalid Azure location. Expected form is 'abfs://[<containerName>@]<accountName>.dfs.core.windows.net/<filePath>': %s";
private static final String INVALID_LOCATION_MESSAGE = "Invalid Azure location. Expected form is 'abfs://[<containerName>@]<accountName>.dfs.<endpoint>/<filePath>': %s";

// https://learn.microsoft.com/en-us/azure/azure-resource-manager/management/resource-name-rules
private static final CharMatcher CONTAINER_VALID_CHARACTERS = CharMatcher.inRange('a', 'z').or(CharMatcher.inRange('0', '9')).or(CharMatcher.is('-'));
Expand All @@ -32,7 +32,16 @@ class AzureLocation
private final Location location;
private final String scheme;
private final String account;
private final String endpoint;

/**
* Creates a new location based on the endpoint, storage account, container and blob path parsed from the location.
* <p>
* Locations follow the conventions used by
* <a href="https://docs.microsoft.com/en-us/azure/storage/blobs/data-lake-storage-introduction-abfs-uri">ABFS URI</a>
* that follows the following convention
* <pre>{@code abfs://<container-name>@<storage-account-name>.dfs.<endpoint>/<blob_path>}</pre>
*/
public AzureLocation(Location location)
{
this.location = requireNonNull(location, "location is null");
Expand Down Expand Up @@ -67,28 +76,19 @@ public AzureLocation(Location location)
this.location);
this.account = host.substring(0, accountSplit);

// host must end with ".dfs.core.windows.net"
checkArgument(host.substring(accountSplit).equals(".dfs.core.windows.net"), INVALID_LOCATION_MESSAGE, location);
// host must contain ".dfs." before endpoint
checkArgument(host.substring(accountSplit).startsWith(".dfs."), INVALID_LOCATION_MESSAGE, location);

// endpoint is the part after ".dfs."
this.endpoint = host.substring(accountSplit + ".dfs.".length());
checkArgument(!endpoint.isEmpty(), INVALID_LOCATION_MESSAGE, location);

// storage account is interpolated into URL host name, so perform extra checks
checkArgument(STORAGE_ACCOUNT_VALID_CHARACTERS.matchesAllOf(account),
"Invalid Azure storage account name. Valid characters are 'a-z' and '0-9': %s",
location);
}

/**
* Creates a new {@link AzureLocation} based on the storage account, container and blob path parsed from the location.
* <p>
* Locations follow the conventions used by
* <a href="https://docs.microsoft.com/en-us/azure/storage/blobs/data-lake-storage-introduction-abfs-uri">ABFS URI</a>
* that follows the following convention
* <pre>{@code abfs://<container-name>@<storage-account-name>.dfs.core.windows.net/<blob_path>}</pre>
*/
public static AzureLocation from(String location)
{
return new AzureLocation(Location.of(location));
}

public Location location()
{
return location;
Expand All @@ -104,6 +104,11 @@ public String account()
return account;
}

public String endpoint()
{
return endpoint;
}

public String path()
{
return location.path();
Expand All @@ -126,9 +131,10 @@ public String toString()

public Location baseLocation()
{
return Location.of("%s://%s%s.dfs.core.windows.net/".formatted(
return Location.of("%s://%s%s.dfs.%s/".formatted(
scheme,
container().map(container -> container + "@").orElse(""),
account()));
account(),
endpoint));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ void testDefaults()
{
assertRecordedDefaults(recordDefaults(AzureFileSystemConfig.class)
.setAuthType(AuthType.DEFAULT)
.setEndpoint("core.windows.net")
.setReadBlockSize(DataSize.of(4, Unit.MEGABYTE))
.setWriteBlockSize(DataSize.of(4, Unit.MEGABYTE))
.setMaxWriteConcurrency(8)
Expand All @@ -43,6 +44,7 @@ public void testExplicitPropertyMappings()
{
Map<String, String> properties = ImmutableMap.<String, String>builder()
.put("azure.auth-type", "oauth")
.put("azure.endpoint", "core.usgovcloudapi.net")
.put("azure.read-block-size", "3MB")
.put("azure.write-block-size", "5MB")
.put("azure.max-write-concurrency", "7")
Expand All @@ -51,6 +53,7 @@ public void testExplicitPropertyMappings()

AzureFileSystemConfig expected = new AzureFileSystemConfig()
.setAuthType(AuthType.OAUTH)
.setEndpoint("core.usgovcloudapi.net")
.setReadBlockSize(DataSize.of(3, Unit.MEGABYTE))
.setWriteBlockSize(DataSize.of(5, Unit.MEGABYTE))
.setMaxWriteConcurrency(7)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,23 @@ class TestAzureLocation
@Test
void test()
{
assertValid("abfs://[email protected]/some/path/file", "account", "container", "some/path/file");
assertValid("abfss://[email protected]/some/path/file", "account", "container", "some/path/file", "abfss");
assertValid("abfs://[email protected]/some/path/file", "account", "container", "some/path/file", "abfs", "core.windows.net");
assertValid("abfss://[email protected]/some/path/file", "account", "container", "some/path/file", "abfss", "core.windows.net");

assertValid("abfs://[email protected]/some/path/file", "account", "container-stuff", "some/path/file");
assertValid("abfs://[email protected]/some/path/file", "account", "container2", "some/path/file");
assertValid("abfs://account.dfs.core.windows.net/some/path/file", "account", null, "some/path/file");
assertValid("abfs://[email protected]/some/path/file", "account", "container-stuff", "some/path/file", "abfs", "core.windows.net");
assertValid("abfs://[email protected]/some/path/file", "account", "container2", "some/path/file", "abfs", "core.windows.net");
assertValid("abfs://account.dfs.core.windows.net/some/path/file", "account", null, "some/path/file", "abfs", "core.windows.net");

assertValid("abfs://[email protected]/file", "account", "container", "file");
assertValid("abfs://[email protected]///f///i///l///e///", "account0", "container", "//f///i///l///e///");
assertValid("abfs://[email protected]/file", "account", "container", "file", "abfs", "core.windows.net");
assertValid("abfs://[email protected]///f///i///l///e///", "account0", "container", "//f///i///l///e///", "abfs", "core.windows.net");

// other endpoints are allowed
assertValid("abfs://[email protected]/some/path/file", "account", "container", "some/path/file", "abfs", "core.usgovcloudapi.net");
assertValid("abfss://[email protected]/some/path/file", "account", "container", "some/path/file", "abfss", "core.usgovcloudapi.net");

// only abfs and abfss schemes allowed
assertInvalid("https://[email protected]/some/path/file");

// host must have at least to labels
assertInvalid("abfs://container@account/some/path/file");
assertInvalid("abfs://container@/some/path/file");
Expand All @@ -54,32 +59,29 @@ void test()
assertInvalid("abfs://[email protected]/some/path/file");
assertInvalid("abfs://[email protected]/some/path/file");
assertInvalid("abfs://[email protected]/some/path/file");

// account is only a-z and 0-9
assertInvalid("abfs://[email protected]/some/path/file");
assertInvalid("abfs://container@ac_count.dfs.core.windows.net/some/path/file");
assertInvalid("abfs://container@ac$count.dfs.core.windows.net/some/path/file");
// host must end with .dfs.core.windows.net

// host must contain .dfs. after account
assertInvalid("abfs://[email protected]/some/path/file");
// host must be just account.dfs.core.windows.net
assertInvalid("abfs://[email protected]/some/path/file");
}

private static void assertValid(String uri, String expectedAccount, String expectedContainer, String expectedPath, String expectedScheme)
private static void assertValid(String uri, String expectedAccount, String expectedContainer, String expectedPath, String expectedScheme, String expectedEndpoint)
{
Location location = Location.of(uri);
AzureLocation azureLocation = new AzureLocation(location);
assertThat(azureLocation.location()).isEqualTo(location);
assertThat(azureLocation.account()).isEqualTo(expectedAccount);
assertThat(azureLocation.endpoint()).isEqualTo(expectedEndpoint);
assertThat(azureLocation.container()).isEqualTo(Optional.ofNullable(expectedContainer));
assertThat(azureLocation.path()).contains(expectedPath);
assertThat(azureLocation.baseLocation().scheme()).isEqualTo(Optional.of(expectedScheme));
}

private static void assertValid(String uri, String expectedAccount, String expectedContainer, String expectedPath)
{
assertValid(uri, expectedAccount, expectedContainer, expectedPath, "abfs");
}

private static void assertInvalid(String uri)
{
Location location = Location.of(uri);
Expand Down

0 comments on commit c463301

Please sign in to comment.