Edit on GitHub

sqlglot.dialects.dialect

   1from __future__ import annotations
   2
   3import logging
   4import typing as t
   5from enum import Enum, auto
   6from functools import reduce
   7
   8from sqlglot import exp
   9from sqlglot.errors import ParseError
  10from sqlglot.generator import Generator
  11from sqlglot.helper import AutoName, flatten, is_int, seq_get
  12from sqlglot.jsonpath import JSONPathTokenizer, parse as parse_json_path
  13from sqlglot.parser import Parser
  14from sqlglot.time import TIMEZONES, format_time
  15from sqlglot.tokens import Token, Tokenizer, TokenType
  16from sqlglot.trie import new_trie
  17
  18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff]
  19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub]
  20JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar]
  21
  22
  23if t.TYPE_CHECKING:
  24    from sqlglot._typing import B, E, F
  25
  26logger = logging.getLogger("sqlglot")
  27
  28UNESCAPED_SEQUENCES = {
  29    "\\a": "\a",
  30    "\\b": "\b",
  31    "\\f": "\f",
  32    "\\n": "\n",
  33    "\\r": "\r",
  34    "\\t": "\t",
  35    "\\v": "\v",
  36    "\\\\": "\\",
  37}
  38
  39
  40class Dialects(str, Enum):
  41    """Dialects supported by SQLGLot."""
  42
  43    DIALECT = ""
  44
  45    ATHENA = "athena"
  46    BIGQUERY = "bigquery"
  47    CLICKHOUSE = "clickhouse"
  48    DATABRICKS = "databricks"
  49    DORIS = "doris"
  50    DRILL = "drill"
  51    DUCKDB = "duckdb"
  52    HIVE = "hive"
  53    MATERIALIZE = "materialize"
  54    MYSQL = "mysql"
  55    ORACLE = "oracle"
  56    POSTGRES = "postgres"
  57    PRESTO = "presto"
  58    PRQL = "prql"
  59    REDSHIFT = "redshift"
  60    RISINGWAVE = "risingwave"
  61    SNOWFLAKE = "snowflake"
  62    SPARK = "spark"
  63    SPARK2 = "spark2"
  64    SQLITE = "sqlite"
  65    STARROCKS = "starrocks"
  66    TABLEAU = "tableau"
  67    TERADATA = "teradata"
  68    TRINO = "trino"
  69    TSQL = "tsql"
  70
  71
  72class NormalizationStrategy(str, AutoName):
  73    """Specifies the strategy according to which identifiers should be normalized."""
  74
  75    LOWERCASE = auto()
  76    """Unquoted identifiers are lowercased."""
  77
  78    UPPERCASE = auto()
  79    """Unquoted identifiers are uppercased."""
  80
  81    CASE_SENSITIVE = auto()
  82    """Always case-sensitive, regardless of quotes."""
  83
  84    CASE_INSENSITIVE = auto()
  85    """Always case-insensitive, regardless of quotes."""
  86
  87
  88class _Dialect(type):
  89    classes: t.Dict[str, t.Type[Dialect]] = {}
  90
  91    def __eq__(cls, other: t.Any) -> bool:
  92        if cls is other:
  93            return True
  94        if isinstance(other, str):
  95            return cls is cls.get(other)
  96        if isinstance(other, Dialect):
  97            return cls is type(other)
  98
  99        return False
 100
 101    def __hash__(cls) -> int:
 102        return hash(cls.__name__.lower())
 103
 104    @classmethod
 105    def __getitem__(cls, key: str) -> t.Type[Dialect]:
 106        return cls.classes[key]
 107
 108    @classmethod
 109    def get(
 110        cls, key: str, default: t.Optional[t.Type[Dialect]] = None
 111    ) -> t.Optional[t.Type[Dialect]]:
 112        return cls.classes.get(key, default)
 113
 114    def __new__(cls, clsname, bases, attrs):
 115        klass = super().__new__(cls, clsname, bases, attrs)
 116        enum = Dialects.__members__.get(clsname.upper())
 117        cls.classes[enum.value if enum is not None else clsname.lower()] = klass
 118
 119        klass.TIME_TRIE = new_trie(klass.TIME_MAPPING)
 120        klass.FORMAT_TRIE = (
 121            new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE
 122        )
 123        klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
 124        klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
 125        klass.INVERSE_FORMAT_MAPPING = {v: k for k, v in klass.FORMAT_MAPPING.items()}
 126        klass.INVERSE_FORMAT_TRIE = new_trie(klass.INVERSE_FORMAT_MAPPING)
 127
 128        base = seq_get(bases, 0)
 129        base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),)
 130        base_jsonpath_tokenizer = (getattr(base, "jsonpath_tokenizer_class", JSONPathTokenizer),)
 131        base_parser = (getattr(base, "parser_class", Parser),)
 132        base_generator = (getattr(base, "generator_class", Generator),)
 133
 134        klass.tokenizer_class = klass.__dict__.get(
 135            "Tokenizer", type("Tokenizer", base_tokenizer, {})
 136        )
 137        klass.jsonpath_tokenizer_class = klass.__dict__.get(
 138            "JSONPathTokenizer", type("JSONPathTokenizer", base_jsonpath_tokenizer, {})
 139        )
 140        klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {}))
 141        klass.generator_class = klass.__dict__.get(
 142            "Generator", type("Generator", base_generator, {})
 143        )
 144
 145        klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0]
 146        klass.IDENTIFIER_START, klass.IDENTIFIER_END = list(
 147            klass.tokenizer_class._IDENTIFIERS.items()
 148        )[0]
 149
 150        def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]:
 151            return next(
 152                (
 153                    (s, e)
 154                    for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items()
 155                    if t == token_type
 156                ),
 157                (None, None),
 158            )
 159
 160        klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING)
 161        klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING)
 162        klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
 163        klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING)
 164
 165        if "\\" in klass.tokenizer_class.STRING_ESCAPES:
 166            klass.UNESCAPED_SEQUENCES = {
 167                **UNESCAPED_SEQUENCES,
 168                **klass.UNESCAPED_SEQUENCES,
 169            }
 170
 171        klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()}
 172
 173        klass.SUPPORTS_COLUMN_JOIN_MARKS = "(+)" in klass.tokenizer_class.KEYWORDS
 174
 175        if enum not in ("", "bigquery"):
 176            klass.generator_class.SELECT_KINDS = ()
 177
 178        if enum not in ("", "athena", "presto", "trino"):
 179            klass.generator_class.TRY_SUPPORTED = False
 180            klass.generator_class.SUPPORTS_UESCAPE = False
 181
 182        if enum not in ("", "databricks", "hive", "spark", "spark2"):
 183            modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy()
 184            for modifier in ("cluster", "distribute", "sort"):
 185                modifier_transforms.pop(modifier, None)
 186
 187            klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms
 188
 189        if enum not in ("", "doris", "mysql"):
 190            klass.parser_class.ID_VAR_TOKENS = klass.parser_class.ID_VAR_TOKENS | {
 191                TokenType.STRAIGHT_JOIN,
 192            }
 193            klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
 194                TokenType.STRAIGHT_JOIN,
 195            }
 196
 197        if not klass.SUPPORTS_SEMI_ANTI_JOIN:
 198            klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
 199                TokenType.ANTI,
 200                TokenType.SEMI,
 201            }
 202
 203        return klass
 204
 205
 206class Dialect(metaclass=_Dialect):
 207    INDEX_OFFSET = 0
 208    """The base index offset for arrays."""
 209
 210    WEEK_OFFSET = 0
 211    """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday."""
 212
 213    UNNEST_COLUMN_ONLY = False
 214    """Whether `UNNEST` table aliases are treated as column aliases."""
 215
 216    ALIAS_POST_TABLESAMPLE = False
 217    """Whether the table alias comes after tablesample."""
 218
 219    TABLESAMPLE_SIZE_IS_PERCENT = False
 220    """Whether a size in the table sample clause represents percentage."""
 221
 222    NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
 223    """Specifies the strategy according to which identifiers should be normalized."""
 224
 225    IDENTIFIERS_CAN_START_WITH_DIGIT = False
 226    """Whether an unquoted identifier can start with a digit."""
 227
 228    DPIPE_IS_STRING_CONCAT = True
 229    """Whether the DPIPE token (`||`) is a string concatenation operator."""
 230
 231    STRICT_STRING_CONCAT = False
 232    """Whether `CONCAT`'s arguments must be strings."""
 233
 234    SUPPORTS_USER_DEFINED_TYPES = True
 235    """Whether user-defined data types are supported."""
 236
 237    SUPPORTS_SEMI_ANTI_JOIN = True
 238    """Whether `SEMI` or `ANTI` joins are supported."""
 239
 240    SUPPORTS_COLUMN_JOIN_MARKS = False
 241    """Whether the old-style outer join (+) syntax is supported."""
 242
 243    COPY_PARAMS_ARE_CSV = True
 244    """Separator of COPY statement parameters."""
 245
 246    NORMALIZE_FUNCTIONS: bool | str = "upper"
 247    """
 248    Determines how function names are going to be normalized.
 249    Possible values:
 250        "upper" or True: Convert names to uppercase.
 251        "lower": Convert names to lowercase.
 252        False: Disables function name normalization.
 253    """
 254
 255    LOG_BASE_FIRST: t.Optional[bool] = True
 256    """
 257    Whether the base comes first in the `LOG` function.
 258    Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`)
 259    """
 260
 261    NULL_ORDERING = "nulls_are_small"
 262    """
 263    Default `NULL` ordering method to use if not explicitly set.
 264    Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"`
 265    """
 266
 267    TYPED_DIVISION = False
 268    """
 269    Whether the behavior of `a / b` depends on the types of `a` and `b`.
 270    False means `a / b` is always float division.
 271    True means `a / b` is integer division if both `a` and `b` are integers.
 272    """
 273
 274    SAFE_DIVISION = False
 275    """Whether division by zero throws an error (`False`) or returns NULL (`True`)."""
 276
 277    CONCAT_COALESCE = False
 278    """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string."""
 279
 280    HEX_LOWERCASE = False
 281    """Whether the `HEX` function returns a lowercase hexadecimal string."""
 282
 283    DATE_FORMAT = "'%Y-%m-%d'"
 284    DATEINT_FORMAT = "'%Y%m%d'"
 285    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
 286
 287    TIME_MAPPING: t.Dict[str, str] = {}
 288    """Associates this dialect's time formats with their equivalent Python `strftime` formats."""
 289
 290    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
 291    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
 292    FORMAT_MAPPING: t.Dict[str, str] = {}
 293    """
 294    Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`.
 295    If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
 296    """
 297
 298    UNESCAPED_SEQUENCES: t.Dict[str, str] = {}
 299    """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`)."""
 300
 301    PSEUDOCOLUMNS: t.Set[str] = set()
 302    """
 303    Columns that are auto-generated by the engine corresponding to this dialect.
 304    For example, such columns may be excluded from `SELECT *` queries.
 305    """
 306
 307    PREFER_CTE_ALIAS_COLUMN = False
 308    """
 309    Some dialects, such as Snowflake, allow you to reference a CTE column alias in the
 310    HAVING clause of the CTE. This flag will cause the CTE alias columns to override
 311    any projection aliases in the subquery.
 312
 313    For example,
 314        WITH y(c) AS (
 315            SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0
 316        ) SELECT c FROM y;
 317
 318        will be rewritten as
 319
 320        WITH y(c) AS (
 321            SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
 322        ) SELECT c FROM y;
 323    """
 324
 325    COPY_PARAMS_ARE_CSV = True
 326    """
 327    Whether COPY statement parameters are separated by comma or whitespace
 328    """
 329
 330    FORCE_EARLY_ALIAS_REF_EXPANSION = False
 331    """
 332    Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()).
 333
 334    For example:
 335        WITH data AS (
 336        SELECT
 337            1 AS id,
 338            2 AS my_id
 339        )
 340        SELECT
 341            id AS my_id
 342        FROM
 343            data
 344        WHERE
 345            my_id = 1
 346        GROUP BY
 347            my_id,
 348        HAVING
 349            my_id = 1
 350
 351    In most dialects "my_id" would refer to "data.my_id" (which is done in _qualify_columns()) across the query, except:
 352        - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1"
 353        - Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1"
 354    """
 355
 356    EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False
 357    """Whether alias reference expansion before qualification should only happen for the GROUP BY clause."""
 358
 359    # --- Autofilled ---
 360
 361    tokenizer_class = Tokenizer
 362    jsonpath_tokenizer_class = JSONPathTokenizer
 363    parser_class = Parser
 364    generator_class = Generator
 365
 366    # A trie of the time_mapping keys
 367    TIME_TRIE: t.Dict = {}
 368    FORMAT_TRIE: t.Dict = {}
 369
 370    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
 371    INVERSE_TIME_TRIE: t.Dict = {}
 372    INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {}
 373    INVERSE_FORMAT_TRIE: t.Dict = {}
 374
 375    ESCAPED_SEQUENCES: t.Dict[str, str] = {}
 376
 377    # Delimiters for string literals and identifiers
 378    QUOTE_START = "'"
 379    QUOTE_END = "'"
 380    IDENTIFIER_START = '"'
 381    IDENTIFIER_END = '"'
 382
 383    # Delimiters for bit, hex, byte and unicode literals
 384    BIT_START: t.Optional[str] = None
 385    BIT_END: t.Optional[str] = None
 386    HEX_START: t.Optional[str] = None
 387    HEX_END: t.Optional[str] = None
 388    BYTE_START: t.Optional[str] = None
 389    BYTE_END: t.Optional[str] = None
 390    UNICODE_START: t.Optional[str] = None
 391    UNICODE_END: t.Optional[str] = None
 392
 393    DATE_PART_MAPPING = {
 394        "Y": "YEAR",
 395        "YY": "YEAR",
 396        "YYY": "YEAR",
 397        "YYYY": "YEAR",
 398        "YR": "YEAR",
 399        "YEARS": "YEAR",
 400        "YRS": "YEAR",
 401        "MM": "MONTH",
 402        "MON": "MONTH",
 403        "MONS": "MONTH",
 404        "MONTHS": "MONTH",
 405        "D": "DAY",
 406        "DD": "DAY",
 407        "DAYS": "DAY",
 408        "DAYOFMONTH": "DAY",
 409        "DAY OF WEEK": "DAYOFWEEK",
 410        "WEEKDAY": "DAYOFWEEK",
 411        "DOW": "DAYOFWEEK",
 412        "DW": "DAYOFWEEK",
 413        "WEEKDAY_ISO": "DAYOFWEEKISO",
 414        "DOW_ISO": "DAYOFWEEKISO",
 415        "DW_ISO": "DAYOFWEEKISO",
 416        "DAY OF YEAR": "DAYOFYEAR",
 417        "DOY": "DAYOFYEAR",
 418        "DY": "DAYOFYEAR",
 419        "W": "WEEK",
 420        "WK": "WEEK",
 421        "WEEKOFYEAR": "WEEK",
 422        "WOY": "WEEK",
 423        "WY": "WEEK",
 424        "WEEK_ISO": "WEEKISO",
 425        "WEEKOFYEARISO": "WEEKISO",
 426        "WEEKOFYEAR_ISO": "WEEKISO",
 427        "Q": "QUARTER",
 428        "QTR": "QUARTER",
 429        "QTRS": "QUARTER",
 430        "QUARTERS": "QUARTER",
 431        "H": "HOUR",
 432        "HH": "HOUR",
 433        "HR": "HOUR",
 434        "HOURS": "HOUR",
 435        "HRS": "HOUR",
 436        "M": "MINUTE",
 437        "MI": "MINUTE",
 438        "MIN": "MINUTE",
 439        "MINUTES": "MINUTE",
 440        "MINS": "MINUTE",
 441        "S": "SECOND",
 442        "SEC": "SECOND",
 443        "SECONDS": "SECOND",
 444        "SECS": "SECOND",
 445        "MS": "MILLISECOND",
 446        "MSEC": "MILLISECOND",
 447        "MSECS": "MILLISECOND",
 448        "MSECOND": "MILLISECOND",
 449        "MSECONDS": "MILLISECOND",
 450        "MILLISEC": "MILLISECOND",
 451        "MILLISECS": "MILLISECOND",
 452        "MILLISECON": "MILLISECOND",
 453        "MILLISECONDS": "MILLISECOND",
 454        "US": "MICROSECOND",
 455        "USEC": "MICROSECOND",
 456        "USECS": "MICROSECOND",
 457        "MICROSEC": "MICROSECOND",
 458        "MICROSECS": "MICROSECOND",
 459        "USECOND": "MICROSECOND",
 460        "USECONDS": "MICROSECOND",
 461        "MICROSECONDS": "MICROSECOND",
 462        "NS": "NANOSECOND",
 463        "NSEC": "NANOSECOND",
 464        "NANOSEC": "NANOSECOND",
 465        "NSECOND": "NANOSECOND",
 466        "NSECONDS": "NANOSECOND",
 467        "NANOSECS": "NANOSECOND",
 468        "EPOCH_SECOND": "EPOCH",
 469        "EPOCH_SECONDS": "EPOCH",
 470        "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND",
 471        "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND",
 472        "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND",
 473        "TZH": "TIMEZONE_HOUR",
 474        "TZM": "TIMEZONE_MINUTE",
 475        "DEC": "DECADE",
 476        "DECS": "DECADE",
 477        "DECADES": "DECADE",
 478        "MIL": "MILLENIUM",
 479        "MILS": "MILLENIUM",
 480        "MILLENIA": "MILLENIUM",
 481        "C": "CENTURY",
 482        "CENT": "CENTURY",
 483        "CENTS": "CENTURY",
 484        "CENTURIES": "CENTURY",
 485    }
 486
 487    @classmethod
 488    def get_or_raise(cls, dialect: DialectType) -> Dialect:
 489        """
 490        Look up a dialect in the global dialect registry and return it if it exists.
 491
 492        Args:
 493            dialect: The target dialect. If this is a string, it can be optionally followed by
 494                additional key-value pairs that are separated by commas and are used to specify
 495                dialect settings, such as whether the dialect's identifiers are case-sensitive.
 496
 497        Example:
 498            >>> dialect = dialect_class = get_or_raise("duckdb")
 499            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
 500
 501        Returns:
 502            The corresponding Dialect instance.
 503        """
 504
 505        if not dialect:
 506            return cls()
 507        if isinstance(dialect, _Dialect):
 508            return dialect()
 509        if isinstance(dialect, Dialect):
 510            return dialect
 511        if isinstance(dialect, str):
 512            try:
 513                dialect_name, *kv_pairs = dialect.split(",")
 514                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
 515            except ValueError:
 516                raise ValueError(
 517                    f"Invalid dialect format: '{dialect}'. "
 518                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
 519                )
 520
 521            result = cls.get(dialect_name.strip())
 522            if not result:
 523                from difflib import get_close_matches
 524
 525                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
 526                if similar:
 527                    similar = f" Did you mean {similar}?"
 528
 529                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
 530
 531            return result(**kwargs)
 532
 533        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
 534
 535    @classmethod
 536    def format_time(
 537        cls, expression: t.Optional[str | exp.Expression]
 538    ) -> t.Optional[exp.Expression]:
 539        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
 540        if isinstance(expression, str):
 541            return exp.Literal.string(
 542                # the time formats are quoted
 543                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
 544            )
 545
 546        if expression and expression.is_string:
 547            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
 548
 549        return expression
 550
 551    def __init__(self, **kwargs) -> None:
 552        normalization_strategy = kwargs.pop("normalization_strategy", None)
 553
 554        if normalization_strategy is None:
 555            self.normalization_strategy = self.NORMALIZATION_STRATEGY
 556        else:
 557            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
 558
 559        self.settings = kwargs
 560
 561    def __eq__(self, other: t.Any) -> bool:
 562        # Does not currently take dialect state into account
 563        return type(self) == other
 564
 565    def __hash__(self) -> int:
 566        # Does not currently take dialect state into account
 567        return hash(type(self))
 568
 569    def normalize_identifier(self, expression: E) -> E:
 570        """
 571        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
 572
 573        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
 574        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
 575        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
 576        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
 577
 578        There are also dialects like Spark, which are case-insensitive even when quotes are
 579        present, and dialects like MySQL, whose resolution rules match those employed by the
 580        underlying operating system, for example they may always be case-sensitive in Linux.
 581
 582        Finally, the normalization behavior of some engines can even be controlled through flags,
 583        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
 584
 585        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
 586        that it can analyze queries in the optimizer and successfully capture their semantics.
 587        """
 588        if (
 589            isinstance(expression, exp.Identifier)
 590            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
 591            and (
 592                not expression.quoted
 593                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
 594            )
 595        ):
 596            expression.set(
 597                "this",
 598                (
 599                    expression.this.upper()
 600                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
 601                    else expression.this.lower()
 602                ),
 603            )
 604
 605        return expression
 606
 607    def case_sensitive(self, text: str) -> bool:
 608        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
 609        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
 610            return False
 611
 612        unsafe = (
 613            str.islower
 614            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
 615            else str.isupper
 616        )
 617        return any(unsafe(char) for char in text)
 618
 619    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
 620        """Checks if text can be identified given an identify option.
 621
 622        Args:
 623            text: The text to check.
 624            identify:
 625                `"always"` or `True`: Always returns `True`.
 626                `"safe"`: Only returns `True` if the identifier is case-insensitive.
 627
 628        Returns:
 629            Whether the given text can be identified.
 630        """
 631        if identify is True or identify == "always":
 632            return True
 633
 634        if identify == "safe":
 635            return not self.case_sensitive(text)
 636
 637        return False
 638
 639    def quote_identifier(self, expression: E, identify: bool = True) -> E:
 640        """
 641        Adds quotes to a given identifier.
 642
 643        Args:
 644            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
 645            identify: If set to `False`, the quotes will only be added if the identifier is deemed
 646                "unsafe", with respect to its characters and this dialect's normalization strategy.
 647        """
 648        if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
 649            name = expression.this
 650            expression.set(
 651                "quoted",
 652                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
 653            )
 654
 655        return expression
 656
 657    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
 658        if isinstance(path, exp.Literal):
 659            path_text = path.name
 660            if path.is_number:
 661                path_text = f"[{path_text}]"
 662            try:
 663                return parse_json_path(path_text, self)
 664            except ParseError as e:
 665                logger.warning(f"Invalid JSON path syntax. {str(e)}")
 666
 667        return path
 668
 669    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
 670        return self.parser(**opts).parse(self.tokenize(sql), sql)
 671
 672    def parse_into(
 673        self, expression_type: exp.IntoType, sql: str, **opts
 674    ) -> t.List[t.Optional[exp.Expression]]:
 675        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
 676
 677    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
 678        return self.generator(**opts).generate(expression, copy=copy)
 679
 680    def transpile(self, sql: str, **opts) -> t.List[str]:
 681        return [
 682            self.generate(expression, copy=False, **opts) if expression else ""
 683            for expression in self.parse(sql)
 684        ]
 685
 686    def tokenize(self, sql: str) -> t.List[Token]:
 687        return self.tokenizer.tokenize(sql)
 688
 689    @property
 690    def tokenizer(self) -> Tokenizer:
 691        return self.tokenizer_class(dialect=self)
 692
 693    @property
 694    def jsonpath_tokenizer(self) -> JSONPathTokenizer:
 695        return self.jsonpath_tokenizer_class(dialect=self)
 696
 697    def parser(self, **opts) -> Parser:
 698        return self.parser_class(dialect=self, **opts)
 699
 700    def generator(self, **opts) -> Generator:
 701        return self.generator_class(dialect=self, **opts)
 702
 703
 704DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
 705
 706
 707def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
 708    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
 709
 710
 711def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
 712    if expression.args.get("accuracy"):
 713        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
 714    return self.func("APPROX_COUNT_DISTINCT", expression.this)
 715
 716
 717def if_sql(
 718    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
 719) -> t.Callable[[Generator, exp.If], str]:
 720    def _if_sql(self: Generator, expression: exp.If) -> str:
 721        return self.func(
 722            name,
 723            expression.this,
 724            expression.args.get("true"),
 725            expression.args.get("false") or false_value,
 726        )
 727
 728    return _if_sql
 729
 730
 731def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
 732    this = expression.this
 733    if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string:
 734        this.replace(exp.cast(this, exp.DataType.Type.JSON))
 735
 736    return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
 737
 738
 739def inline_array_sql(self: Generator, expression: exp.Array) -> str:
 740    return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]"
 741
 742
 743def inline_array_unless_query(self: Generator, expression: exp.Array) -> str:
 744    elem = seq_get(expression.expressions, 0)
 745    if isinstance(elem, exp.Expression) and elem.find(exp.Query):
 746        return self.func("ARRAY", elem)
 747    return inline_array_sql(self, expression)
 748
 749
 750def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
 751    return self.like_sql(
 752        exp.Like(
 753            this=exp.Lower(this=expression.this), expression=exp.Lower(this=expression.expression)
 754        )
 755    )
 756
 757
 758def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
 759    zone = self.sql(expression, "this")
 760    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
 761
 762
 763def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
 764    if expression.args.get("recursive"):
 765        self.unsupported("Recursive CTEs are unsupported")
 766        expression.args["recursive"] = False
 767    return self.with_sql(expression)
 768
 769
 770def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
 771    n = self.sql(expression, "this")
 772    d = self.sql(expression, "expression")
 773    return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)"
 774
 775
 776def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
 777    self.unsupported("TABLESAMPLE unsupported")
 778    return self.sql(expression.this)
 779
 780
 781def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
 782    self.unsupported("PIVOT unsupported")
 783    return ""
 784
 785
 786def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
 787    return self.cast_sql(expression)
 788
 789
 790def no_comment_column_constraint_sql(
 791    self: Generator, expression: exp.CommentColumnConstraint
 792) -> str:
 793    self.unsupported("CommentColumnConstraint unsupported")
 794    return ""
 795
 796
 797def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
 798    self.unsupported("MAP_FROM_ENTRIES unsupported")
 799    return ""
 800
 801
 802def str_position_sql(
 803    self: Generator, expression: exp.StrPosition, generate_instance: bool = False
 804) -> str:
 805    this = self.sql(expression, "this")
 806    substr = self.sql(expression, "substr")
 807    position = self.sql(expression, "position")
 808    instance = expression.args.get("instance") if generate_instance else None
 809    position_offset = ""
 810
 811    if position:
 812        # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects
 813        this = self.func("SUBSTR", this, position)
 814        position_offset = f" + {position} - 1"
 815
 816    return self.func("STRPOS", this, substr, instance) + position_offset
 817
 818
 819def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
 820    return (
 821        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
 822    )
 823
 824
 825def var_map_sql(
 826    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
 827) -> str:
 828    keys = expression.args["keys"]
 829    values = expression.args["values"]
 830
 831    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
 832        self.unsupported("Cannot convert array columns into map.")
 833        return self.func(map_func_name, keys, values)
 834
 835    args = []
 836    for key, value in zip(keys.expressions, values.expressions):
 837        args.append(self.sql(key))
 838        args.append(self.sql(value))
 839
 840    return self.func(map_func_name, *args)
 841
 842
 843def build_formatted_time(
 844    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
 845) -> t.Callable[[t.List], E]:
 846    """Helper used for time expressions.
 847
 848    Args:
 849        exp_class: the expression class to instantiate.
 850        dialect: target sql dialect.
 851        default: the default format, True being time.
 852
 853    Returns:
 854        A callable that can be used to return the appropriately formatted time expression.
 855    """
 856
 857    def _builder(args: t.List):
 858        return exp_class(
 859            this=seq_get(args, 0),
 860            format=Dialect[dialect].format_time(
 861                seq_get(args, 1)
 862                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
 863            ),
 864        )
 865
 866    return _builder
 867
 868
 869def time_format(
 870    dialect: DialectType = None,
 871) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
 872    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
 873        """
 874        Returns the time format for a given expression, unless it's equivalent
 875        to the default time format of the dialect of interest.
 876        """
 877        time_format = self.format_time(expression)
 878        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
 879
 880    return _time_format
 881
 882
 883def build_date_delta(
 884    exp_class: t.Type[E],
 885    unit_mapping: t.Optional[t.Dict[str, str]] = None,
 886    default_unit: t.Optional[str] = "DAY",
 887) -> t.Callable[[t.List], E]:
 888    def _builder(args: t.List) -> E:
 889        unit_based = len(args) == 3
 890        this = args[2] if unit_based else seq_get(args, 0)
 891        unit = None
 892        if unit_based or default_unit:
 893            unit = args[0] if unit_based else exp.Literal.string(default_unit)
 894            unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
 895        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
 896
 897    return _builder
 898
 899
 900def build_date_delta_with_interval(
 901    expression_class: t.Type[E],
 902) -> t.Callable[[t.List], t.Optional[E]]:
 903    def _builder(args: t.List) -> t.Optional[E]:
 904        if len(args) < 2:
 905            return None
 906
 907        interval = args[1]
 908
 909        if not isinstance(interval, exp.Interval):
 910            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
 911
 912        expression = interval.this
 913        if expression and expression.is_string:
 914            expression = exp.Literal.number(expression.this)
 915
 916        return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval))
 917
 918    return _builder
 919
 920
 921def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
 922    unit = seq_get(args, 0)
 923    this = seq_get(args, 1)
 924
 925    if isinstance(this, exp.Cast) and this.is_type("date"):
 926        return exp.DateTrunc(unit=unit, this=this)
 927    return exp.TimestampTrunc(this=this, unit=unit)
 928
 929
 930def date_add_interval_sql(
 931    data_type: str, kind: str
 932) -> t.Callable[[Generator, exp.Expression], str]:
 933    def func(self: Generator, expression: exp.Expression) -> str:
 934        this = self.sql(expression, "this")
 935        interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression))
 936        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
 937
 938    return func
 939
 940
 941def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]:
 942    def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
 943        args = [unit_to_str(expression), expression.this]
 944        if zone:
 945            args.append(expression.args.get("zone"))
 946        return self.func("DATE_TRUNC", *args)
 947
 948    return _timestamptrunc_sql
 949
 950
 951def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
 952    zone = expression.args.get("zone")
 953    if not zone:
 954        from sqlglot.optimizer.annotate_types import annotate_types
 955
 956        target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP
 957        return self.sql(exp.cast(expression.this, target_type))
 958    if zone.name.lower() in TIMEZONES:
 959        return self.sql(
 960            exp.AtTimeZone(
 961                this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP),
 962                zone=zone,
 963            )
 964        )
 965    return self.func("TIMESTAMP", expression.this, zone)
 966
 967
 968def no_time_sql(self: Generator, expression: exp.Time) -> str:
 969    # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME)
 970    this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ)
 971    expr = exp.cast(
 972        exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME
 973    )
 974    return self.sql(expr)
 975
 976
 977def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str:
 978    this = expression.this
 979    expr = expression.expression
 980
 981    if expr.name.lower() in TIMEZONES:
 982        # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP)
 983        this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ)
 984        this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP)
 985        return self.sql(this)
 986
 987    this = exp.cast(this, exp.DataType.Type.DATE)
 988    expr = exp.cast(expr, exp.DataType.Type.TIME)
 989
 990    return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP))
 991
 992
 993def locate_to_strposition(args: t.List) -> exp.Expression:
 994    return exp.StrPosition(
 995        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
 996    )
 997
 998
 999def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
