mirror of
https://github.com/coleam00/ai-agents-masterclass.git
synced 2025-11-29 00:23:14 +00:00
Updating the LLM eval agent to work with Ollama
This commit is contained in:
@@ -28,7 +28,12 @@ HUGGINGFACEHUB_API_TOKEN=
|
||||
# And all Anthropic models you can use here -
|
||||
# https://docs.anthropic.com/en/docs/about-claude/models
|
||||
# A good default to go with here is gpt-4o-mini, claude-3-5-sonnet-20240620, or llama3-groq-70b-8192-tool-use-preview
|
||||
LLM_MODEL=gpt-4o-mini
|
||||
LLM_MODEL=
|
||||
|
||||
# The provider you want to use for your LLMs
|
||||
# If you don't specify this, the script will try to determine your provider from your model, but this isn't guaranteed to work!
|
||||
# Possible providers: openai, anthropic, groq, ollama
|
||||
LLM_PROVIDER=
|
||||
|
||||
# Get your personal Asana access token through the developer console in Asana.
|
||||
# Feel free to follow these instructions -
|
||||
|
||||
Binary file not shown.
@@ -10,6 +10,7 @@ import json
|
||||
import os
|
||||
|
||||
from langchain_groq import ChatGroq
|
||||
from langchain_ollama import ChatOllama
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_core.messages import ToolMessage, AIMessage
|
||||
@@ -21,6 +22,14 @@ from tools.vector_db_tools import available_vector_db_functions
|
||||
|
||||
load_dotenv()
|
||||
model = os.getenv('LLM_MODEL', 'gpt-4o')
|
||||
provider = os.getenv('LLM_PROVIDER', 'auto')
|
||||
|
||||
provider_mapping = {
|
||||
"openai": ChatOpenAI,
|
||||
"anthropic": ChatAnthropic,
|
||||
"ollama": ChatOllama,
|
||||
"llama": ChatGroq
|
||||
}
|
||||
|
||||
model_mapping = {
|
||||
"gpt": ChatOpenAI,
|
||||
@@ -53,10 +62,16 @@ def get_local_model():
|
||||
available_functions = available_asana_functions | available_drive_functions | available_vector_db_functions
|
||||
tools = [tool for _, tool in available_functions.items()]
|
||||
|
||||
for key, chatbot_class in model_mapping.items():
|
||||
if provider == "auto":
|
||||
for key, chatbot_class in model_mapping.items():
|
||||
if key in model.lower():
|
||||
chatbot = chatbot_class(model=model) if key != "huggingface" else chatbot_class(llm=get_local_model())
|
||||
break
|
||||
else:
|
||||
for key, chatbot_class in provider_mapping.items():
|
||||
if key in provider.lower():
|
||||
chatbot = chatbot_class(model=model) if key != "huggingface" else chatbot_class(llm=get_local_model())
|
||||
break
|
||||
|
||||
chatbot_with_tools = chatbot.bind_tools(tools)
|
||||
|
||||
@@ -70,7 +85,7 @@ class GraphState(TypedDict):
|
||||
"""
|
||||
messages: Annotated[list[AnyMessage], add_messages]
|
||||
|
||||
async def call_model(state: GraphState, config: RunnableConfig) -> Dict[str, AnyMessage]:
|
||||
def call_model(state: GraphState, config: RunnableConfig) -> Dict[str, AnyMessage]:
|
||||
"""
|
||||
Function that calls the model to generate a response.
|
||||
|
||||
@@ -88,7 +103,7 @@ async def call_model(state: GraphState, config: RunnableConfig) -> Dict[str, Any
|
||||
))
|
||||
|
||||
# Invoke the chatbot with the binded tools
|
||||
response = await chatbot_with_tools.ainvoke(messages, config)
|
||||
response = chatbot_with_tools.invoke(messages, config)
|
||||
# print("Response from model:", response)
|
||||
|
||||
# We return an object because this will get added to the existing list
|
||||
|
||||
Reference in New Issue
Block a user