Edit on GitHub

sqlglot.lineage

  1from __future__ import annotations
  2
  3import json
  4import logging
  5import typing as t
  6from dataclasses import dataclass, field
  7
  8from sqlglot import Schema, exp, maybe_parse
  9from sqlglot.errors import SqlglotError
 10from sqlglot.optimizer import Scope, build_scope, find_all_in_scope, qualify
 11
 12if t.TYPE_CHECKING:
 13    from sqlglot.dialects.dialect import DialectType
 14
 15logger = logging.getLogger("sqlglot")
 16
 17
 18@dataclass(frozen=True)
 19class Node:
 20    name: str
 21    expression: exp.Expression
 22    source: exp.Expression
 23    downstream: t.List[Node] = field(default_factory=list)
 24    source_name: str = ""
 25    reference_node_name: str = ""
 26
 27    def walk(self) -> t.Iterator[Node]:
 28        yield self
 29
 30        for d in self.downstream:
 31            if isinstance(d, Node):
 32                yield from d.walk()
 33            else:
 34                yield d
 35
 36    def to_html(self, dialect: DialectType = None, **opts) -> GraphHTML:
 37        nodes = {}
 38        edges = []
 39
 40        for node in self.walk():
 41            if isinstance(node.expression, exp.Table):
 42                label = f"FROM {node.expression.this}"
 43                title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
 44                group = 1
 45            else:
 46                label = node.expression.sql(pretty=True, dialect=dialect)
 47                source = node.source.transform(
 48                    lambda n: (
 49                        exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n
 50                    ),
 51                    copy=False,
 52                ).sql(pretty=True, dialect=dialect)
 53                title = f"<pre>{source}</pre>"
 54                group = 0
 55
 56            node_id = id(node)
 57
 58            nodes[node_id] = {
 59                "id": node_id,
 60                "label": label,
 61                "title": title,
 62                "group": group,
 63            }
 64
 65            for d in node.downstream:
 66                edges.append({"from": node_id, "to": id(d)})
 67        return GraphHTML(nodes, edges, **opts)
 68
 69
 70def lineage(
 71    column: str | exp.Column,
 72    sql: str | exp.Expression,
 73    schema: t.Optional[t.Dict | Schema] = None,
 74    sources: t.Optional[t.Dict[str, str | exp.Query]] = None,
 75    dialect: DialectType = None,
 76    **kwargs,
 77) -> Node:
 78    """Build the lineage graph for a column of a SQL query.
 79
 80    Args:
 81        column: The column to build the lineage for.
 82        sql: The SQL string or expression.
 83        schema: The schema of tables.
 84        sources: A mapping of queries which will be used to continue building lineage.
 85        dialect: The dialect of input SQL.
 86        **kwargs: Qualification optimizer kwargs.
 87
 88    Returns:
 89        A lineage node.
 90    """
 91
 92    expression = maybe_parse(sql, dialect=dialect)
 93
 94    if sources:
 95        expression = exp.expand(
 96            expression,
 97            {k: t.cast(exp.Query, maybe_parse(v, dialect=dialect)) for k, v in sources.items()},
 98            dialect=dialect,
 99        )