1000    return self.func(
1001        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
1002    )
1003
1004
1005def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
1006    return self.sql(
1007        exp.Substring(
1008            this=expression.this, start=exp.Literal.number(1), length=expression.expression
1009        )
1010    )
1011
1012
1013def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
1014    return self.sql(
1015        exp.Substring(
1016            this=expression.this,
1017            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
1018        )
1019    )
1020
1021
1022def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
1023    return self.sql(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP))
1024
1025
1026def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
1027    return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE))
1028
1029
1030# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8
1031def encode_decode_sql(
1032    self: Generator, expression: exp.Expression, name: str, replace: bool = True
1033) -> str:
1034    charset = expression.args.get("charset")
1035    if charset and charset.name.lower() != "utf-8":
1036        self.unsupported(f"Expected utf-8 character set, got {charset}.")
1037
1038    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
1039
1040
1041def min_or_least(self: Generator, expression: exp.Min) -> str:
1042    name = "LEAST" if expression.expressions else "MIN"
1043    return rename_func(name)(self, expression)
1044
1045
1046def max_or_greatest(self: Generator, expression: exp.Max) -> str:
1047    name = "GREATEST" if expression.expressions else "MAX"
1048    return rename_func(name)(self, expression)
1049
1050
1051def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
1052    cond = expression.this
1053
1054    if isinstance(expression.this, exp.Distinct):
1055        cond = expression.this.expressions[0]
1056        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
1057
1058    return self.func("sum", exp.func("if", cond, 1, 0))
1059
1060
1061def trim_sql(self: Generator, expression: exp.Trim) -> str:
1062    target = self.sql(expression, "this")
1063    trim_type = self.sql(expression, "position")
1064    remove_chars = self.sql(expression, "expression")
1065    collation = self.sql(expression, "collation")
1066
1067    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
1068    if not remove_chars and not collation:
1069        return self.trim_sql(expression)
1070
1071    trim_type = f"{trim_type} " if trim_type else ""
1072    remove_chars = f"{remove_chars} " if remove_chars else ""
1073    from_part = "FROM " if trim_type or remove_chars else ""
1074    collation = f" COLLATE {collation}" if collation else ""
1075    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
1076
1077
1078def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
1079    return self.func("STRPTIME", expression.this, self.format_time(expression))
1080
1081
1082def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
1083    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
1084
1085
1086def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
1087    delim, *rest_args = expression.expressions
1088    return self.sql(
1089        reduce(
1090            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
1091            rest_args,
1092        )
1093    )
1094
1095
1096def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
1097    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
1098    if bad_args:
1099        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
1100
1101    return self.func(
1102        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
1103    )
1104
1105
1106def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
1107    bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers")))
1108    if bad_args:
1109        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
1110
1111    return self.func(
1112        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
1113    )
1114
1115
1116def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
1117    names = []
1118    for agg in aggregations:
1119        if isinstance(agg, exp.Alias):
1120            names.append(agg.alias)
1121        else:
1122            """
1123            This case corresponds to aggregations without aliases being used as suffixes
1124            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
1125            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
1126            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
1127            """
1128            agg_all_unquoted = agg.transform(
1129                lambda node: (
1130                    exp.Identifier(this=node.name, quoted=False)
1131                    if isinstance(node, exp.Identifier)
1132                    else node
1133                )
1134            )
1135            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
1136
1137    return names
1138
1139
1140def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
1141    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
1142
1143
1144# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects
1145def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
1146    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
1147
1148
1149def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
1150    return self.func("MAX", expression.this)
1151
1152
1153def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
1154    a = self.sql(expression.left)
1155    b = self.sql(expression.right)
1156    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
1157
1158
1159def is_parse_json(expression: exp.Expression) -> bool:
1160    return isinstance(expression, exp.ParseJSON) or (
1161        isinstance(expression, exp.Cast) and expression.is_type("json")
1162    )
1163
1164
1165def isnull_to_is_null(args: t.List) -> exp.Expression:
1166    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
1167
1168
1169def generatedasidentitycolumnconstraint_sql(
1170    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
1171) -> str:
1172    start = self.sql(expression, "start") or "1"
1173    increment = self.sql(expression, "increment") or "1"
1174    return f"IDENTITY({start}, {increment})"
1175
1176
1177def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
1178    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
1179        if expression.args.get("count"):
1180            self.unsupported(f"Only two arguments are supported in function {name}.")
1181
1182        return self.func(name, expression.this, expression.expression)
1183
1184    return _arg_max_or_min_sql
1185
1186
1187def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
1188    this = expression.this.copy()
1189
1190    return_type = expression.return_type
1191    if return_type.is_type(exp.DataType.Type.DATE):
1192        # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
1193        # can truncate timestamp strings, because some dialects can't cast them to DATE
1194        this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
1195
1196    expression.this.replace(exp.cast(this, return_type))
1197    return expression
1198
1199
1200def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
1201    def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
1202        if cast and isinstance(expression, exp.TsOrDsAdd):
1203            expression = ts_or_ds_add_cast(expression)
1204
1205        return self.func(
1206            name,
1207            unit_to_var(expression),
1208            expression.expression,
1209            expression.this,
1210        )
1211
1212    return _delta_sql
1213
1214
1215def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
1216    unit = expression.args.get("unit")
1217
1218    if isinstance(unit, exp.Placeholder):
1219        return unit
1220    if unit:
1221        return exp.Literal.string(unit.name)
1222    return exp.Literal.string(default) if default else None
1223
1224
1225def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
1226    unit = expression.args.get("unit")
1227
1228    if isinstance(unit, (exp.Var, exp.Placeholder)):
1229        return unit
1230    return exp.Var(this=default) if default else None
1231
1232
1233@t.overload
1234def map_date_part(part: exp.Expression, dialect: DialectType = Dialect) -> exp.Var:
1235    pass
1236
1237
1238@t.overload
1239def map_date_part(
1240    part: t.Optional[exp.Expression], dialect: DialectType = Dialect
1241) -> t.Optional[exp.Expression]:
1242    pass
1243
1244
1245def map_date_part(part, dialect: DialectType = Dialect):
1246    mapped = (
1247        Dialect.get_or_raise(dialect).DATE_PART_MAPPING.get(part.name.upper()) if part else None
1248    )
1249    return exp.var(mapped) if mapped else part
1250
1251
1252def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
1253    trunc_curr_date = exp.func("date_trunc", "month", expression.this)
1254    plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
1255    minus_one_day = exp.func("date_sub", plus_one_month, 1, "day")
1256
1257    return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
1258
1259
1260def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
1261    """Remove table refs from columns in when statements."""
1262    alias = expression.this.args.get("alias")
1263
1264    def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]:
1265        return self.dialect.normalize_identifier(identifier).name if identifier else None
1266
1267    targets = {normalize(expression.this.this)}
1268
1269    if alias:
1270        targets.add(normalize(alias.this))
1271
1272    for when in expression.expressions:
1273        when.transform(
1274            lambda node: (
1275                exp.column(node.this)
1276                if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
1277                else node
1278            ),
1279            copy=False,
1280        )
1281
1282    return self.merge_sql(expression)
1283
1284
1285def build_json_extract_path(
1286    expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False
1287) -> t.Callable[[t.List], F]:
1288    def _builder(args: t.List) -> F:
1289        segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
1290        for arg in args[1:]:
1291            if not isinstance(arg, exp.Literal):
1292                # We use the fallback parser because we can't really transpile non-literals safely
1293                return expr_type.from_arg_list(args)
1294
1295            text = arg.name
1296            if is_int(text):
1297                index = int(text)
1298                segments.append(
1299                    exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1)
1300                )
1301            else:
1302                segments.append(exp.JSONPathKey(this=text))
1303
1304        # This is done to avoid failing in the expression validator due to the arg count
1305        del args[2:]
1306        return expr_type(
1307            this=seq_get(args, 0),
1308            expression=exp.JSONPath(expressions=segments),
1309            only_json_types=arrow_req_json_type,
1310        )
1311
1312    return _builder
1313
1314
1315def json_extract_segments(
1316    name: str, quoted_index: bool = True, op: t.Optional[str] = None
1317) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]:
1318    def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
1319        path = expression.expression
1320        if not isinstance(path, exp.JSONPath):
1321            return rename_func(name)(self, expression)
1322
1323        segments = []
1324        for segment in path.expressions:
1325            path = self.sql(segment)
1326            if path:
1327                if isinstance(segment, exp.JSONPathPart) and (
1328                    quoted_index or not isinstance(segment, exp.JSONPathSubscript)
1329                ):
1330                    path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}"
1331
1332                segments.append(path)
1333
1334        if op:
1335            return f" {op} ".join([self.sql(expression.this), *segments])
1336        return self.func(name, expression.this, *segments)
1337
1338    return _json_extract_segments
1339
1340
1341def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str:
1342    if isinstance(expression.this, exp.JSONPathWildcard):
1343        self.unsupported("Unsupported wildcard in JSONPathKey expression")
1344
1345    return expression.name
1346
1347
1348def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str:
1349    cond = expression.expression
1350    if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1:
1351        alias = cond.expressions[0]
1352        cond = cond.this
1353    elif isinstance(cond, exp.Predicate):
1354        alias = "_u"
1355    else:
1356        self.unsupported("Unsupported filter condition")
1357        return ""
1358
1359    unnest = exp.Unnest(expressions=[expression.this])
1360    filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond)
1361    return self.sql(exp.Array(expressions=[filtered]))
1362
1363
1364def to_number_with_nls_param(self: Generator, expression: exp.ToNumber) -> str:
1365    return self.func(
1366        "TO_NUMBER",
1367        expression.this,
1368        expression.args.get("format"),
1369        expression.args.get("nlsparam"),
1370    )
1371
1372
1373def build_default_decimal_type(
1374    precision: t.Optional[int] = None, scale: t.Optional[int] = None
1375) -> t.Callable[[exp.DataType], exp.DataType]:
1376    def _builder(dtype: exp.DataType) -> exp.DataType:
1377        if dtype.expressions or precision is None:
1378            return dtype
1379
1380        params = f"{precision}{f', {scale}' if scale is not None else ''}"
1381        return exp.DataType.build(f"DECIMAL({params})")
1382
1383    return _builder
1384
1385
1386def build_timestamp_from_parts(args: t.List) -> exp.Func:
1387    if len(args) == 2:
1388        # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept,
1389        # so we parse this into Anonymous for now instead of introducing complexity
1390        return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args)
1391
1392    return exp.TimestampFromParts.from_arg_list(args)
1393
1394
1395def sha256_sql(self: Generator, expression: exp.SHA2) -> str:
1396    return self.func(f"SHA{expression.text('length') or '256'}", expression.this)
logger = <Logger sqlglot (WARNING)>
UNESCAPED_SEQUENCES = {'\\a': '\x07', '\\b': '\x08', '\\f': '\x0c', '\\n': '\n', '\\r': '\r', '\\t': '\t', '\\v': '\x0b', '\\\\': '\\'}
class Dialects(builtins.str, enum.Enum):
41class Dialects(str, Enum):
42    """Dialects supported by SQLGLot."""
43
44    DIALECT = ""
45
46    ATHENA = "athena"
47    BIGQUERY = "bigquery"
48    CLICKHOUSE = "clickhouse"
49    DATABRICKS = "databricks"
50    DORIS = "doris"
51    DRILL = "drill"
52    DUCKDB = "duckdb"
53    HIVE = "hive"
54    MATERIALIZE = "materialize"
55    MYSQL = "mysql"
56    ORACLE = "oracle"
57    POSTGRES = "postgres"
58    PRESTO = "presto"
59    PRQL = "prql"
60    REDSHIFT = "redshift"
61    RISINGWAVE = "risingwave"
62    SNOWFLAKE = "snowflake"
63    SPARK = "spark"
64    SPARK2 = "spark2"
65    SQLITE = "sqlite"
66    STARROCKS = "starrocks"
67    TABLEAU = "tableau"
68    TERADATA = "teradata"
69    TRINO = "trino"
70    TSQL = "tsql"

