Source code for afterimage.base

import asyncio
import logging

from .common import GeneratedInstructions
from .metadata_utils import extract_unique_context_ids
from .monitoring import GenerationMonitor
from .types import (
    EvaluatedConversationWithContext,
    GeneratedResponsePrompt,
    GenerationState,
)

logger = logging.getLogger(__name__)


[docs] class BaseGenerator: """Intended to serve as the base class for all generator classes""" def _iter_stopping_callbacks(self, callbacks): """Yield callbacks, recursively flattening composite callback containers.""" for callback in callbacks or []: yield callback nested_callbacks = getattr(callback, "_callbacks", None) if nested_callbacks: yield from self._iter_stopping_callbacks(nested_callbacks) def _infer_target_context_usage_count( self, provider, stopping_criteria, ) -> int | None: """Infer a context usage target from stopping callbacks bound to a provider.""" inferred_targets: list[int] = [] for callback in self._iter_stopping_callbacks(stopping_criteria): callback_provider = getattr(callback, "provider", None) target_visits = getattr(callback, "target_visits", None) if ( callback_provider is provider and isinstance(target_visits, int) and target_visits > 0 ): inferred_targets.append(target_visits) return max(inferred_targets) if inferred_targets else None def _configure_context_sampling( self, instruction_generator_callback, stopping_criteria, ) -> None: """Configure provider sampling weights from stopping criteria when possible.""" provider = getattr(instruction_generator_callback, "provider", None) if provider is None or not hasattr(provider, "set_target_context_usage_count"): return if getattr(provider, "_target_context_usage_count_explicit", False): return inferred_target = self._infer_target_context_usage_count( provider, stopping_criteria, ) provider.set_target_context_usage_count(inferred_target) if getattr(self, "monitor", None) is not None: self.monitor.log_info( "Configured document sampling target", target_context_usage_count=inferred_target, provider_type=provider.__class__.__name__, ) def _configure_persona_sampling( self, instruction_generator_callback, num_requested: int | None, stopping_criteria=None, ) -> None: """Configure persona-aware callbacks with the effective request target.""" configure_persona_sampling = getattr( instruction_generator_callback, "configure_persona_sampling", None, ) if configure_persona_sampling is None: return inferred_num_requested = num_requested if inferred_num_requested is None: fixed_targets: list[int] = [] for callback in self._iter_stopping_callbacks(stopping_criteria): callback_n = getattr(callback, "n", None) if ( callback.__class__.__name__ == "FixedNumberStoppingCallback" and isinstance(callback_n, int) and callback_n > 0 ): fixed_targets.append(callback_n) inferred_num_requested = max(fixed_targets) if fixed_targets else None configure_persona_sampling(num_requested=inferred_num_requested) if getattr(self, "monitor", None) is not None: self.monitor.log_info( "Configured persona sampling target", target_personas_per_document=getattr( instruction_generator_callback, "_persona_target_per_document", None, ), callback_type=instruction_generator_callback.__class__.__name__, ) def _record_context_usage( self, instruction_generator_callback, item, ) -> None: """Report successful context usage back to the document provider.""" provider = getattr(instruction_generator_callback, "provider", None) if provider is None or not hasattr(provider, "report_doc_usage"): return metadata = getattr(item, "metadata", None) if not isinstance(metadata, dict): return for context_id in extract_unique_context_ids(metadata): provider.report_doc_usage(context_id)
[docs] def log_correspondent_prompt(self, correspondent_prompt: str | None) -> None: """Log the correspondent prompt in a standardized format. Args: correspondent_prompt: The correspondent prompt to log, or None if not set. """ self.monitor.log_info( "Correspondent prompt set", correspondent_prompt=correspondent_prompt, )
[docs] async def ainitialize(self, instruction_generator_callback=None): """Initializes the generator by creating the correspondent prompt if it doesn't exist.""" if self.correspondent_prompt is None: # Use provided callback if given, otherwise use instance attribute callback = ( instruction_generator_callback or self.instruction_generator_callback ) # Try to use callback first if available if callback is not None: if hasattr(callback, "acreate_correspondent_prompt"): created_prompt = await callback.acreate_correspondent_prompt( self.respondent_prompt ) else: created_prompt = await asyncio.to_thread( callback.create_correspondent_prompt, self.respondent_prompt ) if created_prompt is not None: self.correspondent_prompt = created_prompt self.log_correspondent_prompt(self.correspondent_prompt) return # Fallback to generator's method self.correspondent_prompt = await self.create_correspondent_prompt( self.respondent_prompt ) self.log_correspondent_prompt(self.correspondent_prompt)
[docs] def initialize(self, instruction_generator_callback=None): """Initializes the generator by creating the correspondent prompt if it doesn't exist.""" if self.correspondent_prompt is None: # Use provided callback if given, otherwise use instance attribute callback = ( instruction_generator_callback or self.instruction_generator_callback ) # Try to use callback first if available if callback is not None: created_prompt = callback.create_correspondent_prompt( self.respondent_prompt ) if created_prompt is not None: self.correspondent_prompt = created_prompt self.log_correspondent_prompt(self.correspondent_prompt) return # Fallback to generator's method self.correspondent_prompt = self.create_correspondent_prompt( self.respondent_prompt ) self.log_correspondent_prompt(self.correspondent_prompt)
[docs] def load_conversations( self, limit: int | None = None, offset: int | None = None, ) -> list[EvaluatedConversationWithContext]: return self.storage.load_conversations(limit=limit, offset=offset)
def _ensure_monitor(self): """Ensures that a monitor exists, creating a default one if necessary.""" if self.monitor is None: self.monitor = GenerationMonitor()
[docs] class BaseStoppingCallback: """Base class for callbacks that decide when to stop generation."""
[docs] async def should_stop(self, state: GenerationState) -> bool: """Return True if generation should stop. Args: state: The current state of the generation process. """ raise NotImplementedError
[docs] class BaseInstructionGeneratorCallback: """Intended to serve as the base class for all custom instruction generator callbacks""" def __call__(self, original_prompt: str) -> GeneratedInstructions: instructions = self.generate(original_prompt) assert isinstance(instructions, GeneratedInstructions), ( f".generate() method should return an instance of GeneratedInstructions, but found {type(instructions)}" ) return instructions
[docs] def generate(self, original_prompt) -> GeneratedInstructions: raise NotImplementedError
[docs] async def acall(self, original_prompt: str) -> GeneratedInstructions: instructions = await self.agenerate(original_prompt) assert isinstance(instructions, GeneratedInstructions), ( f".agenerate() method should return an instance of GeneratedInstructions, but found {type(instructions)}" ) return instructions
[docs] async def agenerate(self, original_prompt) -> GeneratedInstructions: raise NotImplementedError
[docs] def create_correspondent_prompt(self, respondent_prompt: str) -> str | None: """Create a correspondent prompt based on the respondent prompt. This method can be overridden by subclasses to customize correspondent prompt creation. By default, returns None, which means the conversation generator should handle it. Args: respondent_prompt: The prompt for the respondent (assistant) Returns: str: The correspondent prompt, or None if the generator should handle it """ return None
[docs] async def acreate_correspondent_prompt(self, respondent_prompt: str) -> str: """Create a correspondent prompt based on the respondent prompt asynchronously. This method can be overridden by subclasses to customize correspondent prompt creation. By default, returns None, which means the conversation generator should handle it. Args: respondent_prompt: The prompt for the respondent (assistant) Returns: str: The correspondent prompt, or None if the generator should handle it """ return None
[docs] class BaseRespondentPromptModifierCallback: """Intended to serve as the base class for all custom respondent prompt modifier callbacks""" def __call__( self, respondent_prompt: str, context: str, instruction: str ) -> GeneratedResponsePrompt: modified_prompt = self.generate(respondent_prompt, context, instruction) assert isinstance(modified_prompt, GeneratedResponsePrompt), ( f".generate() method is expected to return an instance of `GeneratedRespondentPrompt`, but found {type(modified_prompt)}" ) return modified_prompt
[docs] def generate( self, respondent_prompt, context, instruction ) -> GeneratedResponsePrompt: raise NotImplementedError
[docs] async def acall( self, respondent_prompt: str, context: str, instruction: str ) -> GeneratedResponsePrompt: modified_prompt = await self.agenerate(respondent_prompt, context, instruction) assert isinstance(modified_prompt, GeneratedResponsePrompt), ( f".agenerate() method is expected to return an instance of `GeneratedRespondentPrompt`, but found {type(modified_prompt)}" ) return modified_prompt
[docs] async def agenerate( self, respondent_prompt, context, instruction ) -> GeneratedResponsePrompt: raise NotImplementedError
def _maybe_augment_context(self, instruction: str, current_context: str) -> str: if hasattr(self, "augment_context"): return self.augment_context(instruction, current_context) else: return current_context