Source code for afterimage.providers.document_providers

# document_providers.py
from __future__ import annotations

import glob
import json
import math
import random
import logging
from abc import abstractmethod
from pathlib import Path
from typing import Iterable, List, Optional, Protocol, runtime_checkable

from ..types import Document

logger = logging.getLogger(__name__)


[docs] @runtime_checkable class DocumentProvider(Protocol): """ Unified DocumentProvider protocol. Minimal required method for implementations: - _load_documents() -> list[Document] Public helpers (provided by protocol defaults below): - get_documents(n: int) -> list[Document] - get_all() -> list[Document] - sample(n: int) -> list[Document] - report_doc_usage(document_id: str) -> int - set_target_context_usage_count(target_context_usage_count: int | None) - mark_fully_covered(document_id: str) - clear_cache() - __len__(), __iter__(), __getitem__(i) """ @abstractmethod def _load_documents(self) -> list[Document]: """Load (and return) all documents. Implementations may cache internally.""" ... # --- default helpers (implementations can override) --- def _ensure_usage_tracking_state(self, docs: list[Document]) -> None: """Initialize and refresh usage/weight bookkeeping for loaded documents.""" if not hasattr(self, "_doc_usage_counts"): self._doc_usage_counts = {} if not hasattr(self, "_doc_sampling_weights"): self._doc_sampling_weights = {} if not hasattr(self, "_fully_covered_doc_ids"): self._fully_covered_doc_ids = set() if not hasattr(self, "target_context_usage_count"): self.target_context_usage_count = None if not hasattr(self, "_target_context_usage_count_explicit"): self._target_context_usage_count_explicit = False doc_ids = {doc.id for doc in docs} self._doc_usage_counts = { doc_id: self._doc_usage_counts.get(doc_id, 0) for doc_id in doc_ids } self._doc_sampling_weights = { doc_id: self._doc_sampling_weights.get(doc_id, 1.0) for doc_id in doc_ids } self._fully_covered_doc_ids = { doc_id for doc_id in self._fully_covered_doc_ids if doc_id in doc_ids } self._recalculate_sampling_weights(docs) def _get_sampling_weight(self, document_id: str) -> float: if document_id in self._fully_covered_doc_ids: return 0.0 usage_count = self._doc_usage_counts.get(document_id, 0) target_usage_count = self.target_context_usage_count if target_usage_count is not None: if target_usage_count <= 0: return 0.0 return float(max(target_usage_count - usage_count, 0)) return 1.0 / (usage_count + 1) def _recalculate_sampling_weights(self, docs: list[Document]) -> None: for doc in docs: self._doc_sampling_weights[doc.id] = self._get_sampling_weight(doc.id) def _weighted_sample_without_replacement( self, docs: list[Document], k: int ) -> list[Document]: remaining_docs = list(docs) sampled_docs: list[Document] = [] while remaining_docs and len(sampled_docs) < k: weights = [self._doc_sampling_weights.get(doc.id, 0.0) for doc in remaining_docs] if not any(weight > 0 for weight in weights): break selected_index = random.choices( range(len(remaining_docs)), weights=weights, k=1, )[0] sampled_docs.append(remaining_docs.pop(selected_index)) return sampled_docs
[docs] def get_all(self) -> list[Document]: """Return all documents (loads once if implementation caches).""" docs = self._load_documents() self._ensure_usage_tracking_state(docs) return docs
[docs] def get_documents(self, n: int) -> list[Document]: """Return up to n random documents. If n is math.inf, return all documents.""" if n is None: n = math.inf if n == math.inf: return self.get_all() docs = self.get_all() k = min(int(n), len(docs)) if k == 0: return [] active_docs = [ doc for doc in docs if self._doc_sampling_weights.get(doc.id, 0.0) > 0 ] if not active_docs: return [] if k >= len(active_docs): return list(active_docs) return self._weighted_sample_without_replacement(active_docs, k)
[docs] def sample(self, n: int) -> list[Document]: """Alias for get_documents.""" return self.get_documents(n)
[docs] def report_doc_usage(self, document_id: str) -> int: """Record usage for a document and refresh sampling weights.""" docs = self.get_all() if document_id not in self._doc_usage_counts: raise KeyError(f"Unknown document id: {document_id}") self._doc_usage_counts[document_id] += 1 self._recalculate_sampling_weights(docs) return self._doc_usage_counts[document_id]
[docs] def set_target_context_usage_count( self, target_context_usage_count: int | None, ) -> None: """Update the target usage count used by weight calculation.""" self.target_context_usage_count = target_context_usage_count self._target_context_usage_count_explicit = target_context_usage_count is not None docs = self.get_all() self._recalculate_sampling_weights(docs)
[docs] def get_target_context_usage_count(self) -> int | None: """Return the current target usage count, if configured.""" self.get_all() return self.target_context_usage_count
[docs] def mark_fully_covered(self, document_id: str) -> None: """Exclude a document from future weighted sampling.""" self.get_all() if document_id not in self._doc_sampling_weights: raise KeyError(f"Unknown document id: {document_id}") self._fully_covered_doc_ids.add(document_id) self._doc_sampling_weights[document_id] = 0.0
[docs] def clear_cache(self) -> None: """Optional: implementations can override to clear internal caches.""" # protocol default: no-op return None
def __len__(self) -> int: """Length if supported (may force load).""" return len(self.get_all()) def __iter__(self) -> Iterable[Document]: return iter(self.get_all()) def __getitem__(self, index: int) -> Document: """Index access (forces load).""" return self.get_all()[index]
# ---------- Concrete implementations ----------
[docs] class InMemoryDocumentProvider(DocumentProvider): """Simple provider backed by a list of strings.""" def __init__( self, texts: list[str | Document], target_context_usage_count: int | None = None, ): if not isinstance(texts, list) or not ( all(isinstance(d, str) for d in texts) or all(isinstance(d, Document) for d in texts) ): raise TypeError("texts must be a list[str|Document] but got: " + str(texts)) self.target_context_usage_count = target_context_usage_count self._target_context_usage_count_explicit = ( target_context_usage_count is not None ) self._documents = ( [Document(text=text) for text in texts] if len(texts) > 0 and isinstance(texts[0], str) else texts ) def _load_documents(self) -> list[Document]: return self._documents
[docs] def clear_cache(self) -> None: # nothing to clear for in-memory return None
[docs] class FileSystemDocumentProvider(DocumentProvider): """Load text files matched by a glob pattern.""" def __init__( self, path_pattern: str, encoding: str = "utf-8", recursive: bool = False, min_length: int = 1, cache: bool = True, target_context_usage_count: int | None = None, ): self.pattern = path_pattern self.encoding = encoding self.recursive = recursive self.min_length = min_length self.target_context_usage_count = target_context_usage_count self._target_context_usage_count_explicit = ( target_context_usage_count is not None ) self._cache_enabled = bool(cache) self._cache: Optional[list[Document]] = None def _find_files(self) -> List[str]: return glob.glob(self.pattern, recursive=self.recursive) def _load_documents(self) -> list[Document]: if self._cache_enabled and self._cache is not None: return self._cache files = self._find_files() if not files: raise FileNotFoundError(f"No files matching pattern: {self.pattern}") docs: list[Document] = [] for path in files: try: with open(path, "r", encoding=self.encoding) as f: text = f.read().strip() if len(text) >= self.min_length: docs.append(Document(id=str(Path(path).resolve()), text=text)) except Exception as exc: logger.warning("Failed to read %s: %s", path, exc) continue if not docs: raise ValueError( f"No documents found after filtering for pattern: {self.pattern}" ) if self._cache_enabled: self._cache = docs return docs
[docs] def clear_cache(self) -> None: self._cache = None
[docs] class DirectoryDocumentProvider(DocumentProvider): """Search a directory for several filename patterns (txt/md/jsonl etc).""" def __init__( self, directory: str | Path, file_patterns: Optional[List[str]] = None, encoding: str = "utf-8", recursive: bool = True, min_length: int = 1, cache: bool = True, target_context_usage_count: int | None = None, ): self.directory = Path(directory) self.patterns = file_patterns or ["*.txt", "*.md"] self.encoding = encoding self.recursive = recursive self.min_length = min_length self.target_context_usage_count = target_context_usage_count self._target_context_usage_count_explicit = ( target_context_usage_count is not None ) self._cache_enabled = bool(cache) self._cache: Optional[list[Document]] = None def _find_files(self) -> List[Path]: patterns = self.patterns files: List[Path] = [] for p in patterns: glob_pat = f"**/{p}" if self.recursive else p files.extend(self.directory.glob(glob_pat)) return files def _load_documents(self) -> list[Document]: if self._cache_enabled and self._cache is not None: return self._cache files = self._find_files() if not files: raise FileNotFoundError( f"No files found in {self.directory} for {self.patterns}" ) docs: list[Document] = [] for path in files: if not path.is_file(): continue try: with path.open("r", encoding=self.encoding) as f: text = f.read().strip() if len(text) >= self.min_length: docs.append(Document(id=str(path.resolve()), text=text)) except Exception as exc: logger.debug("skip %s: %s", path, exc) continue if not docs: raise ValueError("No valid documents found in directory after filtering") if self._cache_enabled: self._cache = docs return docs
[docs] def clear_cache(self) -> None: self._cache = None
[docs] class JSONLDocumentProvider(DocumentProvider): """Load text fields from one or more JSONL files. content_key selects which key from each JSON object to use. """ def __init__( self, path_pattern: str, content_key: str = "text", encoding: str = "utf-8", recursive: bool = False, cache: bool = True, max_docs: Optional[int] = None, target_context_usage_count: int | None = None, ): self.pattern = path_pattern self.content_key = content_key self.encoding = encoding self.recursive = recursive self.target_context_usage_count = target_context_usage_count self._target_context_usage_count_explicit = ( target_context_usage_count is not None ) self._cache_enabled = bool(cache) self._cache: Optional[list[Document]] = None self._max_docs = max_docs def _find_files(self) -> List[str]: return glob.glob(self.pattern, recursive=self.recursive) def _load_documents(self) -> list[Document]: if self._cache_enabled and self._cache is not None: return self._cache files = self._find_files() if not files: raise FileNotFoundError(f"No JSONL files matching: {self.pattern}") docs: list[Document] = [] for fp in files: try: with open(fp, "r", encoding=self.encoding) as f: for line_number, line in enumerate(f, start=1): if not line.strip(): continue try: obj = json.loads(line) except json.JSONDecodeError: logger.debug("invalid json line in %s - skipping", fp) continue if isinstance(obj, dict) and self.content_key in obj: val = obj[self.content_key] if isinstance(val, str) and val.strip(): docs.append( Document( id=f"{Path(fp).resolve()}:{line_number}", text=val.strip(), ) ) if self._max_docs and len(docs) >= self._max_docs: break except Exception as exc: logger.warning("Failed to read %s: %s", fp, exc) continue if self._max_docs and len(docs) >= self._max_docs: break if not docs: raise ValueError("No documents extracted from JSONL files") if self._cache_enabled: self._cache = docs return docs
[docs] def clear_cache(self) -> None: self._cache = None
# Qdrant optional provider try: from qdrant_client import QdrantClient # type: ignore from qdrant_client.http.models import Filter, ScoredPoint # type: ignore
[docs] class QdrantDocumentProvider(DocumentProvider): """Load text payloads from a Qdrant collection via scroll. Note: requires qdrant-client package. """ def __init__( self, client: QdrantClient, collection_name: str, content_key: str = "text", batch_size: int = 500, scroll_filter: Optional[Filter] = None, with_payload_keys: Optional[List[str]] = None, cache: bool = True, max_docs: Optional[int] = None, target_context_usage_count: int | None = None, ): self.client = client self.collection_name = collection_name self.content_key = content_key self.batch_size = batch_size self.scroll_filter = scroll_filter self.target_context_usage_count = target_context_usage_count self._target_context_usage_count_explicit = ( target_context_usage_count is not None ) self.with_payload_keys = ( [content_key] if with_payload_keys is None else with_payload_keys ) self._cache_enabled = bool(cache) self._cache: Optional[list[Document]] = None self._max_docs = max_docs def _scroll_once(self, offset: Optional[int] = None) -> List[ScoredPoint]: # using client.scroll - returns (points, next_page) depending on client version resp = self.client.scroll( collection_name=self.collection_name, offset=offset, limit=self.batch_size, scroll_filter=self.scroll_filter, with_payload=True, with_vectors=False, ) # qdrant client may return a tuple; safe-guard: if isinstance(resp, tuple) or isinstance(resp, list): points = resp[0] else: points = resp return points def _load_documents(self) -> list[Document]: if self._cache_enabled and self._cache is not None: return self._cache docs: list[Document] = [] offset = None while True: points = self._scroll_once(offset) if not points: break for p in points: if getattr(p, "payload", None) and self.content_key in p.payload: val = p.payload[self.content_key] if isinstance(val, str) and val.strip(): docs.append(Document(id=str(p.id), text=val.strip())) if self._max_docs and len(docs) >= self._max_docs: break if self._max_docs and len(docs) >= self._max_docs: break offset = points[-1].id if points else None if len(points) < self.batch_size: break if not docs: raise ValueError( f"No documents found in Qdrant collection {self.collection_name}" ) if self._cache_enabled: self._cache = docs return docs
[docs] def clear_cache(self) -> None: self._cache = None
except Exception: # Qdrant not installed - define a lightweight placeholder for type-checkers/usage errors. QdrantDocumentProvider = None # type: ignore # ---------- small usage / test snippet ---------- if __name__ == "__main__": # quick smoke test mem = InMemoryDocumentProvider(["a", "b", "c", "d"]) assert len(mem) == 4 assert len(mem.get_documents(2)) == 2 assert len(mem.get_documents(math.inf)) == 4 # JSONL/FileSystem/Directory providers: create small files and test manually as needed. print("document_providers module loaded - smoke tests passed.")