import logging
from langchain_openai import OpenAIEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Qdrant
from langchain_community.document_loaders import DirectoryLoader, UnstructuredFileLoader
from pathlib import Path
from langchain.schema import Document
from langchain.schema.embeddings import Embeddings
from typing import List
import yaml
from argparse import ArgumentParser
import os
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
DEFAULT_COLLECTION_NAME = "rasa"
def extract_documents(docs_folder: str) -> List[Document]:
"""Extract documents from a given folder.
Args:
docs_folder: The folder containing the documents.
Returns:
the list of documents
"""
logger.debug(f"Extracting files from: {Path(docs_folder).absolute()}")
if not Path(docs_folder).exists():
raise SystemExit(f"Directory '{docs_folder}' does not exist.")
loader = DirectoryLoader(
docs_folder, loader_cls=UnstructuredFileLoader, show_progress=True
)
return loader.load()
def create_chunks(documents: List[Document], chunk_size: int, chunk_overlap: int) -> List[Document]:
"""Splits the documents into chunks with RecursiveCharacterTextSplitter.
Args:
documents: The documents to split.
chunk_size: The size of the chunks.
chunk_overlap: The overlap of the chunks.
Returns:
The list of chunks.
"""
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=len,
)
return text_splitter.split_documents(documents)
def embeddings_factory(type: str, config: dict = None):
if type.lower() == "openai":
return OpenAIEmbeddings(**config)
def create_qdrant_collection(
embeddings: Embeddings,
docs: List[Document],
connection_args: dict,
) -> None:
"""Creates a Qdrant collection from the documents.
Args:
embeddings: embeddings model object
docs: The documents to store as a List of document chunks
connection_args: The connection arguments.
Returns:
The Qdrant collection.
"""
host = connection_args.get("host", None)
port = connection_args.get("port", 6333)
path = connection_args.get("path", None)
collection_name = connection_args.get("collection", DEFAULT_COLLECTION_NAME)
return Qdrant.from_documents(
docs,
embeddings,
host=host,
port=port,
collection_name=collection_name,
path=path,
)
def validate_destination(destination: str):
if destination.lower() not in ["qdrant"]:
raise SystemExit(f"Destination '{destination}' not supported.")
def validate_embeddings_type(embeddings_type: str):
if embeddings_type.lower() not in ["openai"]:
raise SystemExit(f"Embeddings type '{embeddings_type}' not supported.")
elif embeddings_type.lower() == "openai":
if not "OPENAI_API_KEY" in os.environ:
raise SystemExit("OPENAI_API_KEY environment variable not set.")
def main():
parser = ArgumentParser(
prog="ingest.py",
description="Extract documents from a folder and load them into a vector store.",
epilog="Example: python ingest.py --config config.yaml",
)
parser.add_argument('-c', '--config', required=True, help='config file path')
args = parser.parse_args()
opt = yaml.load(open(args.config), Loader=yaml.FullLoader)
opt.update(vars(args))
docs_folder = opt.get("docs_folder", "data/documents")
chunk_size = opt.get("chunk_size", 1000)
chunk_overlap = opt.get("chunk_overlap", 20)
embeddings_type = opt.get("embeddings", "openai")
destination = opt.get("destination")
try:
openai_config = opt["openai_config"]
except KeyError:
raise SystemExit("config not found in config file.")
try:
connection_args = opt["connection_args"]
except KeyError:
raise SystemExit("connection_args not found in config file.")
validate_destination(destination)
validate_embeddings_type(embeddings_type)
docs = extract_documents(docs_folder)
logger.info(f"{len(docs)} documents extracted.")
chunks = create_chunks(docs, chunk_size, chunk_overlap)
logger.info(f"{len(chunks)} chunks created.")
for i, chunk in enumerate(chunks[:3]):
logger.info(f"chunk {i}")
logger.info(chunk)
embeddings = embeddings_factory(embeddings_type, openai_config)
if destination.lower() == "qdrant":
create_qdrant_collection(
embeddings=embeddings,
docs=chunks,
connection_args=connection_args,
)
logger.info(f"Qdrant collection created with arguments {connection_args}")
else:
raise SystemExit(f"Destination '{destination}' not supported. Only qdrant is supported.")
if __name__ == "__main__":
main()