100
101    qualified = qualify.qualify(
102        expression,
103        dialect=dialect,
104        schema=schema,
105        **{"validate_qualify_columns": False, "identify": False, **kwargs},  # type: ignore
106    )
107
108    scope = build_scope(qualified)
109
110    if not scope:
111        raise SqlglotError("Cannot build lineage, sql must be SELECT")
112
113    def to_node(
114        column: str | int,
115        scope: Scope,
116        scope_name: t.Optional[str] = None,
117        upstream: t.Optional[Node] = None,
118        source_name: t.Optional[str] = None,
119        reference_node_name: t.Optional[str] = None,
120    ) -> Node:
121        source_names = {
122            dt.alias: dt.comments[0].split()[1]
123            for dt in scope.derived_tables
124            if dt.comments and dt.comments[0].startswith("source: ")
125        }
126
127        # Find the specific select clause that is the source of the column we want.
128        # This can either be a specific, named select or a generic `*` clause.
129        select = (
130            scope.expression.selects[column]
131            if isinstance(column, int)
132            else next(
133                (select for select in scope.expression.selects if select.alias_or_name == column),
134                exp.Star() if scope.expression.is_star else scope.expression,
135            )
136        )
137
138        if isinstance(scope.expression, exp.Union):
139            upstream = upstream or Node(name="UNION", source=scope.expression, expression=select)
140
141            index = (
142                column
143                if isinstance(column, int)
144                else next(
145                    (
146                        i
147                        for i, select in enumerate(scope.expression.selects)
148                        if select.alias_or_name == column or select.is_star
149                    ),
150                    -1,  # mypy will not allow a None here, but a negative index should never be returned
151                )
152            )
153
154            if index == -1:
155                raise ValueError(f"Could not find {column} in {scope.expression}")
156
157            for s in scope.union_scopes:
158                to_node(
159                    index,
160                    scope=s,
161                    upstream=upstream,
162                    source_name=source_name,
163                    reference_node_name=reference_node_name,
164                )
165
166            return upstream
167
168        if isinstance(scope.expression, exp.Select):
169            # For better ergonomics in our node labels, replace the full select with
170            # a version that has only the column we care about.
171            #   "x", SELECT x, y FROM foo
172            #     => "x", SELECT x FROM foo
173            source = t.cast(exp.Expression, scope.expression.select(select, append=False))
174        else:
175            source = scope.expression
176
177        # Create the node for this step in the lineage chain, and attach it to the previous one.
178        node = Node(
179            name=f"{scope_name}.{column}" if scope_name else str(column),
180            source=source,
181            expression=select,
182            source_name=source_name or "",
183            reference_node_name=reference_node_name or "",
184        )
185
186        if upstream:
187            upstream.downstream.append(node)
188
189        subquery_scopes = {
190            id(subquery_scope.expression): subquery_scope
191            for subquery_scope in scope.subquery_scopes
192        }
193
194        for subquery in find_all_in_scope(select, exp.UNWRAPPED_QUERIES):
195            subquery_scope = subquery_scopes.get(id(subquery))
196            if not subquery_scope:
197                logger.warning(f"Unknown subquery scope: {subquery.sql(dialect=dialect)}")
198                continue
199
200            for name in subquery.named_selects:
201                to_node(name, scope=subquery_scope, upstream=node)
202
203        # if the select is a star add all scope sources as downstreams
204        if select.is_star:
205            for source in scope.sources.values():
206                if isinstance(source, Scope):
207                    source = source.expression
208                node.downstream.append(Node(name=select.sql(), source=source, expression=source))
209
210        # Find all columns that went into creating this one to list their lineage nodes.
211        source_columns = set(find_all_in_scope(select, exp.Column))
212
213        # If the source is a UDTF find columns used in the UTDF to generate the table
214        if isinstance(source, exp.UDTF):
215            source_columns |= set(source.find_all(exp.Column))
216
217        for c in source_columns:
218            table = c.table
219            source = scope.sources.get(table)
220
221            if isinstance(source, Scope):
222                selected_node, _ = scope.selected_sources.get(table, (None, None))
223                # The table itself came from a more specific scope. Recurse into that one using the unaliased column name.
224                to_node(
225                    c.name,
226                    scope=source,
227                    scope_name=table,
228                    upstream=node,
229                    source_name=source_names.get(table) or source_name,
230                    reference_node_name=selected_node.name if selected_node else None,
231                )
232            else:
233                # The source is not a scope - we've reached the end of the line. At this point, if a source is not found
234                # it means this column's lineage is unknown. This can happen if the definition of a source used in a query
235                # is not passed into the `sources` map.
236                source = source or exp.Placeholder()
237                node.downstream.append(Node(name=c.sql(), source=source, expression=source))
238
239        return node
240
241    return to_node(column if isinstance(column, str) else column.name, scope)
242
243
244class GraphHTML:
245    """Node to HTML generator using vis.js.
246
247    https://visjs.github.io/vis-network/docs/network/
248    """
249
250    def __init__(
251        self, nodes: t.Dict, edges: t.List, imports: bool = True, options: t.Optional[t.Dict] = None
252    ):
253        self.imports = imports
254
255        self.options = {
256            "height": "500px",
257            "width": "100%",
258            "layout": {
259                "hierarchical": {
260                    "enabled": True,
261                    "nodeSpacing": 200,
262                    "sortMethod": "directed",
263                },
264            },
265            "interaction": {
266                "dragNodes": False,
267                "selectable": False,
268            },
269            "physics": {
270                "enabled": False,
271            },
272            "edges": {
273                "arrows": "to",
274            },
275            "nodes": {
276                "font": "20px monaco",
277                "shape": "box",
278                "widthConstraint": {
279                    "maximum": 300,
280                },
281            },
282            **(options or {}),
283        }
284
285        self.nodes = nodes
286        self.edges = edges
287
288    def __str__(self):
289        nodes = json.dumps(list(self.nodes.values()))
290        edges = json.dumps(self.edges)
291        options = json.dumps(self.options)
292        imports = (
293            """<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script>
294  <script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script>
295  <link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />"""
296            if self.imports
297            else ""
298        )
299
300        return f"""<div>
301  <div id="sqlglot-lineage"></div>
302  {imports}
303  <script type="text/javascript">
304    var nodes = new vis.DataSet({nodes})
305    nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0])
306
307    new vis.Network(
308        document.getElementById("sqlglot-lineage"),
309        {{
310            nodes: nodes,
311            edges: new vis.DataSet({edges})
312        }},
313        {options},
314    )
315  </script>
316</div>"""
317
318    def _repr_html_(self) -> str:
319        return self.__str__()
logger = <Logger sqlglot (WARNING)>
@dataclass(frozen=True)
class Node:
19@dataclass(frozen=True)
20class Node:
21    name: str
22    expression: exp.Expression
23    source: exp.Expression
24    downstream: t.List[Node] = field(default_factory=list)
25    source_name: str = ""
26    reference_node_name: str = ""
27
28    def walk(self) -> t.Iterator[Node]:
29        yield self
30
31        for d in self.downstream:
32            if isinstance(d, Node):
33                yield from d.walk()
34            else:
35                yield d
36
37    def to_html(self, dialect: DialectType = None, **opts) -> GraphHTML:
38        nodes = {}
39        edges = []
40
41        for node in self.walk():
42            if isinstance(node.expression, exp.Table):
43                label = f"FROM {node.expression.this}"
44                title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
45                group = 1
46            else:
47                label = node.expression.sql(pretty=True, dialect=dialect)
48                source = node.source.transform(
49                    lambda n: (
50                        exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n
51                    ),
52                    copy=False,
53                ).sql(pretty=True, dialect=dialect)
54                title = f"<pre>{source}</pre>"
55                group = 0
56
57            node_id = id(node)
58
59            nodes[node_id] = {
60                "id": node_id,
61                "label": label,
62                "title": title,
63                "group": group,
64            }
65
66            for d in node.downstream:
67                edges.append({"from": node_id, "to": id(d)})
68        return GraphHTML(nodes, edges, **opts)
Node( name: str, expression: sqlglot.expressions.Expression, source: sqlglot.expressions.Expression, downstream: List[Node] = <factory>, source_name: str = '', reference_node_name: str = '')
name: str
downstream: List[Node]
source_name: str = ''
reference_node_name: str = ''
def walk(self) -> Iterator[Node]:
28    def walk(self) -> t.Iterator[Node]:
29        yield self
30
31        for d in self.downstream:
32            if isinstance(d, Node):
33                yield from d.walk()
34            else:
35                yield d
def to_html( self, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, **opts) -> GraphHTML:
37    def to_html(self, dialect: DialectType = None, **opts) -> GraphHTML:
38        nodes = {}
39        edges = []
40
41        for node in self.walk():
42            if isinstance(node.expression, exp.Table):
43                label = f"FROM {node.expression.this}"
44                title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
45                group = 1
46            else:
47                label = node.expression.sql(pretty=True, dialect=dialect)
48                source = node.source.transform(
49                    lambda n: (
50                        exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n
51                    ),
52                    copy=False,
53                ).sql(pretty=True, dialect=dialect)
54                title = f"<pre>{source}</pre>"
55                group = 0
56
57            node_id = id(node)
58
59            nodes[node_id] = {
60                "id": node_id,
61                "label": label,
62                "title": title,
63                "group": group,
64            }
65
66            for d in node.downstream:
67                edges.append({"from": node_id, "to": id(d)})
68        return GraphHTML(nodes, edges, **opts)
def lineage( column: str | sqlglot.expressions.Column, sql: str | sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema, NoneType] = None, sources: Optional[Dict[str, str | sqlglot.expressions.Query]] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, **kwargs) -> Node:
 71def lineage(
 72    column: str | exp.Column,
 73    sql: str | exp.Expression,
 74    schema: t.Optional[t.Dict | Schema] = None,
 75    sources: t.Optional[t.Dict[str, str | exp.Query]] = None,
 76    dialect: DialectType = None,
 77    **kwargs,
 78) -> Node:
 79    """Build the lineage graph for a column of a SQL query.
 80
 81    Args:
 82        column: The column to build the lineage for.
 83        sql: The SQL string or expression.
 84        schema: The schema of tables.
 85        sources: A mapping of queries which will be used to continue building lineage.
 86        dialect: The dialect of input SQL.
 87        **kwargs: Qualification optimizer kwargs.
 88
 89    Returns:
 90        A lineage node.
 91    """
 92
 93    expression = maybe_parse(sql, dialect=dialect)
 94
 95    if sources:
 96        expression = exp.expand(
 97            expression,
 98            {k: t.cast(exp.Query, maybe_parse(v, dialect=dialect)) for k, v in sources.items()},
 99            dialect=dialect,
100        )
101
102    qualified = qualify.qualify(
103        expression,
104        dialect=dialect,
105        schema=schema,
106        **{"validate_qualify_columns": False, "identify": False, **kwargs},  # type: ignore
107    )
108
109    scope = build_scope(qualified)
110
111    if not scope:
112        raise SqlglotError("Cannot build lineage, sql must be SELECT")
113
114    def to_node(
115        column: str | int,
116        scope: Scope,
117        scope_name: t.Optional[str] = None,
118        upstream: t.Optional[Node] = None,
119        source_name: t.Optional[str] = None,
120        reference_node_name: t.Optional[str] = None,
121    ) -> Node:
122        source_names = {
123            dt.alias: dt.comments[0].split()[1]
124            for dt in scope.derived_tables
125            if dt.comments and dt.comments[0].startswith("source: ")
126        }
127
128        # Find the specific select clause that is the source of the column we want.
129        # This can either be a specific, named select or a generic `*` clause.
130        select = (
131            scope.expression.selects[column]
132            if isinstance(column, int)
133            else next(
134                (select for select in scope.expression.selects if select.alias_or_name == column),
135                exp.Star() if scope.expression.is_star else scope.expression,
136            )
137        )
138
139        if isinstance(scope.expression, exp.Union):
140            upstream = upstream or Node(name="UNION", source=scope.expression, expression=select)
141
142            index = (
143                column
144                if isinstance(column, int)
145                else next(
146                    (
147                        i
148                        for i, select in enumerate(scope.expression.selects)
149                        if select.alias_or_name == column or select.is_star
150                    ),
151                    -1,  # mypy will not allow a None here, but a negative index should never be returned
152                )
153            )
154
155            if index == -1:
156                raise ValueError(f"Could not find {column} in {scope.expression}")
157
158            for s in scope.union_scopes:
159                to_node(
160                    index,
161                    scope=s,
162                    upstream=upstream,
163                    source_name=source_name,
164                    reference_node_name=reference_node_name,
165                )
166
167            return upstream
168
169        if isinstance(scope.expression, exp.Select):
170            # For better ergonomics in our node labels, replace the full select with
171            # a version that has only the column we care about.
172            #   "x", SELECT x, y FROM foo
173            #     => "x", SELECT x FROM foo
174            source = t.cast(exp.Expression, scope.expression.select(select, append=False))
175        else:
176            source = scope.expression
177
178        # Create the node for this step in the lineage chain, and attach it to the previous one.
179        node = Node(
180            name=f"{scope_name}.{column}" if scope_name else str(column),
181            source=source,
182            expression=select,
183            source_name=source_name or "",
184            reference_node_name=reference_node_name or "",
185        )
186
187        if upstream:
188            upstream.downstream.append(node)
189
190        subquery_scopes = {
191            id(subquery_scope.expression): subquery_scope
192            for subquery_scope in scope.subquery_scopes
193        }
194
195        for subquery in find_all_in_scope(select, exp.UNWRAPPED_QUERIES):
196            subquery_scope = subquery_scopes.get(id(subquery))
197            if not subquery_scope:
198                logger.warning(f"Unknown subquery scope: {subquery.sql(dialect=dialect)}")
199                continue
200
201            for name in subquery.named_selects:
202                to_node(name, scope=subquery_scope, upstream=node)
203
204        # if the select is a star add all scope sources as downstreams
205        if select.is_star:
206            for source in scope.sources.values():
207                if isinstance(source, Scope):
208                    source = source.expression
209                node.downstream.append(Node(name=select.sql(), source=source, expression=source))
210
211        # Find all columns that went into creating this one to list their lineage nodes.
212        source_columns = set(find_all_in_scope(select, exp.Column))
213
214        # If the source is a UDTF find columns used in the UTDF to generate the table
215        if isinstance(source, exp.UDTF):
216            source_columns |= set(source.find_all(exp.Column))
217
218        for c in source_columns:
219            table = c.table
220            source = scope.sources.get(table)
221
222            if isinstance(source, Scope):
223                selected_node, _ = scope.selected_sources.get(table, (None, None))
224                # The table itself came from a more specific scope. Recurse into that one using the unaliased column name.
225                to_node(
226                    c.name,
227                    scope=source,
228                    scope_name=table,
229                    upstream=node,
230                    source_name=source_names.get(table) or source_name,
231                    reference_node_name=selected_node.name if selected_node else None,
232                )
233            else:
234                # The source is not a scope - we've reached the end of the line. At this point, if a source is not found
235                # it means this column's lineage is unknown. This can happen if the definition of a source used in a query
236                # is not passed into the `sources` map.
237                source = source or exp.Placeholder()
238                node.downstream.append(Node(name=c.sql(), source=source, expression=source))
239
240        return node
241
242    return to_node(column if isinstance(column, str) else column.name, scope)

