Skip to content
Snippets Groups Projects
queries.py 12.67 KiB
import json
import math

from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate

import memory_util
import util

from datetime import datetime


# parses a query request
# format:
# query_type~system_prompt

def parse_request(request):
    request = json.loads(request)
    requestType = request["type"]

    if requestType == "chat":
        return query_chat(request, util.load_model(util.LLAMA_PATH, 4096, 0.75),
                          util.load_tokenizer(util.LLAMA_TOK))
    elif requestType == "chat_summary":
        return query_chat_summary(request, util.load_model(util.MISTRAL_PATH, 4096, 0),
                                  util.load_tokenizer(util.MISTRAL_TOK))
    elif requestType == "chat_extract_plan":
        return query_chat_extract_plan(request, util.load_model(util.MISTRAL_PATH, 4096, 0),
                                       util.load_tokenizer(util.MISTRAL_TOK))
    elif requestType == "reflection":
        return query_reflection(request, util.load_model(util.MISTRAL_PATH, 4096, 0),
                                util.load_tokenizer(util.MISTRAL_TOK))
    elif requestType == "poignancy":
        return query_poignancy(request, util.load_model(util.MISTRAL_PATH, 4096, 0),
                               util.load_tokenizer(util.MISTRAL_TOK))
    elif requestType == "context":
        return generate_context(request)
    elif requestType == "knowledge":
        return query_knowledge(request, util.load_model(util.MISTRAL_PATH, 4096, 0),
                               util.load_tokenizer(util.MISTRAL_TOK))
    elif requestType == "plan_day":
        return query_plan_day(request, util.load_model(util.MISTRAL_PATH, 4096, 0),
                              util.load_tokenizer(util.MISTRAL_TOK))

    return "ERROR"


def query_chat(request, llm, tokenizer):
    parameters = request["data"]
    chat = parameters["chat"].split("~")

    memories = request["memories"]
    memory_util.filter_memories(memories, parameters["user"] + ";" + chat[-1])

    parameters["memories"] = memory_util.memories_to_string(memories[:5], True)  # TODO

    print("\n" + memory_util.memories_to_string(memories[:5], True) + "\n")

    messages = [
        {"role": "system",
         "content": util.load_template("chat_system").format(**parameters)},
    ]

    roles = ["user", "assistant"]

    for i in range(len(chat)):
        messages.append({"role": roles[i % 2], "content": chat[i]})

    prompt = tokenizer.apply_chat_template(messages, tokenize=False)
    print()
    print(prompt)
    print()

    res = run_query(prompt, llm)
    memories_accessed = [str(mem["NodeId"]) for mem in memories[:5]]
    json_res = json.dumps({"response": res,
                           "data": {"memories_accessed": str.join(",", memories_accessed)}})

    return json_res


def query_reflection(request, llm, tokenizer):
    parameters = request["data"]
    memories = request["memories"]
    parameters["memories"] = memory_util.memories_to_string(memories, include_nodeId=True)

    messages = [
        {"role": "user",
         "content": util.load_template("reflection_a").format(**parameters)},
    ]

    prompt = tokenizer.apply_chat_template(messages, tokenize=False)

    res = run_query(prompt, llm)
    insights = [s.replace("- ", "") for s in res.split("\n") if "- " in s]

    json_res = json.dumps({"memories": insights})
    return json_res


def query_poignancy(request, llm, tokenizer):
    parameters = request["data"]
    messages = [
        {"role": "user",
         "content": util.load_template("poignancy").format(**parameters)},
    ]

    prompt = tokenizer.apply_chat_template(messages, tokenize=False)

    res = run_query(prompt, llm)

    json_res = json.dumps({"response": res})
    return json_res


def query_chat_summary(request, llm, tokenizer):
    parameters = request["data"]
    messages_summary = [
        {"role": "user",
         "content": util.load_template("chat_summary_single").format(**parameters)},
    ]

    messages_user = [
        {"role": "user",
         "content": util.load_template("chat_summary_user").format(**parameters)},
    ]

    messages_agent = [
        {"role": "user",
         "content": util.load_template("chat_summary_agent").format(**parameters)},
    ]

    prompt_summary = tokenizer.apply_chat_template(messages_summary, tokenize=False)
    prompt_user = tokenizer.apply_chat_template(messages_user, tokenize=False)
    prompt_agent = tokenizer.apply_chat_template(messages_agent, tokenize=False)

    print("------- SUMMARY: ")
    summary = run_query(prompt_summary, llm)
    summary = parameters["agent"] + " had a conversation with " + parameters["user"] + " - they talked about " + summary
    print()

    print("------- USER: ")
    summary_user = run_query(prompt_user, llm)
    summary_user_list = [s.replace("- ", "") for s in summary_user.split("\n") if "- " in s]

    if "no valuable information" not in summary_user.lower():
        summary_user_list = conversation_validate_statements(parameters, summary_user_list, llm, tokenizer)

    print("------- AGENT: ")
    summary_agent = run_query(prompt_agent, llm)
    summary_agent_list = [s.replace("- ", "") for s in summary_agent.split("\n") if "- " in s]
    if "no valuable information" not in summary_agent.lower():
        summary_agent_list = conversation_validate_statements(parameters, summary_agent_list, llm, tokenizer)

    new_memories = [summary]
    new_memories.extend(summary_user_list)
    new_memories.extend(summary_agent_list)

    json_res = json.dumps({"memories": new_memories})
    return json_res


def query_chat_extract_plan(request, llm, tokenizer):
    parameters = request["data"]

    messages = [
        {"role": "user",
         "content": util.load_template("chat_extract_plan").format(**parameters)},
    ]

    prompt = tokenizer.apply_chat_template(messages, tokenize=False)

    res = run_query(prompt, llm)
    if "no plan" in res.lower():
        res = ""

    json_res = json.dumps({"response": res})
    return json_res


def generate_context(request):
    parameters = request["data"]
    memories = request["memories"]
    memories.sort(key=lambda x: x["HrsSinceCreation"], reverse=True)

    # if the agent has no memory associated with the user, then they have never had conversation
    if len(memories) == 0:
        return json.dumps({"response": parameters["agent"] + " is having a conversation with someone they "
                                                             "never met before."})

    # agent's current action based on their schedule
    action = query_agent_action(request, util.load_model(util.MISTRAL_PATH, 4096, 0),
                                util.load_tokenizer(util.MISTRAL_TOK))

    # when did the agent last talk to user?
    lastChatHrs = int(math.ceil(memories[-1]["HrsSinceCreation"]))
    last_chat = (parameters["agent"] + " last talked to " + parameters["user"] + " on " + memories[-1]["Created"]
                 + " - " + str(lastChatHrs) + " " + ("hour" if lastChatHrs == 1 else "hours") + " ago.")

    # what is the relationship between agent and user?
    relationship = query_relationship(request, util.load_model(util.MISTRAL_PATH, 4096, 0),
                                      util.load_tokenizer(util.MISTRAL_TOK))

    return json.dumps({"response": last_chat + " " + relationship})


def query_relationship(request, llm, tokenizer):
    parameters = request["data"]
    memories = request["memories"]

    memories_str = memory_util.memories_to_string(memories, include_date_created=True)

    messages = [
        {"role": "user",
         "content": util.load_template("relationship").format(memories=memories_str, **parameters)},
    ]

    prompt = tokenizer.apply_chat_template(messages, tokenize=False)
    relationship = parameters["agent"] + "'s relationship with " + parameters["user"] + " is " + run_query(prompt, llm)

    return relationship


def query_agent_action(request, llm, tokenizer):
    parameters = request["data"]
    messages = [
        {"role": "user",
         "content": util.load_template("agent_action").format(**parameters)},
    ]

    prompt = tokenizer.apply_chat_template(messages, tokenize=False)
    action = run_query(prompt, llm)

    return action


def query_knowledge(request, llm, tokenizer):
    parameters = request["data"]
    messages = [
        {"role": "user",
         "content": util.load_template("knowledge_summary").format(**parameters)},
    ]

    prompt = tokenizer.apply_chat_template(messages, tokenize=False)

    res = run_query(prompt, llm)
    json_res = json.dumps({"response": res})
    return json_res


def query_plan_day(request, llm, tokenizer):
    parameters = request["data"]
    memories = request["memories"]

    memories_str = memory_util.memories_to_string(memories)

    messages = [
        {"role": "user",
         "content": util.load_template("plan_day").format(**parameters)},
    ]

    prompt = tokenizer.apply_chat_template(messages, tokenize=False)

    print(prompt)

    rough_plan = run_query(prompt, llm)
    # plans = [s.replace("- ", "") for s in res.split("\n") if "- " in s]

    for h in range(24):
        time_start = datetime.strptime(str(h)+":00", "%H:%M")
        time_start = time_start.strftime("%I:%M%p")
        time_end = datetime.strptime(str((h+1) % 24)+":00", "%H:%M")
        time_end = time_end.strftime("%I:%M%p")

        messages = [
            {"role": "user",
             "content": util.load_template("plan_day_decomp").format(time_start=time_start, time_end=time_end,
                                                                     memories=memories_str, plan=rough_plan,
                                                                     **parameters)},
        ]

        prompt = tokenizer.apply_chat_template(messages, tokenize=False)

        res = time_start + "-" + time_end + ":" + run_query(prompt, llm)

        print()
        print(res)
        print()

    exit(1)
    json_res = json.dumps({"response": res})
    return json_res

# returns a list of validated statements
def conversation_validate_statements(parameters, statements, llm, tokenizer):
    valid = []
    for statement in statements:
        message_validate = [
            {"role": "user",
             "content": util.load_template("chat_validate_statement").format(statement=statement,
                                                                                  **parameters)},
        ]

        prompt_validate = tokenizer.apply_chat_template(message_validate, tokenize=False)

        print("IS THE STATEMENT: '" + statement + "' CORRECT?: ")
        val = run_query(prompt_validate, llm)
        if "yes" in val.lower():
            valid.append(statement)
        print()

    return valid


def run_query(prompt, llm):
    llm_chain = LLMChain(
        prompt=PromptTemplate.from_template(prompt),
        llm=llm,
        verbose=False
    )

    return llm_chain.run({}).strip()


d = {"type": "chat_summary",
     "data": {
         "agent": "John Lin",
         "user": "James",
         "conversation": '''
James: Hi, I'm James. Is the pharmacy still open?
John Linn: Ah, good evening! *adjusts glasses* Oh, hello there! Yes, the pharmacy is still open, my friend! *smiles widely* It's Thursday, April 4th, 2024, and it's currently 12:54 AM. How may I assist you? *leaning forward, eager to help*
James: I need some Aspirin for my sister. She is having a bad headache. Can you help me there?
John Linn: Of course, I'd be happy to help! *reaches for Aspirin box* Ah, yes, Aspirin... *takes out a few tablets* Let me just check if we have any... *checks inventory* Yes, we have plenty! *hands you a couple of tablets* Here you go, my friend! *smiles* Take these to your sister and give her one when she needs it. *winks* Hope she feels better soon!
James: Thank you so much. How much is it?
John Linn: Of course, of course! *glances at price tag* The Aspirin costs $10.99 for a small bottle of 20 tablets. *smiles* It's a bit on the pricey side, I know, but it's high-quality stuff! *winks* Would you like me to pour you a small bottle? *offers a bag*
James: I only have cash, so here you go. Have a nice day
John Linn: Of course, no problem at all! *smiles* I hope your sister feels better soon. *pauses* It's always good to have some Aspirin on hand, you know. *winks* It's a good idea to keep some medication in the house, just in case. *nods* Do you need any more advice or recommendations?
        '''
     },
     "memories": []}

# query_chat_summary(d, util.load_model(util.MISTRAL_PATH, 4096, 0))