All checks were successful
Gitea Docker Redeploy / Redploy-App-on-self-via-SSH (push) Successful in 1m40s
202 lines
5.6 KiB
Python
202 lines
5.6 KiB
Python
"""
|
|
All functions around bots
|
|
|
|
"""
|
|
from uuid import uuid4
|
|
from collections import namedtuple
|
|
import os, hashlib, traceback, logging
|
|
from datetime import datetime, date
|
|
from elasticsearch_dsl import connections
|
|
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
|
|
from langchain_community.llms import Ollama
|
|
from langchain_community.document_loaders import PyPDFLoader, Docx2txtLoader
|
|
from langchain_community.embeddings import OllamaEmbeddings
|
|
|
|
from langchain_elasticsearch import ElasticsearchStore
|
|
|
|
from langchain.chains import create_retrieval_chain
|
|
from langchain.chains.combine_documents import create_stuff_documents_chain
|
|
from langchain_core.prompts import ChatPromptTemplate
|
|
|
|
from lib.models import Chatbot, Text, User
|
|
|
|
|
|
ollama_url = os.getenv("OLLAMA_URI")
|
|
elastic_url = os.getenv("ELASTIC_URI")
|
|
|
|
assert ollama_url
|
|
assert elastic_url
|
|
|
|
ESDocument = namedtuple('Document', ['page_content', 'metadata'])
|
|
|
|
#TODO: needs to be reimplemented with another mechanism like celeery to manage longer running tasks and give feedback to frontend
|
|
|
|
def train_text(bot_id, text):
|
|
"""
|
|
Caution: Long running request!
|
|
"""
|
|
txt_md5 = hashlib.md5(text.encode()).hexdigest()
|
|
t = Text.get(id=txt_md5, ignore=404)
|
|
if t is not None:
|
|
return True
|
|
|
|
else:
|
|
bot = Chatbot.get(id=bot_id)
|
|
user = User.get(id=bot.creator_id)
|
|
|
|
t = Text(meta={'id': txt_md5})
|
|
t = Text()
|
|
t.text = text
|
|
t.md5 = txt_md5
|
|
|
|
#add meta data
|
|
t.creation_date = datetime.now()
|
|
t.creator_id = user.meta.id
|
|
t.save()
|
|
|
|
#train with given text
|
|
chunk_size = 1536
|
|
chunk_overlap = 200
|
|
|
|
documents = []
|
|
for i, s in enumerate(RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap, length_function=len).split_text(text)):
|
|
documents.append(ESDocument(
|
|
page_content=s,
|
|
metadata={
|
|
"segment_nr": i,
|
|
"text_id": t.meta.id,
|
|
"chunk_size": chunk_size,
|
|
"chunk_overlap": chunk_overlap
|
|
}
|
|
))
|
|
|
|
embeddings = OllamaEmbeddings(model=bot.llm_model, base_url=ollama_url)
|
|
|
|
vector_store = ElasticsearchStore(
|
|
es_url=elastic_url,
|
|
index_name= "chatbot_" + bot_id.lower(),
|
|
embedding=embeddings
|
|
)
|
|
|
|
uuids = [str(uuid4()) for _ in range(len(documents))]
|
|
vector_store.add_documents(documents=documents, ids=uuids)
|
|
|
|
return True
|
|
|
|
|
|
|
|
#TODO add history
|
|
|
|
def ask_bot(question, bot_id):
|
|
bot = Chatbot.get(id=bot_id)
|
|
llm = Ollama(
|
|
model=bot.llm_model,
|
|
base_url=ollama_url
|
|
)
|
|
query = bot.system_prompt + " " + question
|
|
for chunk in llm.stream(query):
|
|
yield chunk
|
|
|
|
|
|
|
|
#connections.get_connection()
|
|
#if es.indices.exists(index="index"):
|
|
|
|
def ask_bot2(question, bot_id):
|
|
"""
|
|
Asks a chatbot
|
|
"""
|
|
|
|
bot = Chatbot.get(id=bot_id)
|
|
prompt = question
|
|
system_prompt = bot.system_prompt + "\n\n{context}"
|
|
|
|
rag_index = "chatbot_" + bot_id.lower()
|
|
|
|
if connections.get_connection().indices.exists(index=rag_index):
|
|
|
|
vector_store = ElasticsearchStore(
|
|
es_url=elastic_url,
|
|
index_name=rag_index,
|
|
distance_strategy="COSINE",
|
|
embedding=OllamaEmbeddings(model=bot.llm_model, base_url=ollama_url)
|
|
)
|
|
|
|
def gen_func():
|
|
llm = Ollama(
|
|
model=bot.llm_model,
|
|
base_url=ollama_url
|
|
)
|
|
|
|
ch_prompt = ChatPromptTemplate.from_messages(
|
|
[
|
|
("system", system_prompt),
|
|
("human", "{input}"),
|
|
]
|
|
)
|
|
|
|
retriever = vector_store.as_retriever()
|
|
question_answer_chain = create_stuff_documents_chain(llm, ch_prompt)
|
|
rag_chain = create_retrieval_chain(retriever, question_answer_chain)
|
|
for chunk in rag_chain.stream({"input": prompt}):
|
|
print(chunk, flush=True)
|
|
if "answer" in chunk:
|
|
yield chunk["answer"]
|
|
|
|
|
|
def get_score_docs():
|
|
k = 4
|
|
start_vec_search = datetime.now().timestamp()
|
|
scoredocs = vector_store.similarity_search_with_score(prompt, k=k)
|
|
vec_search_duration = round(datetime.now().timestamp() - start_vec_search, 2)
|
|
print("Vec search duration: ", vec_search_duration, flush=True)
|
|
xs = []
|
|
for doc, score in scoredocs:
|
|
#print(doc.__dict__, flush=True)
|
|
#print(doc, flush=True)
|
|
xs.append([score, dict(doc)])
|
|
return xs
|
|
|
|
|
|
return {
|
|
"answer_generator": gen_func,
|
|
"get_score_docs": get_score_docs
|
|
}
|
|
|
|
else:
|
|
|
|
def gen_func():
|
|
bot = Chatbot.get(id=bot_id)
|
|
llm = Ollama(
|
|
model=bot.llm_model,
|
|
base_url=ollama_url
|
|
)
|
|
query = bot.system_prompt + " " + question
|
|
for chunk in llm.stream(query):
|
|
yield chunk
|
|
|
|
return {
|
|
"answer_generator": gen_func,
|
|
"get_score_docs": lambda: []
|
|
}
|
|
|
|
|
|
|
|
|
|
from ollama import Client as OllamaClient
|
|
|
|
def download_llm(model):
|
|
#print(ollama_url, flush=True)
|
|
|
|
#ollama_client = OllamaClient(host=ollama_url)
|
|
#x = ollama_client.pull('llama3')
|
|
#print( type(x), flush=True)
|
|
#print( x.__dict__, flush=True)
|
|
#print( x, flush=True)
|
|
|
|
s = """curl %s/api/pull -d '{ "name": "%s" }' """ % (ollama_url, model)
|
|
print( os.system(s.strip()) ,flush=True)
|
|
|