Build the lineage graph for a column of a SQL query.

Arguments:
  • column: The column to build the lineage for.
  • sql: The SQL string or expression.
  • schema: The schema of tables.
  • sources: A mapping of queries which will be used to continue building lineage.
  • dialect: The dialect of input SQL.
  • **kwargs: Qualification optimizer kwargs.
Returns:

A lineage node.

class GraphHTML:
245class GraphHTML:
246    """Node to HTML generator using vis.js.
247
248    https://visjs.github.io/vis-network/docs/network/
249    """
250
251    def __init__(
252        self, nodes: t.Dict, edges: t.List, imports: bool = True, options: t.Optional[t.Dict] = None
253    ):
254        self.imports = imports
255
256        self.options = {
257            "height": "500px",
258            "width": "100%",
259            "layout": {
260                "hierarchical": {
261                    "enabled": True,
262                    "nodeSpacing": 200,
263                    "sortMethod": "directed",
264                },
265            },
266            "interaction": {
267                "dragNodes": False,
268                "selectable": False,
269            },
270            "physics": {
271                "enabled": False,
272            },
273            "edges": {
274                "arrows": "to",
275            },
276            "nodes": {
277                "font": "20px monaco",
278                "shape": "box",
279                "widthConstraint": {
280                    "maximum": 300,
281                },
282            },
283            **(options or {}),
284        }
285
286        self.nodes = nodes
287        self.edges = edges
288
289    def __str__(self):
290        nodes = json.dumps(list(self.nodes.values()))
291        edges = json.dumps(self.edges)
292        options = json.dumps(self.options)
293        imports = (
294            """<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script>
295  <script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script>
296  <link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />"""
297            if self.imports
298            else ""
299        )
300
301        return f"""<div>
302  <div id="sqlglot-lineage"></div>
303  {imports}
304  <script type="text/javascript">
305    var nodes = new vis.DataSet({nodes})
306    nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0])
307
308    new vis.Network(
309        document.getElementById("sqlglot-lineage"),
310        {{
311            nodes: nodes,
312            edges: new vis.DataSet({edges})
313        }},
314        {options},
315    )
316  </script>
317</div>"""
318
319    def _repr_html_(self) -> str:
320        return self.__str__()

Node to HTML generator using vis.js.

https://visjs.github.io/vis-network/docs/network/

GraphHTML( nodes: Dict, edges: List, imports: bool = True, options: Optional[Dict] = None)
251    def __init__(
252        self, nodes: t.Dict, edges: t.List, imports: bool = True, options: t.Optional[t.Dict] = None
253    ):
254        self.imports = imports
255
256        self.options = {
257            "height": "500px",
258            "width": "100%",
259            "layout": {
260                "hierarchical": {
261                    "enabled": True,
262                    "nodeSpacing": 200,
263                    "sortMethod": "directed",
264                },
265            },
266            "interaction": {
267                "dragNodes": False,
268                "selectable": False,
269            },
270            "physics": {
271                "enabled": False,
272            },
273            "edges": {
274                "arrows": "to",
275            },
276            "nodes": {
277                "font": "20px monaco",
278                "shape": "box",
279                "widthConstraint": {
280                    "maximum": 300,
281                },
282            },
283            **(options or {}),
284        }
285
286        self.nodes = nodes
287        self.edges = edges
imports
options
nodes
edges