Updating the LLM eval agent to work with Ollama

This commit is contained in:
Cole Medin
2024-11-13 09:08:59 -06:00
parent db25e16fde
commit 3900e0adeb
3 changed files with 27 additions and 7 deletions

View File

@@ -28,7 +28,12 @@ HUGGINGFACEHUB_API_TOKEN=
# And all Anthropic models you can use here -
# https://docs.anthropic.com/en/docs/about-claude/models
# A good default to go with here is gpt-4o-mini, claude-3-5-sonnet-20240620, or llama3-groq-70b-8192-tool-use-preview
LLM_MODEL=gpt-4o-mini
LLM_MODEL=
# The provider you want to use for your LLMs
# If you don't specify this, the script will try to determine your provider from your model, but this isn't guaranteed to work!
# Possible providers: openai, anthropic, groq, ollama
LLM_PROVIDER=
# Get your personal Asana access token through the developer console in Asana.
# Feel free to follow these instructions -

View File

@@ -10,6 +10,7 @@ import json
import os
from langchain_groq import ChatGroq
from langchain_ollama import ChatOllama
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_core.messages import ToolMessage, AIMessage
@@ -21,6 +22,14 @@ from tools.vector_db_tools import available_vector_db_functions
load_dotenv()
model = os.getenv('LLM_MODEL', 'gpt-4o')
provider = os.getenv('LLM_PROVIDER', 'auto')
provider_mapping = {
"openai": ChatOpenAI,
"anthropic": ChatAnthropic,
"ollama": ChatOllama,
"llama": ChatGroq
}
model_mapping = {
"gpt": ChatOpenAI,
@@ -53,10 +62,16 @@ def get_local_model():
available_functions = available_asana_functions | available_drive_functions | available_vector_db_functions
tools = [tool for _, tool in available_functions.items()]
for key, chatbot_class in model_mapping.items():
if key in model.lower():
chatbot = chatbot_class(model=model) if key != "huggingface" else chatbot_class(llm=get_local_model())
break
if provider == "auto":
for key, chatbot_class in model_mapping.items():
if key in model.lower():
chatbot = chatbot_class(model=model) if key != "huggingface" else chatbot_class(llm=get_local_model())
break
else:
for key, chatbot_class in provider_mapping.items():
if key in provider.lower():
chatbot = chatbot_class(model=model) if key != "huggingface" else chatbot_class(llm=get_local_model())
break
chatbot_with_tools = chatbot.bind_tools(tools)
@@ -70,7 +85,7 @@ class GraphState(TypedDict):
"""
messages: Annotated[list[AnyMessage], add_messages]
async def call_model(state: GraphState, config: RunnableConfig) -> Dict[str, AnyMessage]:
def call_model(state: GraphState, config: RunnableConfig) -> Dict[str, AnyMessage]:
"""
Function that calls the model to generate a response.
@@ -88,7 +103,7 @@ async def call_model(state: GraphState, config: RunnableConfig) -> Dict[str, Any
))
# Invoke the chatbot with the binded tools
response = await chatbot_with_tools.ainvoke(messages, config)
response = chatbot_with_tools.invoke(messages, config)
# print("Response from model:", response)
# We return an object because this will get added to the existing list