| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508 |
- #!/usr/bin/env python3
- """Evaluate dashboard/dataset recall quality for BI search questions."""
- from __future__ import annotations
- import argparse
- import csv
- import json
- import random
- import sys
- from collections import defaultdict
- from dataclasses import dataclass
- from pathlib import Path
- from typing import Any
- import requests
- DEFAULT_RECALL_API_URL = "http://replace-with-your-recall-api"
- DEFAULT_RECALL_TIMEOUT_SECONDS = 10
- DEFAULT_FIRST_RECALL_SIZE = 500
- DEFAULT_FIRST_RECALL_HIT_RATE = 0.8
- DEFAULT_RANDOM_SEED = 20260311
- RECALL_TYPE_DASHBOARD = "dashboard"
- RECALL_TYPE_DATASET = "dataset"
- REQUIRED_COLUMNS = {
- "card_id",
- "card_name",
- "dashboard_id",
- "dashboard_name",
- "dataset_id",
- "dataset_name",
- "branch_code",
- }
- @dataclass(frozen=True)
- class QuestionCase:
- line_no: int
- card_id: str
- question: str
- @dataclass
- class CardMapping:
- card_name: str
- dashboard_ids: set[str]
- dashboard_names: set[str]
- dataset_ids: set[str]
- dataset_names: set[str]
- branch_codes: set[str]
- def parse_args() -> argparse.Namespace:
- parser = argparse.ArgumentParser(
- description="Evaluate three-route recall performance for BI dashboard/dataset search."
- )
- parser.add_argument(
- "--questions-txt",
- required=True,
- help="Question file. Recommended format: one line per case as 'card_id<TAB>question'.",
- )
- parser.add_argument(
- "--id-map-parquet",
- required=True,
- help="Parquet file containing card/dashboard/dataset mapping information.",
- )
- parser.add_argument(
- "--output-dir",
- default="output",
- help="Directory for summary JSON and per-question detail CSV.",
- )
- parser.add_argument(
- "--api-url",
- default=DEFAULT_RECALL_API_URL,
- help="Recall API URL. The script POSTs JSON: {'question': ..., 'recall_type': ...}.",
- )
- parser.add_argument(
- "--timeout",
- type=int,
- default=DEFAULT_RECALL_TIMEOUT_SECONDS,
- help="Recall API timeout in seconds.",
- )
- parser.add_argument(
- "--first-recall-size",
- type=int,
- default=DEFAULT_FIRST_RECALL_SIZE,
- help="Sample size for the simulated first recall route.",
- )
- parser.add_argument(
- "--first-recall-hit-rate",
- type=float,
- default=DEFAULT_FIRST_RECALL_HIT_RATE,
- help="Probability of forcing the gold dashboard into first-route results.",
- )
- parser.add_argument(
- "--seed",
- type=int,
- default=DEFAULT_RANDOM_SEED,
- help="Random seed used for reproducible first-route simulation.",
- )
- return parser.parse_args()
- def read_questions(path: Path) -> list[QuestionCase]:
- cases: list[QuestionCase] = []
- with path.open("r", encoding="utf-8") as handle:
- for line_no, raw_line in enumerate(handle, start=1):
- line = raw_line.strip()
- if not line or line.startswith("#"):
- continue
- card_id, question = split_question_line(line, line_no)
- cases.append(QuestionCase(line_no=line_no, card_id=card_id, question=question))
- if not cases:
- raise ValueError(f"No valid questions found in {path}")
- return cases
- def split_question_line(line: str, line_no: int) -> tuple[str, str]:
- for separator in ("\t", ",", "|"):
- if separator in line:
- left, right = line.split(separator, 1)
- card_id = left.strip()
- question = right.strip()
- if card_id and question:
- return card_id, question
- parts = line.split(maxsplit=1)
- if len(parts) == 2 and parts[0].strip() and parts[1].strip():
- return parts[0].strip(), parts[1].strip()
- raise ValueError(
- f"Invalid question line {line_no}: '{line}'. Expected 'card_id<TAB>question'."
- )
- def load_mapping(path: Path) -> tuple[dict[str, CardMapping], list[str]]:
- try:
- import pyarrow.parquet as pq
- except ModuleNotFoundError as exc:
- raise RuntimeError(
- "Missing dependency 'pyarrow'. Install it first, for example: pip install pyarrow requests"
- ) from exc
- table = pq.read_table(path)
- missing_columns = REQUIRED_COLUMNS.difference(table.column_names)
- if missing_columns:
- raise ValueError(f"{path} is missing required columns: {sorted(missing_columns)}")
- rows = table.to_pylist()
- if not rows:
- raise ValueError(f"{path} does not contain any rows")
- card_map: dict[str, CardMapping] = {}
- all_dashboard_ids: set[str] = set()
- for row in rows:
- card_id = normalize_id(row.get("card_id"))
- if card_id is None:
- continue
- mapping = card_map.setdefault(
- card_id,
- CardMapping(
- card_name=normalize_text(row.get("card_name")),
- dashboard_ids=set(),
- dashboard_names=set(),
- dataset_ids=set(),
- dataset_names=set(),
- branch_codes=set(),
- ),
- )
- dashboard_id = normalize_id(row.get("dashboard_id"))
- dataset_id = normalize_id(row.get("dataset_id"))
- dashboard_name = normalize_text(row.get("dashboard_name"))
- dataset_name = normalize_text(row.get("dataset_name"))
- branch_code = normalize_text(row.get("branch_code"))
- if dashboard_id is not None:
- mapping.dashboard_ids.add(dashboard_id)
- all_dashboard_ids.add(dashboard_id)
- if dataset_id is not None:
- mapping.dataset_ids.add(dataset_id)
- if dashboard_name:
- mapping.dashboard_names.add(dashboard_name)
- if dataset_name:
- mapping.dataset_names.add(dataset_name)
- if branch_code:
- mapping.branch_codes.add(branch_code)
- if not mapping.card_name:
- mapping.card_name = normalize_text(row.get("card_name"))
- return card_map, sorted(all_dashboard_ids)
- def normalize_id(value: Any) -> str | None:
- if value is None:
- return None
- text = str(value).strip()
- return text or None
- def normalize_text(value: Any) -> str:
- if value is None:
- return ""
- return str(value).strip()
- def simulate_first_recall(
- all_dashboard_ids: list[str],
- gold_dashboard_ids: set[str],
- sample_size: int,
- force_hit_rate: float,
- rng: random.Random,
- ) -> set[str]:
- if not all_dashboard_ids:
- return set()
- capped_size = min(sample_size, len(all_dashboard_ids))
- gold_candidates = [
- dashboard_id for dashboard_id in gold_dashboard_ids if dashboard_id in all_dashboard_ids
- ]
- should_force_hit = bool(gold_candidates) and rng.random() < force_hit_rate
- if should_force_hit:
- forced_dashboard = rng.choice(gold_candidates)
- remaining = [
- dashboard_id for dashboard_id in all_dashboard_ids if dashboard_id != forced_dashboard
- ]
- picked = rng.sample(remaining, k=max(capped_size - 1, 0))
- selected = set(picked)
- selected.add(forced_dashboard)
- return selected
- non_gold_dashboards = [
- dashboard_id for dashboard_id in all_dashboard_ids if dashboard_id not in gold_dashboard_ids
- ]
- if len(non_gold_dashboards) >= capped_size:
- return set(rng.sample(non_gold_dashboards, k=capped_size))
- return set(rng.sample(all_dashboard_ids, k=capped_size))
- def invoke_recall_api(
- api_url: str,
- question: str,
- recall_type: str,
- timeout_seconds: int,
- ) -> tuple[set[str], list[dict[str, Any]]]:
- if api_url == DEFAULT_RECALL_API_URL:
- raise ValueError(
- "Recall API URL is still the placeholder value. Pass --api-url with the actual endpoint."
- )
- response = requests.post(
- api_url,
- json={"question": question, "recall_type": recall_type},
- timeout=timeout_seconds,
- )
- response.raise_for_status()
- payload = response.json()
- records = extract_records(payload)
- ids = {record_id for record in records if (record_id := extract_record_id(record, recall_type))}
- return ids, records
- def extract_records(payload: Any) -> list[dict[str, Any]]:
- if isinstance(payload, list):
- return [item for item in payload if isinstance(item, dict)]
- if not isinstance(payload, dict):
- return []
- for key in ("data", "items", "results", "records"):
- value = payload.get(key)
- if isinstance(value, list):
- return [item for item in value if isinstance(item, dict)]
- return [payload]
- def extract_record_id(record: dict[str, Any], recall_type: str) -> str | None:
- preferred_keys = [f"{recall_type}_id", "id"]
- fallback_keys = ["dashboard_id", "dataset_id"]
- for key in preferred_keys + fallback_keys:
- value = normalize_id(record.get(key))
- if value is not None:
- return value
- return None
- def safe_api_recall(
- api_url: str,
- question: str,
- recall_type: str,
- timeout_seconds: int,
- ) -> tuple[set[str], list[dict[str, Any]], str]:
- try:
- ids, records = invoke_recall_api(api_url, question, recall_type, timeout_seconds)
- return ids, records, ""
- except Exception as exc: # noqa: BLE001
- return set(), [], str(exc)
- def ensure_output_dir(path: Path) -> None:
- path.mkdir(parents=True, exist_ok=True)
- def write_details_csv(path: Path, rows: list[dict[str, Any]]) -> None:
- fieldnames = [
- "line_no",
- "card_id",
- "card_name",
- "question",
- "gold_dashboard_ids",
- "gold_dataset_ids",
- "first_recall_ids",
- "first_recall_hit",
- "dashboard_recall_ids",
- "dashboard_recall_hit",
- "dashboard_recall_error",
- "dashboard_recall_records",
- "dataset_recall_ids",
- "dataset_recall_hit",
- "dataset_recall_error",
- "dataset_recall_records",
- "union_hit",
- "branch_codes",
- ]
- with path.open("w", encoding="utf-8", newline="") as handle:
- writer = csv.DictWriter(handle, fieldnames=fieldnames)
- writer.writeheader()
- for row in rows:
- writer.writerow(row)
- def write_summary_json(path: Path, summary: dict[str, Any]) -> None:
- with path.open("w", encoding="utf-8") as handle:
- json.dump(summary, handle, ensure_ascii=False, indent=2)
- def rate(hits: int, total: int) -> float:
- if total == 0:
- return 0.0
- return hits / total
- def main() -> int:
- args = parse_args()
- rng = random.Random(args.seed)
- questions_path = Path(args.questions_txt)
- mapping_path = Path(args.id_map_parquet)
- output_dir = Path(args.output_dir)
- questions = read_questions(questions_path)
- card_map, all_dashboard_ids = load_mapping(mapping_path)
- ensure_output_dir(output_dir)
- first_hits = 0
- dashboard_hits = 0
- dataset_hits = 0
- union_hits = 0
- detail_rows: list[dict[str, Any]] = []
- missing_cards: list[dict[str, Any]] = []
- error_counter: dict[str, int] = defaultdict(int)
- for case in questions:
- mapping = card_map.get(case.card_id)
- if mapping is None:
- missing_cards.append(
- {"line_no": case.line_no, "card_id": case.card_id, "question": case.question}
- )
- detail_rows.append(
- {
- "line_no": case.line_no,
- "card_id": case.card_id,
- "card_name": "",
- "question": case.question,
- "gold_dashboard_ids": "",
- "gold_dataset_ids": "",
- "first_recall_ids": "",
- "first_recall_hit": False,
- "dashboard_recall_ids": "",
- "dashboard_recall_hit": False,
- "dashboard_recall_error": "card_id not found in id_map.parquet",
- "dashboard_recall_records": "",
- "dataset_recall_ids": "",
- "dataset_recall_hit": False,
- "dataset_recall_error": "card_id not found in id_map.parquet",
- "dataset_recall_records": "",
- "union_hit": False,
- "branch_codes": "",
- }
- )
- continue
- first_ids = simulate_first_recall(
- all_dashboard_ids=all_dashboard_ids,
- gold_dashboard_ids=mapping.dashboard_ids,
- sample_size=args.first_recall_size,
- force_hit_rate=args.first_recall_hit_rate,
- rng=rng,
- )
- dashboard_ids, dashboard_records, dashboard_error = safe_api_recall(
- api_url=args.api_url,
- question=case.question,
- recall_type=RECALL_TYPE_DASHBOARD,
- timeout_seconds=args.timeout,
- )
- dataset_ids, dataset_records, dataset_error = safe_api_recall(
- api_url=args.api_url,
- question=case.question,
- recall_type=RECALL_TYPE_DATASET,
- timeout_seconds=args.timeout,
- )
- first_hit = bool(first_ids.intersection(mapping.dashboard_ids))
- dashboard_hit = bool(dashboard_ids.intersection(mapping.dashboard_ids))
- dataset_hit = bool(dataset_ids.intersection(mapping.dataset_ids))
- union_hit = first_hit or dashboard_hit or dataset_hit
- first_hits += int(first_hit)
- dashboard_hits += int(dashboard_hit)
- dataset_hits += int(dataset_hit)
- union_hits += int(union_hit)
- if dashboard_error:
- error_counter["dashboard_api_errors"] += 1
- if dataset_error:
- error_counter["dataset_api_errors"] += 1
- detail_rows.append(
- {
- "line_no": case.line_no,
- "card_id": case.card_id,
- "card_name": mapping.card_name,
- "question": case.question,
- "gold_dashboard_ids": ",".join(sorted(mapping.dashboard_ids)),
- "gold_dataset_ids": ",".join(sorted(mapping.dataset_ids)),
- "first_recall_ids": ",".join(sorted(first_ids)),
- "first_recall_hit": first_hit,
- "dashboard_recall_ids": ",".join(sorted(dashboard_ids)),
- "dashboard_recall_hit": dashboard_hit,
- "dashboard_recall_error": dashboard_error,
- "dashboard_recall_records": json.dumps(
- dashboard_records, ensure_ascii=False, separators=(",", ":")
- ),
- "dataset_recall_ids": ",".join(sorted(dataset_ids)),
- "dataset_recall_hit": dataset_hit,
- "dataset_recall_error": dataset_error,
- "dataset_recall_records": json.dumps(
- dataset_records, ensure_ascii=False, separators=(",", ":")
- ),
- "union_hit": union_hit,
- "branch_codes": ",".join(sorted(mapping.branch_codes)),
- }
- )
- if dashboard_records:
- detail_rows[-1]["dashboard_recall_ids"] = ",".join(sorted(dashboard_ids))
- if dataset_records:
- detail_rows[-1]["dataset_recall_ids"] = ",".join(sorted(dataset_ids))
- total = len(questions)
- first_rate = rate(first_hits, total)
- dashboard_rate = rate(dashboard_hits, total)
- dataset_rate = rate(dataset_hits, total)
- union_rate = rate(union_hits, total)
- uplift_vs_first = union_rate - first_rate
- summary = {
- "input": {
- "questions_txt": str(questions_path),
- "id_map_parquet": str(mapping_path),
- "api_url": args.api_url,
- "first_recall_size": args.first_recall_size,
- "first_recall_hit_rate": args.first_recall_hit_rate,
- "seed": args.seed,
- },
- "counts": {
- "total_questions": total,
- "first_recall_hits": first_hits,
- "dashboard_recall_hits": dashboard_hits,
- "dataset_recall_hits": dataset_hits,
- "union_hits": union_hits,
- "missing_card_mappings": len(missing_cards),
- },
- "rates": {
- "first_recall_rate": first_rate,
- "dashboard_recall_rate": dashboard_rate,
- "dataset_recall_rate": dataset_rate,
- "union_recall_rate": union_rate,
- "uplift_vs_first_recall": uplift_vs_first,
- },
- "errors": dict(error_counter),
- "missing_cards": missing_cards,
- "outputs": {
- "detail_csv": str(output_dir / "recall_details.csv"),
- "summary_json": str(output_dir / "recall_summary.json"),
- },
- }
- write_details_csv(output_dir / "recall_details.csv", detail_rows)
- write_summary_json(output_dir / "recall_summary.json", summary)
- print(json.dumps(summary, ensure_ascii=False, indent=2))
- return 0
- if __name__ == "__main__":
- try:
- raise SystemExit(main())
- except Exception as exc: # noqa: BLE001
- print(f"[ERROR] {exc}", file=sys.stderr)
- raise SystemExit(1) from exc
|