mirror of
https://github.com/coleam00/ai-agents-masterclass.git
synced 2025-11-29 00:23:14 +00:00
Updating the LLM eval agent to work with Ollama
This commit is contained in:
@@ -28,7 +28,12 @@ HUGGINGFACEHUB_API_TOKEN=
|
|||||||
# And all Anthropic models you can use here -
|
# And all Anthropic models you can use here -
|
||||||
# https://docs.anthropic.com/en/docs/about-claude/models
|
# https://docs.anthropic.com/en/docs/about-claude/models
|
||||||
# A good default to go with here is gpt-4o-mini, claude-3-5-sonnet-20240620, or llama3-groq-70b-8192-tool-use-preview
|
# A good default to go with here is gpt-4o-mini, claude-3-5-sonnet-20240620, or llama3-groq-70b-8192-tool-use-preview
|
||||||
LLM_MODEL=gpt-4o-mini
|
LLM_MODEL=
|
||||||
|
|
||||||
|
# The provider you want to use for your LLMs
|
||||||
|
# If you don't specify this, the script will try to determine your provider from your model, but this isn't guaranteed to work!
|
||||||
|
# Possible providers: openai, anthropic, groq, ollama
|
||||||
|
LLM_PROVIDER=
|
||||||
|
|
||||||
# Get your personal Asana access token through the developer console in Asana.
|
# Get your personal Asana access token through the developer console in Asana.
|
||||||
# Feel free to follow these instructions -
|
# Feel free to follow these instructions -
|
||||||
|
|||||||
Binary file not shown.
@@ -10,6 +10,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from langchain_groq import ChatGroq
|
from langchain_groq import ChatGroq
|
||||||
|
from langchain_ollama import ChatOllama
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
from langchain_anthropic import ChatAnthropic
|
from langchain_anthropic import ChatAnthropic
|
||||||
from langchain_core.messages import ToolMessage, AIMessage
|
from langchain_core.messages import ToolMessage, AIMessage
|
||||||
@@ -21,6 +22,14 @@ from tools.vector_db_tools import available_vector_db_functions
|
|||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
model = os.getenv('LLM_MODEL', 'gpt-4o')
|
model = os.getenv('LLM_MODEL', 'gpt-4o')
|
||||||
|
provider = os.getenv('LLM_PROVIDER', 'auto')
|
||||||
|
|
||||||
|
provider_mapping = {
|
||||||
|
"openai": ChatOpenAI,
|
||||||
|
"anthropic": ChatAnthropic,
|
||||||
|
"ollama": ChatOllama,
|
||||||
|
"llama": ChatGroq
|
||||||
|
}
|
||||||
|
|
||||||
model_mapping = {
|
model_mapping = {
|
||||||
"gpt": ChatOpenAI,
|
"gpt": ChatOpenAI,
|
||||||
@@ -53,10 +62,16 @@ def get_local_model():
|
|||||||
available_functions = available_asana_functions | available_drive_functions | available_vector_db_functions
|
available_functions = available_asana_functions | available_drive_functions | available_vector_db_functions
|
||||||
tools = [tool for _, tool in available_functions.items()]
|
tools = [tool for _, tool in available_functions.items()]
|
||||||
|
|
||||||
for key, chatbot_class in model_mapping.items():
|
if provider == "auto":
|
||||||
|
for key, chatbot_class in model_mapping.items():
|
||||||
if key in model.lower():
|
if key in model.lower():
|
||||||
chatbot = chatbot_class(model=model) if key != "huggingface" else chatbot_class(llm=get_local_model())
|
chatbot = chatbot_class(model=model) if key != "huggingface" else chatbot_class(llm=get_local_model())
|
||||||
break
|
break
|
||||||
|
else:
|
||||||
|
for key, chatbot_class in provider_mapping.items():
|
||||||
|
if key in provider.lower():
|
||||||
|
chatbot = chatbot_class(model=model) if key != "huggingface" else chatbot_class(llm=get_local_model())
|
||||||
|
break
|
||||||
|
|
||||||
chatbot_with_tools = chatbot.bind_tools(tools)
|
chatbot_with_tools = chatbot.bind_tools(tools)
|
||||||
|
|
||||||
@@ -70,7 +85,7 @@ class GraphState(TypedDict):
|
|||||||
"""
|
"""
|
||||||
messages: Annotated[list[AnyMessage], add_messages]
|
messages: Annotated[list[AnyMessage], add_messages]
|
||||||
|
|
||||||
async def call_model(state: GraphState, config: RunnableConfig) -> Dict[str, AnyMessage]:
|
def call_model(state: GraphState, config: RunnableConfig) -> Dict[str, AnyMessage]:
|
||||||
"""
|
"""
|
||||||
Function that calls the model to generate a response.
|
Function that calls the model to generate a response.
|
||||||
|
|
||||||
@@ -88,7 +103,7 @@ async def call_model(state: GraphState, config: RunnableConfig) -> Dict[str, Any
|
|||||||
))
|
))
|
||||||
|
|
||||||
# Invoke the chatbot with the binded tools
|
# Invoke the chatbot with the binded tools
|
||||||
response = await chatbot_with_tools.ainvoke(messages, config)
|
response = chatbot_with_tools.invoke(messages, config)
|
||||||
# print("Response from model:", response)
|
# print("Response from model:", response)
|
||||||
|
|
||||||
# We return an object because this will get added to the existing list
|
# We return an object because this will get added to the existing list
|
||||||
|
|||||||
Reference in New Issue
Block a user