Skip to content
Snippets Groups Projects
Commit 3d26999f authored by Giner, Aaron's avatar Giner, Aaron
Browse files

some model querying rewrites for better customizability. added openai api usage capability

parent b7334aff
No related branches found
No related tags found
No related merge requests found
from langchain.chains import LLMChain
from langchain_core.prompts import PromptTemplate
from openai import OpenAI
import util
class Model:
def __init__(self, model_name, model_path, temperature, ctx_size=4096, api_key="", use_chat_gpt=False):
self.model_name = model_name
self.model_path = model_path
self.temperature = temperature
self.ctx_size = ctx_size
self.api_key = api_key
self.use_chat_gpt = use_chat_gpt
self.llm = None
if not self.use_chat_gpt:
self.llm = util.load_model(model=self.model_path, ctx_size=self.ctx_size, temperature=temperature)
self.tokenizer = util.load_tokenizer(model=self.model_name, token=util.LLAMA_API_KEY)
def query(self, messages, debug=False):
if self.use_chat_gpt:
return self.query_openai(messages=messages)
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
if debug:
print("-------- PROMPT ------- \n" + prompt + "\n ------------------")
assert self.llm is not None
llm_chain = LLMChain(
prompt=PromptTemplate.from_template(prompt),
llm=self.llm,
verbose=False
)
return llm_chain.run({}).strip()
def query_openai(self, messages):
client = OpenAI(api_key=self.api_key)
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=messages,
temperature=self.temperature
)
return response.choices[0].message.content
......@@ -15,7 +15,7 @@ def replace_all(dt, statements):
print(statement)
parsed_nlp = cal.nlp(statement, dt)
parsed_nlp_now = cal.nlp(statement, dt_now)
if parsed_nlp is None:
if parsed_nlp is None or len(parsed_nlp) == 0:
replaced.append(statement)
continue
parsed_nlp = parsed_nlp[0]
......
......@@ -3,57 +3,44 @@ import math
import re
import dateutil.parser
import spacy
import yake
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from rake_nltk import Rake
from sentence_transformers import SentenceTransformer
import memory_util
import util
import embeddings
import parse_dt
import openai_api
import model
from datetime import datetime, timedelta
# parses a query request
# format:
# query_type~system_prompt
def parse_request(request):
request = json.loads(request)
requestType = request["type"]
if requestType == "chat":
return query_chat(request, util.load_model(util.LLAMA_PATH, 4096, 0.75),
util.load_tokenizer(util.LLAMA_TOK))
return query_chat(request, model.Model(util.LLAMA, util.LLAMA_PATH, 0.75, 4096,
"sk-proj-aUDdsiCXHDwoHewZFL9AT3BlbkFJIkKZEYaMi5AGEBDbW2zv", use_chat_gpt=True))
elif requestType == "chat_summary":
return query_chat_summary(request, util.load_model(util.MISTRAL_PATH, 4096, 0),
util.load_tokenizer(util.MISTRAL_TOK))
return query_chat_summary(request, model.Model(util.MISTRAL, util.MISTRAL_PATH, 0, 4096))
elif requestType == "chat_extract_plan":
return query_chat_extract_plan(request, util.load_model(util.MISTRAL_PATH, 4096, 0),
util.load_tokenizer(util.MISTRAL_TOK))
return query_chat_extract_plan(request, model.Model(util.MISTRAL, util.MISTRAL_PATH, 0, 4096))
elif requestType == "reflection":
return query_reflection(request, util.load_model(util.MISTRAL_PATH, 4096, 0),
util.load_tokenizer(util.MISTRAL_TOK))
return query_reflection(request, model.Model(util.MISTRAL, util.MISTRAL_PATH, 0, 4096))
elif requestType == "poignancy":
return query_poignancy(request, util.load_model(util.MISTRAL_PATH, 4096, 0),
util.load_tokenizer(util.MISTRAL_TOK))
return query_poignancy(request, model.Model(util.MISTRAL, util.MISTRAL_PATH, 0, 4096))
elif requestType == "context":
return generate_context(request)
elif requestType == "knowledge":
return query_knowledge(request, util.load_model(util.MISTRAL_PATH, 4096, 0),
util.load_tokenizer(util.MISTRAL_TOK))
return generate_context(request, model.Model(util.MISTRAL, util.MISTRAL_PATH, 0, 4096))
elif requestType == "plan_day":
return query_plan_day(request, util.load_model(util.MISTRAL_PATH, 4096, 0),
util.load_tokenizer(util.MISTRAL_TOK))
return query_plan_day(request, model.Model(util.MISTRAL, util.MISTRAL_PATH, 0, 4096))
return "ERROR"
def query_chat(request, llm, tokenizer):
def query_chat(request, llm: model.Model):
parameters = request["data"]
chat = parameters["chat"].split("~")
......@@ -72,12 +59,7 @@ def query_chat(request, llm, tokenizer):
for i in range(len(chat)):
messages.append({"role": roles[i % 2], "content": chat[i]})
prompt = tokenizer.apply_chat_template(messages, tokenize=False)
print()
print(prompt)
print()
res = run_query(prompt, llm)
res = llm.query(messages)
memories_accessed = [str(mem["NodeId"]) for mem in memories[:5]]
json_res = json.dumps({"response": res,
"data": {"memories_accessed": str.join(",", memories_accessed)}})
......@@ -85,7 +67,7 @@ def query_chat(request, llm, tokenizer):
return json_res
def query_reflection(request, llm, tokenizer):
def query_reflection(request, llm: model.Model):
parameters = request["data"]
memories = request["memories"]
parameters["memories"] = memory_util.memories_to_string(memories, include_nodeId=True)
......@@ -95,40 +77,35 @@ def query_reflection(request, llm, tokenizer):
"content": util.load_template("reflection_a").format(**parameters)},
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False)
res = run_query(prompt, llm)
res = llm.query(messages)
insights = [s.replace("- ", "").strip() for s in res.split("\n") if "- " in s]
json_res = json.dumps({"memories": insights})
return json_res
def query_poignancy(request, llm, tokenizer):
def query_poignancy(request, llm: model.Model):
parameters = request["data"]
messages = [
{"role": "user",
"content": util.load_template("poignancy").format(**parameters)},
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False)
res = run_query(prompt, llm)
res = llm.query(messages)
json_res = json.dumps({"response": res})
return json_res
# deduplicates a list of statement using a llm
def query_deduplicate(statements, llm, tokenizer):
def query_deduplicate(statements, llm: model.Model):
messages = [
{"role": "user",
"content": util.load_template("dedup_statements").format(mems="\n".join(statements))},
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False)
print("DEDUPLICATED")
dedup = run_query(prompt, llm).strip()
dedup = llm.query(messages)
dedup = [s.replace("- ", "").strip() for s in dedup.split("\n") if "- " in s]
return dedup
......@@ -138,7 +115,7 @@ def query_deduplicate(statements, llm, tokenizer):
# deduplication takes length of statements into account and will keep the
# longer one (potentially containing more information, but not guaranteed)
def deduplicate_sim(statements):
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
st_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
dedup = []
while len(statements) > 0:
......@@ -146,7 +123,7 @@ def deduplicate_sim(statements):
s1 = statements[0]
for i in range(1, len(statements)):
s2 = statements[i]
sim = embeddings.cos_sim(embeddings.get_embeddings(s1, model), embeddings.get_embeddings(s2, model))
sim = embeddings.cos_sim(embeddings.get_embeddings(s1, st_model), embeddings.get_embeddings(s2, st_model))
print(sim, s1, s2)
if sim > 0.9:
if len(s2) > len(s1):
......@@ -164,7 +141,7 @@ def deduplicate_sim(statements):
return dedup
def query_chat_summary_extract(template, request, llm, tokenizer, ctx_window=5):
def query_chat_summary_extract(template, request, llm: model.Model, ctx_window=5):
parameters = request["data"]
chat = parameters["conversation"]
lines = chat.split("\n")
......@@ -182,10 +159,7 @@ def query_chat_summary_extract(template, request, llm, tokenizer, ctx_window=5):
"content": util.load_template(template).format(**parameters)},
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False)
print(prompt)
res = run_query(prompt, llm)
res = llm.query(messages)
res = [s.strip() for s in res.split("\n") if "- " in s]
# validate here? maybe works better for longer conversations?
......@@ -194,34 +168,34 @@ def query_chat_summary_extract(template, request, llm, tokenizer, ctx_window=5):
statements.extend(res)
# use a deduplication prompt/sim to remove duplicate statements from collected list
dedup = query_deduplicate(statements, llm, tokenizer)
dedup = deduplicate_sim(dedup)
if ctx_window < len(lines): # only need to deduplicate when ctx window is smaller than convo length
statements = query_deduplicate(statements, llm)
statements = deduplicate_sim(statements)
# validate extracted statements against original conversation
parameters["conversation"] = chat
extracted = conversation_validate_statements(parameters, dedup, llm, tokenizer)
extracted = conversation_validate_statements(parameters, statements, llm)
return extracted
def query_chat_summary(request, llm, tokenizer):
def query_chat_summary(request, llm: model.Model):
parameters = request["data"]
messages_summary = [
{"role": "user",
"content": util.load_template("chat_summary_single").format(**parameters)},
]
prompt_summary = tokenizer.apply_chat_template(messages_summary, tokenize=False)
summary = run_query(prompt_summary, llm)
summary = llm.query(messages_summary)
summary = "Conversation: " + parameters["agent"] + " and " + parameters["user"] + " talked about " + summary
new_memories = [summary]
user_info = query_chat_summary_extract("chat_summary_user_info", request, llm, tokenizer)
user_info = query_chat_summary_extract("chat_summary_user_info", request, llm)
new_memories.extend(user_info)
unrelated_info = query_chat_summary_extract("chat_summary_unrelated", request, llm, tokenizer)
unrelated_info = query_chat_summary_extract("chat_summary_unrelated", request, llm)
new_memories.extend(unrelated_info)
print("SUMMARY")
......@@ -233,23 +207,12 @@ def query_chat_summary(request, llm, tokenizer):
return json_res
def query_chat_extract_plan(request, llm, tokenizer):
def query_chat_extract_plan(request, llm: model.Model):
parameters = request["data"]
# alternative 1
plans = query_chat_summary_extract("chat_extract_plan", request, llm, tokenizer, 10)
# alternative 2
"""
messages = [
{"role": "user",
"content": util.load_template("chat_extract_plan").format(**parameters)},
]
plans = query_chat_summary_extract("chat_extract_plan", request, llm, 10)
prompt = tokenizer.apply_chat_template(messages, tokenize=False)
plans = run_query(prompt, llm)
plans = [plan.strip() for plan in plans.split("\n") if "- " in plan]
"""
if len(plans) == 0:
return json.dumps({"memories": [], "data": {"replan_day": "0"}})
......@@ -261,9 +224,6 @@ def query_chat_extract_plan(request, llm, tokenizer):
plans_validated, any_plans_for_today = plans_validate(plans, parameters["datetime"])
# only needed in alternative 2, as validation is already done for summary extraction
# plans_validated = conversation_validate_statements(parameters, plans_validated, llm, tokenizer)
print("VALIDATE PLANS")
for p in plans_validated:
print(p)
......@@ -318,7 +278,7 @@ def get_calendar():
return calendar
def generate_context(request):
def generate_context(request, llm: model.Model):
parameters = request["data"]
memories = request["memories"]
memories.sort(key=lambda x: x["HrsSinceCreation"], reverse=True)
......@@ -329,8 +289,7 @@ def generate_context(request):
"never met before."})
# agent's current action based on their schedule
action = query_agent_action(request, util.load_model(util.MISTRAL_PATH, 4096, 0),
util.load_tokenizer(util.MISTRAL_TOK))
action = query_agent_action(request, llm)
# when did the agent last talk to user?
lastChatHrs = int(math.ceil(memories[-1]["HrsSinceCreation"]))
......@@ -339,13 +298,12 @@ def generate_context(request):
+ ("hour" if lastChatHrs == 1 else "hours") + " ago.")
# what is the relationship between agent and user?
relationship = query_relationship(request, util.load_model(util.MISTRAL_PATH, 4096, 0),
util.load_tokenizer(util.MISTRAL_TOK))
relationship = query_relationship(request, llm)
return json.dumps({"response": last_chat + " " + relationship})
def query_relationship(request, llm, tokenizer):
def query_relationship(request, llm: model.Model):
parameters = request["data"]
memories = request["memories"]
......@@ -356,40 +314,24 @@ def query_relationship(request, llm, tokenizer):
"content": util.load_template("relationship").format(memories=memories_str, **parameters)},
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False)
relationship = run_query(prompt, llm)
relationship = llm.query(messages)
return relationship
def query_agent_action(request, llm, tokenizer):
def query_agent_action(request, llm: model.Model):
parameters = request["data"]
messages = [
{"role": "user",
"content": util.load_template("agent_action").format(**parameters)},
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False)
action = run_query(prompt, llm)
action = llm.query(messages)
return action
def query_knowledge(request, llm, tokenizer):
parameters = request["data"]
messages = [
{"role": "user",
"content": util.load_template("knowledge_summary").format(**parameters)},
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False)
res = run_query(prompt, llm)
json_res = json.dumps({"response": res})
return json_res
def query_plan_day(request, llm, tokenizer):
def query_plan_day(request, llm: model.Model):
parameters = request["data"]
memories = request["memories"]
......@@ -400,9 +342,7 @@ def query_plan_day(request, llm, tokenizer):
"content": util.load_template("plan_day").format(memories=memories_str, **parameters)},
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False)
day_plan = run_query(prompt, llm)
day_plan = llm.query(messages)
print("Finished day query")
......@@ -434,7 +374,7 @@ def query_plan_day(request, llm, tokenizer):
# returns a list of validated statements
def validate_statements(template, statements, parameters, llm, tokenizer):
def validate_statements(template, statements, parameters, llm: model.Model):
valid = []
for statement in statements:
message_validate = [
......@@ -442,12 +382,8 @@ def validate_statements(template, statements, parameters, llm, tokenizer):
"content": util.load_template(template).format(statement=statement, **parameters)},
]
prompt_validate = tokenizer.apply_chat_template(message_validate, tokenize=False)
print(prompt_validate)
print("DOES THE STATEMENT: '" + statement + "' FIT THE PARAMETERS?: ")
val = run_query(prompt_validate, llm)
val = llm.query(message_validate)
if "yes" in val.lower():
valid.append(statement)
print()
......@@ -456,7 +392,7 @@ def validate_statements(template, statements, parameters, llm, tokenizer):
# returns a list of validated statements
def conversation_validate_statements(parameters, statements, llm, tokenizer):
def conversation_validate_statements(parameters, statements, llm: model.Model):
valid = []
for statement in statements:
message_validate = [
......@@ -465,10 +401,8 @@ def conversation_validate_statements(parameters, statements, llm, tokenizer):
**parameters)},
]
prompt_validate = tokenizer.apply_chat_template(message_validate, tokenize=False)
print("IS THE STATEMENT: '" + statement + "' CORRECT?: ")
val = run_query(prompt_validate, llm)
val = llm.query(message_validate)
if "yes" in val.lower():
valid.append(statement)
print()
......@@ -476,7 +410,16 @@ def conversation_validate_statements(parameters, statements, llm, tokenizer):
return valid
def run_query(prompt, llm):
'''
def run_query(messages, llm, tokenizer, use_chat_gpt=False, debug=False):
if use_chat_gpt:
return openai_api.query_chatgpt(messages, llm.temperature)
prompt = tokenizer.apply_chat_template(messages, tokenize=False)
if debug:
print("-------- PROMPT ------- \n" + prompt + "\n ------------------")
llm_chain = LLMChain(
prompt=PromptTemplate.from_template(prompt),
llm=llm,
......@@ -484,6 +427,7 @@ def run_query(prompt, llm):
)
return llm_chain.run({}).strip()
'''
t1 = """
......@@ -517,56 +461,12 @@ Aaron: No problem at all. Take care, Lisa. See you around!
Lisa: You too, Aaron. Bye for now!
"""
t2 = """
Aaron: Hey Lisa, how's your week going?
Lisa: Hi Aaron, it's been pretty good, thanks for asking. How about you?
Aaron: Not bad at all. Say, I was thinking, would you like to grab a coffee together next Sunday?
Lisa: Oh, that sounds wonderful! I'd love to. Where do you have in mind?
Aaron: There's this cozy café downtown that I've been wanting to try out. How does that sound?
Lisa: Perfect! Count me in. What time were you thinking?
Aaron: How about around 10 in the morning? Does that work for you?
Lisa: Absolutely, that works great for me.
Aaron: Awesome! Looking forward to it. Oh, and speaking of plans, I was also thinking about going hiking on March 25th. Would you be interested in joining me?
Lisa: Hiking sounds like a fantastic idea! March 25th works for me too. Where were you thinking of going?
Aaron: I was considering hiking up at Pine Ridge Trail. It's got some stunning views along the way.
Lisa: That sounds amazing! I've heard great things about Pine Ridge Trail. Count me in for that too.
Aaron: Fantastic! It'll be great to have some company. How about we start around 9 am?
Lisa: Sounds good to me, 9 am it is. I'll make sure to pack some snacks for the hike.
Aaron: Great! See you for coffee next Sunday at 10 am and then for the hike on March 25th at 9 am!
Lisa: Absolutely! Take care until then, Aaron.
Aaron: You too, Lisa. Bye for now!
"""
t3 = """
Aaron: Hey Lisa, how's it going?
Lisa: Hi Aaron, I'm doing well, thanks! How about you?
Aaron: Can't complain. Say, I remember you mentioning your friend John last time we chatted. How's he doing?
Lisa: Oh, John! Yeah, he's been keeping busy. He just got a promotion at work, so he's been really excited about that.
Aaron: That's fantastic news! What's his new role?
Lisa: He's now the marketing manager for his company. It's a big step up from his previous position, and he's thrilled about the opportunity.
Aaron: Sounds like he's been working hard. It's always inspiring to see friends succeed.
Lisa: Absolutely! John's dedication and perseverance have definitely paid off.
Aaron: Do you two hang out often?
Lisa: Not as much as we used to, unfortunately. With his new job and my busy schedule, it's been tough to find time to catch up.
Aaron: I totally get that. It can be challenging to balance work and personal life sometimes.
Lisa: Definitely. But we try to make time for each other whenever we can. Friendship is important, you know?
Aaron: Absolutely, couldn't agree more. Well, I'm glad to hear John's doing well. Hopefully, you both can catch up soon.
Lisa: Thanks, Aaron. I hope so too. So, what have you been up to lately?
Aaron: Oh, you know, the usual. Working, exploring the city, trying out new hobbies here and there.
Lisa: Anything exciting on the horizon?
Aaron: Not particularly, but I'm always on the lookout for new adventures. Maybe we can all hang out sometime, including John.
Lisa: That sounds like a plan! I'm sure John would appreciate catching up with you too.
Aaron: Great! Let's make it happen then. It's been great chatting with you, Lisa.
Lisa: Likewise, Aaron. Thanks for asking about John. Take care!
Aaron: You too, Lisa. Talk to you soon!
"""
d = {"type": "chat_summary",
"data": {
"agent": "Aaron",
"datetime": "Sunday, April 7, 2024, 3:31pm",
"user": "Lisa",
"conversation": t2
"conversation": t1
},
"memories": []}
......
{agent} had a conversation with {user}.
From the conversation, list 0-2 things we learn about people other than {agent} or {user}
Each item must with "-".
Each item must with "-".
Each item must with "-".
Each item must with "-".
From the conversation, list interesting things we learn that are unrelated to {agent} and {user}.
Each item must start with "-".
Each item must start with "-".
Each item must start with "-".
Each item must start with "-".
Example:
- Lukas' mom had a heart attack
- James' mom had a heart attack
- The chiefs won the superbowl
- The Graz 99ers played well yesterday
......
{agent} had a conversation with {user}.
List the 2 most important facts about {user}.
Each item must with "- {user}".
Each item must with "- {user}".
Each item must with "- {user}".
Each item must with "- {user}".
Each item must start with "- {user}".
Each item must start with "- {user}".
Each item must start with "- {user}".
Each item must start with "- {user}".
Example:
- {user} has a sister called Susan.
......
Your task is to summarize the most important pieces of information {agent} has about {user} given a chronological list of memories and insights belonging to {agent}.
Each memory item includes a time and date of when the memory was acquired.
Return the summary from a third-person perspective of {agent}'s point of view.
If the list is empty or the memories do not contain any relevant information about {user}, return 'No information'.
# Memories:
{memories}
# Information Summary:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import util
from huggingface_hub import hf_hub_download
model_name1 = "TheBloke/Mistral-7B-v0.1-GGUF"
model_name = "mistralai/Mistral-7B-Instruct-v0.2"
model_file = "mistral-7b-v0.1.Q4_K_M.gguf"
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForCausalLM.from_pretrained(model_name, token=util.LLAMA_API_KEY)
model.to(device)
tokenizer = AutoTokenizer.from_pretrained(util.MISTRAL_TOK)
messages = [
{
"role": "user",
"content": "You are a friendly chatbot who always responds in the style of a pirate",
},
{"role": "assistant", "content": "Understood."},
{"role": "user", "content": "How many helicopters can a human eat in one sitting?"},
]
tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(device)
tokenized_chat.to(device)
generated_ids = model.generate(tokenized_chat, max_new_tokens=1000, do_sample=True)
decoded = tokenizer.batch_decode(generated_ids)
print(decoded[0])
\ No newline at end of file
......@@ -15,16 +15,16 @@ LLAMA_PATH_13b = "X:/LLM Models/llama-2-13b-chat.Q4_K_M.gguf"
MISTRAL_LARGE_PATH = "X:/LLM Models/mistral-7b-instruct-v0.1.Q5_K_M.gguf"
LUNA_UNC_PATH = "X:/LLM Models/luna-ai-llama2-uncensored.Q4_K_M.gguf"
MISTRAL_TOK = "mistralai/Mistral-7B-Instruct-v0.1"
LLAMA_TOK = "meta-llama/Llama-2-7b-chat-hf"
MISTRAL = "mistralai/Mistral-7B-Instruct-v0.1"
LLAMA = "meta-llama/Llama-2-7b-chat-hf"
LLAMA_API_KEY = "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"
KW_EN_CORE = "X:/LLM Models/nlm/en_core_web_sm-3.7.1-py3-none-any.whl"
def load_tokenizer(model):
def load_tokenizer(model, token=""):
if "llama" in model:
return AutoTokenizer.from_pretrained(model, token=LLAMA_API_KEY)
return AutoTokenizer.from_pretrained(model, token=token)
else:
return AutoTokenizer.from_pretrained(model)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment