ysl2007 il y a 3 mois
commit
cd15baaffc
4 fichiers modifiés avec 595 ajouts et 0 suppressions
  1. 85 0
      README.md
  2. BIN
      __pycache__/recall_eval.cpython-313.pyc
  3. 508 0
      recall_eval.py
  4. 2 0
      requirements.txt

+ 85 - 0
README.md

@@ -0,0 +1,85 @@
+# BI Recall Eval Script
+
+用于评估“根据用户问题召回仪表板/数据集”的三路召回表现。
+
+## 输入
+
+1. `id_map.parquet`
+
+必需字段:
+
+- `card_id`
+- `card_name`
+- `dashboard_id`
+- `dashboard_name`
+- `dataset_id`
+- `dataset_name`
+- `branch_code`
+
+2. `questions.txt`
+
+推荐格式每行一条:
+
+```txt
+123456\t本月深圳分行存款趋势
+234567\t零售客户AUM周报看板
+```
+
+脚本也兼容 `card_id,question`、`card_id|question` 和 `card_id question`。
+
+## 召回逻辑
+
+1. 第一路:从全量仪表板中随机采样 `500` 个,以 `80%` 概率强制放入正确卡片所在仪表板。
+2. 第二路:调用召回接口,`recall_type=dashboard`。
+3. 第三路:调用召回接口,`recall_type=dataset`。
+
+命中规则:
+
+- 第一路、第二路:只要召回到正确卡片关联的任一 `dashboard_id` 即视为命中。
+- 第三路:只要召回到正确卡片关联的任一 `dataset_id` 即视为命中。
+- 总召回:三路结果取并集后,只要任一路命中即视为命中。
+
+## 安装依赖
+
+```bash
+python3 -m pip install -r requirements.txt
+```
+
+## 运行
+
+```bash
+python3 recall_eval.py \
+  --questions-txt ./questions.txt \
+  --id-map-parquet ./id_map.parquet \
+  --api-url http://your-recall-api/recall \
+  --output-dir ./output
+```
+
+## 输出
+
+- `output/recall_summary.json`:汇总指标
+- `output/recall_details.csv`:逐题明细,包含第二路/第三路召回结果和错误信息
+
+## 接口约定
+
+脚本默认使用 `POST` JSON:
+
+```json
+{
+  "question": "本月深圳分行存款趋势",
+  "recall_type": "dashboard"
+}
+```
+
+返回支持以下常见结构之一:
+
+- 顶层数组
+- `{"data": [...]}`
+- `{"items": [...]}`
+- `{"results": [...]}`
+
+每条记录至少包含一个可识别的 ID 字段,例如:
+
+- `dashboard_id`
+- `dataset_id`
+- `id`

BIN
__pycache__/recall_eval.cpython-313.pyc


+ 508 - 0
recall_eval.py

