Edit on GitHub

sqlglot.optimizer.qualify_columns

  1from __future__ import annotations
  2
  3import itertools
  4import typing as t
  5
  6from sqlglot import alias, exp
  7from sqlglot.dialects.dialect import Dialect, DialectType
  8from sqlglot.errors import OptimizeError
  9from sqlglot.helper import seq_get, SingleValuedMapping
 10from sqlglot.optimizer.annotate_types import TypeAnnotator
 11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
 12from sqlglot.optimizer.simplify import simplify_parens
 13from sqlglot.schema import Schema, ensure_schema
 14
 15if t.TYPE_CHECKING:
 16    from sqlglot._typing import E
 17
 18
 19def qualify_columns(
 20    expression: exp.Expression,
 21    schema: t.Dict | Schema,
 22    expand_alias_refs: bool = True,
 23    expand_stars: bool = True,
 24    infer_schema: t.Optional[bool] = None,
 25) -> exp.Expression:
 26    """
 27    Rewrite sqlglot AST to have fully qualified columns.
 28
 29    Example:
 30        >>> import sqlglot
 31        >>> schema = {"tbl": {"col": "INT"}}
 32        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
 33        >>> qualify_columns(expression, schema).sql()
 34        'SELECT tbl.col AS col FROM tbl'
 35
 36    Args:
 37        expression: Expression to qualify.
 38        schema: Database schema.
 39        expand_alias_refs: Whether to expand references to aliases.
 40        expand_stars: Whether to expand star queries. This is a necessary step
 41            for most of the optimizer's rules to work; do not set to False unless you
 42            know what you're doing!
 43        infer_schema: Whether to infer the schema if missing.
 44
 45    Returns:
 46        The qualified expression.
 47
 48    Notes:
 49        - Currently only handles a single PIVOT or UNPIVOT operator
 50    """
 51    schema = ensure_schema(schema)
 52    annotator = TypeAnnotator(schema)
 53    infer_schema = schema.empty if infer_schema is None else infer_schema
 54    dialect = Dialect.get_or_raise(schema.dialect)
 55    pseudocolumns = dialect.PSEUDOCOLUMNS
 56
 57    for scope in traverse_scope(expression):
 58        resolver = Resolver(scope, schema, infer_schema=infer_schema)
 59        _pop_table_column_aliases(scope.ctes)
 60        _pop_table_column_aliases(scope.derived_tables)
 61        using_column_tables = _expand_using(scope, resolver)
 62
 63        if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs:
 64            _expand_alias_refs(
 65                scope,
 66                resolver,
 67                expand_only_groupby=dialect.EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY,
 68            )
 69
 70        _convert_columns_to_dots(scope, resolver)
 71        _qualify_columns(scope, resolver)
 72
 73        if not schema.empty and expand_alias_refs:
 74            _expand_alias_refs(scope, resolver)
 75
 76        if not isinstance(scope.expression, exp.UDTF):
 77            if expand_stars:
 78                _expand_stars(
 79                    scope,
 80                    resolver,
 81                    using_column_tables,
 82                    pseudocolumns,
 83                    annotator,
 84                )
 85            qualify_outputs(scope)
 86
 87        _expand_group_by(scope, dialect)
 88        _expand_order_by(scope, resolver)
 89
 90        if dialect == "bigquery":
 91            annotator.annotate_scope(scope)
 92
 93    return expression
 94
 95
 96def validate_qualify_columns(expression: E) -> E:
 97    """Raise an `OptimizeError` if any columns aren't qualified"""
 98    all_unqualified_columns = []
 99    for scope in traverse_scope(expression):
