|
|
@@ -140,6 +140,29 @@ def get_fields_rename_map(field_info):
|
|
|
ret[one_map["name"]] = one_map["alias"]
|
|
|
return ret
|
|
|
|
|
|
+# 递归解析嵌套的计算字段
|
|
|
+def resolve_calculation_formula(formula, calculation_fields, visited=None):
|
|
|
+ if not formula:
|
|
|
+ return formula
|
|
|
+ if visited is None:
|
|
|
+ visited = set()
|
|
|
+
|
|
|
+ def replace_calculation_field(match):
|
|
|
+ field_key = match.group(1).strip()
|
|
|
+ field_def = calculation_fields.get(field_key)
|
|
|
+ if not field_def:
|
|
|
+ return match.group(0)
|
|
|
+ field_id = field_def.get("field_id") or field_key
|
|
|
+ if field_id in visited:
|
|
|
+ raise ValueError(f"计算字段存在循环引用: {field_key}")
|
|
|
+ nested_formula = field_def["calculation"].get("formula", "")
|
|
|
+ if "consolidation" in nested_formula:
|
|
|
+ return match.group(0)
|
|
|
+ resolved = resolve_calculation_formula(nested_formula, calculation_fields, visited | {field_id})
|
|
|
+ return f"({resolved})"
|
|
|
+
|
|
|
+ return re.sub(r"\[([^\[\]]+)\]", replace_calculation_field, formula)
|
|
|
+
|
|
|
def build_with_part(new_date_fields, new_dimension_fields, dataset_fid_name_map, added_fields_info, dataset_id):
|
|
|
sql_part = 'WITH tmp as (\nSELECT *,\n'
|
|
|
with_expressions = []
|
|
|
@@ -166,6 +189,8 @@ def build_with_part(new_date_fields, new_dimension_fields, dataset_fid_name_map,
|
|
|
tmp_part = get_consolidation_field(consolidation_dict)
|
|
|
tmp_part += f" AS `{new_name}`"
|
|
|
else:
|
|
|
+ # 递归解析计算字段是否有嵌套情况
|
|
|
+ formula = resolve_calculation_formula(formula, added_fields_info, {fid})
|
|
|
tmp_part = quote_identifier(formula, formula=True) + f" AS `{new_name}`"
|
|
|
with_expressions.append(tmp_part)
|
|
|
sql_part += ',\n'.join(with_expressions)
|
|
|
@@ -178,41 +203,44 @@ def process_measure_fields(measure_fields, measure_aggs, calculation_fields, car
|
|
|
if len(measure_fields) < len(measure_aggs):
|
|
|
print(f"警告: 卡片 {card_id} {card_name}: 数值字段数量小于聚合函数数量,不合法")
|
|
|
print(f"警告: 卡片 {card_id} {card_name}: 不添加任何数值字段.")
|
|
|
- return [], [], False
|
|
|
+ return [], [], [], False
|
|
|
## 数值字段 大于 聚合函数数量,存在聚合类型的计算字段,尝试填充
|
|
|
elif len(measure_fields) > len(measure_aggs):
|
|
|
## 计算数值字段数量
|
|
|
num_cals = 0
|
|
|
for field in measure_fields:
|
|
|
- if field in calculation_fields and calculation_fields[field]["calculation"]["isAggregated"] is True:
|
|
|
+ if field in calculation_fields: # and calculation_fields[field]["calculation"]["isAggregated"] is True:
|
|
|
num_cals += 1
|
|
|
## 如果不存在任何计算字段,补全剩余的NUL聚合函数
|
|
|
if num_cals == 0:
|
|
|
measure_aggs.extend(['NULL'] * (len(measure_fields) - len(measure_aggs)))
|
|
|
- return measure_fields, measure_aggs, True
|
|
|
+ return [quote_identifier(field) for field in measure_fields], measure_aggs, [False] * len(measure_fields), True
|
|
|
## 如果存在计算字段,且相加后的 聚合函数数量 仍小于 数值字段数量,不合法
|
|
|
if num_cals + len(measure_aggs) != len(measure_fields):
|
|
|
print(f"警告: 卡片 {card_id} {card_name}: 数值字段数量大于聚合函数数量,不合法")
|
|
|
print(f"警告: 卡片 {card_id} {card_name}: 不添加任何数值字段.")
|
|
|
- return [], [], False
|
|
|
+ return [], [], [], False
|
|
|
## 通过验证,填充聚合函数
|
|
|
- new_measure_fields, new_measure_aggs, agg_flag = [], [], False
|
|
|
+ new_measure_fields, new_measure_aggs, measure_is_aggregated, agg_flag = [], [], [], False
|
|
|
for i, field in enumerate(measure_fields):
|
|
|
## 非计算字段
|
|
|
if field not in calculation_fields:
|
|
|
new_measure_fields.append(quote_identifier(field))
|
|
|
new_measure_aggs.append(measure_aggs.pop(0))
|
|
|
+ measure_is_aggregated.append(False)
|
|
|
## 计算字段
|
|
|
else:
|
|
|
formula = calculation_fields[field]["calculation"]["formula"]
|
|
|
- formula = formula.replace('\n', '')
|
|
|
+ formula = resolve_calculation_formula(formula, calculation_fields, {calculation_fields[field]["field_id"]})
|
|
|
new_measure_fields.append(quote_identifier(formula, formula=True))
|
|
|
if calculation_fields[field]["calculation"]["isAggregated"] is True:
|
|
|
new_measure_aggs.append("NUL")
|
|
|
+ measure_is_aggregated.append(True)
|
|
|
agg_flag = True
|
|
|
else:
|
|
|
new_measure_aggs.append(measure_aggs.pop(0))
|
|
|
- return new_measure_fields, new_measure_aggs, agg_flag
|
|
|
+ measure_is_aggregated.append(False)
|
|
|
+ return new_measure_fields, new_measure_aggs, measure_is_aggregated, agg_flag
|
|
|
|
|
|
# sql部分去重
|
|
|
def dedupe_sql_parts(parts):
|
|
|
@@ -545,18 +573,20 @@ def build_sql_query(card_data, added_fields_info, dataset_fid_name_map):
|
|
|
selected_fid_alias_map[fid] = field
|
|
|
|
|
|
# 加工计算字段
|
|
|
- new_measure_fields, measure_aggs, agg_flag = process_measure_fields(measure_fields, measure_aggs, added_fields_info, card_id, card_name)
|
|
|
+ new_measure_fields, measure_aggs, measure_is_aggregated, agg_flag = process_measure_fields(measure_fields, measure_aggs, added_fields_info, card_id, card_name)
|
|
|
if agg_flag:
|
|
|
has_aggregation = True
|
|
|
for i, field in enumerate(new_measure_fields):
|
|
|
fid = measure_fids[i]
|
|
|
alias = fields_rename_map.get(field.strip('`'))
|
|
|
+ # measure_agg是NUL,不需要聚合(等同于维度字段)或公式本身已经有聚合函数
|
|
|
agg_func_template = AGGREGATION_MAP.get(measure_aggs[i])
|
|
|
if not agg_func_template:
|
|
|
if not alias or alias == "null":
|
|
|
alias = measure_fields[i]
|
|
|
select_parts.append(f"{field} AS {quote_identifier(alias)}")
|
|
|
- if field and re.search(r"\b(sum|avg|count|max|min|stddev|variance|collect_list|collect_set|percentile|percentile_approx)|\s*\(", field, flags=re.IGNORECASE) is None:
|
|
|
+ # 属于计算字段,但没有聚合函数,等同于维度字段,需要加入groupbyby。
|
|
|
+ if not measure_is_aggregated[i] and field and re.search(r"\b(sum|avg|count|max|min|stddev|variance|collect_list|collect_set|percentile|percentile_approx)|\s*\(", field, flags=re.IGNORECASE) is None:
|
|
|
non_aggregated_select_parts.append(field)
|
|
|
selected_fid_alias_map[fid] = alias
|
|
|
else:
|
|
|
@@ -621,18 +651,22 @@ def build_sql_query(card_data, added_fields_info, dataset_fid_name_map):
|
|
|
# 返回 select, where, groupby, orderby
|
|
|
return ("\n".join(sql_parts)).strip(), json.dumps(filter_conditions, ensure_ascii=False), group_by_clause, order_by_clause
|
|
|
|
|
|
-def generate():
|
|
|
+def generate(start=None, end=None, test_card_id=None):
|
|
|
res_list = []
|
|
|
df = pd.read_csv("data/card.csv").fillna("").reset_index()
|
|
|
add_field_info = pd.read_csv("data/calc.csv").fillna('').set_index("card_id")
|
|
|
all_field_info = pd.read_csv("data/field.csv").fillna('').set_index("ds_id")
|
|
|
for i, row in df.iterrows():
|
|
|
- if i > 100:
|
|
|
+ if start and i < start:
|
|
|
+ continue
|
|
|
+ if end and i > end:
|
|
|
break
|
|
|
- row = row.to_dict()
|
|
|
+ card_id = row["card_id"]
|
|
|
+ if test_card_id and card_id != test_card_id:
|
|
|
+ continue
|
|
|
if row["card_type_cd"] != '图表' or row["ds_id"] == "":
|
|
|
continue
|
|
|
- card_id = row["card_id"]
|
|
|
+
|
|
|
try:
|
|
|
added_fields_info = add_field_info.loc[[card_id]]
|
|
|
except KeyError:
|