sqlglot.dialects.dialect
1from __future__ import annotations 2 3import logging 4import typing as t 5from enum import Enum, auto 6from functools import reduce 7 8from sqlglot import exp 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator 11from sqlglot.helper import AutoName, flatten, is_int, seq_get 12from sqlglot.jsonpath import JSONPathTokenizer, parse as parse_json_path 13from sqlglot.parser import Parser 14from sqlglot.time import TIMEZONES, format_time 15from sqlglot.tokens import Token, Tokenizer, TokenType 16from sqlglot.trie import new_trie 17 18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff] 19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] 20JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar] 21 22 23if t.TYPE_CHECKING: 24 from sqlglot._typing import B, E, F 25 26logger = logging.getLogger("sqlglot") 27 28UNESCAPED_SEQUENCES = { 29 "\\a": "\a", 30 "\\b": "\b", 31 "\\f": "\f", 32 "\\n": "\n", 33 "\\r": "\r", 34 "\\t": "\t", 35 "\\v": "\v", 36 "\\\\": "\\", 37} 38 39 40class Dialects(str, Enum): 41 """Dialects supported by SQLGLot.""" 42 43 DIALECT = "" 44 45 ATHENA = "athena" 46 BIGQUERY = "bigquery" 47 CLICKHOUSE = "clickhouse" 48 DATABRICKS = "databricks" 49 DORIS = "doris" 50 DRILL = "drill" 51 DUCKDB = "duckdb" 52 HIVE = "hive" 53 MATERIALIZE = "materialize" 54 MYSQL = "mysql" 55 ORACLE = "oracle" 56 POSTGRES = "postgres" 57 PRESTO = "presto" 58 PRQL = "prql" 59 REDSHIFT = "redshift" 60 RISINGWAVE = "risingwave" 61 SNOWFLAKE = "snowflake" 62 SPARK = "spark" 63 SPARK2 = "spark2" 64 SQLITE = "sqlite" 65 STARROCKS = "starrocks" 66 TABLEAU = "tableau" 67 TERADATA = "teradata" 68 TRINO = "trino" 69 TSQL = "tsql" 70 71 72class NormalizationStrategy(str, AutoName): 73 """Specifies the strategy according to which identifiers should be normalized.""" 74 75 LOWERCASE = auto() 76 """Unquoted identifiers are lowercased.""" 77 78 UPPERCASE = auto() 79 """Unquoted identifiers are uppercased.""" 80 81 CASE_SENSITIVE = auto() 82 """Always case-sensitive, regardless of quotes.""" 83 84 CASE_INSENSITIVE = auto() 85 """Always case-insensitive, regardless of quotes.""" 86 87 88class _Dialect(type): 89 classes: t.Dict[str, t.Type[Dialect]] = {} 90 91 def __eq__(cls, other: t.Any) -> bool: 92 if cls is other: 93 return True 94 if isinstance(other, str): 95 return cls is cls.get(other) 96 if isinstance(other, Dialect): 97 return cls is type(other) 98 99 return False 100 101 def __hash__(cls) -> int: 102 return hash(cls.__name__.lower()) 103 104 @classmethod 105 def __getitem__(cls, key: str) -> t.Type[Dialect]: 106 return cls.classes[key] 107 108 @classmethod 109 def get( 110 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 111 ) -> t.Optional[t.Type[Dialect]]: 112 return cls.classes.get(key, default) 113 114 def __new__(cls, clsname, bases, attrs): 115 klass = super().__new__(cls, clsname, bases, attrs) 116 enum = Dialects.__members__.get(clsname.upper()) 117 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 118 119 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 120 klass.FORMAT_TRIE = ( 121 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 122 ) 123 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 124 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 125 klass.INVERSE_FORMAT_MAPPING = {v: k for k, v in klass.FORMAT_MAPPING.items()} 126 klass.INVERSE_FORMAT_TRIE = new_trie(klass.INVERSE_FORMAT_MAPPING) 127 128 base = seq_get(bases, 0) 129 base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),) 130 base_jsonpath_tokenizer = (getattr(base, "jsonpath_tokenizer_class", JSONPathTokenizer),) 131 base_parser = (getattr(base, "parser_class", Parser),) 132 base_generator = (getattr(base, "generator_class", Generator),) 133 134 klass.tokenizer_class = klass.__dict__.get( 135 "Tokenizer", type("Tokenizer", base_tokenizer, {}) 136 ) 137 klass.jsonpath_tokenizer_class = klass.__dict__.get( 138 "JSONPathTokenizer", type("JSONPathTokenizer", base_jsonpath_tokenizer, {}) 139 ) 140 klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {})) 141 klass.generator_class = klass.__dict__.get( 142 "Generator", type("Generator", base_generator, {}) 143 ) 144 145 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 146 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 147 klass.tokenizer_class._IDENTIFIERS.items() 148 )[0] 149 150 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 151 return next( 152 ( 153 (s, e) 154 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 155 if t == token_type 156 ), 157 (None, None), 158 ) 159 160 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 161 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 162 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 163 klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING) 164 165 if "\\" in klass.tokenizer_class.STRING_ESCAPES: 166 klass.UNESCAPED_SEQUENCES = { 167 **UNESCAPED_SEQUENCES, 168 **klass.UNESCAPED_SEQUENCES, 169 } 170 171 klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()} 172 173 klass.SUPPORTS_COLUMN_JOIN_MARKS = "(+)" in klass.tokenizer_class.KEYWORDS 174 175 if enum not in ("", "bigquery"): 176 klass.generator_class.SELECT_KINDS = () 177 178 if enum not in ("", "athena", "presto", "trino"): 179 klass.generator_class.TRY_SUPPORTED = False 180 klass.generator_class.SUPPORTS_UESCAPE = False 181 182 if enum not in ("", "databricks", "hive", "spark", "spark2"): 183 modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy() 184 for modifier in ("cluster", "distribute", "sort"): 185 modifier_transforms.pop(modifier, None) 186 187 klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms 188 189 if enum not in ("", "doris", "mysql"): 190 klass.parser_class.ID_VAR_TOKENS = klass.parser_class.ID_VAR_TOKENS | { 191 TokenType.STRAIGHT_JOIN, 192 } 193 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 194 TokenType.STRAIGHT_JOIN, 195 } 196 197 if not klass.SUPPORTS_SEMI_ANTI_JOIN: 198 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 199 TokenType.ANTI, 200 TokenType.SEMI, 201 } 202 203 return klass 204 205 206class Dialect(metaclass=_Dialect): 207 INDEX_OFFSET = 0 208 """The base index offset for arrays.""" 209 210 WEEK_OFFSET = 0 211 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 212 213 UNNEST_COLUMN_ONLY = False 214 """Whether `UNNEST` table aliases are treated as column aliases.""" 215 216 ALIAS_POST_TABLESAMPLE = False 217 """Whether the table alias comes after tablesample.""" 218 219 TABLESAMPLE_SIZE_IS_PERCENT = False 220 """Whether a size in the table sample clause represents percentage.""" 221 222 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 223 """Specifies the strategy according to which identifiers should be normalized.""" 224 225 IDENTIFIERS_CAN_START_WITH_DIGIT = False 226 """Whether an unquoted identifier can start with a digit.""" 227 228 DPIPE_IS_STRING_CONCAT = True 229 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 230 231 STRICT_STRING_CONCAT = False 232 """Whether `CONCAT`'s arguments must be strings.""" 233 234 SUPPORTS_USER_DEFINED_TYPES = True 235 """Whether user-defined data types are supported.""" 236 237 SUPPORTS_SEMI_ANTI_JOIN = True 238 """Whether `SEMI` or `ANTI` joins are supported.""" 239 240 SUPPORTS_COLUMN_JOIN_MARKS = False 241 """Whether the old-style outer join (+) syntax is supported.""" 242 243 COPY_PARAMS_ARE_CSV = True 244 """Separator of COPY statement parameters.""" 245 246 NORMALIZE_FUNCTIONS: bool | str = "upper" 247 """ 248 Determines how function names are going to be normalized. 249 Possible values: 250 "upper" or True: Convert names to uppercase. 251 "lower": Convert names to lowercase. 252 False: Disables function name normalization. 253 """ 254 255 LOG_BASE_FIRST: t.Optional[bool] = True 256 """ 257 Whether the base comes first in the `LOG` function. 258 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 259 """ 260 261 NULL_ORDERING = "nulls_are_small" 262 """ 263 Default `NULL` ordering method to use if not explicitly set. 264 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 265 """ 266 267 TYPED_DIVISION = False 268 """ 269 Whether the behavior of `a / b` depends on the types of `a` and `b`. 270 False means `a / b` is always float division. 271 True means `a / b` is integer division if both `a` and `b` are integers. 272 """ 273 274 SAFE_DIVISION = False 275 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 276 277 CONCAT_COALESCE = False 278 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 279 280 HEX_LOWERCASE = False 281 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 282 283 DATE_FORMAT = "'%Y-%m-%d'" 284 DATEINT_FORMAT = "'%Y%m%d'" 285 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 286 287 TIME_MAPPING: t.Dict[str, str] = {} 288 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 289 290 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 291 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 292 FORMAT_MAPPING: t.Dict[str, str] = {} 293 """ 294 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 295 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 296 """ 297 298 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 299 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 300 301 PSEUDOCOLUMNS: t.Set[str] = set() 302 """ 303 Columns that are auto-generated by the engine corresponding to this dialect. 304 For example, such columns may be excluded from `SELECT *` queries. 305 """ 306 307 PREFER_CTE_ALIAS_COLUMN = False 308 """ 309 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 310 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 311 any projection aliases in the subquery. 312 313 For example, 314 WITH y(c) AS ( 315 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 316 ) SELECT c FROM y; 317 318 will be rewritten as 319 320 WITH y(c) AS ( 321 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 322 ) SELECT c FROM y; 323 """ 324 325 COPY_PARAMS_ARE_CSV = True 326 """ 327 Whether COPY statement parameters are separated by comma or whitespace 328 """ 329 330 FORCE_EARLY_ALIAS_REF_EXPANSION = False 331 """ 332 Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()). 333 334 For example: 335 WITH data AS ( 336 SELECT 337 1 AS id, 338 2 AS my_id 339 ) 340 SELECT 341 id AS my_id 342 FROM 343 data 344 WHERE 345 my_id = 1 346 GROUP BY 347 my_id, 348 HAVING 349 my_id = 1 350 351 In most dialects "my_id" would refer to "data.my_id" (which is done in _qualify_columns()) across the query, except: 352 - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" 353 - Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1" 354 """ 355 356 EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False 357 """Whether alias reference expansion before qualification should only happen for the GROUP BY clause.""" 358 359 # --- Autofilled --- 360 361 tokenizer_class = Tokenizer 362 jsonpath_tokenizer_class = JSONPathTokenizer 363 parser_class = Parser 364 generator_class = Generator 365 366 # A trie of the time_mapping keys 367 TIME_TRIE: t.Dict = {} 368 FORMAT_TRIE: t.Dict = {} 369 370 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 371 INVERSE_TIME_TRIE: t.Dict = {} 372 INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {} 373 INVERSE_FORMAT_TRIE: t.Dict = {} 374 375 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 376 377 # Delimiters for string literals and identifiers 378 QUOTE_START = "'" 379 QUOTE_END = "'" 380 IDENTIFIER_START = '"' 381 IDENTIFIER_END = '"' 382 383 # Delimiters for bit, hex, byte and unicode literals 384 BIT_START: t.Optional[str] = None 385 BIT_END: t.Optional[str] = None 386 HEX_START: t.Optional[str] = None 387 HEX_END: t.Optional[str] = None 388 BYTE_START: t.Optional[str] = None 389 BYTE_END: t.Optional[str] = None 390 UNICODE_START: t.Optional[str] = None 391 UNICODE_END: t.Optional[str] = None 392 393 DATE_PART_MAPPING = { 394 "Y": "YEAR", 395 "YY": "YEAR", 396 "YYY": "YEAR", 397 "YYYY": "YEAR", 398 "YR": "YEAR", 399 "YEARS": "YEAR", 400 "YRS": "YEAR", 401 "MM": "MONTH", 402 "MON": "MONTH", 403 "MONS": "MONTH", 404 "MONTHS": "MONTH", 405 "D": "DAY", 406 "DD": "DAY", 407 "DAYS": "DAY", 408 "DAYOFMONTH": "DAY", 409 "DAY OF WEEK": "DAYOFWEEK", 410 "WEEKDAY": "DAYOFWEEK", 411 "DOW": "DAYOFWEEK", 412 "DW": "DAYOFWEEK", 413 "WEEKDAY_ISO": "DAYOFWEEKISO", 414 "DOW_ISO": "DAYOFWEEKISO", 415 "DW_ISO": "DAYOFWEEKISO", 416 "DAY OF YEAR": "DAYOFYEAR", 417 "DOY": "DAYOFYEAR", 418 "DY": "DAYOFYEAR", 419 "W": "WEEK", 420 "WK": "WEEK", 421 "WEEKOFYEAR": "WEEK", 422 "WOY": "WEEK", 423 "WY": "WEEK", 424 "WEEK_ISO": "WEEKISO", 425 "WEEKOFYEARISO": "WEEKISO", 426 "WEEKOFYEAR_ISO": "WEEKISO", 427 "Q": "QUARTER", 428 "QTR": "QUARTER", 429 "QTRS": "QUARTER", 430 "QUARTERS": "QUARTER", 431 "H": "HOUR", 432 "HH": "HOUR", 433 "HR": "HOUR", 434 "HOURS": "HOUR", 435 "HRS": "HOUR", 436 "M": "MINUTE", 437 "MI": "MINUTE", 438 "MIN": "MINUTE", 439 "MINUTES": "MINUTE", 440 "MINS": "MINUTE", 441 "S": "SECOND", 442 "SEC": "SECOND", 443 "SECONDS": "SECOND", 444 "SECS": "SECOND", 445 "MS": "MILLISECOND", 446 "MSEC": "MILLISECOND", 447 "MSECS": "MILLISECOND", 448 "MSECOND": "MILLISECOND", 449 "MSECONDS": "MILLISECOND", 450 "MILLISEC": "MILLISECOND", 451 "MILLISECS": "MILLISECOND", 452 "MILLISECON": "MILLISECOND", 453 "MILLISECONDS": "MILLISECOND", 454 "US": "MICROSECOND", 455 "USEC": "MICROSECOND", 456 "USECS": "MICROSECOND", 457 "MICROSEC": "MICROSECOND", 458 "MICROSECS": "MICROSECOND", 459 "USECOND": "MICROSECOND", 460 "USECONDS": "MICROSECOND", 461 "MICROSECONDS": "MICROSECOND", 462 "NS": "NANOSECOND", 463 "NSEC": "NANOSECOND", 464 "NANOSEC": "NANOSECOND", 465 "NSECOND": "NANOSECOND", 466 "NSECONDS": "NANOSECOND", 467 "NANOSECS": "NANOSECOND", 468 "EPOCH_SECOND": "EPOCH", 469 "EPOCH_SECONDS": "EPOCH", 470 "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", 471 "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND", 472 "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND", 473 "TZH": "TIMEZONE_HOUR", 474 "TZM": "TIMEZONE_MINUTE", 475 "DEC": "DECADE", 476 "DECS": "DECADE", 477 "DECADES": "DECADE", 478 "MIL": "MILLENIUM", 479 "MILS": "MILLENIUM", 480 "MILLENIA": "MILLENIUM", 481 "C": "CENTURY", 482 "CENT": "CENTURY", 483 "CENTS": "CENTURY", 484 "CENTURIES": "CENTURY", 485 } 486 487 @classmethod 488 def get_or_raise(cls, dialect: DialectType) -> Dialect: 489 """ 490 Look up a dialect in the global dialect registry and return it if it exists. 491 492 Args: 493 dialect: The target dialect. If this is a string, it can be optionally followed by 494 additional key-value pairs that are separated by commas and are used to specify 495 dialect settings, such as whether the dialect's identifiers are case-sensitive. 496 497 Example: 498 >>> dialect = dialect_class = get_or_raise("duckdb") 499 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 500 501 Returns: 502 The corresponding Dialect instance. 503 """ 504 505 if not dialect: 506 return cls() 507 if isinstance(dialect, _Dialect): 508 return dialect() 509 if isinstance(dialect, Dialect): 510 return dialect 511 if isinstance(dialect, str): 512 try: 513 dialect_name, *kv_pairs = dialect.split(",") 514 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 515 except ValueError: 516 raise ValueError( 517 f"Invalid dialect format: '{dialect}'. " 518 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 519 ) 520 521 result = cls.get(dialect_name.strip()) 522 if not result: 523 from difflib import get_close_matches 524 525 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 526 if similar: 527 similar = f" Did you mean {similar}?" 528 529 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 530 531 return result(**kwargs) 532 533 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 534 535 @classmethod 536 def format_time( 537 cls, expression: t.Optional[str | exp.Expression] 538 ) -> t.Optional[exp.Expression]: 539 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 540 if isinstance(expression, str): 541 return exp.Literal.string( 542 # the time formats are quoted 543 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 544 ) 545 546 if expression and expression.is_string: 547 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 548 549 return expression 550 551 def __init__(self, **kwargs) -> None: 552 normalization_strategy = kwargs.pop("normalization_strategy", None) 553 554 if normalization_strategy is None: 555 self.normalization_strategy = self.NORMALIZATION_STRATEGY 556 else: 557 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 558 559 self.settings = kwargs 560 561 def __eq__(self, other: t.Any) -> bool: 562 # Does not currently take dialect state into account 563 return type(self) == other 564 565 def __hash__(self) -> int: 566 # Does not currently take dialect state into account 567 return hash(type(self)) 568 569 def normalize_identifier(self, expression: E) -> E: 570 """ 571 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 572 573 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 574 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 575 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 576 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 577 578 There are also dialects like Spark, which are case-insensitive even when quotes are 579 present, and dialects like MySQL, whose resolution rules match those employed by the 580 underlying operating system, for example they may always be case-sensitive in Linux. 581 582 Finally, the normalization behavior of some engines can even be controlled through flags, 583 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 584 585 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 586 that it can analyze queries in the optimizer and successfully capture their semantics. 587 """ 588 if ( 589 isinstance(expression, exp.Identifier) 590 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 591 and ( 592 not expression.quoted 593 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 594 ) 595 ): 596 expression.set( 597 "this", 598 ( 599 expression.this.upper() 600 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 601 else expression.this.lower() 602 ), 603 ) 604 605 return expression 606 607 def case_sensitive(self, text: str) -> bool: 608 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 609 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 610 return False 611 612 unsafe = ( 613 str.islower 614 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 615 else str.isupper 616 ) 617 return any(unsafe(char) for char in text) 618 619 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 620 """Checks if text can be identified given an identify option. 621 622 Args: 623 text: The text to check. 624 identify: 625 `"always"` or `True`: Always returns `True`. 626 `"safe"`: Only returns `True` if the identifier is case-insensitive. 627 628 Returns: 629 Whether the given text can be identified. 630 """ 631 if identify is True or identify == "always": 632 return True 633 634 if identify == "safe": 635 return not self.case_sensitive(text) 636 637 return False 638 639 def quote_identifier(self, expression: E, identify: bool = True) -> E: 640 """ 641 Adds quotes to a given identifier. 642 643 Args: 644 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 645 identify: If set to `False`, the quotes will only be added if the identifier is deemed 646 "unsafe", with respect to its characters and this dialect's normalization strategy. 647 """ 648 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 649 name = expression.this 650 expression.set( 651 "quoted", 652 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 653 ) 654 655 return expression 656 657 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 658 if isinstance(path, exp.Literal): 659 path_text = path.name 660 if path.is_number: 661 path_text = f"[{path_text}]" 662 try: 663 return parse_json_path(path_text, self) 664 except ParseError as e: 665 logger.warning(f"Invalid JSON path syntax. {str(e)}") 666 667 return path 668 669 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 670 return self.parser(**opts).parse(self.tokenize(sql), sql) 671 672 def parse_into( 673 self, expression_type: exp.IntoType, sql: str, **opts 674 ) -> t.List[t.Optional[exp.Expression]]: 675 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 676 677 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 678 return self.generator(**opts).generate(expression, copy=copy) 679 680 def transpile(self, sql: str, **opts) -> t.List[str]: 681 return [ 682 self.generate(expression, copy=False, **opts) if expression else "" 683 for expression in self.parse(sql) 684 ] 685 686 def tokenize(self, sql: str) -> t.List[Token]: 687 return self.tokenizer.tokenize(sql) 688 689 @property 690 def tokenizer(self) -> Tokenizer: 691 return self.tokenizer_class(dialect=self) 692 693 @property 694 def jsonpath_tokenizer(self) -> JSONPathTokenizer: 695 return self.jsonpath_tokenizer_class(dialect=self) 696 697 def parser(self, **opts) -> Parser: 698 return self.parser_class(dialect=self, **opts) 699 700 def generator(self, **opts) -> Generator: 701 return self.generator_class(dialect=self, **opts) 702 703 704DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 705 706 707def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 708 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 709 710 711def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 712 if expression.args.get("accuracy"): 713 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 714 return self.func("APPROX_COUNT_DISTINCT", expression.this) 715 716 717def if_sql( 718 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 719) -> t.Callable[[Generator, exp.If], str]: 720 def _if_sql(self: Generator, expression: exp.If) -> str: 721 return self.func( 722 name, 723 expression.this, 724 expression.args.get("true"), 725 expression.args.get("false") or false_value, 726 ) 727 728 return _if_sql 729 730 731def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 732 this = expression.this 733 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 734 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 735 736 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>") 737 738 739def inline_array_sql(self: Generator, expression: exp.Array) -> str: 740 return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]" 741 742 743def inline_array_unless_query(self: Generator, expression: exp.Array) -> str: 744 elem = seq_get(expression.expressions, 0) 745 if isinstance(elem, exp.Expression) and elem.find(exp.Query): 746 return self.func("ARRAY", elem) 747 return inline_array_sql(self, expression) 748 749 750def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 751 return self.like_sql( 752 exp.Like( 753 this=exp.Lower(this=expression.this), expression=exp.Lower(this=expression.expression) 754 ) 755 ) 756 757 758def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 759 zone = self.sql(expression, "this") 760 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 761 762 763def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 764 if expression.args.get("recursive"): 765 self.unsupported("Recursive CTEs are unsupported") 766 expression.args["recursive"] = False 767 return self.with_sql(expression) 768 769 770def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 771 n = self.sql(expression, "this") 772 d = self.sql(expression, "expression") 773 return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)" 774 775 776def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 777 self.unsupported("TABLESAMPLE unsupported") 778 return self.sql(expression.this) 779 780 781def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 782 self.unsupported("PIVOT unsupported") 783 return "" 784 785 786def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 787 return self.cast_sql(expression) 788 789 790def no_comment_column_constraint_sql( 791 self: Generator, expression: exp.CommentColumnConstraint 792) -> str: 793 self.unsupported("CommentColumnConstraint unsupported") 794 return "" 795 796 797def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 798 self.unsupported("MAP_FROM_ENTRIES unsupported") 799 return "" 800 801 802def str_position_sql( 803 self: Generator, expression: exp.StrPosition, generate_instance: bool = False 804) -> str: 805 this = self.sql(expression, "this") 806 substr = self.sql(expression, "substr") 807 position = self.sql(expression, "position") 808 instance = expression.args.get("instance") if generate_instance else None 809 position_offset = "" 810 811 if position: 812 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 813 this = self.func("SUBSTR", this, position) 814 position_offset = f" + {position} - 1" 815 816 return self.func("STRPOS", this, substr, instance) + position_offset 817 818 819def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 820 return ( 821 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 822 ) 823 824 825def var_map_sql( 826 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 827) -> str: 828 keys = expression.args["keys"] 829 values = expression.args["values"] 830 831 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 832 self.unsupported("Cannot convert array columns into map.") 833 return self.func(map_func_name, keys, values) 834 835 args = [] 836 for key, value in zip(keys.expressions, values.expressions): 837 args.append(self.sql(key)) 838 args.append(self.sql(value)) 839 840 return self.func(map_func_name, *args) 841 842 843def build_formatted_time( 844 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 845) -> t.Callable[[t.List], E]: 846 """Helper used for time expressions. 847 848 Args: 849 exp_class: the expression class to instantiate. 850 dialect: target sql dialect. 851 default: the default format, True being time. 852 853 Returns: 854 A callable that can be used to return the appropriately formatted time expression. 855 """ 856 857 def _builder(args: t.List): 858 return exp_class( 859 this=seq_get(args, 0), 860 format=Dialect[dialect].format_time( 861 seq_get(args, 1) 862 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 863 ), 864 ) 865 866 return _builder 867 868 869def time_format( 870 dialect: DialectType = None, 871) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 872 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 873 """ 874 Returns the time format for a given expression, unless it's equivalent 875 to the default time format of the dialect of interest. 876 """ 877 time_format = self.format_time(expression) 878 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 879 880 return _time_format 881 882 883def build_date_delta( 884 exp_class: t.Type[E], 885 unit_mapping: t.Optional[t.Dict[str, str]] = None, 886 default_unit: t.Optional[str] = "DAY", 887) -> t.Callable[[t.List], E]: 888 def _builder(args: t.List) -> E: 889 unit_based = len(args) == 3 890 this = args[2] if unit_based else seq_get(args, 0) 891 unit = None 892 if unit_based or default_unit: 893 unit = args[0] if unit_based else exp.Literal.string(default_unit) 894 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 895 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 896 897 return _builder 898 899 900def build_date_delta_with_interval( 901 expression_class: t.Type[E], 902) -> t.Callable[[t.List], t.Optional[E]]: 903 def _builder(args: t.List) -> t.Optional[E]: 904 if len(args) < 2: 905 return None 906 907 interval = args[1] 908 909 if not isinstance(interval, exp.Interval): 910 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 911 912 expression = interval.this 913 if expression and expression.is_string: 914 expression = exp.Literal.number(expression.this) 915 916 return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval)) 917 918 return _builder 919 920 921def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 922 unit = seq_get(args, 0) 923 this = seq_get(args, 1) 924 925 if isinstance(this, exp.Cast) and this.is_type("date"): 926 return exp.DateTrunc(unit=unit, this=this) 927 return exp.TimestampTrunc(this=this, unit=unit) 928 929 930def date_add_interval_sql( 931 data_type: str, kind: str 932) -> t.Callable[[Generator, exp.Expression], str]: 933 def func(self: Generator, expression: exp.Expression) -> str: 934 this = self.sql(expression, "this") 935 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 936 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 937 938 return func 939 940 941def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 942 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 943 args = [unit_to_str(expression), expression.this] 944 if zone: 945 args.append(expression.args.get("zone")) 946 return self.func("DATE_TRUNC", *args) 947 948 return _timestamptrunc_sql 949 950 951def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 952 zone = expression.args.get("zone") 953 if not zone: 954 from sqlglot.optimizer.annotate_types import annotate_types 955 956 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 957 return self.sql(exp.cast(expression.this, target_type)) 958 if zone.name.lower() in TIMEZONES: 959 return self.sql( 960 exp.AtTimeZone( 961 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 962 zone=zone, 963 ) 964 ) 965 return self.func("TIMESTAMP", expression.this, zone) 966 967 968def no_time_sql(self: Generator, expression: exp.Time) -> str: 969 # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME) 970 this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ) 971 expr = exp.cast( 972 exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME 973 ) 974 return self.sql(expr) 975 976 977def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str: 978 this = expression.this 979 expr = expression.expression 980 981 if expr.name.lower() in TIMEZONES: 982 # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP) 983 this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ) 984 this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP) 985 return self.sql(this) 986 987 this = exp.cast(this, exp.DataType.Type.DATE) 988 expr = exp.cast(expr, exp.DataType.Type.TIME) 989 990 return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP)) 991 992 993def locate_to_strposition(args: t.List) -> exp.Expression: 994 return exp.StrPosition( 995 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 996 ) 997 998 999def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 1000 return self.func( 1001 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 1002 ) 1003 1004 1005def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 1006 return self.sql( 1007 exp.Substring( 1008 this=expression.this, start=exp.Literal.number(1), length=expression.expression 1009 ) 1010 ) 1011 1012 1013def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 1014 return self.sql( 1015 exp.Substring( 1016 this=expression.this, 1017 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 1018 ) 1019 ) 1020 1021 1022def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 1023 return self.sql(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP)) 1024 1025 1026def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 1027 return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE)) 1028 1029 1030# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 1031def encode_decode_sql( 1032 self: Generator, expression: exp.Expression, name: str, replace: bool = True 1033) -> str: 1034 charset = expression.args.get("charset") 1035 if charset and charset.name.lower() != "utf-8": 1036 self.unsupported(f"Expected utf-8 character set, got {charset}.") 1037 1038 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 1039 1040 1041def min_or_least(self: Generator, expression: exp.Min) -> str: 1042 name = "LEAST" if expression.expressions else "MIN" 1043 return rename_func(name)(self, expression) 1044 1045 1046def max_or_greatest(self: Generator, expression: exp.Max) -> str: 1047 name = "GREATEST" if expression.expressions else "MAX" 1048 return rename_func(name)(self, expression) 1049 1050 1051def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 1052 cond = expression.this 1053 1054 if isinstance(expression.this, exp.Distinct): 1055 cond = expression.this.expressions[0] 1056 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 1057 1058 return self.func("sum", exp.func("if", cond, 1, 0)) 1059 1060 1061def trim_sql(self: Generator, expression: exp.Trim) -> str: 1062 target = self.sql(expression, "this") 1063 trim_type = self.sql(expression, "position") 1064 remove_chars = self.sql(expression, "expression") 1065 collation = self.sql(expression, "collation") 1066 1067 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 1068 if not remove_chars and not collation: 1069 return self.trim_sql(expression) 1070 1071 trim_type = f"{trim_type} " if trim_type else "" 1072 remove_chars = f"{remove_chars} " if remove_chars else "" 1073 from_part = "FROM " if trim_type or remove_chars else "" 1074 collation = f" COLLATE {collation}" if collation else "" 1075 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 1076 1077 1078def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 1079 return self.func("STRPTIME", expression.this, self.format_time(expression)) 1080 1081 1082def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: 1083 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 1084 1085 1086def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 1087 delim, *rest_args = expression.expressions 1088 return self.sql( 1089 reduce( 1090 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 1091 rest_args, 1092 ) 1093 ) 1094 1095 1096def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 1097 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 1098 if bad_args: 1099 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 1100 1101 return self.func( 1102 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 1103 ) 1104 1105 1106def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 1107 bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers"))) 1108 if bad_args: 1109 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 1110 1111 return self.func( 1112 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 1113 ) 1114 1115 1116def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 1117 names = [] 1118 for agg in aggregations: 1119 if isinstance(agg, exp.Alias): 1120 names.append(agg.alias) 1121 else: 1122 """ 1123 This case corresponds to aggregations without aliases being used as suffixes 1124 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 1125 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 1126 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 1127 """ 1128 agg_all_unquoted = agg.transform( 1129 lambda node: ( 1130 exp.Identifier(this=node.name, quoted=False) 1131 if isinstance(node, exp.Identifier) 1132 else node 1133 ) 1134 ) 1135 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 1136 1137 return names 1138 1139 1140def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 1141 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 1142 1143 1144# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 1145def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 1146 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 1147 1148 1149def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 1150 return self.func("MAX", expression.this) 1151 1152 1153def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: 1154 a = self.sql(expression.left) 1155 b = self.sql(expression.right) 1156 return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" 1157 1158 1159def is_parse_json(expression: exp.Expression) -> bool: 1160 return isinstance(expression, exp.ParseJSON) or ( 1161 isinstance(expression, exp.Cast) and expression.is_type("json") 1162 ) 1163 1164 1165def isnull_to_is_null(args: t.List) -> exp.Expression: 1166 return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) 1167 1168 1169def generatedasidentitycolumnconstraint_sql( 1170 self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint 1171) -> str: 1172 start = self.sql(expression, "start") or "1" 1173 increment = self.sql(expression, "increment") or "1" 1174 return f"IDENTITY({start}, {increment})" 1175 1176 1177def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 1178 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 1179 if expression.args.get("count"): 1180 self.unsupported(f"Only two arguments are supported in function {name}.") 1181 1182 return self.func(name, expression.this, expression.expression) 1183 1184 return _arg_max_or_min_sql 1185 1186 1187def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1188 this = expression.this.copy() 1189 1190 return_type = expression.return_type 1191 if return_type.is_type(exp.DataType.Type.DATE): 1192 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1193 # can truncate timestamp strings, because some dialects can't cast them to DATE 1194 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1195 1196 expression.this.replace(exp.cast(this, return_type)) 1197 return expression 1198 1199 1200def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1201 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1202 if cast and isinstance(expression, exp.TsOrDsAdd): 1203 expression = ts_or_ds_add_cast(expression) 1204 1205 return self.func( 1206 name, 1207 unit_to_var(expression), 1208 expression.expression, 1209 expression.this, 1210 ) 1211 1212 return _delta_sql 1213 1214 1215def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1216 unit = expression.args.get("unit") 1217 1218 if isinstance(unit, exp.Placeholder): 1219 return unit 1220 if unit: 1221 return exp.Literal.string(unit.name) 1222 return exp.Literal.string(default) if default else None 1223 1224 1225def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1226 unit = expression.args.get("unit") 1227 1228 if isinstance(unit, (exp.Var, exp.Placeholder)): 1229 return unit 1230 return exp.Var(this=default) if default else None 1231 1232 1233@t.overload 1234def map_date_part(part: exp.Expression, dialect: DialectType = Dialect) -> exp.Var: 1235 pass 1236 1237 1238@t.overload 1239def map_date_part( 1240 part: t.Optional[exp.Expression], dialect: DialectType = Dialect 1241) -> t.Optional[exp.Expression]: 1242 pass 1243 1244 1245def map_date_part(part, dialect: DialectType = Dialect): 1246 mapped = ( 1247 Dialect.get_or_raise(dialect).DATE_PART_MAPPING.get(part.name.upper()) if part else None 1248 ) 1249 return exp.var(mapped) if mapped else part 1250 1251 1252def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1253 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1254 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1255 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1256 1257 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE)) 1258 1259 1260def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1261 """Remove table refs from columns in when statements.""" 1262 alias = expression.this.args.get("alias") 1263 1264 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1265 return self.dialect.normalize_identifier(identifier).name if identifier else None 1266 1267 targets = {normalize(expression.this.this)} 1268 1269 if alias: 1270 targets.add(normalize(alias.this)) 1271 1272 for when in expression.expressions: 1273 when.transform( 1274 lambda node: ( 1275 exp.column(node.this) 1276 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1277 else node 1278 ), 1279 copy=False, 1280 ) 1281 1282 return self.merge_sql(expression) 1283 1284 1285def build_json_extract_path( 1286 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1287) -> t.Callable[[t.List], F]: 1288 def _builder(args: t.List) -> F: 1289 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1290 for arg in args[1:]: 1291 if not isinstance(arg, exp.Literal): 1292 # We use the fallback parser because we can't really transpile non-literals safely 1293 return expr_type.from_arg_list(args) 1294 1295 text = arg.name 1296 if is_int(text): 1297 index = int(text) 1298 segments.append( 1299 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1300 ) 1301 else: 1302 segments.append(exp.JSONPathKey(this=text)) 1303 1304 # This is done to avoid failing in the expression validator due to the arg count 1305 del args[2:] 1306 return expr_type( 1307 this=seq_get(args, 0), 1308 expression=exp.JSONPath(expressions=segments), 1309 only_json_types=arrow_req_json_type, 1310 ) 1311 1312 return _builder 1313 1314 1315def json_extract_segments( 1316 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1317) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1318 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1319 path = expression.expression 1320 if not isinstance(path, exp.JSONPath): 1321 return rename_func(name)(self, expression) 1322 1323 segments = [] 1324 for segment in path.expressions: 1325 path = self.sql(segment) 1326 if path: 1327 if isinstance(segment, exp.JSONPathPart) and ( 1328 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1329 ): 1330 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1331 1332 segments.append(path) 1333 1334 if op: 1335 return f" {op} ".join([self.sql(expression.this), *segments]) 1336 return self.func(name, expression.this, *segments) 1337 1338 return _json_extract_segments 1339 1340 1341def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str: 1342 if isinstance(expression.this, exp.JSONPathWildcard): 1343 self.unsupported("Unsupported wildcard in JSONPathKey expression") 1344 1345 return expression.name 1346 1347 1348def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1349 cond = expression.expression 1350 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1351 alias = cond.expressions[0] 1352 cond = cond.this 1353 elif isinstance(cond, exp.Predicate): 1354 alias = "_u" 1355 else: 1356 self.unsupported("Unsupported filter condition") 1357 return "" 1358 1359 unnest = exp.Unnest(expressions=[expression.this]) 1360 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1361 return self.sql(exp.Array(expressions=[filtered])) 1362 1363 1364def to_number_with_nls_param(self: Generator, expression: exp.ToNumber) -> str: 1365 return self.func( 1366 "TO_NUMBER", 1367 expression.this, 1368 expression.args.get("format"), 1369 expression.args.get("nlsparam"), 1370 ) 1371 1372 1373def build_default_decimal_type( 1374 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1375) -> t.Callable[[exp.DataType], exp.DataType]: 1376 def _builder(dtype: exp.DataType) -> exp.DataType: 1377 if dtype.expressions or precision is None: 1378 return dtype 1379 1380 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1381 return exp.DataType.build(f"DECIMAL({params})") 1382 1383 return _builder 1384 1385 1386def build_timestamp_from_parts(args: t.List) -> exp.Func: 1387 if len(args) == 2: 1388 # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, 1389 # so we parse this into Anonymous for now instead of introducing complexity 1390 return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) 1391 1392 return exp.TimestampFromParts.from_arg_list(args) 1393 1394 1395def sha256_sql(self: Generator, expression: exp.SHA2) -> str: 1396 return self.func(f"SHA{expression.text('length') or '256'}", expression.this)
41class Dialects(str, Enum): 42 """Dialects supported by SQLGLot.""" 43 44 DIALECT = "" 45 46 ATHENA = "athena" 47 BIGQUERY = "bigquery" 48 CLICKHOUSE = "clickhouse" 49 DATABRICKS = "databricks" 50 DORIS = "doris" 51 DRILL = "drill" 52 DUCKDB = "duckdb" 53 HIVE = "hive" 54 MATERIALIZE = "materialize" 55 MYSQL = "mysql" 56 ORACLE = "oracle" 57 POSTGRES = "postgres" 58 PRESTO = "presto" 59 PRQL = "prql" 60 REDSHIFT = "redshift" 61 RISINGWAVE = "risingwave" 62 SNOWFLAKE = "snowflake" 63 SPARK = "spark" 64 SPARK2 = "spark2" 65 SQLITE = "sqlite" 66 STARROCKS = "starrocks" 67 TABLEAU = "tableau" 68 TERADATA = "teradata" 69 TRINO = "trino" 70 TSQL = "tsql"
Dialects supported by SQLGLot.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
73class NormalizationStrategy(str, AutoName): 74 """Specifies the strategy according to which identifiers should be normalized.""" 75 76 LOWERCASE = auto() 77 """Unquoted identifiers are lowercased.""" 78 79 UPPERCASE = auto() 80 """Unquoted identifiers are uppercased.""" 81 82 CASE_SENSITIVE = auto() 83 """Always case-sensitive, regardless of quotes.""" 84 85 CASE_INSENSITIVE = auto() 86 """Always case-insensitive, regardless of quotes."""
Specifies the strategy according to which identifiers should be normalized.
Always case-sensitive, regardless of quotes.
Always case-insensitive, regardless of quotes.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
207class Dialect(metaclass=_Dialect): 208 INDEX_OFFSET = 0 209 """The base index offset for arrays.""" 210 211 WEEK_OFFSET = 0 212 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 213 214 UNNEST_COLUMN_ONLY = False 215 """Whether `UNNEST` table aliases are treated as column aliases.""" 216 217 ALIAS_POST_TABLESAMPLE = False 218 """Whether the table alias comes after tablesample.""" 219 220 TABLESAMPLE_SIZE_IS_PERCENT = False 221 """Whether a size in the table sample clause represents percentage.""" 222 223 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 224 """Specifies the strategy according to which identifiers should be normalized.""" 225 226 IDENTIFIERS_CAN_START_WITH_DIGIT = False 227 """Whether an unquoted identifier can start with a digit.""" 228 229 DPIPE_IS_STRING_CONCAT = True 230 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 231 232 STRICT_STRING_CONCAT = False 233 """Whether `CONCAT`'s arguments must be strings.""" 234 235 SUPPORTS_USER_DEFINED_TYPES = True 236 """Whether user-defined data types are supported.""" 237 238 SUPPORTS_SEMI_ANTI_JOIN = True 239 """Whether `SEMI` or `ANTI` joins are supported.""" 240 241 SUPPORTS_COLUMN_JOIN_MARKS = False 242 """Whether the old-style outer join (+) syntax is supported.""" 243 244 COPY_PARAMS_ARE_CSV = True 245 """Separator of COPY statement parameters.""" 246 247 NORMALIZE_FUNCTIONS: bool | str = "upper" 248 """ 249 Determines how function names are going to be normalized. 250 Possible values: 251 "upper" or True: Convert names to uppercase. 252 "lower": Convert names to lowercase. 253 False: Disables function name normalization. 254 """ 255 256 LOG_BASE_FIRST: t.Optional[bool] = True 257 """ 258 Whether the base comes first in the `LOG` function. 259 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 260 """ 261 262 NULL_ORDERING = "nulls_are_small" 263 """ 264 Default `NULL` ordering method to use if not explicitly set. 265 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 266 """ 267 268 TYPED_DIVISION = False 269 """ 270 Whether the behavior of `a / b` depends on the types of `a` and `b`. 271 False means `a / b` is always float division. 272 True means `a / b` is integer division if both `a` and `b` are integers. 273 """ 274 275 SAFE_DIVISION = False 276 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 277 278 CONCAT_COALESCE = False 279 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 280 281 HEX_LOWERCASE = False 282 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 283 284 DATE_FORMAT = "'%Y-%m-%d'" 285 DATEINT_FORMAT = "'%Y%m%d'" 286 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 287 288 TIME_MAPPING: t.Dict[str, str] = {} 289 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 290 291 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 292 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 293 FORMAT_MAPPING: t.Dict[str, str] = {} 294 """ 295 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 296 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 297 """ 298 299 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 300 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 301 302 PSEUDOCOLUMNS: t.Set[str] = set() 303 """ 304 Columns that are auto-generated by the engine corresponding to this dialect. 305 For example, such columns may be excluded from `SELECT *` queries. 306 """ 307 308 PREFER_CTE_ALIAS_COLUMN = False 309 """ 310 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 311 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 312 any projection aliases in the subquery. 313 314 For example, 315 WITH y(c) AS ( 316 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 317 ) SELECT c FROM y; 318 319 will be rewritten as 320 321 WITH y(c) AS ( 322 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 323 ) SELECT c FROM y; 324 """ 325 326 COPY_PARAMS_ARE_CSV = True 327 """ 328 Whether COPY statement parameters are separated by comma or whitespace 329 """ 330 331 FORCE_EARLY_ALIAS_REF_EXPANSION = False 332 """ 333 Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()). 334 335 For example: 336 WITH data AS ( 337 SELECT 338 1 AS id, 339 2 AS my_id 340 ) 341 SELECT 342 id AS my_id 343 FROM 344 data 345 WHERE 346 my_id = 1 347 GROUP BY 348 my_id, 349 HAVING 350 my_id = 1 351 352 In most dialects "my_id" would refer to "data.my_id" (which is done in _qualify_columns()) across the query, except: 353 - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" 354 - Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1" 355 """ 356 357 EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False 358 """Whether alias reference expansion before qualification should only happen for the GROUP BY clause.""" 359 360 # --- Autofilled --- 361 362 tokenizer_class = Tokenizer 363 jsonpath_tokenizer_class = JSONPathTokenizer 364 parser_class = Parser 365 generator_class = Generator 366 367 # A trie of the time_mapping keys 368 TIME_TRIE: t.Dict = {} 369 FORMAT_TRIE: t.Dict = {} 370 371 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 372 INVERSE_TIME_TRIE: t.Dict = {} 373 INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {} 374 INVERSE_FORMAT_TRIE: t.Dict = {} 375 376 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 377 378 # Delimiters for string literals and identifiers 379 QUOTE_START = "'" 380 QUOTE_END = "'" 381 IDENTIFIER_START = '"' 382 IDENTIFIER_END = '"' 383 384 # Delimiters for bit, hex, byte and unicode literals 385 BIT_START: t.Optional[str] = None 386 BIT_END: t.Optional[str] = None 387 HEX_START: t.Optional[str] = None 388 HEX_END: t.Optional[str] = None 389 BYTE_START: t.Optional[str] = None 390 BYTE_END: t.Optional[str] = None 391 UNICODE_START: t.Optional[str] = None 392 UNICODE_END: t.Optional[str] = None 393 394 DATE_PART_MAPPING = { 395 "Y": "YEAR", 396 "YY": "YEAR", 397 "YYY": "YEAR", 398 "YYYY": "YEAR", 399 "YR": "YEAR", 400 "YEARS": "YEAR", 401 "YRS": "YEAR", 402 "MM": "MONTH", 403 "MON": "MONTH", 404 "MONS": "MONTH", 405 "MONTHS": "MONTH", 406 "D": "DAY", 407 "DD": "DAY", 408 "DAYS": "DAY", 409 "DAYOFMONTH": "DAY", 410 "DAY OF WEEK": "DAYOFWEEK", 411 "WEEKDAY": "DAYOFWEEK", 412 "DOW": "DAYOFWEEK", 413 "DW": "DAYOFWEEK", 414 "WEEKDAY_ISO": "DAYOFWEEKISO", 415 "DOW_ISO": "DAYOFWEEKISO", 416 "DW_ISO": "DAYOFWEEKISO", 417 "DAY OF YEAR": "DAYOFYEAR", 418 "DOY": "DAYOFYEAR", 419 "DY": "DAYOFYEAR", 420 "W": "WEEK", 421 "WK": "WEEK", 422 "WEEKOFYEAR": "WEEK", 423 "WOY": "WEEK", 424 "WY": "WEEK", 425 "WEEK_ISO": "WEEKISO", 426 "WEEKOFYEARISO": "WEEKISO", 427 "WEEKOFYEAR_ISO": "WEEKISO", 428 "Q": "QUARTER", 429 "QTR": "QUARTER", 430 "QTRS": "QUARTER", 431 "QUARTERS": "QUARTER", 432 "H": "HOUR", 433 "HH": "HOUR", 434 "HR": "HOUR", 435 "HOURS": "HOUR", 436 "HRS": "HOUR", 437 "M": "MINUTE", 438 "MI": "MINUTE", 439 "MIN": "MINUTE", 440 "MINUTES": "MINUTE", 441 "MINS": "MINUTE", 442 "S": "SECOND", 443 "SEC": "SECOND", 444 "SECONDS": "SECOND", 445 "SECS": "SECOND", 446 "MS": "MILLISECOND", 447 "MSEC": "MILLISECOND", 448 "MSECS": "MILLISECOND", 449 "MSECOND": "MILLISECOND", 450 "MSECONDS": "MILLISECOND", 451 "MILLISEC": "MILLISECOND", 452 "MILLISECS": "MILLISECOND", 453 "MILLISECON": "MILLISECOND", 454 "MILLISECONDS": "MILLISECOND", 455 "US": "MICROSECOND", 456 "USEC": "MICROSECOND", 457 "USECS": "MICROSECOND", 458 "MICROSEC": "MICROSECOND", 459 "MICROSECS": "MICROSECOND", 460 "USECOND": "MICROSECOND", 461 "USECONDS": "MICROSECOND", 462 "MICROSECONDS": "MICROSECOND", 463 "NS": "NANOSECOND", 464 "NSEC": "NANOSECOND", 465 "NANOSEC": "NANOSECOND", 466 "NSECOND": "NANOSECOND", 467 "NSECONDS": "NANOSECOND", 468 "NANOSECS": "NANOSECOND", 469 "EPOCH_SECOND": "EPOCH", 470 "EPOCH_SECONDS": "EPOCH", 471 "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", 472 "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND", 473 "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND", 474 "TZH": "TIMEZONE_HOUR", 475 "TZM": "TIMEZONE_MINUTE", 476 "DEC": "DECADE", 477 "DECS": "DECADE", 478 "DECADES": "DECADE", 479 "MIL": "MILLENIUM", 480 "MILS": "MILLENIUM", 481 "MILLENIA": "MILLENIUM", 482 "C": "CENTURY", 483 "CENT": "CENTURY", 484 "CENTS": "CENTURY", 485 "CENTURIES": "CENTURY", 486 } 487 488 @classmethod 489 def get_or_raise(cls, dialect: DialectType) -> Dialect: 490 """ 491 Look up a dialect in the global dialect registry and return it if it exists. 492 493 Args: 494 dialect: The target dialect. If this is a string, it can be optionally followed by 495 additional key-value pairs that are separated by commas and are used to specify 496 dialect settings, such as whether the dialect's identifiers are case-sensitive. 497 498 Example: 499 >>> dialect = dialect_class = get_or_raise("duckdb") 500 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 501 502 Returns: 503 The corresponding Dialect instance. 504 """ 505 506 if not dialect: 507 return cls() 508 if isinstance(dialect, _Dialect): 509 return dialect() 510 if isinstance(dialect, Dialect): 511 return dialect 512 if isinstance(dialect, str): 513 try: 514 dialect_name, *kv_pairs = dialect.split(",") 515 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 516 except ValueError: 517 raise ValueError( 518 f"Invalid dialect format: '{dialect}'. " 519 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 520 ) 521 522 result = cls.get(dialect_name.strip()) 523 if not result: 524 from difflib import get_close_matches 525 526 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 527 if similar: 528 similar = f" Did you mean {similar}?" 529 530 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 531 532 return result(**kwargs) 533 534 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 535 536 @classmethod 537 def format_time( 538 cls, expression: t.Optional[str | exp.Expression] 539 ) -> t.Optional[exp.Expression]: 540 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 541 if isinstance(expression, str): 542 return exp.Literal.string( 543 # the time formats are quoted 544 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 545 ) 546 547 if expression and expression.is_string: 548 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 549 550 return expression 551 552 def __init__(self, **kwargs) -> None: 553 normalization_strategy = kwargs.pop("normalization_strategy", None) 554 555 if normalization_strategy is None: 556 self.normalization_strategy = self.NORMALIZATION_STRATEGY 557 else: 558 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 559 560 self.settings = kwargs 561 562 def __eq__(self, other: t.Any) -> bool: 563 # Does not currently take dialect state into account 564 return type(self) == other 565 566 def __hash__(self) -> int: 567 # Does not currently take dialect state into account 568 return hash(type(self)) 569 570 def normalize_identifier(self, expression: E) -> E: 571 """ 572 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 573 574 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 575 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 576 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 577 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 578 579 There are also dialects like Spark, which are case-insensitive even when quotes are 580 present, and dialects like MySQL, whose resolution rules match those employed by the 581 underlying operating system, for example they may always be case-sensitive in Linux. 582 583 Finally, the normalization behavior of some engines can even be controlled through flags, 584 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 585 586 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 587 that it can analyze queries in the optimizer and successfully capture their semantics. 588 """ 589 if ( 590 isinstance(expression, exp.Identifier) 591 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 592 and ( 593 not expression.quoted 594 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 595 ) 596 ): 597 expression.set( 598 "this", 599 ( 600 expression.this.upper() 601 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 602 else expression.this.lower() 603 ), 604 ) 605 606 return expression 607 608 def case_sensitive(self, text: str) -> bool: 609 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 610 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 611 return False 612 613 unsafe = ( 614 str.islower 615 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 616 else str.isupper 617 ) 618 return any(unsafe(char) for char in text) 619 620 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 621 """Checks if text can be identified given an identify option. 622 623 Args: 624 text: The text to check. 625 identify: 626 `"always"` or `True`: Always returns `True`. 627 `"safe"`: Only returns `True` if the identifier is case-insensitive. 628 629 Returns: 630 Whether the given text can be identified. 631 """ 632 if identify is True or identify == "always": 633 return True 634 635 if identify == "safe": 636 return not self.case_sensitive(text) 637 638 return False 639 640 def quote_identifier(self, expression: E, identify: bool = True) -> E: 641 """ 642 Adds quotes to a given identifier. 643 644 Args: 645 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 646 identify: If set to `False`, the quotes will only be added if the identifier is deemed 647 "unsafe", with respect to its characters and this dialect's normalization strategy. 648 """ 649 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 650 name = expression.this 651 expression.set( 652 "quoted", 653 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 654 ) 655 656 return expression 657 658 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 659 if isinstance(path, exp.Literal): 660 path_text = path.name 661 if path.is_number: 662 path_text = f"[{path_text}]" 663 try: 664 return parse_json_path(path_text, self) 665 except ParseError as e: 666 logger.warning(f"Invalid JSON path syntax. {str(e)}") 667 668 return path 669 670 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 671 return self.parser(**opts).parse(self.tokenize(sql), sql) 672 673 def parse_into( 674 self, expression_type: exp.IntoType, sql: str, **opts 675 ) -> t.List[t.Optional[exp.Expression]]: 676 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 677 678 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 679 return self.generator(**opts).generate(expression, copy=copy) 680 681 def transpile(self, sql: str, **opts) -> t.List[str]: 682 return [ 683 self.generate(expression, copy=False, **opts) if expression else "" 684 for expression in self.parse(sql) 685 ] 686 687 def tokenize(self, sql: str) -> t.List[Token]: 688 return self.tokenizer.tokenize(sql) 689 690 @property 691 def tokenizer(self) -> Tokenizer: 692 return self.tokenizer_class(dialect=self) 693 694 @property 695 def jsonpath_tokenizer(self) -> JSONPathTokenizer: 696 return self.jsonpath_tokenizer_class(dialect=self) 697 698 def parser(self, **opts) -> Parser: 699 return self.parser_class(dialect=self, **opts) 700 701 def generator(self, **opts) -> Generator: 702 return self.generator_class(dialect=self, **opts)
552 def __init__(self, **kwargs) -> None: 553 normalization_strategy = kwargs.pop("normalization_strategy", None) 554 555 if normalization_strategy is None: 556 self.normalization_strategy = self.NORMALIZATION_STRATEGY 557 else: 558 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 559 560 self.settings = kwargs
First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.
Whether a size in the table sample clause represents percentage.
Specifies the strategy according to which identifiers should be normalized.
Determines how function names are going to be normalized.
Possible values:
"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.
Whether the base comes first in the LOG
function.
Possible values: True
, False
, None
(two arguments are not supported by LOG
)
Default NULL
ordering method to use if not explicitly set.
Possible values: "nulls_are_small"
, "nulls_are_large"
, "nulls_are_last"
Whether the behavior of a / b
depends on the types of a
and b
.
False means a / b
is always float division.
True means a / b
is integer division if both a
and b
are integers.
A NULL
arg in CONCAT
yields NULL
by default, but in some dialects it yields an empty string.
Associates this dialect's time formats with their equivalent Python strftime
formats.
Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy')
.
If empty, the corresponding trie will be constructed off of TIME_MAPPING
.
Mapping of an escaped sequence (\n
) to its unescaped version (
).
Columns that are auto-generated by the engine corresponding to this dialect.
For example, such columns may be excluded from SELECT *
queries.
Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.
For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;
will be rewritten as
WITH y(c) AS (
SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()).
For example:
WITH data AS ( SELECT 1 AS id, 2 AS my_id ) SELECT id AS my_id FROM data WHERE my_id = 1 GROUP BY my_id, HAVING my_id = 1
In most dialects "my_id" would refer to "data.my_id" (which is done in _qualify_columns()) across the query, except: - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" - Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1"
Whether alias reference expansion before qualification should only happen for the GROUP BY clause.
488 @classmethod 489 def get_or_raise(cls, dialect: DialectType) -> Dialect: 490 """ 491 Look up a dialect in the global dialect registry and return it if it exists. 492 493 Args: 494 dialect: The target dialect. If this is a string, it can be optionally followed by 495 additional key-value pairs that are separated by commas and are used to specify 496 dialect settings, such as whether the dialect's identifiers are case-sensitive. 497 498 Example: 499 >>> dialect = dialect_class = get_or_raise("duckdb") 500 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 501 502 Returns: 503 The corresponding Dialect instance. 504 """ 505 506 if not dialect: 507 return cls() 508 if isinstance(dialect, _Dialect): 509 return dialect() 510 if isinstance(dialect, Dialect): 511 return dialect 512 if isinstance(dialect, str): 513 try: 514 dialect_name, *kv_pairs = dialect.split(",") 515 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 516 except ValueError: 517 raise ValueError( 518 f"Invalid dialect format: '{dialect}'. " 519 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 520 ) 521 522 result = cls.get(dialect_name.strip()) 523 if not result: 524 from difflib import get_close_matches 525 526 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 527 if similar: 528 similar = f" Did you mean {similar}?" 529 530 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 531 532 return result(**kwargs) 533 534 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
Look up a dialect in the global dialect registry and return it if it exists.
Arguments:
- dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb") >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:
The corresponding Dialect instance.
536 @classmethod 537 def format_time( 538 cls, expression: t.Optional[str | exp.Expression] 539 ) -> t.Optional[exp.Expression]: 540 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 541 if isinstance(expression, str): 542 return exp.Literal.string( 543 # the time formats are quoted 544 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 545 ) 546 547 if expression and expression.is_string: 548 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 549 550 return expression
Converts a time format in this dialect to its equivalent Python strftime
format.
570 def normalize_identifier(self, expression: E) -> E: 571 """ 572 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 573 574 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 575 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 576 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 577 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 578 579 There are also dialects like Spark, which are case-insensitive even when quotes are 580 present, and dialects like MySQL, whose resolution rules match those employed by the 581 underlying operating system, for example they may always be case-sensitive in Linux. 582 583 Finally, the normalization behavior of some engines can even be controlled through flags, 584 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 585 586 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 587 that it can analyze queries in the optimizer and successfully capture their semantics. 588 """ 589 if ( 590 isinstance(expression, exp.Identifier) 591 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 592 and ( 593 not expression.quoted 594 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 595 ) 596 ): 597 expression.set( 598 "this", 599 ( 600 expression.this.upper() 601 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 602 else expression.this.lower() 603 ), 604 ) 605 606 return expression
Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
For example, an identifier like FoO
would be resolved as foo
in Postgres, because it
lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
it would resolve it as FOO
. If it was quoted, it'd need to be treated as case-sensitive,
and so any normalization would be prohibited in order to avoid "breaking" the identifier.
There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.
Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.
608 def case_sensitive(self, text: str) -> bool: 609 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 610 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 611 return False 612 613 unsafe = ( 614 str.islower 615 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 616 else str.isupper 617 ) 618 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
620 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 621 """Checks if text can be identified given an identify option. 622 623 Args: 624 text: The text to check. 625 identify: 626 `"always"` or `True`: Always returns `True`. 627 `"safe"`: Only returns `True` if the identifier is case-insensitive. 628 629 Returns: 630 Whether the given text can be identified. 631 """ 632 if identify is True or identify == "always": 633 return True 634 635 if identify == "safe": 636 return not self.case_sensitive(text) 637 638 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify:
"always"
orTrue
: Always returnsTrue
."safe"
: Only returnsTrue
if the identifier is case-insensitive.
Returns:
Whether the given text can be identified.
640 def quote_identifier(self, expression: E, identify: bool = True) -> E: 641 """ 642 Adds quotes to a given identifier. 643 644 Args: 645 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 646 identify: If set to `False`, the quotes will only be added if the identifier is deemed 647 "unsafe", with respect to its characters and this dialect's normalization strategy. 648 """ 649 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 650 name = expression.this 651 expression.set( 652 "quoted", 653 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 654 ) 655 656 return expression
Adds quotes to a given identifier.
Arguments:
- expression: The expression of interest. If it's not an
Identifier
, this method is a no-op. - identify: If set to
False
, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
658 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 659 if isinstance(path, exp.Literal): 660 path_text = path.name 661 if path.is_number: 662 path_text = f"[{path_text}]" 663 try: 664 return parse_json_path(path_text, self) 665 except ParseError as e: 666 logger.warning(f"Invalid JSON path syntax. {str(e)}") 667 668 return path
718def if_sql( 719 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 720) -> t.Callable[[Generator, exp.If], str]: 721 def _if_sql(self: Generator, expression: exp.If) -> str: 722 return self.func( 723 name, 724 expression.this, 725 expression.args.get("true"), 726 expression.args.get("false") or false_value, 727 ) 728 729 return _if_sql
732def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 733 this = expression.this 734 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 735 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 736 737 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
803def str_position_sql( 804 self: Generator, expression: exp.StrPosition, generate_instance: bool = False 805) -> str: 806 this = self.sql(expression, "this") 807 substr = self.sql(expression, "substr") 808 position = self.sql(expression, "position") 809 instance = expression.args.get("instance") if generate_instance else None 810 position_offset = "" 811 812 if position: 813 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 814 this = self.func("SUBSTR", this, position) 815 position_offset = f" + {position} - 1" 816 817 return self.func("STRPOS", this, substr, instance) + position_offset
826def var_map_sql( 827 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 828) -> str: 829 keys = expression.args["keys"] 830 values = expression.args["values"] 831 832 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 833 self.unsupported("Cannot convert array columns into map.") 834 return self.func(map_func_name, keys, values) 835 836 args = [] 837 for key, value in zip(keys.expressions, values.expressions): 838 args.append(self.sql(key)) 839 args.append(self.sql(value)) 840 841 return self.func(map_func_name, *args)
844def build_formatted_time( 845 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 846) -> t.Callable[[t.List], E]: 847 """Helper used for time expressions. 848 849 Args: 850 exp_class: the expression class to instantiate. 851 dialect: target sql dialect. 852 default: the default format, True being time. 853 854 Returns: 855 A callable that can be used to return the appropriately formatted time expression. 856 """ 857 858 def _builder(args: t.List): 859 return exp_class( 860 this=seq_get(args, 0), 861 format=Dialect[dialect].format_time( 862 seq_get(args, 1) 863 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 864 ), 865 ) 866 867 return _builder
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
870def time_format( 871 dialect: DialectType = None, 872) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 873 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 874 """ 875 Returns the time format for a given expression, unless it's equivalent 876 to the default time format of the dialect of interest. 877 """ 878 time_format = self.format_time(expression) 879 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 880 881 return _time_format
884def build_date_delta( 885 exp_class: t.Type[E], 886 unit_mapping: t.Optional[t.Dict[str, str]] = None, 887 default_unit: t.Optional[str] = "DAY", 888) -> t.Callable[[t.List], E]: 889 def _builder(args: t.List) -> E: 890 unit_based = len(args) == 3 891 this = args[2] if unit_based else seq_get(args, 0) 892 unit = None 893 if unit_based or default_unit: 894 unit = args[0] if unit_based else exp.Literal.string(default_unit) 895 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 896 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 897 898 return _builder
901def build_date_delta_with_interval( 902 expression_class: t.Type[E], 903) -> t.Callable[[t.List], t.Optional[E]]: 904 def _builder(args: t.List) -> t.Optional[E]: 905 if len(args) < 2: 906 return None 907 908 interval = args[1] 909 910 if not isinstance(interval, exp.Interval): 911 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 912 913 expression = interval.this 914 if expression and expression.is_string: 915 expression = exp.Literal.number(expression.this) 916 917 return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval)) 918 919 return _builder
931def date_add_interval_sql( 932 data_type: str, kind: str 933) -> t.Callable[[Generator, exp.Expression], str]: 934 def func(self: Generator, expression: exp.Expression) -> str: 935 this = self.sql(expression, "this") 936 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 937 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 938 939 return func
942def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 943 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 944 args = [unit_to_str(expression), expression.this] 945 if zone: 946 args.append(expression.args.get("zone")) 947 return self.func("DATE_TRUNC", *args) 948 949 return _timestamptrunc_sql
952def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 953 zone = expression.args.get("zone") 954 if not zone: 955 from sqlglot.optimizer.annotate_types import annotate_types 956 957 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 958 return self.sql(exp.cast(expression.this, target_type)) 959 if zone.name.lower() in TIMEZONES: 960 return self.sql( 961 exp.AtTimeZone( 962 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 963 zone=zone, 964 ) 965 ) 966 return self.func("TIMESTAMP", expression.this, zone)
969def no_time_sql(self: Generator, expression: exp.Time) -> str: 970 # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME) 971 this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ) 972 expr = exp.cast( 973 exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME 974 ) 975 return self.sql(expr)
978def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str: 979 this = expression.this 980 expr = expression.expression 981 982 if expr.name.lower() in TIMEZONES: 983 # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP) 984 this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ) 985 this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP) 986 return self.sql(this) 987 988 this = exp.cast(this, exp.DataType.Type.DATE) 989 expr = exp.cast(expr, exp.DataType.Type.TIME) 990 991 return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP))
1032def encode_decode_sql( 1033 self: Generator, expression: exp.Expression, name: str, replace: bool = True 1034) -> str: 1035 charset = expression.args.get("charset") 1036 if charset and charset.name.lower() != "utf-8": 1037 self.unsupported(f"Expected utf-8 character set, got {charset}.") 1038 1039 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
1052def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 1053 cond = expression.this 1054 1055 if isinstance(expression.this, exp.Distinct): 1056 cond = expression.this.expressions[0] 1057 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 1058 1059 return self.func("sum", exp.func("if", cond, 1, 0))
1062def trim_sql(self: Generator, expression: exp.Trim) -> str: 1063 target = self.sql(expression, "this") 1064 trim_type = self.sql(expression, "position") 1065 remove_chars = self.sql(expression, "expression") 1066 collation = self.sql(expression, "collation") 1067 1068 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 1069 if not remove_chars and not collation: 1070 return self.trim_sql(expression) 1071 1072 trim_type = f"{trim_type} " if trim_type else "" 1073 remove_chars = f"{remove_chars} " if remove_chars else "" 1074 from_part = "FROM " if trim_type or remove_chars else "" 1075 collation = f" COLLATE {collation}" if collation else "" 1076 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
1097def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 1098 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 1099 if bad_args: 1100 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 1101 1102 return self.func( 1103 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 1104 )
1107def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 1108 bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers"))) 1109 if bad_args: 1110 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 1111 1112 return self.func( 1113 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 1114 )
1117def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 1118 names = [] 1119 for agg in aggregations: 1120 if isinstance(agg, exp.Alias): 1121 names.append(agg.alias) 1122 else: 1123 """ 1124 This case corresponds to aggregations without aliases being used as suffixes 1125 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 1126 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 1127 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 1128 """ 1129 agg_all_unquoted = agg.transform( 1130 lambda node: ( 1131 exp.Identifier(this=node.name, quoted=False) 1132 if isinstance(node, exp.Identifier) 1133 else node 1134 ) 1135 ) 1136 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 1137 1138 return names
1178def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 1179 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 1180 if expression.args.get("count"): 1181 self.unsupported(f"Only two arguments are supported in function {name}.") 1182 1183 return self.func(name, expression.this, expression.expression) 1184 1185 return _arg_max_or_min_sql
1188def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1189 this = expression.this.copy() 1190 1191 return_type = expression.return_type 1192 if return_type.is_type(exp.DataType.Type.DATE): 1193 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1194 # can truncate timestamp strings, because some dialects can't cast them to DATE 1195 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1196 1197 expression.this.replace(exp.cast(this, return_type)) 1198 return expression
1201def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1202 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1203 if cast and isinstance(expression, exp.TsOrDsAdd): 1204 expression = ts_or_ds_add_cast(expression) 1205 1206 return self.func( 1207 name, 1208 unit_to_var(expression), 1209 expression.expression, 1210 expression.this, 1211 ) 1212 1213 return _delta_sql
1216def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1217 unit = expression.args.get("unit") 1218 1219 if isinstance(unit, exp.Placeholder): 1220 return unit 1221 if unit: 1222 return exp.Literal.string(unit.name) 1223 return exp.Literal.string(default) if default else None
1253def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1254 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1255 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1256 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1257 1258 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
1261def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1262 """Remove table refs from columns in when statements.""" 1263 alias = expression.this.args.get("alias") 1264 1265 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1266 return self.dialect.normalize_identifier(identifier).name if identifier else None 1267 1268 targets = {normalize(expression.this.this)} 1269 1270 if alias: 1271 targets.add(normalize(alias.this)) 1272 1273 for when in expression.expressions: 1274 when.transform( 1275 lambda node: ( 1276 exp.column(node.this) 1277 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1278 else node 1279 ), 1280 copy=False, 1281 ) 1282 1283 return self.merge_sql(expression)
Remove table refs from columns in when statements.
1286def build_json_extract_path( 1287 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1288) -> t.Callable[[t.List], F]: 1289 def _builder(args: t.List) -> F: 1290 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1291 for arg in args[1:]: 1292 if not isinstance(arg, exp.Literal): 1293 # We use the fallback parser because we can't really transpile non-literals safely 1294 return expr_type.from_arg_list(args) 1295 1296 text = arg.name 1297 if is_int(text): 1298 index = int(text) 1299 segments.append( 1300 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1301 ) 1302 else: 1303 segments.append(exp.JSONPathKey(this=text)) 1304 1305 # This is done to avoid failing in the expression validator due to the arg count 1306 del args[2:] 1307 return expr_type( 1308 this=seq_get(args, 0), 1309 expression=exp.JSONPath(expressions=segments), 1310 only_json_types=arrow_req_json_type, 1311 ) 1312 1313 return _builder
1316def json_extract_segments( 1317 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1318) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1319 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1320 path = expression.expression 1321 if not isinstance(path, exp.JSONPath): 1322 return rename_func(name)(self, expression) 1323 1324 segments = [] 1325 for segment in path.expressions: 1326 path = self.sql(segment) 1327 if path: 1328 if isinstance(segment, exp.JSONPathPart) and ( 1329 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1330 ): 1331 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1332 1333 segments.append(path) 1334 1335 if op: 1336 return f" {op} ".join([self.sql(expression.this), *segments]) 1337 return self.func(name, expression.this, *segments) 1338 1339 return _json_extract_segments
1349def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1350 cond = expression.expression 1351 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1352 alias = cond.expressions[0] 1353 cond = cond.this 1354 elif isinstance(cond, exp.Predicate): 1355 alias = "_u" 1356 else: 1357 self.unsupported("Unsupported filter condition") 1358 return "" 1359 1360 unnest = exp.Unnest(expressions=[expression.this]) 1361 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1362 return self.sql(exp.Array(expressions=[filtered]))
1374def build_default_decimal_type( 1375 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1376) -> t.Callable[[exp.DataType], exp.DataType]: 1377 def _builder(dtype: exp.DataType) -> exp.DataType: 1378 if dtype.expressions or precision is None: 1379 return dtype 1380 1381 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1382 return exp.DataType.build(f"DECIMAL({params})") 1383 1384 return _builder
1387def build_timestamp_from_parts(args: t.List) -> exp.Func: 1388 if len(args) == 2: 1389 # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, 1390 # so we parse this into Anonymous for now instead of introducing complexity 1391 return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) 1392 1393 return exp.TimestampFromParts.from_arg_list(args)