import asyncio
import logging
import random
import time
import traceback
import warnings
from typing import Any, AsyncGenerator, Dict, List, Literal, Optional, Union
from tqdm.asyncio import tqdm
from .base import (
BaseGenerator,
BaseInstructionGeneratorCallback,
BaseRespondentPromptModifierCallback,
BaseStoppingCallback,
)
from .callbacks import FixedNumberStoppingCallback
from .common import (
default_model_name,
default_safety_settings,
resolve_generation_max_concurrency,
)
from .evaluator import (
ConversationJudge,
ConversationJudgeConfig,
default_embedding_provider_config,
)
from .key_management import SmartKeyPool
from .monitoring import GenerationMonitor
from .prompts import get_correspondent_instruction_generation_prompt
from .providers import ChatSession, LLMFactory
from .providers.embedding_providers import EmbeddingProvider
from .storage import BaseStorage, JSONLStorage
from .types import (
Conversation,
ConversationEntry,
ConversationWithContext,
EvaluatedConversationWithContext,
GenerationState,
GradeSchema,
Role,
)
logger = logging.getLogger(__name__)
[docs]
class ConversationGenerator(BaseGenerator):
"""Generates conversations between a correspondent (question generator) and a respondent (answer generator) asynchronously.
Args:
respondent_prompt: System prompt to the respondent, e.g., assistant that you want you fine-tune on this dataset
api_key: Either a single API key string or a SmartKeyPool instance for LLM use
correspondent_prompt: System prompt to the correspondent, e.g., model that roleplays a user of the assistant
that you want to fine-tune on this dataset
model_name: Model name to use
safety_settings: Safety settings for the model
auto_improve: Whether to try to improve low-quality generations
evaluator_model_name: Model name for the evaluator LLM when auto_improve is True.
embedding_provider: Optional shared :class:`~afterimage.providers.embedding_providers.EmbeddingProvider` for embedding metrics.
embedding_provider_config: JSON-style config for :class:`~afterimage.providers.embedding_providers.EmbeddingProviderFactory` when ``embedding_provider`` is omitted (defaults by chat provider).
judge_config: Optional :class:`~afterimage.evaluator.ConversationJudgeConfig` (aggregation and grade thresholds).
model_provider_name: Provider used for accessing LLMs. Supported values are `"gemini"`, `"openai"`, and `"deepseek"`.
storage: Storage implementation for saving conversations.
If `None`, creates JSONLStorage with datetime-based filename.
monitor: GenerationMonitor instance for tracking generation metrics.
If `None`, a default one is created.
instruction_generator_callback: Callback for instruction generation. Can also be passed to generate() method (deprecated).
respondent_prompt_modifier: Callback to modify respondent prompts. Can also be passed to generate() method (deprecated).
"""
def __init__(
self,
respondent_prompt: str,
api_key: str | SmartKeyPool,
correspondent_prompt: str | None = None,
model_name: str | None = None,
safety_settings: List[Dict[str, str]] | None = None,
auto_improve: bool = False,
evaluator_model_name: str | None = None,
model_provider_name: Literal["gemini", "openai", "deepseek"] = "gemini",
embedding_provider: EmbeddingProvider | None = None,
embedding_provider_config: dict[str, Any] | None = None,
judge_config: ConversationJudgeConfig | None = None,
storage: Optional[BaseStorage] = None,
monitor: Optional[GenerationMonitor] = None,
instruction_generator_callback: BaseInstructionGeneratorCallback | None = None,
respondent_prompt_modifier: BaseRespondentPromptModifierCallback | None = None,
):
self.monitor: GenerationMonitor = (
monitor or GenerationMonitor()
) # ensure it's always created
self.key_pool: SmartKeyPool = (
api_key
if isinstance(api_key, SmartKeyPool)
else SmartKeyPool.from_single_key(api_key)
)
self.model_provider_name = model_provider_name
self.model_name = model_name if model_name is not None else default_model_name
self.monitor.log_info(
"Model info set",
model_provider=self.model_provider_name,
model_name=self.model_name,
num_api_keys=len(self.key_pool),
)
self.safety_settings = (
safety_settings if safety_settings is not None else default_safety_settings
)
# users should pass at least one of correspondent prompt and instruction generator callback.
# if both are passed, the correspondent prompt will be used as is passed.
# if neither are passed, raise an error.
if correspondent_prompt is None and instruction_generator_callback is None:
error = ValueError(
"At least one of `correspondent_prompt` or `instruction_generator_callback` should be passed."
)
self.monitor.log_error(
"failed to initialize because correspondent prompt and instruction generator callback are both None",
error,
)
raise error
if correspondent_prompt is None:
warning_msg = "A correspondent prompt will be automatically created because you did not pass one."
self.monitor.log_warning(warning_msg)
warnings.warn(warning_msg)
self.respondent_prompt = respondent_prompt
self.monitor.log_info(
"Respondent prompt set",
respondent_prompt=respondent_prompt,
)
self.correspondent_prompt = correspondent_prompt
self.instruction_generator_callback = instruction_generator_callback
self.respondent_prompt_modifier = respondent_prompt_modifier
self.evaluator = None
if auto_improve:
evaluator_model_name = (
evaluator_model_name
if evaluator_model_name is not None
else self.model_name
)
evaluator_llm = LLMFactory.create(
self.model_provider_name,
evaluator_model_name,
self.key_pool,
safety_settings=self.safety_settings,
)
if embedding_provider is not None:
self.evaluator = ConversationJudge(
llm=evaluator_llm,
embedding_provider=embedding_provider,
monitor=self.monitor,
config=judge_config,
)
else:
embed_cfg = (
embedding_provider_config
if embedding_provider_config is not None
else default_embedding_provider_config(self.model_provider_name)
)
self.evaluator = ConversationJudge.from_factory(
evaluator_llm,
key_pool=self.key_pool,
model_provider_name=self.model_provider_name,
embedding_provider_config=embed_cfg,
monitor=self.monitor,
config=judge_config,
)
self.initiators = []
self.storage = storage or JSONLStorage()
if (
self.instruction_generator_callback
and self.instruction_generator_callback.monitor is None
):
self.instruction_generator_callback.set_monitor(self.monitor)
[docs]
async def create_correspondent_prompt(self, assistant_prompt: str) -> str:
"""Create a correspondent prompt based on the assistant prompt."""
start_time = time.time()
api_key: str | None = None
try:
prompt = get_correspondent_instruction_generation_prompt(
assistant_prompt=assistant_prompt
)
api_key = await self.key_pool.aget_next_key()
model = LLMFactory.create(
self.model_provider_name,
self.model_name,
api_key=api_key,
safety_settings=self.safety_settings,
)
response = await model.agenerate_content(prompt=prompt, temperature=0.7)
prompt = (
response.text.strip()
.lstrip("<user_system_prompt>")
.rstrip("</user_system_prompt>")
.strip()
)
if self.monitor:
self.monitor.track_generation(
duration=time.time() - start_time,
success=True,
prompt_token_count=response.prompt_token_count,
completion_token_count=response.completion_token_count,
total_token_count=response.total_token_count,
model_name=response.model_name,
metadata={"operation": "correspondent_prompt_generation"},
)
return prompt
except Exception as e:
if self.monitor:
self.monitor.track_generation(
duration=time.time() - start_time,
success=False,
error=str(e),
metadata={
"operation": "correspondent_prompt_generation",
"error_type": e.__class__.__name__,
},
)
if api_key is not None:
await self.key_pool.areport_error(api_key)
raise
[docs]
async def create_model(self, prompt: str) -> ChatSession:
"""Creates and initializes a chat model with the given prompt."""
start_time = time.time()
api_key: str | None = None
try:
api_key = await self.key_pool.aget_next_key()
model = LLMFactory.create(
self.model_provider_name,
self.model_name,
api_key=api_key,
system_instruction=prompt,
safety_settings=self.safety_settings,
)
chat = await model.astart_chat()
if self.monitor:
self.monitor.record_metric(
"model_creation_time",
time.time() - start_time,
metadata={"success": True},
)
return chat
except Exception as e:
if self.monitor:
self.monitor.record_metric(
"model_creation_time",
time.time() - start_time,
metadata={
"success": False,
"error": str(e),
"error_type": e.__class__.__name__,
},
)
if api_key is not None:
await self.key_pool.areport_error(api_key)
raise
[docs]
async def ask(
self, correspondent: ChatSession, answer: str | ConversationEntry
) -> str:
"""Generates a question from the correspondent based on the given answer."""
start_time = time.time()
try:
response = await correspondent.asend_message(answer)
question = response.text
self.monitor.track_generation(
duration=time.time() - start_time,
success=True,
prompt_token_count=response.prompt_token_count,
completion_token_count=response.completion_token_count,
total_token_count=response.total_token_count,
finish_reason=response.finish_reason,
model_name=response.model_name,
metadata={
"operation": "question_generation",
},
)
return question
except Exception as e:
self.monitor.track_generation(
duration=time.time() - start_time,
success=False,
error=str(e),
metadata={
"operation": "question_generation",
"error_type": e.__class__.__name__,
},
)
raise
[docs]
async def answer(
self, respondent: ChatSession, question: str | ConversationEntry
) -> ConversationEntry:
"""Generates an answer from the respondent based on the given question."""
start_time = time.time()
try:
response = await respondent.asend_message(question)
answer = response.text
self.monitor.track_generation(
duration=time.time() - start_time,
success=True,
prompt_token_count=response.prompt_token_count,
completion_token_count=response.completion_token_count,
total_token_count=response.total_token_count,
finish_reason=response.finish_reason,
model_name=response.model_name,
metadata={
"operation": "answer_generation",
},
)
return ConversationEntry(
role=Role.ASSISTANT,
content=answer,
reasoning_content=response.reasoning_content,
)
except Exception as e:
self.monitor.track_generation(
duration=time.time() - start_time,
success=False,
error=str(e),
metadata={
"operation": "answer_generation",
"error_type": e.__class__.__name__,
},
)
raise
[docs]
async def go(
self,
turns: int = 1,
first_question: str | None = None,
check_for_near_duplicates: bool = False,
correspondent_prompt: str | None = None,
respondent_prompt: str | None = None,
) -> List[ConversationEntry]:
"""Simulates a multi-turn conversation between the correspondent and respondent."""
start_time = time.time()
conversation = []
try:
if correspondent_prompt is None:
correspondent_prompt = self.correspondent_prompt
# If still None, create it using generator's method
if correspondent_prompt is None:
correspondent_prompt = await self.create_correspondent_prompt(
self.respondent_prompt
)
if respondent_prompt is None:
respondent_prompt = self.respondent_prompt
correspondent = await self.create_model(correspondent_prompt)
respondent = await self.create_model(respondent_prompt)
question = first_question or await self.ask(
correspondent, "Ask your first question."
)
self.initiators.append(question)
conversation.append(ConversationEntry(role=Role.USER, content=question))
for turn in range(turns):
answer_entry = await self.answer(respondent, question)
conversation.append(answer_entry)
if (turn + 1) == turns:
break
else:
question = await self.ask(correspondent, answer_entry)
conversation.append(
ConversationEntry(role=Role.USER, content=question)
)
self.initiators.append(question)
self.monitor.track_generation(
duration=time.time() - start_time,
success=True,
conversation_length=len(conversation) // 2,
metadata={
"operation": "conversation_generation",
"planned_turns": turns,
"actual_turns": len(conversation) // 2,
},
)
return conversation
except Exception as e:
self.monitor.track_generation(
duration=time.time() - start_time,
success=False,
error=str(e),
metadata={
"operation": "conversation_generation",
"error_type": e.__class__.__name__,
"completed_turns": len(conversation) // 2,
},
)
raise
[docs]
async def generate_single(
self,
max_turns: int,
check_for_near_duplicates: bool = False,
instruction_generator_callback: BaseInstructionGeneratorCallback | None = None,
respondent_prompt_modifier: BaseRespondentPromptModifierCallback | None = None,
) -> AsyncGenerator[Union[EvaluatedConversationWithContext, Conversation], None]:
"""Generates conversations for a single session and yields them."""
# ainitialize ensures correspondent_prompt is set
await self.ainitialize(instruction_generator_callback)
correspondent_prompt = self.correspondent_prompt
respondent_prompt = self.respondent_prompt
turns = random.randint(1, max_turns)
if instruction_generator_callback:
if hasattr(instruction_generator_callback, "acall"):
gen_instructions = await instruction_generator_callback.acall(
correspondent_prompt
)
else:
gen_instructions = await asyncio.to_thread(
instruction_generator_callback, correspondent_prompt
)
for instruction in gen_instructions.instructions:
instruction_context = gen_instructions.context
persona = gen_instructions.persona
response_context = None
current_respondent_prompt = respondent_prompt
if respondent_prompt_modifier:
if hasattr(respondent_prompt_modifier, "acall"):
modified_respondent_prompt = (
await respondent_prompt_modifier.acall(
respondent_prompt,
context=instruction_context,
instruction=instruction,
)
)
else:
modified_respondent_prompt = await asyncio.to_thread(
respondent_prompt_modifier,
respondent_prompt,
context=instruction_context,
instruction=instruction,
)
current_respondent_prompt = modified_respondent_prompt.prompt
response_context = modified_respondent_prompt.context
conversation = await self.go(
turns=turns,
first_question=instruction,
check_for_near_duplicates=check_for_near_duplicates,
correspondent_prompt=correspondent_prompt,
respondent_prompt=current_respondent_prompt,
)
def build_conversation_row(
generated_conversation,
) -> ConversationWithContext:
return ConversationWithContext(
conversations=generated_conversation,
instruction_context=instruction_context,
response_context=response_context,
persona=persona,
metadata={
"context_id": gen_instructions.context_id,
"context_ids": gen_instructions.context_ids,
"persona_name": persona,
"persona_generation_depth": (
gen_instructions.persona_generation_depth
),
},
)
conversation_row = build_conversation_row(conversation)
evaluation_grade = GradeSchema.NOT_ACCEPTABLE
while self.evaluator and evaluation_grade in [
GradeSchema.NOT_ACCEPTABLE,
GradeSchema.BAD,
GradeSchema.NEEDS_IMPROVEMENT,
]:
evaluated_conversation = await self.evaluator.aevaluate_row(
conversation_row
)
if evaluated_conversation.evaluation.overall_grade in [
GradeSchema.NOT_ACCEPTABLE,
GradeSchema.BAD,
GradeSchema.NEEDS_IMPROVEMENT,
]:
conversation = await self.go(
turns=turns,
first_question=instruction,
check_for_near_duplicates=check_for_near_duplicates,
correspondent_prompt=correspondent_prompt,
respondent_prompt=current_respondent_prompt,
)
conversation_row = build_conversation_row(conversation)
else:
evaluation_grade = (
evaluated_conversation.evaluation.overall_grade
)
conversation_row = evaluated_conversation
yield conversation_row
else:
raise ValueError("An `instruction_generator_callback` must be provided.")
[docs]
async def generate(
self,
num_dialogs: int | None = None,
max_turns: int = 1,
stopping_criteria: Optional[List[BaseStoppingCallback]] = None,
instruction_generator_callback: BaseInstructionGeneratorCallback | None = None,
respondent_prompt_modifier: BaseRespondentPromptModifierCallback | None = None,
max_concurrency: int | None = None,
) -> None:
"""Generates multiple conversation dialogs until stopping criteria is met.
Args:
num_dialogs: Number of dialogs to generate. Defaults to 5 if no other stopping criteria is specified.
max_turns: Maximum number of turns per dialog. Actual number of turns is randomly sampled from 1 .. max_turns.
stopping_criteria: A list of callbacks to determine when to stop generation. If num_dialogs is specified, FixedNumberStoppingCallback is added to this list automatically.
instruction_generator_callback: Callback for instruction generation.
Deprecated: Pass this to the constructor instead. Defaults to None.
respondent_prompt_modifier: Callback to modify respondent prompts.
Deprecated: Pass this to the constructor instead. Defaults to None.
max_concurrency: Number of concurrent generations. Defaults to 8 for
DeepSeek and 4 for other providers.
"""
if instruction_generator_callback is not None:
warnings.warn(
"Passing `instruction_generator_callback` to `generate()` is deprecated. "
"Please pass it to the constructor instead.",
DeprecationWarning,
stacklevel=2,
)
else:
instruction_generator_callback = self.instruction_generator_callback
if respondent_prompt_modifier is not None:
warnings.warn(
"Passing `respondent_prompt_modifier` to `generate()` is deprecated. "
"Please pass it to the constructor instead.",
DeprecationWarning,
stacklevel=2,
)
else:
respondent_prompt_modifier = self.respondent_prompt_modifier
if instruction_generator_callback is None:
error = ValueError("An `instruction_generator_callback` must be provided.")
self.monitor.log_error("No instruction generator callback set", error)
raise error
else:
self.monitor.log_info(
"Instruction generator callback set",
type=instruction_generator_callback.__class__.__name__,
)
if instruction_generator_callback.monitor is None:
instruction_generator_callback.set_monitor(self.monitor)
if self.correspondent_prompt is None:
self.monitor.log_info("No correspondent prompt set, initializing...")
await self.ainitialize(instruction_generator_callback)
# Handle stopping criteria
final_stopping_criteria = stopping_criteria or []
if num_dialogs is not None:
final_stopping_criteria.append(FixedNumberStoppingCallback(n=num_dialogs))
# Default if nothing specified
if not final_stopping_criteria:
num_dialogs = 5
final_stopping_criteria.append(FixedNumberStoppingCallback(n=num_dialogs))
self._configure_context_sampling(
instruction_generator_callback,
final_stopping_criteria,
)
self._configure_persona_sampling(
instruction_generator_callback,
num_requested=num_dialogs,
stopping_criteria=final_stopping_criteria,
)
state = GenerationState(
num_requested=num_dialogs or 0,
monitor=self.monitor,
stop_event=asyncio.Event(),
)
resolved_max_concurrency = self._resolve_max_concurrency(max_concurrency)
semaphore = asyncio.Semaphore(resolved_max_concurrency)
async def save_conversations(conversations: list[ConversationWithContext]):
if conversations:
if hasattr(self.storage, "asave_conversations"):
await self.storage.asave_conversations(conversations)
else:
await asyncio.to_thread(
self.storage.save_conversations, conversations
)
async def worker_task():
while not state.stop_event.is_set():
async with semaphore:
if state.stop_event.is_set():
break
try:
async for conversation in self.generate_single(
max_turns=max_turns,
instruction_generator_callback=instruction_generator_callback,
respondent_prompt_modifier=respondent_prompt_modifier,
):
# Update state and check stopping criteria
state.update(conversation)
self._record_context_usage(
instruction_generator_callback,
conversation,
)
for criteria in final_stopping_criteria:
if await criteria.should_stop(state):
self.monitor.log_info(
"Stopping criteria met, stopping generation...",
criteria=criteria.__class__.__name__,
)
state.stop_event.set()
break
await save_conversations([conversation])
if state.stop_event.is_set():
break
except Exception as e:
logger.error(f"Error in generation: {e}")
traceback.print_exc()
if self.monitor:
self.monitor.record_metric("error_rate", 1.0)
continue
pbar = tqdm(total=num_dialogs, desc="Generating...", unit="conversation")
tasks: list[asyncio.Task] = []
# Create initial set of tasks
for _ in range(resolved_max_concurrency):
tasks.append(asyncio.create_task(worker_task()))
last_count = 0
while not state.stop_event.is_set() or any(not t.done() for t in tasks):
# Update progress bar
if state.num_generated > last_count:
pbar.update(state.num_generated - last_count)
last_count = state.num_generated
if state.stop_event.is_set():
for t in tasks:
if not t.done():
t.cancel()
break
await asyncio.sleep(0.1)
if all(t.done() for t in tasks):
break
pbar.update(state.num_generated - last_count)
pbar.close()
# Wait for any remaining tasks to finish/cancel
try:
await asyncio.gather(*tasks, return_exceptions=True)
except Exception as e:
self.monitor.log_error("Error while trying to finalize generation", error=e)
self.monitor.record_metric("error_rate", 1.0)
traceback.print_exc()
finally:
self.monitor.log_info("Generation complete")
self.monitor.visualize_metrics()
def _resolve_max_concurrency(self, max_concurrency: int | None) -> int:
return resolve_generation_max_concurrency(
self.model_provider_name,
max_concurrency,
)