using types; way too hard (used gemini a lit bit for help); should have done it from start

This commit is contained in:
2025-08-27 15:52:17 +05:30
parent 23299b7bb2
commit 5216794aeb
10 changed files with 289 additions and 93 deletions

View File

@@ -7,40 +7,23 @@ This module handles:
- Interactive Terms of Service acceptance workflows
"""
from __future__ import annotations
import logging
import threading
import concurrent.futures
from datetime import datetime, timezone
from typing import Any, Dict, List
from google.api_core import exceptions as google_exceptions
from google.cloud import resourcemanager_v3, api_keys_v2
from google.oauth2.credentials import Credentials
from google.cloud.resourcemanager_v3.types import Project as CloudProject
from google.cloud.api_keys_v2.types import Key as CloudKey
from . import config, gcp_api, database, utils
from .exceptions import TermsOfServiceNotAcceptedError
class TempKey:
"""Mock key object compatible with database operations.
Provides a temporary representation of an API key for database insertion
when direct API key string retrieval is not possible.
Attributes:
key_string (str): The actual API key string
uid (str): Unique identifier of the key
name (str): Full resource name of the key
display_name (str): Human-readable display name
create_time (datetime): Key creation timestamp
update_time (datetime): Last update timestamp
restrictions (api_keys_v2.Restrictions): Key usage restrictions
"""
def __init__(self, cloud_key, key_string):
self.key_string = key_string
self.uid = cloud_key.uid
self.name = cloud_key.name
self.display_name = cloud_key.display_name
self.create_time = cloud_key.create_time
self.update_time = cloud_key.update_time
self.restrictions = cloud_key.restrictions
from .types import Account, Project as LocalProject, ApiKeysDatabase, TempKey
class TosAcceptanceHelper:
@@ -55,18 +38,20 @@ class TosAcceptanceHelper:
prompt_in_progress (bool): Indicates active prompt display status
"""
def __init__(self):
def __init__(self) -> None:
self.lock = threading.Lock()
self.prompted_event = threading.Event()
self.prompt_in_progress = False
def _enable_api_with_interactive_retry(project_id, creds, dry_run, tos_helper):
def _enable_api_with_interactive_retry(
project_id: str, creds: Credentials, dry_run: bool, tos_helper: TosAcceptanceHelper
) -> bool:
"""Attempts to enable API with retry logic for ToS acceptance.
Args:
project_id (str): Target GCP project ID
creds (Credentials: Authenticated Google credentials
creds (Credentials): Authenticated Google credentials
dry_run (bool): Simulation mode flag
tos_helper (TosAcceptanceHelper): ToS workflow coordinator
@@ -95,7 +80,13 @@ def _enable_api_with_interactive_retry(project_id, creds, dry_run, tos_helper):
return False
def reconcile_project_keys(project, creds, dry_run, db_lock, account_entry):
def reconcile_project_keys(
project: CloudProject,
creds: Credentials,
dry_run: bool,
db_lock: threading.Lock,
account_entry: Account,
) -> bool:
"""Reconciles cloud and local database API key states.
Args:
@@ -108,7 +99,7 @@ def reconcile_project_keys(project, creds, dry_run, db_lock, account_entry):
Returns:
bool: True if Gemini key exists, False otherwise
"""
project_id = project.project_id
project_id: str = project.project_id
logging.info(f"Reconciling keys for {project_id}")
gemini_key_exists = False
@@ -116,7 +107,7 @@ def reconcile_project_keys(project, creds, dry_run, db_lock, account_entry):
api_keys_client = api_keys_v2.ApiKeysClient(credentials=creds)
parent = f"projects/{project_id}/locations/global"
cloud_keys_list = list(api_keys_client.list_keys(parent=parent))
cloud_keys_list: List[CloudKey] = list(api_keys_client.list_keys(parent=parent))
for key in cloud_keys_list:
if key.display_name in [
config.GEMINI_API_KEY_DISPLAY_NAME,
@@ -124,7 +115,7 @@ def reconcile_project_keys(project, creds, dry_run, db_lock, account_entry):
]:
gemini_key_exists = True
cloud_keys = {key.uid: key for key in cloud_keys_list}
cloud_keys: Dict[str, CloudKey] = {key.uid: key for key in cloud_keys_list}
project_entry = next(
(
@@ -136,7 +127,7 @@ def reconcile_project_keys(project, creds, dry_run, db_lock, account_entry):
)
if not project_entry:
project_entry = {
project_entry: LocalProject = {
"project_info": {
"project_id": project.project_id,
"project_name": project.display_name,
@@ -205,8 +196,13 @@ def reconcile_project_keys(project, creds, dry_run, db_lock, account_entry):
def _create_and_process_new_project(
project_number, creds, dry_run, db_lock, account_entry, tos_helper
):
project_number: str,
creds: Credentials,
dry_run: bool,
db_lock: threading.Lock,
account_entry: Account,
tos_helper: TosAcceptanceHelper,
) -> None:
"""Creates and initializes new GCP project with API key.
Args:
@@ -234,7 +230,7 @@ def _create_and_process_new_project(
)
operation = resource_manager.create_project(project=project_to_create)
logging.info(f"Awaiting project creation: {display_name}")
created_project = operation.result()
created_project: CloudProject = operation.result()
logging.info(f"Project created: {display_name}")
if _enable_api_with_interactive_retry(project_id, creds, dry_run, tos_helper):
@@ -253,8 +249,14 @@ def _create_and_process_new_project(
def process_project_for_action(
project, creds, action, dry_run, db_lock, account_entry, tos_helper
):
project: CloudProject,
creds: Credentials,
action: str,
dry_run: bool,
db_lock: threading.Lock,
account_entry: Account,
tos_helper: TosAcceptanceHelper,
) -> None:
"""Executes specified action on a single GCP project.
Args:
@@ -266,7 +268,7 @@ def process_project_for_action(
account_entry (dict): Account data structure
tos_helper (TosAcceptanceHelper): ToS workflow coordinator
"""
project_id = project.project_id
project_id: str = project.project_id
logging.info(f"Processing {project_id} ({project.display_name})")
if action == "create":
@@ -294,8 +296,14 @@ def process_project_for_action(
def process_account(
email, creds, action, api_keys_data, schema, dry_run=False, max_workers=5
):
email: str,
creds: Credentials,
action: str,
api_keys_data: ApiKeysDatabase,
schema: Dict[str, Any],
dry_run: bool = False,
max_workers: int = 5,
) -> None:
"""Orchestrates account-level key management operations.
Args:
@@ -324,7 +332,7 @@ def process_account(
None,
)
if not account_entry:
account_entry = {
account_entry: Account = {
"account_details": {
"email": email,
"authentication_details": {
@@ -338,7 +346,9 @@ def process_account(
try:
resource_manager = resourcemanager_v3.ProjectsClient(credentials=creds)
existing_projects = list(resource_manager.search_projects())
existing_projects: List[CloudProject] = list(
resource_manager.search_projects()
)
if not existing_projects and action == "create":
logging.warning(f"No projects found for {email}")
@@ -398,4 +408,4 @@ def process_account(
logging.error(f"API error processing {email}: {err}")
if not dry_run:
database.save_keys_to_json(api_keys_data, config.API_KEYS_DATABASE_FILE, schema)
database.save_keys_to_json(api_keys_data, config.API_KEYS_DATABASE_FILE, schema)

View File

@@ -6,10 +6,14 @@ Handles OAuth2 credential management including:
- Credential storage/retrieval
"""
from __future__ import annotations
import os
import json
import logging
import time
from typing import Optional
from google.oauth2.credentials import Credentials
from google_auth_oauthlib.flow import InstalledAppFlow
from google.auth.transport.requests import Request
@@ -18,7 +22,9 @@ from . import config
logger = logging.getLogger(__name__)
def get_and_refresh_credentials(email, max_retries=3, retry_delay=5):
def get_and_refresh_credentials(
email: str, max_retries: int = 3, retry_delay: int = 5
) -> Optional[Credentials]:
"""Manages credential lifecycle with automated refresh and retry.
Args:
@@ -30,7 +36,7 @@ def get_and_refresh_credentials(email, max_retries=3, retry_delay=5):
Credentials: Valid credentials or None if unrecoverable
"""
token_file = os.path.join(config.CREDENTIALS_DIR, f"{email}.json")
creds = None
creds: Optional[Credentials] = None
if os.path.exists(token_file):
try:
creds = Credentials.from_authorized_user_file(token_file, config.SCOPES)
@@ -69,7 +75,9 @@ def get_and_refresh_credentials(email, max_retries=3, retry_delay=5):
return None
def run_interactive_auth(email, max_retries=3, retry_delay=5):
def run_interactive_auth(
email: str, max_retries: int = 3, retry_delay: int = 5
) -> Optional[Credentials]:
"""Executes interactive OAuth2 flow with error handling.
Args:
@@ -88,7 +96,7 @@ def run_interactive_auth(email, max_retries=3, retry_delay=5):
flow = InstalledAppFlow.from_client_secrets_file(
config.CLIENT_SECRETS_FILE, config.SCOPES
)
creds = flow.run_local_server(port=0)
creds: Credentials = flow.run_local_server(port=0)
token_file = os.path.join(config.CREDENTIALS_DIR, f"{email}.json")
with open(token_file, "w") as token:
token.write(creds.to_json())

View File

@@ -7,24 +7,27 @@ Contains:
"""
import os
from typing import List
# --- DIRECTORIES ---
CREDENTIALS_DIR = "credentials"
LOG_DIR = "logs"
SCHEMA_DIR = "schemas"
CREDENTIALS_DIR: str = "credentials"
LOG_DIR: str = "logs"
SCHEMA_DIR: str = "schemas"
# --- FILENAMES ---
EMAILS_FILE = "emails.txt"
CLIENT_SECRETS_FILE = "credentials.json"
API_KEYS_DATABASE_FILE = "api_keys_database.json"
EMAILS_FILE: str = "emails.txt"
CLIENT_SECRETS_FILE: str = "credentials.json"
API_KEYS_DATABASE_FILE: str = "api_keys_database.json"
# --- SCHEMA ---
API_KEYS_SCHEMA_FILE = os.path.join(SCHEMA_DIR, "v1", "api_keys_database.schema.json")
API_KEYS_SCHEMA_FILE: str = os.path.join(
SCHEMA_DIR, "v1", "api_keys_database.schema.json"
)
# --- GOOGLE API ---
SCOPES = [
SCOPES: List[str] = [
"https://www.googleapis.com/auth/cloud-platform",
]
GENERATIVE_LANGUAGE_API = "generativelanguage.googleapis.com"
GEMINI_API_KEY_DISPLAY_NAME = "Gemini API Key"
GENERATIVE_LANGUAGE_API_KEY_DISPLAY_NAME = "Generative Language API Key"
GENERATIVE_LANGUAGE_API: str = "generativelanguage.googleapis.com"
GEMINI_API_KEY_DISPLAY_NAME: str = "Gemini API Key"
GENERATIVE_LANGUAGE_API_KEY_DISPLAY_NAME: str = "Generative Language API Key"

View File

@@ -7,15 +7,24 @@ Implements:
- Data versioning and backup
"""
from __future__ import annotations
import os
import json
import logging
import sys
from datetime import datetime, timezone
from typing import Any, Dict, List
import jsonschema
from google.cloud.resourcemanager_v3.types import Project as CloudProject
from google.cloud.api_keys_v2.types import Key as CloudKey
from .types import Account, ApiKeysDatabase, Project as LocalProject, TempKey
def load_schema(filename):
def load_schema(filename: str) -> Dict[str, Any]:
"""Validates and loads JSON schema definition.
Args:
@@ -38,10 +47,18 @@ def load_schema(filename):
sys.exit(1)
def load_keys_database(filename, schema):
def load_keys_database(filename: str, schema: Dict[str, Any]) -> ApiKeysDatabase:
"""Loads and validates the JSON database of API keys."""
now = datetime.now(timezone.utc).isoformat()
empty_db: ApiKeysDatabase = {
"schema_version": "1.0.0",
"accounts": [],
"generation_timestamp_utc": now,
"last_modified_utc": now,
}
if not os.path.exists(filename):
return {"schema_version": "1.0.0", "accounts": []}
return empty_db
with open(filename, "r") as f:
try:
data = json.load(f)
@@ -54,10 +71,12 @@ def load_keys_database(filename, schema):
f"Database file '{filename}' is not valid. {e.message}. Starting fresh."
)
return {"schema_version": "1.0.0", "accounts": []}
return empty_db
def save_keys_to_json(data, filename, schema):
def save_keys_to_json(
data: ApiKeysDatabase, filename: str, schema: Dict[str, Any]
) -> None:
"""Validates and saves the API key data to a single JSON file."""
now = datetime.now(timezone.utc).isoformat()
data["generation_timestamp_utc"] = data.get("generation_timestamp_utc", now)
@@ -73,7 +92,9 @@ def save_keys_to_json(data, filename, schema):
sys.exit(1)
def add_key_to_database(account_entry, project, key_object):
def add_key_to_database(
account_entry: Account, project: CloudProject, key_object: TempKey | CloudKey
) -> None:
"""Adds a new API key's details to the data structure."""
project_id = project.project_id
@@ -86,7 +107,7 @@ def add_key_to_database(account_entry, project, key_object):
None,
)
if not project_entry:
project_entry = {
project_entry: LocalProject = {
"project_info": {
"project_id": project_id,
"project_name": project.display_name,
@@ -97,7 +118,7 @@ def add_key_to_database(account_entry, project, key_object):
}
account_entry["projects"].append(project_entry)
api_targets = []
api_targets: List[Dict[str, List[str]]] = []
if key_object.restrictions and key_object.restrictions.api_targets:
for target in key_object.restrictions.api_targets:
api_targets.append({"service": target.service, "methods": []})
@@ -134,7 +155,9 @@ def add_key_to_database(account_entry, project, key_object):
)
def remove_keys_from_database(account_entry, project_id, deleted_keys_uids):
def remove_keys_from_database(
account_entry: Account, project_id: str, deleted_keys_uids: List[str]
) -> None:
"""Removes deleted API keys from the data structure."""
project_entry = next(
(
@@ -158,4 +181,4 @@ def remove_keys_from_database(account_entry, project_id, deleted_keys_uids):
if num_removed > 0:
logging.info(
f" Removed {num_removed} key(s) from local database for project {project_id}"
)
)

View File

@@ -6,6 +6,8 @@ Defines domain-specific exceptions for:
- API operation constraints
"""
from __future__ import annotations
class TermsOfServiceNotAcceptedError(Exception):
"""Indicates unaccepted Terms of Service for critical API operations.
@@ -15,7 +17,7 @@ class TermsOfServiceNotAcceptedError(Exception):
url (str): URL for Terms of Service acceptance portal
"""
def __init__(self, message, url):
def __init__(self, message: str, url: str) -> None:
self.message = message
self.url = url
super().__init__(self.message)
super().__init__(self.message)

View File

@@ -2,14 +2,22 @@
Functions for interacting with Google Cloud Platform APIs.
"""
from __future__ import annotations
import logging
from datetime import datetime, timezone
from typing import List, Optional
from google.cloud import service_usage_v1, api_keys_v2
from google.api_core import exceptions as google_exceptions
from google.oauth2.credentials import Credentials
from . import config, exceptions
def enable_api(project_id, credentials, dry_run=False):
def enable_api(
project_id: str, credentials: Credentials, dry_run: bool = False
) -> bool:
"""Manages Generative Language API enablement with error handling.
Args:
@@ -62,7 +70,9 @@ def enable_api(project_id, credentials, dry_run=False):
return False
def create_api_key(project_id, credentials, dry_run=False):
def create_api_key(
project_id: str, credentials: Credentials, dry_run: bool = False
) -> Optional[api_keys_v2.Key]:
"""Generates restricted API key with security constraints.
Args:
@@ -121,9 +131,11 @@ def create_api_key(project_id, credentials, dry_run=False):
return None
def delete_api_keys(project_id, credentials, dry_run=False):
def delete_api_keys(
project_id: str, credentials: Credentials, dry_run: bool = False
) -> List[str]:
"""Deletes all API keys with the display name 'Gemini API Key' and returns their UIDs."""
deleted_keys_uids = []
deleted_keys_uids: List[str] = []
try:
api_keys_client = api_keys_v2.ApiKeysClient(credentials=credentials)
parent = f"projects/{project_id}/locations/global"
@@ -166,4 +178,4 @@ def delete_api_keys(project_id, credentials, dry_run=False):
logging.error(
f" An API error occurred while deleting keys for project {project_id}: {err}"
)
return []
return []

View File

@@ -2,15 +2,21 @@
Main entry point for the Gemini Key Management script.
"""
from __future__ import annotations
import argparse
import logging
import sys
import os
import concurrent.futures
from typing import List, Dict
from google.oauth2.credentials import Credentials
from . import utils, config, auth, database, actions
def main():
def main() -> None:
"""Orchestrates API key lifecycle management workflow.
Handles:
@@ -75,7 +81,7 @@ def main():
schema = database.load_schema(config.API_KEYS_SCHEMA_FILE)
api_keys_data = database.load_keys_database(config.API_KEYS_DATABASE_FILE, schema)
emails_to_process = []
emails_to_process: List[str] = []
if args.email:
emails_to_process.append(args.email)
elif args.action == "delete":
@@ -89,8 +95,8 @@ def main():
logging.info("No emails found in emails.txt. Exiting.")
sys.exit(1)
creds_map = {}
emails_needing_interactive_auth = []
creds_map: Dict[str, Credentials] = {}
emails_needing_interactive_auth: List[str] = []
logging.info("Checking credentials and refreshing tokens for all accounts...")
@@ -156,4 +162,4 @@ def main():
else:
logging.warning(
f"Skipping account {email} because authentication was not successful."
)
)

122
gemini_key_manager/types.py Normal file
View File

@@ -0,0 +1,122 @@
"""
This module defines the core data structures for the Gemini Key Management system
using TypedDicts to ensure type safety and clarity. These structures mirror the
JSON schema for the API keys database, providing a single source of truth for
data shapes throughout the application.
"""
from __future__ import annotations
from typing import List, Literal, TYPE_CHECKING, TypedDict
from datetime import datetime
if TYPE_CHECKING:
from google.cloud.api_keys_v2.types import Key as CloudKey
from google.cloud.api_keys_v2.types import Restrictions as CloudRestrictions
class ApiTarget(TypedDict):
"""Represents a single API target for key restrictions."""
service: str
methods: List[str]
class Restrictions(TypedDict):
"""Defines the API restrictions for a key."""
api_targets: List[ApiTarget]
class KeyDetails(TypedDict):
"""Contains the detailed information for an API key."""
key_string: str
key_id: str
key_name: str
display_name: str
creation_timestamp_utc: str
last_updated_timestamp_utc: str
class ApiKey(TypedDict):
"""Represents a single API key, including its details and restrictions."""
key_details: KeyDetails
restrictions: Restrictions
state: Literal["ACTIVE", "INACTIVE"]
class ProjectInfo(TypedDict):
"""Contains metadata about a Google Cloud project."""
project_id: str
project_name: str
project_number: str
state: str
class Project(TypedDict):
"""Represents a Google Cloud project and its associated API keys."""
project_info: ProjectInfo
api_keys: List[ApiKey]
class AuthenticationDetails(TypedDict):
"""Holds authentication information for a Google account."""
token_file: str
scopes: List[str]
class AccountDetails(TypedDict):
"""Contains details for a single Google account."""
email: str
authentication_details: AuthenticationDetails
class Account(TypedDict):
"""Represents a single user account and all its associated projects."""
account_details: AccountDetails
projects: List[Project]
class ApiKeysDatabase(TypedDict):
"""
Defines the root structure of the JSON database file, holding all account
and key information.
"""
schema_version: str
accounts: List[Account]
generation_timestamp_utc: str
last_modified_utc: str
class TempKey:
"""
A temporary, mock-like key object used for database operations when a full
cloud key object is not available or necessary. It provides a compatible
structure for functions that expect a key-like object.
Attributes:
key_string (str): The actual API key string.
uid (str): The unique identifier of the key.
name (str): The full resource name of the key.
display_name (str): The human-readable display name.
create_time (datetime): The timestamp of key creation.
update_time (datetime): The timestamp of the last update.
restrictions (CloudRestrictions): The usage restrictions for the key.
"""
def __init__(self, cloud_key: "CloudKey", key_string: str) -> None:
self.key_string: str = key_string
self.uid: str = cloud_key.uid
self.name: str = cloud_key.name
self.display_name: str = cloud_key.display_name
self.create_time: datetime = cloud_key.create_time
self.update_time: datetime = cloud_key.update_time
self.restrictions: "CloudRestrictions" = cloud_key.restrictions

View File

@@ -1,18 +1,23 @@
"""
Utility functions for the Gemini Key Management script.
"""
from __future__ import annotations
import logging
import os
import sys
import random
import string
from datetime import datetime, timezone
from typing import List
from colorama import Fore, Style, init
from . import config
class ColoredFormatter(logging.Formatter):
"""Adds ANSI color coding to log output based on severity.
Attributes:
LOG_COLORS (dict): Maps log levels to color codes
"""
@@ -25,7 +30,7 @@ class ColoredFormatter(logging.Formatter):
logging.CRITICAL: Fore.RED + Style.BRIGHT,
}
def format(self, record):
def format(self, record: logging.LogRecord) -> str:
"""Formats the log record with appropriate colors."""
color = self.LOG_COLORS.get(record.levelno)
message = super().format(record)
@@ -39,15 +44,16 @@ class ColoredFormatter(logging.Formatter):
message = color + message + Style.RESET_ALL
return message
def setup_logging():
def setup_logging() -> None:
"""Configures dual logging to file and colorized console output.
Creates:
- Rotating file handler with full debug details
- Stream handler with color-coded brief format
Ensures proper directory structure for log files
"""
init(autoreset=True) # Initialize Colorama
init(autoreset=True) # Initialize Colorama
if not os.path.exists(config.LOG_DIR):
os.makedirs(config.LOG_DIR)
@@ -63,7 +69,7 @@ def setup_logging():
logger.handlers.clear()
# File handler for detailed, non-colored logging
file_handler = logging.FileHandler(log_filepath, encoding='utf-8')
file_handler = logging.FileHandler(log_filepath, encoding="utf-8")
file_formatter = logging.Formatter(
"%(asctime)s - %(levelname)s - [%(name)s:%(module)s:%(lineno)d] - %(message)s"
)
@@ -78,7 +84,8 @@ def setup_logging():
logging.info(f"Logging initialized. Log file: {log_filepath}")
def load_emails_from_file(filename):
def load_emails_from_file(filename: str) -> List[str]:
"""Loads a list of emails from a text file, ignoring comments."""
if not os.path.exists(filename):
logging.error(f"Email file not found at '{filename}'")
@@ -86,9 +93,12 @@ def load_emails_from_file(filename):
return []
with open(filename, "r") as f:
# Ignore empty lines and lines starting with #
return [line.strip() for line in f if line.strip() and not line.startswith("#")]
return [
line.strip() for line in f if line.strip() and not line.startswith("#")
]
def generate_random_string(length=10):
def generate_random_string(length: int = 10) -> str:
"""Generates a random alphanumeric string of a given length."""
letters_and_digits = string.ascii_lowercase + string.digits
return ''.join(random.choice(letters_and_digits) for i in range(length))
return "".join(random.choice(letters_and_digits) for _ in range(length))

View File

@@ -10,5 +10,5 @@ dependencies = [
"google-cloud-resource-manager>=1.14.2",
"google-cloud-service-usage>=1.13.1",
"jsonschema>=4.25.1",
"colorama>=0.4.6",
"colorama>=0.4.6"
]