Prechádzať zdrojové kódy

v4.21 WITH优化,WHERE开窗

ysl2007 1 mesiac pred
rodič
commit
9334247d18
1 zmenil súbory, kde vykonal 243 pridanie a 32 odobranie
  1. 243 32
      carddef2sql.py

+ 243 - 32
carddef2sql.py

@@ -55,6 +55,13 @@ FILTER_OPERATOR_MAP = {
 IDENTIFIER_QUOTE = '`'
 QUOTE_FLAG = True
 
+WINDOW_MAX_OVER_PATTERN = re.compile(
+    r"max\s*\(\s*(?P<arg>.*?)\s*\)\s*over\s*\(\s*(?P<window>.*?)\s*\)",
+    flags=re.IGNORECASE | re.DOTALL,
+)
+
+AGGREGATION_PATTERN= re.compile(r"\b(sum|avg|count|max|min|stddev|variance|collect_list|collect_set|percentile|percentile_approx)|\s*\(", flags=re.IGNORECASE)
+
 # 副词
 ADV_FILTER_EXP_MAP = {
     'TODAY': "{field} = '{{today}}'",
@@ -163,7 +170,166 @@ def resolve_calculation_formula(formula, calculation_fields, visited=None):
 
     return re.sub(r"\[([^\[\]]+)\]", replace_calculation_field, formula)
 
-def build_with_part(new_date_fields, new_dimension_fields, dataset_fid_name_map, added_fields_info, dataset_id):
+def extract_formula_field_refs(formula):
+    # 提取公式中以 [字段] 形式引用的字段,供依赖收集使用。
+    if not formula:
+        return set()
+    refs = set()
+    for match in re.findall(r"\[([^\[\]]+)\]", formula):
+        field_name = match.strip()
+        if field_name and not re.fullmatch(r"\d+", field_name):
+            refs.add(field_name)
+    return refs
+
+def collect_formula_dependencies(formula, calculation_fields, visited=None):
+    # 递归下钻计算字段,收集最终依赖到的数据集原始字段。
+    if not formula:
+        return set()
+    if visited is None:
+        visited = set()
+
+    if "consolidation" in formula:
+        consolidation_dict = json.loads(formula)["consolidation"]
+        source_name = consolidation_dict.get("sourceName")
+        if not source_name:
+            return set()
+        source_field = calculation_fields.get(source_name)
+        if not source_field:
+            return {source_name}
+        source_field_id = source_field.get("field_id") or source_name
+        if source_field_id in visited:
+            raise ValueError(f"计算字段存在循环引用: {source_name}")
+        nested_formula = source_field["calculation"].get("formula", "")
+        return collect_formula_dependencies(nested_formula, calculation_fields, visited | {source_field_id})
+
+    dependencies = set()
+    for field_name in extract_formula_field_refs(formula):
+        field_def = calculation_fields.get(field_name)
+        if not field_def:
+            dependencies.add(field_name)
+            continue
+        field_id = field_def.get("field_id") or field_name
+        if field_id in visited:
+            raise ValueError(f"计算字段存在循环引用: {field_name}")
+        nested_formula = field_def["calculation"].get("formula", "")
+        dependencies.update(collect_formula_dependencies(nested_formula, calculation_fields, visited | {field_id}))
+    return dependencies
+
+def collect_filter_dependencies(filter_relation_str, calculation_fields):
+    # 筛选条件里的公式也可能依赖额外字段,需要提前纳入 WITH 基础列。
+    if not filter_relation_str or filter_relation_str == "[]":
+        return set()
+    dependencies = set()
+    try:
+        raw_conditions = json.loads(filter_relation_str)
+    except Exception:
+        return dependencies
+
+    for cond_dict in raw_conditions:
+        field_name = cond_dict.get("name")
+        if field_name:
+            dependencies.add(field_name)
+        formula = cond_dict.get("formula")
+        if formula:
+            dependencies.update(collect_formula_dependencies(formula, calculation_fields))
+        consolidation = cond_dict.get("consolidation")
+        if consolidation:
+            source_name = consolidation.get("sourceName")
+            if source_name:
+                source_field = calculation_fields.get(source_name)
+                if not source_field:
+                    dependencies.add(source_name)
+                else:
+                    nested_formula = source_field["calculation"].get("formula", "")
+                    dependencies.update(collect_formula_dependencies(nested_formula, calculation_fields, {source_field.get("field_id") or source_name}))
+    return dependencies
+
+def collect_with_base_fields(
+    all_field_names,
+    measure_fields,
+    new_date_fields,
+    new_dimension_fields,
+    dataset_fid_name_map,
+    added_fields_info,
+    filter_relation_str,
+):
+    # WITH 只保留后续 SELECT / WHERE / ORDER BY 真正需要的底层字段,
+    # 避免把整张数据集无差别 SELECT 进临时表。
+    dataset_field_names = set(dataset_fid_name_map.values())
+    required_fields = {name for name in all_field_names if name in dataset_field_names}
+
+    for fid, _ in new_date_fields:
+        old_fid = fid.split('_')[0]
+        if old_fid in dataset_fid_name_map:
+            required_fields.add(dataset_fid_name_map[old_fid])
+        elif old_fid in added_fields_info:
+            formula = added_fields_info[old_fid]["calculation"].get("formula", "")
+            required_fields.update(collect_formula_dependencies(formula, added_fields_info, {old_fid}))
+
+    for fid, _ in new_dimension_fields:
+        formula = added_fields_info[fid]["calculation"].get("formula", "")
+        required_fields.update(collect_formula_dependencies(formula, added_fields_info, {fid}))
+
+    for field in measure_fields:
+        if field not in added_fields_info:
+            continue
+        field_id = added_fields_info[field]["field_id"]
+        formula = added_fields_info[field]["calculation"].get("formula", "")
+        required_fields.update(collect_formula_dependencies(formula, added_fields_info, {field_id}))
+
+    required_fields.update(collect_filter_dependencies(filter_relation_str, added_fields_info))
+    return required_fields
+
+def resolve_window_expression_fields(expression, calculation_fields):
+    # 窗口函数内部若引用了计算字段,需要先还原为公式,
+    # 否则 WITH 中生成的窗口列仍会依赖一个并不存在的别名字段。
+    if not expression:
+        return expression
+
+    def replace_identifier(match):
+        field_name = match.group(1).strip()
+        field_def = calculation_fields.get(field_name)
+        if not field_def:
+            return match.group(0)
+
+        field_id = field_def.get("field_id") or field_name
+        formula = field_def["calculation"].get("formula", "")
+        if "consolidation" in formula:
+            resolved_formula = get_consolidation_field(json.loads(formula)["consolidation"])
+        else:
+            resolved_formula = resolve_calculation_formula(formula, calculation_fields, {field_id})
+            resolved_formula = quote_identifier(resolved_formula, formula=True)
+        return f"({resolved_formula})"
+
+    return re.sub(r"`([^`]+)`", replace_identifier, expression)
+
+def rewrite_window_max_over(expression, calculation_fields, window_alias_map, window_select_expressions):
+    # Hive / SparkSQL 不允许在 WHERE/HAVING 中直接使用窗口函数。
+    # 这里将 max() over(...) 提取到 WITH 中,WHERE 里只保留对中间列的判断。
+    if not expression:
+        return expression
+
+    def replace_window(match):
+        raw_expression = resolve_window_expression_fields(match.group(0).strip(), calculation_fields)
+        normalized_expression = re.sub(r"\s+", " ", raw_expression).lower()
+        alias = window_alias_map.get(normalized_expression)
+        if not alias:
+            alias = f"window_max_over_{len(window_alias_map) + 1}"
+            window_alias_map[normalized_expression] = alias
+            window_select_expressions.append(f"{raw_expression} AS {quote_identifier(alias)}")
+        return quote_identifier(alias)
+
+    return WINDOW_MAX_OVER_PATTERN.sub(replace_window, expression)
+
+def build_with_part(
+    new_date_fields,
+    new_dimension_fields,
+    dataset_fid_name_map,
+    added_fields_info,
+    dataset_id,
+    required_base_fields,
+    extra_with_expressions=None,
+):
     override_field_names = set()
     for _, new_name in new_date_fields:
         override_field_names.add(new_name)
@@ -175,6 +341,9 @@ def build_with_part(new_date_fields, new_dimension_fields, dataset_fid_name_map,
     for field_name in dataset_fid_name_map.values():
         if field_name in override_field_names or field_name in seen_columns:
             continue
+        # 仅保留依赖收集阶段判定为需要的原始字段。
+        if field_name not in required_base_fields:
+            continue
         seen_columns.add(field_name)
         base_columns.append(quote_identifier(field_name))
 
@@ -207,6 +376,10 @@ def build_with_part(new_date_fields, new_dimension_fields, dataset_fid_name_map,
             tmp_part = quote_identifier(formula, formula=True) + f" AS `{new_name}`"
         with_expressions.append(tmp_part)
 
+    if extra_with_expressions:
+        # 额外字段主要承载从 WHERE 中抽出的窗口函数中间列。
+        with_expressions.extend(extra_with_expressions)
+
     select_parts = base_columns + with_expressions
     sql_part = "WITH tmp as (\nSELECT " + ",\n".join(select_parts)
     sql_part += f"\nFROM {quote_identifier(str(dataset_id))}\n)"
@@ -253,8 +426,8 @@ def process_measure_fields(measure_fields, measure_aggs, calculation_fields, car
                 measure_is_aggregated.append(True)
                 agg_flag = True
             else:
-                new_measure_aggs.append(measure_aggs.pop(0))
-                measure_is_aggregated.append(False)
+                new_measure_aggs.append('NUL')
+                measure_is_aggregated.append(True)
     return new_measure_fields, new_measure_aggs, measure_is_aggregated, agg_flag
 
 # sql部分去重
@@ -418,10 +591,16 @@ def get_consolidation_field(consolidation_dict):
     field += "\nEND"
     return field
 
-def parse_filter_string(filter_relation_str):
+def parse_filter_string(filter_relation_str, calculation_fields=None, window_alias_map=None, window_select_expressions=None):
     conditions = {}
     if not filter_relation_str or filter_relation_str == "[]":
         return conditions
+    if calculation_fields is None:
+        calculation_fields = {}
+    if window_alias_map is None:
+        window_alias_map = {}
+    if window_select_expressions is None:
+        window_select_expressions = []
 
     raw_conditions = json.loads(filter_relation_str)
     for cond_dict in raw_conditions:
@@ -449,6 +628,8 @@ def parse_filter_string(filter_relation_str):
                 continue
             if 'formula' in cond_dict:
                 field = quote_identifier(cond_dict['formula'], formula=True)
+                # 先改写窗口函数,避免将非法的 over(...) 留在 WHERE 条件中。
+                field = rewrite_window_max_over(field, calculation_fields, window_alias_map, window_select_expressions)
             else:
                 field = quote_identifier(cond_dict['name'])
             expression = ADV_FILTER_EXP_MAP.get(cond_dict["advFilter"])
@@ -461,6 +642,8 @@ def parse_filter_string(filter_relation_str):
         elif op_dict == 'SPARK_EXPR':
             if 'formula' in cond_dict:
                 formula = quote_identifier(cond_dict['formula'], formula=True)
+                # SPARK_EXPR 中也可能直接出现窗口函数,处理方式与普通公式一致。
+                formula = rewrite_window_max_over(formula, calculation_fields, window_alias_map, window_select_expressions)
                 conditions[fdId] = {"exp": formula, "agg": is_aggregated}
             else:
                 if isinstance(cond_dict['filterValue'], list) and len(cond_dict['filterValue']) == 1:
@@ -473,7 +656,7 @@ def parse_filter_string(filter_relation_str):
 
         # 处理条件
         value_nums = op_dict["val_nums"]
-        if value_nums != 0 and len(values) != value_nums:
+        if value_nums != 9 and len(values) != value_nums:
             print(f"警告: 无法解析筛选条件,值数量与操作符不匹配。跳过此条件。")
             continue
 
@@ -491,6 +674,7 @@ def parse_filter_string(filter_relation_str):
             # 公式,非 consolidation情况
             if "formula" in cond_dict:
                 field = quote_identifier(cond_dict["formula"], formula=True)
+                field = rewrite_window_max_over(field, calculation_fields, window_alias_map, window_select_expressions)
             
         if op_name in ("NI", "IN") and len(values) == 0:
             print(f"警告: 无法解析筛选条件,IN或NI中参数个数为0。跳过此条件。")
@@ -527,6 +711,9 @@ def build_sql_query(card_data, added_fields_info, dataset_fid_name_map):
 
     measure_fids = parse_multi_value_field(card_data.get("num_value_field_id", []))
     measure_fields = parse_multi_value_field(card_data.get("num_value_field_name", []))
+    # 处理用于转置行列的特殊无ID“度量名”字段
+    if "度量名" in dimension_fields and len(dimension_fields) == len(dimension_fids) + 1:
+        dimension_fields.remove("度量名")
     measure_aggs = parse_multi_value_field(card_data.get("num_value_field_merge_way", []))
     filter_relation_str = card_data.get("filters_field_value_name_rela")
 
@@ -552,25 +739,6 @@ def build_sql_query(card_data, added_fields_info, dataset_fid_name_map):
     # 只需要更新有重命名的字段即可
     selected_fid_alias_map = dict(zip(dimension_fids+measure_fids, dimension_fields+measure_fields))
 
-    # 构建WITH
-    with_part = ""
-    new_date_fields = []
-    # 日期转换
-    for fid, name in all_field_id_name_map.items():
-        fid_splits = fid.split('_')
-        if len(fid_splits) == 2:
-            new_date_fields.append((fid, name))
-            old_fid = fid_splits[0]
-            selected_fid_alias_map[old_fid] = name
-    # 新增维度字段
-    new_dimension_fields = []
-    for fid, name in dimension_fid_name_map.items():
-        if fid in added_fields_info:
-            new_dimension_fields.append((fid, name))
-    # 如果有新增日期字段、新增维度字段,构建WITH
-    if new_date_fields or new_dimension_fields:
-        with_part = build_with_part(new_date_fields, new_dimension_fields, dataset_fid_name_map, added_fields_info, dataset_id)
-    
     # 构建SELECT
     select_parts = []
     has_aggregation = False
@@ -601,7 +769,11 @@ def build_sql_query(card_data, added_fields_info, dataset_fid_name_map):
                 alias = measure_fields[i]
             select_parts.append(f"{field} AS {quote_identifier(alias)}")
             # 属于计算字段,但没有聚合函数,等同于维度字段,需要加入groupbyby。
-            if not measure_is_aggregated[i] and field and re.search(r"\b(sum|avg|count|max|min|stddev|variance|collect_list|collect_set|percentile|percentile_approx)|\s*\(", field, flags=re.IGNORECASE) is None:
+            if not measure_is_aggregated[i] and field and re.search(AGGREGATION_PATTERN, field) is None:
+                if re.match(r"\d+", field):
+                    non_aggregated_select_parts.append(quote_identifier(field))
+                else:
+                    non_aggregated_select_parts.append(field)
                 non_aggregated_select_parts.append(field)
             selected_fid_alias_map[fid] = alias
         else:
@@ -626,21 +798,60 @@ def build_sql_query(card_data, added_fields_info, dataset_fid_name_map):
     else:
         select_clause = "SELECT " + ",\n    ".join(select_parts)
     
-    # 构建FROM
-    if with_part:
-        from_clause = "FROM tmp"
-    else:
-        from_clause = f"FROM {quote_identifier(str(dataset_id))}"
-    
     # 构建WHERE
     filter_conditions = {}
+    window_alias_map = {}
+    window_select_expressions = []
     try:
-        filter_conditions = parse_filter_string(filter_relation_str)
+        # parse_filter_string 会顺便收集需要下推到 WITH 的窗口函数表达式。
+        filter_conditions = parse_filter_string(filter_relation_str, added_fields_info, window_alias_map, window_select_expressions)
     except Exception as e:
         print(f"错误: 卡片 {card_id} {card_name} 解析筛选条件出错:{e}。WHERE字句缺失。")
         print("详细错误信息:")
         print(traceback.format_exc())
 
+    # 构建WITH
+    with_part = ""
+    new_date_fields = []
+    # 日期转换
+    for fid, name in all_field_id_name_map.items():
+        fid_splits = fid.split('_')
+        if len(fid_splits) == 2:
+            new_date_fields.append((fid, name))
+            old_fid = fid_splits[0]
+            selected_fid_alias_map[old_fid] = name
+    # 新增维度字段
+    new_dimension_fields = []
+    for fid, name in dimension_fid_name_map.items():
+        if fid in added_fields_info:
+            new_dimension_fields.append((fid, name))
+    # 只要存在派生日期、计算维度或窗口筛选中的任一情况,就需要 WITH。
+    if new_date_fields or new_dimension_fields or window_select_expressions:
+        required_base_fields = collect_with_base_fields(
+            all_field_names,
+            measure_fields,
+            new_date_fields,
+            new_dimension_fields,
+            dataset_fid_name_map,
+            added_fields_info,
+            filter_relation_str,
+        )
+        with_part = build_with_part(
+            new_date_fields,
+            new_dimension_fields,
+            dataset_fid_name_map,
+            added_fields_info,
+            dataset_id,
+            required_base_fields,
+            window_select_expressions,
+        )
+
+    # 构建FROM
+    if with_part:
+        from_clause = "FROM tmp"
+    else:
+        from_clause = f"FROM {quote_identifier(str(dataset_id))}"
+
     # 构建GROUPBY
     group_by_clause = ""
     if has_aggregation: