sqlglot.optimizer.scope
1from __future__ import annotations 2 3import itertools 4import logging 5import typing as t 6from collections import defaultdict 7from enum import Enum, auto 8 9from sqlglot import exp 10from sqlglot.errors import OptimizeError 11from sqlglot.helper import ensure_collection, find_new_name, seq_get 12 13logger = logging.getLogger("sqlglot") 14 15 16class ScopeType(Enum): 17 ROOT = auto() 18 SUBQUERY = auto() 19 DERIVED_TABLE = auto() 20 CTE = auto() 21 UNION = auto() 22 UDTF = auto() 23 24 25class Scope: 26 """ 27 Selection scope. 28 29 Attributes: 30 expression (exp.Select|exp.Union): Root expression of this scope 31 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 32 a Table expression or another Scope instance. For example: 33 SELECT * FROM x {"x": Table(this="x")} 34 SELECT * FROM x AS y {"y": Table(this="x")} 35 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 36 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 37 For example: 38 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 39 The LATERAL VIEW EXPLODE gets x as a source. 40 cte_sources (dict[str, Scope]): Sources from CTES 41 outer_columns (list[str]): If this is a derived table or CTE, and the outer query 42 defines a column list for the alias of this scope, this is that list of columns. 43 For example: 44 SELECT * FROM (SELECT ...) AS y(col1, col2) 45 The inner query would have `["col1", "col2"]` for its `outer_columns` 46 parent (Scope): Parent scope 47 scope_type (ScopeType): Type of this scope, relative to it's parent 48 subquery_scopes (list[Scope]): List of all child scopes for subqueries 49 cte_scopes (list[Scope]): List of all child scopes for CTEs 50 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 51 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 52 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 53 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 54 a list of the left and right child scopes. 55 """ 56 57 def __init__( 58 self, 59 expression, 60 sources=None, 61 outer_columns=None, 62 parent=None, 63 scope_type=ScopeType.ROOT, 64 lateral_sources=None, 65 cte_sources=None, 66 ): 67 self.expression = expression 68 self.sources = sources or {} 69 self.lateral_sources = lateral_sources or {} 70 self.cte_sources = cte_sources or {} 71 self.sources.update(self.lateral_sources) 72 self.sources.update(self.cte_sources) 73 self.outer_columns = outer_columns or [] 74 self.parent = parent 75 self.scope_type = scope_type 76 self.subquery_scopes = [] 77 self.derived_table_scopes = [] 78 self.table_scopes = [] 79 self.cte_scopes = [] 80 self.union_scopes = [] 81 self.udtf_scopes = [] 82 self.clear_cache() 83 84 def clear_cache(self): 85 self._collected = False 86 self._raw_columns = None 87 self._derived_tables = None 88 self._udtfs = None 89 self._tables = None 90 self._ctes = None 91 self._subqueries = None 92 self._selected_sources = None 93 self._columns = None 94 self._external_columns = None 95 self._join_hints = None 96 self._pivots = None 97 self._references = None 98 99 def branch( 100 self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs 101 ): 102 """Branch from the current scope to a new, inner scope""" 103 return Scope( 104 expression=expression.unnest(), 105 sources=sources.copy() if sources else None, 106 parent=self, 107 scope_type=scope_type, 108 cte_sources={**self.cte_sources, **(cte_sources or {})}, 109 lateral_sources=lateral_sources.copy() if lateral_sources else None, 110 **kwargs, 111 ) 112 113 def _collect(self): 114 self._tables = [] 115 self._ctes = [] 116 self._subqueries = [] 117 self._derived_tables = [] 118 self._udtfs = [] 119 self._raw_columns = [] 120 self._join_hints = [] 121 122 for node, parent, _ in self.walk(bfs=False): 123 if node is self.expression: 124 continue 125 126 if isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): 127 self._raw_columns.append(node) 128 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 129 self._tables.append(node) 130 elif isinstance(node, exp.JoinHint): 131 self._join_hints.append(node) 132 elif isinstance(node, exp.UDTF): 133 self._udtfs.append(node) 134 elif isinstance(node, exp.CTE): 135 self._ctes.append(node) 136 elif _is_derived_table(node) and isinstance(parent, (exp.From, exp.Join, exp.Subquery)): 137 self._derived_tables.append(node) 138 elif isinstance(node, exp.UNWRAPPED_QUERIES): 139 self._subqueries.append(node) 140 141 self._collected = True 142 143 def _ensure_collected(self): 144 if not self._collected: 145 self._collect() 146 147 def walk(self, bfs=True, prune=None): 148 return walk_in_scope(self.expression, bfs=bfs, prune=None) 149 150 def find(self, *expression_types, bfs=True): 151 return find_in_scope(self.expression, expression_types, bfs=bfs) 152 153 def find_all(self, *expression_types, bfs=True): 154 return find_all_in_scope(self.expression, expression_types, bfs=bfs) 155 156 def replace(self, old, new): 157 """ 158 Replace `old` with `new`. 159 160 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 161 162 Args: 163 old (exp.Expression): old node 164 new (exp.Expression): new node 165 """ 166 old.replace(new) 167 self.clear_cache() 168 169 @property 170 def tables(self): 171 """ 172 List of tables in this scope. 173 174 Returns: 175 list[exp.Table]: tables 176 """ 177 self._ensure_collected() 178 return self._tables 179 180 @property 181 def ctes(self): 182 """ 183 List of CTEs in this scope. 184 185 Returns: 186 list[exp.CTE]: ctes 187 """ 188 self._ensure_collected() 189 return self._ctes 190 191 @property 192 def derived_tables(self): 193 """ 194 List of derived tables in this scope. 195 196 For example: 197 SELECT * FROM (SELECT ...) <- that's a derived table 198 199 Returns: 200 list[exp.Subquery]: derived tables 201 """ 202 self._ensure_collected() 203 return self._derived_tables 204 205 @property 206 def udtfs(self): 207 """ 208 List of "User Defined Tabular Functions" in this scope. 209 210 Returns: 211 list[exp.UDTF]: UDTFs 212 """ 213 self._ensure_collected() 214 return self._udtfs 215 216 @property 217 def subqueries(self): 218 """ 219 List of subqueries in this scope. 220 221 For example: 222 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 223 224 Returns: 225 list[exp.Select | exp.Union]: subqueries 226 """ 227 self._ensure_collected() 228 return self._subqueries 229 230 @property 231 def columns(self): 232 """ 233 List of columns in this scope. 234 235 Returns: 236 list[exp.Column]: Column instances in this scope, plus any 237 Columns that reference this scope from correlated subqueries. 238 """ 239 if self._columns is None: 240 self._ensure_collected() 241 columns = self._raw_columns 242 243 external_columns = [ 244 column 245 for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes) 246 for column in scope.external_columns 247 ] 248 249 named_selects = set(self.expression.named_selects) 250 251 self._columns = [] 252 for column in columns + external_columns: 253 ancestor = column.find_ancestor( 254 exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table, exp.Star 255 ) 256 if ( 257 not ancestor 258 or column.table 259 or isinstance(ancestor, exp.Select) 260 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 261 or ( 262 isinstance(ancestor, exp.Order) 263 and ( 264 isinstance(ancestor.parent, exp.Window) 265 or column.name not in named_selects 266 ) 267 ) 268 ): 269 self._columns.append(column) 270 271 return self._columns 272 273 @property 274 def selected_sources(self): 275 """ 276 Mapping of nodes and sources that are actually selected from in this scope. 277 278 That is, all tables in a schema are selectable at any point. But a 279 table only becomes a selected source if it's included in a FROM or JOIN clause. 280 281 Returns: 282 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 283 """ 284 if self._selected_sources is None: 285 result = {} 286 287 for name, node in self.references: 288 if name in result: 289 raise OptimizeError(f"Alias already used: {name}") 290 if name in self.sources: 291 result[name] = (node, self.sources[name]) 292 293 self._selected_sources = result 294 return self._selected_sources 295 296 @property 297 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 298 if self._references is None: 299 self._references = [] 300 301 for table in self.tables: 302 self._references.append((table.alias_or_name, table)) 303 for expression in itertools.chain(self.derived_tables, self.udtfs): 304 self._references.append( 305 ( 306 expression.alias, 307 expression if expression.args.get("pivots") else expression.unnest(), 308 ) 309 ) 310 311 return self._references 312 313 @property 314 def external_columns(self): 315 """ 316 Columns that appear to reference sources in outer scopes. 317 318 Returns: 319 list[exp.Column]: Column instances that don't reference 320 sources in the current scope. 321 """ 322 if self._external_columns is None: 323 if isinstance(self.expression, exp.Union): 324 left, right = self.union_scopes 325 self._external_columns = left.external_columns + right.external_columns 326 else: 327 self._external_columns = [ 328 c for c in self.columns if c.table not in self.selected_sources 329 ] 330 331 return self._external_columns 332 333 @property 334 def unqualified_columns(self): 335 """ 336 Unqualified columns in the current scope. 337 338 Returns: 339 list[exp.Column]: Unqualified columns 340 """ 341 return [c for c in self.columns if not c.table] 342 343 @property 344 def join_hints(self): 345 """ 346 Hints that exist in the scope that reference tables 347 348 Returns: 349 list[exp.JoinHint]: Join hints that are referenced within the scope 350 """ 351 if self._join_hints is None: 352 return [] 353 return self._join_hints 354 355 @property 356 def pivots(self): 357 if not self._pivots: 358 self._pivots = [ 359 pivot for _, node in self.references for pivot in node.args.get("pivots") or [] 360 ] 361 362 return self._pivots 363 364 def source_columns(self, source_name): 365 """ 366 Get all columns in the current scope for a particular source. 367 368 Args: 369 source_name (str): Name of the source 370 Returns: 371 list[exp.Column]: Column instances that reference `source_name` 372 """ 373 return [column for column in self.columns if column.table == source_name] 374 375 @property 376 def is_subquery(self): 377 """Determine if this scope is a subquery""" 378 return self.scope_type == ScopeType.SUBQUERY 379 380 @property 381 def is_derived_table(self): 382 """Determine if this scope is a derived table""" 383 return self.scope_type == ScopeType.DERIVED_TABLE 384 385 @property 386 def is_union(self): 387 """Determine if this scope is a union""" 388 return self.scope_type == ScopeType.UNION 389 390 @property 391 def is_cte(self): 392 """Determine if this scope is a common table expression""" 393 return self.scope_type == ScopeType.CTE 394 395 @property 396 def is_root(self): 397 """Determine if this is the root scope""" 398 return self.scope_type == ScopeType.ROOT 399 400 @property 401 def is_udtf(self): 402 """Determine if this scope is a UDTF (User Defined Table Function)""" 403 return self.scope_type == ScopeType.UDTF 404 405 @property 406 def is_correlated_subquery(self): 407 """Determine if this scope is a correlated subquery""" 408 return bool( 409 (self.is_subquery or (self.parent and isinstance(self.parent.expression, exp.Lateral))) 410 and self.external_columns 411 ) 412 413 def rename_source(self, old_name, new_name): 414 """Rename a source in this scope""" 415 columns = self.sources.pop(old_name or "", []) 416 self.sources[new_name] = columns 417 418 def add_source(self, name, source): 419 """Add a source to this scope""" 420 self.sources[name] = source 421 self.clear_cache() 422 423 def remove_source(self, name): 424 """Remove a source from this scope""" 425 self.sources.pop(name, None) 426 self.clear_cache() 427 428 def __repr__(self): 429 return f"Scope<{self.expression.sql()}>" 430 431 def traverse(self): 432 """ 433 Traverse the scope tree from this node. 434 435 Yields: 436 Scope: scope instances in depth-first-search post-order 437 """ 438 stack = [self] 439 result = [] 440 while stack: 441 scope = stack.pop() 442 result.append(scope) 443 stack.extend( 444 itertools.chain( 445 scope.cte_scopes, 446 scope.union_scopes, 447 scope.table_scopes, 448 scope.subquery_scopes, 449 ) 450 ) 451 452 yield from reversed(result) 453 454 def ref_count(self): 455 """ 456 Count the number of times each scope in this tree is referenced. 457 458 Returns: 459 dict[int, int]: Mapping of Scope instance ID to reference count 460 """ 461 scope_ref_count = defaultdict(lambda: 0) 462 463 for scope in self.traverse(): 464 for _, source in scope.selected_sources.values(): 465 scope_ref_count[id(source)] += 1 466 467 return scope_ref_count 468 469 470def traverse_scope(expression: exp.Expression) -> t.List[Scope]: 471 """ 472 Traverse an expression by its "scopes". 473 474 "Scope" represents the current context of a Select statement. 475 476 This is helpful for optimizing queries, where we need more information than 477 the expression tree itself. For example, we might care about the source 478 names within a subquery. Returns a list because a generator could result in 479 incomplete properties which is confusing. 480 481 Examples: 482 >>> import sqlglot 483 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 484 >>> scopes = traverse_scope(expression) 485 >>> scopes[0].expression.sql(), list(scopes[0].sources) 486 ('SELECT a FROM x', ['x']) 487 >>> scopes[1].expression.sql(), list(scopes[1].sources) 488 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 489 490 Args: 491 expression: Expression to traverse 492 493 Returns: 494 A list of the created scope instances 495 """ 496 if isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Query): 497 # We ignore the DDL expression and build a scope for its query instead 498 ddl_with = expression.args.get("with") 499 expression = expression.expression 500 501 # If the DDL has CTEs attached, we need to add them to the query, or 502 # prepend them if the query itself already has CTEs attached to it 503 if ddl_with: 504 ddl_with.pop() 505 query_ctes = expression.ctes 506 if not query_ctes: 507 expression.set("with", ddl_with) 508 else: 509 expression.args["with"].set("recursive", ddl_with.recursive) 510 expression.args["with"].set("expressions", [*ddl_with.expressions, *query_ctes]) 511 512 if isinstance(expression, exp.Query): 513 return list(_traverse_scope(Scope(expression))) 514 515 return [] 516 517 518def build_scope(expression: exp.Expression) -> t.Optional[Scope]: 519 """ 520 Build a scope tree. 521 522 Args: 523 expression: Expression to build the scope tree for. 524 525 Returns: 526 The root scope 527 """ 528 return seq_get(traverse_scope(expression), -1) 529 530 531def _traverse_scope(scope): 532 if isinstance(scope.expression, exp.Select): 533 yield from _traverse_select(scope) 534 elif isinstance(scope.expression, exp.Union): 535 yield from _traverse_ctes(scope) 536 yield from _traverse_union(scope) 537 return 538 elif isinstance(scope.expression, exp.Subquery): 539 if scope.is_root: 540 yield from _traverse_select(scope) 541 else: 542 yield from _traverse_subqueries(scope) 543 elif isinstance(scope.expression, exp.Table): 544 yield from _traverse_tables(scope) 545 elif isinstance(scope.expression, exp.UDTF): 546 yield from _traverse_udtfs(scope) 547 else: 548 logger.warning( 549 "Cannot traverse scope %s with type '%s'", scope.expression, type(scope.expression) 550 ) 551 return 552 553 yield scope 554 555 556def _traverse_select(scope): 557 yield from _traverse_ctes(scope) 558 yield from _traverse_tables(scope) 559 yield from _traverse_subqueries(scope) 560 561 562def _traverse_union(scope): 563 prev_scope = None 564 union_scope_stack = [scope] 565 expression_stack = [scope.expression.right, scope.expression.left] 566 567 while expression_stack: 568 expression = expression_stack.pop() 569 union_scope = union_scope_stack[-1] 570 571 new_scope = union_scope.branch( 572 expression, 573 outer_columns=union_scope.outer_columns, 574 scope_type=ScopeType.UNION, 575 ) 576 577 if isinstance(expression, exp.Union): 578 yield from _traverse_ctes(new_scope) 579 580 union_scope_stack.append(new_scope) 581 expression_stack.extend([expression.right, expression.left]) 582 continue 583 584 for scope in _traverse_scope(new_scope): 585 yield scope 586 587 if prev_scope: 588 union_scope_stack.pop() 589 union_scope.union_scopes = [prev_scope, scope] 590 prev_scope = union_scope 591 592 yield union_scope 593 else: 594 prev_scope = scope 595 596 597def _traverse_ctes(scope): 598 sources = {} 599 600 for cte in scope.ctes: 601 recursive_scope = None 602 603 # if the scope is a recursive cte, it must be in the form of base_case UNION recursive. 604 # thus the recursive scope is the first section of the union. 605 with_ = scope.expression.args.get("with") 606 if with_ and with_.recursive: 607 union = cte.this 608 609 if isinstance(union, exp.Union): 610 recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE) 611 612 child_scope = None 613 614 for child_scope in _traverse_scope( 615 scope.branch( 616 cte.this, 617 cte_sources=sources, 618 outer_columns=cte.alias_column_names, 619 scope_type=ScopeType.CTE, 620 ) 621 ): 622 yield child_scope 623 624 alias = cte.alias 625 sources[alias] = child_scope 626 627 if recursive_scope: 628 child_scope.add_source(alias, recursive_scope) 629 child_scope.cte_sources[alias] = recursive_scope 630 631 # append the final child_scope yielded 632 if child_scope: 633 scope.cte_scopes.append(child_scope) 634 635 scope.sources.update(sources) 636 scope.cte_sources.update(sources) 637 638 639def _is_derived_table(expression: exp.Subquery) -> bool: 640 """ 641 We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a "derived table", 642 as it doesn't introduce a new scope. If an alias is present, it shadows all names 643 under the Subquery, so that's one exception to this rule. 644 """ 645 return isinstance(expression, exp.Subquery) and bool( 646 expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES) 647 ) 648 649 650def _traverse_tables(scope): 651 sources = {} 652 653 # Traverse FROMs, JOINs, and LATERALs in the order they are defined 654 expressions = [] 655 from_ = scope.expression.args.get("from") 656 if from_: 657 expressions.append(from_.this) 658 659 for join in scope.expression.args.get("joins") or []: 660 expressions.append(join.this) 661 662 if isinstance(scope.expression, exp.Table): 663 expressions.append(scope.expression) 664 665 expressions.extend(scope.expression.args.get("laterals") or []) 666 667 for expression in expressions: 668 if isinstance(expression, exp.Table): 669 table_name = expression.name 670 source_name = expression.alias_or_name 671 672 if table_name in scope.sources and not expression.db: 673 # This is a reference to a parent source (e.g. a CTE), not an actual table, unless 674 # it is pivoted, because then we get back a new table and hence a new source. 675 pivots = expression.args.get("pivots") 676 if pivots: 677 sources[pivots[0].alias] = expression 678 else: 679 sources[source_name] = scope.sources[table_name] 680 elif source_name in sources: 681 sources[find_new_name(sources, table_name)] = expression 682 else: 683 sources[source_name] = expression 684 685 # Make sure to not include the joins twice 686 if expression is not scope.expression: 687 expressions.extend(join.this for join in expression.args.get("joins") or []) 688 689 continue 690 691 if not isinstance(expression, exp.DerivedTable): 692 continue 693 694 if isinstance(expression, exp.UDTF): 695 lateral_sources = sources 696 scope_type = ScopeType.UDTF 697 scopes = scope.udtf_scopes 698 elif _is_derived_table(expression): 699 lateral_sources = None 700 scope_type = ScopeType.DERIVED_TABLE 701 scopes = scope.derived_table_scopes 702 expressions.extend(join.this for join in expression.args.get("joins") or []) 703 else: 704 # Makes sure we check for possible sources in nested table constructs 705 expressions.append(expression.this) 706 expressions.extend(join.this for join in expression.args.get("joins") or []) 707 continue 708 709 for child_scope in _traverse_scope( 710 scope.branch( 711 expression, 712 lateral_sources=lateral_sources, 713 outer_columns=expression.alias_column_names, 714 scope_type=scope_type, 715 ) 716 ): 717 yield child_scope 718 719 # Tables without aliases will be set as "" 720 # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. 721 # Until then, this means that only a single, unaliased derived table is allowed (rather, 722 # the latest one wins. 723 sources[expression.alias] = child_scope 724 725 # append the final child_scope yielded 726 scopes.append(child_scope) 727 scope.table_scopes.append(child_scope) 728 729 scope.sources.update(sources) 730 731 732def _traverse_subqueries(scope): 733 for subquery in scope.subqueries: 734 top = None 735 for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)): 736 yield child_scope 737 top = child_scope 738 scope.subquery_scopes.append(top) 739 740 741def _traverse_udtfs(scope): 742 if isinstance(scope.expression, exp.Unnest): 743 expressions = scope.expression.expressions 744 elif isinstance(scope.expression, exp.Lateral): 745 expressions = [scope.expression.this] 746 else: 747 expressions = [] 748 749 sources = {} 750 for expression in expressions: 751 if _is_derived_table(expression): 752 top = None 753 for child_scope in _traverse_scope( 754 scope.branch( 755 expression, 756 scope_type=ScopeType.DERIVED_TABLE, 757 outer_columns=expression.alias_column_names, 758 ) 759 ): 760 yield child_scope 761 top = child_scope 762 sources[expression.alias] = child_scope 763 764 scope.derived_table_scopes.append(top) 765 scope.table_scopes.append(top) 766 767 scope.sources.update(sources) 768 769 770def walk_in_scope(expression, bfs=True, prune=None): 771 """ 772 Returns a generator object which visits all nodes in the syntrax tree, stopping at 773 nodes that start child scopes. 774 775 Args: 776 expression (exp.Expression): 777 bfs (bool): if set to True the BFS traversal order will be applied, 778 otherwise the DFS traversal will be used instead. 779 prune ((node, parent, arg_key) -> bool): callable that returns True if 780 the generator should stop traversing this branch of the tree. 781 782 Yields: 783 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 784 """ 785 # We'll use this variable to pass state into the dfs generator. 786 # Whenever we set it to True, we exclude a subtree from traversal. 787 crossed_scope_boundary = False 788 789 for node, parent, key in expression.walk( 790 bfs=bfs, prune=lambda *args: crossed_scope_boundary or (prune and prune(*args)) 791 ): 792 crossed_scope_boundary = False 793 794 yield node, parent, key 795 796 if node is expression: 797 continue 798 if ( 799 isinstance(node, exp.CTE) 800 or (_is_derived_table(node) and isinstance(parent, (exp.From, exp.Join, exp.Subquery))) 801 or isinstance(node, exp.UDTF) 802 or isinstance(node, exp.UNWRAPPED_QUERIES) 803 ): 804 crossed_scope_boundary = True 805 806 if isinstance(node, (exp.Subquery, exp.UDTF)): 807 # The following args are not actually in the inner scope, so we should visit them 808 for key in ("joins", "laterals", "pivots"): 809 for arg in node.args.get(key) or []: 810 yield from walk_in_scope(arg, bfs=bfs) 811 812 813def find_all_in_scope(expression, expression_types, bfs=True): 814 """ 815 Returns a generator object which visits all nodes in this scope and only yields those that 816 match at least one of the specified expression types. 817 818 This does NOT traverse into subscopes. 819 820 Args: 821 expression (exp.Expression): 822 expression_types (tuple[type]|type): the expression type(s) to match. 823 bfs (bool): True to use breadth-first search, False to use depth-first. 824 825 Yields: 826 exp.Expression: nodes 827 """ 828 for expression, *_ in walk_in_scope(expression, bfs=bfs): 829 if isinstance(expression, tuple(ensure_collection(expression_types))): 830 yield expression 831 832 833def find_in_scope(expression, expression_types, bfs=True): 834 """ 835 Returns the first node in this scope which matches at least one of the specified types. 836 837 This does NOT traverse into subscopes. 838 839 Args: 840 expression (exp.Expression): 841 expression_types (tuple[type]|type): the expression type(s) to match. 842 bfs (bool): True to use breadth-first search, False to use depth-first. 843 844 Returns: 845 exp.Expression: the node which matches the criteria or None if no node matching 846 the criteria was found. 847 """ 848 return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)
17class ScopeType(Enum): 18 ROOT = auto() 19 SUBQUERY = auto() 20 DERIVED_TABLE = auto() 21 CTE = auto() 22 UNION = auto() 23 UDTF = auto()
An enumeration.
Inherited Members
- enum.Enum
- name
- value
26class Scope: 27 """ 28 Selection scope. 29 30 Attributes: 31 expression (exp.Select|exp.Union): Root expression of this scope 32 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 33 a Table expression or another Scope instance. For example: 34 SELECT * FROM x {"x": Table(this="x")} 35 SELECT * FROM x AS y {"y": Table(this="x")} 36 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 37 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 38 For example: 39 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 40 The LATERAL VIEW EXPLODE gets x as a source. 41 cte_sources (dict[str, Scope]): Sources from CTES 42 outer_columns (list[str]): If this is a derived table or CTE, and the outer query 43 defines a column list for the alias of this scope, this is that list of columns. 44 For example: 45 SELECT * FROM (SELECT ...) AS y(col1, col2) 46 The inner query would have `["col1", "col2"]` for its `outer_columns` 47 parent (Scope): Parent scope 48 scope_type (ScopeType): Type of this scope, relative to it's parent 49 subquery_scopes (list[Scope]): List of all child scopes for subqueries 50 cte_scopes (list[Scope]): List of all child scopes for CTEs 51 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 52 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 53 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 54 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 55 a list of the left and right child scopes. 56 """ 57 58 def __init__( 59 self, 60 expression, 61 sources=None, 62 outer_columns=None, 63 parent=None, 64 scope_type=ScopeType.ROOT, 65 lateral_sources=None, 66 cte_sources=None, 67 ): 68 self.expression = expression 69 self.sources = sources or {} 70 self.lateral_sources = lateral_sources or {} 71 self.cte_sources = cte_sources or {} 72 self.sources.update(self.lateral_sources) 73 self.sources.update(self.cte_sources) 74 self.outer_columns = outer_columns or [] 75 self.parent = parent 76 self.scope_type = scope_type 77 self.subquery_scopes = [] 78 self.derived_table_scopes = [] 79 self.table_scopes = [] 80 self.cte_scopes = [] 81 self.union_scopes = [] 82 self.udtf_scopes = [] 83 self.clear_cache() 84 85 def clear_cache(self): 86 self._collected = False 87 self._raw_columns = None 88 self._derived_tables = None 89 self._udtfs = None 90 self._tables = None 91 self._ctes = None 92 self._subqueries = None 93 self._selected_sources = None 94 self._columns = None 95 self._external_columns = None 96 self._join_hints = None 97 self._pivots = None 98 self._references = None 99 100 def branch( 101 self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs 102 ): 103 """Branch from the current scope to a new, inner scope""" 104 return Scope( 105 expression=expression.unnest(), 106 sources=sources.copy() if sources else None, 107 parent=self, 108 scope_type=scope_type, 109 cte_sources={**self.cte_sources, **(cte_sources or {})}, 110 lateral_sources=lateral_sources.copy() if lateral_sources else None, 111 **kwargs, 112 ) 113 114 def _collect(self): 115 self._tables = [] 116 self._ctes = [] 117 self._subqueries = [] 118 self._derived_tables = [] 119 self._udtfs = [] 120 self._raw_columns = [] 121 self._join_hints = [] 122 123 for node, parent, _ in self.walk(bfs=False): 124 if node is self.expression: 125 continue 126 127 if isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): 128 self._raw_columns.append(node) 129 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 130 self._tables.append(node) 131 elif isinstance(node, exp.JoinHint): 132 self._join_hints.append(node) 133 elif isinstance(node, exp.UDTF): 134 self._udtfs.append(node) 135 elif isinstance(node, exp.CTE): 136 self._ctes.append(node) 137 elif _is_derived_table(node) and isinstance(parent, (exp.From, exp.Join, exp.Subquery)): 138 self._derived_tables.append(node) 139 elif isinstance(node, exp.UNWRAPPED_QUERIES): 140 self._subqueries.append(node) 141 142 self._collected = True 143 144 def _ensure_collected(self): 145 if not self._collected: 146 self._collect() 147 148 def walk(self, bfs=True, prune=None): 149 return walk_in_scope(self.expression, bfs=bfs, prune=None) 150 151 def find(self, *expression_types, bfs=True): 152 return find_in_scope(self.expression, expression_types, bfs=bfs) 153 154 def find_all(self, *expression_types, bfs=True): 155 return find_all_in_scope(self.expression, expression_types, bfs=bfs) 156 157 def replace(self, old, new): 158 """ 159 Replace `old` with `new`. 160 161 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 162 163 Args: 164 old (exp.Expression): old node 165 new (exp.Expression): new node 166 """ 167 old.replace(new) 168 self.clear_cache() 169 170 @property 171 def tables(self): 172 """ 173 List of tables in this scope. 174 175 Returns: 176 list[exp.Table]: tables 177 """ 178 self._ensure_collected() 179 return self._tables 180 181 @property 182 def ctes(self): 183 """ 184 List of CTEs in this scope. 185 186 Returns: 187 list[exp.CTE]: ctes 188 """ 189 self._ensure_collected() 190 return self._ctes 191 192 @property 193 def derived_tables(self): 194 """ 195 List of derived tables in this scope. 196 197 For example: 198 SELECT * FROM (SELECT ...) <- that's a derived table 199 200 Returns: 201 list[exp.Subquery]: derived tables 202 """ 203 self._ensure_collected() 204 return self._derived_tables 205 206 @property 207 def udtfs(self): 208 """ 209 List of "User Defined Tabular Functions" in this scope. 210 211 Returns: 212 list[exp.UDTF]: UDTFs 213 """ 214 self._ensure_collected() 215 return self._udtfs 216 217 @property 218 def subqueries(self): 219 """ 220 List of subqueries in this scope. 221 222 For example: 223 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 224 225 Returns: 226 list[exp.Select | exp.Union]: subqueries 227 """ 228 self._ensure_collected() 229 return self._subqueries 230 231 @property 232 def columns(self): 233 """ 234 List of columns in this scope. 235 236 Returns: 237 list[exp.Column]: Column instances in this scope, plus any 238 Columns that reference this scope from correlated subqueries. 239 """ 240 if self._columns is None: 241 self._ensure_collected() 242 columns = self._raw_columns 243 244 external_columns = [ 245 column 246 for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes) 247 for column in scope.external_columns 248 ] 249 250 named_selects = set(self.expression.named_selects) 251 252 self._columns = [] 253 for column in columns + external_columns: 254 ancestor = column.find_ancestor( 255 exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table, exp.Star 256 ) 257 if ( 258 not ancestor 259 or column.table 260 or isinstance(ancestor, exp.Select) 261 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 262 or ( 263 isinstance(ancestor, exp.Order) 264 and ( 265 isinstance(ancestor.parent, exp.Window) 266 or column.name not in named_selects 267 ) 268 ) 269 ): 270 self._columns.append(column) 271 272 return self._columns 273 274 @property 275 def selected_sources(self): 276 """ 277 Mapping of nodes and sources that are actually selected from in this scope. 278 279 That is, all tables in a schema are selectable at any point. But a 280 table only becomes a selected source if it's included in a FROM or JOIN clause. 281 282 Returns: 283 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 284 """ 285 if self._selected_sources is None: 286 result = {} 287 288 for name, node in self.references: 289 if name in result: 290 raise OptimizeError(f"Alias already used: {name}") 291 if name in self.sources: 292 result[name] = (node, self.sources[name]) 293 294 self._selected_sources = result 295 return self._selected_sources 296 297 @property 298 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 299 if self._references is None: 300 self._references = [] 301 302 for table in self.tables: 303 self._references.append((table.alias_or_name, table)) 304 for expression in itertools.chain(self.derived_tables, self.udtfs): 305 self._references.append( 306 ( 307 expression.alias, 308 expression if expression.args.get("pivots") else expression.unnest(), 309 ) 310 ) 311 312 return self._references 313 314 @property 315 def external_columns(self): 316 """ 317 Columns that appear to reference sources in outer scopes. 318 319 Returns: 320 list[exp.Column]: Column instances that don't reference 321 sources in the current scope. 322 """ 323 if self._external_columns is None: 324 if isinstance(self.expression, exp.Union): 325 left, right = self.union_scopes 326 self._external_columns = left.external_columns + right.external_columns 327 else: 328 self._external_columns = [ 329 c for c in self.columns if c.table not in self.selected_sources 330 ] 331 332 return self._external_columns 333 334 @property 335 def unqualified_columns(self): 336 """ 337 Unqualified columns in the current scope. 338 339 Returns: 340 list[exp.Column]: Unqualified columns 341 """ 342 return [c for c in self.columns if not c.table] 343 344 @property 345 def join_hints(self): 346 """ 347 Hints that exist in the scope that reference tables 348 349 Returns: 350 list[exp.JoinHint]: Join hints that are referenced within the scope 351 """ 352 if self._join_hints is None: 353 return [] 354 return self._join_hints 355 356 @property 357 def pivots(self): 358 if not self._pivots: 359 self._pivots = [ 360 pivot for _, node in self.references for pivot in node.args.get("pivots") or [] 361 ] 362 363 return self._pivots 364 365 def source_columns(self, source_name): 366 """ 367 Get all columns in the current scope for a particular source. 368 369 Args: 370 source_name (str): Name of the source 371 Returns: 372 list[exp.Column]: Column instances that reference `source_name` 373 """ 374 return [column for column in self.columns if column.table == source_name] 375 376 @property 377 def is_subquery(self): 378 """Determine if this scope is a subquery""" 379 return self.scope_type == ScopeType.SUBQUERY 380 381 @property 382 def is_derived_table(self): 383 """Determine if this scope is a derived table""" 384 return self.scope_type == ScopeType.DERIVED_TABLE 385 386 @property 387 def is_union(self): 388 """Determine if this scope is a union""" 389 return self.scope_type == ScopeType.UNION 390 391 @property 392 def is_cte(self): 393 """Determine if this scope is a common table expression""" 394 return self.scope_type == ScopeType.CTE 395 396 @property 397 def is_root(self): 398 """Determine if this is the root scope""" 399 return self.scope_type == ScopeType.ROOT 400 401 @property 402 def is_udtf(self): 403 """Determine if this scope is a UDTF (User Defined Table Function)""" 404 return self.scope_type == ScopeType.UDTF 405 406 @property 407 def is_correlated_subquery(self): 408 """Determine if this scope is a correlated subquery""" 409 return bool( 410 (self.is_subquery or (self.parent and isinstance(self.parent.expression, exp.Lateral))) 411 and self.external_columns 412 ) 413 414 def rename_source(self, old_name, new_name): 415 """Rename a source in this scope""" 416 columns = self.sources.pop(old_name or "", []) 417 self.sources[new_name] = columns 418 419 def add_source(self, name, source): 420 """Add a source to this scope""" 421 self.sources[name] = source 422 self.clear_cache() 423 424 def remove_source(self, name): 425 """Remove a source from this scope""" 426 self.sources.pop(name, None) 427 self.clear_cache() 428 429 def __repr__(self): 430 return f"Scope<{self.expression.sql()}>" 431 432 def traverse(self): 433 """ 434 Traverse the scope tree from this node. 435 436 Yields: 437 Scope: scope instances in depth-first-search post-order 438 """ 439 stack = [self] 440 result = [] 441 while stack: 442 scope = stack.pop() 443 result.append(scope) 444 stack.extend( 445 itertools.chain( 446 scope.cte_scopes, 447 scope.union_scopes, 448 scope.table_scopes, 449 scope.subquery_scopes, 450 ) 451 ) 452 453 yield from reversed(result) 454 455 def ref_count(self): 456 """ 457 Count the number of times each scope in this tree is referenced. 458 459 Returns: 460 dict[int, int]: Mapping of Scope instance ID to reference count 461 """ 462 scope_ref_count = defaultdict(lambda: 0) 463 464 for scope in self.traverse(): 465 for _, source in scope.selected_sources.values(): 466 scope_ref_count[id(source)] += 1 467 468 return scope_ref_count
Selection scope.
Attributes:
- expression (exp.Select|exp.Union): Root expression of this scope
- sources (dict[str, exp.Table|Scope]): Mapping of source name to either a Table expression or another Scope instance. For example: SELECT * FROM x {"x": Table(this="x")} SELECT * FROM x AS y {"y": Table(this="x")} SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
- lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals For example: SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; The LATERAL VIEW EXPLODE gets x as a source.
- cte_sources (dict[str, Scope]): Sources from CTES
- outer_columns (list[str]): If this is a derived table or CTE, and the outer query
defines a column list for the alias of this scope, this is that list of columns.
For example:
SELECT * FROM (SELECT ...) AS y(col1, col2)
The inner query would have
["col1", "col2"]
for itsouter_columns
- parent (Scope): Parent scope
- scope_type (ScopeType): Type of this scope, relative to it's parent
- subquery_scopes (list[Scope]): List of all child scopes for subqueries
- cte_scopes (list[Scope]): List of all child scopes for CTEs
- derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
- udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
- table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
- union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be a list of the left and right child scopes.
58 def __init__( 59 self, 60 expression, 61 sources=None, 62 outer_columns=None, 63 parent=None, 64 scope_type=ScopeType.ROOT, 65 lateral_sources=None, 66 cte_sources=None, 67 ): 68 self.expression = expression 69 self.sources = sources or {} 70 self.lateral_sources = lateral_sources or {} 71 self.cte_sources = cte_sources or {} 72 self.sources.update(self.lateral_sources) 73 self.sources.update(self.cte_sources) 74 self.outer_columns = outer_columns or [] 75 self.parent = parent 76 self.scope_type = scope_type 77 self.subquery_scopes = [] 78 self.derived_table_scopes = [] 79 self.table_scopes = [] 80 self.cte_scopes = [] 81 self.union_scopes = [] 82 self.udtf_scopes = [] 83 self.clear_cache()
85 def clear_cache(self): 86 self._collected = False 87 self._raw_columns = None 88 self._derived_tables = None 89 self._udtfs = None 90 self._tables = None 91 self._ctes = None 92 self._subqueries = None 93 self._selected_sources = None 94 self._columns = None 95 self._external_columns = None 96 self._join_hints = None 97 self._pivots = None 98 self._references = None
100 def branch( 101 self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs 102 ): 103 """Branch from the current scope to a new, inner scope""" 104 return Scope( 105 expression=expression.unnest(), 106 sources=sources.copy() if sources else None, 107 parent=self, 108 scope_type=scope_type, 109 cte_sources={**self.cte_sources, **(cte_sources or {})}, 110 lateral_sources=lateral_sources.copy() if lateral_sources else None, 111 **kwargs, 112 )
Branch from the current scope to a new, inner scope
157 def replace(self, old, new): 158 """ 159 Replace `old` with `new`. 160 161 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 162 163 Args: 164 old (exp.Expression): old node 165 new (exp.Expression): new node 166 """ 167 old.replace(new) 168 self.clear_cache()
Replace old
with new
.
This can be used instead of exp.Expression.replace
to ensure the Scope
is kept up-to-date.
Arguments:
- old (exp.Expression): old node
- new (exp.Expression): new node
170 @property 171 def tables(self): 172 """ 173 List of tables in this scope. 174 175 Returns: 176 list[exp.Table]: tables 177 """ 178 self._ensure_collected() 179 return self._tables
List of tables in this scope.
Returns:
list[exp.Table]: tables
181 @property 182 def ctes(self): 183 """ 184 List of CTEs in this scope. 185 186 Returns: 187 list[exp.CTE]: ctes 188 """ 189 self._ensure_collected() 190 return self._ctes
List of CTEs in this scope.
Returns:
list[exp.CTE]: ctes
192 @property 193 def derived_tables(self): 194 """ 195 List of derived tables in this scope. 196 197 For example: 198 SELECT * FROM (SELECT ...) <- that's a derived table 199 200 Returns: 201 list[exp.Subquery]: derived tables 202 """ 203 self._ensure_collected() 204 return self._derived_tables
List of derived tables in this scope.
For example:
SELECT * FROM (SELECT ...) <- that's a derived table
Returns:
list[exp.Subquery]: derived tables
206 @property 207 def udtfs(self): 208 """ 209 List of "User Defined Tabular Functions" in this scope. 210 211 Returns: 212 list[exp.UDTF]: UDTFs 213 """ 214 self._ensure_collected() 215 return self._udtfs
List of "User Defined Tabular Functions" in this scope.
Returns:
list[exp.UDTF]: UDTFs
217 @property 218 def subqueries(self): 219 """ 220 List of subqueries in this scope. 221 222 For example: 223 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 224 225 Returns: 226 list[exp.Select | exp.Union]: subqueries 227 """ 228 self._ensure_collected() 229 return self._subqueries
List of subqueries in this scope.
For example:
SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
Returns:
list[exp.Select | exp.Union]: subqueries
231 @property 232 def columns(self): 233 """ 234 List of columns in this scope. 235 236 Returns: 237 list[exp.Column]: Column instances in this scope, plus any 238 Columns that reference this scope from correlated subqueries. 239 """ 240 if self._columns is None: 241 self._ensure_collected() 242 columns = self._raw_columns 243 244 external_columns = [ 245 column 246 for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes) 247 for column in scope.external_columns 248 ] 249 250 named_selects = set(self.expression.named_selects) 251 252 self._columns = [] 253 for column in columns + external_columns: 254 ancestor = column.find_ancestor( 255 exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table, exp.Star 256 ) 257 if ( 258 not ancestor 259 or column.table 260 or isinstance(ancestor, exp.Select) 261 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 262 or ( 263 isinstance(ancestor, exp.Order) 264 and ( 265 isinstance(ancestor.parent, exp.Window) 266 or column.name not in named_selects 267 ) 268 ) 269 ): 270 self._columns.append(column) 271 272 return self._columns
List of columns in this scope.
Returns:
list[exp.Column]: Column instances in this scope, plus any Columns that reference this scope from correlated subqueries.
274 @property 275 def selected_sources(self): 276 """ 277 Mapping of nodes and sources that are actually selected from in this scope. 278 279 That is, all tables in a schema are selectable at any point. But a 280 table only becomes a selected source if it's included in a FROM or JOIN clause. 281 282 Returns: 283 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 284 """ 285 if self._selected_sources is None: 286 result = {} 287 288 for name, node in self.references: 289 if name in result: 290 raise OptimizeError(f"Alias already used: {name}") 291 if name in self.sources: 292 result[name] = (node, self.sources[name]) 293 294 self._selected_sources = result 295 return self._selected_sources
Mapping of nodes and sources that are actually selected from in this scope.
That is, all tables in a schema are selectable at any point. But a table only becomes a selected source if it's included in a FROM or JOIN clause.
Returns:
dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
297 @property 298 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 299 if self._references is None: 300 self._references = [] 301 302 for table in self.tables: 303 self._references.append((table.alias_or_name, table)) 304 for expression in itertools.chain(self.derived_tables, self.udtfs): 305 self._references.append( 306 ( 307 expression.alias, 308 expression if expression.args.get("pivots") else expression.unnest(), 309 ) 310 ) 311 312 return self._references
314 @property 315 def external_columns(self): 316 """ 317 Columns that appear to reference sources in outer scopes. 318 319 Returns: 320 list[exp.Column]: Column instances that don't reference 321 sources in the current scope. 322 """ 323 if self._external_columns is None: 324 if isinstance(self.expression, exp.Union): 325 left, right = self.union_scopes 326 self._external_columns = left.external_columns + right.external_columns 327 else: 328 self._external_columns = [ 329 c for c in self.columns if c.table not in self.selected_sources 330 ] 331 332 return self._external_columns
Columns that appear to reference sources in outer scopes.
Returns:
list[exp.Column]: Column instances that don't reference sources in the current scope.
334 @property 335 def unqualified_columns(self): 336 """ 337 Unqualified columns in the current scope. 338 339 Returns: 340 list[exp.Column]: Unqualified columns 341 """ 342 return [c for c in self.columns if not c.table]
Unqualified columns in the current scope.
Returns:
list[exp.Column]: Unqualified columns
344 @property 345 def join_hints(self): 346 """ 347 Hints that exist in the scope that reference tables 348 349 Returns: 350 list[exp.JoinHint]: Join hints that are referenced within the scope 351 """ 352 if self._join_hints is None: 353 return [] 354 return self._join_hints
Hints that exist in the scope that reference tables
Returns:
list[exp.JoinHint]: Join hints that are referenced within the scope
365 def source_columns(self, source_name): 366 """ 367 Get all columns in the current scope for a particular source. 368 369 Args: 370 source_name (str): Name of the source 371 Returns: 372 list[exp.Column]: Column instances that reference `source_name` 373 """ 374 return [column for column in self.columns if column.table == source_name]
Get all columns in the current scope for a particular source.
Arguments:
- source_name (str): Name of the source
Returns:
list[exp.Column]: Column instances that reference
source_name
376 @property 377 def is_subquery(self): 378 """Determine if this scope is a subquery""" 379 return self.scope_type == ScopeType.SUBQUERY
Determine if this scope is a subquery
381 @property 382 def is_derived_table(self): 383 """Determine if this scope is a derived table""" 384 return self.scope_type == ScopeType.DERIVED_TABLE
Determine if this scope is a derived table
386 @property 387 def is_union(self): 388 """Determine if this scope is a union""" 389 return self.scope_type == ScopeType.UNION
Determine if this scope is a union
391 @property 392 def is_cte(self): 393 """Determine if this scope is a common table expression""" 394 return self.scope_type == ScopeType.CTE
Determine if this scope is a common table expression
396 @property 397 def is_root(self): 398 """Determine if this is the root scope""" 399 return self.scope_type == ScopeType.ROOT
Determine if this is the root scope
401 @property 402 def is_udtf(self): 403 """Determine if this scope is a UDTF (User Defined Table Function)""" 404 return self.scope_type == ScopeType.UDTF
Determine if this scope is a UDTF (User Defined Table Function)
414 def rename_source(self, old_name, new_name): 415 """Rename a source in this scope""" 416 columns = self.sources.pop(old_name or "", []) 417 self.sources[new_name] = columns
Rename a source in this scope
419 def add_source(self, name, source): 420 """Add a source to this scope""" 421 self.sources[name] = source 422 self.clear_cache()
Add a source to this scope
424 def remove_source(self, name): 425 """Remove a source from this scope""" 426 self.sources.pop(name, None) 427 self.clear_cache()
Remove a source from this scope
432 def traverse(self): 433 """ 434 Traverse the scope tree from this node. 435 436 Yields: 437 Scope: scope instances in depth-first-search post-order 438 """ 439 stack = [self] 440 result = [] 441 while stack: 442 scope = stack.pop() 443 result.append(scope) 444 stack.extend( 445 itertools.chain( 446 scope.cte_scopes, 447 scope.union_scopes, 448 scope.table_scopes, 449 scope.subquery_scopes, 450 ) 451 ) 452 453 yield from reversed(result)
Traverse the scope tree from this node.
Yields:
Scope: scope instances in depth-first-search post-order
455 def ref_count(self): 456 """ 457 Count the number of times each scope in this tree is referenced. 458 459 Returns: 460 dict[int, int]: Mapping of Scope instance ID to reference count 461 """ 462 scope_ref_count = defaultdict(lambda: 0) 463 464 for scope in self.traverse(): 465 for _, source in scope.selected_sources.values(): 466 scope_ref_count[id(source)] += 1 467 468 return scope_ref_count
Count the number of times each scope in this tree is referenced.
Returns:
dict[int, int]: Mapping of Scope instance ID to reference count
471def traverse_scope(expression: exp.Expression) -> t.List[Scope]: 472 """ 473 Traverse an expression by its "scopes". 474 475 "Scope" represents the current context of a Select statement. 476 477 This is helpful for optimizing queries, where we need more information than 478 the expression tree itself. For example, we might care about the source 479 names within a subquery. Returns a list because a generator could result in 480 incomplete properties which is confusing. 481 482 Examples: 483 >>> import sqlglot 484 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 485 >>> scopes = traverse_scope(expression) 486 >>> scopes[0].expression.sql(), list(scopes[0].sources) 487 ('SELECT a FROM x', ['x']) 488 >>> scopes[1].expression.sql(), list(scopes[1].sources) 489 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 490 491 Args: 492 expression: Expression to traverse 493 494 Returns: 495 A list of the created scope instances 496 """ 497 if isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Query): 498 # We ignore the DDL expression and build a scope for its query instead 499 ddl_with = expression.args.get("with") 500 expression = expression.expression 501 502 # If the DDL has CTEs attached, we need to add them to the query, or 503 # prepend them if the query itself already has CTEs attached to it 504 if ddl_with: 505 ddl_with.pop() 506 query_ctes = expression.ctes 507 if not query_ctes: 508 expression.set("with", ddl_with) 509 else: 510 expression.args["with"].set("recursive", ddl_with.recursive) 511 expression.args["with"].set("expressions", [*ddl_with.expressions, *query_ctes]) 512 513 if isinstance(expression, exp.Query): 514 return list(_traverse_scope(Scope(expression))) 515 516 return []
Traverse an expression by its "scopes".
"Scope" represents the current context of a Select statement.
This is helpful for optimizing queries, where we need more information than the expression tree itself. For example, we might care about the source names within a subquery. Returns a list because a generator could result in incomplete properties which is confusing.
Examples:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") >>> scopes = traverse_scope(expression) >>> scopes[0].expression.sql(), list(scopes[0].sources) ('SELECT a FROM x', ['x']) >>> scopes[1].expression.sql(), list(scopes[1].sources) ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
Arguments:
- expression: Expression to traverse
Returns:
A list of the created scope instances
519def build_scope(expression: exp.Expression) -> t.Optional[Scope]: 520 """ 521 Build a scope tree. 522 523 Args: 524 expression: Expression to build the scope tree for. 525 526 Returns: 527 The root scope 528 """ 529 return seq_get(traverse_scope(expression), -1)
Build a scope tree.
Arguments:
- expression: Expression to build the scope tree for.
Returns:
The root scope
771def walk_in_scope(expression, bfs=True, prune=None): 772 """ 773 Returns a generator object which visits all nodes in the syntrax tree, stopping at 774 nodes that start child scopes. 775 776 Args: 777 expression (exp.Expression): 778 bfs (bool): if set to True the BFS traversal order will be applied, 779 otherwise the DFS traversal will be used instead. 780 prune ((node, parent, arg_key) -> bool): callable that returns True if 781 the generator should stop traversing this branch of the tree. 782 783 Yields: 784 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 785 """ 786 # We'll use this variable to pass state into the dfs generator. 787 # Whenever we set it to True, we exclude a subtree from traversal. 788 crossed_scope_boundary = False 789 790 for node, parent, key in expression.walk( 791 bfs=bfs, prune=lambda *args: crossed_scope_boundary or (prune and prune(*args)) 792 ): 793 crossed_scope_boundary = False 794 795 yield node, parent, key 796 797 if node is expression: 798 continue 799 if ( 800 isinstance(node, exp.CTE) 801 or (_is_derived_table(node) and isinstance(parent, (exp.From, exp.Join, exp.Subquery))) 802 or isinstance(node, exp.UDTF) 803 or isinstance(node, exp.UNWRAPPED_QUERIES) 804 ): 805 crossed_scope_boundary = True 806 807 if isinstance(node, (exp.Subquery, exp.UDTF)): 808 # The following args are not actually in the inner scope, so we should visit them 809 for key in ("joins", "laterals", "pivots"): 810 for arg in node.args.get(key) or []: 811 yield from walk_in_scope(arg, bfs=bfs)
Returns a generator object which visits all nodes in the syntrax tree, stopping at nodes that start child scopes.
Arguments:
- expression (exp.Expression):
- bfs (bool): if set to True the BFS traversal order will be applied, otherwise the DFS traversal will be used instead.
- prune ((node, parent, arg_key) -> bool): callable that returns True if the generator should stop traversing this branch of the tree.
Yields:
tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
814def find_all_in_scope(expression, expression_types, bfs=True): 815 """ 816 Returns a generator object which visits all nodes in this scope and only yields those that 817 match at least one of the specified expression types. 818 819 This does NOT traverse into subscopes. 820 821 Args: 822 expression (exp.Expression): 823 expression_types (tuple[type]|type): the expression type(s) to match. 824 bfs (bool): True to use breadth-first search, False to use depth-first. 825 826 Yields: 827 exp.Expression: nodes 828 """ 829 for expression, *_ in walk_in_scope(expression, bfs=bfs): 830 if isinstance(expression, tuple(ensure_collection(expression_types))): 831 yield expression
Returns a generator object which visits all nodes in this scope and only yields those that match at least one of the specified expression types.
This does NOT traverse into subscopes.
Arguments:
- expression (exp.Expression):
- expression_types (tuple[type]|type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Yields:
exp.Expression: nodes
834def find_in_scope(expression, expression_types, bfs=True): 835 """ 836 Returns the first node in this scope which matches at least one of the specified types. 837 838 This does NOT traverse into subscopes. 839 840 Args: 841 expression (exp.Expression): 842 expression_types (tuple[type]|type): the expression type(s) to match. 843 bfs (bool): True to use breadth-first search, False to use depth-first. 844 845 Returns: 846 exp.Expression: the node which matches the criteria or None if no node matching 847 the criteria was found. 848 """ 849 return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)
Returns the first node in this scope which matches at least one of the specified types.
This does NOT traverse into subscopes.
Arguments:
- expression (exp.Expression):
- expression_types (tuple[type]|type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Returns:
exp.Expression: the node which matches the criteria or None if no node matching the criteria was found.