@@ -0,0 +1,508 @@
+#!/usr/bin/env python3
+"""Evaluate dashboard/dataset recall quality for BI search questions."""
+
+from __future__ import annotations
+
+import argparse
+import csv
+import json
+import random
+import sys
+from collections import defaultdict
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Any
+
+import requests
+
+DEFAULT_RECALL_API_URL = "http://replace-with-your-recall-api"
+DEFAULT_RECALL_TIMEOUT_SECONDS = 10
+DEFAULT_FIRST_RECALL_SIZE = 500
+DEFAULT_FIRST_RECALL_HIT_RATE = 0.8
+DEFAULT_RANDOM_SEED = 20260311
+
+RECALL_TYPE_DASHBOARD = "dashboard"
+RECALL_TYPE_DATASET = "dataset"
+
+REQUIRED_COLUMNS = {
+    "card_id",
+    "card_name",
+    "dashboard_id",
+    "dashboard_name",
+    "dataset_id",
+    "dataset_name",
+    "branch_code",
+}
+
+
+@dataclass(frozen=True)
+class QuestionCase:
+    line_no: int
+    card_id: str
+    question: str
+
+
+@dataclass
+class CardMapping:
+    card_name: str
+    dashboard_ids: set[str]
+    dashboard_names: set[str]
+    dataset_ids: set[str]
+    dataset_names: set[str]
+    branch_codes: set[str]
+
+
+def parse_args() -> argparse.Namespace:
+    parser = argparse.ArgumentParser(
+        description="Evaluate three-route recall performance for BI dashboard/dataset search."
+    )
+    parser.add_argument(
+        "--questions-txt",
+        required=True,
+        help="Question file. Recommended format: one line per case as 'card_id<TAB>question'.",
+    )
+    parser.add_argument(
+        "--id-map-parquet",
+        required=True,
+        help="Parquet file containing card/dashboard/dataset mapping information.",
+    )
+    parser.add_argument(
+        "--output-dir",
+        default="output",
+        help="Directory for summary JSON and per-question detail CSV.",
+    )
+    parser.add_argument(
+        "--api-url",
+        default=DEFAULT_RECALL_API_URL,
+        help="Recall API URL. The script POSTs JSON: {'question': ..., 'recall_type': ...}.",
+    )
+    parser.add_argument(
+        "--timeout",
+        type=int,
+        default=DEFAULT_RECALL_TIMEOUT_SECONDS,
+        help="Recall API timeout in seconds.",
+    )
+    parser.add_argument(
+        "--first-recall-size",
+        type=int,
+        default=DEFAULT_FIRST_RECALL_SIZE,
+        help="Sample size for the simulated first recall route.",
+    )
+    parser.add_argument(
+        "--first-recall-hit-rate",
+        type=float,
+        default=DEFAULT_FIRST_RECALL_HIT_RATE,
+        help="Probability of forcing the gold dashboard into first-route results.",
+    )
+    parser.add_argument(
+        "--seed",
+        type=int,
+        default=DEFAULT_RANDOM_SEED,
+        help="Random seed used for reproducible first-route simulation.",
+    )
+    return parser.parse_args()
+
+
+def read_questions(path: Path) -> list[QuestionCase]:
+    cases: list[QuestionCase] = []
+    with path.open("r", encoding="utf-8") as handle:
+        for line_no, raw_line in enumerate(handle, start=1):
+            line = raw_line.strip()
+            if not line or line.startswith("#"):
+                continue
+            card_id, question = split_question_line(line, line_no)
+            cases.append(QuestionCase(line_no=line_no, card_id=card_id, question=question))
+    if not cases:
+        raise ValueError(f"No valid questions found in {path}")
+    return cases
+
+
+def split_question_line(line: str, line_no: int) -> tuple[str, str]:
+    for separator in ("\t", ",", "|"):
+        if separator in line:
+            left, right = line.split(separator, 1)
+            card_id = left.strip()
+            question = right.strip()
+            if card_id and question:
+                return card_id, question
+    parts = line.split(maxsplit=1)
+    if len(parts) == 2 and parts[0].strip() and parts[1].strip():
+        return parts[0].strip(), parts[1].strip()
+    raise ValueError(
+        f"Invalid question line {line_no}: '{line}'. Expected 'card_id<TAB>question'."
+    )
+
+
+def load_mapping(path: Path) -> tuple[dict[str, CardMapping], list[str]]:
+    try:
+        import pyarrow.parquet as pq
+    except ModuleNotFoundError as exc:
+        raise RuntimeError(
+            "Missing dependency 'pyarrow'. Install it first, for example: pip install pyarrow requests"
+        ) from exc
+
+    table = pq.read_table(path)
+    missing_columns = REQUIRED_COLUMNS.difference(table.column_names)
+    if missing_columns:
+        raise ValueError(f"{path} is missing required columns: {sorted(missing_columns)}")
+
+    rows = table.to_pylist()
+    if not rows:
+        raise ValueError(f"{path} does not contain any rows")
+
+    card_map: dict[str, CardMapping] = {}
+    all_dashboard_ids: set[str] = set()
+    for row in rows:
+        card_id = normalize_id(row.get("card_id"))
+        if card_id is None:
+            continue
+        mapping = card_map.setdefault(
+            card_id,
+            CardMapping(
+                card_name=normalize_text(row.get("card_name")),
+                dashboard_ids=set(),
+                dashboard_names=set(),
+                dataset_ids=set(),
+                dataset_names=set(),
+                branch_codes=set(),
+            ),
+        )
+        dashboard_id = normalize_id(row.get("dashboard_id"))
+        dataset_id = normalize_id(row.get("dataset_id"))
+        dashboard_name = normalize_text(row.get("dashboard_name"))
+        dataset_name = normalize_text(row.get("dataset_name"))
+        branch_code = normalize_text(row.get("branch_code"))
+        if dashboard_id is not None:
+            mapping.dashboard_ids.add(dashboard_id)
+            all_dashboard_ids.add(dashboard_id)
+        if dataset_id is not None:
+            mapping.dataset_ids.add(dataset_id)
+        if dashboard_name:
+            mapping.dashboard_names.add(dashboard_name)
+        if dataset_name:
+            mapping.dataset_names.add(dataset_name)
+        if branch_code:
+            mapping.branch_codes.add(branch_code)
+        if not mapping.card_name:
+            mapping.card_name = normalize_text(row.get("card_name"))
+    return card_map, sorted(all_dashboard_ids)
+
+
+def normalize_id(value: Any) -> str | None:
+    if value is None:
+        return None
+    text = str(value).strip()
+    return text or None
+
+
+def normalize_text(value: Any) -> str:
+    if value is None:
+        return ""
+    return str(value).strip()
+
+
+def simulate_first_recall(
+    all_dashboard_ids: list[str],
+    gold_dashboard_ids: set[str],
+    sample_size: int,
+    force_hit_rate: float,
+    rng: random.Random,
+) -> set[str]:
+    if not all_dashboard_ids:
+        return set()
+
+    capped_size = min(sample_size, len(all_dashboard_ids))
+    gold_candidates = [
+        dashboard_id for dashboard_id in gold_dashboard_ids if dashboard_id in all_dashboard_ids
+    ]
+    should_force_hit = bool(gold_candidates) and rng.random() < force_hit_rate
+
+    if should_force_hit:
+        forced_dashboard = rng.choice(gold_candidates)
+        remaining = [
+            dashboard_id for dashboard_id in all_dashboard_ids if dashboard_id != forced_dashboard
+        ]
+        picked = rng.sample(remaining, k=max(capped_size - 1, 0))
+        selected = set(picked)
+        selected.add(forced_dashboard)
+        return selected
+
+    non_gold_dashboards = [
+        dashboard_id for dashboard_id in all_dashboard_ids if dashboard_id not in gold_dashboard_ids
+    ]
+    if len(non_gold_dashboards) >= capped_size:
+        return set(rng.sample(non_gold_dashboards, k=capped_size))
+
+    return set(rng.sample(all_dashboard_ids, k=capped_size))
+
+
+def invoke_recall_api(
+    api_url: str,
+    question: str,
+    recall_type: str,
+    timeout_seconds: int,
+) -> tuple[set[str], list[dict[str, Any]]]:
+    if api_url == DEFAULT_RECALL_API_URL:
+        raise ValueError(
+            "Recall API URL is still the placeholder value. Pass --api-url with the actual endpoint."
+        )
+
+    response = requests.post(
+        api_url,
+        json={"question": question, "recall_type": recall_type},
+        timeout=timeout_seconds,
+    )
+    response.raise_for_status()
+    payload = response.json()
+    records = extract_records(payload)
+    ids = {record_id for record in records if (record_id := extract_record_id(record, recall_type))}
+    return ids, records
+
+
+def extract_records(payload: Any) -> list[dict[str, Any]]:
+    if isinstance(payload, list):
+        return [item for item in payload if isinstance(item, dict)]
+    if not isinstance(payload, dict):
+        return []
+    for key in ("data", "items", "results", "records"):
+        value = payload.get(key)
+        if isinstance(value, list):
+            return [item for item in value if isinstance(item, dict)]
+    return [payload]
+
+
+def extract_record_id(record: dict[str, Any], recall_type: str) -> str | None:
+    preferred_keys = [f"{recall_type}_id", "id"]
+    fallback_keys = ["dashboard_id", "dataset_id"]
+    for key in preferred_keys + fallback_keys:
+        value = normalize_id(record.get(key))
+        if value is not None:
+            return value
+    return None
+
+
+def safe_api_recall(
+    api_url: str,
+    question: str,
+    recall_type: str,
+    timeout_seconds: int,
+) -> tuple[set[str], list[dict[str, Any]], str]:
+    try:
+        ids, records = invoke_recall_api(api_url, question, recall_type, timeout_seconds)
+        return ids, records, ""
+    except Exception as exc:  # noqa: BLE001
+        return set(), [], str(exc)
+
+
+def ensure_output_dir(path: Path) -> None:
+    path.mkdir(parents=True, exist_ok=True)
+
+
+def write_details_csv(path: Path, rows: list[dict[str, Any]]) -> None:
+    fieldnames = [
+        "line_no",
+        "card_id",
+        "card_name",
+        "question",
+        "gold_dashboard_ids",
+        "gold_dataset_ids",
+        "first_recall_ids",
+        "first_recall_hit",
+        "dashboard_recall_ids",
+        "dashboard_recall_hit",
+        "dashboard_recall_error",
+        "dashboard_recall_records",
+        "dataset_recall_ids",
+        "dataset_recall_hit",
+        "dataset_recall_error",
+        "dataset_recall_records",
+        "union_hit",
+        "branch_codes",
+    ]
+    with path.open("w", encoding="utf-8", newline="") as handle:
+        writer = csv.DictWriter(handle, fieldnames=fieldnames)
+        writer.writeheader()
+        for row in rows:
+            writer.writerow(row)
+
+
+def write_summary_json(path: Path, summary: dict[str, Any]) -> None:
+    with path.open("w", encoding="utf-8") as handle:
+        json.dump(summary, handle, ensure_ascii=False, indent=2)
+
+
+def rate(hits: int, total: int) -> float:
+    if total == 0:
+        return 0.0
+    return hits / total
+
+
+def main() -> int:
+    args = parse_args()
+    rng = random.Random(args.seed)
+
+    questions_path = Path(args.questions_txt)
+    mapping_path = Path(args.id_map_parquet)
+    output_dir = Path(args.output_dir)
+
+    questions = read_questions(questions_path)
+    card_map, all_dashboard_ids = load_mapping(mapping_path)
+    ensure_output_dir(output_dir)
+
+    first_hits = 0
+    dashboard_hits = 0
+    dataset_hits = 0
+    union_hits = 0
+    detail_rows: list[dict[str, Any]] = []
+    missing_cards: list[dict[str, Any]] = []
+    error_counter: dict[str, int] = defaultdict(int)
+
+    for case in questions:
+        mapping = card_map.get(case.card_id)
+        if mapping is None:
+            missing_cards.append(
+                {"line_no": case.line_no, "card_id": case.card_id, "question": case.question}
+            )
+            detail_rows.append(
+                {
+                    "line_no": case.line_no,
+                    "card_id": case.card_id,
+                    "card_name": "",
+                    "question": case.question,
+                    "gold_dashboard_ids": "",
+                    "gold_dataset_ids": "",
+                    "first_recall_ids": "",
+                    "first_recall_hit": False,
+                    "dashboard_recall_ids": "",
+                    "dashboard_recall_hit": False,
+                    "dashboard_recall_error": "card_id not found in id_map.parquet",
+                    "dashboard_recall_records": "",
+                    "dataset_recall_ids": "",
+                    "dataset_recall_hit": False,
+                    "dataset_recall_error": "card_id not found in id_map.parquet",
+                    "dataset_recall_records": "",
+                    "union_hit": False,
+                    "branch_codes": "",
+                }
+            )
+            continue
+
+        first_ids = simulate_first_recall(
+            all_dashboard_ids=all_dashboard_ids,
+            gold_dashboard_ids=mapping.dashboard_ids,
+            sample_size=args.first_recall_size,
+            force_hit_rate=args.first_recall_hit_rate,
+            rng=rng,
+        )
+        dashboard_ids, dashboard_records, dashboard_error = safe_api_recall(
+            api_url=args.api_url,
+            question=case.question,
+            recall_type=RECALL_TYPE_DASHBOARD,
+            timeout_seconds=args.timeout,
+        )
+        dataset_ids, dataset_records, dataset_error = safe_api_recall(
+            api_url=args.api_url,
+            question=case.question,
+            recall_type=RECALL_TYPE_DATASET,
+            timeout_seconds=args.timeout,
+        )
+
+        first_hit = bool(first_ids.intersection(mapping.dashboard_ids))
+        dashboard_hit = bool(dashboard_ids.intersection(mapping.dashboard_ids))
+        dataset_hit = bool(dataset_ids.intersection(mapping.dataset_ids))
+        union_hit = first_hit or dashboard_hit or dataset_hit
+
+        first_hits += int(first_hit)
+        dashboard_hits += int(dashboard_hit)
+        dataset_hits += int(dataset_hit)
+        union_hits += int(union_hit)
+
+        if dashboard_error:
+            error_counter["dashboard_api_errors"] += 1
+        if dataset_error:
+            error_counter["dataset_api_errors"] += 1
+
+        detail_rows.append(
+            {
+                "line_no": case.line_no,
+                "card_id": case.card_id,
+                "card_name": mapping.card_name,
+                "question": case.question,
+                "gold_dashboard_ids": ",".join(sorted(mapping.dashboard_ids)),
+                "gold_dataset_ids": ",".join(sorted(mapping.dataset_ids)),
+                "first_recall_ids": ",".join(sorted(first_ids)),
+                "first_recall_hit": first_hit,
+                "dashboard_recall_ids": ",".join(sorted(dashboard_ids)),
+                "dashboard_recall_hit": dashboard_hit,
+                "dashboard_recall_error": dashboard_error,
+                "dashboard_recall_records": json.dumps(
+                    dashboard_records, ensure_ascii=False, separators=(",", ":")
+                ),
+                "dataset_recall_ids": ",".join(sorted(dataset_ids)),
+                "dataset_recall_hit": dataset_hit,
+                "dataset_recall_error": dataset_error,
+                "dataset_recall_records": json.dumps(
+                    dataset_records, ensure_ascii=False, separators=(",", ":")
+                ),
+                "union_hit": union_hit,
+                "branch_codes": ",".join(sorted(mapping.branch_codes)),
+            }
+        )
+
+        if dashboard_records:
+            detail_rows[-1]["dashboard_recall_ids"] = ",".join(sorted(dashboard_ids))
+        if dataset_records:
+            detail_rows[-1]["dataset_recall_ids"] = ",".join(sorted(dataset_ids))
+
+    total = len(questions)
+    first_rate = rate(first_hits, total)
+    dashboard_rate = rate(dashboard_hits, total)
+    dataset_rate = rate(dataset_hits, total)
+    union_rate = rate(union_hits, total)
+    uplift_vs_first = union_rate - first_rate
+
+    summary = {
+        "input": {
+            "questions_txt": str(questions_path),
+            "id_map_parquet": str(mapping_path),
+            "api_url": args.api_url,
+            "first_recall_size": args.first_recall_size,
+            "first_recall_hit_rate": args.first_recall_hit_rate,
+            "seed": args.seed,
+        },
+        "counts": {
+            "total_questions": total,
+            "first_recall_hits": first_hits,
+            "dashboard_recall_hits": dashboard_hits,
+            "dataset_recall_hits": dataset_hits,
+            "union_hits": union_hits,
+            "missing_card_mappings": len(missing_cards),
+        },
+        "rates": {
+            "first_recall_rate": first_rate,
+            "dashboard_recall_rate": dashboard_rate,
+            "dataset_recall_rate": dataset_rate,
+            "union_recall_rate": union_rate,
+            "uplift_vs_first_recall": uplift_vs_first,
+        },
+        "errors": dict(error_counter),
+        "missing_cards": missing_cards,
+        "outputs": {
+            "detail_csv": str(output_dir / "recall_details.csv"),
+            "summary_json": str(output_dir / "recall_summary.json"),
+        },
+    }
+
+    write_details_csv(output_dir / "recall_details.csv", detail_rows)
+    write_summary_json(output_dir / "recall_summary.json", summary)
+
+    print(json.dumps(summary, ensure_ascii=False, indent=2))
+    return 0
+
+
+if __name__ == "__main__":
+    try:
+        raise SystemExit(main())
+    except Exception as exc:  # noqa: BLE001
+        print(f"[ERROR] {exc}", file=sys.stderr)
+        raise SystemExit(1) from exc

+ 2 - 0
requirements.txt

@@ -0,0 +1,2 @@
+pyarrow>=15.0.0
+requests>=2.31.0