From 43930ec7b7bbced23374d9a046f938a412b5f1c4 Mon Sep 17 00:00:00 2001 From: Sean Rose Date: Wed, 29 Jan 2025 16:15:01 -0800 Subject: [PATCH 1/2] feat(dryrun): Add `--validate-union-schemas` option (DENG-7746). --- bigquery_etl/cli/dryrun.py | 35 +++- bigquery_etl/dryrun.py | 357 ++++++++++++++++++++++++++------ bigquery_etl/schema/__init__.py | 66 +++++- 3 files changed, 388 insertions(+), 70 deletions(-) diff --git a/bigquery_etl/cli/dryrun.py b/bigquery_etl/cli/dryrun.py index 7b6abb23fec..91b71508ccd 100644 --- a/bigquery_etl/cli/dryrun.py +++ b/bigquery_etl/cli/dryrun.py @@ -5,6 +5,7 @@ import os import re import sys +import traceback from functools import partial from multiprocessing.pool import Pool from typing import List, Set, Tuple @@ -53,6 +54,13 @@ is_flag=True, default=False, ) +@click.option( + "--validate_union_schemas", + "--validate-union-schemas", + help="Require any subqueries being unioned to have exactly matching schemas.", + is_flag=True, + default=False, +) @click.option( "--respect-skip/--ignore-skip", help="Respect or ignore query skip configuration. Default is --respect-skip.", @@ -67,6 +75,7 @@ def dryrun( paths: List[str], use_cloud_function: bool, validate_schemas: bool, + validate_union_schemas: bool, respect_skip: bool, project: str, ): @@ -112,9 +121,10 @@ def dryrun( sql_file_valid = partial( _sql_file_valid, - use_cloud_function, - respect_skip, - validate_schemas, + use_cloud_function=use_cloud_function, + respect_skip=respect_skip, + validate_schemas=validate_schemas, + validate_union_schemas=validate_union_schemas, credentials=credentials, id_token=id_token, ) @@ -133,7 +143,13 @@ def dryrun( def _sql_file_valid( - use_cloud_function, respect_skip, validate_schemas, sqlfile, credentials, id_token + sqlfile, + use_cloud_function, + respect_skip, + validate_schemas, + validate_union_schemas, + credentials, + id_token, ) -> Tuple[bool, str]: """Dry run the SQL file.""" result = DryRun( @@ -151,4 +167,15 @@ def _sql_file_valid( success = False return success, sqlfile + if validate_union_schemas: + try: + success = result.validate_union_schemas() + except Exception: + click.echo( + f"Failed to validate union schemas in {sqlfile}:\n{traceback.format_exc(limit=0)}", + err=True, + ) + success = False + return success, sqlfile + return result.is_valid(), sqlfile diff --git a/bigquery_etl/dryrun.py b/bigquery_etl/dryrun.py index d61ebb6e6f0..c00f6546fb2 100644 --- a/bigquery_etl/dryrun.py +++ b/bigquery_etl/dryrun.py @@ -15,14 +15,17 @@ import json import re import sys +from datetime import datetime, timedelta, timezone from enum import Enum from os.path import basename, dirname, exists from pathlib import Path -from typing import Optional, Set +from typing import Any, Iterable, Optional, Set from urllib.request import Request, urlopen import click import google.auth +import sqlglot +import sqlglot.optimizer.scope from google.auth.transport.requests import Request as GoogleAuthRequest from google.cloud import bigquery from google.oauth2.id_token import fetch_id_token @@ -38,10 +41,15 @@ from backports.cached_property import cached_property # type: ignore -QUERY_PARAMETER_TYPE_VALUES = { - "DATE": "2019-01-01", - "DATETIME": "2019-01-01 00:00:00", - "TIMESTAMP": "2019-01-01 00:00:00", +TMP_DATASET = "bigquery-etl-integration-test.tmp" + +YESTERDAY = datetime.now(timezone.utc).date() - timedelta(days=1) + +QUERY_PARAMETER_TYPE_VALUES: dict[str, str | bool | int] = { + # Use yesterday's date in case the date/time parameters are used in time travel queries. + "DATE": YESTERDAY.isoformat(), + "DATETIME": YESTERDAY.isoformat() + " 00:00:00", + "TIMESTAMP": YESTERDAY.isoformat() + " 00:00:00", "STRING": "foo", "BOOL": True, "FLOAT64": 1, @@ -212,14 +220,8 @@ def get_sql(self): return sql - @cached_property - def dry_run_result(self): - """Dry run the provided SQL file.""" - if self.content: - sql = self.content - else: - sql = self.get_sql() - + def get_query_parameters(self) -> list[bigquery.ScalarQueryParameter]: + """Get query parameters to use for the dry run.""" query_parameters = [] scheduling_metadata = self.metadata.scheduling if self.metadata else {} if date_partition_parameter := scheduling_metadata.get( @@ -242,69 +244,90 @@ def dry_run_result(self): QUERY_PARAMETER_TYPE_VALUES.get(parameter_type), ) ) + return query_parameters + + def _dry_run_query( + self, + sql: str, + query_parameters: Iterable[bigquery.ScalarQueryParameter], + target_project: str, + target_dataset: str, + ) -> dict[str, Any]: + """Dry run the provided query.""" + if self.use_cloud_function: + json_data = { + "project": self.project or target_project, + "dataset": self.dataset or target_dataset, + "query": sql, + "query_parameters": [ + query_parameter.to_api_repr() + for query_parameter in query_parameters + ], + } + + if self.table: + json_data["table"] = self.table + + r = urlopen( + Request( + self.dry_run_url, + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {self.id_token}", + }, + data=json.dumps(json_data).encode("utf8"), + method="POST", + ) + ) + return json.load(r) + else: + self.client.project = target_project + job_config = bigquery.QueryJobConfig( + dry_run=True, + use_query_cache=False, + default_dataset=f"{target_project}.{target_dataset}", + query_parameters=list(query_parameters), + ) + job = self.client.query(sql, job_config=job_config) + return { + "valid": True, + "referencedTables": [ + ref.to_api_repr() for ref in job.referenced_tables + ], + "schema": {"fields": [field.to_api_repr() for field in job.schema]}, + } + + @cached_property + def dry_run_result(self): + """Dry run the provided SQL file.""" + if self.content: + sql = self.content + else: + sql = self.get_sql() + + query_parameters = self.get_query_parameters() project = basename(dirname(dirname(dirname(self.sqlfile)))) dataset = basename(dirname(dirname(self.sqlfile))) + try: - if self.use_cloud_function: - json_data = { - "project": self.project or project, - "dataset": self.dataset or dataset, - "query": sql, - "query_parameters": [ - query_parameter.to_api_repr() - for query_parameter in query_parameters - ], - } - - if self.table: - json_data["table"] = self.table - - r = urlopen( - Request( - self.dry_run_url, - headers={ - "Content-Type": "application/json", - "Authorization": f"Bearer {self.id_token}", - }, - data=json.dumps(json_data).encode("utf8"), - method="POST", - ) - ) - return json.load(r) - else: - self.client.project = project - job_config = bigquery.QueryJobConfig( - dry_run=True, - use_query_cache=False, - default_dataset=f"{project}.{dataset}", - query_parameters=query_parameters, - ) - job = self.client.query(sql, job_config=job_config) + result = self._dry_run_query(sql, query_parameters, project, dataset) + + if not self.use_cloud_function: try: - dataset_labels = self.client.get_dataset(job.default_dataset).labels + result["datasetLabels"] = self.client.get_dataset( + f"{project}.{dataset}" + ).labels except Exception as e: # Most users do not have bigquery.datasets.get permission in # moz-fx-data-shared-prod # This should not prevent the dry run from running since the dataset # labels are usually not required if "Permission bigquery.datasets.get denied on dataset" in str(e): - dataset_labels = [] + result["datasetLabels"] = [] else: - raise e - - result = { - "valid": True, - "referencedTables": [ - ref.to_api_repr() for ref in job.referenced_tables - ], - "schema": ( - job._properties.get("statistics", {}) - .get("query", {}) - .get("schema", {}) - ), - "datasetLabels": dataset_labels, - } + raise + if ( self.project is not None and self.table is not None @@ -321,7 +344,7 @@ def dry_run_result(self): }, } - return result + return result except Exception as e: print(f"{self.sqlfile!s:59} ERROR\n", e) @@ -597,6 +620,210 @@ def validate_schema(self): click.echo(f"Schemas for {query_file_path} are valid.") return True + def validate_union_schemas(self) -> bool: + """Check whether subqueries being unioned have exactly matching schemas.""" + # Delay import to prevent circular imports in `bigquery_etl.schema`. + from .schema import Schema, SchemaAssertError + + result = True + + target_project = basename(dirname(dirname(dirname(self.sqlfile)))) + target_dataset = basename(dirname(dirname(self.sqlfile))) + + query_parameters_by_name = { + param.name: param for param in self.get_query_parameters() + } + + def _replace_parameters(node: sqlglot.exp.Expression) -> sqlglot.exp.Expression: + if ( + isinstance(node, sqlglot.exp.Parameter) + and node.name in query_parameters_by_name + ): + query_parameter = query_parameters_by_name[node.name] + value_literal = sqlglot.exp.Literal( + this=query_parameter.value, + is_string=isinstance(query_parameter.value, str), + ) + if query_parameter.type_ in ("STRING", "BOOL", "INT64", "INTEGER"): + return value_literal + return sqlglot.cast( + value_literal, query_parameter.type_, dialect="bigquery" + ) + return node + + function_defs_by_name: dict[str, sqlglot.exp.Create] = {} + + def _inline_functions(node: sqlglot.exp.Expression) -> sqlglot.exp.Expression: + if ( + isinstance(node, sqlglot.exp.Anonymous) + and node.name in function_defs_by_name + and not ( + isinstance(node.parent, sqlglot.exp.Dot) + and node == node.parent.right + ) + ): + function_def = function_defs_by_name[node.name] + function_params: list[sqlglot.exp.ColumnDef] = ( + function_def.this.expressions + ) + function_args: list[sqlglot.exp.Expression] = node.expressions + function_body: sqlglot.exp.Expression = function_def.expression + + # If we know the return type, replace the function call with null cast as the return type. + for function_property in function_def.args["properties"].expressions: + if isinstance(function_property, sqlglot.exp.ReturnsProperty): + return sqlglot.cast(sqlglot.exp.Null(), function_property.this) + + # If the function has no arguments, replace the function call with the function body as is. + if not function_args: + return function_body + + # Replace the function call with the function body modified to inline the function arguments. + function_args_by_name = {} + for function_arg_index, function_arg in enumerate(function_args): + if isinstance(function_arg, sqlglot.exp.Kwarg): + function_args_by_name[function_arg.name] = ( + function_arg.expression + ) + elif function_arg_index < len(function_params): + function_args_by_name[ + function_params[function_arg_index].name + ] = function_arg + + def _inline_function_args( + node: sqlglot.exp.Expression, + ) -> sqlglot.exp.Expression: + if ( + isinstance(node, sqlglot.exp.Column) + and node.parts[0].name in function_args_by_name + ): + function_arg_name = node.parts[0].name + function_arg = function_args_by_name[function_arg_name] + if isinstance( + function_arg, + ( + sqlglot.exp.Column, + sqlglot.exp.Dot, + sqlglot.exp.Literal, + sqlglot.exp.Paren, + ), + ): + inline_function_arg = function_arg + else: + inline_function_arg = sqlglot.exp.paren(function_arg) + if len(node.parts) == 1: + if isinstance( + node.parent, (sqlglot.exp.Select, sqlglot.exp.Struct) + ): + return sqlglot.alias( + inline_function_arg, function_arg_name + ) + return inline_function_arg + # Recursively transform the final column expression part in case it's a + # `* REPLACE (...)` with embedded function parameter references. + return sqlglot.exp.Dot.build( + [ + inline_function_arg, + *node.parts[1:-1], + node.parts[-1].transform(_inline_function_args), + ] + ) + return node + + return function_body.transform(_inline_function_args) + return node + + for sql_expression in sqlglot.parse(self.get_sql(), dialect="bigquery"): + if not sql_expression: + continue + if ( + isinstance(sql_expression, sqlglot.exp.Create) + and sql_expression.kind == "FUNCTION" + ): + function_defs_by_name[sql_expression.this.name] = sql_expression + continue + + sql_scope = sqlglot.optimizer.scope.build_scope(sql_expression) + if not sql_scope: + continue + + union_expression_scopes: list[sqlglot.optimizer.scope.Scope] = [ + scope + for scope in sql_scope.traverse() + if isinstance(scope.expression, sqlglot.exp.Union) + ] + for union_number, union_scope in enumerate( + union_expression_scopes, start=1 + ): + union: sqlglot.exp.Union = union_scope.expression + left_select = union.left.copy() + # Always use the very first `SELECT` statement in the union as the basis for comparison. + while isinstance(left_select, sqlglot.exp.Union): + left_select = left_select.left + right_select = union.right.copy() + if union_scope.cte_sources: + for cte_name, cte_scope in union_scope.cte_sources.items(): + left_select.with_(cte_name, cte_scope.expression, copy=False) + right_select.with_(cte_name, cte_scope.expression, copy=False) + + if query_parameters_by_name: + left_select.transform(_replace_parameters, copy=False) + right_select.transform(_replace_parameters, copy=False) + + if function_defs_by_name: + left_select.transform(_inline_functions, copy=False) + right_select.transform(_inline_functions, copy=False) + + left_select_sql = left_select.sql(dialect="bigquery", pretty=True) + right_select_sql = right_select.sql(dialect="bigquery", pretty=True) + + # Use `CREATE VIEW` statements to avoid having to add `WHERE` clauses for tables' partition columns. + left_dryrun_sql = f"CREATE OR REPLACE VIEW `{TMP_DATASET}.union_left` AS\n{left_select_sql}" + right_dryrun_sql = f"CREATE OR REPLACE VIEW `{TMP_DATASET}.union_right` AS\n{right_select_sql}" + + try: + left_dryrun_result = self._dry_run_query( + left_dryrun_sql, [], target_project, target_dataset + ) + except Exception as e: + raise Exception( + f"Dryrun error for left side of union #{union_number}." + ) from e + if not left_dryrun_result["valid"]: + raise Exception( + f"Dryrun error for left side of union #{union_number}: {left_dryrun_result['errors']}" + ) + + try: + right_dryrun_result = self._dry_run_query( + right_dryrun_sql, [], target_project, target_dataset + ) + except Exception as e: + raise Exception( + f"Dryrun error for right side of union #{union_number}." + ) from e + if not right_dryrun_result["valid"]: + raise Exception( + f"Dryrun error for right side of union #{union_number}: {right_dryrun_result['errors']}" + ) + + left_select_schema = Schema.from_json(left_dryrun_result["schema"]) + right_select_schema = Schema.from_json(right_dryrun_result["schema"]) + try: + left_select_schema.assert_exactly_unionable_with( + right_select_schema + ) + except SchemaAssertError as e: + click.echo( + click.style( + f"ERROR: Schema mismatch in union #{union_number} in {self.sqlfile}: {e}", + fg="red", + ), + err=True, + ) + result = False + return result + def sql_file_valid(sqlfile): """Dry run SQL files.""" diff --git a/bigquery_etl/schema/__init__.py b/bigquery_etl/schema/__init__.py index 97dbbe73a64..83b5583188a 100644 --- a/bigquery_etl/schema/__init__.py +++ b/bigquery_etl/schema/__init__.py @@ -4,7 +4,7 @@ import os from pathlib import Path from tempfile import NamedTemporaryFile -from typing import Any, Dict, Iterable, List, Optional +from typing import Any, Dict, Iterable, List, Optional, Sequence import attr import yaml @@ -17,6 +17,12 @@ SCHEMA_FILE = "schema.yaml" +class SchemaAssertError(Exception): + """Schema assert error.""" + + pass + + @attr.s(auto_attribs=True) class Schema: """Query schema representation and helpers.""" @@ -273,6 +279,64 @@ def _traverse( f"Field {prefix}.{field_path} is missing in schema" ) + def _assert_fields_exactly_unionable( + self, + fields: Sequence[dict[str, Any]], + other_fields: Sequence[dict[str, Any]], + parent_field_name: Optional[str] = None, + ) -> None: + """Raise an error if the two lists of fields don't match in terms of field positions, names, and types.""" + parent_field_note = f" in {parent_field_name}" if parent_field_name else "" + + if len(fields) != len(other_fields): + raise SchemaAssertError( + f"Different number of fields{parent_field_note} ({len(fields)} vs {len(other_fields)})." + ) + + for field_number, (field, other_field) in enumerate( + zip(fields, other_fields), start=1 + ): + if field["name"] != other_field["name"]: + raise SchemaAssertError( + f"Field #{field_number}{parent_field_note} has different names ({field['name']} vs {other_field['name']})." + ) + + field_name = ( + f"{parent_field_name}.{field['name']}" + if parent_field_name + else field["name"] + ) + field_is_repeated = field.get("mode") == "REPEATED" + other_field_is_repeated = other_field.get("mode") == "REPEATED" + field_type = ( + f"REPEATED {field['type']}" if field_is_repeated else field["type"] + ) + other_field_type = ( + f"REPEATED {other_field['type']}" + if other_field_is_repeated + else other_field["type"] + ) + + if field_type != other_field_type: + raise SchemaAssertError( + f"Field {field_name} has different types ({field_type} vs {other_field_type})." + ) + + if field["type"] == "RECORD": + self._assert_fields_exactly_unionable( + field["fields"], + other_field["fields"], + parent_field_name=( + field_name + ("[]" if field_is_repeated else "") + ), + ) + + def assert_exactly_unionable_with(self, other: "Schema") -> None: + """Raise an error if this schema doesn't match the other schema in terms of field positions, names, and types.""" + self._assert_fields_exactly_unionable( + self.schema["fields"], other.schema["fields"] + ) + def to_yaml_file(self, yaml_path: Path): """Write schema to the YAML file path.""" with open(yaml_path, "w") as out: From a99ce3b5a043480eab331bcf81166797d013e75b Mon Sep 17 00:00:00 2001 From: Sean Rose <1994030+sean-rose@users.noreply.github.com> Date: Fri, 31 Jan 2025 12:49:56 -0800 Subject: [PATCH 2/2] Remove redundant `pass`. Co-authored-by: Ben Wu <12437227+BenWu@users.noreply.github.com> --- bigquery_etl/schema/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/bigquery_etl/schema/__init__.py b/bigquery_etl/schema/__init__.py index 83b5583188a..1deef0a35eb 100644 --- a/bigquery_etl/schema/__init__.py +++ b/bigquery_etl/schema/__init__.py @@ -20,8 +20,6 @@ class SchemaAssertError(Exception): """Schema assert error.""" - pass - @attr.s(auto_attribs=True) class Schema: