From 4472681b06075c13d921581df229ffca9e1c9146 Mon Sep 17 00:00:00 2001
From: Aaron Giner <aaron.giner@student.tugraz.at>
Date: Tue, 30 Jul 2024 03:39:16 +0200
Subject: [PATCH] added documentation, performed some code cleanup

---
 python/llm_server/embeddings.py  |  34 ++----
 python/llm_server/memory_util.py |  26 +++-
 python/llm_server/model.py       |   4 +
 python/llm_server/parse_dt.py    |   8 ++
 python/llm_server/queries.py     | 200 +++++++++++++++++++------------
 python/llm_server/server.py      |  42 ++++++-
 python/llm_server/util.py        |  31 +++--
 7 files changed, 223 insertions(+), 122 deletions(-)

diff --git a/python/llm_server/embeddings.py b/python/llm_server/embeddings.py
index 8c8b00a..598254d 100644
--- a/python/llm_server/embeddings.py
+++ b/python/llm_server/embeddings.py
@@ -1,28 +1,14 @@
 import torch
-from rake_nltk import Rake
 from sentence_transformers import SentenceTransformer, util
 
 
-# extracts keywords from a statement and returns embeddings for the statement
-def get_embeddings_kw(statement, model=None):
-    if model is None:
-        model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
-
-    # extract kw
-    rake_nltk_var = Rake()
-    rake_nltk_var.extract_keywords_from_text(statement)
-    keyword_extracted = rake_nltk_var.get_ranked_phrases()
-
-    # combine keywords and deduplicate words
-    kw = " ".join(set(" ".join([k for k in keyword_extracted]).split(" ")))
-
-    embedding_kw = model.encode(kw)
-
-    return embedding_kw
-
-
-# returns embeddings for a statement
+#
 def get_embeddings(statement, model=None):
+    """
+    :param statement: Input
+    :param model: Embeddings model to use (default specified below)
+    :return: Returns text embedding for a statement
+    """
     if model is None:
         model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
 
@@ -31,9 +17,15 @@ def get_embeddings(statement, model=None):
     return embeddings
 
 
-# must not be in tensor form
+#
 # returns similarity between embeddings (inputs must be 1 dimensional)
 def cos_sim(embeddings1, embeddings2):
+    """
+    Calculate cosine similarity score between text embeddings.
+    :param embeddings1: text embeddding 1 (must not be in tensor form)
+    :param embeddings2: text embedding 2 (must not be in tensor form)
+    :return: Cosine similarity score.
+    """
     embeddings1 = torch.tensor(embeddings1)
     embeddings2 = torch.tensor(embeddings2)
 
diff --git a/python/llm_server/memory_util.py b/python/llm_server/memory_util.py
index aa466f5..5c712ce 100644
--- a/python/llm_server/memory_util.py
+++ b/python/llm_server/memory_util.py
@@ -6,10 +6,22 @@ decay_factor = 0.995
 
 
 def filter_memories(memories, queries, datetime, w_recency=0.5, w_importance=1, w_relevance=3):
+    """
+    Based on a set of query statements, performs a sorting of the input memories based on the factors: recency,
+    importance and relevance. The score for each memory is calculated for each query statement separately and the
+    highest score is used.
+    :param memories: Set of input memories.
+    :param queries: Set of query statements.
+    :param datetime: Date/time string.
+    :param w_recency: Recency weight.
+    :param w_importance: Importance weight.
+    :param w_relevance: Relevance weight.
+    :return: List of sorted memories by score.
+    """
     model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
 
-    # queries = parse_dt.replace_all(datetime, queries)
-    # print(queries)
+    # perform datetime replacement for the most recent user input and add it to the list of query memories.
+    queries.append(parse_dt.replace_all(datetime, queries[-1:])[0])
 
     embedding_memories = [embeddings.get_embeddings(m["Content"], model) for m in memories]
     embedding_queries = [embeddings.get_embeddings(q, model) for q in queries]
