sqlglot.lineage
1from __future__ import annotations 2 3import json 4import logging 5import typing as t 6from dataclasses import dataclass, field 7 8from sqlglot import Schema, exp, maybe_parse 9from sqlglot.errors import SqlglotError 10from sqlglot.optimizer import Scope, build_scope, find_all_in_scope, qualify 11 12if t.TYPE_CHECKING: 13 from sqlglot.dialects.dialect import DialectType 14 15logger = logging.getLogger("sqlglot") 16 17 18@dataclass(frozen=True) 19class Node: 20 name: str 21 expression: exp.Expression 22 source: exp.Expression 23 downstream: t.List[Node] = field(default_factory=list) 24 source_name: str = "" 25 reference_node_name: str = "" 26 27 def walk(self) -> t.Iterator[Node]: 28 yield self 29 30 for d in self.downstream: 31 if isinstance(d, Node): 32 yield from d.walk() 33 else: 34 yield d 35 36 def to_html(self, dialect: DialectType = None, **opts) -> GraphHTML: 37 nodes = {} 38 edges = [] 39 40 for node in self.walk(): 41 if isinstance(node.expression, exp.Table): 42 label = f"FROM {node.expression.this}" 43 title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>" 44 group = 1 45 else: 46 label = node.expression.sql(pretty=True, dialect=dialect) 47 source = node.source.transform( 48 lambda n: ( 49 exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n 50 ), 51 copy=False, 52 ).sql(pretty=True, dialect=dialect) 53 title = f"<pre>{source}</pre>" 54 group = 0 55 56 node_id = id(node) 57 58 nodes[node_id] = { 59 "id": node_id, 60 "label": label, 61 "title": title, 62 "group": group, 63 } 64 65 for d in node.downstream: 66 edges.append({"from": node_id, "to": id(d)}) 67 return GraphHTML(nodes, edges, **opts) 68 69 70def lineage( 71 column: str | exp.Column, 72 sql: str | exp.Expression, 73 schema: t.Optional[t.Dict | Schema] = None, 74 sources: t.Optional[t.Dict[str, str | exp.Query]] = None, 75 dialect: DialectType = None, 76 **kwargs, 77) -> Node: 78 """Build the lineage graph for a column of a SQL query. 79 80 Args: 81 column: The column to build the lineage for. 82 sql: The SQL string or expression. 83 schema: The schema of tables. 84 sources: A mapping of queries which will be used to continue building lineage. 85 dialect: The dialect of input SQL. 86 **kwargs: Qualification optimizer kwargs. 87 88 Returns: 89 A lineage node. 90 """ 91 92 expression = maybe_parse(sql, dialect=dialect) 93 94 if sources: 95 expression = exp.expand( 96 expression, 97 {k: t.cast(exp.Query, maybe_parse(v, dialect=dialect)) for k, v in sources.items()}, 98 dialect=dialect, 99 ) 100 101 qualified = qualify.qualify( 102 expression, 103 dialect=dialect, 104 schema=schema, 105 **{"validate_qualify_columns": False, "identify": False, **kwargs}, # type: ignore 106 ) 107 108 scope = build_scope(qualified) 109 110 if not scope: 111 raise SqlglotError("Cannot build lineage, sql must be SELECT") 112 113 def to_node( 114 column: str | int, 115 scope: Scope, 116 scope_name: t.Optional[str] = None, 117 upstream: t.Optional[Node] = None, 118 source_name: t.Optional[str] = None, 119 reference_node_name: t.Optional[str] = None, 120 ) -> Node: 121 source_names = { 122 dt.alias: dt.comments[0].split()[1] 123 for dt in scope.derived_tables 124 if dt.comments and dt.comments[0].startswith("source: ") 125 } 126 127 # Find the specific select clause that is the source of the column we want. 128 # This can either be a specific, named select or a generic `*` clause. 129 select = ( 130 scope.expression.selects[column] 131 if isinstance(column, int) 132 else next( 133 (select for select in scope.expression.selects if select.alias_or_name == column), 134 exp.Star() if scope.expression.is_star else scope.expression, 135 ) 136 ) 137 138 if isinstance(scope.expression, exp.Union): 139 upstream = upstream or Node(name="UNION", source=scope.expression, expression=select) 140 141 index = ( 142 column 143 if isinstance(column, int) 144 else next( 145 ( 146 i 147 for i, select in enumerate(scope.expression.selects) 148 if select.alias_or_name == column or select.is_star 149 ), 150 -1, # mypy will not allow a None here, but a negative index should never be returned 151 ) 152 ) 153 154 if index == -1: 155 raise ValueError(f"Could not find {column} in {scope.expression}") 156 157 for s in scope.union_scopes: 158 to_node( 159 index, 160 scope=s, 161 upstream=upstream, 162 source_name=source_name, 163 reference_node_name=reference_node_name, 164 ) 165 166 return upstream 167 168 if isinstance(scope.expression, exp.Select): 169 # For better ergonomics in our node labels, replace the full select with 170 # a version that has only the column we care about. 171 # "x", SELECT x, y FROM foo 172 # => "x", SELECT x FROM foo 173 source = t.cast(exp.Expression, scope.expression.select(select, append=False)) 174 else: 175 source = scope.expression 176 177 # Create the node for this step in the lineage chain, and attach it to the previous one. 178 node = Node( 179 name=f"{scope_name}.{column}" if scope_name else str(column), 180 source=source, 181 expression=select, 182 source_name=source_name or "", 183 reference_node_name=reference_node_name or "", 184 ) 185 186 if upstream: 187 upstream.downstream.append(node) 188 189 subquery_scopes = { 190 id(subquery_scope.expression): subquery_scope 191 for subquery_scope in scope.subquery_scopes 192 } 193 194 for subquery in find_all_in_scope(select, exp.UNWRAPPED_QUERIES): 195 subquery_scope = subquery_scopes.get(id(subquery)) 196 if not subquery_scope: 197 logger.warning(f"Unknown subquery scope: {subquery.sql(dialect=dialect)}") 198 continue 199 200 for name in subquery.named_selects: 201 to_node(name, scope=subquery_scope, upstream=node) 202 203 # if the select is a star add all scope sources as downstreams 204 if select.is_star: 205 for source in scope.sources.values(): 206 if isinstance(source, Scope): 207 source = source.expression 208 node.downstream.append(Node(name=select.sql(), source=source, expression=source)) 209 210 # Find all columns that went into creating this one to list their lineage nodes. 211 source_columns = set(find_all_in_scope(select, exp.Column)) 212 213 # If the source is a UDTF find columns used in the UTDF to generate the table 214 if isinstance(source, exp.UDTF): 215 source_columns |= set(source.find_all(exp.Column)) 216 217 for c in source_columns: 218 table = c.table 219 source = scope.sources.get(table) 220 221 if isinstance(source, Scope): 222 selected_node, _ = scope.selected_sources.get(table, (None, None)) 223 # The table itself came from a more specific scope. Recurse into that one using the unaliased column name. 224 to_node( 225 c.name, 226 scope=source, 227 scope_name=table, 228 upstream=node, 229 source_name=source_names.get(table) or source_name, 230 reference_node_name=selected_node.name if selected_node else None, 231 ) 232 else: 233 # The source is not a scope - we've reached the end of the line. At this point, if a source is not found 234 # it means this column's lineage is unknown. This can happen if the definition of a source used in a query 235 # is not passed into the `sources` map. 236 source = source or exp.Placeholder() 237 node.downstream.append(Node(name=c.sql(), source=source, expression=source)) 238 239 return node 240 241 return to_node(column if isinstance(column, str) else column.name, scope) 242 243 244class GraphHTML: 245 """Node to HTML generator using vis.js. 246 247 https://visjs.github.io/vis-network/docs/network/ 248 """ 249 250 def __init__( 251 self, nodes: t.Dict, edges: t.List, imports: bool = True, options: t.Optional[t.Dict] = None 252 ): 253 self.imports = imports 254 255 self.options = { 256 "height": "500px", 257 "width": "100%", 258 "layout": { 259 "hierarchical": { 260 "enabled": True, 261 "nodeSpacing": 200, 262 "sortMethod": "directed", 263 }, 264 }, 265 "interaction": { 266 "dragNodes": False, 267 "selectable": False, 268 }, 269 "physics": { 270 "enabled": False, 271 }, 272 "edges": { 273 "arrows": "to", 274 }, 275 "nodes": { 276 "font": "20px monaco", 277 "shape": "box", 278 "widthConstraint": { 279 "maximum": 300, 280 }, 281 }, 282 **(options or {}), 283 } 284 285 self.nodes = nodes 286 self.edges = edges 287 288 def __str__(self): 289 nodes = json.dumps(list(self.nodes.values())) 290 edges = json.dumps(self.edges) 291 options = json.dumps(self.options) 292 imports = ( 293 """<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script> 294 <script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script> 295 <link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />""" 296 if self.imports 297 else "" 298 ) 299 300 return f"""<div> 301 <div id="sqlglot-lineage"></div> 302 {imports} 303 <script type="text/javascript"> 304 var nodes = new vis.DataSet({nodes}) 305 nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0]) 306 307 new vis.Network( 308 document.getElementById("sqlglot-lineage"), 309 {{ 310 nodes: nodes, 311 edges: new vis.DataSet({edges}) 312 }}, 313 {options}, 314 ) 315 </script> 316</div>""" 317 318 def _repr_html_(self) -> str: 319 return self.__str__()
logger =
<Logger sqlglot (WARNING)>
@dataclass(frozen=True)
class
Node:
19@dataclass(frozen=True) 20class Node: 21 name: str 22 expression: exp.Expression 23 source: exp.Expression 24 downstream: t.List[Node] = field(default_factory=list) 25 source_name: str = "" 26 reference_node_name: str = "" 27 28 def walk(self) -> t.Iterator[Node]: 29 yield self 30 31 for d in self.downstream: 32 if isinstance(d, Node): 33 yield from d.walk() 34 else: 35 yield d 36 37 def to_html(self, dialect: DialectType = None, **opts) -> GraphHTML: 38 nodes = {} 39 edges = [] 40 41 for node in self.walk(): 42 if isinstance(node.expression, exp.Table): 43 label = f"FROM {node.expression.this}" 44 title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>" 45 group = 1 46 else: 47 label = node.expression.sql(pretty=True, dialect=dialect) 48 source = node.source.transform( 49 lambda n: ( 50 exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n 51 ), 52 copy=False, 53 ).sql(pretty=True, dialect=dialect) 54 title = f"<pre>{source}</pre>" 55 group = 0 56 57 node_id = id(node) 58 59 nodes[node_id] = { 60 "id": node_id, 61 "label": label, 62 "title": title, 63 "group": group, 64 } 65 66 for d in node.downstream: 67 edges.append({"from": node_id, "to": id(d)}) 68 return GraphHTML(nodes, edges, **opts)
Node( name: str, expression: sqlglot.expressions.Expression, source: sqlglot.expressions.Expression, downstream: List[Node] = <factory>, source_name: str = '', reference_node_name: str = '')
expression: sqlglot.expressions.Expression
source: sqlglot.expressions.Expression
downstream: List[Node]
def
to_html( self, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, **opts) -> GraphHTML:
37 def to_html(self, dialect: DialectType = None, **opts) -> GraphHTML: 38 nodes = {} 39 edges = [] 40 41 for node in self.walk(): 42 if isinstance(node.expression, exp.Table): 43 label = f"FROM {node.expression.this}" 44 title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>" 45 group = 1 46 else: 47 label = node.expression.sql(pretty=True, dialect=dialect) 48 source = node.source.transform( 49 lambda n: ( 50 exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n 51 ), 52 copy=False, 53 ).sql(pretty=True, dialect=dialect) 54 title = f"<pre>{source}</pre>" 55 group = 0 56 57 node_id = id(node) 58 59 nodes[node_id] = { 60 "id": node_id, 61 "label": label, 62 "title": title, 63 "group": group, 64 } 65 66 for d in node.downstream: 67 edges.append({"from": node_id, "to": id(d)}) 68 return GraphHTML(nodes, edges, **opts)
def
lineage( column: str | sqlglot.expressions.Column, sql: str | sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema, NoneType] = None, sources: Optional[Dict[str, str | sqlglot.expressions.Query]] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, **kwargs) -> Node:
71def lineage( 72 column: str | exp.Column, 73 sql: str | exp.Expression, 74 schema: t.Optional[t.Dict | Schema] = None, 75 sources: t.Optional[t.Dict[str, str | exp.Query]] = None, 76 dialect: DialectType = None, 77 **kwargs, 78) -> Node: 79 """Build the lineage graph for a column of a SQL query. 80 81 Args: 82 column: The column to build the lineage for. 83 sql: The SQL string or expression. 84 schema: The schema of tables. 85 sources: A mapping of queries which will be used to continue building lineage. 86 dialect: The dialect of input SQL. 87 **kwargs: Qualification optimizer kwargs. 88 89 Returns: 90 A lineage node. 91 """ 92 93 expression = maybe_parse(sql, dialect=dialect) 94 95 if sources: 96 expression = exp.expand( 97 expression, 98 {k: t.cast(exp.Query, maybe_parse(v, dialect=dialect)) for k, v in sources.items()}, 99 dialect=dialect, 100 ) 101 102 qualified = qualify.qualify( 103 expression, 104 dialect=dialect, 105 schema=schema, 106 **{"validate_qualify_columns": False, "identify": False, **kwargs}, # type: ignore 107 ) 108 109 scope = build_scope(qualified) 110 111 if not scope: 112 raise SqlglotError("Cannot build lineage, sql must be SELECT") 113 114 def to_node( 115 column: str | int, 116 scope: Scope, 117 scope_name: t.Optional[str] = None, 118 upstream: t.Optional[Node] = None, 119 source_name: t.Optional[str] = None, 120 reference_node_name: t.Optional[str] = None, 121 ) -> Node: 122 source_names = { 123 dt.alias: dt.comments[0].split()[1] 124 for dt in scope.derived_tables 125 if dt.comments and dt.comments[0].startswith("source: ") 126 } 127 128 # Find the specific select clause that is the source of the column we want. 129 # This can either be a specific, named select or a generic `*` clause. 130 select = ( 131 scope.expression.selects[column] 132 if isinstance(column, int) 133 else next( 134 (select for select in scope.expression.selects if select.alias_or_name == column), 135 exp.Star() if scope.expression.is_star else scope.expression, 136 ) 137 ) 138 139 if isinstance(scope.expression, exp.Union): 140 upstream = upstream or Node(name="UNION", source=scope.expression, expression=select) 141 142 index = ( 143 column 144 if isinstance(column, int) 145 else next( 146 ( 147 i 148 for i, select in enumerate(scope.expression.selects) 149 if select.alias_or_name == column or select.is_star 150 ), 151 -1, # mypy will not allow a None here, but a negative index should never be returned 152 ) 153 ) 154 155 if index == -1: 156 raise ValueError(f"Could not find {column} in {scope.expression}") 157 158 for s in scope.union_scopes: 159 to_node( 160 index, 161 scope=s, 162 upstream=upstream, 163 source_name=source_name, 164 reference_node_name=reference_node_name, 165 ) 166 167 return upstream 168 169 if isinstance(scope.expression, exp.Select): 170 # For better ergonomics in our node labels, replace the full select with 171 # a version that has only the column we care about. 172 # "x", SELECT x, y FROM foo 173 # => "x", SELECT x FROM foo 174 source = t.cast(exp.Expression, scope.expression.select(select, append=False)) 175 else: 176 source = scope.expression 177 178 # Create the node for this step in the lineage chain, and attach it to the previous one. 179 node = Node( 180 name=f"{scope_name}.{column}" if scope_name else str(column), 181 source=source, 182 expression=select, 183 source_name=source_name or "", 184 reference_node_name=reference_node_name or "", 185 ) 186 187 if upstream: 188 upstream.downstream.append(node) 189 190 subquery_scopes = { 191 id(subquery_scope.expression): subquery_scope 192 for subquery_scope in scope.subquery_scopes 193 } 194 195 for subquery in find_all_in_scope(select, exp.UNWRAPPED_QUERIES): 196 subquery_scope = subquery_scopes.get(id(subquery)) 197 if not subquery_scope: 198 logger.warning(f"Unknown subquery scope: {subquery.sql(dialect=dialect)}") 199 continue 200 201 for name in subquery.named_selects: 202 to_node(name, scope=subquery_scope, upstream=node) 203 204 # if the select is a star add all scope sources as downstreams 205 if select.is_star: 206 for source in scope.sources.values(): 207 if isinstance(source, Scope): 208 source = source.expression 209 node.downstream.append(Node(name=select.sql(), source=source, expression=source)) 210 211 # Find all columns that went into creating this one to list their lineage nodes. 212 source_columns = set(find_all_in_scope(select, exp.Column)) 213 214 # If the source is a UDTF find columns used in the UTDF to generate the table 215 if isinstance(source, exp.UDTF): 216 source_columns |= set(source.find_all(exp.Column)) 217 218 for c in source_columns: 219 table = c.table 220 source = scope.sources.get(table) 221 222 if isinstance(source, Scope): 223 selected_node, _ = scope.selected_sources.get(table, (None, None)) 224 # The table itself came from a more specific scope. Recurse into that one using the unaliased column name. 225 to_node( 226 c.name, 227 scope=source, 228 scope_name=table, 229 upstream=node, 230 source_name=source_names.get(table) or source_name, 231 reference_node_name=selected_node.name if selected_node else None, 232 ) 233 else: 234 # The source is not a scope - we've reached the end of the line. At this point, if a source is not found 235 # it means this column's lineage is unknown. This can happen if the definition of a source used in a query 236 # is not passed into the `sources` map. 237 source = source or exp.Placeholder() 238 node.downstream.append(Node(name=c.sql(), source=source, expression=source)) 239 240 return node 241 242 return to_node(column if isinstance(column, str) else column.name, scope)
Build the lineage graph for a column of a SQL query.
Arguments:
- column: The column to build the lineage for.
- sql: The SQL string or expression.
- schema: The schema of tables.
- sources: A mapping of queries which will be used to continue building lineage.
- dialect: The dialect of input SQL.
- **kwargs: Qualification optimizer kwargs.
Returns:
A lineage node.
class
GraphHTML:
245class GraphHTML: 246 """Node to HTML generator using vis.js. 247 248 https://visjs.github.io/vis-network/docs/network/ 249 """ 250 251 def __init__( 252 self, nodes: t.Dict, edges: t.List, imports: bool = True, options: t.Optional[t.Dict] = None 253 ): 254 self.imports = imports 255 256 self.options = { 257 "height": "500px", 258 "width": "100%", 259 "layout": { 260 "hierarchical": { 261 "enabled": True, 262 "nodeSpacing": 200, 263 "sortMethod": "directed", 264 }, 265 }, 266 "interaction": { 267 "dragNodes": False, 268 "selectable": False, 269 }, 270 "physics": { 271 "enabled": False, 272 }, 273 "edges": { 274 "arrows": "to", 275 }, 276 "nodes": { 277 "font": "20px monaco", 278 "shape": "box", 279 "widthConstraint": { 280 "maximum": 300, 281 }, 282 }, 283 **(options or {}), 284 } 285 286 self.nodes = nodes 287 self.edges = edges 288 289 def __str__(self): 290 nodes = json.dumps(list(self.nodes.values())) 291 edges = json.dumps(self.edges) 292 options = json.dumps(self.options) 293 imports = ( 294 """<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script> 295 <script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script> 296 <link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />""" 297 if self.imports 298 else "" 299 ) 300 301 return f"""<div> 302 <div id="sqlglot-lineage"></div> 303 {imports} 304 <script type="text/javascript"> 305 var nodes = new vis.DataSet({nodes}) 306 nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0]) 307 308 new vis.Network( 309 document.getElementById("sqlglot-lineage"), 310 {{ 311 nodes: nodes, 312 edges: new vis.DataSet({edges}) 313 }}, 314 {options}, 315 ) 316 </script> 317</div>""" 318 319 def _repr_html_(self) -> str: 320 return self.__str__()
Node to HTML generator using vis.js.
GraphHTML( nodes: Dict, edges: List, imports: bool = True, options: Optional[Dict] = None)
251 def __init__( 252 self, nodes: t.Dict, edges: t.List, imports: bool = True, options: t.Optional[t.Dict] = None 253 ): 254 self.imports = imports 255 256 self.options = { 257 "height": "500px", 258 "width": "100%", 259 "layout": { 260 "hierarchical": { 261 "enabled": True, 262 "nodeSpacing": 200, 263 "sortMethod": "directed", 264 }, 265 }, 266 "interaction": { 267 "dragNodes": False, 268 "selectable": False, 269 }, 270 "physics": { 271 "enabled": False, 272 }, 273 "edges": { 274 "arrows": "to", 275 }, 276 "nodes": { 277 "font": "20px monaco", 278 "shape": "box", 279 "widthConstraint": { 280 "maximum": 300, 281 }, 282 }, 283 **(options or {}), 284 } 285 286 self.nodes = nodes 287 self.edges = edges