better format; using ruff now

This commit is contained in:
2025-08-25 14:10:22 +05:30
parent 5311c68c58
commit 23299b7bb2
7 changed files with 346 additions and 146 deletions

View File

@@ -6,9 +6,9 @@ This module handles:
- Thread-safe database interactions - Thread-safe database interactions
- Interactive Terms of Service acceptance workflows - Interactive Terms of Service acceptance workflows
""" """
import logging import logging
import threading import threading
import time
import concurrent.futures import concurrent.futures
from datetime import datetime, timezone from datetime import datetime, timezone
from google.api_core import exceptions as google_exceptions from google.api_core import exceptions as google_exceptions
@@ -16,6 +16,7 @@ from google.cloud import resourcemanager_v3, api_keys_v2
from . import config, gcp_api, database, utils from . import config, gcp_api, database, utils
from .exceptions import TermsOfServiceNotAcceptedError from .exceptions import TermsOfServiceNotAcceptedError
class TempKey: class TempKey:
"""Mock key object compatible with database operations. """Mock key object compatible with database operations.
@@ -31,6 +32,7 @@ class TempKey:
update_time (datetime): Last update timestamp update_time (datetime): Last update timestamp
restrictions (api_keys_v2.Restrictions): Key usage restrictions restrictions (api_keys_v2.Restrictions): Key usage restrictions
""" """
def __init__(self, cloud_key, key_string): def __init__(self, cloud_key, key_string):
self.key_string = key_string self.key_string = key_string
self.uid = cloud_key.uid self.uid = cloud_key.uid
@@ -40,6 +42,7 @@ class TempKey:
self.update_time = cloud_key.update_time self.update_time = cloud_key.update_time
self.restrictions = cloud_key.restrictions self.restrictions = cloud_key.restrictions
class TosAcceptanceHelper: class TosAcceptanceHelper:
"""Manages Terms of Service acceptance workflow with thread synchronization. """Manages Terms of Service acceptance workflow with thread synchronization.
@@ -51,11 +54,13 @@ class TosAcceptanceHelper:
prompted_event (threading.Event): Signals ToS acceptance completion prompted_event (threading.Event): Signals ToS acceptance completion
prompt_in_progress (bool): Indicates active prompt display status prompt_in_progress (bool): Indicates active prompt display status
""" """
def __init__(self): def __init__(self):
self.lock = threading.Lock() self.lock = threading.Lock()
self.prompted_event = threading.Event() self.prompted_event = threading.Event()
self.prompt_in_progress = False self.prompt_in_progress = False
def _enable_api_with_interactive_retry(project_id, creds, dry_run, tos_helper): def _enable_api_with_interactive_retry(project_id, creds, dry_run, tos_helper):
"""Attempts to enable API with retry logic for ToS acceptance. """Attempts to enable API with retry logic for ToS acceptance.
@@ -89,6 +94,7 @@ def _enable_api_with_interactive_retry(project_id, creds, dry_run, tos_helper):
logging.error(f"API enablement error for {project_id}: {e}", exc_info=True) logging.error(f"API enablement error for {project_id}: {e}", exc_info=True)
return False return False
def reconcile_project_keys(project, creds, dry_run, db_lock, account_entry): def reconcile_project_keys(project, creds, dry_run, db_lock, account_entry):
"""Reconciles cloud and local database API key states. """Reconciles cloud and local database API key states.
@@ -112,27 +118,40 @@ def reconcile_project_keys(project, creds, dry_run, db_lock, account_entry):
cloud_keys_list = list(api_keys_client.list_keys(parent=parent)) cloud_keys_list = list(api_keys_client.list_keys(parent=parent))
for key in cloud_keys_list: for key in cloud_keys_list:
if key.display_name in [config.GEMINI_API_KEY_DISPLAY_NAME, config.GENERATIVE_LANGUAGE_API_KEY_DISPLAY_NAME]: if key.display_name in [
config.GEMINI_API_KEY_DISPLAY_NAME,
config.GENERATIVE_LANGUAGE_API_KEY_DISPLAY_NAME,
]:
gemini_key_exists = True gemini_key_exists = True
cloud_keys = {key.uid: key for key in cloud_keys_list} cloud_keys = {key.uid: key for key in cloud_keys_list}
project_entry = next((p for p in account_entry["projects"] if p.get("project_info", {}).get("project_id") == project_id), None) project_entry = next(
(
p
for p in account_entry["projects"]
if p.get("project_info", {}).get("project_id") == project_id
),
None,
)
if not project_entry: if not project_entry:
project_entry = { project_entry = {
"project_info": { "project_info": {
"project_id": project.project_id, "project_id": project.project_id,
"project_name": project.display_name, "project_name": project.display_name,
"project_number": project.name.split('/')[-1], "project_number": project.name.split("/")[-1],
"state": str(project.state) "state": str(project.state),
}, },
"api_keys": [] "api_keys": [],
} }
with db_lock: with db_lock:
account_entry["projects"].append(project_entry) account_entry["projects"].append(project_entry)
local_keys = {key['key_details']['key_id']: key for key in project_entry.get('api_keys', [])} local_keys = {
key["key_details"]["key_id"]: key
for key in project_entry.get("api_keys", [])
}
cloud_uids = set(cloud_keys.keys()) cloud_uids = set(cloud_keys.keys())
local_uids = set(local_keys.keys()) local_uids = set(local_keys.keys())
@@ -152,7 +171,9 @@ def reconcile_project_keys(project, creds, dry_run, db_lock, account_entry):
continue continue
try: try:
key_string_response = api_keys_client.get_key_string(name=key_object.name) key_string_response = api_keys_client.get_key_string(
name=key_object.name
)
hydrated_key = TempKey(key_object, key_string_response.key_string) hydrated_key = TempKey(key_object, key_string_response.key_string)
with db_lock: with db_lock:
database.add_key_to_database(account_entry, project, hydrated_key) database.add_key_to_database(account_entry, project, hydrated_key)
@@ -168,8 +189,10 @@ def reconcile_project_keys(project, creds, dry_run, db_lock, account_entry):
continue continue
with db_lock: with db_lock:
local_keys[uid]['state'] = 'INACTIVE' local_keys[uid]["state"] = "INACTIVE"
local_keys[uid]['key_details']['last_updated_timestamp_utc'] = datetime.now(timezone.utc).isoformat() local_keys[uid]["key_details"]["last_updated_timestamp_utc"] = (
datetime.now(timezone.utc).isoformat()
)
return gemini_key_exists return gemini_key_exists
@@ -180,7 +203,10 @@ def reconcile_project_keys(project, creds, dry_run, db_lock, account_entry):
logging.error(f"API error during reconciliation: {err}") logging.error(f"API error during reconciliation: {err}")
return False return False
def _create_and_process_new_project(project_number, creds, dry_run, db_lock, account_entry, tos_helper):
def _create_and_process_new_project(
project_number, creds, dry_run, db_lock, account_entry, tos_helper
):
"""Creates and initializes new GCP project with API key. """Creates and initializes new GCP project with API key.
Args: Args:
@@ -203,7 +229,9 @@ def _create_and_process_new_project(project_number, creds, dry_run, db_lock, acc
try: try:
resource_manager = resourcemanager_v3.ProjectsClient(credentials=creds) resource_manager = resourcemanager_v3.ProjectsClient(credentials=creds)
project_to_create = resourcemanager_v3.Project(project_id=project_id, display_name=display_name) project_to_create = resourcemanager_v3.Project(
project_id=project_id, display_name=display_name
)
operation = resource_manager.create_project(project=project_to_create) operation = resource_manager.create_project(project=project_to_create)
logging.info(f"Awaiting project creation: {display_name}") logging.info(f"Awaiting project creation: {display_name}")
created_project = operation.result() created_project = operation.result()
@@ -214,14 +242,19 @@ def _create_and_process_new_project(project_number, creds, dry_run, db_lock, acc
key_object = gcp_api.create_api_key(project_id, creds, dry_run=dry_run) key_object = gcp_api.create_api_key(project_id, creds, dry_run=dry_run)
if key_object: if key_object:
with db_lock: with db_lock:
database.add_key_to_database(account_entry, created_project, key_object) database.add_key_to_database(
account_entry, created_project, key_object
)
else: else:
logging.error(f"API enablement failed for {display_name}") logging.error(f"API enablement failed for {display_name}")
except Exception as e: except Exception as e:
logging.error(f"Project creation failed: {e}", exc_info=True) logging.error(f"Project creation failed: {e}", exc_info=True)
def process_project_for_action(project, creds, action, dry_run, db_lock, account_entry, tos_helper):
def process_project_for_action(
project, creds, action, dry_run, db_lock, account_entry, tos_helper
):
"""Executes specified action on a single GCP project. """Executes specified action on a single GCP project.
Args: Args:
@@ -236,8 +269,10 @@ def process_project_for_action(project, creds, action, dry_run, db_lock, account
project_id = project.project_id project_id = project.project_id
logging.info(f"Processing {project_id} ({project.display_name})") logging.info(f"Processing {project_id} ({project.display_name})")
if action == 'create': if action == "create":
gemini_key_exists = reconcile_project_keys(project, creds, dry_run, db_lock, account_entry) gemini_key_exists = reconcile_project_keys(
project, creds, dry_run, db_lock, account_entry
)
if gemini_key_exists: if gemini_key_exists:
logging.info(f"Existing Gemini key in {project_id}") logging.info(f"Existing Gemini key in {project_id}")
return return
@@ -247,15 +282,20 @@ def process_project_for_action(project, creds, action, dry_run, db_lock, account
if key_object: if key_object:
with db_lock: with db_lock:
database.add_key_to_database(account_entry, project, key_object) database.add_key_to_database(account_entry, project, key_object)
elif action == 'delete': elif action == "delete":
deleted_keys_uids = gcp_api.delete_api_keys(project_id, creds, dry_run=dry_run) deleted_keys_uids = gcp_api.delete_api_keys(project_id, creds, dry_run=dry_run)
if deleted_keys_uids: if deleted_keys_uids:
with db_lock: with db_lock:
database.remove_keys_from_database(account_entry, project_id, deleted_keys_uids) database.remove_keys_from_database(
account_entry, project_id, deleted_keys_uids
)
logging.info(f"Completed processing {project_id}") logging.info(f"Completed processing {project_id}")
def process_account(email, creds, action, api_keys_data, schema, dry_run=False, max_workers=5):
def process_account(
email, creds, action, api_keys_data, schema, dry_run=False, max_workers=5
):
"""Orchestrates account-level key management operations. """Orchestrates account-level key management operations.
Args: Args:
@@ -275,17 +315,24 @@ def process_account(email, creds, action, api_keys_data, schema, dry_run=False,
logging.warning(f"Invalid credentials for {email}") logging.warning(f"Invalid credentials for {email}")
return return
account_entry = next((acc for acc in api_keys_data["accounts"] if acc.get("account_details", {}).get("email") == email), None) account_entry = next(
(
acc
for acc in api_keys_data["accounts"]
if acc.get("account_details", {}).get("email") == email
),
None,
)
if not account_entry: if not account_entry:
account_entry = { account_entry = {
"account_details": { "account_details": {
"email": email, "email": email,
"authentication_details": { "authentication_details": {
"token_file": f"{config.CREDENTIALS_DIR}/{email}.json", "token_file": f"{config.CREDENTIALS_DIR}/{email}.json",
"scopes": config.SCOPES "scopes": config.SCOPES,
}
}, },
"projects": [] },
"projects": [],
} }
api_keys_data["accounts"].append(account_entry) api_keys_data["accounts"].append(account_entry)
@@ -293,13 +340,15 @@ def process_account(email, creds, action, api_keys_data, schema, dry_run=False,
resource_manager = resourcemanager_v3.ProjectsClient(credentials=creds) resource_manager = resourcemanager_v3.ProjectsClient(credentials=creds)
existing_projects = list(resource_manager.search_projects()) existing_projects = list(resource_manager.search_projects())
if not existing_projects and action == 'create': if not existing_projects and action == "create":
logging.warning(f"No projects found for {email}") logging.warning(f"No projects found for {email}")
logging.warning("Possible reasons: No projects or unaccepted ToS") logging.warning("Possible reasons: No projects or unaccepted ToS")
logging.warning(f"Verify ToS: https://console.cloud.google.com/iam-admin/settings?user={email}") logging.warning(
f"Verify ToS: https://console.cloud.google.com/iam-admin/settings?user={email}"
)
projects_to_create_count = 0 projects_to_create_count = 0
if action == 'create': if action == "create":
if len(existing_projects) < 12: if len(existing_projects) < 12:
projects_to_create_count = 12 - len(existing_projects) projects_to_create_count = 12 - len(existing_projects)
@@ -309,12 +358,33 @@ def process_account(email, creds, action, api_keys_data, schema, dry_run=False,
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [] futures = []
for project in existing_projects: for project in existing_projects:
futures.append(executor.submit(process_project_for_action, project, creds, action, dry_run, db_lock, account_entry, tos_helper)) futures.append(
executor.submit(
process_project_for_action,
project,
creds,
action,
dry_run,
db_lock,
account_entry,
tos_helper,
)
)
if action == 'create' and projects_to_create_count > 0: if action == "create" and projects_to_create_count > 0:
for i in range(len(existing_projects), 12): for i in range(len(existing_projects), 12):
project_number = str(i + 1).zfill(2) project_number = str(i + 1).zfill(2)
futures.append(executor.submit(_create_and_process_new_project, project_number, creds, dry_run, db_lock, account_entry, tos_helper)) futures.append(
executor.submit(
_create_and_process_new_project,
project_number,
creds,
dry_run,
db_lock,
account_entry,
tos_helper,
)
)
for future in concurrent.futures.as_completed(futures): for future in concurrent.futures.as_completed(futures):
try: try:

View File

@@ -5,15 +5,14 @@ Handles OAuth2 credential management including:
- Interactive authentication flows - Interactive authentication flows
- Credential storage/retrieval - Credential storage/retrieval
""" """
import os import os
import json import json
import logging import logging
import time import time
import google.auth
from google.oauth2.credentials import Credentials from google.oauth2.credentials import Credentials
import google_auth_oauthlib.flow
from google_auth_oauthlib.flow import InstalledAppFlow from google_auth_oauthlib.flow import InstalledAppFlow
from google.auth.transport import requests from google.auth.transport.requests import Request
from . import config from . import config
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -36,7 +35,9 @@ def get_and_refresh_credentials(email, max_retries=3, retry_delay=5):
try: try:
creds = Credentials.from_authorized_user_file(token_file, config.SCOPES) creds = Credentials.from_authorized_user_file(token_file, config.SCOPES)
except (ValueError, json.JSONDecodeError): except (ValueError, json.JSONDecodeError):
logging.warning(f"Could not decode token file for {email}. Re-authentication will be required.") logging.warning(
f"Could not decode token file for {email}. Re-authentication will be required."
)
return None return None
if creds and creds.valid: if creds and creds.valid:
@@ -45,22 +46,29 @@ def get_and_refresh_credentials(email, max_retries=3, retry_delay=5):
if creds and creds.expired and creds.refresh_token: if creds and creds.expired and creds.refresh_token:
for attempt in range(max_retries): for attempt in range(max_retries):
try: try:
logging.info(f"Refreshing credentials for {email} (attempt {attempt + 1}/{max_retries})...") logging.info(
creds.refresh(google.auth.transport.requests.Request()) f"Refreshing credentials for {email} (attempt {attempt + 1}/{max_retries})..."
)
creds.refresh(Request())
with open(token_file, "w") as token: with open(token_file, "w") as token:
token.write(creds.to_json()) token.write(creds.to_json())
logging.info(f"Successfully refreshed credentials for {email}.") logging.info(f"Successfully refreshed credentials for {email}.")
return creds return creds
except Exception as e: except Exception as e:
logging.warning(f"Failed to refresh credentials for {email} on attempt {attempt + 1}: {e}") logging.warning(
f"Failed to refresh credentials for {email} on attempt {attempt + 1}: {e}"
)
if attempt < max_retries - 1: if attempt < max_retries - 1:
time.sleep(retry_delay) time.sleep(retry_delay)
logging.error(f"Failed to refresh credentials for {email} after {max_retries} attempts.") logging.error(
f"Failed to refresh credentials for {email} after {max_retries} attempts."
)
return None return None
return None return None
def run_interactive_auth(email, max_retries=3, retry_delay=5): def run_interactive_auth(email, max_retries=3, retry_delay=5):
"""Executes interactive OAuth2 flow with error handling. """Executes interactive OAuth2 flow with error handling.
@@ -74,8 +82,10 @@ def run_interactive_auth(email, max_retries=3, retry_delay=5):
""" """
for attempt in range(max_retries): for attempt in range(max_retries):
try: try:
logging.info(f"Please authenticate with: {email} (attempt {attempt + 1}/{max_retries})") logging.info(
flow = google_auth_oauthlib.flow.InstalledAppFlow.from_client_secrets_file( f"Please authenticate with: {email} (attempt {attempt + 1}/{max_retries})"
)
flow = InstalledAppFlow.from_client_secrets_file(
config.CLIENT_SECRETS_FILE, config.SCOPES config.CLIENT_SECRETS_FILE, config.SCOPES
) )
creds = flow.run_local_server(port=0) creds = flow.run_local_server(port=0)
@@ -84,7 +94,9 @@ def run_interactive_auth(email, max_retries=3, retry_delay=5):
token.write(creds.to_json()) token.write(creds.to_json())
return creds return creds
except Exception as e: except Exception as e:
logging.error(f"An unexpected error occurred during authentication for {email} on attempt {attempt + 1}: {e}") logging.error(
f"An unexpected error occurred during authentication for {email} on attempt {attempt + 1}: {e}"
)
if attempt < max_retries - 1: if attempt < max_retries - 1:
logging.info(f"Retrying authentication in {retry_delay} seconds...") logging.info(f"Retrying authentication in {retry_delay} seconds...")
time.sleep(retry_delay) time.sleep(retry_delay)

View File

@@ -5,6 +5,7 @@ Contains:
- API endpoint configurations - API endpoint configurations
- Security scopes and schema locations - Security scopes and schema locations
""" """
import os import os
# --- DIRECTORIES --- # --- DIRECTORIES ---

View File

@@ -6,13 +6,14 @@ Implements:
- Key lifecycle tracking - Key lifecycle tracking
- Data versioning and backup - Data versioning and backup
""" """
import os import os
import json import json
import logging import logging
import sys import sys
from datetime import datetime, timezone from datetime import datetime, timezone
import jsonschema import jsonschema
from . import config
def load_schema(filename): def load_schema(filename):
"""Validates and loads JSON schema definition. """Validates and loads JSON schema definition.
@@ -36,13 +37,11 @@ def load_schema(filename):
logging.error(f"Could not decode JSON schema from {filename}.") logging.error(f"Could not decode JSON schema from {filename}.")
sys.exit(1) sys.exit(1)
def load_keys_database(filename, schema): def load_keys_database(filename, schema):
"""Loads and validates the JSON database of API keys.""" """Loads and validates the JSON database of API keys."""
if not os.path.exists(filename): if not os.path.exists(filename):
return { return {"schema_version": "1.0.0", "accounts": []}
"schema_version": "1.0.0",
"accounts": []
}
with open(filename, "r") as f: with open(filename, "r") as f:
try: try:
data = json.load(f) data = json.load(f)
@@ -51,12 +50,12 @@ def load_keys_database(filename, schema):
except json.JSONDecodeError: except json.JSONDecodeError:
logging.warning(f"Could not decode JSON from {filename}. Starting fresh.") logging.warning(f"Could not decode JSON from {filename}. Starting fresh.")
except jsonschema.ValidationError as e: except jsonschema.ValidationError as e:
logging.warning(f"Database file '{filename}' is not valid. {e.message}. Starting fresh.") logging.warning(
f"Database file '{filename}' is not valid. {e.message}. Starting fresh."
)
return {"schema_version": "1.0.0", "accounts": []}
return {
"schema_version": "1.0.0",
"accounts": []
}
def save_keys_to_json(data, filename, schema): def save_keys_to_json(data, filename, schema):
"""Validates and saves the API key data to a single JSON file.""" """Validates and saves the API key data to a single JSON file."""
@@ -73,20 +72,28 @@ def save_keys_to_json(data, filename, schema):
logging.error(f"Validation Error: {e.message}") logging.error(f"Validation Error: {e.message}")
sys.exit(1) sys.exit(1)
def add_key_to_database(account_entry, project, key_object): def add_key_to_database(account_entry, project, key_object):
"""Adds a new API key's details to the data structure.""" """Adds a new API key's details to the data structure."""
project_id = project.project_id project_id = project.project_id
project_entry = next((p for p in account_entry["projects"] if p.get("project_info", {}).get("project_id") == project_id), None) project_entry = next(
(
p
for p in account_entry["projects"]
if p.get("project_info", {}).get("project_id") == project_id
),
None,
)
if not project_entry: if not project_entry:
project_entry = { project_entry = {
"project_info": { "project_info": {
"project_id": project_id, "project_id": project_id,
"project_name": project.display_name, "project_name": project.display_name,
"project_number": project.name.split('/')[-1], "project_number": project.name.split("/")[-1],
"state": str(project.state) "state": str(project.state),
}, },
"api_keys": [] "api_keys": [],
} }
account_entry["projects"].append(project_entry) account_entry["projects"].append(project_entry)
@@ -104,31 +111,51 @@ def add_key_to_database(account_entry, project, key_object):
"creation_timestamp_utc": key_object.create_time.isoformat(), "creation_timestamp_utc": key_object.create_time.isoformat(),
"last_updated_timestamp_utc": key_object.update_time.isoformat(), "last_updated_timestamp_utc": key_object.update_time.isoformat(),
}, },
"restrictions": { "restrictions": {"api_targets": api_targets},
"api_targets": api_targets "state": "ACTIVE",
},
"state": "ACTIVE"
} }
existing_key = next((k for k in project_entry["api_keys"] if k.get("key_details", {}).get("key_id") == key_object.uid), None) existing_key = next(
(
k
for k in project_entry["api_keys"]
if k.get("key_details", {}).get("key_id") == key_object.uid
),
None,
)
if not existing_key: if not existing_key:
project_entry["api_keys"].append(new_key_entry) project_entry["api_keys"].append(new_key_entry)
logging.info(f" Added key {key_object.uid} to local database for project {project_id}") logging.info(
f" Added key {key_object.uid} to local database for project {project_id}"
)
else: else:
logging.warning(f" Key {key_object.uid} already exists in local database for project {project_id}") logging.warning(
f" Key {key_object.uid} already exists in local database for project {project_id}"
)
def remove_keys_from_database(account_entry, project_id, deleted_keys_uids): def remove_keys_from_database(account_entry, project_id, deleted_keys_uids):
"""Removes deleted API keys from the data structure.""" """Removes deleted API keys from the data structure."""
project_entry = next((p for p in account_entry["projects"] if p.get("project_info", {}).get("project_id") == project_id), None) project_entry = next(
(
p
for p in account_entry["projects"]
if p.get("project_info", {}).get("project_id") == project_id
),
None,
)
if not project_entry: if not project_entry:
return return
initial_key_count = len(project_entry["api_keys"]) initial_key_count = len(project_entry["api_keys"])
project_entry["api_keys"] = [ project_entry["api_keys"] = [
key for key in project_entry["api_keys"] key
for key in project_entry["api_keys"]
if key.get("key_details", {}).get("key_id") not in deleted_keys_uids if key.get("key_details", {}).get("key_id") not in deleted_keys_uids
] ]
final_key_count = len(project_entry["api_keys"]) final_key_count = len(project_entry["api_keys"])
num_removed = initial_key_count - final_key_count num_removed = initial_key_count - final_key_count
if num_removed > 0: if num_removed > 0:
logging.info(f" Removed {num_removed} key(s) from local database for project {project_id}") logging.info(
f" Removed {num_removed} key(s) from local database for project {project_id}"
)

View File

@@ -6,6 +6,7 @@ Defines domain-specific exceptions for:
- API operation constraints - API operation constraints
""" """
class TermsOfServiceNotAcceptedError(Exception): class TermsOfServiceNotAcceptedError(Exception):
"""Indicates unaccepted Terms of Service for critical API operations. """Indicates unaccepted Terms of Service for critical API operations.
@@ -13,6 +14,7 @@ class TermsOfServiceNotAcceptedError(Exception):
message (str): Human-readable error description message (str): Human-readable error description
url (str): URL for Terms of Service acceptance portal url (str): URL for Terms of Service acceptance portal
""" """
def __init__(self, message, url): def __init__(self, message, url):
self.message = message self.message = message
self.url = url self.url = url

View File

@@ -1,13 +1,13 @@
""" """
Functions for interacting with Google Cloud Platform APIs. Functions for interacting with Google Cloud Platform APIs.
""" """
import logging import logging
import time
import concurrent.futures
from datetime import datetime, timezone from datetime import datetime, timezone
from google.cloud import resourcemanager_v3, service_usage_v1, api_keys_v2 from google.cloud import service_usage_v1, api_keys_v2
from google.api_core import exceptions as google_exceptions from google.api_core import exceptions as google_exceptions
from . import config, utils, exceptions from . import config, exceptions
def enable_api(project_id, credentials, dry_run=False): def enable_api(project_id, credentials, dry_run=False):
"""Manages Generative Language API enablement with error handling. """Manages Generative Language API enablement with error handling.
@@ -28,7 +28,9 @@ def enable_api(project_id, credentials, dry_run=False):
service_usage_client = service_usage_v1.ServiceUsageClient(credentials=credentials) service_usage_client = service_usage_v1.ServiceUsageClient(credentials=credentials)
try: try:
logging.info(f" Attempting to enable Generative Language API for project {project_id}...") logging.info(
f" Attempting to enable Generative Language API for project {project_id}..."
)
if dry_run: if dry_run:
logging.info(f" [DRY RUN] Would enable API for project {project_id}") logging.info(f" [DRY RUN] Would enable API for project {project_id}")
return True return True
@@ -37,22 +39,29 @@ def enable_api(project_id, credentials, dry_run=False):
operation = service_usage_client.enable_service(request=enable_request) operation = service_usage_client.enable_service(request=enable_request)
# Wait for the operation to complete. # Wait for the operation to complete.
operation.result() operation.result()
logging.info(f" Successfully enabled Generative Language API for project {project_id}") logging.info(
f" Successfully enabled Generative Language API for project {project_id}"
)
return True return True
except google_exceptions.PermissionDenied: except google_exceptions.PermissionDenied:
logging.warning(f" Permission denied to enable API for project {project_id}. Skipping.") logging.warning(
f" Permission denied to enable API for project {project_id}. Skipping."
)
return False return False
except google_exceptions.GoogleAPICallError as err: except google_exceptions.GoogleAPICallError as err:
if 'UREQ_TOS_NOT_ACCEPTED' in str(err): if "UREQ_TOS_NOT_ACCEPTED" in str(err):
tos_url = "https://console.developers.google.com/terms/generative-language-api" tos_url = (
"https://console.developers.google.com/terms/generative-language-api"
)
raise exceptions.TermsOfServiceNotAcceptedError( raise exceptions.TermsOfServiceNotAcceptedError(
f"Terms of Service for the Generative Language API have not been accepted for project {project_id}.", f"Terms of Service for the Generative Language API have not been accepted for project {project_id}.",
url=tos_url url=tos_url,
) )
logging.error(f" Error enabling API for project {project_id}: {err}") logging.error(f" Error enabling API for project {project_id}: {err}")
return False return False
def create_api_key(project_id, credentials, dry_run=False): def create_api_key(project_id, credentials, dry_run=False):
"""Generates restricted API key with security constraints. """Generates restricted API key with security constraints.
@@ -98,15 +107,20 @@ def create_api_key(project_id, credentials, dry_run=False):
logging.info(" Creating API key...") logging.info(" Creating API key...")
operation = api_keys_client.create_key(request=request) operation = api_keys_client.create_key(request=request)
result = operation.result() result = operation.result()
logging.info(f" Successfully created restricted API key for project {project_id}") logging.info(
f" Successfully created restricted API key for project {project_id}"
)
return result return result
except google_exceptions.PermissionDenied: except google_exceptions.PermissionDenied:
logging.warning(f" Permission denied to create API key for project {project_id}. Skipping.") logging.warning(
f" Permission denied to create API key for project {project_id}. Skipping."
)
return None return None
except google_exceptions.GoogleAPICallError as err: except google_exceptions.GoogleAPICallError as err:
logging.error(f" Error creating API key for project {project_id}: {err}") logging.error(f" Error creating API key for project {project_id}: {err}")
return None return None
def delete_api_keys(project_id, credentials, dry_run=False): def delete_api_keys(project_id, credentials, dry_run=False):
"""Deletes all API keys with the display name 'Gemini API Key' and returns their UIDs.""" """Deletes all API keys with the display name 'Gemini API Key' and returns their UIDs."""
deleted_keys_uids = [] deleted_keys_uids = []
@@ -115,13 +129,21 @@ def delete_api_keys(project_id, credentials, dry_run=False):
parent = f"projects/{project_id}/locations/global" parent = f"projects/{project_id}/locations/global"
keys = api_keys_client.list_keys(parent=parent) keys = api_keys_client.list_keys(parent=parent)
keys_to_delete = [key for key in keys if key.display_name == config.GEMINI_API_KEY_DISPLAY_NAME] keys_to_delete = [
key
for key in keys
if key.display_name == config.GEMINI_API_KEY_DISPLAY_NAME
]
if not keys_to_delete: if not keys_to_delete:
logging.info(f" No '{config.GEMINI_API_KEY_DISPLAY_NAME}' found to delete.") logging.info(
f" No '{config.GEMINI_API_KEY_DISPLAY_NAME}' found to delete."
)
return [] return []
logging.info(f" Found {len(keys_to_delete)} key(s) with display name '{config.GEMINI_API_KEY_DISPLAY_NAME}'. Deleting...") logging.info(
f" Found {len(keys_to_delete)} key(s) with display name '{config.GEMINI_API_KEY_DISPLAY_NAME}'. Deleting..."
)
for key in keys_to_delete: for key in keys_to_delete:
if dry_run: if dry_run:
logging.info(f" [DRY RUN] Would delete key: {key.uid}") logging.info(f" [DRY RUN] Would delete key: {key.uid}")
@@ -137,8 +159,11 @@ def delete_api_keys(project_id, credentials, dry_run=False):
logging.error(f" Error deleting key {key.uid}: {err}") logging.error(f" Error deleting key {key.uid}: {err}")
return deleted_keys_uids return deleted_keys_uids
except google_exceptions.PermissionDenied: except google_exceptions.PermissionDenied:
logging.warning(f" Permission denied to list or delete API keys for project {project_id}. Skipping.") logging.warning(
f" Permission denied to list or delete API keys for project {project_id}. Skipping."
)
except google_exceptions.GoogleAPICallError as err: except google_exceptions.GoogleAPICallError as err:
logging.error(f" An API error occurred while deleting keys for project {project_id}: {err}") logging.error(
f" An API error occurred while deleting keys for project {project_id}: {err}"
)
return [] return []

View File

@@ -1,6 +1,7 @@
""" """
Main entry point for the Gemini Key Management script. Main entry point for the Gemini Key Management script.
""" """
import argparse import argparse
import logging import logging
import sys import sys
@@ -8,6 +9,7 @@ import os
import concurrent.futures import concurrent.futures
from . import utils, config, auth, database, actions from . import utils, config, auth, database, actions
def main(): def main():
"""Orchestrates API key lifecycle management workflow. """Orchestrates API key lifecycle management workflow.
@@ -17,23 +19,53 @@ def main():
- Multi-account processing - Multi-account processing
- Thread pool execution - Thread pool execution
""" """
parser = argparse.ArgumentParser(description="Manage Gemini API keys in Google Cloud projects.") parser = argparse.ArgumentParser(
parser.add_argument("action", choices=['create', 'delete'], help="The action to perform: 'create' or 'delete' API keys.") description="Manage Gemini API keys in Google Cloud projects."
parser.add_argument("--email", help="Specify a single email address to process. Required for 'delete'. If not provided for 'create', emails will be read from emails.txt.") )
parser.add_argument("--dry-run", action="store_true", help="Simulate the run without making any actual changes to Google Cloud resources.") parser.add_argument(
parser.add_argument("--max-workers", type=int, default=5, help="The maximum number of concurrent projects to process.") "action",
parser.add_argument("--auth-retries", type=int, default=3, help="Number of retries for a failed authentication attempt.") choices=["create", "delete"],
parser.add_argument("--auth-retry-delay", type=int, default=5, help="Delay in seconds between authentication retries.") help="The action to perform: 'create' or 'delete' API keys.",
)
parser.add_argument(
"--email",
help="Specify a single email address to process. Required for 'delete'. If not provided for 'create', emails will be read from emails.txt.",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Simulate the run without making any actual changes to Google Cloud resources.",
)
parser.add_argument(
"--max-workers",
type=int,
default=5,
help="The maximum number of concurrent projects to process.",
)
parser.add_argument(
"--auth-retries",
type=int,
default=3,
help="Number of retries for a failed authentication attempt.",
)
parser.add_argument(
"--auth-retry-delay",
type=int,
default=5,
help="Delay in seconds between authentication retries.",
)
args = parser.parse_args() args = parser.parse_args()
utils.setup_logging() utils.setup_logging()
logging.info(f"Program arguments: {vars(args)}") logging.info(f"Program arguments: {vars(args)}")
if args.action == 'delete' and not args.email: if args.action == "delete" and not args.email:
parser.error("the --email argument is required for the 'delete' action") parser.error("the --email argument is required for the 'delete' action")
if not os.path.exists(config.CLIENT_SECRETS_FILE): if not os.path.exists(config.CLIENT_SECRETS_FILE):
logging.error(f"OAuth client secrets file not found at '{config.CLIENT_SECRETS_FILE}'") logging.error(
f"OAuth client secrets file not found at '{config.CLIENT_SECRETS_FILE}'"
)
logging.error("Please follow the setup instructions in README.md to create it.") logging.error("Please follow the setup instructions in README.md to create it.")
sys.exit(1) sys.exit(1)
@@ -46,8 +78,10 @@ def main():
emails_to_process = [] emails_to_process = []
if args.email: if args.email:
emails_to_process.append(args.email) emails_to_process.append(args.email)
elif args.action == 'delete': elif args.action == "delete":
logging.error("The 'delete' action requires the --email argument to specify which account's keys to delete.") logging.error(
"The 'delete' action requires the --email argument to specify which account's keys to delete."
)
sys.exit(1) sys.exit(1)
else: # action is 'create' and no email provided else: # action is 'create' and no email provided
emails_to_process = utils.load_emails_from_file(config.EMAILS_FILE) emails_to_process = utils.load_emails_from_file(config.EMAILS_FILE)
@@ -60,8 +94,18 @@ def main():
logging.info("Checking credentials and refreshing tokens for all accounts...") logging.info("Checking credentials and refreshing tokens for all accounts...")
with concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor: with concurrent.futures.ThreadPoolExecutor(
future_to_email = {executor.submit(auth.get_and_refresh_credentials, email, max_retries=args.auth_retries, retry_delay=args.auth_retry_delay): email for email in emails_to_process} max_workers=args.max_workers
) as executor:
future_to_email = {
executor.submit(
auth.get_and_refresh_credentials,
email,
max_retries=args.auth_retries,
retry_delay=args.auth_retry_delay,
): email
for email in emails_to_process
}
for future in concurrent.futures.as_completed(future_to_email): for future in concurrent.futures.as_completed(future_to_email):
email = future_to_email[future] email = future_to_email[future]
@@ -72,25 +116,44 @@ def main():
else: else:
emails_needing_interactive_auth.append(email) emails_needing_interactive_auth.append(email)
except Exception as exc: except Exception as exc:
logging.error(f"Credential check for {email} generated an exception: {exc}", exc_info=True) logging.error(
f"Credential check for {email} generated an exception: {exc}",
exc_info=True,
)
emails_needing_interactive_auth.append(email) emails_needing_interactive_auth.append(email)
if emails_needing_interactive_auth: if emails_needing_interactive_auth:
logging.info(f"\n--- INTERACTIVE AUTHENTICATION REQUIRED ---") logging.info("\n--- INTERACTIVE AUTHENTICATION REQUIRED ---")
logging.info(f"The following accounts require manual authentication: {', '.join(sorted(emails_needing_interactive_auth))}") logging.info(
f"The following accounts require manual authentication: {', '.join(sorted(emails_needing_interactive_auth))}"
)
for email in sorted(emails_needing_interactive_auth): for email in sorted(emails_needing_interactive_auth):
creds = auth.run_interactive_auth(email, max_retries=args.auth_retries, retry_delay=args.auth_retry_delay) creds = auth.run_interactive_auth(
email, max_retries=args.auth_retries, retry_delay=args.auth_retry_delay
)
if creds: if creds:
logging.info(f"Successfully authenticated {email}.") logging.info(f"Successfully authenticated {email}.")
creds_map[email] = creds creds_map[email] = creds
else: else:
logging.warning(f"Authentication failed or was cancelled for {email}. This account will be skipped.") logging.warning(
f"Authentication failed or was cancelled for {email}. This account will be skipped."
)
logging.info("\n--- Credential checking complete ---") logging.info("\n--- Credential checking complete ---")
for email in emails_to_process: for email in emails_to_process:
if email in creds_map: if email in creds_map:
actions.process_account(email, creds_map[email], args.action, api_keys_data, schema, dry_run=args.dry_run, max_workers=args.max_workers) actions.process_account(
email,
creds_map[email],
args.action,
api_keys_data,
schema,
dry_run=args.dry_run,
max_workers=args.max_workers,
)
else: else:
logging.warning(f"Skipping account {email} because authentication was not successful.") logging.warning(
f"Skipping account {email} because authentication was not successful."
)