100        if isinstance(scope.expression, exp.Select):
101            unqualified_columns = scope.unqualified_columns
102
103            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
104                column = scope.external_columns[0]
105                for_table = f" for table: '{column.table}'" if column.table else ""
106                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
107
108            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
109                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
110                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
111                # this list here to ensure those in the former category will be excluded.
112                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
113                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
114
115            all_unqualified_columns.extend(unqualified_columns)
116
117    if all_unqualified_columns:
118        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
119
120    return expression
121
122
123def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]:
124    name_column = []
125    field = unpivot.args.get("field")
126    if isinstance(field, exp.In) and isinstance(field.this, exp.Column):
127        name_column.append(field.this)
128
129    value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column))
130    return itertools.chain(name_column, value_columns)
131
132
133def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None:
134    """
135    Remove table column aliases.
136
137    For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2)
138    """
139    for derived_table in derived_tables:
140        if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive:
141            continue
142        table_alias = derived_table.args.get("alias")
143        if table_alias:
144            table_alias.args.pop("columns", None)
145
146
147def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
148    joins = list(scope.find_all(exp.Join))
149    names = {join.alias_or_name for join in joins}
150    ordered = [key for key in scope.selected_sources if key not in names]
151
152    # Mapping of automatically joined column names to an ordered set of source names (dict).
153    column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}
154
155    for i, join in enumerate(joins):
156        using = join.args.get("using")
157
158        if not using:
159            continue
160
161        join_table = join.alias_or_name
162
163        columns = {}
164
165        for source_name in scope.selected_sources:
166            if source_name in ordered:
167                for column_name in resolver.get_source_columns(source_name):
168                    if column_name not in columns:
169                        columns[column_name] = source_name
170
171        source_table = ordered[-1]
172        ordered.append(join_table)
173        join_columns = resolver.get_source_columns(join_table)
174        conditions = []
175        using_identifier_count = len(using)
176
177        for identifier in using:
178            identifier = identifier.name
179            table = columns.get(identifier)
180
181            if not table or identifier not in join_columns:
182                if (columns and "*" not in columns) and join_columns:
183                    raise OptimizeError(f"Cannot automatically join: {identifier}")
184
185            table = table or source_table
186
187            if i == 0 or using_identifier_count == 1:
188                lhs: exp.Expression = exp.column(identifier, table=table)
189            else:
190                lhs = exp.func("coalesce", *[exp.column(identifier, table=t) for t in ordered[:-1]])
191
192            conditions.append(lhs.eq(exp.column(identifier, table=join_table)))
193
194            # Set all values in the dict to None, because we only care about the key ordering
195            tables = column_tables.setdefault(identifier, {})
196            if table not in tables:
197                tables[table] = None
198            if join_table not in tables:
199                tables[join_table] = None
200
201        join.args.pop("using")
202        join.set("on", exp.and_(*conditions, copy=False))
203
204    if column_tables:
205        for column in scope.columns:
206            if not column.table and column.name in column_tables:
207                tables = column_tables[column.name]
208                coalesce_args = [exp.column(column.name, table=table) for table in tables]
209                replacement = exp.func("coalesce", *coalesce_args)
210
211                # Ensure selects keep their output name
212                if isinstance(column.parent, exp.Select):
213                    replacement = alias(replacement, alias=column.name, copy=False)
214
215                scope.replace(column, replacement)
216
217    return column_tables
218
219
220def _expand_alias_refs(scope: Scope, resolver: Resolver, expand_only_groupby: bool = False) -> None:
221    expression = scope.expression
222
223    if not isinstance(expression, exp.Select):
224        return
225
226    alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {}
227
228    def replace_columns(
229        node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
230    ) -> None:
231        if not node or (expand_only_groupby and not isinstance(node, exp.Group)):
232            return
233
234        for column in walk_in_scope(node, prune=lambda node: node.is_star):
235            if not isinstance(column, exp.Column):
236                continue
237
238            table = resolver.get_table(column.name) if resolve_table and not column.table else None
239            alias_expr, i = alias_to_expression.get(column.name, (None, 1))
240            double_agg = (
241                (
242                    alias_expr.find(exp.AggFunc)
243                    and (
244                        column.find_ancestor(exp.AggFunc)
245                        and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window)
246                    )
247                )
248                if alias_expr
249                else False
250            )
251
252            if table and (not alias_expr or double_agg):
253                column.set("table", table)
254            elif not column.table and alias_expr and not double_agg:
255                if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table):
256                    if literal_index:
257                        column.replace(exp.Literal.number(i))
258                else:
259                    column = column.replace(exp.paren(alias_expr))
260                    simplified = simplify_parens(column)
261                    if simplified is not column:
262                        column.replace(simplified)
263
264    for i, projection in enumerate(scope.expression.selects):
265        replace_columns(projection)
266
267        if isinstance(projection, exp.Alias):
268            alias_to_expression[projection.alias] = (projection.this, i + 1)
269
270    replace_columns(expression.args.get("where"))
271    replace_columns(expression.args.get("group"), literal_index=True)
272    replace_columns(expression.args.get("having"), resolve_table=True)
273    replace_columns(expression.args.get("qualify"), resolve_table=True)
274
275    scope.clear_cache()
276
277
278def _expand_group_by(scope: Scope, dialect: DialectType) -> None:
279    expression = scope.expression
280    group = expression.args.get("group")
281    if not group:
282        return
283
284    group.set("expressions", _expand_positional_references(scope, group.expressions, dialect))
285    expression.set("group", group)
286
287
288def _expand_order_by(scope: Scope, resolver: Resolver) -> None:
289    order = scope.expression.args.get("order")
290    if not order:
291        return
292
293    ordereds = order.expressions
294    for ordered, new_expression in zip(
295        ordereds,
296        _expand_positional_references(
297            scope, (o.this for o in ordereds), resolver.schema.dialect, alias=True
298        ),
299    ):
300        for agg in ordered.find_all(exp.AggFunc):
301            for col in agg.find_all(exp.Column):
302                if not col.table:
303                    col.set("table", resolver.get_table(col.name))
304
305        ordered.set("this", new_expression)
306
307    if scope.expression.args.get("group"):
308        selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects}
309
310        for ordered in ordereds:
311            ordered = ordered.this
312
313            ordered.replace(
314                exp.to_identifier(_select_by_pos(scope, ordered).alias)
315                if ordered.is_int
316                else selects.get(ordered, ordered)
317            )
318
319
320def _expand_positional_references(
321    scope: Scope, expressions: t.Iterable[exp.Expression], dialect: DialectType, alias: bool = False
322) -> t.List[exp.Expression]:
323    new_nodes: t.List[exp.Expression] = []
324    ambiguous_projections = None
325
326    for node in expressions:
327        if node.is_int:
328            select = _select_by_pos(scope, t.cast(exp.Literal, node))
329
330            if alias:
331                new_nodes.append(exp.column(select.args["alias"].copy()))
332            else:
333                select = select.this
334
335                if dialect == "bigquery":
336                    if ambiguous_projections is None:
337                        # When a projection name is also a source name and it is referenced in the
338                        # GROUP BY clause, BQ can't understand what the identifier corresponds to
339                        ambiguous_projections = {
340                            s.alias_or_name
341                            for s in scope.expression.selects
342                            if s.alias_or_name in scope.selected_sources
343                        }
344
345                    ambiguous = any(
346                        column.parts[0].name in ambiguous_projections
347                        for column in select.find_all(exp.Column)
348                    )
349                else:
350                    ambiguous = False
351
352                if (
353                    isinstance(select, exp.CONSTANTS)
354                    or select.find(exp.Explode, exp.Unnest)
355                    or ambiguous
356                ):
357                    new_nodes.append(node)
358                else:
359                    new_nodes.append(select.copy())
360        else:
361            new_nodes.append(node)
362
363    return new_nodes
364
365
366def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
367    try:
368        return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias)
369    except IndexError:
370        raise OptimizeError(f"Unknown output column: {node.name}")
371
372
373def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None:
374    """
375    Converts `Column` instances that represent struct field lookup into chained `Dots`.
376
377    Struct field lookups look like columns (e.g. "struct"."field"), but they need to be
378    qualified separately and represented as Dot(Dot(...(<table>.<column>, field1), field2, ...)).
379    """
380    converted = False
381    for column in itertools.chain(scope.columns, scope.stars):
382        if isinstance(column, exp.Dot):
383            continue
384
385        column_table: t.Optional[str | exp.Identifier] = column.table
386        if (
387            column_table
388            and column_table not in scope.sources
389            and (
390                not scope.parent
391                or column_table not in scope.parent.sources
392                or not scope.is_correlated_subquery
393            )
394        ):
395            root, *parts = column.parts
396
397            if root.name in scope.sources:
398                # The struct is already qualified, but we still need to change the AST
399                column_table = root
400                root, *parts = parts
401            else:
402                column_table = resolver.get_table(root.name)
403
404            if column_table:
405                converted = True
406                column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
407
408    if converted:
409        # We want to re-aggregate the converted columns, otherwise they'd be skipped in
410        # a `for column in scope.columns` iteration, even though they shouldn't be
411        scope.clear_cache()
412
413
414def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
415    """Disambiguate columns, ensuring each column specifies a source"""
416    for column in scope.columns:
417        column_table = column.table
418        column_name = column.name
419
420        if column_table and column_table in scope.sources:
421            source_columns = resolver.get_source_columns(column_table)
422            if source_columns and column_name not in source_columns and "*" not in source_columns:
423                raise OptimizeError(f"Unknown column: {column_name}")
424
425        if not column_table:
426            if scope.pivots and not column.find_ancestor(exp.Pivot):
427                # If the column is under the Pivot expression, we need to qualify it
428                # using the name of the pivoted source instead of the pivot's alias
429                column.set("table", exp.to_identifier(scope.pivots[0].alias))
430                continue
431
432            # column_table can be a '' because bigquery unnest has no table alias
433            column_table = resolver.get_table(column_name)
434            if column_table:
435                column.set("table", column_table)
436
437    for pivot in scope.pivots:
438        for column in pivot.find_all(exp.Column):
439            if not column.table and column.name in resolver.all_columns:
440                column_table = resolver.get_table(column.name)
441                if column_table:
442                    column.set("table", column_table)
443
444
445def _expand_struct_stars(
446    expression: exp.Dot,
447) -> t.List[exp.Alias]:
448    """[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column"""
449
450    dot_column = t.cast(exp.Column, expression.find(exp.Column))
451    if not dot_column.is_type(exp.DataType.Type.STRUCT):
452        return []
453
454    # All nested struct values are ColumnDefs, so normalize the first exp.Column in one
455    dot_column = dot_column.copy()
456    starting_struct = exp.ColumnDef(this=dot_column.this, kind=dot_column.type)
457
458    # First part is the table name and last part is the star so they can be dropped
459    dot_parts = expression.parts[1:-1]
460
461    # If we're expanding a nested struct eg. t.c.f1.f2.* find the last struct (f2 in this case)
462    for part in dot_parts[1:]:
463        for field in t.cast(exp.DataType, starting_struct.kind).expressions:
464            # Unable to expand star unless all fields are named
465            if not isinstance(field.this, exp.Identifier):
466                return []
467
468            if field.name == part.name and field.kind.is_type(exp.DataType.Type.STRUCT):
469                starting_struct = field
470                break
471        else:
472            # There is no matching field in the struct
473            return []
474
475    taken_names = set()
476    new_selections = []
477
478    for field in t.cast(exp.DataType, starting_struct.kind).expressions:
479        name = field.name
480
481        # Ambiguous or anonymous fields can't be expanded
482        if name in taken_names or not isinstance(field.this, exp.Identifier):
483            return []
484
485        taken_names.add(name)
486
487        this = field.this.copy()
488        root, *parts = [part.copy() for part in itertools.chain(dot_parts, [this])]
489        new_column = exp.column(
490            t.cast(exp.Identifier, root), table=dot_column.args.get("table"), fields=parts
491        )
492        new_selections.append(alias(new_column, this, copy=False))
493
494    return new_selections
495
496
497def _expand_stars(
498    scope: Scope,
499    resolver: Resolver,
500    using_column_tables: t.Dict[str, t.Any],
501    pseudocolumns: t.Set[str],
502    annotator: TypeAnnotator,
503) -> None:
504    """Expand stars to lists of column selections"""
505
506    new_selections = []
507    except_columns: t.Dict[int, t.Set[str]] = {}
508    replace_columns: t.Dict[int, t.Dict[str, str]] = {}
509    coalesced_columns = set()
510    dialect = resolver.schema.dialect
511
512    pivot_output_columns = None
513    pivot_exclude_columns = None
514
515    pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
516    if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names:
517        if pivot.unpivot:
518            pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)]
519
520            field = pivot.args.get("field")
521            if isinstance(field, exp.In):
522                pivot_exclude_columns = {
523                    c.output_name for e in field.expressions for c in e.find_all(exp.Column)
524                }
525        else:
526            pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column))
527
528            pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])]
529            if not pivot_output_columns:
530                pivot_output_columns = [c.alias_or_name for c in pivot.expressions]
531
532    is_bigquery = dialect == "bigquery"
533    if is_bigquery and any(isinstance(col, exp.Dot) for col in scope.stars):
534        # Found struct expansion, annotate scope ahead of time
535        annotator.annotate_scope(scope)
536
537    for expression in scope.expression.selects:
538        tables = []
539        if isinstance(expression, exp.Star):
540            tables.extend(scope.selected_sources)
541            _add_except_columns(expression, tables, except_columns)
542            _add_replace_columns(expression, tables, replace_columns)
543        elif expression.is_star:
544            if not isinstance(expression, exp.Dot):
545                tables.append(expression.table)
546                _add_except_columns(expression.this, tables, except_columns)
547                _add_replace_columns(expression.this, tables, replace_columns)
548            elif is_bigquery:
549                struct_fields = _expand_struct_stars(expression)
550                if struct_fields:
551                    new_selections.extend(struct_fields)
552                    continue
553
554        if not tables:
555            new_selections.append(expression)
556            continue
557
558        for table in tables:
559            if table not in scope.sources:
560                raise OptimizeError(f"Unknown table: {table}")
561
562            columns = resolver.get_source_columns(table, only_visible=True)
563            columns = columns or scope.outer_columns
564
565            if pseudocolumns:
566                columns = [name for name in columns if name.upper() not in pseudocolumns]
567
568            if not columns or "*" in columns:
569                return
570
571            table_id = id(table)
572            columns_to_exclude = except_columns.get(table_id) or set()
573
574            if pivot:
575                if pivot_output_columns and pivot_exclude_columns:
576                    pivot_columns = [c for c in columns if c not in pivot_exclude_columns]
577                    pivot_columns.extend(pivot_output_columns)
578                else:
579                    pivot_columns = pivot.alias_column_names
580
581                if pivot_columns:
582                    new_selections.extend(
583                        alias(exp.column(name, table=pivot.alias), name, copy=False)
584                        for name in pivot_columns
585                        if name not in columns_to_exclude
586                    )
587                    continue
588
589            for name in columns:
590                if name in columns_to_exclude or name in coalesced_columns:
591                    continue
592                if name in using_column_tables and table in using_column_tables[name]:
593                    coalesced_columns.add(name)
594                    tables = using_column_tables[name]
595                    coalesce_args = [exp.column(name, table=table) for table in tables]
596
597                    new_selections.append(
598                        alias(exp.func("coalesce", *coalesce_args), alias=name, copy=False)
599                    )
600                else:
601                    alias_ = replace_columns.get(table_id, {}).get(name, name)
602                    column = exp.column(name, table=table)
603                    new_selections.append(
604                        alias(column, alias_, copy=False) if alias_ != name else column
605                    )
606
607    # Ensures we don't overwrite the initial selections with an empty list
608    if new_selections and isinstance(scope.expression, exp.Select):
609        scope.expression.set("expressions", new_selections)
610
611
612def _add_except_columns(
613    expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]]
614) -> None:
615    except_ = expression.args.get("except")
616
617    if not except_:
618        return
619
620    columns = {e.name for e in except_}
621
622    for table in tables:
623        except_columns[id(table)] = columns
624
625
626def _add_replace_columns(
627    expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]]
628) -> None:
629    replace = expression.args.get("replace")
630
631    if not replace:
632        return
633
634    columns = {e.this.name: e.alias for e in replace}
635
636    for table in tables:
637        replace_columns[id(table)] = columns
638
639
640def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
641    """Ensure all output columns are aliased"""
642    if isinstance(scope_or_expression, exp.Expression):
643        scope = build_scope(scope_or_expression)
644        if not isinstance(scope, Scope):
645            return
646    else:
647        scope = scope_or_expression
648
649    new_selections = []
650    for i, (selection, aliased_column) in enumerate(
651        itertools.zip_longest(scope.expression.selects, scope.outer_columns)
652    ):
653        if selection is None:
654            break
655
656        if isinstance(selection, exp.Subquery):
657            if not selection.output_name:
658                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
659        elif not isinstance(selection, exp.Alias) and not selection.is_star:
660            selection = alias(
661                selection,
662                alias=selection.output_name or f"_col_{i}",
663                copy=False,
664            )
665        if aliased_column:
666            selection.set("alias", exp.to_identifier(aliased_column))
667
668        new_selections.append(selection)
669
670    if isinstance(scope.expression, exp.Select):
671        scope.expression.set("expressions", new_selections)
672
673
674def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
675    """Makes sure all identifiers that need to be quoted are quoted."""
676    return expression.transform(
677        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
678    )  # type: ignore
679
680
681def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
682    """
683    Pushes down the CTE alias columns into the projection,
684
685    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
686
687    Example:
688        >>> import sqlglot
689        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
690        >>> pushdown_cte_alias_columns(expression).sql()
691        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
692
693    Args:
694        expression: Expression to pushdown.
695
696    Returns:
697        The expression with the CTE aliases pushed down into the projection.
698    """
699    for cte in expression.find_all(exp.CTE):
700        if cte.alias_column_names:
701            new_expressions = []
702            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
703                if isinstance(projection, exp.Alias):
704                    projection.set("alias", _alias)
705                else:
706                    projection = alias(projection, alias=_alias)
707                new_expressions.append(projection)
708            cte.this.set("expressions", new_expressions)
709
710    return expression
711
712
713class Resolver:
714    """
715    Helper for resolving columns.
716
717    This is a class so we can lazily load some things and easily share them across functions.
718    """
719
720    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
721        self.scope = scope
722        self.schema = schema
723        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
724        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
725        self._all_columns: t.Optional[t.Set[str]] = None
726        self._infer_schema = infer_schema
727
728    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
729        """
730        Get the table for a column name.
731
732        Args:
733            column_name: The column name to find the table for.
734        Returns:
735            The table name if it can be found/inferred.
736        """
737        if self._unambiguous_columns is None:
738            self._unambiguous_columns = self._get_unambiguous_columns(
739                self._get_all_source_columns()
740            )
741
742        table_name = self._unambiguous_columns.get(column_name)
743
744        if not table_name and self._infer_schema:
745            sources_without_schema = tuple(
746                source
747                for source, columns in self._get_all_source_columns().items()
748                if not columns or "*" in columns
749            )
750            if len(sources_without_schema) == 1:
751                table_name = sources_without_schema[0]
752
753        if table_name not in self.scope.selected_sources:
754            return exp.to_identifier(table_name)
755
756        node, _ = self.scope.selected_sources.get(table_name)
757
758        if isinstance(node, exp.Query):
759            while node and node.alias != table_name:
760                node = node.parent
761
762        node_alias = node.args.get("alias")
763        if node_alias:
764            return exp.to_identifier(node_alias.this)
765
766        return exp.to_identifier(table_name)
767
768    @property
769    def all_columns(self) -> t.Set[str]:
770        """All available columns of all sources in this scope"""
771        if self._all_columns is None:
772            self._all_columns = {
773                column for columns in self._get_all_source_columns().values() for column in columns
774            }
775        return self._all_columns
776
777    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
778        """Resolve the source columns for a given source `name`."""
779        if name not in self.scope.sources:
780            raise OptimizeError(f"Unknown table: {name}")
781
782        source = self.scope.sources[name]
783
784        if isinstance(source, exp.Table):
785            columns = self.schema.column_names(source, only_visible)
786        elif isinstance(source, Scope) and isinstance(source.expression, (exp.Values, exp.Unnest)):
787            columns = source.expression.named_selects
788
789            # in bigquery, unnest structs are automatically scoped as tables, so you can
790            # directly select a struct field in a query.
791            # this handles the case where the unnest is statically defined.
792            if self.schema.dialect == "bigquery":
793                if source.expression.is_type(exp.DataType.Type.STRUCT):
794                    for k in source.expression.type.expressions:  # type: ignore
795                        columns.append(k.name)
796        else:
797            columns = source.expression.named_selects
798
799        node, _ = self.scope.selected_sources.get(name) or (None, None)
800        if isinstance(node, Scope):
801            column_aliases = node.expression.alias_column_names
802        elif isinstance(node, exp.Expression):
803            column_aliases = node.alias_column_names
804        else:
805            column_aliases = []
806
807        if column_aliases:
808            # If the source's columns are aliased, their aliases shadow the corresponding column names.
809            # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
810            return [
811                alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)
812            ]
813        return columns
814
815    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
816        if self._source_columns is None:
817            self._source_columns = {
818                source_name: self.get_source_columns(source_name)
819                for source_name, source in itertools.chain(
820                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
821                )
822            }
823        return self._source_columns
824
825    def _get_unambiguous_columns(
826        self, source_columns: t.Dict[str, t.Sequence[str]]
827    ) -> t.Mapping[str, str]:
828        """
829        Find all the unambiguous columns in sources.
830
831        Args:
832            source_columns: Mapping of names to source columns.
833
834        Returns:
835            Mapping of column name to source name.
836        """
837        if not source_columns:
838            return {}
839
840        source_columns_pairs = list(source_columns.items())
841
842        first_table, first_columns = source_columns_pairs[0]
843
844        if len(source_columns_pairs) == 1:
845            # Performance optimization - avoid copying first_columns if there is only one table.
846            return SingleValuedMapping(first_columns, first_table)
847
848        unambiguous_columns = {col: first_table for col in first_columns}
849        all_columns = set(unambiguous_columns)
850
851        for table, columns in source_columns_pairs[1:]:
852            unique = set(columns)
853            ambiguous = all_columns.intersection(unique)
854            all_columns.update(columns)
855
856            for column in ambiguous:
857                unambiguous_columns.pop(column, None)
858            for column in unique.difference(ambiguous):
859                unambiguous_columns[column] = table
860
861        return unambiguous_columns
def qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None) -> sqlglot.expressions.Expression:
20def qualify_columns(
21    expression: exp.Expression,
22    schema: t.Dict | Schema,
23    expand_alias_refs: bool = True,
24    expand_stars: bool = True,
25    infer_schema: t.Optional[bool] = None,
26) -> exp.Expression:
27    """
28    Rewrite sqlglot AST to have fully qualified columns.
29
30    Example:
31        >>> import sqlglot
32        >>> schema = {"tbl": {"col": "INT"}}
33        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
34        >>> qualify_columns(expression, schema).sql()
35        'SELECT tbl.col AS col FROM tbl'
36
37    Args:
38        expression: Expression to qualify.
39        schema: Database schema.
40        expand_alias_refs: Whether to expand references to aliases.
41        expand_stars: Whether to expand star queries. This is a necessary step
42            for most of the optimizer's rules to work; do not set to False unless you
43            know what you're doing!
44        infer_schema: Whether to infer the schema if missing.
45
46    Returns:
47        The qualified expression.
48
49    Notes:
50        - Currently only handles a single PIVOT or UNPIVOT operator
51    """
52    schema = ensure_schema(schema)
53    annotator = TypeAnnotator(schema)
54    infer_schema = schema.empty if infer_schema is None else infer_schema
55    dialect = Dialect.get_or_raise(schema.dialect)
56    pseudocolumns = dialect.PSEUDOCOLUMNS
57
58    for scope in traverse_scope(expression):
59        resolver = Resolver(scope, schema, infer_schema=infer_schema)
60        _pop_table_column_aliases(scope.ctes)
61        _pop_table_column_aliases(scope.derived_tables)
62        using_column_tables = _expand_using(scope, resolver)
63
64        if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs:
65            _expand_alias_refs(
66                scope,
67                resolver,
68                expand_only_groupby=dialect.EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY,
69            )
70
71        _convert_columns_to_dots(scope, resolver)
72        _qualify_columns(scope, resolver)
73
74        if not schema.empty and expand_alias_refs:
75            _expand_alias_refs(scope, resolver)
76
77        if not isinstance(scope.expression, exp.UDTF):
78            if expand_stars:
79                _expand_stars(
80                    scope,
81                    resolver,
82                    using_column_tables,
83                    pseudocolumns,
84                    annotator,
85                )
86            qualify_outputs(scope)
87
88        _expand_group_by(scope, dialect)
89        _expand_order_by(scope, resolver)
90
91        if dialect == "bigquery":
92            annotator.annotate_scope(scope)
93
94    return expression

