Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,14 @@ def store_cloud_connection(
).id
else:
azure_account_key_ref_id = None
if connection.azure_sas_token is not None:
azure_sas_token_ref_id = store_secret(
db,
SecretInput(name=connection.connection_name + "azure_sas_token", value=connection.azure_sas_token),
user_id,
).id
else:
azure_sas_token_ref_id = None

db_cloud_connection = DBCloudStorageConnection(
connection_name=connection.connection_name,
Expand All @@ -231,6 +239,7 @@ def store_cloud_connection(
azure_client_id=connection.azure_client_id,
azure_account_key_id=azure_account_key_ref_id,
azure_client_secret_id=azure_client_secret_ref_id,
azure_sas_token_id=azure_sas_token_ref_id,
# Common fields
endpoint_url=connection.endpoint_url,
verify_ssl=connection.verify_ssl,
Expand Down Expand Up @@ -290,6 +299,12 @@ def get_cloud_connection_schema(db: Session, connection_name: str, user_id: int)
if secret_record:
azure_client_secret = decrypt_secret(secret_record.encrypted_value)

azure_sas_token = None
if db_connection.azure_sas_token_id:
secret_record = db.query(Secret).filter(Secret.id == db_connection.azure_sas_token_id).first()
if secret_record:
azure_sas_token = decrypt_secret(secret_record.encrypted_value)

# Construct the full Pydantic model
return FullCloudStorageConnection(
connection_name=db_connection.connection_name,
Expand All @@ -305,6 +320,7 @@ def get_cloud_connection_schema(db: Session, connection_name: str, user_id: int)
azure_tenant_id=db_connection.azure_tenant_id,
azure_client_id=db_connection.azure_client_id,
azure_client_secret=azure_client_secret,
azure_sas_token=azure_sas_token,
endpoint_url=db_connection.endpoint_url,
verify_ssl=db_connection.verify_ssl,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Literal

import boto3
from azure.storage.blob import BlobServiceClient, ContainerClient
from botocore.exceptions import ClientError

from flowfile_core.schemas.cloud_storage_schemas import FullCloudStorageConnection
Expand Down Expand Up @@ -258,3 +259,125 @@ def ensure_path_has_wildcard_pattern(resource_path: str, file_format: Literal["c
if not resource_path.endswith(f"*.{file_format}"):
resource_path = resource_path.rstrip("/") + f"/**/*.{file_format}"
return resource_path


def get_first_file_from_adls_dir(source: str, storage_options: dict[str, Any] = None) -> str:
"""
Get the first file from an Azure ADLS directory path.

Parameters
----------
source : str
ADLS path with wildcards (e.g., 'az://container/prefix/**/*.parquet' or
'abfs://container@account.dfs.core.windows.net/prefix/*.parquet')

storage_options: dict
Storage options containing authentication details

Returns
-------
str
ADLS URI of the first file found

Raises
------
ValueError
If source path is invalid or no files found
Exception
If ADLS access fails
"""
if not (source.startswith("az://") or source.startswith("abfs://")):
raise ValueError("Source must be a valid ADLS URI starting with 'az://' or 'abfs://'")

container_name, prefix, account_name = _parse_adls_path(source)
file_extension = _get_file_extension(source)
base_prefix = _remove_wildcards_from_prefix(prefix)

blob_service_client = _create_adls_client(account_name, storage_options)
container_client = blob_service_client.get_container_client(container_name)

# List blobs with the given prefix
first_file = _get_first_adls_file(container_client, base_prefix, file_extension)

# Return first file URI in az:// format
return f"az://{container_name}/{first_file['name']}"


def _parse_adls_path(source: str) -> tuple[str, str, str]:
"""
Parse ADLS URI into container name, prefix, and account name.

Supports both formats:
- az://container/prefix/path
- abfs://container@account.dfs.core.windows.net/prefix/path
"""
if source.startswith("az://"):
# Format: az://container/prefix/path
path_parts = source[5:].split("/", 1) # Remove 'az://'
container_name = path_parts[0]
prefix = path_parts[1] if len(path_parts) > 1 else ""
account_name = None # Will be extracted from storage_options
elif source.startswith("abfs://"):
# Format: abfs://container@account.dfs.core.windows.net/prefix/path
path_parts = source[7:].split("/", 1) # Remove 'abfs://'
container_and_account = path_parts[0]
prefix = path_parts[1] if len(path_parts) > 1 else ""

# Extract container and account
if "@" in container_and_account:
container_name, account_part = container_and_account.split("@", 1)
account_name = account_part.split(".")[0] # Extract account name from FQDN
else:
container_name = container_and_account
account_name = None
else:
raise ValueError("Invalid ADLS URI format")

return container_name, prefix, account_name


def _create_adls_client(account_name: str | None, storage_options: dict[str, Any] | None) -> BlobServiceClient:
"""Create Azure Blob Service Client with optional credentials."""
if storage_options is None:
raise ValueError("Storage options are required for ADLS connections")

# Extract account name from storage options if not provided
if account_name is None:
account_name = storage_options.get("account_name")

if not account_name:
raise ValueError("Azure account name is required")

account_url = f"https://{account_name}.blob.core.windows.net"

# Authenticate based on available credentials
if "account_key" in storage_options:
return BlobServiceClient(account_url=account_url, credential=storage_options["account_key"])
elif "sas_token" in storage_options:
return BlobServiceClient(account_url=account_url, credential=storage_options["sas_token"])
elif "client_id" in storage_options and "client_secret" in storage_options and "tenant_id" in storage_options:
# Service principal authentication
from azure.identity import ClientSecretCredential

credential = ClientSecretCredential(
tenant_id=storage_options["tenant_id"],
client_id=storage_options["client_id"],
client_secret=storage_options["client_secret"],
)
return BlobServiceClient(account_url=account_url, credential=credential)
else:
raise ValueError("No valid authentication method found in storage options")


def _get_first_adls_file(container_client: ContainerClient, base_prefix: str, file_extension: str) -> dict[str, Any]:
"""List all files in ADLS container with given prefix and return the first match."""
try:
blob_list = container_client.list_blobs(name_starts_with=base_prefix)

for blob in blob_list:
if blob.name.endswith(f".{file_extension}"):
return {"name": blob.name}

raise ValueError(f"No {file_extension} files found in container with prefix {base_prefix}")
except Exception as e:
raise ValueError(f"Failed to list files in ADLS container: {e}")
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from flowfile_core.flowfile.flow_data_engine.cloud_storage_reader import (
CloudStorageReader,
ensure_path_has_wildcard_pattern,
get_first_file_from_adls_dir,
get_first_file_from_s3_dir,
)
from flowfile_core.flowfile.flow_data_engine.create import funcs as create_funcs
Expand Down Expand Up @@ -589,7 +590,14 @@ def _get_schema_from_first_file_in_dir(
"""Infers the schema by scanning the first file in a cloud directory."""
try:
scan_func = getattr(pl, "scan_" + file_format)
first_file_ref = get_first_file_from_s3_dir(source, storage_options=storage_options)
# Determine storage type and use appropriate function
if source.startswith("s3://"):
first_file_ref = get_first_file_from_s3_dir(source, storage_options=storage_options)
elif source.startswith("az://") or source.startswith("abfs://"):
first_file_ref = get_first_file_from_adls_dir(source, storage_options=storage_options)
else:
raise ValueError(f"Unsupported cloud storage URI format: {source}")

return convert_stats_to_column_info(
FlowDataEngine._create_schema_stats_from_pl_schema(
scan_func(first_file_ref, storage_options=storage_options).collect_schema()
Expand Down
3 changes: 3 additions & 0 deletions flowfile_core/flowfile_core/schemas/cloud_storage_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class FullCloudStorageConnectionWorkerInterface(AuthSettingsInput):
azure_tenant_id: str | None = None
azure_client_id: str | None = None
azure_client_secret: str | None = None
azure_sas_token: str | None = None

# Common
endpoint_url: str | None = None
Expand All @@ -81,6 +82,7 @@ class FullCloudStorageConnection(AuthSettingsInput):
azure_tenant_id: str | None = None
azure_client_id: str | None = None
azure_client_secret: SecretStr | None = None
azure_sas_token: SecretStr | None = None

# Common
endpoint_url: str | None = None
Expand Down Expand Up @@ -111,6 +113,7 @@ def get_worker_interface(self, user_id: int) -> "FullCloudStorageConnectionWorke
azure_account_key=encrypt_for_worker(self.azure_account_key, user_id),
azure_client_id=self.azure_client_id,
azure_client_secret=encrypt_for_worker(self.azure_client_secret, user_id),
azure_sas_token=encrypt_for_worker(self.azure_sas_token, user_id),
endpoint_url=self.endpoint_url,
verify_ssl=self.verify_ssl,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
<label for="storage-type" class="form-label">Storage Type</label>
<select id="storage-type" v-model="connection.storageType" class="form-input" required>
<option value="s3">AWS S3</option>
<!-- <option value="adls">Azure Data Lake Storage</option> -->
<option value="adls">Azure Data Lake Storage</option>
</select>
</div>

Expand Down Expand Up @@ -200,6 +200,29 @@
</div>
</div>
</template>

<!-- Azure SAS Token (for sas_token auth) -->
<div v-if="connection.authMethod === 'sas_token'" class="form-field">
<label for="azure-sas-token" class="form-label">Azure SAS Token</label>
<div class="password-field">
<input
id="azure-sas-token"
v-model="connection.azureSasToken"
:type="showAzureSasToken ? 'text' : 'password'"
class="form-input"
placeholder="SAS token"
:required="connection.authMethod === 'sas_token'"
/>
<button
type="button"
class="toggle-visibility"
aria-label="Toggle Azure SAS token visibility"
@click="showAzureSasToken = !showAzureSasToken"
>
<i :class="showAzureSasToken ? 'fa-solid fa-eye-slash' : 'fa-solid fa-eye'"></i>
</button>
</div>
</div>
</template>

<!-- Common Fields -->
Expand Down Expand Up @@ -299,6 +322,7 @@ watch(
const showAwsSecret = ref(false);
const showAzureKey = ref(false);
const showAzureSecret = ref(false);
const showAzureSasToken = ref(false);

// Computed property for available auth methods based on storage type
const availableAuthMethods = computed(() => {
Expand Down Expand Up @@ -353,6 +377,8 @@ const isValid = computed(() => {
!!connection.value.azureClientId &&
!!connection.value.azureClientSecret
);
} else if (connection.value.authMethod === "sas_token") {
return !!connection.value.azureSasToken;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ export interface PythonFullCloudStorageConnection extends PythonAuthSettingsInpu
azure_tenant_id?: string;
azure_client_id?: string;
azure_client_secret?: string;
azure_sas_token?: string;

// Common
endpoint_url?: string;
Expand All @@ -59,6 +60,7 @@ export interface FullCloudStorageConnection extends AuthSettingsInput {
azureTenantId?: string;
azureClientId?: string;
azureClientSecret?: string;
azureSasToken?: string;

// Common
endpointUrl?: string;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ const toPythonFormat = (
azure_tenant_id: connection.azureTenantId,
azure_client_id: connection.azureClientId,
azure_client_secret: connection.azureClientSecret,
azure_sas_token: connection.azureSasToken,

// Common
endpoint_url: connection.endpointUrl,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class FullCloudStorageConnection(BaseModel):
azure_tenant_id: str | None = None
azure_client_id: str | None = None
azure_client_secret: SecretStr | None = None
azure_sas_token: SecretStr | None = None

# Common
endpoint_url: str | None = None
Expand All @@ -84,6 +85,10 @@ def get_storage_options(self) -> dict[str, Any]:
"""
if self.storage_type == "s3":
return self._get_s3_storage_options()
elif self.storage_type == "adls":
return self._get_adls_storage_options()
else:
raise ValueError(f"Unsupported storage type: {self.storage_type}")

def _get_s3_storage_options(self) -> dict[str, Any]:
"""Build S3-specific storage options."""
Expand Down Expand Up @@ -127,6 +132,44 @@ def _get_s3_storage_options(self) -> dict[str, Any]:

return storage_options

def _get_adls_storage_options(self) -> dict[str, Any]:
"""Build Azure ADLS-specific storage options."""
auth_method = self.auth_method
print(f"Building ADLS storage options for auth_method: '{auth_method}'")

storage_options = {}

# Common options
if self.azure_account_name:
storage_options["account_name"] = self.azure_account_name

if auth_method == "access_key":
# Account key authentication
if self.azure_account_key:
storage_options["account_key"] = decrypt_secret(
self.azure_account_key.get_secret_value()
).get_secret_value()

elif auth_method == "service_principal":
# Service principal authentication
if self.azure_tenant_id:
storage_options["tenant_id"] = self.azure_tenant_id
if self.azure_client_id:
storage_options["client_id"] = self.azure_client_id
if self.azure_client_secret:
storage_options["client_secret"] = decrypt_secret(
self.azure_client_secret.get_secret_value()
).get_secret_value()

elif auth_method == "sas_token":
# SAS token authentication
if self.azure_sas_token:
storage_options["sas_token"] = decrypt_secret(
self.azure_sas_token.get_secret_value()
).get_secret_value()

return storage_options


class WriteSettings(BaseModel):
"""Settings for writing to cloud storage"""
Expand Down
Loading
Loading