Browse Source

v3.25.3 嵌套计算字段

ysl2007 2 months ago
parent
commit
36426634fe
1 changed files with 47 additions and 13 deletions
  1. 47 13
      carddef2sql.py

+ 47 - 13
carddef2sql.py

@@ -140,6 +140,29 @@ def get_fields_rename_map(field_info):
             ret[one_map["name"]] = one_map["alias"]
             ret[one_map["name"]] = one_map["alias"]
     return ret
     return ret
 
 
+# 递归解析嵌套的计算字段
+def resolve_calculation_formula(formula, calculation_fields, visited=None):
+    if not formula:
+        return formula
+    if visited is None:
+        visited = set()
+
+    def replace_calculation_field(match):
+        field_key = match.group(1).strip()
+        field_def = calculation_fields.get(field_key)
+        if not field_def:
+            return match.group(0)
+        field_id = field_def.get("field_id") or field_key
+        if field_id in visited:
+            raise ValueError(f"计算字段存在循环引用: {field_key}")
+        nested_formula = field_def["calculation"].get("formula", "")
+        if "consolidation" in nested_formula:
+            return match.group(0)
+        resolved = resolve_calculation_formula(nested_formula, calculation_fields, visited | {field_id})
+        return f"({resolved})"
+
+    return re.sub(r"\[([^\[\]]+)\]", replace_calculation_field, formula)
+
 def build_with_part(new_date_fields, new_dimension_fields, dataset_fid_name_map, added_fields_info, dataset_id):
 def build_with_part(new_date_fields, new_dimension_fields, dataset_fid_name_map, added_fields_info, dataset_id):
     sql_part = 'WITH tmp as (\nSELECT *,\n'
     sql_part = 'WITH tmp as (\nSELECT *,\n'
     with_expressions = []
     with_expressions = []
