import asyncio
from typing import Optional
from ..base import (
BaseRespondentPromptModifierCallback,
)
from ..prompts import (
default_rag_respondent_prompt_with_context,
default_respondent_prompt_with_context,
)
from ..retrievers import ContextRetriever
from ..types import (
GeneratedResponsePrompt,
)
[docs]
class WithContextRespondentPromptModifier(BaseRespondentPromptModifierCallback):
"""Modifies respondent prompt by adding context.
Args:
prompt_template: Custom prompt template. If None, uses `default_respondent_prompt_with_context`.
If it contains `{prompt}` and/or `{context}`, they will be replaced by the respondent prompt and the context, respectively."""
def __init__(self, prompt_template: Optional[str] = None):
self.prompt_template = (
prompt_template
if prompt_template is not None
else default_respondent_prompt_with_context
)
self.should_inject_prompt = "{prompt}" in self.prompt_template
self.should_inject_context = "{context}" in self.prompt_template
[docs]
def generate(
self, respondent_prompt: str, context: str, instruction: str
) -> GeneratedResponsePrompt:
"""Generates a modified respondent prompt by injecting context and instructions.
Args:
respondent_prompt: The original prompt for the respondent
context: Additional context to be included
instruction: The instruction associated with the prompt
Returns:
GeneratedResponsePrompt containing the modified prompt and context
"""
additional_context = self._maybe_augment_context(instruction, context)
if self.should_inject_prompt and self.should_inject_context:
modified_prompt = self.prompt_template.format(
prompt=respondent_prompt, context=additional_context
)
elif self.should_inject_prompt:
modified_prompt = self.prompt_template.format(prompt=respondent_prompt)
elif self.should_inject_context:
modified_prompt = self.prompt_template.format(context=additional_context)
else:
modified_prompt = respondent_prompt
return GeneratedResponsePrompt(
prompt=modified_prompt,
context=additional_context,
)
[docs]
async def agenerate(
self, respondent_prompt: str, context: str, instruction: str
) -> GeneratedResponsePrompt:
"""Generates a modified respondent prompt by injecting context and instructions asynchronously."""
if hasattr(self, "augment_context_async"):
additional_context = await self.augment_context_async(instruction, context)
else:
additional_context = self._maybe_augment_context(instruction, context)
if self.should_inject_prompt and self.should_inject_context:
modified_prompt = self.prompt_template.format(
prompt=respondent_prompt, context=additional_context
)
elif self.should_inject_prompt:
modified_prompt = self.prompt_template.format(prompt=respondent_prompt)
elif self.should_inject_context:
modified_prompt = self.prompt_template.format(context=additional_context)
else:
modified_prompt = respondent_prompt
return GeneratedResponsePrompt(
prompt=modified_prompt,
context=additional_context,
)
[docs]
class WithRAGRespondentPromptModifier(WithContextRespondentPromptModifier):
"""Modifies respondent prompt by adding relevant context using a retrieval strategy.
Args:
retriever: Strategy for retrieving relevant context
prompt_template: Custom prompt template. If None, uses `default_rag_respondent_prompt_with_context`."""
def __init__(
self,
retriever: ContextRetriever,
prompt_template: Optional[str] = None,
):
super().__init__(
prompt_template
if prompt_template is not None
else default_rag_respondent_prompt_with_context
)
self.retriever = retriever
[docs]
def augment_context(self, instruction: str, current_context: str) -> str:
"""Augment existing context with relevant information using the retriever.
Args:
instruction: The current instruction/question
current_context: Any existing context
Returns:
str: Combined context from both sources
"""
rag_context = self.retriever.get_context(instruction)
if current_context:
return (
f"{current_context}\n\nAdditional relevant information:\n{rag_context}"
)
return rag_context
[docs]
async def augment_context_async(
self, instruction: str, current_context: str
) -> str:
"""Async RAG augmentation; prefers ``retriever.aget_context`` when defined."""
if hasattr(self.retriever, "aget_context"):
rag_context = await self.retriever.aget_context(instruction) # type: ignore[union-attr]
else:
rag_context = await asyncio.to_thread(
self.retriever.get_context, instruction
)
if current_context:
return (
f"{current_context}\n\nAdditional relevant information:\n{rag_context}"
)
return rag_context