Dialects supported by SQLGLot.

DIALECT = <Dialects.DIALECT: ''>
ATHENA = <Dialects.ATHENA: 'athena'>
BIGQUERY = <Dialects.BIGQUERY: 'bigquery'>
CLICKHOUSE = <Dialects.CLICKHOUSE: 'clickhouse'>
DATABRICKS = <Dialects.DATABRICKS: 'databricks'>
DORIS = <Dialects.DORIS: 'doris'>
DRILL = <Dialects.DRILL: 'drill'>
DUCKDB = <Dialects.DUCKDB: 'duckdb'>
HIVE = <Dialects.HIVE: 'hive'>
MATERIALIZE = <Dialects.MATERIALIZE: 'materialize'>
MYSQL = <Dialects.MYSQL: 'mysql'>
ORACLE = <Dialects.ORACLE: 'oracle'>
POSTGRES = <Dialects.POSTGRES: 'postgres'>
PRESTO = <Dialects.PRESTO: 'presto'>
PRQL = <Dialects.PRQL: 'prql'>
REDSHIFT = <Dialects.REDSHIFT: 'redshift'>
RISINGWAVE = <Dialects.RISINGWAVE: 'risingwave'>
SNOWFLAKE = <Dialects.SNOWFLAKE: 'snowflake'>
SPARK = <Dialects.SPARK: 'spark'>
SPARK2 = <Dialects.SPARK2: 'spark2'>
SQLITE = <Dialects.SQLITE: 'sqlite'>
STARROCKS = <Dialects.STARROCKS: 'starrocks'>
TABLEAU = <Dialects.TABLEAU: 'tableau'>
TERADATA = <Dialects.TERADATA: 'teradata'>
TRINO = <Dialects.TRINO: 'trino'>
TSQL = <Dialects.TSQL: 'tsql'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class NormalizationStrategy(builtins.str, sqlglot.helper.AutoName):
73class NormalizationStrategy(str, AutoName):
74    """Specifies the strategy according to which identifiers should be normalized."""
75
76    LOWERCASE = auto()
77    """Unquoted identifiers are lowercased."""
78
79    UPPERCASE = auto()
80    """Unquoted identifiers are uppercased."""
81
82    CASE_SENSITIVE = auto()
83    """Always case-sensitive, regardless of quotes."""
84
85    CASE_INSENSITIVE = auto()
86    """Always case-insensitive, regardless of quotes."""

Specifies the strategy according to which identifiers should be normalized.

LOWERCASE = <NormalizationStrategy.LOWERCASE: 'LOWERCASE'>

Unquoted identifiers are lowercased.

UPPERCASE = <NormalizationStrategy.UPPERCASE: 'UPPERCASE'>

Unquoted identifiers are uppercased.

CASE_SENSITIVE = <NormalizationStrategy.CASE_SENSITIVE: 'CASE_SENSITIVE'>

Always case-sensitive, regardless of quotes.

CASE_INSENSITIVE = <NormalizationStrategy.CASE_INSENSITIVE: 'CASE_INSENSITIVE'>

Always case-insensitive, regardless of quotes.

Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class Dialect:
207class Dialect(metaclass=_Dialect):
208    INDEX_OFFSET = 0
209    """The base index offset for arrays."""
210
211    WEEK_OFFSET = 0
212    """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday."""
213
214    UNNEST_COLUMN_ONLY = False
215    """Whether `UNNEST` table aliases are treated as column aliases."""
216
217    ALIAS_POST_TABLESAMPLE = False
218    """Whether the table alias comes after tablesample."""
219
220    TABLESAMPLE_SIZE_IS_PERCENT = False
221    """Whether a size in the table sample clause represents percentage."""
222
223    NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
224    """Specifies the strategy according to which identifiers should be normalized."""
225
226    IDENTIFIERS_CAN_START_WITH_DIGIT = False
227    """Whether an unquoted identifier can start with a digit."""
228
229    DPIPE_IS_STRING_CONCAT = True
230    """Whether the DPIPE token (`||`) is a string concatenation operator."""
231
232    STRICT_STRING_CONCAT = False
233    """Whether `CONCAT`'s arguments must be strings."""
234
235    SUPPORTS_USER_DEFINED_TYPES = True
236    """Whether user-defined data types are supported."""
237
238    SUPPORTS_SEMI_ANTI_JOIN = True
239    """Whether `SEMI` or `ANTI` joins are supported."""
240
241    SUPPORTS_COLUMN_JOIN_MARKS = False
242    """Whether the old-style outer join (+) syntax is supported."""
243
244    COPY_PARAMS_ARE_CSV = True
245    """Separator of COPY statement parameters."""
246
247    NORMALIZE_FUNCTIONS: bool | str = "upper"
248    """
249    Determines how function names are going to be normalized.
250    Possible values:
251        "upper" or True: Convert names to uppercase.
252        "lower": Convert names to lowercase.
253        False: Disables function name normalization.
254    """
255
256    LOG_BASE_FIRST: t.Optional[bool] = True
257    """
258    Whether the base comes first in the `LOG` function.
259    Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`)
260    """
261
262    NULL_ORDERING = "nulls_are_small"
263    """
264    Default `NULL` ordering method to use if not explicitly set.
265    Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"`
266    """
267
268    TYPED_DIVISION = False
269    """
270    Whether the behavior of `a / b` depends on the types of `a` and `b`.
271    False means `a / b` is always float division.
272    True means `a / b` is integer division if both `a` and `b` are integers.
273    """
274
275    SAFE_DIVISION = False
276    """Whether division by zero throws an error (`False`) or returns NULL (`True`)."""
277
278    CONCAT_COALESCE = False
279    """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string."""
280
281    HEX_LOWERCASE = False
282    """Whether the `HEX` function returns a lowercase hexadecimal string."""
283
284    DATE_FORMAT = "'%Y-%m-%d'"
285    DATEINT_FORMAT = "'%Y%m%d'"
286    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
287
288    TIME_MAPPING: t.Dict[str, str] = {}
289    """Associates this dialect's time formats with their equivalent Python `strftime` formats."""
290
291    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
292    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
293    FORMAT_MAPPING: t.Dict[str, str] = {}
294    """
295    Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`.
296    If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
297    """
298
299    UNESCAPED_SEQUENCES: t.Dict[str, str] = {}
300    """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`)."""
301
302    PSEUDOCOLUMNS: t.Set[str] = set()
303    """
304    Columns that are auto-generated by the engine corresponding to this dialect.
305    For example, such columns may be excluded from `SELECT *` queries.
306    """
307
308    PREFER_CTE_ALIAS_COLUMN = False
309    """
310    Some dialects, such as Snowflake, allow you to reference a CTE column alias in the
311    HAVING clause of the CTE. This flag will cause the CTE alias columns to override
312    any projection aliases in the subquery.
313
314    For example,
315        WITH y(c) AS (
316            SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0
317        ) SELECT c FROM y;
318
319        will be rewritten as
320
321        WITH y(c) AS (
322            SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
323        ) SELECT c FROM y;
324    """
325
326    COPY_PARAMS_ARE_CSV = True
327    """
328    Whether COPY statement parameters are separated by comma or whitespace
329    """
330
331    FORCE_EARLY_ALIAS_REF_EXPANSION = False
332    """
333    Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()).
334
335    For example:
336        WITH data AS (
337        SELECT
338            1 AS id,
339            2 AS my_id
340        )
341        SELECT
342            id AS my_id
343        FROM
344            data
345        WHERE
346            my_id = 1
347        GROUP BY
348            my_id,
349        HAVING
350            my_id = 1
351
352    In most dialects "my_id" would refer to "data.my_id" (which is done in _qualify_columns()) across the query, except:
353        - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1"
354        - Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1"
355    """
356
357    EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False
358    """Whether alias reference expansion before qualification should only happen for the GROUP BY clause."""
359
360    # --- Autofilled ---
361
362    tokenizer_class = Tokenizer
363    jsonpath_tokenizer_class = JSONPathTokenizer
364    parser_class = Parser
365    generator_class = Generator
366
367    # A trie of the time_mapping keys
368    TIME_TRIE: t.Dict = {}
369    FORMAT_TRIE: t.Dict = {}
370
371    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
372    INVERSE_TIME_TRIE: t.Dict = {}
373    INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {}
374    INVERSE_FORMAT_TRIE: t.Dict = {}
375
376    ESCAPED_SEQUENCES: t.Dict[str, str] = {}
377
378    # Delimiters for string literals and identifiers
379    QUOTE_START = "'"
380    QUOTE_END = "'"
381    IDENTIFIER_START = '"'
382    IDENTIFIER_END = '"'
383
384    # Delimiters for bit, hex, byte and unicode literals
385    BIT_START: t.Optional[str] = None
386    BIT_END: t.Optional[str] = None
387    HEX_START: t.Optional[str] = None
388    HEX_END: t.Optional[str] = None
389    BYTE_START: t.Optional[str] = None
390    BYTE_END: t.Optional[str] = None
391    UNICODE_START: t.Optional[str] = None
392    UNICODE_END: t.Optional[str] = None
393
394    DATE_PART_MAPPING = {
395        "Y": "YEAR",
396        "YY": "YEAR",
397        "YYY": "YEAR",
398        "YYYY": "YEAR",
399        "YR": "YEAR",
400        "YEARS": "YEAR",
401        "YRS": "YEAR",
402        "MM": "MONTH",
403        "MON": "MONTH",
404        "MONS": "MONTH",
405        "MONTHS": "MONTH",
406        "D": "DAY",
407        "DD": "DAY",
408        "DAYS": "DAY",
409        "DAYOFMONTH": "DAY",
410        "DAY OF WEEK": "DAYOFWEEK",
411        "WEEKDAY": "DAYOFWEEK",
412        "DOW": "DAYOFWEEK",
413        "DW": "DAYOFWEEK",
414        "WEEKDAY_ISO": "DAYOFWEEKISO",
415        "DOW_ISO": "DAYOFWEEKISO",
416        "DW_ISO": "DAYOFWEEKISO",
417        "DAY OF YEAR": "DAYOFYEAR",
418        "DOY": "DAYOFYEAR",
419        "DY": "DAYOFYEAR",
420        "W": "WEEK",
421        "WK": "WEEK",
422        "WEEKOFYEAR": "WEEK",
423        "WOY": "WEEK",
424        "WY": "WEEK",
425        "WEEK_ISO": "WEEKISO",
426        "WEEKOFYEARISO": "WEEKISO",
427        "WEEKOFYEAR_ISO": "WEEKISO",
428        "Q": "QUARTER",
429        "QTR": "QUARTER",
430        "QTRS": "QUARTER",
431        "QUARTERS": "QUARTER",
432        "H": "HOUR",
433        "HH": "HOUR",
434        "HR": "HOUR",
435        "HOURS": "HOUR",
436        "HRS": "HOUR",
437        "M": "MINUTE",
438        "MI": "MINUTE",
439        "MIN": "MINUTE",
440        "MINUTES": "MINUTE",
441        "MINS": "MINUTE",
442        "S": "SECOND",
443        "SEC": "SECOND",
444        "SECONDS": "SECOND",
445        "SECS": "SECOND",
446        "MS": "MILLISECOND",
447        "MSEC": "MILLISECOND",
448        "MSECS": "MILLISECOND",
449        "MSECOND": "MILLISECOND",
450        "MSECONDS": "MILLISECOND",
451        "MILLISEC": "MILLISECOND",
452        "MILLISECS": "MILLISECOND",
453        "MILLISECON": "MILLISECOND",
454        "MILLISECONDS": "MILLISECOND",
455        "US": "MICROSECOND",
456        "USEC": "MICROSECOND",
457        "USECS": "MICROSECOND",
458        "MICROSEC": "MICROSECOND",
459        "MICROSECS": "MICROSECOND",
460        "USECOND": "MICROSECOND",
461        "USECONDS": "MICROSECOND",
462        "MICROSECONDS": "MICROSECOND",
463        "NS": "NANOSECOND",
464        "NSEC": "NANOSECOND",
465        "NANOSEC": "NANOSECOND",
466        "NSECOND": "NANOSECOND",
467        "NSECONDS": "NANOSECOND",
468        "NANOSECS": "NANOSECOND",
469        "EPOCH_SECOND": "EPOCH",
470        "EPOCH_SECONDS": "EPOCH",
471        "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND",
472        "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND",
473        "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND",
474        "TZH": "TIMEZONE_HOUR",
475        "TZM": "TIMEZONE_MINUTE",
476        "DEC": "DECADE",
477        "DECS": "DECADE",
478        "DECADES": "DECADE",
479        "MIL": "MILLENIUM",
480        "MILS": "MILLENIUM",
481        "MILLENIA": "MILLENIUM",
482        "C": "CENTURY",
483        "CENT": "CENTURY",
484        "CENTS": "CENTURY",
485        "CENTURIES": "CENTURY",
486    }
487
488    @classmethod
489    def get_or_raise(cls, dialect: DialectType) -> Dialect:
490        """
491        Look up a dialect in the global dialect registry and return it if it exists.
492
493        Args:
494            dialect: The target dialect. If this is a string, it can be optionally followed by
495                additional key-value pairs that are separated by commas and are used to specify
496                dialect settings, such as whether the dialect's identifiers are case-sensitive.
497
498        Example:
499            >>> dialect = dialect_class = get_or_raise("duckdb")
500            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
501
502        Returns:
503            The corresponding Dialect instance.
504        """
505
506        if not dialect:
507            return cls()
508        if isinstance(dialect, _Dialect):
509            return dialect()
510        if isinstance(dialect, Dialect):
511            return dialect
512        if isinstance(dialect, str):
513            try:
514                dialect_name, *kv_pairs = dialect.split(",")
515                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
516            except ValueError:
517                raise ValueError(
518                    f"Invalid dialect format: '{dialect}'. "
519                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
520                )
521
522            result = cls.get(dialect_name.strip())
523            if not result:
524                from difflib import get_close_matches
525
526                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
527                if similar:
528                    similar = f" Did you mean {similar}?"
529
530                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
531
532            return result(**kwargs)
533
534        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
535
536    @classmethod
537    def format_time(
538        cls, expression: t.Optional[str | exp.Expression]
539    ) -> t.Optional[exp.Expression]:
540        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
541        if isinstance(expression, str):
542            return exp.Literal.string(
543                # the time formats are quoted
544                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
545            )
546
547        if expression and expression.is_string:
548            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
549
550        return expression
551
552    def __init__(self, **kwargs) -> None:
553        normalization_strategy = kwargs.pop("normalization_strategy", None)
554
555        if normalization_strategy is None:
556            self.normalization_strategy = self.NORMALIZATION_STRATEGY
557        else:
558            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
559
560        self.settings = kwargs
561
562    def __eq__(self, other: t.Any) -> bool:
563        # Does not currently take dialect state into account
564        return type(self) == other
565
566    def __hash__(self) -> int:
567        # Does not currently take dialect state into account
568        return hash(type(self))
569
570    def normalize_identifier(self, expression: E) -> E:
571        """
572        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
573
574        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
575        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
576        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
577        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
578
579        There are also dialects like Spark, which are case-insensitive even when quotes are
580        present, and dialects like MySQL, whose resolution rules match those employed by the
581        underlying operating system, for example they may always be case-sensitive in Linux.
582
583        Finally, the normalization behavior of some engines can even be controlled through flags,
584        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
585
586        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
587        that it can analyze queries in the optimizer and successfully capture their semantics.
588        """
589        if (
590            isinstance(expression, exp.Identifier)
591            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
592            and (
593                not expression.quoted
594                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
595            )
596        ):
597            expression.set(
598                "this",
599                (
600                    expression.this.upper()
601                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
602                    else expression.this.lower()
603                ),
604            )
605
606        return expression
607
608    def case_sensitive(self, text: str) -> bool:
609        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
610        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
611            return False
612
613        unsafe = (
614            str.islower
615            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
616            else str.isupper
617        )
618        return any(unsafe(char) for char in text)
619
620    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
621        """Checks if text can be identified given an identify option.
622
623        Args:
624            text: The text to check.
625            identify:
626                `"always"` or `True`: Always returns `True`.
627                `"safe"`: Only returns `True` if the identifier is case-insensitive.
628
629        Returns:
630            Whether the given text can be identified.
631        """
632        if identify is True or identify == "always":
633            return True
634
635        if identify == "safe":
636            return not self.case_sensitive(text)
637
638        return False
639
640    def quote_identifier(self, expression: E, identify: bool = True) -> E:
641        """
642        Adds quotes to a given identifier.
643
644        Args:
645            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
646            identify: If set to `False`, the quotes will only be added if the identifier is deemed
647                "unsafe", with respect to its characters and this dialect's normalization strategy.
648        """
649        if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
650            name = expression.this
651            expression.set(
652                "quoted",
653                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
654            )
655
656        return expression
657
658    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
659        if isinstance(path, exp.Literal):
660            path_text = path.name
661            if path.is_number:
662                path_text = f"[{path_text}]"
663            try:
664                return parse_json_path(path_text, self)
665            except ParseError as e:
666                logger.warning(f"Invalid JSON path syntax. {str(e)}")
667
668        return path
669
670    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
671        return self.parser(**opts).parse(self.tokenize(sql), sql)
672
673    def parse_into(
674        self, expression_type: exp.IntoType, sql: str, **opts
675    ) -> t.List[t.Optional[exp.Expression]]:
676        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
677
678    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
679        return self.generator(**opts).generate(expression, copy=copy)
680
681    def transpile(self, sql: str, **opts) -> t.List[str]:
682        return [
683            self.generate(expression, copy=False, **opts) if expression else ""
684            for expression in self.parse(sql)
685        ]
686
687    def tokenize(self, sql: str) -> t.List[Token]:
688        return self.tokenizer.tokenize(sql)
689
690    @property
691    def tokenizer(self) -> Tokenizer:
692        return self.tokenizer_class(dialect=self)
693
694    @property
695    def jsonpath_tokenizer(self) -> JSONPathTokenizer:
696        return self.jsonpath_tokenizer_class(dialect=self)
697
698    def parser(self, **opts) -> Parser:
699        return self.parser_class(dialect=self, **opts)
700
701    def generator(self, **opts) -> Generator:
702        return self.generator_class(dialect=self, **opts)
Dialect(**kwargs)
552    def __init__(self, **kwargs) -> None:
553        normalization_strategy = kwargs.pop("normalization_strategy", None)
554
555        if normalization_strategy is None:
556            self.normalization_strategy = self.NORMALIZATION_STRATEGY
557        else:
558            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
559
560        self.settings = kwargs
INDEX_OFFSET = 0

