CreativeBots/backend/lib/chatbot.py
Tobias Weise 2d227d438c
All checks were successful
Gitea Docker Redeploy / Redploy-App-on-self-via-SSH (push) Successful in 1m40s
more overhaul
2024-08-27 22:39:29 +02:00

202 lines
5.6 KiB
Python

"""
All functions around bots
"""
from uuid import uuid4
from collections import namedtuple
import os, hashlib, traceback, logging
from datetime import datetime, date
from elasticsearch_dsl import connections
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.llms import Ollama
from langchain_community.document_loaders import PyPDFLoader, Docx2txtLoader
from langchain_community.embeddings import OllamaEmbeddings
from langchain_elasticsearch import ElasticsearchStore
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate
from lib.models import Chatbot, Text, User
ollama_url = os.getenv("OLLAMA_URI")
elastic_url = os.getenv("ELASTIC_URI")
assert ollama_url
assert elastic_url
ESDocument = namedtuple('Document', ['page_content', 'metadata'])
#TODO: needs to be reimplemented with another mechanism like celeery to manage longer running tasks and give feedback to frontend
def train_text(bot_id, text):
"""
Caution: Long running request!
"""
txt_md5 = hashlib.md5(text.encode()).hexdigest()
t = Text.get(id=txt_md5, ignore=404)
if t is not None:
return True
else:
bot = Chatbot.get(id=bot_id)
user = User.get(id=bot.creator_id)
t = Text(meta={'id': txt_md5})
t = Text()
t.text = text
t.md5 = txt_md5
#add meta data
t.creation_date = datetime.now()
t.creator_id = user.meta.id
t.save()
#train with given text
chunk_size = 1536
chunk_overlap = 200
documents = []
for i, s in enumerate(RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap, length_function=len).split_text(text)):
documents.append(ESDocument(
page_content=s,
metadata={
"segment_nr": i,
"text_id": t.meta.id,
"chunk_size": chunk_size,
"chunk_overlap": chunk_overlap
}
))
embeddings = OllamaEmbeddings(model=bot.llm_model, base_url=ollama_url)
vector_store = ElasticsearchStore(
es_url=elastic_url,
index_name= "chatbot_" + bot_id.lower(),
embedding=embeddings
)
uuids = [str(uuid4()) for _ in range(len(documents))]
vector_store.add_documents(documents=documents, ids=uuids)
return True
#TODO add history
def ask_bot(question, bot_id):
bot = Chatbot.get(id=bot_id)
llm = Ollama(
model=bot.llm_model,
base_url=ollama_url
)
query = bot.system_prompt + " " + question
for chunk in llm.stream(query):
yield chunk
#connections.get_connection()
#if es.indices.exists(index="index"):
def ask_bot2(question, bot_id):
"""
Asks a chatbot
"""
bot = Chatbot.get(id=bot_id)
prompt = question
system_prompt = bot.system_prompt + "\n\n{context}"
rag_index = "chatbot_" + bot_id.lower()
if connections.get_connection().indices.exists(index=rag_index):
vector_store = ElasticsearchStore(
es_url=elastic_url,
index_name=rag_index,
distance_strategy="COSINE",
embedding=OllamaEmbeddings(model=bot.llm_model, base_url=ollama_url)
)
def gen_func():
llm = Ollama(
model=bot.llm_model,
base_url=ollama_url
)
ch_prompt = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
("human", "{input}"),
]
)
retriever = vector_store.as_retriever()
question_answer_chain = create_stuff_documents_chain(llm, ch_prompt)
rag_chain = create_retrieval_chain(retriever, question_answer_chain)
for chunk in rag_chain.stream({"input": prompt}):
print(chunk, flush=True)
if "answer" in chunk:
yield chunk["answer"]
def get_score_docs():
k = 4
start_vec_search = datetime.now().timestamp()
scoredocs = vector_store.similarity_search_with_score(prompt, k=k)
vec_search_duration = round(datetime.now().timestamp() - start_vec_search, 2)
print("Vec search duration: ", vec_search_duration, flush=True)
xs = []
for doc, score in scoredocs:
#print(doc.__dict__, flush=True)
#print(doc, flush=True)
xs.append([score, dict(doc)])
return xs
return {
"answer_generator": gen_func,
"get_score_docs": get_score_docs
}
else:
def gen_func():
bot = Chatbot.get(id=bot_id)
llm = Ollama(
model=bot.llm_model,
base_url=ollama_url
)
query = bot.system_prompt + " " + question
for chunk in llm.stream(query):
yield chunk
return {
"answer_generator": gen_func,
"get_score_docs": lambda: []
}
from ollama import Client as OllamaClient
def download_llm(model):
#print(ollama_url, flush=True)
#ollama_client = OllamaClient(host=ollama_url)
#x = ollama_client.pull('llama3')
#print( type(x), flush=True)
#print( x.__dict__, flush=True)
#print( x, flush=True)
s = """curl %s/api/pull -d '{ "name": "%s" }' """ % (ollama_url, model)
print( os.system(s.strip()) ,flush=True)