import asyncio
import logging
import math
import time
from typing import Literal, Optional, Union
from tqdm import tqdm
from .common import (
default_model_name,
default_safety_settings,
resolve_generation_max_concurrency,
)
from .key_management import SmartKeyPool
from .monitoring import GenerationMonitor
from .prompts import (
parse_personas,
persona_to_persona_generation_prompt_tmpl,
text_to_persona_generation_prompt_tmpl,
)
from .providers import DocumentProvider, InMemoryDocumentProvider, LLMFactory
from .storage import BaseStorage, JSONLStorage
from .types import Document, PersonaEntry
EXPECTED_PERSONA_COUNT = 5
MAX_PERSONA_GENERATION_ATTEMPTS = 3
class PersonaGenerationContractError(ValueError):
"""Raised when persona generation cannot satisfy the fixed-width contract."""
[docs]
class PersonaGenerator:
def __init__(
self,
api_key: str | SmartKeyPool,
model_name: str | None = None,
safety_settings: list[dict[str, str]] | None = None,
model_provider_name: Literal["gemini", "openai", "deepseek"] = "gemini",
storage: Optional[BaseStorage] = None,
monitor: Optional[GenerationMonitor] = None,
max_concurrency: int | None = None,
):
self.key_pool = (
api_key
if isinstance(api_key, SmartKeyPool)
else SmartKeyPool.from_single_key(api_key)
)
self.model_provider_name = model_provider_name
self.model_name = model_name if model_name is not None else default_model_name
self.safety_settings = (
safety_settings if safety_settings is not None else default_safety_settings
)
self.storage = storage or JSONLStorage()
self.monitor = monitor
self.max_concurrency = self._resolve_max_concurrency(max_concurrency)
self.semaphore = asyncio.Semaphore(self.max_concurrency)
def _resolve_max_concurrency(self, max_concurrency: int | None) -> int:
return resolve_generation_max_concurrency(
self.model_provider_name,
max_concurrency,
)
def _normalize_persona_text(self, persona: str) -> str:
return " ".join(persona.split()).strip()
def _normalize_personas(self, personas: list[str]) -> list[str]:
normalized_personas: list[str] = []
seen: set[str] = set()
for persona in personas:
normalized = self._normalize_persona_text(persona)
if not normalized or normalized in seen:
continue
seen.add(normalized)
normalized_personas.append(normalized)
return normalized_personas
def _select_fixed_width_personas(self, personas: list[str]) -> list[str]:
normalized_personas = self._normalize_personas(personas)
if len(normalized_personas) < EXPECTED_PERSONA_COUNT:
raise PersonaGenerationContractError(
"Persona generation did not return enough unique personas"
)
return normalized_personas[:EXPECTED_PERSONA_COUNT]
def _generate_personas_with_retry(
self,
generate_content,
prompt: str,
):
last_error: PersonaGenerationContractError | None = None
for _ in range(MAX_PERSONA_GENERATION_ATTEMPTS):
response = generate_content(prompt)
try:
personas = self._select_fixed_width_personas(
parse_personas(response.text)
)
except PersonaGenerationContractError as error:
last_error = error
continue
return personas, response
raise last_error or PersonaGenerationContractError(
"Persona generation could not satisfy the fixed-width contract"
)
async def _agenerate_personas_with_retry(
self,
generate_content,
prompt: str,
):
last_error: PersonaGenerationContractError | None = None
for _ in range(MAX_PERSONA_GENERATION_ATTEMPTS):
response = await generate_content(prompt)
try:
personas = self._select_fixed_width_personas(
parse_personas(response.text)
)
except PersonaGenerationContractError as error:
last_error = error
continue
return personas, response
raise last_error or PersonaGenerationContractError(
"Persona generation could not satisfy the fixed-width contract"
)
def _track_success(
self,
*,
start_time: float,
response,
operation: str,
text_length: int,
generation: int | None = None,
) -> None:
if self.monitor is None:
return
metadata = {
"operation": operation,
"text_length": text_length,
}
if generation is not None:
metadata["generation"] = generation
self.monitor.track_generation(
duration=time.time() - start_time,
success=True,
prompt_token_count=response.prompt_token_count,
completion_token_count=response.completion_token_count,
total_token_count=response.total_token_count,
model_name=response.model_name,
metadata=metadata,
)
def _track_failure(
self,
*,
start_time: float,
error: Exception,
operation: str,
text_length: int,
generation: int | None = None,
) -> None:
if self.monitor is None:
return
metadata = {
"operation": operation,
"text_length": text_length,
}
if generation is not None:
metadata["generation"] = generation
self.monitor.track_generation(
duration=time.time() - start_time,
success=False,
error=str(error),
metadata=metadata,
)
[docs]
def expected_persona_count(self, n_iterations: int) -> int:
if n_iterations < 0:
raise ValueError("n_iterations must be >= 0")
total = 0
layer_size = EXPECTED_PERSONA_COUNT
for _ in range(n_iterations + 1):
total += layer_size
layer_size *= EXPECTED_PERSONA_COUNT
return total
def _resolve_active_doc_count(
self,
docs_to_process: list[Document],
provider,
) -> int:
if hasattr(provider, "_doc_sampling_weights"):
active_doc_count = sum(
1
for doc in docs_to_process
if provider._doc_sampling_weights.get(doc.id, 0.0) > 0
)
else:
active_doc_count = len(docs_to_process)
return max(active_doc_count, 1)
def _resolve_target_per_document(
self,
provider,
docs_to_process: list[Document],
target_data_count: int | None,
num_random_contexts: int,
) -> int | None:
target_context_usage_count = None
if hasattr(provider, "get_target_context_usage_count"):
target_context_usage_count = provider.get_target_context_usage_count()
else:
target_context_usage_count = getattr(
provider,
"target_context_usage_count",
None,
)
if isinstance(target_context_usage_count, int) and target_context_usage_count > 0:
contexts_per_row = max(int(num_random_contexts), 1)
return max(math.ceil(target_context_usage_count / contexts_per_row), 1)
if target_data_count is None:
return None
requested = max(int(target_data_count), 1)
active_doc_count = self._resolve_active_doc_count(docs_to_process, provider)
return max(math.ceil(requested / active_doc_count), 1)
def _resolve_auto_n_iterations(self, target_per_document: int | None) -> int:
if target_per_document is None or target_per_document <= 0:
return 0
current_n = 0
current_total = self.expected_persona_count(current_n)
if target_per_document <= current_total:
return current_n
previous_n = current_n
previous_total = current_total
while current_total < target_per_document:
previous_n = current_n
previous_total = current_total
current_n += 1
current_total = self.expected_persona_count(current_n)
if (target_per_document - previous_total) <= (
current_total - target_per_document
):
return previous_n
return current_n
def _resolve_generation_iterations(
self,
*,
provider,
docs_to_process: list[Document],
n_iterations: int | None,
target_data_count: int | None,
num_random_contexts: int,
) -> int:
if n_iterations is not None:
if n_iterations < 0:
raise ValueError("n_iterations must be >= 0")
return n_iterations
target_per_document = self._resolve_target_per_document(
provider=provider,
docs_to_process=docs_to_process,
target_data_count=target_data_count,
num_random_contexts=num_random_contexts,
)
resolved_iterations = self._resolve_auto_n_iterations(target_per_document)
if self.monitor is not None:
self.monitor.log_info(
"Resolved persona generation iterations",
n_iterations=resolved_iterations,
target_per_document=target_per_document,
active_doc_count=self._resolve_active_doc_count(docs_to_process, provider),
)
return resolved_iterations
[docs]
def generate_from_text(self, text: str) -> list[str]:
api_key = self.key_pool.get_next_key()
llm = LLMFactory.create(
self.model_provider_name,
self.model_name,
api_key,
safety_settings=self.safety_settings,
)
start_time = time.time()
try:
prompt = text_to_persona_generation_prompt_tmpl.format(text=text)
personas, response = self._generate_personas_with_retry(
llm.generate_content,
prompt,
)
self._track_success(
start_time=start_time,
response=response,
operation="text_to_persona_generation",
text_length=len(text),
)
return personas
except Exception as error:
if not isinstance(error, PersonaGenerationContractError):
self.key_pool.report_error(api_key)
self._track_failure(
start_time=start_time,
error=error,
operation="text_to_persona_generation",
text_length=len(text),
)
raise
[docs]
async def agenerate_from_text(self, text: str) -> list[str]:
async with self.semaphore:
api_key = await self.key_pool.aget_next_key()
llm = LLMFactory.create(
self.model_provider_name,
self.model_name,
api_key,
safety_settings=self.safety_settings,
)
start_time = time.time()
try:
prompt = text_to_persona_generation_prompt_tmpl.format(text=text)
personas, response = await self._agenerate_personas_with_retry(
llm.agenerate_content,
prompt,
)
self._track_success(
start_time=start_time,
response=response,
operation="text_to_persona_generation",
text_length=len(text),
)
return personas
except Exception as error:
if not isinstance(error, PersonaGenerationContractError):
await self.key_pool.areport_error(api_key)
self._track_failure(
start_time=start_time,
error=error,
operation="text_to_persona_generation",
text_length=len(text),
)
raise
[docs]
def generate_from_persona(self, persona: str, generation: int = 1) -> list[str]:
api_key = self.key_pool.get_next_key()
llm = LLMFactory.create(
self.model_provider_name,
self.model_name,
api_key,
safety_settings=self.safety_settings,
)
start_time = time.time()
try:
prompt = persona_to_persona_generation_prompt_tmpl.format(personas=persona)
personas, response = self._generate_personas_with_retry(
llm.generate_content,
prompt,
)
self._track_success(
start_time=start_time,
response=response,
operation="persona_to_persona_generation",
text_length=len(persona),
generation=generation,
)
return personas
except Exception as error:
if not isinstance(error, PersonaGenerationContractError):
self.key_pool.report_error(api_key)
self._track_failure(
start_time=start_time,
error=error,
operation="persona_to_persona_generation",
text_length=len(persona),
generation=generation,
)
raise
[docs]
async def agenerate_from_persona(
self, persona: str, generation: int = 1
) -> list[str]:
async with self.semaphore:
api_key = await self.key_pool.aget_next_key()
llm = LLMFactory.create(
self.model_provider_name,
self.model_name,
api_key,
safety_settings=self.safety_settings,
)
start_time = time.time()
try:
prompt = persona_to_persona_generation_prompt_tmpl.format(
personas=persona
)
personas, response = await self._agenerate_personas_with_retry(
llm.agenerate_content,
prompt,
)
self._track_success(
start_time=start_time,
response=response,
operation="persona_to_persona_generation",
text_length=len(persona),
generation=generation,
)
return personas
except Exception as error:
if not isinstance(error, PersonaGenerationContractError):
await self.key_pool.areport_error(api_key)
self._track_failure(
start_time=start_time,
error=error,
operation="persona_to_persona_generation",
text_length=len(persona),
generation=generation,
)
raise
async def _agenerate_persona_chains(
self, base_personas: list[str], depth: int
) -> list[PersonaEntry]:
all_entries = []
current_personas = base_personas
for i in range(depth):
new_personas = []
results = await asyncio.gather(
*[
self.agenerate_from_persona(p, generation=i + 1)
for p in current_personas
],
return_exceptions=True,
)
for result in results:
if isinstance(result, Exception):
logging.warning(f"Persona generation failed: {result}")
continue
new_personas.extend(result)
entry = PersonaEntry(
descriptions=new_personas,
metadata={"generation_depth": i + 1},
)
all_entries.append(entry)
current_personas = new_personas
if not new_personas:
break
return all_entries
[docs]
async def generate_from_documents(
self,
documents: Union[DocumentProvider, list[str]],
max_docs: int | None = None,
n_iterations: int | None = None,
target_data_count: int | None = None,
num_random_contexts: int = 1,
):
if isinstance(documents, list):
documents = InMemoryDocumentProvider(documents)
if max_docs is not None and max_docs < len(documents):
docs_to_process = documents.get_documents(n=max_docs)
else:
docs_to_process = documents.get_all()
resolved_iterations = self._resolve_generation_iterations(
provider=documents,
docs_to_process=docs_to_process,
n_iterations=n_iterations,
target_data_count=target_data_count,
num_random_contexts=num_random_contexts,
)
pbar = tqdm(total=len(docs_to_process), desc="Generating Personas...")
task_errors: list[Exception] = []
completed_results: list[tuple[Document, Document]] = []
async def worker_task(doc: Document):
enriched_doc = doc.model_copy(deep=True)
base_personas = await self.agenerate_from_text(doc.text)
persona_entries = [
PersonaEntry(
descriptions=base_personas, metadata={"generation_depth": 0}
)
]
if resolved_iterations > 0:
deeper_personas = await self._agenerate_persona_chains(
base_personas, depth=resolved_iterations
)
persona_entries.extend(deeper_personas)
pbar.update(1)
enriched_doc.personas.extend(persona_entries)
return doc, enriched_doc
tasks = [asyncio.create_task(worker_task(doc)) for doc in docs_to_process]
for future in asyncio.as_completed(tasks):
try:
original_doc, enriched_doc = await future
completed_results.append((original_doc, enriched_doc))
except Exception as e:
logging.error(f"A task failed: {e}", exc_info=True)
task_errors.append(e)
pbar.close()
if task_errors:
raise task_errors[0]
enriched_docs = [enriched_doc for _, enriched_doc in completed_results]
if self.storage and enriched_docs:
await self.storage.asave_documents(enriched_docs)
for original_doc, enriched_doc in completed_results:
original_doc.personas = enriched_doc.personas