forked from Azure-Samples/azure-search-openai-demo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmanageacl.py
266 lines (238 loc) · 11.4 KB
/
manageacl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
import argparse
import asyncio
import json
import logging
import os
from typing import Any, Union
from urllib.parse import urljoin
from azure.core.credentials import AzureKeyCredential
from azure.core.credentials_async import AsyncTokenCredential
from azure.identity.aio import AzureDeveloperCliCredential
from azure.search.documents.aio import SearchClient
from azure.search.documents.indexes.aio import SearchIndexClient
from azure.search.documents.indexes.models import (
SearchFieldDataType,
SimpleField,
)
from load_azd_env import load_azd_env
logger = logging.getLogger("scripts")
class ManageAcl:
"""
Manually enable document level access control on a search index and manually set access control values using the [manageacl.ps1](./scripts/manageacl.ps1) script.
"""
def __init__(
self,
service_name: str,
index_name: str,
url: str,
acl_action: str,
acl_type: str,
acl: str,
credentials: Union[AsyncTokenCredential, AzureKeyCredential],
):
"""
Initializes the command
Parameters
----------
service_name
Name of the Azure Search service
index_name
Name of the Azure Search index
url
Full Blob storage URL of the document to manage acls for
acl_action
Action to take regarding the index or document. Valid values include enable_acls (turn acls on for the entire index), view (print acls for the document), remove_all (remove all acls), remove (remove a specific acl), or add (add a specific acl)
acl_type
Type of acls to manage. Valid values include groups or oids.
acl
The actual value of the acl, if the acl action is add or remove
credentials
Credentials for the azure search service
"""
self.service_name = service_name
self.index_name = index_name
self.credentials = credentials
self.url = url
self.acl_action = acl_action
self.acl_type = acl_type
self.acl = acl
async def run(self):
endpoint = f"https://{self.service_name}.search.windows.net"
if self.acl_action == "enable_acls":
await self.enable_acls(endpoint)
return
async with SearchClient(
endpoint=endpoint, index_name=self.index_name, credential=self.credentials
) as search_client:
if self.acl_action == "view":
await self.view_acl(search_client)
elif self.acl_action == "remove":
await self.remove_acl(search_client)
elif self.acl_action == "remove_all":
await self.remove_all_acls(search_client)
elif self.acl_action == "add":
await self.add_acl(search_client)
elif self.acl_action == "update_storage_urls":
await self.update_storage_urls(search_client)
else:
raise Exception(f"Unknown action {self.acl_action}")
async def view_acl(self, search_client: SearchClient):
for document in await self.get_documents(search_client):
# Assumes the acls are consistent across all sections of the document
print(json.dumps(document[self.acl_type]))
return
async def remove_acl(self, search_client: SearchClient):
documents_to_merge = []
for document in await self.get_documents(search_client):
new_acls = document[self.acl_type]
if any(acl_value == self.acl for acl_value in new_acls):
new_acls = [acl_value for acl_value in document[self.acl_type] if acl_value != self.acl]
documents_to_merge.append({"id": document["id"], self.acl_type: new_acls})
else:
logger.info("Search document %s does not have %s acl %s", document["id"], self.acl_type, self.acl)
if len(documents_to_merge) > 0:
logger.info("Removing acl %s from %d search documents", self.acl, len(documents_to_merge))
await search_client.merge_documents(documents=documents_to_merge)
else:
logger.info("Not updating any search documents")
async def remove_all_acls(self, search_client: SearchClient):
documents_to_merge = []
for document in await self.get_documents(search_client):
if len(document[self.acl_type]) > 0:
documents_to_merge.append({"id": document["id"], self.acl_type: []})
else:
logger.info("Search document %s already has no %s acls", document["id"], self.acl_type)
if len(documents_to_merge) > 0:
logger.info("Removing all %s acls from %d search documents", self.acl_type, len(documents_to_merge))
await search_client.merge_documents(documents=documents_to_merge)
else:
logger.info("Not updating any search documents")
async def add_acl(self, search_client: SearchClient):
documents_to_merge = []
for document in await self.get_documents(search_client):
new_acls = document[self.acl_type]
if not any(acl_value == self.acl for acl_value in new_acls):
new_acls.append(self.acl)
documents_to_merge.append({"id": document["id"], self.acl_type: new_acls})
else:
logger.info("Search document %s already has %s acl %s", document["id"], self.acl_type, self.acl)
if len(documents_to_merge) > 0:
logger.info("Adding acl %s to %d search documents", self.acl, len(documents_to_merge))
await search_client.merge_documents(documents=documents_to_merge)
else:
logger.info("Not updating any search documents")
async def get_documents(self, search_client: SearchClient):
filter = f"storageUrl eq '{self.url}'"
documents = await search_client.search("", filter=filter, select=["id", self.acl_type])
found_documents = []
async for document in documents:
found_documents.append(document)
logger.info("Found %d search documents with storageUrl %s", len(found_documents), self.url)
return found_documents
async def enable_acls(self, endpoint: str):
async with SearchIndexClient(endpoint=endpoint, credential=self.credentials) as search_index_client:
logger.info(f"Enabling acls for index {self.index_name}")
index_definition = await search_index_client.get_index(self.index_name)
if not any(field.name == "oids" for field in index_definition.fields):
index_definition.fields.append(
SimpleField(
name="oids",
type=SearchFieldDataType.Collection(SearchFieldDataType.String),
filterable=True,
)
)
if not any(field.name == "groups" for field in index_definition.fields):
index_definition.fields.append(
SimpleField(
name="groups",
type=SearchFieldDataType.Collection(SearchFieldDataType.String),
filterable=True,
)
)
if not any(field.name == "storageUrl" for field in index_definition.fields):
index_definition.fields.append(
SimpleField(
name="storageUrl",
type="Edm.String",
filterable=True,
facetable=False,
),
)
await search_index_client.create_or_update_index(index_definition)
async def update_storage_urls(self, search_client: SearchClient):
filter = "storageUrl eq ''"
documents = await search_client.search("", filter=filter, select=["id", "storageUrl", "oids", "sourcefile"])
found_documents = []
documents_to_merge = []
async for document in documents:
found_documents.append(document)
if len(document["oids"]) == 1:
logger.warning(
"Not updating storage URL of document %s as it has only one oid and may be user uploaded",
document["id"],
)
else:
storage_url = urljoin(self.url, document["sourcefile"])
documents_to_merge.append({"id": document["id"], "storageUrl": storage_url})
logger.info("Adding storage URL %s for document %s", storage_url, document["id"])
if len(documents_to_merge) > 0:
logger.info("Updating storage URL for %d search documents", len(documents_to_merge))
await search_client.merge_documents(documents=documents_to_merge)
elif len(found_documents) == 0:
logger.info("No documents found with empty storageUrl value")
else:
logger.info("Not updating any search documents")
async def main(args: Any):
load_azd_env()
# Use the current user identity to connect to Azure services unless a key is explicitly set for any of them
azd_credential = (
AzureDeveloperCliCredential()
if args.tenant_id is None
else AzureDeveloperCliCredential(tenant_id=args.tenant_id, process_timeout=60)
)
search_credential: Union[AsyncTokenCredential, AzureKeyCredential] = azd_credential
if args.search_key is not None:
search_credential = AzureKeyCredential(args.search_key)
command = ManageAcl(
service_name=os.environ["AZURE_SEARCH_SERVICE"],
index_name=os.environ["AZURE_SEARCH_INDEX"],
url=args.url,
acl_action=args.acl_action,
acl_type=args.acl_type,
acl=args.acl,
credentials=search_credential,
)
await command.run()
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Manage ACLs in a search index",
epilog="Example: manageacl.py --acl-action enable_acls",
)
parser.add_argument(
"--search-key",
required=False,
help="Optional. Use this Azure AI Search account key instead of the current user identity to login",
)
parser.add_argument("--acl-type", required=False, choices=["oids", "groups"], help="Optional. Type of ACL")
parser.add_argument(
"--acl-action",
required=False,
choices=["remove", "add", "view", "remove_all", "enable_acls", "update_storage_urls"],
help="Optional. Whether to remove or add the ACL to the document, or enable acls on the index",
)
parser.add_argument("--acl", required=False, default=None, help="Optional. Value of ACL to add or remove.")
parser.add_argument("--url", required=False, help="Optional. Storage URL of document to update ACLs for")
parser.add_argument(
"--tenant-id", required=False, help="Optional. Use this to define the Azure directory where to authenticate)"
)
parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output")
args = parser.parse_args()
if args.verbose:
logging.basicConfig(level=logging.WARNING, format="%(message)s")
# We only set the level to INFO for our logger,
# to avoid seeing the noisy INFO level logs from the Azure SDKs
logger.setLevel(logging.INFO)
if not args.acl_type and args.acl_action != "enable_acls" and args.acl_action != "update_storage_urls":
print("Must specify either --acl-type or --acl-action enable_acls or --acl-action update_storage_urls")
exit(1)
asyncio.run(main(args))