Rewrite sqlglot AST to have fully qualified columns.

Example:
>>> import sqlglot
>>> schema = {"tbl": {"col": "INT"}}
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
>>> qualify_columns(expression, schema).sql()
'SELECT tbl.col AS col FROM tbl'
Arguments:
  • expression: Expression to qualify.
  • schema: Database schema.
  • expand_alias_refs: Whether to expand references to aliases.
  • expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
  • infer_schema: Whether to infer the schema if missing.
Returns:

The qualified expression.

Notes:
  • Currently only handles a single PIVOT or UNPIVOT operator
def validate_qualify_columns(expression: ~E) -> ~E:
 97def validate_qualify_columns(expression: E) -> E:
 98    """Raise an `OptimizeError` if any columns aren't qualified"""
 99    all_unqualified_columns = []
100    for scope in traverse_scope(expression):
101        if isinstance(scope.expression, exp.Select):
102            unqualified_columns = scope.unqualified_columns
103
104            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
105                column = scope.external_columns[0]
106                for_table = f" for table: '{column.table}'" if column.table else ""
107                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
108
109            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
110                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
111                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
112                # this list here to ensure those in the former category will be excluded.
113                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
114                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
115
116            all_unqualified_columns.extend(unqualified_columns)
117
118    if all_unqualified_columns:
119        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
120
121    return expression

