Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/default creds inheritance #246

Merged
merged 7 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 43 additions & 20 deletions src/dfcx_scrapi/core/scrapi_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Base for other SCRAPI classes."""

# Copyright 2023 Google LLC
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -22,8 +22,9 @@
import threading
import vertexai
from collections import defaultdict
from typing import Dict, Any, Iterable
from typing import Dict, Any, Iterable, List

from google.auth import default
from google.api_core import exceptions
from google.cloud.dialogflowcx_v3beta1 import types
from google.oauth2 import service_account
Expand Down Expand Up @@ -68,49 +69,54 @@

ALL_GENERATIVE_MODELS = ALL_GEMINI_MODELS + TEXT_GENERATION_MODELS

# Define global scopes used for all Dialogflow CX Requests
GLOBAL_SCOPES = [
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/dialogflow",
]

class ScrapiBase:
"""Core Class for managing Auth and other shared functions."""

global_scopes = [
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/dialogflow",
]

def __init__(
self,
creds_path: str = None,
creds_dict: Dict[str, str] = None,
creds: service_account.Credentials = None,
scope=False,
agent_id=None,
scope: List[str] = None,
agent_id: str = None,
):

self.scopes = ScrapiBase.global_scopes
self.scopes = GLOBAL_SCOPES
if scope:
self.scopes += scope

if creds:
self.creds = creds
self.creds.refresh(Request())
self.token = self.creds.token

elif creds_path:
self.creds = service_account.Credentials.from_service_account_file(
creds_path, scopes=self.scopes
)
self.creds.refresh(Request())
self.token = self.creds.token

elif creds_dict:
self.creds = service_account.Credentials.from_service_account_info(
creds_dict, scopes=self.scopes
)
self.creds.refresh(Request())
self.token = self.creds.token

else:
self.creds = None
self.token = None
self.creds, _ = default()
self.creds.refresh(Request())
self.token = self.creds.token
self._check_and_update_scopes(self.creds)

if agent_id:
self.agent_id = agent_id
self.agent_id = agent_id

self.api_calls_dict = defaultdict(int)

Expand Down Expand Up @@ -223,13 +229,22 @@ def dict_to_struct(some_dict: Dict[str, Any]):
@staticmethod
def parse_agent_id(resource_id: str):
"""Attempts to parse Agent ID from provided Resource ID."""
try:
agent_id = "/".join(resource_id.split("/")[:6])
except IndexError as err:
logging.error("IndexError - path too short? %s", resource_id)
raise err
parts = resource_id.split("/")
if len(parts) < 6:
raise ValueError(
"Resource ID is too short to contain an Agent ID: {}".format(
resource_id
)
)

return agent_id
if parts[4] != "agents":
raise ValueError(
"Resource ID does not contain an agent ID: {}".format(
resource_id
)
)

return "/".join(parts[:6])

@staticmethod
def _parse_resource_path(
Expand Down Expand Up @@ -400,6 +415,14 @@ def is_valid_sys_instruct_model(llm_model: str) -> bool:

return valid_sys_instruct

def _check_and_update_scopes(self, creds: Any):
"""Update Credentials scopes if possible based on creds type."""
if creds.requires_scopes:
self.creds.scopes.extend(GLOBAL_SCOPES)

else:
logging.info("Found user creds, skipping global scopes...")

def build_generative_model(
self,
llm_model: str,
Expand Down
27 changes: 17 additions & 10 deletions src/dfcx_scrapi/tools/dataframe_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import json
import logging
import time
from typing import Dict, List
from typing import Dict, List, Any
import gspread
import pandas as pd
import numpy as np
Expand All @@ -33,7 +33,7 @@
from dfcx_scrapi.core.pages import Pages
from dfcx_scrapi.core.transition_route_groups import TransitionRouteGroups

GLOBAL_SCOPE = [
SHEETS_SCOPE = [
"https://spreadsheets.google.com/feeds",
"https://www.googleapis.com/auth/drive",
]
Expand All @@ -52,8 +52,8 @@ def __init__(
self,
creds_path: str = None,
creds_dict: dict = None,
creds=None,
scope=False,
creds: Any = None,
scope: List[str] = None,
):
super().__init__(
creds_path=creds_path,
Expand All @@ -62,14 +62,16 @@ def __init__(
scope=scope,
)

scopes = GLOBAL_SCOPE
self._check_and_update_sheets_scopes()

if scope:
scopes += scope

self.creds.scopes.extend(scopes)
if hasattr(self.creds, "service_account_email") and self.creds.service_account_email:
self.sheets_client = gspread.authorize(self.creds)
else:
logging.warning(
"Application Default Credentials (ADC) found and Sheets Client"
" could not be authorized. Use Service Account or Oauth2 user"
" credentials if you require Sheets access.")

self.sheets_client = gspread.authorize(self.creds)
self.entities = EntityTypes(creds=self.creds)
self.intents = Intents(creds=self.creds)
self.flows = Flows(creds=self.creds)
Expand Down Expand Up @@ -138,6 +140,11 @@ def _remap_intent_values(original_intent: types.Intent) -> types.Intent:

return new_intent

def _check_and_update_sheets_scopes(self):
"""Update Credentials scopes if possible based on creds type."""
if self.creds.requires_scopes:
self.creds.scopes.extend(SHEETS_SCOPE)

def _update_intent_from_dataframe(
self,
intent_id: str,
Expand Down
Loading
Loading