The base index offset for arrays.

WEEK_OFFSET = 0

First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.

UNNEST_COLUMN_ONLY = False

Whether UNNEST table aliases are treated as column aliases.

ALIAS_POST_TABLESAMPLE = False

Whether the table alias comes after tablesample.

TABLESAMPLE_SIZE_IS_PERCENT = False

Whether a size in the table sample clause represents percentage.

NORMALIZATION_STRATEGY = <NormalizationStrategy.LOWERCASE: 'LOWERCASE'>

Specifies the strategy according to which identifiers should be normalized.

IDENTIFIERS_CAN_START_WITH_DIGIT = False

Whether an unquoted identifier can start with a digit.

DPIPE_IS_STRING_CONCAT = True

Whether the DPIPE token (||) is a string concatenation operator.

STRICT_STRING_CONCAT = False

Whether CONCAT's arguments must be strings.

SUPPORTS_USER_DEFINED_TYPES = True

Whether user-defined data types are supported.

SUPPORTS_SEMI_ANTI_JOIN = True

Whether SEMI or ANTI joins are supported.

SUPPORTS_COLUMN_JOIN_MARKS = False

Whether the old-style outer join (+) syntax is supported.

COPY_PARAMS_ARE_CSV = True

Whether COPY statement parameters are separated by comma or whitespace

