Skip to content
Snippets Groups Projects
Commit 4472681b authored by Giner, Aaron's avatar Giner, Aaron
Browse files

added documentation, performed some code cleanup

parent cb2e93ad
No related branches found
No related tags found
No related merge requests found
import torch
from rake_nltk import Rake
from sentence_transformers import SentenceTransformer, util
# extracts keywords from a statement and returns embeddings for the statement
def get_embeddings_kw(statement, model=None):
if model is None:
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
# extract kw
rake_nltk_var = Rake()
rake_nltk_var.extract_keywords_from_text(statement)
keyword_extracted = rake_nltk_var.get_ranked_phrases()
# combine keywords and deduplicate words
kw = " ".join(set(" ".join([k for k in keyword_extracted]).split(" ")))
embedding_kw = model.encode(kw)
return embedding_kw
# returns embeddings for a statement
#
def get_embeddings(statement, model=None):
"""
:param statement: Input
:param model: Embeddings model to use (default specified below)
:return: Returns text embedding for a statement
"""
if model is None:
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
......@@ -31,9 +17,15 @@ def get_embeddings(statement, model=None):
return embeddings
# must not be in tensor form
#
# returns similarity between embeddings (inputs must be 1 dimensional)
def cos_sim(embeddings1, embeddings2):
"""
Calculate cosine similarity score between text embeddings.
:param embeddings1: text embeddding 1 (must not be in tensor form)
:param embeddings2: text embedding 2 (must not be in tensor form)
:return: Cosine similarity score.
"""
embeddings1 = torch.tensor(embeddings1)
embeddings2 = torch.tensor(embeddings2)
......
......@@ -6,10 +6,22 @@ decay_factor = 0.995
def filter_memories(memories, queries, datetime, w_recency=0.5, w_importance=1, w_relevance=3):
"""
Based on a set of query statements, performs a sorting of the input memories based on the factors: recency,
importance and relevance. The score for each memory is calculated for each query statement separately and the
highest score is used.
:param memories: Set of input memories.
:param queries: Set of query statements.
:param datetime: Date/time string.
:param w_recency: Recency weight.
:param w_importance: Importance weight.
:param w_relevance: Relevance weight.
:return: List of sorted memories by score.
"""
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
# queries = parse_dt.replace_all(datetime, queries)
# print(queries)
# perform datetime replacement for the most recent user input and add it to the list of query memories.
queries.append(parse_dt.replace_all(datetime, queries[-1:])[0])
embedding_memories = [embeddings.get_embeddings(m["Content"], model) for m in memories]
embedding_queries = [embeddings.get_embeddings(q, model) for q in queries]
......@@ -32,13 +44,17 @@ def filter_memories(memories, queries, datetime, w_recency=0.5, w_importance=1,
return memories
def memories_to_string(memories, include_date_created=False, include_nodeId=False):
def memories_to_string(memories, include_nodeId=False):
"""
Converts a list of input memories to a string according to the specified format.
:param memories: List of memories.
:param include_nodeId: Whether the node-id should be prepended to the memory item.
:return: A string of all memories.
"""
mem_string = ""
print(memories)
for mem in memories:
prefix = "- "
prefix += (str(mem["NodeId"]) + ". ") if include_nodeId else ""
prefix += (("[" + mem["LastAccessed"] + "] ") if include_date_created else " ")
mem_string += prefix + mem["Content"].strip() + "\n"
return mem_string
......@@ -8,6 +8,10 @@ import error
class Model:
"""
The model class is a further abstraction layer and acts as an interface for inference using ChatGPT and local
models.
"""
def __init__(self, model_name, model_file, model_tok, temperature, allow_system_prompt,
ctx_size=8192, api_key="", use_chat_gpt=False,):
self.model_name = model_name
......
import dateutil.parser
import parsedatetime as pdt
def replace_all(dt, statements):
"""
Replaces the first occurrence of a referenced date/time in natural language with a date/time-string of the specified
format.
:param dt: input datetime as a string (usually the current date/time)
:param statements: a list of strings to replace the date/time.
:return: The input statements with replaced date/time strings.
"""
cal = pdt.Calendar()
dt = dateutil.parser.parse(dt)
......
......@@ -15,6 +15,11 @@ import util
def parse_request(request):
"""
Parses incoming request packets and delegates to appropriate routines. Loads a Model object for the specified model.
:param request: Request packet.
:return: Response packet.
"""
try:
request = json.loads(request)
requestType = request["Type"]
......@@ -77,6 +82,12 @@ def parse_request(request):
def query_chat(request, llm: model.Model):
"""
Handles the Chat Request.
:param request: Request packet.
:param llm: Model to use.
:return: Response packet.
"""
NUM_MEMORIES_TO_USE = 25
parameters = request["Data"]
......@@ -111,7 +122,6 @@ def query_chat(request, llm: model.Model):
if res is not False:
break
else:
print("LIMIT EXCEEDED, RETRYING")
chat_messages = chat_messages[2:]
memories_accessed = [str(mem["NodeId"]) for mem in memories[:NUM_MEMORIES_TO_USE]]
......@@ -120,6 +130,12 @@ def query_chat(request, llm: model.Model):
def query_reflection(request, llm: model.Model):
"""
Handles the Reflection Request.
:param request: Request packet.
:param llm: Model to use.
:return: Response packet.
"""
parameters = request["Data"]
memories = request["Memories"]
parameters["Memories"] = memory_util.memories_to_string(memories)
......@@ -140,8 +156,6 @@ def query_reflection(request, llm: model.Model):
memory_util.filter_memories(memories, [q], parameters["DateTime"], 0, 0, 1)
parameters["Memories"] = memory_util.memories_to_string(memories[:5], include_nodeId=True)
print("Question: ", q)
messages_a = [
{"role": "user",
"content": util.load_template("reflection_a").format(**parameters)},
......@@ -153,15 +167,16 @@ def query_reflection(request, llm: model.Model):
insights.extend(insights_q)
print("QUESTION:")
print(q)
for x in insights_q:
print(" - " + x)
return {"Memories": insights}
def query_poignancy(request, llm: model.Model):
"""
Handles the Poignancy Request.
:param request: Request packet.
:param llm: Model to use.
:return: Response packet.
"""
parameters = request["Data"]
messages = [
{"role": "user",
......@@ -174,14 +189,24 @@ def query_poignancy(request, llm: model.Model):
def add_poignancy_to_memories(memories, llm: model.Model):
"""
Adds a poignancy score to a new memory element using the poignancy template.
:param memories: Memories, for which the poignancy shall be determined.
:param llm: Model to use.
:return: List of memories with added poignancy score.
"""
for i in range(len(memories)):
memories[i] += "~" + query_poignancy({"Data": {"Memory": memories[i]}}, llm)["Text"]
# deduplicates a list of statements using their similarity score
# deduplication takes length of statements into account and will keep the
# longer one (potentially containing more information, but not guaranteed)
def deduplicate_sim(statements, threshold=0.85):
"""
Deduplicates a list of statements using their similarity score. Deduplication takes length of statements
into account and will keep the longer one (potentially containing more information, but not guaranteed)
:param statements: List of statements to deduplicate.
:param threshold: Similarity threshold for comparison.
:return: Deduplicated list.
"""
st_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
# calculate embedding for each input statement
......@@ -215,8 +240,15 @@ def deduplicate_sim(statements, threshold=0.85):
return dedup
# remove statements that are already present in the agent's memory based on similarity
#
def deduplicate_sim_mem(statements, memories, threshold=0.85):
"""
Remove statements that are already present in the agent's memory based on similarity.
:param statements: List of new statements to check.
:param memories: List of memories already in the memory stream.
:param threshold: Similarity threshold for comparison.
:return: Deduplicated list.
"""
st_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
# convert input memories and statements to embedding
......@@ -239,6 +271,15 @@ def deduplicate_sim_mem(statements, memories, threshold=0.85):
def query_chat_extract(template, template_val, request, llm: model.Model, ctx_window=20):
"""
Performs chat extraction using the specified templates.
:param template: The template to use for extraction.
:param template_val: The template to use for validation of extracted statements.
:param request: The request packet.
:param llm: The model to use.
:param ctx_window: The size of the sliding window for extraction.
:return: List of extracted statements.
"""
parameters = request["Data"]
chat = parameters["Conversation"]
lines = chat.split("\n")
......@@ -266,19 +307,20 @@ def query_chat_extract(template, template_val, request, llm: model.Model, ctx_wi
extracted.extend(res)
# use a deduplication prompt/sim to remove duplicate statements from collected list
if ctx_window < len(lines): # only need to deduplicate when ctx window is smaller than convo length
# statements = query_deduplicate(statements, llm)
extracted = deduplicate_sim(extracted)
# validate extracted statements against original conversation
# parameters["Conversation"] = chat
# extracted = conversation_validate_statements(parameters, statements, llm)
return extracted
def query_chat_summary(request, llm: model.Model):
"""
Handles the ChatSummary Request.
:param request: Request packet.
:param llm: Model to use.
:return: Response packet.
"""
parameters = request["Data"]
new_memories = []
......@@ -295,11 +337,8 @@ def query_chat_summary(request, llm: model.Model):
agent_facts = query_chat_extract("chat_summary_agent_facts", "chat_validate_statement", request, llm)
new_memories.extend(agent_facts)
print(len(new_memories))
new_memories = deduplicate_sim(new_memories, 0.75)
print(len(new_memories))
new_memories = deduplicate_sim_mem(new_memories, request["Memories"], 0.75)
print(len(new_memories))
add_poignancy_to_memories(new_memories, llm)
messages_summary = [
......@@ -314,6 +353,12 @@ def query_chat_summary(request, llm: model.Model):
def query_chat_extract_plan(request, llm: model.Model):
"""
Handles the ChatExtractPlan Request.
:param request: Request packet.
:param llm: Model to use.
:return: Response packet.
"""
parameters = request["Data"]
# alternative 1
......@@ -333,6 +378,11 @@ def query_chat_extract_plan(request, llm: model.Model):
def parse_datetime(statement):
"""
Parses a statement for a datetime occurrence based on the format specified by the RegEx.
:param statement: Statement to search.
:return: The dt object or None.
"""
try:
datetime_str = re.search(
"[A-Z][a-z]+, [A-Z][a-z]+ [0-9]+, [0-9]{4}(, [0-9]{1,2}:[0-9]{1,2}( )*((pm|am)|(AM|PM)))?",
......@@ -345,9 +395,14 @@ def parse_datetime(statement):
return None
# takes a list of plan statements, parses the referenced time and date and removes past events.
# additionally returns true if any of the plans are for the current date (for replanning)
def plans_validate(statements, c_datetime):
"""
Takes a list of plan statements, parses the referenced time and date and removes past events. Additionally, returns
true if any of the plans are for the current date (for replanning)
:param statements: List of plans to validate.
:param c_datetime: The current datetime.
:return: List of valid future plans and an indication of whether one of the plans takes place on the current day.
"""
any_plans_for_today = False
valid = []
c_datetime = dateutil.parser.parse(c_datetime)
......@@ -364,18 +419,13 @@ def plans_validate(statements, c_datetime):
return valid, any_plans_for_today
# obsolete
def get_calendar():
calendar = ""
start_date = datetime.now() - timedelta(days=14)
for i in range(14):
date = start_date + timedelta(days=i)
date_str = date.strftime("%A, %B %d, %Y")
calendar += date_str + "\n"
return calendar
def generate_context(request, llm: model.Model):
"""
Handles the Context Request.
:param request: Request packet.
:param llm: Model to use.
:return: Response packet.
"""
parameters = request["Data"]
memories = request["Memories"]
memories.sort(key=lambda x: x["HrsSinceCreation"], reverse=True)
......@@ -401,12 +451,16 @@ def generate_context(request, llm: model.Model):
context += action
print(context)
return {"Text": context}
def query_relationship(request, llm: model.Model):
"""
Performs the relationship query used for the Context request.
:param request: Request packet.
:param llm: Model to use.
:return: Relationship response.
"""
parameters = request["Data"]
memories = request["Memories"]
......@@ -423,6 +477,12 @@ def query_relationship(request, llm: model.Model):
def query_agent_action(request, llm: model.Model):
"""
Determines the agents current action for the Context request.
:param request: Request packet.
:param llm: Model to use.
:return: Agent's action response.
"""
parameters = request["Data"]
messages = [
{"role": "user",
......@@ -435,6 +495,12 @@ def query_agent_action(request, llm: model.Model):
def query_plan_day(request, llm: model.Model):
"""
Handles the PlanDay Request.
:param request: Request packet.
:param llm: Model to use.
:return: Response packet.
"""
parameters = request["Data"]
memories = request["Memories"]
......@@ -447,33 +513,16 @@ def query_plan_day(request, llm: model.Model):
day_plan = llm.query(messages)
"""
for h in range(24):
time_start = datetime.strptime(str(h) + ":00", "%H:%M")
time_start = time_start.strftime("%I:%M%p")
time_end = datetime.strptime(str((h + 1) % 24) + ":00", "%H:%M")
time_end = time_end.strftime("%I:%M%p")
messages = [
{"role": "user",
"content": util.load_template("plan_day_decomp").format(time_start=time_start, time_end=time_end,
memories=memories_str, plan=day_plan,
**parameters)},
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False)
res = time_start + "-" + time_end + ":" + run_query(prompt, llm)
print()
print(res)
print()
"""
return {"Text": day_plan}
def query_replan_day(request, llm: model.Model):
"""
Handles the ReplanDay Request.
:param request: Request packet.
:param llm: Model to use.
:return: Response packet.
"""
parameters = request["Data"]
memories = request["Memories"]
......@@ -490,6 +539,12 @@ def query_replan_day(request, llm: model.Model):
def query_custom(request, llm):
"""
Handles the Custom Request.
:param request: Request packet.
:param llm: Model to use.
:return: Response packet.
"""
parameters = request["Data"]
messages = [
{"role": "user",
......@@ -500,8 +555,15 @@ def query_custom(request, llm):
return {"Text": response}
# returns a list of validated statements
def validate_statements(template, statements, parameters, llm: model.Model):
def conversation_validate_statements(template, statements, parameters, llm: model.Model):
"""
# Returns a list of validated statements based on the LLM response.
:param template: The validation template to use.
:param statements: A list of statements to validate.
:param parameters: A parameters dictionary containing all data needed for the query (as defined in the template).
:param llm: The model to use.
:return: List of validated statements.
"""
valid = []
for statement in statements:
message_validate = [
......@@ -513,22 +575,4 @@ def validate_statements(template, statements, parameters, llm: model.Model):
if "yes" in val.lower():
valid.append(statement)
return valid
# returns a list of validated statements
def conversation_validate_statements(template, parameters, statements, llm: model.Model):
valid = []
for statement in statements:
message_validate = [
{"role": "user",
"content": util.load_template(template).format(statement=statement,
**parameters)},
]
val = llm.query(message_validate)
print(statement, val)
if "yes" in val.lower():
valid.append(statement)
return valid
\ No newline at end of file
# https://realpython.com/python-sockets/
import json
import socket
import threading
import time
import queries
import select
# Define host and port here.
HOST = "127.0.0.1"
PORT = 65432
# Define the maximum number of sessions here.
MAX_CONNECTIONS = 60
connections = []
......@@ -20,6 +22,10 @@ STREAM_BUFFER_SIZE = 8192
class ClientHandler(threading.Thread):
"""
A ClientHandler thread is responsible for communicating with one client, transmitting the requests to the
RequestHandler queue and returning the result.
"""
def __init__(self, conn, addr, r):
super().__init__()
self._stop_event = threading.Event()
......@@ -31,11 +37,19 @@ class ClientHandler(threading.Thread):
self.request_handler = r
def signal_response(self, response):
"""
Used to signal the ClientHandler thread to continue execution when the result arrives.
:param response: The JSON-response as a string.
"""
with self.response_condition:
self.response = response
self.response_condition.notify()
def run(self):
"""
Action loop for the ClientHandler thread. Continually waits for requests from the client. Puts ClientHandler to
sleep once request is received and has been transmitted to RequestHandler.
"""
print(f"Connected with {self.addr}")
with self.conn:
while not self._stop_event.is_set():
......@@ -45,8 +59,6 @@ class ClientHandler(threading.Thread):
while True:
try:
received = self.conn.recv(STREAM_BUFFER_SIZE)
# print("Received: " + str(len(received)))
# print(received)
total_rec += len(received)
if len(received) == 0:
raise Exception
......@@ -57,11 +69,9 @@ class ClientHandler(threading.Thread):
if len(data) == 0:
continue
# print("Total data receieved: " + str(total_rec))
request = data.decode("utf-8")
with self.response_condition:
print("ADDING PACKET TO QUEUE")
self.request_handler.queue_request(request, self)
self.response_condition.wait()
......@@ -80,6 +90,10 @@ class ClientHandler(threading.Thread):
class RequestHandler(threading.Thread):
"""
The RequestHandler thread is responsible for handling all client requests and transmitting the results back to the
ClientHandler threads.
"""
def __init__(self):
super().__init__()
self._stop_event = threading.Event()
......@@ -87,11 +101,20 @@ class RequestHandler(threading.Thread):
self.request_condition = threading.Condition()
def queue_request(self, request, client):
"""
Used add a request to the queue and signal the RequestHandler thread to continue execution.
:param request: The request packet.
:param client: The ClientHandler thread which should be signalled upon request fulfillment.
"""
with self.request_condition:
self.request_queue.append((request, client))
self.request_condition.notify()
def run(self):
"""
Action loop for RequestHandler. Continually handles requests as long as queue is not empty, otherwise goes to
sleep. Signals ClientHandler threads once response is ready.
"""
while not self._stop_event.is_set():
with self.request_condition:
while len(self.request_queue) <= 0 and not self._stop_event.is_set():
......@@ -113,6 +136,9 @@ class RequestHandler(threading.Thread):
class ExitHandler(threading.Thread):
"""
Separate thread to enable easy termination of server by pressing the "e" key.
"""
def stop(self):
pass
......@@ -126,6 +152,10 @@ class ExitHandler(threading.Thread):
def start_server():
"""
Main server routine. Starts RequestHandler and ExitHandler threads at the start and
ClientHandler threads for every connected client. Keeps track of maximum connections.
"""
global s
e = ExitHandler()
r = RequestHandler()
......
......@@ -5,26 +5,33 @@ TEMPLATES_PATH = "llm_server/templates/"
def load_template(template):
"""
:param template: template file name
:return: content of template file as string
"""
return open(TEMPLATES_PATH + template + ".txt").read()
MISTRAL_TOK = "mistralai/Mistral-7B-Instruct-v0.1"
MISTRAL_GGUF = "TheBloke/Mistral-7B-Instruct-v0.1-GGUF"
MISTRAL_FILE = "mistral-7b-instruct-v0.1.Q4_K_M.gguf"
LLAMA_TOK = "meta-llama/Llama-2-7b-chat-hf"
LLAMA_API_KEY = "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"
LLAMA_NAME = "TheBloke/Llama-2-7B-Chat-GGUF"
LLAMA_FILE = "llama-2-7b.Q4_K_M.gguf"
KW_EN_CORE = "X:/LLM Models/nlm/en_core_web_sm-3.7.1-py3-none-any.whl"
def load_tokenizer(model, token=""):
"""
Loads the tokenizer for the specified model
:param model: huggingface repo url
:param token: huggingface repo access token
:return: tokenizer object
"""
return AutoTokenizer.from_pretrained(model, token=token)
def load_model(model, ctx_size, temperature):
"""
Loads the specified large language model.
:param model: huggingface repo url
:param ctx_size: maximum input/output tokens (input maximum is 4096 for most models)
:param temperature: model temperature
:return: llm object to use for inference
"""
llm = LlamaCpp(
model_path=model,
max_tokens=500,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment