今天我们接着看fast-graphrag的源码,首先看 fast_graphrag.GraphRAG 类,他是入口,继承自fast_graphrag._graphrag.BaseGraphRAG 类
@dataclass
class BaseGraphRAG(Generic[GTEmbedding, GTHash, GTChunk, GTNode, GTEdge, GTId]):
"""A class representing a Graph-based Retrieval-Augmented Generation system."""
working_dir: str = field()
domain: str = field()
example_queries: str = field()
entity_types: List[str] = field()
n_checkpoints: int = field(default=0)
llm_service: BaseLLMService = field(init=False, default_factory=lambda: BaseLLMService())
chunking_service: BaseChunkingService[GTChunk] = field(init=False, default_factory=lambda: BaseChunkingService())
information_extraction_service: BaseInformationExtractionService[GTChunk, GTNode, GTEdge, GTId] = field(
init=False,
default_factory=lambda: BaseInformationExtractionService(
graph_upsert=BaseGraphUpsertPolicy(
config=None,
nodes_upsert_cls=BaseNodeUpsertPolicy,
edges_upsert_cls=BaseEdgeUpsertPolicy,
)
),
)
state_manager: BaseStateManagerService[GTNode, GTEdge, GTHash, GTChunk, GTId, GTEmbedding] = field(
init=False,
default_factory=lambda: BaseStateManagerService(
workspace=None,
graph_storage=BaseGraphStorage[GTNode, GTEdge, GTId](config=None),
entity_storage=BaseVectorStorage[GTId, GTEmbedding](config=None),
chunk_storage=BaseIndexedKeyValueStorage[GTHash, GTChunk](config=None),
embedding_service=BaseEmbeddingService(),
node_upsert_policy=BaseNodeUpsertPolicy(config=None),
edge_upsert_policy=BaseEdgeUpsertPolicy(config=None),
),
)
要看懂这段代码,首先要了解几个概念:
数据类(Data Classes):
@dataclass 是 Python 的一个装饰器,用于简化类的定义,自动生成初始化方法、表示方法等。
泛型(Generics):
Generic 是 Python 的类型提示功能的一部分,允许我们定义可以接受不同类型的类。BaseGraphRAG 类使用了多个泛型参数(如 GTEmbedding, GTHash 等),了解泛型的概念将帮助我们理解这个类的灵活性。
类型提示(Type Hinting):
代码中使用了许多类型提示(如 List[str], Optional[Dict[str, Any]]),这有助于提高代码的可读性和可维护性。了解 Python 的类型提示将有助于我们理解参数和返回值的预期类型。
依赖注入(Dependency Injection):
代码中使用了 default_factory 来初始化一些服务(如 llm_service, chunking_service 等),这是一种依赖注入的方式。了解依赖注入的概念将帮助我们理解如何管理和使用这些服务。
下面以一个简单的图类来对比说明一下使用泛型和不使用泛型的区别:
简单实现一个图类,可以添加节点添加边。
可以看到使用左边使用泛型的版本,更加的灵活,初始化的时候可以通过泛型来传入节点的类型
graph = SimpleGraph[int, tuple[int, int]]() # 节点为整数,边为整数元组
graph = SimpleGraph[str, tuple[str, str]]() # 节点为字符串,边为字符串元组
而右边节点和边的类型被固定为 str 和 Tuple[str, str]。这意味着你只能使用字符串作为节点,且边只能是连接两个字符串的元组。传入不符合的类型,编辑器会提示错误。(虽然直接运行也不会出错就是了~)
简单的说:
泛型参数:是占位符,允许在使用时指定具体类型。
泛型类:是使用泛型参数定义的类,能够处理多种类型的数据,提高代码的灵活性和可重用性。
llm_service: BaseLLMService = field(init=False, default_factory=lambda: BaseLLMService())
llm_service 是一个属性,类型为 BaseLLMService。init=False 表示在初始化时不需要传入这个参数。default_factory 用于指定一个工厂函数,在创建实例时调用这个函数来生成默认值。其他几个参数类似,表示初始化不需要传入,有默认值。
再回到fast_graphrag.GraphRAG类。
GraphRAG 继承自 BaseGraphRAG
内部配置类:Config 是一个内部数据类,负责配置 GraphRAG 的各种服务和策略。它包含多个字段,每个字段都有默认值或工厂函数来初始化。
后初始化方法:__post_init__ 方法在 Config 和 GraphRAG 类实例化后被调用,用于初始化一些属性,例如设置嵌入维度。
服务实例化:在 GraphRAG 的 __post_init__ 方法中,实例化了各种服务和策略,确保它们在类的使用过程中可用。
工作空间管理:state_manager 的初始化涉及到工作空间的创建,使用了 Workspace.new 方法来管理工作目录和检查点。
此时再看初始化代码就很明了了。
grag = GraphRAG(
working_dir="./book_example",
domain=DOMAIN,
example_queries="",
entity_types=ENTITY_TYPES,
n_checkpoints=2,
config=GraphRAG.Config(
llm_service=OpenAILLMService(
model="gpt-4o-mini", base_url=os.getenv("OPENAI_API_BASE"), api_key=os.getenv("OPENAI_API_KEY"), mode=instructor.Mode.JSON
),
embedding_service=OpenAIEmbeddingService(
model="text-embedding-ada-002",
base_url=os.getenv("OPENAI_API_BASE"),
api_key=os.getenv("OPENAI_API_KEY"),
embedding_dim=1536, # the output embedding dim of the chosen model
),
),
)
到这,本篇已经完结,下次咱们研究这个fast graphrag的insert策略,也欢迎大家在评论区留言,分享你对模型应用的一些经验和看法~