NORMALIZE_FUNCTIONS: bool | str = 'upper'

Determines how function names are going to be normalized.

Possible values:

"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.

LOG_BASE_FIRST: Optional[bool] = True

Whether the base comes first in the LOG function. Possible values: True, False, None (two arguments are not supported by LOG)

NULL_ORDERING = 'nulls_are_small'

Default NULL ordering method to use if not explicitly set. Possible values: "nulls_are_small", "nulls_are_large", "nulls_are_last"

TYPED_DIVISION = False

Whether the behavior of a / b depends on the types of a and b. False means a / b is always float division. True means a / b is integer division if both a and b are integers.

SAFE_DIVISION = False

Whether division by zero throws an error (False) or returns NULL (True).

CONCAT_COALESCE = False

A NULL arg in CONCAT yields NULL by default, but in some dialects it yields an empty string.

HEX_LOWERCASE = False

Whether the HEX function returns a lowercase hexadecimal string.

DATE_FORMAT = "'%Y-%m-%d'"
DATEINT_FORMAT = "'%Y%m%d'"
TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
TIME_MAPPING: Dict[str, str] = {}

Associates this dialect's time formats with their equivalent Python strftime formats.

FORMAT_MAPPING: Dict[str, str] = {}

Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy'). If empty, the corresponding trie will be constructed off of TIME_MAPPING.

UNESCAPED_SEQUENCES: Dict[str, str] = {}

Mapping of an escaped sequence (\n) to its unescaped version ( ).

PSEUDOCOLUMNS: Set[str] = set()

Columns that are auto-generated by the engine corresponding to this dialect. For example, such columns may be excluded from SELECT * queries.

PREFER_CTE_ALIAS_COLUMN = False

Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.

For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;

will be rewritten as

WITH y(c) AS (
    SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
FORCE_EARLY_ALIAS_REF_EXPANSION = False

Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()).

For example:

WITH data AS ( SELECT 1 AS id, 2 AS my_id ) SELECT id AS my_id FROM data WHERE my_id = 1 GROUP BY my_id, HAVING my_id = 1

In most dialects "my_id" would refer to "data.my_id" (which is done in _qualify_columns()) across the query, except: - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" - Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1"

EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False

Whether alias reference expansion before qualification should only happen for the GROUP BY clause.

tokenizer_class = <class 'sqlglot.tokens.Tokenizer'>
jsonpath_tokenizer_class = <class 'sqlglot.tokens.JSONPathTokenizer'>
parser_class = <class 'sqlglot.parser.Parser'>
generator_class = <class 'sqlglot.generator.Generator'>
TIME_TRIE: Dict = {}
FORMAT_TRIE: Dict = {}
INVERSE_TIME_MAPPING: Dict[str, str] = {}
INVERSE_TIME_TRIE: Dict = {}
INVERSE_FORMAT_MAPPING: Dict[str, str] = {}
INVERSE_FORMAT_TRIE: Dict = {}
ESCAPED_SEQUENCES: Dict[str, str] = {}
QUOTE_START = "'"
QUOTE_END = "'"
IDENTIFIER_START = '"'
IDENTIFIER_END = '"'
BIT_START: Optional[str] = None
BIT_END: Optional[str] = None
HEX_START: Optional[str] = None
HEX_END: Optional[str] = None
BYTE_START: Optional[str] = None
BYTE_END: Optional[str] = None
UNICODE_START: Optional[str] = None
UNICODE_END: Optional[str] = None
DATE_PART_MAPPING = {'Y': 'YEAR', 'YY': 'YEAR', 'YYY': 'YEAR', 'YYYY': 'YEAR', 'YR': 'YEAR', 'YEARS': 'YEAR', 'YRS': 'YEAR', 'MM': 'MONTH', 'MON': 'MONTH', 'MONS': 'MONTH', 'MONTHS': 'MONTH', 'D': 'DAY', 'DD': 'DAY', 'DAYS': 'DAY', 'DAYOFMONTH': 'DAY', 'DAY OF WEEK': 'DAYOFWEEK', 'WEEKDAY': 'DAYOFWEEK', 'DOW': 'DAYOFWEEK', 'DW': 'DAYOFWEEK', 'WEEKDAY_ISO': 'DAYOFWEEKISO', 'DOW_ISO': 'DAYOFWEEKISO', 'DW_ISO': 'DAYOFWEEKISO', 'DAY OF YEAR': 'DAYOFYEAR', 'DOY': 'DAYOFYEAR', 'DY': 'DAYOFYEAR', 'W': 'WEEK', 'WK': 'WEEK', 'WEEKOFYEAR': 'WEEK', 'WOY': 'WEEK', 'WY': 'WEEK', 'WEEK_ISO': 'WEEKISO', 'WEEKOFYEARISO': 'WEEKISO', 'WEEKOFYEAR_ISO': 'WEEKISO', 'Q': 'QUARTER', 'QTR': 'QUARTER', 'QTRS': 'QUARTER', 'QUARTERS': 'QUARTER', 'H': 'HOUR', 'HH': 'HOUR', 'HR': 'HOUR', 'HOURS': 'HOUR', 'HRS': 'HOUR', 'M': 'MINUTE', 'MI': 'MINUTE', 'MIN': 'MINUTE', 'MINUTES': 'MINUTE', 'MINS': 'MINUTE', 'S': 'SECOND', 'SEC': 'SECOND', 'SECONDS': 'SECOND', 'SECS': 'SECOND', 'MS': 'MILLISECOND', 'MSEC': 'MILLISECOND', 'MSECS': 'MILLISECOND', 'MSECOND': 'MILLISECOND', 'MSECONDS': 'MILLISECOND', 'MILLISEC': 'MILLISECOND', 'MILLISECS': 'MILLISECOND', 'MILLISECON': 'MILLISECOND', 'MILLISECONDS': 'MILLISECOND', 'US': 'MICROSECOND', 'USEC': 'MICROSECOND', 'USECS': 'MICROSECOND', 'MICROSEC': 'MICROSECOND', 'MICROSECS': 'MICROSECOND', 'USECOND': 'MICROSECOND', 'USECONDS': 'MICROSECOND', 'MICROSECONDS': 'MICROSECOND', 'NS': 'NANOSECOND', 'NSEC': 'NANOSECOND', 'NANOSEC': 'NANOSECOND', 'NSECOND': 'NANOSECOND', 'NSECONDS': 'NANOSECOND', 'NANOSECS': 'NANOSECOND', 'EPOCH_SECOND': 'EPOCH', 'EPOCH_SECONDS': 'EPOCH', 'EPOCH_MILLISECONDS': 'EPOCH_MILLISECOND', 'EPOCH_MICROSECONDS': 'EPOCH_MICROSECOND', 'EPOCH_NANOSECONDS': 'EPOCH_NANOSECOND', 'TZH': 'TIMEZONE_HOUR', 'TZM': 'TIMEZONE_MINUTE', 'DEC': 'DECADE', 'DECS': 'DECADE', 'DECADES': 'DECADE', 'MIL': 'MILLENIUM', 'MILS': 'MILLENIUM', 'MILLENIA': 'MILLENIUM', 'C': 'CENTURY', 'CENT': 'CENTURY', 'CENTS': 'CENTURY', 'CENTURIES': 'CENTURY'}
@classmethod
def get_or_raise( cls, dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> Dialect:
488    @classmethod
489    def get_or_raise(cls, dialect: DialectType) -> Dialect:
490        """
491        Look up a dialect in the global dialect registry and return it if it exists.
492
493        Args:
494            dialect: The target dialect. If this is a string, it can be optionally followed by
495                additional key-value pairs that are separated by commas and are used to specify
496                dialect settings, such as whether the dialect's identifiers are case-sensitive.
497
498        Example:
499            >>> dialect = dialect_class = get_or_raise("duckdb")
500            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
501
502        Returns:
503            The corresponding Dialect instance.
504        """
505
506        if not dialect:
507            return cls()
508        if isinstance(dialect, _Dialect):
509            return dialect()
510        if isinstance(dialect, Dialect):
511            return dialect
512        if isinstance(dialect, str):
513            try:
514                dialect_name, *kv_pairs = dialect.split(",")
515                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
516            except ValueError:
517                raise ValueError(
518                    f"Invalid dialect format: '{dialect}'. "
519                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
520                )
521
522            result = cls.get(dialect_name.strip())
523            if not result:
524                from difflib import get_close_matches
525
526                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
527                if similar:
528                    similar = f" Did you mean {similar}?"
529
530                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
531
532            return result(**kwargs)
533
534        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")

Look up a dialect in the global dialect registry and return it if it exists.

Arguments:
  • dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb")
>>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:

The corresponding Dialect instance.

@classmethod
def format_time( cls, expression: Union[str, sqlglot.expressions.Expression, NoneType]) -> Optional[sqlglot.expressions.Expression]:
536    @classmethod
537    def format_time(
538        cls, expression: t.Optional[str | exp.Expression]
539    ) -> t.Optional[exp.Expression]:
540        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
541        if isinstance(expression, str):
542            return exp.Literal.string(
543                # the time formats are quoted
544                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
545            )
546
547        if expression and expression.is_string:
548            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
549
550        return expression

Converts a time format in this dialect to its equivalent Python strftime format.

settings
def normalize_identifier(self, expression: ~E) -> ~E:
570    def normalize_identifier(self, expression: E) -> E:
571        """
572        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
573
574        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
575        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
576        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
577        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
578
579        There are also dialects like Spark, which are case-insensitive even when quotes are
580        present, and dialects like MySQL, whose resolution rules match those employed by the
581        underlying operating system, for example they may always be case-sensitive in Linux.
582
583        Finally, the normalization behavior of some engines can even be controlled through flags,
584        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
585
586        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
587        that it can analyze queries in the optimizer and successfully capture their semantics.
588        """
589        if (
590            isinstance(expression, exp.Identifier)
591            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
592            and (
593                not expression.quoted
594                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
595            )
596        ):
597            expression.set(
598                "this",
599                (
600                    expression.this.upper()
601                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
602                    else expression.this.lower()
603                ),
604            )
605
606        return expression

Transforms an identifier in a way that resembles how it'd be resolved by this dialect.

For example, an identifier like FoO would be resolved as foo in Postgres, because it lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so it would resolve it as FOO. If it was quoted, it'd need to be treated as case-sensitive, and so any normalization would be prohibited in order to avoid "breaking" the identifier.

There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.

Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.

SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.

def case_sensitive(self, text: str) -> bool:
608    def case_sensitive(self, text: str) -> bool:
609        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
610        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
611            return False
612
613        unsafe = (
614            str.islower
615            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
616            else str.isupper
617        )
618        return any(unsafe(char) for char in text)

Checks if text contains any case sensitive characters, based on the dialect's rules.

def can_identify(self, text: str, identify: str | bool = 'safe') -> bool:
620    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
621        """Checks if text can be identified given an identify option.
622
623        Args:
624            text: The text to check.
625            identify:
626                `"always"` or `True`: Always returns `True`.
627                `"safe"`: Only returns `True` if the identifier is case-insensitive.
628
629        Returns:
630            Whether the given text can be identified.
631        """
632        if identify is True or identify == "always":
633            return True
634
635        if identify == "safe":
636            return not self.case_sensitive(text)
637
638        return False

Checks if text can be identified given an identify option.

Arguments:
  • text: The text to check.
  • identify: "always" or True: Always returns True. "safe": Only returns True if the identifier is case-insensitive.
Returns:

Whether the given text can be identified.

def quote_identifier(self, expression: ~E, identify: bool = True) -> ~E:
640    def quote_identifier(self, expression: E, identify: bool = True) -> E:
641        """
642        Adds quotes to a given identifier.
643
644        Args:
645            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
646            identify: If set to `False`, the quotes will only be added if the identifier is deemed
647                "unsafe", with respect to its characters and this dialect's normalization strategy.
648        """
649        if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
650            name = expression.this
651            expression.set(
652                "quoted",
653                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
654            )
655
656        return expression

Adds quotes to a given identifier.

Arguments:
  • expression: The expression of interest. If it's not an Identifier, this method is a no-op.
  • identify: If set to False, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
def to_json_path( self, path: Optional[sqlglot.expressions.Expression]) -> Optional[sqlglot.expressions.Expression]:
658    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
659        if isinstance(path, exp.Literal):
660            path_text = path.name
661            if path.is_number:
662                path_text = f"[{path_text}]"
663            try:
664                return parse_json_path(path_text, self)
665            except ParseError as e:
666                logger.warning(f"Invalid JSON path syntax. {str(e)}")
667
668        return path
def parse(self, sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
670    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
671        return self.parser(**opts).parse(self.tokenize(sql), sql)
def parse_into( self, expression_type: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]]], sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
673    def parse_into(
674        self, expression_type: exp.IntoType, sql: str, **opts
675    ) -> t.List[t.Optional[exp.Expression]]:
676        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
def generate( self, expression: sqlglot.expressions.Expression, copy: bool = True, **opts) -> str:
678    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
679        return self.generator(**opts).generate(expression, copy=copy)
def transpile(self, sql: str, **opts) -> List[str]:
681    def transpile(self, sql: str, **opts) -> t.List[str]:
682        return [
683            self.generate(expression, copy=False, **opts) if expression else ""
684            for expression in self.parse(sql)
685        ]
def tokenize(self, sql: str) -> List[sqlglot.tokens.Token]:
687    def tokenize(self, sql: str) -> t.List[Token]:
688        return self.tokenizer.tokenize(sql)
tokenizer: sqlglot.tokens.Tokenizer
690    @property
691    def tokenizer(self) -> Tokenizer:
692        return self.tokenizer_class(dialect=self)
jsonpath_tokenizer: sqlglot.jsonpath.JSONPathTokenizer
694    @property
695    def jsonpath_tokenizer(self) -> JSONPathTokenizer:
696        return self.jsonpath_tokenizer_class(dialect=self)
def parser(self, **opts) -> sqlglot.parser.Parser:
698    def parser(self, **opts) -> Parser:
699        return self.parser_class(dialect=self, **opts)
def generator(self, **opts) -> sqlglot.generator.Generator:
701    def generator(self, **opts) -> Generator:
702        return self.generator_class(dialect=self, **opts)
DialectType = typing.Union[str, Dialect, typing.Type[Dialect], NoneType]
def rename_func( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
708def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
709    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
def approx_count_distinct_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ApproxDistinct) -> str:
712def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
713    if expression.args.get("accuracy"):
714        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
715    return self.func("APPROX_COUNT_DISTINCT", expression.this)
def if_sql( name: str = 'IF', false_value: Union[str, sqlglot.expressions.Expression, NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.If], str]:
718def if_sql(
719    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
720) -> t.Callable[[Generator, exp.If], str]:
721    def _if_sql(self: Generator, expression: exp.If) -> str:
722        return self.func(
723            name,
724            expression.this,
725            expression.args.get("true"),
726            expression.args.get("false") or false_value,
727        )
728
729    return _if_sql
def arrow_json_extract_sql( self: sqlglot.generator.Generator, expression: Union[sqlglot.expressions.JSONExtract, sqlglot.expressions.JSONExtractScalar]) -> str:
732def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
733    this = expression.this
734    if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string:
735        this.replace(exp.cast(this, exp.DataType.Type.JSON))
736
737    return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
def inline_array_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
740def inline_array_sql(self: Generator, expression: exp.Array) -> str:
741    return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]"
def inline_array_unless_query( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
744def inline_array_unless_query(self: Generator, expression: exp.Array) -> str:
745    elem = seq_get(expression.expressions, 0)
746    if isinstance(elem, exp.Expression) and elem.find(exp.Query):
747        return self.func("ARRAY", elem)
748    return inline_array_sql(self, expression)
def no_ilike_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ILike) -> str:
751def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
752    return self.like_sql(
753        exp.Like(
754            this=exp.Lower(this=expression.this), expression=exp.Lower(this=expression.expression)
755        )
756    )
def no_paren_current_date_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CurrentDate) -> str:
759def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
760    zone = self.sql(expression, "this")
761    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
def no_recursive_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.With) -> str:
764def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
765    if expression.args.get("recursive"):
766        self.unsupported("Recursive CTEs are unsupported")
767        expression.args["recursive"] = False
768    return self.with_sql(expression)
def no_safe_divide_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SafeDivide) -> str:
771def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
772    n = self.sql(expression, "this")
773    d = self.sql(expression, "expression")
774    return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)"
def no_tablesample_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TableSample) -> str:
777def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
778    self.unsupported("TABLESAMPLE unsupported")
779    return self.sql(expression.this)
def no_pivot_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Pivot) -> str:
782def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
783    self.unsupported("PIVOT unsupported")
784    return ""
def no_trycast_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TryCast) -> str:
787def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
788    return self.cast_sql(expression)
def no_comment_column_constraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CommentColumnConstraint) -> str:
791def no_comment_column_constraint_sql(
792    self: Generator, expression: exp.CommentColumnConstraint
793) -> str:
794    self.unsupported("CommentColumnConstraint unsupported")
795    return ""
def no_map_from_entries_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.MapFromEntries) -> str:
798def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
799    self.unsupported("MAP_FROM_ENTRIES unsupported")
800    return ""
def str_position_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition, generate_instance: bool = False) -> str:
803def str_position_sql(
804    self: Generator, expression: exp.StrPosition, generate_instance: bool = False
805) -> str:
806    this = self.sql(expression, "this")
807    substr = self.sql(expression, "substr")
808    position = self.sql(expression, "position")
809    instance = expression.args.get("instance") if generate_instance else None
810    position_offset = ""
811
812    if position:
813        # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects
814        this = self.func("SUBSTR", this, position)
815        position_offset = f" + {position} - 1"
816
817    return self.func("STRPOS", this, substr, instance) + position_offset
def struct_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StructExtract) -> str:
820def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
821    return (
822        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
823    )
def var_map_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Map | sqlglot.expressions.VarMap, map_func_name: str = 'MAP') -> str:
826def var_map_sql(
827    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
828) -> str:
829    keys = expression.args["keys"]
830    values = expression.args["values"]
831
832    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
833        self.unsupported("Cannot convert array columns into map.")
834        return self.func(map_func_name, keys, values)
835
836    args = []
837    for key, value in zip(keys.expressions, values.expressions):
838        args.append(self.sql(key))
839        args.append(self.sql(value))
840
841    return self.func(map_func_name, *args)
def build_formatted_time( exp_class: Type[~E], dialect: str, default: Union[str, bool, NoneType] = None) -> Callable[[List], ~E]:
844def build_formatted_time(
845    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
846) -> t.Callable[[t.List], E]:
847    """Helper used for time expressions.
848
849    Args:
850        exp_class: the expression class to instantiate.
851        dialect: target sql dialect.
852        default: the default format, True being time.
853
854    Returns:
855        A callable that can be used to return the appropriately formatted time expression.
856    """
857
858    def _builder(args: t.List):
859        return exp_class(
860            this=seq_get(args, 0),
861            format=Dialect[dialect].format_time(
862                seq_get(args, 1)
863                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
864            ),
865        )
866
867    return _builder

Helper used for time expressions.

Arguments:
  • exp_class: the expression class to instantiate.
  • dialect: target sql dialect.
  • default: the default format, True being time.
Returns:

A callable that can be used to return the appropriately formatted time expression.

def time_format( dialect: Union[str, Dialect, Type[Dialect], NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.UnixToStr | sqlglot.expressions.StrToUnix], Optional[str]]:
870def time_format(
871    dialect: DialectType = None,
872) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
873    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
874        """
875        Returns the time format for a given expression, unless it's equivalent
876        to the default time format of the dialect of interest.
877        """
878        time_format = self.format_time(expression)
879        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
880
881    return _time_format
def build_date_delta( exp_class: Type[~E], unit_mapping: Optional[Dict[str, str]] = None, default_unit: Optional[str] = 'DAY') -> Callable[[List], ~E]:
884def build_date_delta(
885    exp_class: t.Type[E],
886    unit_mapping: t.Optional[t.Dict[str, str]] = None,
887    default_unit: t.Optional[str] = "DAY",
888) -> t.Callable[[t.List], E]:
889    def _builder(args: t.List) -> E:
890        unit_based = len(args) == 3
891        this = args[2] if unit_based else seq_get(args, 0)
892        unit = None
893        if unit_based or default_unit:
894            unit = args[0] if unit_based else exp.Literal.string(default_unit)
895            unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
896        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
897
898    return _builder
def build_date_delta_with_interval(expression_class: Type[~E]) -> Callable[[List], Optional[~E]]:
901def build_date_delta_with_interval(
902    expression_class: t.Type[E],
903) -> t.Callable[[t.List], t.Optional[E]]:
904    def _builder(args: t.List) -> t.Optional[E]:
905        if len(args) < 2:
906            return None
907
908        interval = args[1]
909
910        if not isinstance(interval, exp.Interval):
911            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
912
913        expression = interval.this
914        if expression and expression.is_string:
915            expression = exp.Literal.number(expression.this)
916
917        return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval))
918
919    return _builder
def date_trunc_to_time( args: List) -> sqlglot.expressions.DateTrunc | sqlglot.expressions.TimestampTrunc:
922def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
923    unit = seq_get(args, 0)
924    this = seq_get(args, 1)
925
926    if isinstance(this, exp.Cast) and this.is_type("date"):
927        return exp.DateTrunc(unit=unit, this=this)
928    return exp.TimestampTrunc(this=this, unit=unit)
def date_add_interval_sql( data_type: str, kind: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
931def date_add_interval_sql(
932    data_type: str, kind: str
933) -> t.Callable[[Generator, exp.Expression], str]:
934    def func(self: Generator, expression: exp.Expression) -> str:
935        this = self.sql(expression, "this")
936        interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression))
937        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
938
939    return func
def timestamptrunc_sql( zone: bool = False) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.TimestampTrunc], str]:
942def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]:
943    def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
944        args = [unit_to_str(expression), expression.this]
945        if zone:
946            args.append(expression.args.get("zone"))
947        return self.func("DATE_TRUNC", *args)
948
949    return _timestamptrunc_sql
def no_timestamp_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Timestamp) -> str:
952def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
953    zone = expression.args.get("zone")
954    if not zone:
955        from sqlglot.optimizer.annotate_types import annotate_types
956
957        target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP
958        return self.sql(exp.cast(expression.this, target_type))
959    if zone.name.lower() in TIMEZONES:
960        return self.sql(
961            exp.AtTimeZone(
962                this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP),
963                zone=zone,
964            )
965        )
966    return self.func("TIMESTAMP", expression.this, zone)
def no_time_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Time) -> str:
969def no_time_sql(self: Generator, expression: exp.Time) -> str:
970    # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME)
971    this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ)
972    expr = exp.cast(
973        exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME
974    )
975    return self.sql(expr)
def no_datetime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Datetime) -> str:
978def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str:
979    this = expression.this
980    expr = expression.expression
981
982    if expr.name.lower() in TIMEZONES:
983        # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP)
984        this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ)
985        this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP)
986        return self.sql(this)
987
988    this = exp.cast(this, exp.DataType.Type.DATE)
989    expr = exp.cast(expr, exp.DataType.Type.TIME)
990
991    return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP))
def locate_to_strposition(args: List) -> sqlglot.expressions.Expression:
994def locate_to_strposition(args: t.List) -> exp.Expression:
995    return exp.StrPosition(
996        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
997    )
def strposition_to_locate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
1000def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
1001    return self.func(
1002        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
1003    )
def left_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
1006def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
1007    return self.sql(
1008        exp.Substring(
1009            this=expression.this, start=exp.Literal.number(1), length=expression.expression
1010        )
1011    )
def right_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
1014def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
1015    return self.sql(
1016        exp.Substring(
1017            this=expression.this,
1018            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
1019        )
1020    )
def timestrtotime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimeStrToTime) -> str:
1023def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
1024    return self.sql(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP))
def datestrtodate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.DateStrToDate) -> str:
1027def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
1028    return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE))
def encode_decode_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression, name: str, replace: bool = True) -> str:
1032def encode_decode_sql(
1033    self: Generator, expression: exp.Expression, name: str, replace: bool = True
1034) -> str:
1035    charset = expression.args.get("charset")
1036    if charset and charset.name.lower() != "utf-8":
1037        self.unsupported(f"Expected utf-8 character set, got {charset}.")
1038
1039    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
def min_or_least( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Min) -> str:
1042def min_or_least(self: Generator, expression: exp.Min) -> str:
1043    name = "LEAST" if expression.expressions else "MIN"
1044    return rename_func(name)(self, expression)
def max_or_greatest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Max) -> str:
1047def max_or_greatest(self: Generator, expression: exp.Max) -> str:
1048    name = "GREATEST" if expression.expressions else "MAX"
1049    return rename_func(name)(self, expression)
def count_if_to_sum( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CountIf) -> str:
1052def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
1053    cond = expression.this
1054
1055    if isinstance(expression.this, exp.Distinct):
1056        cond = expression.this.expressions[0]
1057        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
1058
1059    return self.func("sum", exp.func("if", cond, 1, 0))
def trim_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Trim) -> str:
1062def trim_sql(self: Generator, expression: exp.Trim) -> str:
1063    target = self.sql(expression, "this")
1064    trim_type = self.sql(expression, "position")
1065    remove_chars = self.sql(expression, "expression")
1066    collation = self.sql(expression, "collation")
1067
1068    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
1069    if not remove_chars and not collation:
1070        return self.trim_sql(expression)
1071
1072    trim_type = f"{trim_type} " if trim_type else ""
1073    remove_chars = f"{remove_chars} " if remove_chars else ""
1074    from_part = "FROM " if trim_type or remove_chars else ""
1075    collation = f" COLLATE {collation}" if collation else ""
1076    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def str_to_time_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression) -> str:
1079def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
1080    return self.func("STRPTIME", expression.this, self.format_time(expression))
def concat_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Concat) -> str:
1083def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
1084    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
def concat_ws_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ConcatWs) -> str:
1087def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
1088    delim, *rest_args = expression.expressions
1089    return self.sql(
1090        reduce(
1091            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
1092            rest_args,
1093        )
1094    )
def regexp_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpExtract) -> str:
1097def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
1098    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
1099    if bad_args:
1100        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
1101
1102    return self.func(
1103        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
1104    )
def regexp_replace_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpReplace) -> str:
1107def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
1108    bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers")))
1109    if bad_args:
1110        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
1111
1112    return self.func(
1113        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
1114    )
def pivot_column_names( aggregations: List[sqlglot.expressions.Expression], dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> List[str]:
1117def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
1118    names = []
1119    for agg in aggregations:
1120        if isinstance(agg, exp.Alias):
1121            names.append(agg.alias)
1122        else:
1123            """
1124            This case corresponds to aggregations without aliases being used as suffixes
1125            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
1126            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
1127            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
1128            """
1129            agg_all_unquoted = agg.transform(
1130                lambda node: (
1131                    exp.Identifier(this=node.name, quoted=False)
1132                    if isinstance(node, exp.Identifier)
1133                    else node
1134                )
1135            )
1136            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
1137
1138    return names
def binary_from_function(expr_type: Type[~B]) -> Callable[[List], ~B]:
1141def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
1142    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
def build_timestamp_trunc(args: List) -> sqlglot.expressions.TimestampTrunc:
1146def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
1147    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
def any_value_to_max_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.AnyValue) -> str:
1150def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
1151    return self.func("MAX", expression.this)
def bool_xor_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Xor) -> str:
1154def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
1155    a = self.sql(expression.left)
1156    b = self.sql(expression.right)
1157    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
def is_parse_json(expression: sqlglot.expressions.Expression) -> bool:
1160def is_parse_json(expression: exp.Expression) -> bool:
1161    return isinstance(expression, exp.ParseJSON) or (
1162        isinstance(expression, exp.Cast) and expression.is_type("json")
1163    )
def isnull_to_is_null(args: List) -> sqlglot.expressions.Expression:
1166def isnull_to_is_null(args: t.List) -> exp.Expression:
1167    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
def generatedasidentitycolumnconstraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.GeneratedAsIdentityColumnConstraint) -> str:
1170def generatedasidentitycolumnconstraint_sql(
1171    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
1172) -> str:
1173    start = self.sql(expression, "start") or "1"
1174    increment = self.sql(expression, "increment") or "1"
1175    return f"IDENTITY({start}, {increment})"
def arg_max_or_min_no_count( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.ArgMax | sqlglot.expressions.ArgMin], str]:
1178def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
1179    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
1180        if expression.args.get("count"):
1181            self.unsupported(f"Only two arguments are supported in function {name}.")
1182
1183        return self.func(name, expression.this, expression.expression)
1184
1185    return _arg_max_or_min_sql
def ts_or_ds_add_cast( expression: sqlglot.expressions.TsOrDsAdd) -> sqlglot.expressions.TsOrDsAdd:
1188def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
1189    this = expression.this.copy()
1190
1191    return_type = expression.return_type
1192    if return_type.is_type(exp.DataType.Type.DATE):
1193        # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
1194        # can truncate timestamp strings, because some dialects can't cast them to DATE
1195        this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
1196
1197    expression.this.replace(exp.cast(this, return_type))
1198    return expression
def date_delta_sql( name: str, cast: bool = False) -> Callable[[sqlglot.generator.Generator, Union[sqlglot.expressions.DateAdd, sqlglot.expressions.TsOrDsAdd, sqlglot.expressions.DateDiff, sqlglot.expressions.TsOrDsDiff]], str]:
1201def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
1202    def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
1203        if cast and isinstance(expression, exp.TsOrDsAdd):
1204            expression = ts_or_ds_add_cast(expression)
1205
1206        return self.func(
1207            name,
1208            unit_to_var(expression),
1209            expression.expression,
1210            expression.this,
1211        )
1212
1213    return _delta_sql
def unit_to_str( expression: sqlglot.expressions.Expression, default: str = 'DAY') -> Optional[sqlglot.expressions.Expression]:
1216def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
1217    unit = expression.args.get("unit")
1218
1219    if isinstance(unit, exp.Placeholder):
1220        return unit
1221    if unit:
1222        return exp.Literal.string(unit.name)
1223    return exp.Literal.string(default) if default else None
def unit_to_var( expression: sqlglot.expressions.Expression, default: str = 'DAY') -> Optional[sqlglot.expressions.Expression]:
1226def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
1227    unit = expression.args.get("unit")
1228
1229    if isinstance(unit, (exp.Var, exp.Placeholder)):
1230        return unit
1231    return exp.Var(this=default) if default else None
def map_date_part( part, dialect: Union[str, Dialect, Type[Dialect], NoneType] = <class 'Dialect'>):
1246def map_date_part(part, dialect: DialectType = Dialect):
1247    mapped = (
1248        Dialect.get_or_raise(dialect).DATE_PART_MAPPING.get(part.name.upper()) if part else None
1249    )
1250    return exp.var(mapped) if mapped else part
def no_last_day_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.LastDay) -> str:
1253def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
1254    trunc_curr_date = exp.func("date_trunc", "month", expression.this)
1255    plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
1256    minus_one_day = exp.func("date_sub", plus_one_month, 1, "day")
1257
1258    return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
def merge_without_target_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Merge) -> str:
1261def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
1262    """Remove table refs from columns in when statements."""
1263    alias = expression.this.args.get("alias")
1264
1265    def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]:
1266        return self.dialect.normalize_identifier(identifier).name if identifier else None
1267
1268    targets = {normalize(expression.this.this)}
1269
1270    if alias:
1271        targets.add(normalize(alias.this))
1272
1273    for when in expression.expressions:
1274        when.transform(
1275            lambda node: (
1276                exp.column(node.this)
1277                if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
1278                else node
1279            ),
1280            copy=False,
1281        )
1282
1283    return self.merge_sql(expression)