@@ -32,13 +44,17 @@ def filter_memories(memories, queries, datetime, w_recency=0.5, w_importance=1,
     return memories
 
 
-def memories_to_string(memories, include_date_created=False, include_nodeId=False):
+def memories_to_string(memories, include_nodeId=False):
+    """
+    Converts a list of input memories to a string according to the specified format.
+    :param memories: List of memories.
+    :param include_nodeId: Whether the node-id should be prepended to the memory item.
+    :return: A string of all memories.
+    """
     mem_string = ""
-    print(memories)
     for mem in memories:
         prefix = "- "
         prefix += (str(mem["NodeId"]) + ". ") if include_nodeId else ""
-        prefix += (("[" + mem["LastAccessed"] + "] ") if include_date_created else " ")
         mem_string += prefix + mem["Content"].strip() + "\n"
 
     return mem_string
diff --git a/python/llm_server/model.py b/python/llm_server/model.py
index 8af9ebf..9022cee 100644
--- a/python/llm_server/model.py
+++ b/python/llm_server/model.py
@@ -8,6 +8,10 @@ import error
 
 
 class Model:
+    """
+    The model class is a further abstraction layer and acts as an interface for inference using ChatGPT and local
+    models.
+    """
     def __init__(self, model_name, model_file, model_tok, temperature, allow_system_prompt,
                  ctx_size=8192, api_key="", use_chat_gpt=False,):
         self.model_name = model_name
diff --git a/python/llm_server/parse_dt.py b/python/llm_server/parse_dt.py
index 3346584..c0a5f16 100644
--- a/python/llm_server/parse_dt.py
+++ b/python/llm_server/parse_dt.py
@@ -1,7 +1,15 @@
 import dateutil.parser
 import parsedatetime as pdt
 
+
 def replace_all(dt, statements):
+    """
+    Replaces the first occurrence of a referenced date/time in natural language with a date/time-string of the specified
+    format.
+    :param dt: input datetime as a string (usually the current date/time)
+    :param statements: a list of strings to replace the date/time.
+    :return: The input statements with replaced date/time strings.
+    """
     cal = pdt.Calendar()
     dt = dateutil.parser.parse(dt)
 
diff --git a/python/llm_server/queries.py b/python/llm_server/queries.py
index 30c4bb7..791270c 100644
--- a/python/llm_server/queries.py
+++ b/python/llm_server/queries.py
@@ -15,6 +15,11 @@ import util
 
 
 def parse_request(request):
+    """
+    Parses incoming request packets and delegates to appropriate routines. Loads a Model object for the specified model.
+    :param request: Request packet.
+    :return: Response packet.
+    """
     try:
         request = json.loads(request)
         requestType = request["Type"]
@@ -77,6 +82,12 @@ def parse_request(request):
 
 
 def query_chat(request, llm: model.Model):
+    """
+    Handles the Chat Request.
+    :param request: Request packet.
+    :param llm: Model to use.
+    :return: Response packet.
+    """
     NUM_MEMORIES_TO_USE = 25
 
     parameters = request["Data"]
@@ -111,7 +122,6 @@ def query_chat(request, llm: model.Model):
         if res is not False:
             break
         else:
-            print("LIMIT EXCEEDED, RETRYING")
             chat_messages = chat_messages[2:]
 
     memories_accessed = [str(mem["NodeId"]) for mem in memories[:NUM_MEMORIES_TO_USE]]
@@ -120,6 +130,12 @@ def query_chat(request, llm: model.Model):
 
 
 def query_reflection(request, llm: model.Model):
+    """
+    Handles the Reflection Request.
+    :param request: Request packet.
+    :param llm: Model to use.
+    :return: Response packet.
+    """
     parameters = request["Data"]
     memories = request["Memories"]
     parameters["Memories"] = memory_util.memories_to_string(memories)
@@ -140,8 +156,6 @@ def query_reflection(request, llm: model.Model):
         memory_util.filter_memories(memories, [q], parameters["DateTime"], 0, 0, 1)
         parameters["Memories"] = memory_util.memories_to_string(memories[:5], include_nodeId=True)
 
-        print("Question: ", q)
-
         messages_a = [
             {"role": "user",
              "content": util.load_template("reflection_a").format(**parameters)},
@@ -153,15 +167,16 @@ def query_reflection(request, llm: model.Model):
 
         insights.extend(insights_q)
 
-        print("QUESTION:")
-        print(q)
-        for x in insights_q:
-            print(" - " + x)
-
     return {"Memories": insights}
 
 
 def query_poignancy(request, llm: model.Model):
+    """
+    Handles the Poignancy Request.
+    :param request: Request packet.
+    :param llm: Model to use.
+    :return: Response packet.
+    """
     parameters = request["Data"]
     messages = [
         {"role": "user",
@@ -174,14 +189,24 @@ def query_poignancy(request, llm: model.Model):
 
 
 def add_poignancy_to_memories(memories, llm: model.Model):
+    """
+    Adds a poignancy score to a new memory element using the poignancy template.
+    :param memories: Memories, for which the poignancy shall be determined.
+    :param llm: Model to use.
+    :return: List of memories with added poignancy score.
+    """
     for i in range(len(memories)):
         memories[i] += "~" + query_poignancy({"Data": {"Memory": memories[i]}}, llm)["Text"]
 
 
-# deduplicates a list of statements using their similarity score
-# deduplication takes length of statements into account and will keep the
-# longer one (potentially containing more information, but not guaranteed)
 def deduplicate_sim(statements, threshold=0.85):
+    """
+    Deduplicates a list of statements using their similarity score. Deduplication takes length of statements
+    into account and will keep the longer one (potentially containing more information, but not guaranteed)
+    :param statements: List of statements to deduplicate.
+    :param threshold: Similarity threshold for comparison.
+    :return: Deduplicated list.
+    """
     st_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
 
     # calculate embedding for each input statement
@@ -215,8 +240,15 @@ def deduplicate_sim(statements, threshold=0.85):
     return dedup
 
 
-# remove statements that are already present in the agent's memory based on similarity
+#
 def deduplicate_sim_mem(statements, memories, threshold=0.85):
+    """
+    Remove statements that are already present in the agent's memory based on similarity.
+    :param statements: List of new statements to check.
+    :param memories: List of memories already in the memory stream.
+    :param threshold: Similarity threshold for comparison.
+    :return: Deduplicated list.
+    """
     st_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
 
     # convert input memories and statements to embedding
@@ -239,6 +271,15 @@ def deduplicate_sim_mem(statements, memories, threshold=0.85):
 
 
 def query_chat_extract(template, template_val, request, llm: model.Model, ctx_window=20):
+    """
+    Performs chat extraction using the specified templates.
+    :param template: The template to use for extraction.
+    :param template_val: The template to use for validation of extracted statements.
+    :param request: The request packet.
+    :param llm: The model to use.
+    :param ctx_window: The size of the sliding window for extraction.
+    :return: List of extracted statements.
+    """
     parameters = request["Data"]
     chat = parameters["Conversation"]
     lines = chat.split("\n")
@@ -266,19 +307,20 @@ def query_chat_extract(template, template_val, request, llm: model.Model, ctx_wi
         extracted.extend(res)
 
     # use a deduplication prompt/sim to remove duplicate statements from collected list
-
     if ctx_window < len(lines):  # only need to deduplicate when ctx window is smaller than convo length
         # statements = query_deduplicate(statements, llm)
         extracted = deduplicate_sim(extracted)
 
-    # validate extracted statements against original conversation
-    # parameters["Conversation"] = chat
-    # extracted = conversation_validate_statements(parameters, statements, llm)
-
     return extracted
 
 
 def query_chat_summary(request, llm: model.Model):
+    """
+    Handles the ChatSummary Request.
+    :param request: Request packet.
+    :param llm: Model to use.
+    :return: Response packet.
+    """
     parameters = request["Data"]
 
     new_memories = []
@@ -295,11 +337,8 @@ def query_chat_summary(request, llm: model.Model):
     agent_facts = query_chat_extract("chat_summary_agent_facts", "chat_validate_statement", request, llm)
     new_memories.extend(agent_facts)
 
-    print(len(new_memories))
     new_memories = deduplicate_sim(new_memories, 0.75)
-    print(len(new_memories))
     new_memories = deduplicate_sim_mem(new_memories, request["Memories"], 0.75)
-    print(len(new_memories))
     add_poignancy_to_memories(new_memories, llm)
 
     messages_summary = [
@@ -314,6 +353,12 @@ def query_chat_summary(request, llm: model.Model):
 
 
 def query_chat_extract_plan(request, llm: model.Model):
+    """
+    Handles the ChatExtractPlan Request.
+    :param request: Request packet.
+    :param llm: Model to use.
+    :return: Response packet.
+    """
     parameters = request["Data"]
 
     # alternative 1
@@ -333,6 +378,11 @@ def query_chat_extract_plan(request, llm: model.Model):
 
 
 def parse_datetime(statement):
+    """
+    Parses a statement for a datetime occurrence based on the format specified by the RegEx.
+    :param statement: Statement to search.
+    :return: The dt object or None.
+    """
     try:
         datetime_str = re.search(
             "[A-Z][a-z]+, [A-Z][a-z]+ [0-9]+, [0-9]{4}(, [0-9]{1,2}:[0-9]{1,2}( )*((pm|am)|(AM|PM)))?",
@@ -345,9 +395,14 @@ def parse_datetime(statement):
         return None
 
 
-# takes a list of plan statements, parses the referenced time and date and removes past events.
-# additionally returns true if any of the plans are for the current date (for replanning)
 def plans_validate(statements, c_datetime):
+    """
+    Takes a list of plan statements, parses the referenced time and date and removes past events. Additionally, returns
+    true if any of the plans are for the current date (for replanning)
+    :param statements: List of plans to validate.
+    :param c_datetime: The current datetime.
+    :return: List of valid future plans and an indication of whether one of the plans takes place on the current day.
+    """
     any_plans_for_today = False
     valid = []
     c_datetime = dateutil.parser.parse(c_datetime)
@@ -364,18 +419,13 @@ def plans_validate(statements, c_datetime):
     return valid, any_plans_for_today
 
 
-# obsolete
-def get_calendar():
-    calendar = ""
-    start_date = datetime.now() - timedelta(days=14)
-    for i in range(14):
-        date = start_date + timedelta(days=i)
-        date_str = date.strftime("%A, %B %d, %Y")
-        calendar += date_str + "\n"
-    return calendar
-
-
 def generate_context(request, llm: model.Model):
+    """
+    Handles the Context Request.
+    :param request: Request packet.
+    :param llm: Model to use.
+    :return: Response packet.
+    """
     parameters = request["Data"]
     memories = request["Memories"]
     memories.sort(key=lambda x: x["HrsSinceCreation"], reverse=True)
@@ -401,12 +451,16 @@ def generate_context(request, llm: model.Model):
 
     context += action
 
-    print(context)
-
     return {"Text": context}
 
 
 def query_relationship(request, llm: model.Model):
+    """
+    Performs the relationship query used for the Context request.
+    :param request: Request packet.
+    :param llm: Model to use.
+    :return: Relationship response.
+    """
     parameters = request["Data"]
     memories = request["Memories"]
 
@@ -423,6 +477,12 @@ def query_relationship(request, llm: model.Model):
 
 
 def query_agent_action(request, llm: model.Model):
+    """
+    Determines the agents current action for the Context request.
+    :param request: Request packet.
+    :param llm: Model to use.
+    :return: Agent's action response.
+    """
     parameters = request["Data"]
     messages = [
         {"role": "user",
@@ -435,6 +495,12 @@ def query_agent_action(request, llm: model.Model):
 
 
 def query_plan_day(request, llm: model.Model):
+    """
+    Handles the PlanDay Request.
+    :param request: Request packet.
+    :param llm: Model to use.
+    :return: Response packet.
+    """
     parameters = request["Data"]
     memories = request["Memories"]
 
@@ -447,33 +513,16 @@ def query_plan_day(request, llm: model.Model):
 
     day_plan = llm.query(messages)
 
-    """
-    for h in range(24):
-        time_start = datetime.strptime(str(h) + ":00", "%H:%M")
-        time_start = time_start.strftime("%I:%M%p")
-        time_end = datetime.strptime(str((h + 1) % 24) + ":00", "%H:%M")
-        time_end = time_end.strftime("%I:%M%p")
-
-        messages = [
-            {"role": "user",
-             "content": util.load_template("plan_day_decomp").format(time_start=time_start, time_end=time_end,
-                                                                     memories=memories_str, plan=day_plan,
-                                                                     **parameters)},
-        ]
-
-        prompt = tokenizer.apply_chat_template(messages, tokenize=False)
-
-        res = time_start + "-" + time_end + ":" + run_query(prompt, llm)
-
-        print()
-        print(res)
-        print()
-    """
-
     return {"Text": day_plan}
 
 
 def query_replan_day(request, llm: model.Model):
+    """
+    Handles the ReplanDay Request.
+    :param request: Request packet.
+    :param llm: Model to use.
+    :return: Response packet.
+    """
     parameters = request["Data"]
     memories = request["Memories"]
 
@@ -490,6 +539,12 @@ def query_replan_day(request, llm: model.Model):
 
 
 def query_custom(request, llm):
+    """
+    Handles the Custom Request.
+    :param request: Request packet.
+    :param llm: Model to use.
+    :return: Response packet.
+    """
     parameters = request["Data"]
     messages = [
         {"role": "user",
@@ -500,8 +555,15 @@ def query_custom(request, llm):
     return {"Text": response}
 
 
-# returns a list of validated statements
-def validate_statements(template, statements, parameters, llm: model.Model):
+def conversation_validate_statements(template, statements, parameters, llm: model.Model):
+    """
+    # Returns a list of validated statements based on the LLM response.
+    :param template: The validation template to use.
+    :param statements: A list of statements to validate.
+    :param parameters: A parameters dictionary containing all data needed for the query (as defined in the template).
+    :param llm: The model to use.
+    :return: List of validated statements.
+    """
     valid = []
     for statement in statements:
         message_validate = [
@@ -513,22 +575,4 @@ def validate_statements(template, statements, parameters, llm: model.Model):
         if "yes" in val.lower():
             valid.append(statement)
 
-    return valid
-
-
-# returns a list of validated statements
-def conversation_validate_statements(template, parameters, statements, llm: model.Model):
-    valid = []
-    for statement in statements:
-        message_validate = [
-            {"role": "user",
-             "content": util.load_template(template).format(statement=statement,
-                                                                             **parameters)},
-        ]
-
-        val = llm.query(message_validate)
-        print(statement, val)
-        if "yes" in val.lower():
-            valid.append(statement)
-
     return valid
\ No newline at end of file
diff --git a/python/llm_server/server.py b/python/llm_server/server.py
index a5502ae..7009c3a 100644
--- a/python/llm_server/server.py
+++ b/python/llm_server/server.py
@@ -1,14 +1,16 @@
 # https://realpython.com/python-sockets/
-import json
 import socket
 import threading
 import time
 
 import queries
-import select
 
+
+# Define host and port here.
 HOST = "127.0.0.1"
 PORT = 65432
+
+# Define the maximum number of sessions here.
 MAX_CONNECTIONS = 60
 
 connections = []
@@ -20,6 +22,10 @@ STREAM_BUFFER_SIZE = 8192
 
 
 class ClientHandler(threading.Thread):
+    """
+    A ClientHandler thread is responsible for communicating with one client, transmitting the requests to the
+    RequestHandler queue and returning the result.
+    """
     def __init__(self, conn, addr, r):
         super().__init__()
         self._stop_event = threading.Event()
@@ -31,11 +37,19 @@ class ClientHandler(threading.Thread):
         self.request_handler = r
 
     def signal_response(self, response):
+        """
+        Used to signal the ClientHandler thread to continue execution when the result arrives.
+        :param response: The JSON-response as a string.
+        """
         with self.response_condition:
             self.response = response
             self.response_condition.notify()
 
     def run(self):
+        """
+        Action loop for the ClientHandler thread. Continually waits for requests from the client. Puts ClientHandler to
+        sleep once request is received and has been transmitted to RequestHandler.
+        """
         print(f"Connected with {self.addr}")
         with self.conn:
             while not self._stop_event.is_set():
@@ -45,8 +59,6 @@ class ClientHandler(threading.Thread):
                     while True:
                         try:
                             received = self.conn.recv(STREAM_BUFFER_SIZE)
-                            # print("Received: " + str(len(received)))
-                            # print(received)
                             total_rec += len(received)
                             if len(received) == 0:
                                 raise Exception
@@ -57,11 +69,9 @@ class ClientHandler(threading.Thread):
                     if len(data) == 0:
                         continue
 
-                    # print("Total data receieved: " + str(total_rec))
                     request = data.decode("utf-8")
 
                     with self.response_condition:
-                        print("ADDING PACKET TO QUEUE")
                         self.request_handler.queue_request(request, self)
                         self.response_condition.wait()
 
@@ -80,6 +90,10 @@ class ClientHandler(threading.Thread):
 
 
 class RequestHandler(threading.Thread):
+    """
+    The RequestHandler thread is responsible for handling all client requests and transmitting the results back to the
+    ClientHandler threads.
+    """
     def __init__(self):
         super().__init__()
         self._stop_event = threading.Event()
@@ -87,11 +101,20 @@ class RequestHandler(threading.Thread):
         self.request_condition = threading.Condition()
 
     def queue_request(self, request, client):
+        """
+        Used add a request to the queue and signal the RequestHandler thread to continue execution.
+        :param request: The request packet.
+        :param client: The ClientHandler thread which should be signalled upon request fulfillment.
+        """
         with self.request_condition:
             self.request_queue.append((request, client))
             self.request_condition.notify()
 
     def run(self):
+        """
+        Action loop for RequestHandler. Continually handles requests as long as queue is not empty, otherwise goes to
+        sleep. Signals ClientHandler threads once response is ready.
+        """
         while not self._stop_event.is_set():
             with self.request_condition:
                 while len(self.request_queue) <= 0 and not self._stop_event.is_set():
@@ -113,6 +136,9 @@ class RequestHandler(threading.Thread):
 
 
 class ExitHandler(threading.Thread):
+    """
+    Separate thread to enable easy termination of server by pressing the "e" key.
+    """
     def stop(self):
         pass
 
@@ -126,6 +152,10 @@ class ExitHandler(threading.Thread):
 
 
 def start_server():
+    """
+    Main server routine. Starts RequestHandler and ExitHandler threads at the start and
+    ClientHandler threads for every connected client. Keeps track of maximum connections.
+    """
     global s
     e = ExitHandler()
     r = RequestHandler()
diff --git a/python/llm_server/util.py b/python/llm_server/util.py
index 99a4d8e..954b09f 100644
--- a/python/llm_server/util.py
+++ b/python/llm_server/util.py
@@ -5,26 +5,33 @@ TEMPLATES_PATH = "llm_server/templates/"
 
 
 def load_template(template):
+    """
+    :param template: template file name
+    :return: content of template file as string
+    """
     return open(TEMPLATES_PATH + template + ".txt").read()
 
 
-MISTRAL_TOK = "mistralai/Mistral-7B-Instruct-v0.1"
-MISTRAL_GGUF = "TheBloke/Mistral-7B-Instruct-v0.1-GGUF"
-MISTRAL_FILE = "mistral-7b-instruct-v0.1.Q4_K_M.gguf"
-
-LLAMA_TOK = "meta-llama/Llama-2-7b-chat-hf"
-LLAMA_API_KEY = "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"
-LLAMA_NAME = "TheBloke/Llama-2-7B-Chat-GGUF"
-LLAMA_FILE = "llama-2-7b.Q4_K_M.gguf"
-
-KW_EN_CORE = "X:/LLM Models/nlm/en_core_web_sm-3.7.1-py3-none-any.whl"
-
-
 def load_tokenizer(model, token=""):
+    """
+    Loads the tokenizer for the specified model
+
+    :param model: huggingface repo url
+    :param token: huggingface repo access token
+    :return: tokenizer object
+    """
     return AutoTokenizer.from_pretrained(model, token=token)
 
 
 def load_model(model, ctx_size, temperature):
+    """
+    Loads the specified large language model.
+
+    :param model: huggingface repo url
+    :param ctx_size: maximum input/output tokens (input maximum is 4096 for most models)
+    :param temperature: model temperature
+    :return: llm object to use for inference
+    """
     llm = LlamaCpp(
         model_path=model,
         max_tokens=500,
-- 
GitLab