LlamaIndex is a python library for LLM applications. It provides several abstractions/utilities to make LLM RAG application easier. This is deepdive into the stages of Llamaindex and the source code.
Loading Link to heading
Starting with LlamaIndex readers. I will dig deeper into SimpleDirectoryReader
which, as name suggest, a simple reader but powerful enough to handle several file types.
Document Link to heading
The docs have a simple example to load one document using SimpleDirectoryReader
. Note that reader returns Document
object (or array of them).
from llama_index.core import SimpleDirectoryReader
reader = SimpleDirectoryReader(
input_files=["./data/paul_graham/paul_graham_essay1.txt"]
)
documents = reader.load_data()
documents[0]
This prints Document which have metadata about the file and its content.
Document(id_='8766bb96-9125-44c7-b388-761180153ed5', embedding=None, metadata={'file_path': 'data/paul_graham/paul_graham_essay1.txt', 'file_name': 'paul_graham_essay1.txt', 'file_type': 'text/plain', 'file_size': 75042, 'creation_date': '2024-03-16', 'last_modified_date': '2024-03-16'}, excluded_embed_metadata_keys=['file_name', 'file_type', 'file_size', 'creation_date', 'last_modified_date', 'last_accessed_date'], excluded_llm_metadata_keys=[
So, How does SimpleDirectoryReade
work?
It starts be checking the path and pushing it into input_files
. And __init__
calls load_data
which calls load_file
which eventually reads the file and creates Document
object per file.
class SimpleDirectoryReader(BaseReader):
"""Simple directory reader.
Load files from file directory.
Automatically select the best file reader given file extensions.
"""
def __init__(
self,
input_dir: Optional[str] = None,
input_files: Optional[List] = None,
...
) -> None:
...
...
if input_files:
self.input_files = []
for path in input_files:
if not self.fs.isfile(path):
raise ValueError(f"File {path} does not exist.")
input_file = Path(path)
self.input_files.append(input_file)
elif input_dir:
if not self.fs.isdir(input_dir):
raise ValueError(f"Directory {input_dir} does not exist.")
self.input_dir = Path(input_dir)
self.exclude = exclude
self.input_files = self._add_files(self.input_dir)
...
...
def load_data(
self,
show_progress: bool = False,
num_workers: Optional[int] = None,
fs: Optional[fsspec.AbstractFileSystem] = None,
) -> List[Document]:
"""Load data from the input directory.
Args:
show_progress (bool): Whether to show tqdm progress bars. Defaults to False.
num_workers (Optional[int]): Number of workers to parallelize data-loading over.
fs (Optional[fsspec.AbstractFileSystem]): File system to use. If fs was specified
in the constructor, it will override the fs parameter here.
Returns:
List[Document]: A list of documents.
"""
for input_file in files_to_process:
documents.extend(
SimpleDirectoryReader.load_file(
input_file=input_file,
file_metadata=self.file_metadata,
file_extractor=self.file_extractor,
filename_as_id=self.filename_as_id,
encoding=self.encoding,
errors=self.errors,
fs=fs,
)
)
And finally, load_file
as promised.
def load_file(
input_file: Path,
file_metadata: Callable[[str], Dict],
file_extractor: Dict[str, BaseReader],
filename_as_id: bool = False,
encoding: str = "utf-8",
errors: str = "ignore",
fs: Optional[fsspec.AbstractFileSystem] = None,
) -> List[Document]:
...
...
if file_suffix not in file_extractor:
# instantiate file reader if not already
reader_cls = default_file_reader_cls[file_suffix]
file_extractor[file_suffix] = reader_cls()
reader = file_extractor[file_suffix]
...
...
else:
# do standard read
fs = fs or get_default_fs()
with fs.open(input_file, errors=errors, encoding=encoding) as f:
data = f.read().decode(encoding, errors=errors)
doc = Document(text=data, metadata=metadata or {})
if filename_as_id:
doc.id_ = str(input_file)
documents.append(doc)
The Document
class extends TextNode
and adds id_
Field. So, most of the Fields are defined in TextNode
class Document(TextNode):
"""Generic interface for a data document.
This document connects to data sources.
"""
# TODO: A lot of backwards compatibility logic here, clean up
id_: str = Field(
default_factory=lambda: str(uuid.uuid4()),
description="Unique ID of the node.",
alias="doc_id",
)
TextNode Link to heading
Next step is creating TextNode
from Document
, This snippet shows how to split Document into TextNode using SentenceSplitter
from llama_index.core.node_parser import SentenceSplitter
parser = SentenceSplitter()
nodes = parser.get_nodes_from_documents(documents)
The Node class have chunks of the file with relationship between the node (like a linked list). Nothing fancy there.
class TextNode(BaseNode):
text: str = Field(default="", description="Text content of the node.")
start_char_idx: Optional[int] = Field(
default=None, description="Start char index of the node."
)
end_char_idx: Optional[int] = Field(
default=None, description="End char index of the node."
)
text_template: str = Field(
default=DEFAULT_TEXT_NODE_TMPL,
description=(
"Template for how text is formatted, with {content} and "
"{metadata_str} placeholders."
),
)
Indexing Link to heading
Now we have nodes, we can run embedding to create vectors, there are other types of indexing but VectorStoreIndex
is the most popular one.
__init__
in VectorStoreIndex
takes nodes
and embed_model
.
from llama_index.core import VectorStoreIndex
index = VectorStoreIndex(nodes=nodes, embed_model=embed_model, show_progress=True)
But VectorStoreIndex
delegates to BaseIndex
class VectorStoreIndex(BaseIndex[IndexDict]):
...
...
def __init__(
self,
nodes: Optional[Sequence[BaseNode]] = None,
# vector store index params
use_async: bool = False,
store_nodes_override: bool = False,
embed_model: Optional[EmbedType] = None,
insert_batch_size: int = 2048,
# parent class params
objects: Optional[Sequence[IndexNode]] = None,
index_struct: Optional[IndexDict] = None,
storage_context: Optional[StorageContext] = None,
callback_manager: Optional[CallbackManager] = None,
transformations: Optional[List[TransformComponent]] = None,
show_progress: bool = False,
# deprecated
service_context: Optional[ServiceContext] = None,
**kwargs: Any,
) -> None:
...
...
super().__init__(
nodes=nodes,
index_struct=index_struct,
service_context=service_context,
storage_context=storage_context,
show_progress=show_progress,
objects=objects,
callback_manager=callback_manager,
transformations=transformations,
**kwargs,
)
In BaseIndex
, embed_nodes
is called with nodes and embedding model.
def _get_node_with_embedding(
self,
nodes: Sequence[BaseNode],
show_progress: bool = False,
) -> List[BaseNode]:
"""Get tuples of id, node, and embedding.
Allows us to store these nodes in a vector store.
Embeddings are called in batches.
"""embed_nodes
id_to_embed_map = embed_nodes(
nodes, self._embed_model, show_progress=show_progress
)
This is the sequence of calls between the call on __init__
and actually to embed_nodes
.
class BaseIndex(Generic[IS], ABC):
"""Base LlamaIndex.
Args:
nodes (List[Node]): List of nodes to index
show_progress (bool): Whether to show tqdm progress bars. Defaults to False.
service_context (ServiceContext): Service context container (contains
components like LLM, Embeddings, etc.).
"""
index_struct_cls: Type[IS]
def __init__(
self,
nodes: Optional[Sequence[BaseNode]] = None,
objects: Optional[Sequence[IndexNode]] = None,
index_struct: Optional[IS] = None,
storage_context: Optional[StorageContext] = None,
callback_manager: Optional[CallbackManager] = None,
transformations: Optional[List[TransformComponent]] = None,
show_progress: bool = False,
# deprecated
service_context: Optional[ServiceContext] = None,
**kwargs: Any,
) -> None:
...
...
index_struct = self.build_index_from_nodes(
nodes + objects # type: ignore
)
def build_index_from_nodes(
self,
nodes: Sequence[BaseNode],
**insert_kwargs: Any,
) -> IndexDict:
return self._build_index_from_nodes(nodes, **insert_kwargs)
def _add_nodes_to_index(
self,
index_struct: IndexDict,
nodes: Sequence[BaseNode],
show_progress: bool = False,
**insert_kwargs: Any,
) -> None:
"""Add document to index."""
if not nodes:
return
for nodes_batch in iter_batch(nodes, self._insert_batch_size):
nodes_batch = self._get_node_with_embedding
def _get_node_with_embedding(
self,
nodes: Sequence[BaseNode],
show_progress: bool = False,
) -> List[BaseNode]:
"""Get tuples of id, node, and embedding.
Allows us to store these nodes in a vector store.
Embeddings are called in batches.
"""
id_to_embed_map = embed_nodes(
nodes, self._embed_model, show_progress=show_progress
)
def embed_nodes(
nodes: Sequence[BaseNode], embed_model: BaseEmbedding, show_progress: bool = False
) -> Dict[str, List[float]]:
"""Get embeddings of the given nodes, run embedding model if necessary.
Args:
nodes (Sequence[BaseNode]): The nodes to embed.
embed_model (BaseEmbedding): The embedding model to use.
show_progress (bool): Whether to show progress bar.
Returns:
Dict[str, List[float]]: A map from node id to embedding.
"""
id_to_embed_map: Dict[str, List[float]] = {}
texts_to_embed = []
ids_to_embed = []
for node in nodes:
if node.embedding is None:
ids_to_embed.append(node.node_id)
texts_to_embed.append(node.get_content(metadata_mode=MetadataMode.EMBED))
else:
id_to_embed_map[node.node_id] = node.embedding
new_embeddings = embed_model.get_text_embedding_batch(
texts_to_embed, show_progress=show_progress
)
Storage Link to heading
This is optional, but the next step creating storage. The easiest way to call persist
from the storage_context.
index.storage_context.persist(persist_dir="./storage")
@dataclass
class StorageContext:
"""Storage context.
The storage context container is a utility container for storing nodes,
indices, and vectors. It contains the following:
- docstore: BaseDocumentStore
- index_store: BaseIndexStore
- vector_store: VectorStore
- graph_store: GraphStore
"""
docstore: BaseDocumentStore
index_store: BaseIndexStore
vector_stores: Dict[str, VectorStore]
graph_store: GraphStore
Query Link to heading
There are two main types of engines provided by LLamaIndex
- Chat engine
- Query engine
query engine Link to heading
The query engine is an abstraction that maps to OpenAI completion API. It has 3 stages
- retriever
- post-processor
- response synthesizer
The query engine can be built on top of an index by calling .as_query_engine()
query_engine = index.as_query_engine()
response = query_engine.query("sup?")
We will look at RetrieverQueryEngine
as as_query_engine
from BaseIndex
returns RetrieverQueryEngine
with whatever retriever the index defines(and the LLM)
def as_query_engine(
self, llm: Optional[LLMType] = None, **kwargs: Any
) -> BaseQueryEngine:
# NOTE: lazy import
from llama_index.core.query_engine.retriever_query_engine import (
RetrieverQueryEngine,
)
retriever = self.as_retriever(**kwargs)
llm = (
resolve_llm(llm, callback_manager=self._callback_manager)
if llm
else llm_from_settings_or_context(Settings, self.service_context)
)
return RetrieverQueryEngine.from_args(
retriever,
llm=llm,
**kwargs,
)
RetrieverQueryEngine Link to heading
Looking at RetrieverQueryEngine
, we see the 3 stages of query engine: retriever, post-processor, synthesizer. Starting with __init__
class RetrieverQueryEngine(BaseQueryEngine):
"""Retriever query engine.
Args:
retriever (BaseRetriever): A retriever object.
response_synthesizer (Optional[BaseSynthesizer]): A BaseSynthesizer
object.
callback_manager (Optional[CallbackManager]): A callback manager.
"""
def __init__(
self,
retriever: BaseRetriever,
response_synthesizer: Optional[BaseSynthesizer] = None,
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
callback_manager: Optional[CallbackManager] = None,
) -> None:
self._retriever = retriever
self._response_synthesizer = response_synthesizer or get_response_synthesizer(
llm=llm_from_settings_or_context(Settings, retriever.get_service_context()),
callback_manager=callback_manager
or callback_manager_from_settings_or_context(
Settings, retriever.get_service_context()
),
)
self._node_postprocessors = node_postprocessors or []
callback_manager = (
callback_manager or self._response_synthesizer.callback_manager
)
for node_postprocessor in self._node_postprocessors:
node_postprocessor.callback_manager = callback_manager
super().__init__(callback_manager=callback_manager)
In RetrieverQueryEngine
, we have also _query
which calls the retrieve
and synthesize
.
@dispatcher.span
def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
"""Answer a query."""
with self.callback_manager.event(
CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str}
) as query_event:
nodes = self.retrieve(query_bundle)
response = self._response_synthesizer.synthesize(
query=query_bundle,
nodes=nodes,
)
query_event.on_end(payload={EventPayload.RESPONSE: response})
return response
retrieve
is defined in RetrieverQueryEngine
which looks nodes and calls postprocessors (if defined)
def retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
nodes = self._retriever.retrieve(query_bundle)
return self._apply_node_postprocessors(nodes, query_bundle=query_bundle)
There is also synthesize
which call synthesize
from _response_synthesizer
def synthesize(
self,
query_bundle: QueryBundle,
nodes: List[NodeWithScore],
additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
) -> RESPONSE_TYPE:
return self._response_synthesizer.synthesize(
query=query_bundle,
nodes=nodes,
additional_source_nodes=additional_source_nodes,
)
Retriever Link to heading
So, what does retriever do? The retriever searches the index based on similarity search(similarity_top_k defines how much to return) and return node(or nodes).
Let’s look at VectorIndexRetriever
which is the retriever for VectorIndexStore
.
class VectorIndexRetriever(BaseRetriever):
"""Vector index retriever.
Args:
index (VectorStoreIndex): vector store index.
similarity_top_k (int): number of top k results to return.
vector_store_query_mode (str): vector store query mode
See reference for VectorStoreQueryMode for full list of supported modes.
filters (Optional[MetadataFilters]): metadata filters, defaults to None
alpha (float): weight for sparse/dense retrieval, only used for
hybrid query mode.
doc_ids (Optional[List[str]]): list of documents to constrain search.
vector_store_kwargs (dict): Additional vector store specific kwargs to pass
through to the vector store at query time.
"""
def __init__(
self,
index: VectorStoreIndex,
similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
vector_store_query_mode: VectorStoreQueryMode = VectorStoreQueryMode.DEFAULT,
filters: Optional[MetadataFilters] = None,
alpha: Optional[float] = None,
node_ids: Optional[List[str]] = None,
doc_ids: Optional[List[str]] = None,
sparse_top_k: Optional[int] = None,
callback_manager: Optional[CallbackManager] = None,
object_map: Optional[dict] = None,
embed_model: Optional[BaseEmbedding] = None,
verbose: bool = False,
**kwargs: Any,
) -> None:
In _retrieve
, the embedding of the query is first calculated with get_agg_embedding_from_queries
.
@dispatcher.span
def _retrieve(
self,
query_bundle: QueryBundle,
) -> List[NodeWithScore]:
if self._vector_store.is_embedding_query:
if query_bundle.embedding is None and len(query_bundle.embedding_strs) > 0:
query_bundle.embedding = (
self._embed_model.get_agg_embedding_from_queries(
query_bundle.embedding_strs
)
)
return self._get_nodes_with_embeddings(query_bundle)
Then _get_nodes_with_embeddings
is called with embedding of the query to look up the nodes. Thin is happening in query_bundle_with_embeddings
def _get_nodes_with_embeddings(
self, query_bundle_with_embeddings: QueryBundle
) -> List[NodeWithScore]:
query = self._build_vector_store_query(query_bundle_with_embeddings)
query_result = self._vector_store.query(query, **self._kwargs)
return self._build_node_list_from_query_result(query_result)
Postprocessing Link to heading
This is optional. So, will skip for now.
Response Synthesizer Link to heading
The response synthesizer is built in RetrieverQueryEngine
by calling get_response_synthesizer
and giving to some parameters including response type(default is COMPACT).
Here is the list of synthesizer modes.
from enum import Enum
class ResponseMode(str, Enum):
"""Response modes of the response builder (and synthesizer)."""
REFINE = "refine"
"""
Refine is an iterative way of generating a response.
We first use the context in the first node, along with the query, to generate an \
initial answer.
We then pass this answer, the query, and the context of the second node as input \
into a “refine prompt” to generate a refined answer. We refine through N-1 nodes, \
where N is the total number of nodes.
"""
COMPACT = "compact"
"""
Compact and refine mode first combine text chunks into larger consolidated chunks \
that more fully utilize the available context window, then refine answers \
across them.
This mode is faster than refine since we make fewer calls to the LLM.
"""
SIMPLE_SUMMARIZE = "simple_summarize"
"""
Merge all text chunks into one, and make a LLM call.
This will fail if the merged text chunk exceeds the context window size.
"""
TREE_SUMMARIZE = "tree_summarize"
"""
Build a tree index over the set of candidate nodes, with a summary prompt seeded \
with the query.
The tree is built in a bottoms-up fashion, and in the end the root node is \
returned as the response
"""
GENERATION = "generation"
"""Ignore context, just use LLM to generate a response."""
NO_TEXT = "no_text"
"""Return the retrieved context nodes, without synthesizing a final response."""
ACCUMULATE = "accumulate"
"""Synthesize a response for each text chunk, and then return the concatenation."""
COMPACT_ACCUMULATE = "compact_accumulate"
"""
Compact and accumulate mode first combine text chunks into larger consolidated \
chunks that more fully utilize the available context window, then accumulate \
answers for each of them and finally return the concatenation.
This mode is faster than accumulate since we make fewer calls to the LLM.
"""
In BaseSynthesizer
, there is the call to get_response
which is implemented by different Synthesizer classes.
@dispatcher.span
def synthesize(
self,
query: QueryTextType,
nodes: List[NodeWithScore],
additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
**response_kwargs: Any,
) -> RESPONSE_TYPE:
dispatcher.event(SynthesizeStartEvent(query=query))
...
...
response_str = self.get_response(
query_str=query.query_str,
text_chunks=[
n.node.get_content(metadata_mode=MetadataMode.LLM) for n in nodes
],
**response_kwargs,
)
additional_source_nodes = additional_source_nodes or []
source_nodes = list(nodes) + list(additional_source_nodes)
response = self._prepare_response_output(response_str, source_nodes)
...
...
return response
Let’s look at CompactAndRefine
as it’s the default.
class CompactAndRefine(Refine):
"""Refine responses across compact text chunks."""
get_response
calls _make_compact_text_chunks
to combine prompt and text chunks before calling super().get_response
from Refine
.
@dispatcher.span
def get_response(
self,
query_str: str,
text_chunks: Sequence[str],
prev_response: Optional[RESPONSE_TEXT_TYPE] = None,
**response_kwargs: Any,
) -> RESPONSE_TEXT_TYPE:
"""Get compact response."""
# use prompt helper to fix compact text_chunks under the prompt limitation
# TODO: This is a temporary fix - reason it's temporary is that
# the refine template does not account for size of previous answer.
new_texts = self._make_compact_text_chunks(query_str, text_chunks)
return super().get_response(
query_str=query_str,
text_chunks=new_texts,
prev_response=prev_response,
**response_kwargs,
)
And in Refine
we have also get_response
which just loops over the chunks and passes the chunk, query and response from last LLM call to LLM.
class Refine(BaseSynthesizer):
"""Refine a response to a query across text chunks."""
...
...
@dispatcher.span
def get_response(
self,
query_str: str,
text_chunks: Sequence[str],
prev_response: Optional[RESPONSE_TEXT_TYPE] = None,
**response_kwargs: Any,
) -> RESPONSE_TEXT_TYPE:
"""Give response over chunks."""
dispatcher.event(GetResponseStartEvent())
response: Optional[RESPONSE_TEXT_TYPE] = None
for text_chunk in text_chunks:
if prev_response is None:
# if this is the first chunk, and text chunk already
# is an answer, then return it
response = self._give_response_single(
query_str, text_chunk, **response_kwargs
)
else:
# refine response if possible
response = self._refine_response_single(
prev_response, query_str, text_chunk, **response_kwargs
)
prev_response = response
Chat engine Link to heading
Query engine maps to chatGPT chat API. LLamaIndex provides an abstraction for chatbot using the history of conversation (Aka memory) to keep LL context-aware of previous prompt and response.
There are several modes for chat engine.
class ChatMode(str, Enum):
"""Chat Engine Modes."""
SIMPLE = "simple"
"""Corresponds to `SimpleChatEngine`.
Chat with LLM, without making use of a knowledge base.
"""
CONDENSE_QUESTION = "condense_question"
"""Corresponds to `CondenseQuestionChatEngine`.
First generate a standalone question from conversation context and last message,
then query the query engine for a response.
"""
CONTEXT = "context"
"""Corresponds to `ContextChatEngine`.
First retrieve text from the index using the user's message, then use the context
in the system prompt to generate a response.
"""
CONDENSE_PLUS_CONTEXT = "condense_plus_context"
"""Corresponds to `CondensePlusContextChatEngine`.
First condense a conversation and latest user message to a standalone question.
Then build a context for the standalone question from a retriever,
Then pass the context along with prompt and user message to LLM to generate a response.
"""
REACT = "react"
"""Corresponds to `ReActAgent`.
Use a ReAct agent loop with query engine tools.
"""
OPENAI = "openai"
"""Corresponds to `OpenAIAgent`.
Use an OpenAI function calling agent loop.
NOTE: only works with OpenAI models that support function calling API.
"""
BEST = "best"
"""Select the best chat engine based on the current LLM.
Corresponds to `OpenAIAgent` if using an OpenAI model that supports
function calling API, otherwise, corresponds to `ReActAgent`.
"""
Each of the chat engine defines sync and async (for asyncio apps) versions for chat
method. There are also streaming versions.
For SimpleChatEngine
, The prompt adds the current message to history and send the whole thing to LLM.
@trace_method("chat")
def chat(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
) -> AgentChatResponse:
if chat_history is not None:
self._memory.set(chat_history)
self._memory.put(ChatMessage(content=message, role="user"))
initial_token_count = len(
self._memory.tokenizer_fn(
" ".join([(m.content or "") for m in self._prefix_messages])
)
)
all_messages = self._prefix_messages + self._memory.get(
initial_token_count=initial_token_count
)
chat_response = self._llm.chat(all_messages)
ai_message = chat_response.message
self._memory.put(ai_message)
return AgentChatResponse(response=str(chat_response.message.content))
For CondenseQuestionChatEngine
, it first calls LLM with the history and current question, to create a question that can sent to LLM. Cool!
DEFAULT_TEMPLATE = """\
Given a conversation (between Human and Assistant) and a follow up message from Human, \
rewrite the message to be a standalone question that captures all relevant context \
from the conversation.
<Chat History>
{chat_history}
<Follow Up Message>
{question}
<Standalone question>
"""
DEFAULT_PROMPT = PromptTemplate(DEFAULT_TEMPLATE)
class CondenseQuestionChatEngine(BaseChatEngine):
"""Condense Question Chat Engine.
First generate a standalone question from conversation context and last message,
then query the query engine for a response.
"""
Finally, to chat
method which just created query and sends to LLM.
@trace_method("chat")
def chat(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
) -> AgentChatResponse:
chat_history = chat_history or self._memory.get()
# Generate standalone question from conversation context and last message
condensed_question = self._condense_question(chat_history, message)
log_str = f"Querying with: {condensed_question}"
logger.info(log_str)
if self._verbose:
print(log_str)
# TODO: right now, query engine uses class attribute to configure streaming,
# we are moving towards separate streaming and non-streaming methods.
# In the meanwhile, use this hack to toggle streaming.
from llama_index.core.query_engine.retriever_query_engine import (
RetrieverQueryEngine,
)
if isinstance(self._query_engine, RetrieverQueryEngine):
is_streaming = self._query_engine._response_synthesizer._streaming
self._query_engine._response_synthesizer._streaming = False
# Query with standalone question
query_response = self._query_engine.query(condensed_question)
# NOTE: reset streaming flag
if isinstance(self._query_engine, RetrieverQueryEngine):
self._query_engine._response_synthesizer._streaming = is_streaming
tool_output = self._get_tool_output_from_response(
condensed_question, query_response
)
# Record response
self._memory.put(ChatMessage(role=MessageRole.USER, content=message))
self._memory.put(
ChatMessage(role=MessageRole.ASSISTANT, content=str(query_response))
)
return AgentChatResponse(response=str(query_response), sources=[tool_output])
For ContextChatEngine
, It just adds retriever output in the query. Nothing fancy.
DEFAULT_CONTEXT_TEMPLATE = (
"Context information is below."
"\n--------------------\n"
"{context_str}"
"\n--------------------\n"
)
class ContextChatEngine(BaseChatEngine):
"""Context Chat Engine.
Uses a retriever to retrieve a context, set the context in the system prompt,
and then uses an LLM to generate a response, for a fluid chat experience.
"""
@trace_method("chat")
def chat(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
) -> AgentChatResponse:
if chat_history is not None:
self._memory.set(chat_history)
self._memory.put(ChatMessage(content=message, role="user"))
context_str_template, nodes = self._generate_context(message)
prefix_messages = self._get_prefix_messages_with_context(context_str_template)
prefix_messages_token_count = len(
self._memory.tokenizer_fn(
" ".join([(m.content or "") for m in prefix_messages])
)
)
all_messages = prefix_messages + self._memory.get(
initial_token_count=prefix_messages_token_count
)
chat_response = self._llm.chat(all_messages)
ai_message = chat_response.message
self._memory.put(ai_message)
return AgentChatResponse(
response=str(chat_response.message.content),
sources=[
ToolOutput(
tool_name="retriever",
content=str(prefix_messages[0]),
raw_input={"message": message},
raw_output=prefix_messages[0],
)
],
source_nodes=nodes,
)
Agents Link to heading
TLDR: The agent creates a prompt that tells LLM about the tools (the user provided) and asks LLM to choose the right tool and with right input. Agent calls the tool gets the output asks if LLM is ready to answer or not. if not, it will keep calling the tools until it’s done.
Although query and chat engines are useful for query and chat bots over indexes. RAG agents provide a way to call user-defined tools and create thoughts to determine the next step.
I will start with ReActAgent
because this is the most important agent for RAG. From docs,
An “agent” is an automated reasoning and decision engine. It takes in a user input/query and can make internal decisions for executing that query in order to return the correct result. The key agent components can include, but are not limited to:
Breaking down a complex question into smaller ones Choosing an external Tool to use + coming up with parameters for calling the Tool Planning out a set of tasks Storing previously completed tasks in a memory module
ReActAgent Link to heading
In llama-index-core/llama_index/core/agent/react/base.py
, ReActAgent
extends AgentRunner
and creates its own instance of ReActAgentWorker
. So, Most of the logic is happening in AgentRunner
and ReActAgentWorker
. As their code says
Simple wrapper around AgentRunner + ReActAgentWorker.
class ReActAgent(AgentRunner):
def __init__(
) -> None:
step_engine = ReActAgentWorker.from_tools(
tools=tools,
tool_retriever=tool_retriever,
llm=llm,
max_iterations=max_iterations,
react_chat_formatter=react_chat_formatter,
output_parser=output_parser,
callback_manager=callback_manager,
verbose=verbose,
)
super().__init__(
step_engine,
memory=memory,
llm=llm,
callback_manager=callback_manager,
)
AgentRunner Link to heading
In llama-index-core/llama_index/core/agent/runner/base.py
, AgentRunner
docstring describes what it does.
class AgentRunner(BaseAgentRunner):
"""Agent runner.
Top-level agent orchestrator that can create tasks, run each step in a task,
or run a task e2e. Stores state and keeps track of tasks.
Args:
agent_worker (BaseAgentWorker): step executor
chat_history (Optional[List[ChatMessage]], optional): chat history. Defaults to None.
state (Optional[AgentState], optional): agent state. Defaults to None.
memory (Optional[BaseMemory], optional): memory. Defaults to None.
llm (Optional[LLM], optional): LLM. Defaults to None.
callback_manager (Optional[CallbackManager], optional): callback manager. Defaults to None.
init_task_state_kwargs (Optional[dict], optional): init task state kwargs. Defaults to None.
"""
In chat()
, self._chat()
is called with message, tools, and history.
@trace_method("chat")
def chat(
self,
message: str,
chat_history: Optional[List[ChatMessage]] = None,
tool_choice: Optional[Union[str, dict]] = None,
) -> AgentChatResponse:
# override tool choice is provided as input.
if tool_choice is None:
tool_choice = self.default_tool_choice
with self.callback_manager.event(
CBEventType.AGENT_STEP,
payload={EventPayload.MESSAGES: [message]},
) as e:
chat_response = self._chat(
message=message,
chat_history=chat_history,
tool_choice=tool_choice,
mode=ChatResponseMode.WAIT,
)
assert isinstance(chat_response, AgentChatResponse)
e.on_end(payload={EventPayload.RESPONSE: chat_response})
return chat_response
In _chat()
, _run_step
is called after creating Task
for that messages
@dispatcher.span
def _chat(
self,
message: str,
chat_history: Optional[List[ChatMessage]] = None,
tool_choice: Union[str, dict] = "auto",
mode: ChatResponseMode = ChatResponseMode.WAIT,
) -> AGENT_CHAT_RESPONSE_TYPE:
"""Chat with step executor."""
if chat_history is not None:
self.memory.set(chat_history)
task = self.create_task(message)
result_output = None
dispatcher.event(AgentChatWithStepStartEvent())
while True:
# pass step queue in as argument, assume step executor is stateless
cur_step_output = self._run_step(
task.task_id, mode=mode, tool_choice=tool_choice
)
if cur_step_output.is_last:
result_output = cur_step_output
break
# ensure tool_choice does not cause endless loops
tool_choice = "auto"
result = self.finalize_response(
task.task_id,
result_output,
)
dispatcher.event(AgentChatWithStepEndEvent())
return result
Still in runner/base.py
def _run_step(
self,
task_id: str,
step: Optional[TaskStep] = None,
input: Optional[str] = None,
mode: ChatResponseMode = ChatResponseMode.WAIT,
**kwargs: Any,
) -> TaskStepOutput:
"""Execute step."""
dispatcher.event(AgentRunStepStartEvent())
task = self.state.get_task(task_id)
step_queue = self.state.get_step_queue(task_id)
step = step or step_queue.popleft()
if input is not None:
step.input = input
if self.verbose:
print(f"> Running step {step.step_id}. Step input: {step.input}")
# TODO: figure out if you can dynamically swap in different step executors
# not clear when you would do that by theoretically possible
if mode == ChatResponseMode.WAIT:
cur_step_output = self.agent_worker.run_step(step, task, **kwargs)
elif mode == ChatResponseMode.STREAM:
cur_step_output = self.agent_worker.stream_step(step, task, **kwargs)
else:
raise ValueError(f"Invalid mode: {mode}")
# append cur_step_output next steps to queue
next_steps = cur_step_output.next_steps
step_queue.extend(next_steps)
# add cur_step_output to completed steps
completed_steps = self.state.get_completed_steps(task_id)
completed_steps.append(cur_step_output)
dispatcher.event(AgentRunStepEndEvent())
return cur_step_output
Jumping to llama-index-core/llama_index/core/agent/react/step.py
, where self.agent_worker.run_step
is defined in ReActAgentWorker
.
def _run_step(
self,
step: TaskStep,
task: Task,
) -> TaskStepOutput:
"""Run step."""
if step.input is not None:
add_user_step_to_reasoning(
step,
task.extra_state["new_memory"],
task.extra_state["current_reasoning"],
verbose=self._verbose,
)
# TODO: see if we want to do step-based inputs
tools = self.get_tools(task.input)
input_chat = self._react_chat_formatter.format(
tools,
chat_history=task.memory.get() + task.extra_state["new_memory"].get_all(),
current_reasoning=task.extra_state["current_reasoning"],
)
# send prompt
chat_response = self._llm.chat(input_chat)
# given react prompt outputs, call tools or return response
reasoning_steps, is_done = self._process_actions(
task, tools, output=chat_response
)
task.extra_state["current_reasoning"].extend(reasoning_steps)
agent_response = self._get_response(
task.extra_state["current_reasoning"], task.extra_state["sources"]
)
if is_done:
task.extra_state["new_memory"].put(
ChatMessage(content=agent_response.response, role=MessageRole.ASSISTANT)
)
return self._get_task_step_response(agent_response, step, is_done)
Note the default formatter prompt is ReActChatFormatter
which uses the following REACT_CHAT_SYSTEM_HEADER
. Basically, LlamaIndex tells LLM about the tools and asks it to call them with json format. that’s happening in line self._react_chat_formatter.format self._react_chat_formatter.format
above.
"""Default prompt for ReAct agent."""
# ReAct chat prompt
# TODO: have formatting instructions be a part of react output parser
REACT_CHAT_SYSTEM_HEADER = """\
You are designed to help with a variety of tasks, from answering questions \
to providing summaries to other types of analyses.
## Tools
You have access to a wide variety of tools. You are responsible for using
the tools in any sequence you deem appropriate to complete the task at hand.
This may require breaking the task into subtasks and using different tools
to complete each subtask.
You have access to the following tools:
{tool_desc}
## Output Format
Please answer in the same language as the question and use the following format:
Thought: The current language of the user is: (user's language). I need to use a tool to help me answer the question.
Action: tool name (one of {tool_names}) if using a tool.
Action Input: the input to the tool, in a JSON format representing the kwargs (e.g. {{"input": "hello world", "num_beams": 5}})
Please ALWAYS start with a Thought.
Please use a valid JSON format for the Action Input. Do NOT do this {{'input': 'hello world', 'num_beams': 5}}.
If this format is used, the user will respond in the following format:
Observation: tool response
You should keep repeating the above format until you have enough information
to answer the question without using any more tools. At that point, you MUST respond
in the one of the following two formats:
Thought: I can answer without using any more tools. I'll use the user's language to answer
Answer: [your answer here (In the same language as the user's question)]
Thought: I cannot answer the question with the provided tools.
Answer: [your answer here (In the same language as the user's question)]
## Current Conversation
Below is the current conversation consisting of interleaving human and assistant messages.
"""
So, after the first LLM call with tools and question, LLamaIndex parse the output to know what tool it needs to call. That’s happening in _process_actions
. So, it’s extract the tool name, that LLM wanted and the input and calls it with __call__
. Cool!
def _process_actions(
self,
task: Task,
tools: Sequence[AsyncBaseTool],
output: ChatResponse,
is_streaming: bool = False,
) -> Tuple[List[BaseReasoningStep], bool]:
tools_dict: Dict[str, AsyncBaseTool] = {
tool.metadata.get_name(): tool for tool in tools
}
_, current_reasoning, is_done = self._extract_reasoning_step(
output, is_streaming
)
if is_done:
return current_reasoning, True
# call tool with input
reasoning_step = cast(ActionReasoningStep, current_reasoning[-1])
tool = tools_dict[reasoning_step.action]
with self.callback_manager.event(
CBEventType.FUNCTION_CALL,
payload={
EventPayload.FUNCTION_CALL: reasoning_step.action_input,
EventPayload.TOOL: tool.metadata,
},
) as event:
tool_output = tool.call(**reasoning_step.action_input)
event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)})
task.extra_state["sources"].append(tool_output)
observation_step = ObservationReasoningStep(observation=str(tool_output))
current_reasoning.append(observation_step)
if self._verbose:
print_text(f"{observation_step.get_content()}\n", color="blue")
return current_reasoning, False
The actual parsing is happening in _extract_reasoning_step
.
def _extract_reasoning_step(
self, output: ChatResponse, is_streaming: bool = False
) -> Tuple[str, List[BaseReasoningStep], bool]:
"""
Extracts the reasoning step from the given output.
This method parses the message content from the output,
extracts the reasoning step, and determines whether the processing is
complete. It also performs validation checks on the output and
handles possible errors.
"""
if output.message.content is None:
raise ValueError("Got empty message.")
message_content = output.message.content
current_reasoning = []
try:
reasoning_step = self._output_parser.parse(message_content, is_streaming)
except BaseException as exc:
raise ValueError(f"Could not parse output: {message_content}") from exc
if self._verbose:
print_text(f"{reasoning_step.get_content()}\n", color="pink")
current_reasoning.append(reasoning_step)
if reasoning_step.is_done:
return message_content, current_reasoning, True
reasoning_step = cast(ActionReasoningStep, reasoning_step)
if not isinstance(reasoning_step, ActionReasoningStep):
raise ValueError(f"Expected ActionReasoningStep, got {reasoning_step}")
return message_content, current_reasoning, False
It calls the logic in ReActOutputParser
class ReActOutputParser(BaseOutputParser):
"""ReAct Output parser."""
def parse(self, output: str, is_streaming: bool = False) -> BaseReasoningStep:
"""Parse output from ReAct agent.
We expect the output to be in one of the following formats:
1. If the agent need to use a tool to answer the question:
Thought: <thought>
Action: <action>
Action Input: <action_input>
2. If the agent can answer the question without any tools:
Thought: <thought>
Answer: <answer>
"""
if "Thought:" not in output:
# NOTE: handle the case where the agent directly outputs the answer
# instead of following the thought-answer format
return ResponseReasoningStep(
thought="(Implicit) I can answer without any more tools!",
response=output,
is_streaming=is_streaming,
)
if "Answer:" in output:
thought, answer = extract_final_response(output)
return ResponseReasoningStep(
thought=thought, response=answer, is_streaming=is_streaming
)
if "Action:" in output:
return parse_action_reasoning_step(output)
raise ValueError(f"Could not parse output: {output}")
Note that parser return different types on steps depending on thought from LLM. The types are:
- ActionReasoningStep
- ObservationReasoningStep
- ResponseReasoningStep
class ActionReasoningStep(BaseReasoningStep):
"""Action Reasoning step."""
thought: str
action: str
action_input: Dict
def get_content(self) -> str:
"""Get content."""
return (
f"Thought: {self.thought}\nAction: {self.action}\n"
f"Action Input: {self.action_input}"
)
@property
def is_done(self) -> bool:
"""Is the reasoning step the last one."""
return False
class ObservationReasoningStep(BaseReasoningStep):
"""Observation reasoning step."""
observation: str
def get_content(self) -> str:
"""Get content."""
return f"Observation: {self.observation}"
@property
def is_done(self) -> bool:
"""Is the reasoning step the last one."""
return False
class ResponseReasoningStep(BaseReasoningStep):
"""Response reasoning step."""
thought: str
response: str
is_streaming: bool = False
def get_content(self) -> str:
"""Get content."""
if self.is_streaming:
return (
f"Thought: {self.thought}\n"
f"Answer (Starts With): {self.response} ..."
)
else:
return f"Thought: {self.thought}\n" f"Answer: {self.response}"
@property
def is_done(self) -> bool:
"""Is the reasoning step the last one."""
return True