@@ -166,6 +189,8 @@ def build_with_part(new_date_fields, new_dimension_fields, dataset_fid_name_map,
             tmp_part = get_consolidation_field(consolidation_dict)
             tmp_part = get_consolidation_field(consolidation_dict)
             tmp_part += f" AS `{new_name}`"
             tmp_part += f" AS `{new_name}`"
         else:
         else:
+            # 递归解析计算字段是否有嵌套情况
+            formula = resolve_calculation_formula(formula, added_fields_info, {fid})
             tmp_part = quote_identifier(formula, formula=True) + f" AS `{new_name}`"
             tmp_part = quote_identifier(formula, formula=True) + f" AS `{new_name}`"
         with_expressions.append(tmp_part)
         with_expressions.append(tmp_part)
     sql_part += ',\n'.join(with_expressions)
     sql_part += ',\n'.join(with_expressions)
@@ -178,41 +203,44 @@ def process_measure_fields(measure_fields, measure_aggs, calculation_fields, car
     if len(measure_fields) < len(measure_aggs):
     if len(measure_fields) < len(measure_aggs):
         print(f"警告: 卡片 {card_id} {card_name}: 数值字段数量小于聚合函数数量,不合法")
         print(f"警告: 卡片 {card_id} {card_name}: 数值字段数量小于聚合函数数量,不合法")
         print(f"警告: 卡片 {card_id} {card_name}: 不添加任何数值字段.")
         print(f"警告: 卡片 {card_id} {card_name}: 不添加任何数值字段.")
-        return [], [], False
+        return [], [], [], False
     ## 数值字段 大于 聚合函数数量,存在聚合类型的计算字段,尝试填充
     ## 数值字段 大于 聚合函数数量,存在聚合类型的计算字段,尝试填充
     elif len(measure_fields) > len(measure_aggs):
     elif len(measure_fields) > len(measure_aggs):
         ## 计算数值字段数量
         ## 计算数值字段数量
         num_cals = 0
         num_cals = 0
         for field in measure_fields:
         for field in measure_fields:
-            if field in calculation_fields and calculation_fields[field]["calculation"]["isAggregated"] is True:
+            if field in calculation_fields: # and calculation_fields[field]["calculation"]["isAggregated"] is True:
                 num_cals += 1
                 num_cals += 1
         ## 如果不存在任何计算字段,补全剩余的NUL聚合函数
         ## 如果不存在任何计算字段,补全剩余的NUL聚合函数
         if num_cals == 0:
         if num_cals == 0:
             measure_aggs.extend(['NULL'] * (len(measure_fields) - len(measure_aggs)))
             measure_aggs.extend(['NULL'] * (len(measure_fields) - len(measure_aggs)))
-            return measure_fields, measure_aggs, True
+            return [quote_identifier(field) for field in measure_fields], measure_aggs, [False] * len(measure_fields), True
         ## 如果存在计算字段,且相加后的 聚合函数数量 仍小于 数值字段数量,不合法
         ## 如果存在计算字段,且相加后的 聚合函数数量 仍小于 数值字段数量,不合法
         if num_cals + len(measure_aggs) != len(measure_fields):
         if num_cals + len(measure_aggs) != len(measure_fields):
             print(f"警告: 卡片 {card_id} {card_name}: 数值字段数量大于聚合函数数量,不合法")
             print(f"警告: 卡片 {card_id} {card_name}: 数值字段数量大于聚合函数数量,不合法")
             print(f"警告: 卡片 {card_id} {card_name}: 不添加任何数值字段.")
             print(f"警告: 卡片 {card_id} {card_name}: 不添加任何数值字段.")
-            return [], [], False
+            return [], [], [], False
     ## 通过验证,填充聚合函数
     ## 通过验证,填充聚合函数
-    new_measure_fields, new_measure_aggs, agg_flag = [], [], False
+    new_measure_fields, new_measure_aggs, measure_is_aggregated, agg_flag = [], [], [], False
     for i, field in enumerate(measure_fields):
     for i, field in enumerate(measure_fields):
         ## 非计算字段
         ## 非计算字段
         if field not in calculation_fields:
         if field not in calculation_fields:
             new_measure_fields.append(quote_identifier(field))
             new_measure_fields.append(quote_identifier(field))
             new_measure_aggs.append(measure_aggs.pop(0))
             new_measure_aggs.append(measure_aggs.pop(0))
+            measure_is_aggregated.append(False)
         ## 计算字段
         ## 计算字段
         else:
         else:
             formula = calculation_fields[field]["calculation"]["formula"]
             formula = calculation_fields[field]["calculation"]["formula"]
-            formula = formula.replace('\n', '')
+            formula = resolve_calculation_formula(formula, calculation_fields, {calculation_fields[field]["field_id"]})
             new_measure_fields.append(quote_identifier(formula, formula=True))
             new_measure_fields.append(quote_identifier(formula, formula=True))
             if calculation_fields[field]["calculation"]["isAggregated"] is True:
             if calculation_fields[field]["calculation"]["isAggregated"] is True:
                 new_measure_aggs.append("NUL")
                 new_measure_aggs.append("NUL")
+                measure_is_aggregated.append(True)
                 agg_flag = True
                 agg_flag = True
             else:
             else:
                 new_measure_aggs.append(measure_aggs.pop(0))
                 new_measure_aggs.append(measure_aggs.pop(0))
-    return new_measure_fields, new_measure_aggs, agg_flag
+                measure_is_aggregated.append(False)
+    return new_measure_fields, new_measure_aggs, measure_is_aggregated, agg_flag
 
 
 # sql部分去重
 # sql部分去重
 def dedupe_sql_parts(parts):
 def dedupe_sql_parts(parts):
@@ -545,18 +573,20 @@ def build_sql_query(card_data, added_fields_info, dataset_fid_name_map):
             selected_fid_alias_map[fid] = field
             selected_fid_alias_map[fid] = field
     
     
     # 加工计算字段
     # 加工计算字段
-    new_measure_fields, measure_aggs, agg_flag = process_measure_fields(measure_fields, measure_aggs, added_fields_info, card_id, card_name)
+    new_measure_fields, measure_aggs, measure_is_aggregated, agg_flag = process_measure_fields(measure_fields, measure_aggs, added_fields_info, card_id, card_name)
     if agg_flag:
     if agg_flag:
         has_aggregation = True
         has_aggregation = True
     for i, field in enumerate(new_measure_fields):
     for i, field in enumerate(new_measure_fields):
         fid = measure_fids[i]
         fid = measure_fids[i]
         alias = fields_rename_map.get(field.strip('`'))
         alias = fields_rename_map.get(field.strip('`'))
+        # measure_agg是NUL,不需要聚合(等同于维度字段)或公式本身已经有聚合函数
         agg_func_template = AGGREGATION_MAP.get(measure_aggs[i])
         agg_func_template = AGGREGATION_MAP.get(measure_aggs[i])
         if not agg_func_template:
         if not agg_func_template:
             if not alias or alias == "null":
             if not alias or alias == "null":
                 alias = measure_fields[i]
                 alias = measure_fields[i]
             select_parts.append(f"{field} AS {quote_identifier(alias)}")
             select_parts.append(f"{field} AS {quote_identifier(alias)}")
-            if field and re.search(r"\b(sum|avg|count|max|min|stddev|variance|collect_list|collect_set|percentile|percentile_approx)|\s*\(", field, flags=re.IGNORECASE) is None:
+            # 属于计算字段,但没有聚合函数,等同于维度字段,需要加入groupbyby。
+            if not measure_is_aggregated[i] and field and re.search(r"\b(sum|avg|count|max|min|stddev|variance|collect_list|collect_set|percentile|percentile_approx)|\s*\(", field, flags=re.IGNORECASE) is None:
                 non_aggregated_select_parts.append(field)
                 non_aggregated_select_parts.append(field)
             selected_fid_alias_map[fid] = alias
             selected_fid_alias_map[fid] = alias
         else:
         else:
@@ -621,18 +651,22 @@ def build_sql_query(card_data, added_fields_info, dataset_fid_name_map):
     # 返回 select, where, groupby, orderby
     # 返回 select, where, groupby, orderby
     return ("\n".join(sql_parts)).strip(), json.dumps(filter_conditions, ensure_ascii=False), group_by_clause, order_by_clause
     return ("\n".join(sql_parts)).strip(), json.dumps(filter_conditions, ensure_ascii=False), group_by_clause, order_by_clause
 
 
-def generate():
+def generate(start=None, end=None, test_card_id=None):
     res_list = []
     res_list = []
     df = pd.read_csv("data/card.csv").fillna("").reset_index()
     df = pd.read_csv("data/card.csv").fillna("").reset_index()
     add_field_info = pd.read_csv("data/calc.csv").fillna('').set_index("card_id")
     add_field_info = pd.read_csv("data/calc.csv").fillna('').set_index("card_id")
     all_field_info = pd.read_csv("data/field.csv").fillna('').set_index("ds_id")
     all_field_info = pd.read_csv("data/field.csv").fillna('').set_index("ds_id")
     for i, row in df.iterrows():
     for i, row in df.iterrows():
-        if i > 100:
+        if start and i < start:
+            continue
+        if end and i > end:
             break
             break
-        row = row.to_dict()
+        card_id = row["card_id"]
+        if test_card_id and card_id != test_card_id:
+            continue
         if row["card_type_cd"] != '图表' or row["ds_id"] == "":
         if row["card_type_cd"] != '图表' or row["ds_id"] == "":
             continue
             continue
-        card_id = row["card_id"]
+
         try:
         try:
             added_fields_info = add_field_info.loc[[card_id]]
             added_fields_info = add_field_info.loc[[card_id]]
         except KeyError:
         except KeyError: