ysl2007 2 месяцев назад
Родитель
Сommit
7cc15a613b
1 измененных файлов с 232 добавлено и 1 удалено
  1. 232 1
      carddef2sql.py

+ 232 - 1
carddef2sql.py

@@ -399,4 +399,235 @@ def parse_filter_string(filter_relation_str):
             continue
             continue
         elif op_dict == 'SPARK_EXPR':
         elif op_dict == 'SPARK_EXPR':
             if 'formula' in cond_dict:
             if 'formula' in cond_dict:
-                
+                formula = quote_identifier(cond_dict['formula'], formula=True)
+                conditions[fdId] = {"exp": formula, "agg": is_aggregated}
+            else:
+                if isinstance(cond_dict['filterValue'], list) and len(cond_dict['filterValue']) == 1:
+                    field = quote_identifier(cond_dict['name'])
+                    value = cond_dict['filterValue'][0]
+                    conditions[fdId] = {"exp": f"{field} = {value}", "agg": is_aggregated}
+                else:
+                    print(f"警告: 无法解析筛选条件,SPARK_EXPR中未定义。跳过此条件。")
+            continue
+
+        # 处理条件
+        value_nums = op_dict["val_nums"]
+        if value_nums != 0 and len(values) != value_nums:
+            print(f"警告: 无法解析筛选条件,值数量与操作符不匹配。跳过此条件。")
+            continue
+
+        field = quote_identifier(field)
+        # consolidation 情况,将consolidation公式替换条件左边的field
+        if "consolidation" in cond_dict:
+            consolidation = cond_dict["consolidation"]
+            consolidation_field = get_consolidation_field(consolidation)
+            if not consolidation_field:
+                print(f"警告: 无法解析consolidation字段。跳过此条件。")
+                continue
+            else:
+                field = consolidation_field
+        else:
+            # 公式,非 consolidation情况
+            if "formula" in cond_dict:
+                field = quote_identifier(cond_dict["formula"], formula=True)
+            
+        if op_name in ("NI", "IN") and len(values) == 0:
+            print(f"警告: 无法解析筛选条件,IN或NI中参数个数为0。跳过此条件。")
+            continue
+            
+        # 特殊情况
+        if op_name in ('NI', 'IN') and None in values:
+            conditions[fdId] = {"exp": f"{field} IS NOT NULL", "agg": is_aggregated}
+            values = [x for x in values if x is not None]
+            if len(values) == 0:
+                continue
+        
+        # 填充模板所需要的参数
+        format_args = get_format_args(field, fd_type, op_dict, values)
+        condition_str = op_dict["template"].format(**format_args)
+        conditions[fdId] = {"exp": condition_str, "agg": is_aggregated}
+    return conditions
+
+def build_sql_query(card_data, added_fields_info, dataset_fid_name_map):
+    card_id = card_data["card_id"]
+    card_name = card_data["card_name"]
+    dataset_id = card_data.get("ds_id")
+    if not dataset_id:
+        print(f"错误: {card_id} {card_name} 数据集ID为空.")
+        return "", "", "", ""
+    
+    added_fields_info = get_added_fields_info(added_fields_info)
+    dataset_fid_name_map = get_fid_name_map(dataset_fid_name_map)
+
+    dimension_fids = parse_multi_value_field(card_data.get("field_id", []))
+    dimension_fields = parse_multi_value_field(card_data.get("field_name", []))
+    dimension_fid_name_map = dict(zip(dimension_fids, dimension_fields))
+    dimension_name_fid_map = dict(zip(dimension_fields, dimension_fids))
+
+    measure_fids = parse_multi_value_field(card_data.get("num_value_field_id", []))
+    measure_fields = parse_multi_value_field(card_data.get("num_value_field_name", []))
+    measure_aggs = parse_multi_value_field(card_data.get("num_value_field_merge_way", []))
+    filter_relation_str = card_data.get("filters_field_value_name_rela")
+
+    sort_fids = parse_multi_value_field(card_data.get("sort_field_id", []))
+    sort_fields = parse_multi_value_field(card_data.get("sort_field_name", []))
+    sort_method = parse_multi_value_field(card_data.get("sort_way", []))
+
+    all_field_ids = dimension_fids + \
+                    parse_multi_value_field(card_data.get("filters_field_id", [])) + \
+                    sort_fids + \
+                    measure_fids
+    all_field_names = dimension_fields + \
+                    parse_multi_value_field(card_data.get("filters_field_name", [])) + \
+                    sort_fields + \
+                    measure_fields
+    all_field_id_name_map = dict(zip(all_field_ids, all_field_names))
+
+    # 处理字段重命名关系
+    fields_rename_map = get_fields_rename_map(card_data.get("field_info", ""))
+    selected_fid_alias_map = dict(zip(dimension_fids+measure_fids, dimension_fields+measure_fields))
+
+    # 构建WITH
+    with_part = ""
+    new_date_fields = []
+    # 日期转换
+    for fid, name in all_field_id_name_map.items():
+        fid_splits = fid.split('_')
+        if len(fid_splits) == 2:
+            new_date_fields.append((fid, name))
+            old_fid = fid_splits[0]
+            selected_fid_alias_map[old_fid] = name
+    # 新增维度字段
+    new_dimension_fields = []
+    for fid, name in dimension_fid_name_map.items():
+        if fid in added_fields_info:
+            new_dimension_fields.append((fid, name))
+    # 如果有新增日期字段、新增维度字段,构建WITH
+    if new_date_fields or new_dimension_fields:
+        with_part = build_with_part(new_date_fields, new_dimension_fields, dataset_fid_name_map, added_fields_info, dataset_id)
+    
+    # 构建SELECT
+    select_parts = []
+    has_aggregation = False
+    
+    # 添加维度字段
+    for field in dimension_fields:
+        fid = dimension_name_fid_map[field]
+        alias = fields_rename_map.get(field)
+        if alias and alias != "null":
+            select_parts.append(f"{quote_identifier(field)} AS {quote_identifier(alias)}")
+            selected_fid_alias_map[fid] = alias
+        else:
+            select_parts.append(f"{quote_identifier(field)}")
+            selected_fid_alias_map[fid] = field
+    
+    # 加工计算字段
+    new_measure_fields, measure_aggs, agg_flag = process_calculation_fields(measure_fields, measure_aggs, added_fields_info, card_id, card_name)
+    if agg_flag:
+        has_aggregation = True
+    for i, field in enumerate(new_measure_fields):
+        fid = measure_fids[i]
+        alias = fields_rename_map.get(field.strip('`'))
+        agg_func_template = AGGREGATION_MAP.get(measure_aggs[i])
+        if not agg_func_template:
+            if not alias or alias == "null":
+                alias = measure_fields[i]
+            select_parts.append(f"{field} AS {quote_identifier(alias)}")
+            selected_fid_alias_map[fid] = alias
+        else:
+            has_aggregation = True
+            # 特殊处理 count distinct
+            if '{}' in agg_func_template:
+                agg_expression = agg_func_template.format(field)
+            else:
+                agg_expression = f"{agg_func_template}({field})"
+            # 添加别名
+            if not alias or alias == "null":
+                suffix = AGGREGATION_SUFFIX_MAP.get(measure_aggs[i])
+                alias = f"{measure_fields[i]}_{suffix}"
+            select_parts.append(f"{agg_expression} AS {quote_identifier(alias)}")
+            selected_fid_alias_map[fid] = alias
+    
+    if not select_parts:
+        print(f"错误: {card_id} {card_name} 没有select字段。")
+        return '', '', '', ''
+    else:
+        select_clause = "SELECT " + ",\n    ".join(select_parts)
+    
+    # 构建FROM
+    if with_part:
+        from_clause = "FROM tmp"
+    else:
+        from_clause = f"FROM {quote_identifier(str(dataset_id))}"
+    
+    # 构建WHERE
+    filter_conditions = {}
+    try:
+        filter_conditions = parse_filter_string(filter_relation_str)
+    except Exception as e:
+        print(f"错误: 卡片 {card_id} {card_name} 解析筛选条件出错:{e}。WHERE字句缺失。")
+        print("详细错误信息:")
+        print(traceback.format_exc())
+
+    # 构建GROUPBY
+    group_by_clause = ""
+    if has_aggregation and dimension_fields:
+        group_by_parts = [quote_identifier(field) for field in dimension_fields]
+        group_by_clause = "GROUP BY " + ", ".join(group_by_parts)
+    
+    # 构建ORDERBY
+    order_by_clause = ""
+    if sort_fields and sort_method and len(sort_fields) == len(sort_method):
+        order_by_parts = []
+        for i, field in enumerate(sort_fields):
+            fid = sort_fids[i]
+            if fid not in selected_fid_alias_map:
+                continue
+            alias = selected_fid_alias_map[fid]
+            order_by_parts.append(f"{quote_identifier(alias)} {sort_method[i]}")
+        if order_by_parts:
+            order_by_clause = "ORDER BY " + ", ".join(order_by_parts)
+    
+    # 组装SQL
+    sql_parts = [with_part, select_clause, from_clause]
+
+    return ("\n".join(sql_parts)).strip(), json.dumps(filter_conditions, ensure_ascii=False), group_by_clause, order_by_clause
+
+def generate():
+    res_list = []
+    df = pd.read_parquet("data/dev_card.parquet").reset_index()
+    add_field_info = pd.read_parquet("data/dev_calc.parquet").set_index("card_id")
+    all_field_info = pd.read_parquet("data/dev_field.parquet").set_index("ds_id")
+    for i, row in df.iterrows():
+        if i > 100:
+            break
+        row = row.to_dict()
+        if row["card_type_cd"] != '图表' or row["ds_id"] == "":
+            continue
+        card_id = row["card_id"]
+        try:
+            added_fields_info = add_field_info.loc[[card_id]]
+        except KeyError:
+            added_fields_info = pd.DataFrame()
+        try:
+            dataset_fid_name_map = all_field_info.loc[[row["ds_id"]]]
+        except KeyError:
+            print(f"错误: 没有数据及字段信息: {card_id}")
+            continue
+
+        select, where, groupby, orderby = '', '', '', ''
+        try:
+            select, where, groupby, orderby = build_sql_query(row, added_fields_info, dataset_fid_name_map)
+        except Exception as e:
+            print(f"错误: 卡片 {card_id} 发生未知错误: {e}")
+            print(i, traceback.format_exc())
+        if not select:
+            print(f"{card_id} 生成失败")
+            continue
+        res_list.append([str(card_id), str(row["card_name"]), select, where, groupby, orderby])
+    res_df = pd.DataFrame(res_list, columns=["card_id", "card_name", "select", 'where', 'groupby', 'orderby'])
+    return res_df
+
+if __name__ == "__main__":
+    df = generate()
+    df.to_parquet("output/sql.parquet")