Source code for afterimage.providers.embedding_providers

"""Async embedding providers (API and process-based).

This module defines a single async contract (:class:`EmbeddingProvider`) and
concrete backends for OpenAI-compatible APIs, Google Gemini, and local
SentenceTransformer models running in worker processes so the asyncio event
loop is not blocked by CPU/GPU embedding work.
"""

from __future__ import annotations

import asyncio
import os
from concurrent.futures import ProcessPoolExecutor
from typing import Any, Optional, Protocol, Sequence, runtime_checkable

from google import genai
from openai import AsyncOpenAI

from ..key_management import SmartKeyPool


def _chunk_list(items: Sequence[str], size: int) -> list[list[str]]:
    """Split ``items`` into contiguous chunks of at most ``size`` elements.

    Args:
        items: Strings to partition.
        size: Maximum chunk length; must be positive.

    Returns:
        List of chunks, each a list of strings. Empty ``items`` yields ``[]``.

    Raises:
        ValueError: If ``size`` is not positive.
    """
    if size <= 0:
        raise ValueError("Batch size must be positive")
    return [list(items[i : i + size]) for i in range(0, len(items), size)]


async def _aclose_genai_client(client: genai.Client) -> None:
    """Best-effort shutdown of async HTTP resources on a google-genai client.

    Args:
        client: Client instance that may have opened aiohttp/httpx clients.
    """
    try:
        if hasattr(client, "aio"):
            api_client = client.aio._api_client
            if (
                hasattr(api_client, "_aiohttp_session")
                and api_client._aiohttp_session
            ):
                await api_client._aiohttp_session.close()
            if (
                hasattr(api_client, "_async_httpx_client")
                and api_client._async_httpx_client
            ):
                await api_client._async_httpx_client.aclose()
    except Exception:
        pass