Remove table refs from columns in when statements.

def build_json_extract_path( expr_type: Type[~F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False) -> Callable[[List], ~F]:
1286def build_json_extract_path(
1287    expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False
1288) -> t.Callable[[t.List], F]:
1289    def _builder(args: t.List) -> F:
1290        segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
1291        for arg in args[1:]:
1292            if not isinstance(arg, exp.Literal):
1293                # We use the fallback parser because we can't really transpile non-literals safely
1294                return expr_type.from_arg_list(args)
1295
1296            text = arg.name
1297            if is_int(text):
1298                index = int(text)
1299                segments.append(
1300                    exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1)
1301                )
1302            else:
1303                segments.append(exp.JSONPathKey(this=text))
1304
1305        # This is done to avoid failing in the expression validator due to the arg count
1306        del args[2:]
1307        return expr_type(
1308            this=seq_get(args, 0),
1309            expression=exp.JSONPath(expressions=segments),
1310            only_json_types=arrow_req_json_type,
1311        )
1312
1313    return _builder
def json_extract_segments( name: str, quoted_index: bool = True, op: Optional[str] = None) -> Callable[[sqlglot.generator.Generator, Union[sqlglot.expressions.JSONExtract, sqlglot.expressions.JSONExtractScalar]], str]:
1316def json_extract_segments(
1317    name: str, quoted_index: bool = True, op: t.Optional[str] = None
1318) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]:
1319    def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
1320        path = expression.expression
1321        if not isinstance(path, exp.JSONPath):
1322            return rename_func(name)(self, expression)
1323
1324        segments = []
1325        for segment in path.expressions:
1326            path = self.sql(segment)
1327            if path:
1328                if isinstance(segment, exp.JSONPathPart) and (
1329                    quoted_index or not isinstance(segment, exp.JSONPathSubscript)
1330                ):
1331                    path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}"
1332
1333                segments.append(path)
1334
1335        if op:
1336            return f" {op} ".join([self.sql(expression.this), *segments])
1337        return self.func(name, expression.this, *segments)
1338
1339    return _json_extract_segments
def json_path_key_only_name( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONPathKey) -> str:
1342def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str:
1343    if isinstance(expression.this, exp.JSONPathWildcard):
1344        self.unsupported("Unsupported wildcard in JSONPathKey expression")
1345
1346    return expression.name
def filter_array_using_unnest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ArrayFilter) -> str:
1349def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str:
1350    cond = expression.expression
1351    if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1:
1352        alias = cond.expressions[0]
1353        cond = cond.this
1354    elif isinstance(cond, exp.Predicate):
1355        alias = "_u"
1356    else:
1357        self.unsupported("Unsupported filter condition")
1358        return ""
1359
1360    unnest = exp.Unnest(expressions=[expression.this])
1361    filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond)
1362    return self.sql(exp.Array(expressions=[filtered]))
def to_number_with_nls_param( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ToNumber) -> str:
1365def to_number_with_nls_param(self: Generator, expression: exp.ToNumber) -> str:
1366    return self.func(
1367        "TO_NUMBER",
1368        expression.this,
1369        expression.args.get("format"),
1370        expression.args.get("nlsparam"),
1371    )
def build_default_decimal_type( precision: Optional[int] = None, scale: Optional[int] = None) -> Callable[[sqlglot.expressions.DataType], sqlglot.expressions.DataType]:
1374def build_default_decimal_type(
1375    precision: t.Optional[int] = None, scale: t.Optional[int] = None
1376) -> t.Callable[[exp.DataType], exp.DataType]:
1377    def _builder(dtype: exp.DataType) -> exp.DataType:
1378        if dtype.expressions or precision is None:
1379            return dtype
1380
1381        params = f"{precision}{f', {scale}' if scale is not None else ''}"
1382        return exp.DataType.build(f"DECIMAL({params})")
1383
1384    return _builder
def build_timestamp_from_parts(args: List) -> sqlglot.expressions.Func:
1387def build_timestamp_from_parts(args: t.List) -> exp.Func:
1388    if len(args) == 2:
1389        # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept,
1390        # so we parse this into Anonymous for now instead of introducing complexity
1391        return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args)
1392
1393    return exp.TimestampFromParts.from_arg_list(args)
def sha256_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SHA2) -> str:
1396def sha256_sql(self: Generator, expression: exp.SHA2) -> str:
1397    return self.func(f"SHA{expression.text('length') or '256'}", expression.this)