Raise an OptimizeError if any columns aren't qualified

def qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
641def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
642    """Ensure all output columns are aliased"""
643    if isinstance(scope_or_expression, exp.Expression):
644        scope = build_scope(scope_or_expression)
645        if not isinstance(scope, Scope):
646            return
647    else:
648        scope = scope_or_expression
649
650    new_selections = []
651    for i, (selection, aliased_column) in enumerate(
652        itertools.zip_longest(scope.expression.selects, scope.outer_columns)
653    ):
654        if selection is None:
655            break
656
657        if isinstance(selection, exp.Subquery):
658            if not selection.output_name:
659                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
660        elif not isinstance(selection, exp.Alias) and not selection.is_star:
661            selection = alias(
662                selection,
663                alias=selection.output_name or f"_col_{i}",
664                copy=False,
665            )
666        if aliased_column:
667            selection.set("alias", exp.to_identifier(aliased_column))
668
669        new_selections.append(selection)
670
671    if isinstance(scope.expression, exp.Select):
672        scope.expression.set("expressions", new_selections)

Ensure all output columns are aliased

def quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
675def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
676    """Makes sure all identifiers that need to be quoted are quoted."""
677    return expression.transform(
678        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
679    )  # type: ignore

Makes sure all identifiers that need to be quoted are quoted.

def pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
682def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
683    """
684    Pushes down the CTE alias columns into the projection,
685
686    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
687
688    Example:
689        >>> import sqlglot
690        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
691        >>> pushdown_cte_alias_columns(expression).sql()
692        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
693
694    Args:
695        expression: Expression to pushdown.
696
697    Returns:
698        The expression with the CTE aliases pushed down into the projection.
699    """
700    for cte in expression.find_all(exp.CTE):
701        if cte.alias_column_names:
702            new_expressions = []
703            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
704                if isinstance(projection, exp.Alias):
705                    projection.set("alias", _alias)
706                else:
707                    projection = alias(projection, alias=_alias)
708                new_expressions.append(projection)
709            cte.this.set("expressions", new_expressions)
710
711    return expression

Pushes down the CTE alias columns into the projection,

This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
>>> pushdown_cte_alias_columns(expression).sql()
'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
  • expression: Expression to pushdown.
Returns:

The expression with the CTE aliases pushed down into the projection.

class Resolver:
714class Resolver:
715    """
716    Helper for resolving columns.
717
718    This is a class so we can lazily load some things and easily share them across functions.
719    """
720
721    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
722        self.scope = scope
723        self.schema = schema
724        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
725        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
726        self._all_columns: t.Optional[t.Set[str]] = None
727        self._infer_schema = infer_schema
728
729    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
730        """
731        Get the table for a column name.
732
733        Args:
734            column_name: The column name to find the table for.
735        Returns:
736            The table name if it can be found/inferred.
737        """
738        if self._unambiguous_columns is None:
739            self._unambiguous_columns = self._get_unambiguous_columns(
740                self._get_all_source_columns()
741            )
742
743        table_name = self._unambiguous_columns.get(column_name)
744
745        if not table_name and self._infer_schema:
746            sources_without_schema = tuple(
747                source
748                for source, columns in self._get_all_source_columns().items()
749                if not columns or "*" in columns
750            )
751            if len(sources_without_schema) == 1:
752                table_name = sources_without_schema[0]
753
754        if table_name not in self.scope.selected_sources:
755            return exp.to_identifier(table_name)
756
757        node, _ = self.scope.selected_sources.get(table_name)
758
759        if isinstance(node, exp.Query):
760            while node and node.alias != table_name:
761                node = node.parent
762
763        node_alias = node.args.get("alias")
764        if node_alias:
765            return exp.to_identifier(node_alias.this)
766
767        return exp.to_identifier(table_name)
768
769    @property
770    def all_columns(self) -> t.Set[str]:
771        """All available columns of all sources in this scope"""
772        if self._all_columns is None:
773            self._all_columns = {
774                column for columns in self._get_all_source_columns().values() for column in columns
775            }
776        return self._all_columns
777
778    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
779        """Resolve the source columns for a given source `name`."""
780        if name not in self.scope.sources:
781            raise OptimizeError(f"Unknown table: {name}")
782
783        source = self.scope.sources[name]
784
785        if isinstance(source, exp.Table):
786            columns = self.schema.column_names(source, only_visible)
787        elif isinstance(source, Scope) and isinstance(source.expression, (exp.Values, exp.Unnest)):
788            columns = source.expression.named_selects
789
790            # in bigquery, unnest structs are automatically scoped as tables, so you can
791            # directly select a struct field in a query.
792            # this handles the case where the unnest is statically defined.
793            if self.schema.dialect == "bigquery":
794                if source.expression.is_type(exp.DataType.Type.STRUCT):
795                    for k in source.expression.type.expressions:  # type: ignore
796                        columns.append(k.name)
797        else:
798            columns = source.expression.named_selects
799
800        node, _ = self.scope.selected_sources.get(name) or (None, None)
801        if isinstance(node, Scope):
802            column_aliases = node.expression.alias_column_names
803        elif isinstance(node, exp.Expression):
804            column_aliases = node.alias_column_names
805        else:
806            column_aliases = []
807
808        if column_aliases:
809            # If the source's columns are aliased, their aliases shadow the corresponding column names.
810            # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
811            return [
812                alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)
813            ]
814        return columns
815
816    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
817        if self._source_columns is None:
818            self._source_columns = {
819                source_name: self.get_source_columns(source_name)
820                for source_name, source in itertools.chain(
821                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
822                )
823            }
824        return self._source_columns
825
826    def _get_unambiguous_columns(
827        self, source_columns: t.Dict[str, t.Sequence[str]]
828    ) -> t.Mapping[str, str]:
829        """
830        Find all the unambiguous columns in sources.
831
832        Args:
833            source_columns: Mapping of names to source columns.
834
835        Returns:
836            Mapping of column name to source name.
837        """
838        if not source_columns:
839            return {}
840
841        source_columns_pairs = list(source_columns.items())
842
843        first_table, first_columns = source_columns_pairs[0]
844
845        if len(source_columns_pairs) == 1:
846            # Performance optimization - avoid copying first_columns if there is only one table.
847            return SingleValuedMapping(first_columns, first_table)
848
849        unambiguous_columns = {col: first_table for col in first_columns}
850        all_columns = set(unambiguous_columns)
851
852        for table, columns in source_columns_pairs[1:]:
853            unique = set(columns)
854            ambiguous = all_columns.intersection(unique)
855            all_columns.update(columns)
856
857            for column in ambiguous:
858                unambiguous_columns.pop(column, None)
859            for column in unique.difference(ambiguous):
860                unambiguous_columns[column] = table
861
862        return unambiguous_columns

Helper for resolving columns.

This is a class so we can lazily load some things and easily share them across functions.

Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
721    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
722        self.scope = scope
723        self.schema = schema
724        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
725        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
726        self._all_columns: t.Optional[t.Set[str]] = None
727        self._infer_schema = infer_schema
scope
schema
def get_table(self, column_name: str) -> Optional[sqlglot.expressions.Identifier]:
729    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
730        """
731        Get the table for a column name.
732
733        Args:
734            column_name: The column name to find the table for.
735        Returns:
736            The table name if it can be found/inferred.
737        """
738        if self._unambiguous_columns is None:
739            self._unambiguous_columns = self._get_unambiguous_columns(
740                self._get_all_source_columns()
741            )
742
743        table_name = self._unambiguous_columns.get(column_name)
744
745        if not table_name and self._infer_schema:
746            sources_without_schema = tuple(
747                source
748                for source, columns in self._get_all_source_columns().items()
749                if not columns or "*" in columns
750            )
751            if len(sources_without_schema) == 1:
752                table_name = sources_without_schema[0]
753
754        if table_name not in self.scope.selected_sources:
755            return exp.to_identifier(table_name)
756
757        node, _ = self.scope.selected_sources.get(table_name)
758
759        if isinstance(node, exp.Query):
760            while node and node.alias != table_name:
761                node = node.parent
762
763        node_alias = node.args.get("alias")
764        if node_alias:
765            return exp.to_identifier(node_alias.this)
766
767        return exp.to_identifier(table_name)

