From 4472681b06075c13d921581df229ffca9e1c9146 Mon Sep 17 00:00:00 2001 From: Aaron Giner <aaron.giner@student.tugraz.at> Date: Tue, 30 Jul 2024 03:39:16 +0200 Subject: [PATCH] added documentation, performed some code cleanup --- python/llm_server/embeddings.py | 34 ++---- python/llm_server/memory_util.py | 26 +++- python/llm_server/model.py | 4 + python/llm_server/parse_dt.py | 8 ++ python/llm_server/queries.py | 200 +++++++++++++++++++------------ python/llm_server/server.py | 42 ++++++- python/llm_server/util.py | 31 +++-- 7 files changed, 223 insertions(+), 122 deletions(-) diff --git a/python/llm_server/embeddings.py b/python/llm_server/embeddings.py index 8c8b00a..598254d 100644 --- a/python/llm_server/embeddings.py +++ b/python/llm_server/embeddings.py @@ -1,28 +1,14 @@ import torch -from rake_nltk import Rake from sentence_transformers import SentenceTransformer, util -# extracts keywords from a statement and returns embeddings for the statement -def get_embeddings_kw(statement, model=None): - if model is None: - model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') - - # extract kw - rake_nltk_var = Rake() - rake_nltk_var.extract_keywords_from_text(statement) - keyword_extracted = rake_nltk_var.get_ranked_phrases() - - # combine keywords and deduplicate words - kw = " ".join(set(" ".join([k for k in keyword_extracted]).split(" "))) - - embedding_kw = model.encode(kw) - - return embedding_kw - - -# returns embeddings for a statement +# def get_embeddings(statement, model=None): + """ + :param statement: Input + :param model: Embeddings model to use (default specified below) + :return: Returns text embedding for a statement + """ if model is None: model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') @@ -31,9 +17,15 @@ def get_embeddings(statement, model=None): return embeddings -# must not be in tensor form +# # returns similarity between embeddings (inputs must be 1 dimensional) def cos_sim(embeddings1, embeddings2): + """ + Calculate cosine similarity score between text embeddings. + :param embeddings1: text embeddding 1 (must not be in tensor form) + :param embeddings2: text embedding 2 (must not be in tensor form) + :return: Cosine similarity score. + """ embeddings1 = torch.tensor(embeddings1) embeddings2 = torch.tensor(embeddings2) diff --git a/python/llm_server/memory_util.py b/python/llm_server/memory_util.py index aa466f5..5c712ce 100644 --- a/python/llm_server/memory_util.py +++ b/python/llm_server/memory_util.py @@ -6,10 +6,22 @@ decay_factor = 0.995 def filter_memories(memories, queries, datetime, w_recency=0.5, w_importance=1, w_relevance=3): + """ + Based on a set of query statements, performs a sorting of the input memories based on the factors: recency, + importance and relevance. The score for each memory is calculated for each query statement separately and the + highest score is used. + :param memories: Set of input memories. + :param queries: Set of query statements. + :param datetime: Date/time string. + :param w_recency: Recency weight. + :param w_importance: Importance weight. + :param w_relevance: Relevance weight. + :return: List of sorted memories by score. + """ model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') - # queries = parse_dt.replace_all(datetime, queries) - # print(queries) + # perform datetime replacement for the most recent user input and add it to the list of query memories. + queries.append(parse_dt.replace_all(datetime, queries[-1:])[0]) embedding_memories = [embeddings.get_embeddings(m["Content"], model) for m in memories] embedding_queries = [embeddings.get_embeddings(q, model) for q in queries] @@ -32,13 +44,17 @@ def filter_memories(memories, queries, datetime, w_recency=0.5, w_importance=1, return memories -def memories_to_string(memories, include_date_created=False, include_nodeId=False): +def memories_to_string(memories, include_nodeId=False): + """ + Converts a list of input memories to a string according to the specified format. + :param memories: List of memories. + :param include_nodeId: Whether the node-id should be prepended to the memory item. + :return: A string of all memories. + """ mem_string = "" - print(memories) for mem in memories: prefix = "- " prefix += (str(mem["NodeId"]) + ". ") if include_nodeId else "" - prefix += (("[" + mem["LastAccessed"] + "] ") if include_date_created else " ") mem_string += prefix + mem["Content"].strip() + "\n" return mem_string diff --git a/python/llm_server/model.py b/python/llm_server/model.py index 8af9ebf..9022cee 100644 --- a/python/llm_server/model.py +++ b/python/llm_server/model.py @@ -8,6 +8,10 @@ import error class Model: + """ + The model class is a further abstraction layer and acts as an interface for inference using ChatGPT and local + models. + """ def __init__(self, model_name, model_file, model_tok, temperature, allow_system_prompt, ctx_size=8192, api_key="", use_chat_gpt=False,): self.model_name = model_name diff --git a/python/llm_server/parse_dt.py b/python/llm_server/parse_dt.py index 3346584..c0a5f16 100644 --- a/python/llm_server/parse_dt.py +++ b/python/llm_server/parse_dt.py @@ -1,7 +1,15 @@ import dateutil.parser import parsedatetime as pdt + def replace_all(dt, statements): + """ + Replaces the first occurrence of a referenced date/time in natural language with a date/time-string of the specified + format. + :param dt: input datetime as a string (usually the current date/time) + :param statements: a list of strings to replace the date/time. + :return: The input statements with replaced date/time strings. + """ cal = pdt.Calendar() dt = dateutil.parser.parse(dt) diff --git a/python/llm_server/queries.py b/python/llm_server/queries.py index 30c4bb7..791270c 100644 --- a/python/llm_server/queries.py +++ b/python/llm_server/queries.py @@ -15,6 +15,11 @@ import util def parse_request(request): + """ + Parses incoming request packets and delegates to appropriate routines. Loads a Model object for the specified model. + :param request: Request packet. + :return: Response packet. + """ try: request = json.loads(request) requestType = request["Type"] @@ -77,6 +82,12 @@ def parse_request(request): def query_chat(request, llm: model.Model): + """ + Handles the Chat Request. + :param request: Request packet. + :param llm: Model to use. + :return: Response packet. + """ NUM_MEMORIES_TO_USE = 25 parameters = request["Data"] @@ -111,7 +122,6 @@ def query_chat(request, llm: model.Model): if res is not False: break else: - print("LIMIT EXCEEDED, RETRYING") chat_messages = chat_messages[2:] memories_accessed = [str(mem["NodeId"]) for mem in memories[:NUM_MEMORIES_TO_USE]] @@ -120,6 +130,12 @@ def query_chat(request, llm: model.Model): def query_reflection(request, llm: model.Model): + """ + Handles the Reflection Request. + :param request: Request packet. + :param llm: Model to use. + :return: Response packet. + """ parameters = request["Data"] memories = request["Memories"] parameters["Memories"] = memory_util.memories_to_string(memories) @@ -140,8 +156,6 @@ def query_reflection(request, llm: model.Model): memory_util.filter_memories(memories, [q], parameters["DateTime"], 0, 0, 1) parameters["Memories"] = memory_util.memories_to_string(memories[:5], include_nodeId=True) - print("Question: ", q) - messages_a = [ {"role": "user", "content": util.load_template("reflection_a").format(**parameters)}, @@ -153,15 +167,16 @@ def query_reflection(request, llm: model.Model): insights.extend(insights_q) - print("QUESTION:") - print(q) - for x in insights_q: - print(" - " + x) - return {"Memories": insights} def query_poignancy(request, llm: model.Model): + """ + Handles the Poignancy Request. + :param request: Request packet. + :param llm: Model to use. + :return: Response packet. + """ parameters = request["Data"] messages = [ {"role": "user", @@ -174,14 +189,24 @@ def query_poignancy(request, llm: model.Model): def add_poignancy_to_memories(memories, llm: model.Model): + """ + Adds a poignancy score to a new memory element using the poignancy template. + :param memories: Memories, for which the poignancy shall be determined. + :param llm: Model to use. + :return: List of memories with added poignancy score. + """ for i in range(len(memories)): memories[i] += "~" + query_poignancy({"Data": {"Memory": memories[i]}}, llm)["Text"] -# deduplicates a list of statements using their similarity score -# deduplication takes length of statements into account and will keep the -# longer one (potentially containing more information, but not guaranteed) def deduplicate_sim(statements, threshold=0.85): + """ + Deduplicates a list of statements using their similarity score. Deduplication takes length of statements + into account and will keep the longer one (potentially containing more information, but not guaranteed) + :param statements: List of statements to deduplicate. + :param threshold: Similarity threshold for comparison. + :return: Deduplicated list. + """ st_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') # calculate embedding for each input statement @@ -215,8 +240,15 @@ def deduplicate_sim(statements, threshold=0.85): return dedup -# remove statements that are already present in the agent's memory based on similarity +# def deduplicate_sim_mem(statements, memories, threshold=0.85): + """ + Remove statements that are already present in the agent's memory based on similarity. + :param statements: List of new statements to check. + :param memories: List of memories already in the memory stream. + :param threshold: Similarity threshold for comparison. + :return: Deduplicated list. + """ st_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') # convert input memories and statements to embedding @@ -239,6 +271,15 @@ def deduplicate_sim_mem(statements, memories, threshold=0.85): def query_chat_extract(template, template_val, request, llm: model.Model, ctx_window=20): + """ + Performs chat extraction using the specified templates. + :param template: The template to use for extraction. + :param template_val: The template to use for validation of extracted statements. + :param request: The request packet. + :param llm: The model to use. + :param ctx_window: The size of the sliding window for extraction. + :return: List of extracted statements. + """ parameters = request["Data"] chat = parameters["Conversation"] lines = chat.split("\n") @@ -266,19 +307,20 @@ def query_chat_extract(template, template_val, request, llm: model.Model, ctx_wi extracted.extend(res) # use a deduplication prompt/sim to remove duplicate statements from collected list - if ctx_window < len(lines): # only need to deduplicate when ctx window is smaller than convo length # statements = query_deduplicate(statements, llm) extracted = deduplicate_sim(extracted) - # validate extracted statements against original conversation - # parameters["Conversation"] = chat - # extracted = conversation_validate_statements(parameters, statements, llm) - return extracted def query_chat_summary(request, llm: model.Model): + """ + Handles the ChatSummary Request. + :param request: Request packet. + :param llm: Model to use. + :return: Response packet. + """ parameters = request["Data"] new_memories = [] @@ -295,11 +337,8 @@ def query_chat_summary(request, llm: model.Model): agent_facts = query_chat_extract("chat_summary_agent_facts", "chat_validate_statement", request, llm) new_memories.extend(agent_facts) - print(len(new_memories)) new_memories = deduplicate_sim(new_memories, 0.75) - print(len(new_memories)) new_memories = deduplicate_sim_mem(new_memories, request["Memories"], 0.75) - print(len(new_memories)) add_poignancy_to_memories(new_memories, llm) messages_summary = [ @@ -314,6 +353,12 @@ def query_chat_summary(request, llm: model.Model): def query_chat_extract_plan(request, llm: model.Model): + """ + Handles the ChatExtractPlan Request. + :param request: Request packet. + :param llm: Model to use. + :return: Response packet. + """ parameters = request["Data"] # alternative 1 @@ -333,6 +378,11 @@ def query_chat_extract_plan(request, llm: model.Model): def parse_datetime(statement): + """ + Parses a statement for a datetime occurrence based on the format specified by the RegEx. + :param statement: Statement to search. + :return: The dt object or None. + """ try: datetime_str = re.search( "[A-Z][a-z]+, [A-Z][a-z]+ [0-9]+, [0-9]{4}(, [0-9]{1,2}:[0-9]{1,2}( )*((pm|am)|(AM|PM)))?", @@ -345,9 +395,14 @@ def parse_datetime(statement): return None -# takes a list of plan statements, parses the referenced time and date and removes past events. -# additionally returns true if any of the plans are for the current date (for replanning) def plans_validate(statements, c_datetime): + """ + Takes a list of plan statements, parses the referenced time and date and removes past events. Additionally, returns + true if any of the plans are for the current date (for replanning) + :param statements: List of plans to validate. + :param c_datetime: The current datetime. + :return: List of valid future plans and an indication of whether one of the plans takes place on the current day. + """ any_plans_for_today = False valid = [] c_datetime = dateutil.parser.parse(c_datetime) @@ -364,18 +419,13 @@ def plans_validate(statements, c_datetime): return valid, any_plans_for_today -# obsolete -def get_calendar(): - calendar = "" - start_date = datetime.now() - timedelta(days=14) - for i in range(14): - date = start_date + timedelta(days=i) - date_str = date.strftime("%A, %B %d, %Y") - calendar += date_str + "\n" - return calendar - - def generate_context(request, llm: model.Model): + """ + Handles the Context Request. + :param request: Request packet. + :param llm: Model to use. + :return: Response packet. + """ parameters = request["Data"] memories = request["Memories"] memories.sort(key=lambda x: x["HrsSinceCreation"], reverse=True) @@ -401,12 +451,16 @@ def generate_context(request, llm: model.Model): context += action - print(context) - return {"Text": context} def query_relationship(request, llm: model.Model): + """ + Performs the relationship query used for the Context request. + :param request: Request packet. + :param llm: Model to use. + :return: Relationship response. + """ parameters = request["Data"] memories = request["Memories"] @@ -423,6 +477,12 @@ def query_relationship(request, llm: model.Model): def query_agent_action(request, llm: model.Model): + """ + Determines the agents current action for the Context request. + :param request: Request packet. + :param llm: Model to use. + :return: Agent's action response. + """ parameters = request["Data"] messages = [ {"role": "user", @@ -435,6 +495,12 @@ def query_agent_action(request, llm: model.Model): def query_plan_day(request, llm: model.Model): + """ + Handles the PlanDay Request. + :param request: Request packet. + :param llm: Model to use. + :return: Response packet. + """ parameters = request["Data"] memories = request["Memories"] @@ -447,33 +513,16 @@ def query_plan_day(request, llm: model.Model): day_plan = llm.query(messages) - """ - for h in range(24): - time_start = datetime.strptime(str(h) + ":00", "%H:%M") - time_start = time_start.strftime("%I:%M%p") - time_end = datetime.strptime(str((h + 1) % 24) + ":00", "%H:%M") - time_end = time_end.strftime("%I:%M%p") - - messages = [ - {"role": "user", - "content": util.load_template("plan_day_decomp").format(time_start=time_start, time_end=time_end, - memories=memories_str, plan=day_plan, - **parameters)}, - ] - - prompt = tokenizer.apply_chat_template(messages, tokenize=False) - - res = time_start + "-" + time_end + ":" + run_query(prompt, llm) - - print() - print(res) - print() - """ - return {"Text": day_plan} def query_replan_day(request, llm: model.Model): + """ + Handles the ReplanDay Request. + :param request: Request packet. + :param llm: Model to use. + :return: Response packet. + """ parameters = request["Data"] memories = request["Memories"] @@ -490,6 +539,12 @@ def query_replan_day(request, llm: model.Model): def query_custom(request, llm): + """ + Handles the Custom Request. + :param request: Request packet. + :param llm: Model to use. + :return: Response packet. + """ parameters = request["Data"] messages = [ {"role": "user", @@ -500,8 +555,15 @@ def query_custom(request, llm): return {"Text": response} -# returns a list of validated statements -def validate_statements(template, statements, parameters, llm: model.Model): +def conversation_validate_statements(template, statements, parameters, llm: model.Model): + """ + # Returns a list of validated statements based on the LLM response. + :param template: The validation template to use. + :param statements: A list of statements to validate. + :param parameters: A parameters dictionary containing all data needed for the query (as defined in the template). + :param llm: The model to use. + :return: List of validated statements. + """ valid = [] for statement in statements: message_validate = [ @@ -513,22 +575,4 @@ def validate_statements(template, statements, parameters, llm: model.Model): if "yes" in val.lower(): valid.append(statement) - return valid - - -# returns a list of validated statements -def conversation_validate_statements(template, parameters, statements, llm: model.Model): - valid = [] - for statement in statements: - message_validate = [ - {"role": "user", - "content": util.load_template(template).format(statement=statement, - **parameters)}, - ] - - val = llm.query(message_validate) - print(statement, val) - if "yes" in val.lower(): - valid.append(statement) - return valid \ No newline at end of file diff --git a/python/llm_server/server.py b/python/llm_server/server.py index a5502ae..7009c3a 100644 --- a/python/llm_server/server.py +++ b/python/llm_server/server.py @@ -1,14 +1,16 @@ # https://realpython.com/python-sockets/ -import json import socket import threading import time import queries -import select + +# Define host and port here. HOST = "127.0.0.1" PORT = 65432 + +# Define the maximum number of sessions here. MAX_CONNECTIONS = 60 connections = [] @@ -20,6 +22,10 @@ STREAM_BUFFER_SIZE = 8192 class ClientHandler(threading.Thread): + """ + A ClientHandler thread is responsible for communicating with one client, transmitting the requests to the + RequestHandler queue and returning the result. + """ def __init__(self, conn, addr, r): super().__init__() self._stop_event = threading.Event() @@ -31,11 +37,19 @@ class ClientHandler(threading.Thread): self.request_handler = r def signal_response(self, response): + """ + Used to signal the ClientHandler thread to continue execution when the result arrives. + :param response: The JSON-response as a string. + """ with self.response_condition: self.response = response self.response_condition.notify() def run(self): + """ + Action loop for the ClientHandler thread. Continually waits for requests from the client. Puts ClientHandler to + sleep once request is received and has been transmitted to RequestHandler. + """ print(f"Connected with {self.addr}") with self.conn: while not self._stop_event.is_set(): @@ -45,8 +59,6 @@ class ClientHandler(threading.Thread): while True: try: received = self.conn.recv(STREAM_BUFFER_SIZE) - # print("Received: " + str(len(received))) - # print(received) total_rec += len(received) if len(received) == 0: raise Exception @@ -57,11 +69,9 @@ class ClientHandler(threading.Thread): if len(data) == 0: continue - # print("Total data receieved: " + str(total_rec)) request = data.decode("utf-8") with self.response_condition: - print("ADDING PACKET TO QUEUE") self.request_handler.queue_request(request, self) self.response_condition.wait() @@ -80,6 +90,10 @@ class ClientHandler(threading.Thread): class RequestHandler(threading.Thread): + """ + The RequestHandler thread is responsible for handling all client requests and transmitting the results back to the + ClientHandler threads. + """ def __init__(self): super().__init__() self._stop_event = threading.Event() @@ -87,11 +101,20 @@ class RequestHandler(threading.Thread): self.request_condition = threading.Condition() def queue_request(self, request, client): + """ + Used add a request to the queue and signal the RequestHandler thread to continue execution. + :param request: The request packet. + :param client: The ClientHandler thread which should be signalled upon request fulfillment. + """ with self.request_condition: self.request_queue.append((request, client)) self.request_condition.notify() def run(self): + """ + Action loop for RequestHandler. Continually handles requests as long as queue is not empty, otherwise goes to + sleep. Signals ClientHandler threads once response is ready. + """ while not self._stop_event.is_set(): with self.request_condition: while len(self.request_queue) <= 0 and not self._stop_event.is_set(): @@ -113,6 +136,9 @@ class RequestHandler(threading.Thread): class ExitHandler(threading.Thread): + """ + Separate thread to enable easy termination of server by pressing the "e" key. + """ def stop(self): pass @@ -126,6 +152,10 @@ class ExitHandler(threading.Thread): def start_server(): + """ + Main server routine. Starts RequestHandler and ExitHandler threads at the start and + ClientHandler threads for every connected client. Keeps track of maximum connections. + """ global s e = ExitHandler() r = RequestHandler() diff --git a/python/llm_server/util.py b/python/llm_server/util.py index 99a4d8e..954b09f 100644 --- a/python/llm_server/util.py +++ b/python/llm_server/util.py @@ -5,26 +5,33 @@ TEMPLATES_PATH = "llm_server/templates/" def load_template(template): + """ + :param template: template file name + :return: content of template file as string + """ return open(TEMPLATES_PATH + template + ".txt").read() -MISTRAL_TOK = "mistralai/Mistral-7B-Instruct-v0.1" -MISTRAL_GGUF = "TheBloke/Mistral-7B-Instruct-v0.1-GGUF" -MISTRAL_FILE = "mistral-7b-instruct-v0.1.Q4_K_M.gguf" - -LLAMA_TOK = "meta-llama/Llama-2-7b-chat-hf" -LLAMA_API_KEY = "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" -LLAMA_NAME = "TheBloke/Llama-2-7B-Chat-GGUF" -LLAMA_FILE = "llama-2-7b.Q4_K_M.gguf" - -KW_EN_CORE = "X:/LLM Models/nlm/en_core_web_sm-3.7.1-py3-none-any.whl" - - def load_tokenizer(model, token=""): + """ + Loads the tokenizer for the specified model + + :param model: huggingface repo url + :param token: huggingface repo access token + :return: tokenizer object + """ return AutoTokenizer.from_pretrained(model, token=token) def load_model(model, ctx_size, temperature): + """ + Loads the specified large language model. + + :param model: huggingface repo url + :param ctx_size: maximum input/output tokens (input maximum is 4096 for most models) + :param temperature: model temperature + :return: llm object to use for inference + """ llm = LlamaCpp( model_path=model, max_tokens=500, -- GitLab