diff --git a/docling_graph/config.py b/docling_graph/config.py index b3f926fa..22db1c8d 100644 --- a/docling_graph/config.py +++ b/docling_graph/config.py @@ -72,6 +72,15 @@ class ModelsConfig(BaseModel): llm: LLMConfig = Field(default_factory=LLMConfig) vlm: VLMConfig = Field(default_factory=VLMConfig) +class Neo4jConfig(BaseModel): + """Configuration for Neo4j database connection.""" + + uri: str = Field(default="bolt://localhost:7687", description="Neo4j URI") + username: str = Field(default="neo4j", description="Database username") + password: str = Field(default="password", description="Database password") + database: str = Field(default="neo4j", description="Target database name") + batch_size: int = Field(default=1000, description="Batch size for ingestion") + write_mode: Literal["merge", "create"] = Field(default="merge", description="Write strategy") class PipelineConfig(BaseModel): """ @@ -101,13 +110,15 @@ class PipelineConfig(BaseModel): # Models configuration (flat only, with defaults) models: ModelsConfig = Field(default_factory=ModelsConfig) + neo4j: Neo4jConfig = Field(default_factory=Neo4jConfig) + # Extract settings (with defaults) use_chunking: bool = True llm_consolidation: bool = False max_batch_size: int = 1 # Export settings (with defaults) - export_format: Literal["csv", "cypher"] = Field(default="csv") + export_format: Literal["csv", "cypher", "neo4j"] = Field(default="csv") export_docling: bool = Field(default=True) export_docling_json: bool = Field(default=True) export_markdown: bool = Field(default=True) @@ -153,6 +164,7 @@ def to_dict(self) -> Dict[str, Any]: "reverse_edges": self.reverse_edges, "output_dir": self.output_dir, "models": self.models.model_dump(), + "neo4j": self.neo4j.model_dump(), } def run(self) -> None: @@ -185,6 +197,7 @@ def generate_yaml_dict(cls) -> Dict[str, Any]: }, }, "models": default_config.models.model_dump(), + "neo4j": default_config.neo4j.model_dump(), "output": { "directory": str(default_config.output_dir), }, diff --git a/docling_graph/db_clients/__init__.py b/docling_graph/db_clients/__init__.py index e69de29b..f4029cc0 100644 --- a/docling_graph/db_clients/__init__.py +++ b/docling_graph/db_clients/__init__.py @@ -0,0 +1,3 @@ +from .neo4j_client import Neo4jExporter + +__all__ = ["Neo4jExporter"] \ No newline at end of file diff --git a/docling_graph/db_clients/neo4j_client.py b/docling_graph/db_clients/neo4j_client.py new file mode 100644 index 00000000..9ceb686a --- /dev/null +++ b/docling_graph/db_clients/neo4j_client.py @@ -0,0 +1,156 @@ +from typing import Any, Dict, List, Optional +import networkx as nx +from neo4j import GraphDatabase, Driver +from rich import print as rich_print + +class Neo4jExporter: + """Exporter for populating a live Neo4j database.""" + + def __init__( + self, + uri: str, + auth: Optional[tuple[str, str]] = None, + database: str = "neo4j", + batch_size: int = 1000, + write_mode: str = "merge", # "merge" or "create" + ): + """ + Initialize the Neo4j exporter. + + Args: + uri: Neo4j database URI (e.g., 'bolt://localhost:7687') + auth: Tuple of (username, password) + database: Database name to use + batch_size: Number of records to commit in a single transaction + write_mode: Strategy for writing nodes ('merge' updates existing, 'create' adds new) + """ + self.uri = uri + self.auth = auth + self.database = database + self.batch_size = batch_size + self.write_mode = write_mode.lower() + self._driver: Optional[Driver] = None + + def _get_driver(self) -> Driver: + if self._driver is None: + self._driver = GraphDatabase.driver(self.uri, auth=self.auth) + return self._driver + + def close(self) -> None: + if self._driver: + self._driver.close() + self._driver = None + + def export(self, graph: nx.DiGraph) -> None: + """ + Export the NetworkX graph to Neo4j. + + Args: + graph: The NetworkX directed graph to export + """ + if graph.number_of_nodes() == 0: + rich_print("[yellow]Graph is empty. Skipping Neo4j export.[/yellow]") + return + + driver = self._get_driver() + + try: + with driver.session(database=self.database) as session: + # 1. Export Nodes + self._export_nodes(session, graph) + + # 2. Export Relationships + self._export_edges(session, graph) + + rich_print(f"[green]Successfully exported graph to Neo4j database '{self.database}'[/green]") + except Exception as e: + rich_print(f"[red]Failed to export to Neo4j:[/red] {e}") + raise + finally: + self.close() + + def _export_nodes(self, session, graph: nx.DiGraph) -> None: + """Batch write nodes to Neo4j.""" + batch: List[Dict[str, Any]] = [] + + query = ( + "UNWIND $batch AS row " + f"{'MERGE' if self.write_mode == 'merge' else 'CREATE'} (n:Node {{id: row.id}}) " + "SET n += row.properties, n.label = row.label " + "WITH n, row " + "CALL apoc.create.addLabels(n, [row.label]) YIELD node " # Optional: requires APOC, fallback to simple label setting if needed + "RETURN count(*)" + ) + + # Simplified query without APOC dependency + query = ( + "UNWIND $batch AS row " + f"{'MERGE' if self.write_mode == 'merge' else 'CREATE'} (n:Node {{id: row.id}}) " + "SET n += row.properties " + ) + + # Strategy: Group nodes by label to allow static label assignment + nodes_by_label: Dict[str, List[Dict[str, Any]]] = {} + + for node_id, data in graph.nodes(data=True): + label = data.get("label", "Entity") + # Sanitize label + label = "".join(x for x in label if x.isalnum() or x == "_") + if not label: + label = "Entity" + + props = {k: v for k, v in data.items() if k != "label"} + props["id"] = node_id # Ensure ID is a property + + if label not in nodes_by_label: + nodes_by_label[label] = [] + nodes_by_label[label].append(props) + + total_nodes = 0 + for label, nodes in nodes_by_label.items(): + for i in range(0, len(nodes), self.batch_size): + batch = nodes[i : i + self.batch_size] + cypher = ( + f"UNWIND $batch AS row " + f"{'MERGE' if self.write_mode == 'merge' else 'CREATE'} (n:{label} {{id: row.id}}) " + "SET n += row.properties" + ) + session.run(cypher, batch=batch) + total_nodes += len(batch) + + rich_print(f" - Exported {total_nodes} nodes") + + def _export_edges(self, session, graph: nx.DiGraph) -> None: + """Batch write edges to Neo4j.""" + edges_by_type: Dict[str, List[Dict[str, Any]]] = {} + + for u, v, data in graph.edges(data=True): + rel_type = data.get("label", "RELATED_TO").upper() + # Sanitize relationship type + rel_type = "".join(x for x in rel_type if x.isalnum() or x == "_") + if not rel_type: + rel_type = "RELATED_TO" + + props = {k: v for k, v in data.items() if k != "label"} + props["source_id"] = u + props["target_id"] = v + + if rel_type not in edges_by_type: + edges_by_type[rel_type] = [] + edges_by_type[rel_type].append(props) + + total_edges = 0 + for rel_type, edges in edges_by_type.items(): + for i in range(0, len(edges), self.batch_size): + batch = edges[i : i + self.batch_size] + cypher = ( + "UNWIND $batch AS row " + "MATCH (source {id: row.source_id}) " + "MATCH (target {id: row.target_id}) " + f"{'MERGE' if self.write_mode == 'merge' else 'CREATE'} (source)-[r:{rel_type}]->(target) " + "SET r += row " # This sets source_id/target_id on rel too, which is harmless but redundant + ) + session.run(cypher, batch=batch) + total_edges += len(batch) + + rich_print(f" - Exported {total_edges} edges") \ No newline at end of file diff --git a/docling_graph/pipeline.py b/docling_graph/pipeline.py index 48c6effe..56340fe3 100644 --- a/docling_graph/pipeline.py +++ b/docling_graph/pipeline.py @@ -29,6 +29,7 @@ # Import LLM clients from .llm_clients import BaseLlmClient, get_client +from .db_clients.neo4j_client import Neo4jExporter def _load_template_class(template_str: str) -> type[BaseModel]: """Dynamically imports a Pydantic model class from a string.""" @@ -230,11 +231,22 @@ def run_pipeline(config: Union[PipelineConfig, Dict[str, Any]]) -> None: if export_format == "csv": CSVExporter().export(knowledge_graph, output_dir) - rich_print(f"[green]→[/green] Saved CSV files to [green]{output_dir}[/green]") + rich_print(f"[green]✔[/green] Saved CSV files to [green]{output_dir}[/green]") elif export_format == "cypher": cypher_path = output_dir / f"{base_name}_graph.cypher" CypherExporter().export(knowledge_graph, cypher_path) - rich_print(f"[green]→[/green] Saved Cypher script to [green]{cypher_path}[/green]") + rich_print(f"[green]✔[/green] Saved Cypher script to [green]{cypher_path}[/green]") + elif export_format == "neo4j": + # Extract Neo4j config + neo_conf = conf.get("neo4j", {}) + exporter = Neo4jExporter( + uri=neo_conf.get("uri", "bolt://localhost:7687"), + auth=(neo_conf.get("username", "neo4j"), neo_conf.get("password", "password")), + database=neo_conf.get("database", "neo4j"), + batch_size=neo_conf.get("batch_size", 1000), + write_mode=neo_conf.get("write_mode", "merge") + ) + exporter.export(knowledge_graph) else: raise ValueError(f"Unknown export format: {export_format}") diff --git a/pyproject.toml b/pyproject.toml index 74cadfae..102f53ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ dependencies = [ "rich>=13,<14", "typer[all]>=0.12,<1.0.0", "python-dotenv>=1.0,<2.0", + "neo4j>=5.0.0", ] [project.urls]