from abc import abstractmethod
from pathlib import Path
from typing import List, Optional, Protocol, Dict, Any
import json
from filelock import FileLock
from datetime import datetime
import asyncio
from .types import (
ConversationWithContext,
EvaluatedConversationWithContext,
Document,
)
from pydantic import BaseModel
[docs]
class BaseStorage(Protocol):
"""Protocol defining the interface for storage implementations."""
[docs]
@abstractmethod
def save_conversations(
self,
conversations: List[
EvaluatedConversationWithContext | ConversationWithContext | BaseModel
],
) -> None:
pass
[docs]
@abstractmethod
async def asave_conversations(
self,
conversations: List[
ConversationWithContext | EvaluatedConversationWithContext | BaseModel
],
) -> None:
pass
[docs]
@abstractmethod
def load_conversations(
self,
limit: int | None = None,
offset: int | None = None,
) -> List[ConversationWithContext]:
pass
[docs]
@abstractmethod
def load_documents(
self,
limit: int | None = None,
offset: int | None = None,
) -> List[Document]:
pass
[docs]
@abstractmethod
def save_documents(self, documents: List[Document]) -> None:
pass
[docs]
@abstractmethod
async def asave_documents(self, documents: List[Document]) -> None:
pass
[docs]
class JSONLStorage(BaseStorage):
"""Stores conversations and documents in JSONL format."""
def __init__(
self,
conversations_path: Optional[str | Path] = None,
documents_path: Optional[str | Path] = None,
encoding: str = "utf-8",
lock_timeout: int = 30,
):
self.encoding = encoding
self.lock_timeout = lock_timeout
self.conversations_path = (
Path(conversations_path)
if conversations_path
else self._get_default_path("conversations")
)
self.conversations_lock_path = self.conversations_path.with_suffix(
self.conversations_path.suffix + ".lock"
)
self.documents_path = (
Path(documents_path)
if documents_path
else self._get_default_path("documents")
)
self.documents_lock_path = self.documents_path.with_suffix(
self.documents_path.suffix + ".lock"
)
@staticmethod
def _get_default_path(prefix: str) -> Path:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
return Path(f"{prefix}_{timestamp}.jsonl")
[docs]
def save_conversations(
self,
conversations: List[
ConversationWithContext | EvaluatedConversationWithContext | BaseModel
],
) -> None:
with FileLock(self.conversations_lock_path, timeout=self.lock_timeout):
mode = "a" if self.conversations_path.exists() else "w"
with open(self.conversations_path, mode, encoding=self.encoding) as f:
for conv in conversations:
f.write(json.dumps(conv.model_dump(), ensure_ascii=False) + "\n")
[docs]
async def asave_conversations(
self,
conversations: List[
ConversationWithContext | EvaluatedConversationWithContext | BaseModel
],
) -> None:
def _save():
self.save_conversations(conversations)
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, _save)
[docs]
def load_conversations(
self,
limit: int | None = None,
offset: int | None = None,
) -> List[EvaluatedConversationWithContext]:
"""Load conversations from JSONL file.
Args:
limit: Maximum number of conversations to load
offset: Number of conversations to skip
Returns:
List of conversations
"""
if not self.conversations_path.exists():
return []
with FileLock(self.conversations_lock_path, timeout=self.lock_timeout):
conversations = []
current_idx = 0
with open(self.conversations_path, "r", encoding=self.encoding) as f:
for line in f:
if offset and current_idx < offset:
current_idx += 1
continue
conv_data = json.loads(line.strip())
conversations.append(EvaluatedConversationWithContext(**conv_data))
current_idx += 1
if limit and len(conversations) >= limit:
break
return conversations
[docs]
def load_documents(
self,
limit: int | None = None,
offset: int | None = None,
) -> List[Document]:
"""Load documents from JSONL file.
Args:
limit: Maximum number of documents to load
offset: Number of documents to skip
Returns:
List of documents
"""
if not self.documents_path.exists():
return []
with FileLock(self.documents_lock_path, timeout=self.lock_timeout):
documents = []
current_idx = 0
with open(self.documents_path, "r", encoding=self.encoding) as f:
for line in f:
if offset and current_idx < offset:
current_idx += 1
continue
doc_data = json.loads(line.strip())
documents.append(Document(**doc_data))
current_idx += 1
if limit and len(documents) >= limit:
break
return documents
[docs]
def save_documents(self, documents: List[Document]) -> None:
with FileLock(self.documents_lock_path, timeout=self.lock_timeout):
mode = "a" if self.documents_path.exists() else "w"
with open(self.documents_path, mode, encoding=self.encoding) as f:
for entry in documents:
f.write(entry.model_dump_json() + "\n")
[docs]
async def asave_documents(self, documents: List[Document]) -> None:
def _save():
self.save_documents(documents)
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, _save)
[docs]
class SQLStorage(BaseStorage):
"""Stores conversations and documents using SQLAlchemy."""
def __init__(
self,
url: str,
conversations_table_name: str = "conversations",
documents_table_name: str = "documents",
metadata_fields: Optional[List[str]] = None,
batch_size: int = 100,
):
try:
from sqlalchemy import (
create_engine,
MetaData,
Table,
Column,
Integer,
String,
JSON,
DateTime,
)
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
except ImportError:
raise ImportError("SQL storage requires 'sqlalchemy'.")
self.engine = create_engine(url)
self.async_engine = create_async_engine(url)
self.metadata = MetaData()
self.batch_size = batch_size
self.async_session_maker = async_sessionmaker(self.async_engine)
self.conversations_table = Table(
conversations_table_name,
self.metadata,
Column("id", Integer, primary_key=True),
Column("conversations", JSON),
Column("instruction_context", String, nullable=True),
Column("response_context", String, nullable=True),
Column("metadata", JSON, nullable=True),
Column("timestamp", DateTime),
Column("evaluation", JSON, nullable=True),
)
self.documents_table = Table(
documents_table_name,
self.metadata,
Column("_id", Integer, primary_key=True),
Column("id", String),
Column("text", String),
Column("personas", JSON, nullable=True),
Column("metadata", JSON, nullable=True),
)
# Create table if it doesn't exist
self.metadata.create_all(self.engine)
# Create indexes for metadata fields
if metadata_fields:
for field in metadata_fields:
idx_name = f"idx_{conversations_table_name}_{field}"
self.engine.execute(
f"CREATE INDEX IF NOT EXISTS {idx_name} "
f"ON {conversations_table_name} ((metadata->'{field}'))"
)
[docs]
def save_conversations(
self,
conversations: List[ConversationWithContext | EvaluatedConversationWithContext],
) -> None:
"""Save conversations to database.
Args:
conversations: List of conversations to save
"""
records = []
for conv in conversations:
data = conv.model_dump()
# If it's a generic BaseModel, wrapping it in a structure compatible with the table
# or we might need a separate table/method for generic types.
# GUIDANCE: For now, we assume if it's NOT a Conversation object, we try to fit it or fail if table doesn't match.
# But the requirement was mainly for JSONL.
# Let's simple check if keys exist.
if "conversations" in data:
record = {
"conversations": data["conversations"],
"instruction_context": data.get("instruction_context"),
"response_context": data.get("response_context"),
"metadata": data.get("metadata", {}),
"evaluation": data.get("evaluation"),
"timestamp": datetime.now(),
}
else:
# Fallback for generic models: store the whole model in 'metadata' or 'conversations' column?
# SQLStorage relies on specific schema. Storing generic models in 'conversations' column as JSON.
record = {
"conversations": data, # Storing the whole object in the JSON column
"timestamp": datetime.now(),
"metadata": {},
"instruction_context": None,
"response_context": None,
}
records.append(record)
# Insert in batches
with self.engine.begin() as conn:
for i in range(0, len(records), self.batch_size):
batch = records[i : i + self.batch_size]
conn.execute(self.conversations_table.insert(), batch)
[docs]
async def asave_conversations(
self,
conversations: List[
ConversationWithContext | EvaluatedConversationWithContext | BaseModel
],
) -> None:
"""Save conversations to database asynchronously."""
records = []
for conv in conversations:
data = conv.model_dump()
if "conversations" in data:
record = {
"conversations": data["conversations"],
"instruction_context": data.get("instruction_context"),
"response_context": data.get("response_context"),
"metadata": data.get("metadata", {}),
"evaluation": data.get("evaluation"),
"timestamp": datetime.now(),
}
else:
record = {
"conversations": data,
"timestamp": datetime.now(),
"metadata": {},
"instruction_context": None,
"response_context": None,
}
records.append(record)
async with self.async_session_maker() as session:
async with session.begin():
for i in range(0, len(records), self.batch_size):
batch = records[i : i + self.batch_size]
await session.execute(self.conversations_table.insert(), batch)
[docs]
def load_conversations(
self,
limit: Optional[int] = None,
offset: Optional[int] = None,
filters: Optional[Dict[str, Any]] = None,
order_by: Optional[List[tuple]] = None,
) -> List[EvaluatedConversationWithContext]:
"""Load conversations from database with filtering and sorting.
Args:
limit: Maximum number of conversations to load
offset: Number of conversations to skip
filters: Dict of field-value pairs for filtering
order_by: List of (field, direction) tuples for sorting
Returns:
List of conversations
"""
query = self.conversations_table.select()
if filters:
for field, value in filters.items():
if field.startswith("metadata."):
# Handle metadata field filtering
_, key = field.split(".", 1)
query = query.where(
self.conversations_table.c.metadata[key] == value
)
else:
# Handle regular field filtering
query = query.where(
getattr(self.conversations_table.c, field) == value
)
if order_by:
for field, direction in order_by:
col = getattr(self.conversations_table.c, field)
query = query.order_by(col.desc() if direction == -1 else col)
else:
# Default sort by timestamp descending
query = query.order_by(self.conversations_table.c.timestamp.desc())
if offset:
query = query.offset(offset)
if limit:
query = query.limit(limit)
with self.engine.connect() as conn:
result = conn.execute(query)
return [
EvaluatedConversationWithContext(
conversations=row.conversations,
instruction_context=row.instruction_context,
response_context=row.response_context,
metadata=row.metadata,
evaluation=row.evaluation,
)
for row in result
]
[docs]
def load_documents(
self,
limit: Optional[int] = None,
offset: Optional[int] = None,
filters: Optional[Dict[str, Any]] = None,
order_by: Optional[List[tuple]] = None,
) -> List[Document]:
"""Load documents from database with filtering and sorting.
Args:
limit: Maximum number of documents to load
offset: Number of documents to skip
filters: Dict of field-value pairs for filtering
order_by: List of (field, direction) tuples for sorting
Returns:
List of documents
"""
query = self.documents_table.select()
if filters:
for field, value in filters.items():
if field.startswith("metadata."):
# Handle metadata field filtering
_, key = field.split(".", 1)
query = query.where(self.documents_table.c.metadata[key] == value)
else:
# Handle regular field filtering
query = query.where(getattr(self.documents_table.c, field) == value)
if order_by:
for field, direction in order_by:
col = getattr(self.documents_table.c, field)
query = query.order_by(col.desc() if direction == -1 else col)
if offset:
query = query.offset(offset)
if limit:
query = query.limit(limit)
with self.engine.connect() as conn:
result = conn.execute(query)
return [
Document(
id=row.id,
text=row.text,
personas=row.personas,
metadata=row.metadata,
)
for row in result
]
[docs]
def save_documents(self, documents: List[Document]) -> None:
records = [document.model_dump(mode="json") for document in documents]
with self.engine.begin() as conn:
for i in range(0, len(records), self.batch_size):
batch = records[i : i + self.batch_size]
conn.execute(self.documents_table.insert(), batch)
[docs]
async def asave_documents(self, documents: List[Document]) -> None:
records = [document.model_dump(mode="json") for document in documents]
async with self.async_session_maker() as session:
async with session.begin():
for i in range(0, len(records), self.batch_size):
batch = records[i : i + self.batch_size]
await session.execute(self.documents_table.insert(), batch)