Hi! Thanks for the response. Unfortunately, that’s not quite the problem that’s occurring here. I initially thought it might be related to LangGraph and the state or a hidden prompt issue, but it’s not. I’ve tried to replay the messages manually with different system prompts, but I get the same result.
I’ve finally been able to distill down the issue more, and the problem is that the system isn’t utilizing the full context of the message array. I’ve created an annotated script from the Jupyter notebook I’ve been playing with.
As you can see, the tool calls will behave just fine. The problem occurs when the human responds to the last AIMessage with something that isn’t a command but rather a response to the agent’s suggestion.
import getpass
import os
import sys
from dotenv import load_dotenv
from pydantic import BaseModel, Field
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.tools import tool
from langchain.globals import set_verbose
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, StateGraph
from langgraph.graph import MessagesState, START
from langgraph.prebuilt import ToolNode
set_verbose(True)
# Get absolute paths for ai-starter-kit. Your location may vary
current_dir = os.getcwd()
parent_dir = os.path.abspath(os.path.join(current_dir, '..'))
ai_kit_dir = os.path.abspath(os.path.join(parent_dir, 'ai-starter-kit'))
env_dir = parent_dir
sys.path.append(ai_kit_dir)
from utils.model_wrappers.langchain_chat_models import ChatSambaNovaCloud
os.environ["LANGCHAIN_TRACING_V2"] = "false"
# load env variables from a .env file into Python environment
if load_dotenv(os.path.join(env_dir, '.env')):
api_key = os.getenv('SAMBANOVA_API_KEY')
else:
os.environ["SAMBANOVA_API_KEY"] = getpass.getpass()
api_key = os.environ.get("SAMBANOVA_API_KEY")
# Get the SambaNova chat client.
# Here, I've chosen to use the 70B version to have more context
model = ChatSambaNovaCloud(
base_url="https://api.sambanova.ai/v1/",
api_key=api_key,
streaming=False,
temperature=0.01,
model="Meta-Llama-3.1-70B-Instruct",
)
# Argument schema
# NOTE: These aren't technically necessary, but most of the SambaNova examples
# use them, whereas the newer LangGraph examples skip them.
class GetSongSchema(BaseModel):
"""Play a song"""
song: str = Field(description='song name to play')
class GetSpotifySongSchema(GetSongSchema):
"""Play a song on Spotify"""
pass
class GetAppleMusicSongSchema(GetSongSchema):
"""Play a song on Apple Music"""
pass
# Define the available tools. These are just mocks for a real tool call.
# We want the system to fail with a call to Apple Music, attempt to recover by
# suggesting Spotify, and then respond to the user's follow-up message.
@tool(args_schema=GetSpotifySongSchema)
def play_song_on_spotify(song: str):
"""Play a song on Spotify"""
# Call the Spotify API ...
return f"Successfully played {song} on Spotify!"
@tool(args_schema=GetAppleMusicSongSchema)
def play_song_on_apple(song: str):
"""Play a song on Apple Music"""
# Call the Apple Music API ...
# We always want this to fail so that the agent has to respond with
# context. The actual song choice is irrelevant.
if False: #
return f"Successfully played {song} on Apple Music!"
else:
return f"Sorry, you don't have access to {song} on Apple Music"
# By placing Apple first, the agent will try the failing method before trying
# Spotify, which always succeeds.
tools = [play_song_on_apple, play_song_on_spotify]
tool_node = ToolNode(tools)
# Bind the tools. Do not enable parallel calls because we want the model to
# behave as if it is calling a real service.
model = model.bind_tools(tools, parallel_tool_calls=False)
# Now we need to define the LangGraph components. This taken from the example
# and modified slightly to provide better instrumentation and wrap the tool
# call with an AIMessage, which is important for SambaNova's client.
# Define the function that determines whether to continue or not
def should_continue(state) -> str:
messages = state["messages"]
last_message = messages[-1]
# If there is no function call, then we finish
if not last_message.tool_calls:
return "end"
# Otherwise if there is, we continue
else:
return "continue"
# The values "end" and "continue" are condition names in the graph
# Define the function that calls the model
def call_model(state):
messages = state["messages"]
if isinstance(messages[-1], ToolMessage):
# We have to add an AI prompt here for SambaNova. Otherwise, it won't
# know what to do. This should only be appended after a ToolMessage
# because otherwise the LLM gets very chatty about irrelevant topics.
response_message = AIMessage(
content="Please respond conversationally to the user"
)
messages.append(response_message)
response = call_llm(model, messages)
# We return a list, because this will get added to the existing list
return {"messages": [response]}
# I'm putting this here to print the type of messages that we're invoking
# This is the only place in the workflow where the LLM is queried
def call_llm(llm, messages):
# Uncomment these to show all of the message types being passed to the LLM
# message_types = [str(type(message)) for message in messages]
# print(f"[messages] [{', '.join(message_types)}]")
return llm.invoke(messages) # Send back the response as usual
def create_workflow():
"""Get a new LangGraph workflow"""
# Define a new graph
workflow = StateGraph(MessagesState)
# Define the two nodes we will cycle between
workflow.add_node("agent", call_model)
workflow.add_node("action", tool_node)
# Set the entrypoint as `agent`
# This means that this node is the first one called
workflow.add_edge(START, "agent")
# We now add a conditional edge
workflow.add_conditional_edges(
# First, we define the start node. We use `agent`.
# This means these are the edges taken after the `agent` node is called.
"agent",
# Next, we pass in the function that will determine which node is called next.
should_continue,
# Finally we pass in a mapping.
# The keys are strings, and the values are other nodes.
# END is a special node marking that the graph should finish.
# What will happen is we will call `should_continue`, and then the output of that
# will be matched against the keys in this mapping.
# Based on which one it matches, that node will then be called.
{
# If `tools`, then we call the tool node.
"continue": "action",
# Otherwise we finish.
"end": END,
},
)
# We now add a normal edge from `tools` to `agent`.
# This means that after `tools` is called, `agent` node is called next.
workflow.add_edge("action", "agent")
# Set up memory
memory = MemorySaver()
# Finally, we compile it!
# This compiles it into a LangChain Runnable,
# meaning you can use it as you would any other runnable
# We add in `interrupt_before=["action"]`
# This will add a breakpoint before the `action` node is called
app = workflow.compile(checkpointer=memory)
return app
# This prompt is good enough. You can certainly use the more elaborate
# prompt from the function calling
simple_assistant_prompt = """
Your answer should be in the same language as the initial query.
Either call a tool or respond to the user.
You are a helpful assistant.
"""
if __name__ == '__main__':
app = create_workflow()
# For the configuration, any thread id will work. Just invent something.
config = {"configurable": {"thread_id": "13542"}}
# Instantiate the template
chat_template = ChatPromptTemplate.from_messages([('system', simple_assistant_prompt)])
history = chat_template.format_prompt().to_messages()
initial_human_message = HumanMessage(content="Can you play Taylor Swift's most popular song?")
history.append(initial_human_message)
# Call the system
for event in app.stream({"messages": history}, config, stream_mode="values"):
event["messages"][-1].pretty_print()
# The result at this point should look something like this:
#
# ================================ Human Message =================================
#
# Can you play Taylor Swift's most popular song?
#
# ================================== Ai Message ==================================
# Tool Calls:
# play_song_on_apple (call_ad7294371a424cfcb3)
# Call ID: call_ad7294371a424cfcb3
# Args:
# song: Shake It Off
# ================================= Tool Message =================================
# Name: play_song_on_apple
#
# Sorry, you don't have access to Shake It Off on Apple Music
#
# ================================== Ai Message ==================================
#
# I apologize, but it seems that you don't have access to "Shake It Off" on Apple Music.
# Would you like to try playing it on Spotify instead?
#
# If you look at the state and the messages by calling
# `app.get_state(config).values['messages']`, you should find the messages
# [SystemMessage, HumanMessage, AIMessage, ToolMessage, AIMessage] as shown
# above.
acknowledge_message = HumanMessage(content="Yes, please!")
# Note here that we don't append the human message directory to the history
# because the history is already stored in `config`. You can verify that
# message is correctly appended by setting a breakpoint or printing the
# message array in call_llm(...)
try:
for event in app.stream({"messages": [acknowledge_message]}, config, stream_mode="values"):
event["messages"][-1].pretty_print()
except:
# This will be triggered because the response from the LLM is empty
# Note, however, that the LLM has all of the information it needs.
print("An error occurred trying to recover using spotify")