[docs] @runtime_checkable class EmbeddingProvider(Protocol): """Protocol for async text embedding backends. Implementations return one dense vector per input string, preserve order, and may batch requests internally. Call :meth:`aclose` when the provider is no longer needed (required for process-based providers). """
[docs] async def embed(self, texts: list[str]) -> list[list[float]]: """Embed each string into a floating-point vector. Args: texts: Input strings. Empty list returns ``[]``. Returns: Embeddings in the same order as ``texts``; each embedding is a list of floats (dimension is model-specific). Note: Callers must not assume a single HTTP or IPC round trip; large inputs may be split into batches by the implementation. """ ...
[docs] async def aclose(self) -> None: """Release resources held by this provider (pools, clients). Implementations should make this idempotent (safe to call multiple times). API-only providers may use a no-op. """ ...
class _NoOpAcloseMixin: """Mixin adding a no-op :meth:`aclose` for stateless API providers.""" async def aclose(self) -> None: """See :meth:`EmbeddingProvider.aclose`.""" return None
[docs] class OpenAIEmbeddingProvider(_NoOpAcloseMixin): """Embeddings via the OpenAI async client (``embeddings.create``). Supports OpenAI and OpenAI-compatible servers via ``base_url``. Uses :class:`~afterimage.key_management.SmartKeyPool` for key rotation and error reporting consistent with chat providers. """ def __init__( self, api_key: str | SmartKeyPool, model: str = "text-embedding-3-small", *, base_url: Optional[str] = None, max_batch_size: int = 128, extra_create_kwargs: Optional[dict[str, Any]] = None, ): """Initialize the OpenAI embedding provider. Args: api_key: A single API key or a :class:`~afterimage.key_management.SmartKeyPool`. model: Embedding model id passed to ``embeddings.create``. base_url: Optional base URL for compatible APIs (e.g. proxies). max_batch_size: Maximum number of texts per ``embeddings.create`` call. extra_create_kwargs: Additional keyword arguments forwarded to ``embeddings.create`` (e.g. dimensions for matryoshka models). """ self.key_pool = ( api_key if isinstance(api_key, SmartKeyPool) else SmartKeyPool.from_single_key(api_key) ) self.model = model self.base_url = base_url self.max_batch_size = max_batch_size self._extra_create_kwargs = extra_create_kwargs or {}
[docs] async def embed(self, texts: list[str]) -> list[list[float]]: """Compute embeddings for ``texts`` using the configured model. Args: texts: Non-empty list of strings to embed, or empty for no work. Returns: One embedding per input string, in order. Raises: Exception: Propagates API errors after reporting the key to the pool. """ if not texts: return [] api_key = await self.key_pool.aget_next_key() client = AsyncOpenAI(api_key=api_key, base_url=self.base_url) try: out: list[list[float]] = [] for batch in _chunk_list(texts, self.max_batch_size): response = await client.embeddings.create( model=self.model, input=batch, **self._extra_create_kwargs, ) ordered = sorted(response.data, key=lambda d: d.index) out.extend([list(d.embedding) for d in ordered]) return out except Exception: await self.key_pool.areport_error(api_key) raise
[docs] class GeminiEmbeddingProvider(_NoOpAcloseMixin): """Embeddings via Google Gemini ``client.aio.models.embed_content``. Uses the async Gemini client, closes transient HTTP resources after each :meth:`embed` call, and integrates with :class:`~afterimage.key_management.SmartKeyPool`. """ def __init__( self, api_key: str | SmartKeyPool, model: str = "text-embedding-004", *, max_batch_size: int = 128, ): """Initialize the Gemini embedding provider. Args: api_key: A single API key or a :class:`~afterimage.key_management.SmartKeyPool`. model: Gemini embedding model resource name or id. max_batch_size: Maximum number of strings per ``embed_content`` call. """ self.key_pool = ( api_key if isinstance(api_key, SmartKeyPool) else SmartKeyPool.from_single_key(api_key) ) self.model = model self.max_batch_size = max_batch_size
[docs] async def embed(self, texts: list[str]) -> list[list[float]]: """Compute embeddings for ``texts`` using the configured Gemini model. Args: texts: Non-empty list of strings to embed, or empty for no work. Returns: One embedding per input string, in order. Raises: ValueError: If the API returns an embedding without ``values``. Exception: Propagates API errors after reporting the key to the pool. """ if not texts: return [] api_key = await self.key_pool.aget_next_key() client = genai.Client(api_key=api_key, vertexai=False) try: out: list[list[float]] = [] for batch in _chunk_list(texts, self.max_batch_size): resp = await client.aio.models.embed_content( model=self.model, contents=batch, ) for emb in resp.embeddings or []: values = emb.values if values is None: raise ValueError("Gemini embedding missing values") out.append(list(values)) return out except Exception: await self.key_pool.areport_error(api_key) raise finally: await _aclose_genai_client(client)
# --- Process pool workers (top-level for pickling on Windows) --- _worker_model: Any = None def _process_pool_init(model_name: str) -> None: """Load ``SentenceTransformer`` once in each worker process. Args: model_name: HuggingFace model id or local path. Raises: ImportError: If ``sentence_transformers`` is not installed. """ global _worker_model try: from sentence_transformers import SentenceTransformer except ImportError as e: raise ImportError( "ProcessEmbeddingProvider requires sentence-transformers. " 'Install with pip install "afterimage[embeddings-local]" or ' "pip install sentence-transformers" ) from e _worker_model = SentenceTransformer(model_name) def _process_pool_embed_batch(texts: list[str]) -> list[list[float]]: """Encode a batch in a worker process (must run after pool initializer). Args: texts: Batch of strings to encode. Returns: List of embedding vectors as plain Python floats. Raises: RuntimeError: If the worker model was not initialized. """ global _worker_model if _worker_model is None: raise RuntimeError("Embedding worker model not initialized") import numpy as np arr = _worker_model.encode(texts, convert_to_numpy=True) if isinstance(arr, np.ndarray): return arr.tolist() return [row.tolist() for row in arr]
[docs] class ProcessEmbeddingProvider: """Local embeddings using SentenceTransformer in a process pool. Inference runs in child processes so the host asyncio loop is not blocked. The model is loaded once per worker via the pool initializer. Call :meth:`aclose` to shut down workers when finished. """ def __init__( self, model_name: str, *, max_workers: int = 2, max_batch_size: int = 64, ): """Initialize configuration; the process pool starts on first :meth:`embed`. Args: model_name: HuggingFace model id or path passed to SentenceTransformer. max_workers: Number of worker processes. max_batch_size: Maximum strings passed to each worker call. Raises: ValueError: If ``max_workers`` is less than 1. """ if max_workers < 1: raise ValueError("max_workers must be at least 1") self._model_name = model_name self._max_workers = max_workers self._max_batch_size = max_batch_size self._executor: ProcessPoolExecutor | None = None self._closed = False def _get_executor(self) -> ProcessPoolExecutor: """Lazily create the process pool. Returns: Shared executor for this provider instance. Raises: RuntimeError: If :meth:`aclose` was already called. """ if self._closed: raise RuntimeError("ProcessEmbeddingProvider is closed") if self._executor is None: self._executor = ProcessPoolExecutor( max_workers=self._max_workers, initializer=_process_pool_init, initargs=(self._model_name,), ) return self._executor
[docs] async def embed(self, texts: list[str]) -> list[list[float]]: """Encode texts in worker processes via :func:`asyncio.loop.run_in_executor`. Args: texts: Non-empty list of strings to embed, or empty for no work. Returns: One embedding per input string, in order. Raises: RuntimeError: If the provider was closed before or during use. ImportError: In workers if sentence-transformers is missing. """ if not texts: return [] loop = asyncio.get_running_loop() executor = self._get_executor() out: list[list[float]] = [] for batch in _chunk_list(texts, self._max_batch_size): part: list[list[float]] = await loop.run_in_executor( executor, _process_pool_embed_batch, batch, ) out.extend(part) return out
[docs] async def aclose(self) -> None: """Shut down the process pool and mark this provider closed. Waits for workers to finish. Idempotent after the first call. """ if self._executor is not None: self._executor.shutdown(wait=True) self._executor = None self._closed = True
[docs] class EmbeddingProviderFactory: """Factory for constructing :class:`EmbeddingProvider` instances from config. Config dictionaries are intended to be JSON-serializable (aside from embedding-specific nested structures). Key names are matched case-insensitively. """
[docs] @staticmethod def create( config: dict[str, Any], *, api_key: str | None = None, key_pool: SmartKeyPool | None = None, ) -> EmbeddingProvider: """Build a provider from a configuration mapping. Args: config: Must include ``type`` (``"openai"``, ``"gemini"``, or ``"process"``). Optional keys: * ``model`` — embedding model id for API providers. * ``model_path`` — HuggingFace id or path for ``process`` (also ``model`` is accepted for process). * ``base_url`` — OpenAI-compatible base URL. * ``workers`` — process count for ``process`` (default ``2``). * ``max_batch_size`` — chunk size for batched calls. * ``api_key`` — inline secret when not using ``key_pool``. api_key: Optional default API key when ``key_pool`` is omitted (OpenAI/Gemini only). key_pool: Optional shared pool; takes precedence over ``api_key`` and env vars for API providers. Returns: A concrete :class:`EmbeddingProvider`. Raises: ValueError: If ``type`` is missing, unknown, or required keys are absent. """ cfg = {k.lower(): v for k, v in config.items()} provider_type = cfg.get("type") if not provider_type: raise ValueError("Embedding config requires 'type'") max_batch = cfg.get("max_batch_size") max_batch_int = int(max_batch) if max_batch is not None else None if provider_type == "openai": if key_pool is not None: pool = key_pool else: key = ( api_key or cfg.get("api_key") or os.environ.get("OPENAI_API_KEY") ) if not key: raise ValueError( "OpenAI embedding provider needs key_pool, api_key, config['api_key'], or OPENAI_API_KEY" ) pool = SmartKeyPool.from_single_key(str(key)) model = cfg.get("model", "text-embedding-3-small") kwargs: dict[str, Any] = { "api_key": pool, "model": str(model), "base_url": cfg.get("base_url"), } if max_batch_int is not None: kwargs["max_batch_size"] = max_batch_int return OpenAIEmbeddingProvider(**kwargs) if provider_type == "gemini": if key_pool is not None: pool = key_pool else: key = ( api_key or cfg.get("api_key") or os.environ.get("GEMINI_API_KEY") ) if not key: raise ValueError( "Gemini embedding provider needs key_pool, api_key, config['api_key'], or GEMINI_API_KEY" ) pool = SmartKeyPool.from_single_key(str(key)) model = cfg.get("model", "text-embedding-004") gkwargs: dict[str, Any] = { "api_key": pool, "model": str(model), } if max_batch_int is not None: gkwargs["max_batch_size"] = max_batch_int return GeminiEmbeddingProvider(**gkwargs) if provider_type == "process": model_name = cfg.get("model_path") or cfg.get("model") if not model_name: raise ValueError( "Process embedding provider requires 'model_path' or 'model' (HuggingFace id or path)" ) workers = int(cfg.get("workers", 2)) pkwargs: dict[str, Any] = { "model_name": str(model_name), "max_workers": workers, } if max_batch_int is not None: pkwargs["max_batch_size"] = max_batch_int return ProcessEmbeddingProvider(**pkwargs) raise ValueError(f"Unknown embedding provider type: {provider_type!r}")