Get the table for a column name.

Arguments:
  • column_name: The column name to find the table for.
Returns:

The table name if it can be found/inferred.

all_columns: Set[str]
769    @property
770    def all_columns(self) -> t.Set[str]:
771        """All available columns of all sources in this scope"""
772        if self._all_columns is None:
773            self._all_columns = {
774                column for columns in self._get_all_source_columns().values() for column in columns
775            }
776        return self._all_columns

All available columns of all sources in this scope

def get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
778    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
779        """Resolve the source columns for a given source `name`."""
780        if name not in self.scope.sources:
781            raise OptimizeError(f"Unknown table: {name}")
782
783        source = self.scope.sources[name]
784
785        if isinstance(source, exp.Table):
786            columns = self.schema.column_names(source, only_visible)
787        elif isinstance(source, Scope) and isinstance(source.expression, (exp.Values, exp.Unnest)):
788            columns = source.expression.named_selects
789
790            # in bigquery, unnest structs are automatically scoped as tables, so you can
791            # directly select a struct field in a query.
792            # this handles the case where the unnest is statically defined.
793            if self.schema.dialect == "bigquery":
794                if source.expression.is_type(exp.DataType.Type.STRUCT):
795                    for k in source.expression.type.expressions:  # type: ignore
796                        columns.append(k.name)
797        else:
798            columns = source.expression.named_selects
799
800        node, _ = self.scope.selected_sources.get(name) or (None, None)
801        if isinstance(node, Scope):
802            column_aliases = node.expression.alias_column_names
803        elif isinstance(node, exp.Expression):
804            column_aliases = node.alias_column_names
805        else:
806            column_aliases = []
807
808        if column_aliases:
809            # If the source's columns are aliased, their aliases shadow the corresponding column names.
810            # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
811            return [
812                alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)
813            ]
814        return columns

Resolve the source columns for a given source name.