recall_eval.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508
  1. #!/usr/bin/env python3
  2. """Evaluate dashboard/dataset recall quality for BI search questions."""
  3. from __future__ import annotations
  4. import argparse
  5. import csv
  6. import json
  7. import random
  8. import sys
  9. from collections import defaultdict
  10. from dataclasses import dataclass
  11. from pathlib import Path
  12. from typing import Any
  13. import requests
  14. DEFAULT_RECALL_API_URL = "http://replace-with-your-recall-api"
  15. DEFAULT_RECALL_TIMEOUT_SECONDS = 10
  16. DEFAULT_FIRST_RECALL_SIZE = 500
  17. DEFAULT_FIRST_RECALL_HIT_RATE = 0.8
  18. DEFAULT_RANDOM_SEED = 20260311
  19. RECALL_TYPE_DASHBOARD = "dashboard"
  20. RECALL_TYPE_DATASET = "dataset"
  21. REQUIRED_COLUMNS = {
  22. "card_id",
  23. "card_name",
  24. "dashboard_id",
  25. "dashboard_name",
  26. "dataset_id",
  27. "dataset_name",
  28. "branch_code",
  29. }
  30. @dataclass(frozen=True)
  31. class QuestionCase:
  32. line_no: int
  33. card_id: str
  34. question: str
  35. @dataclass
  36. class CardMapping:
  37. card_name: str
  38. dashboard_ids: set[str]
  39. dashboard_names: set[str]
  40. dataset_ids: set[str]
  41. dataset_names: set[str]
  42. branch_codes: set[str]
  43. def parse_args() -> argparse.Namespace:
  44. parser = argparse.ArgumentParser(
  45. description="Evaluate three-route recall performance for BI dashboard/dataset search."
  46. )
  47. parser.add_argument(
  48. "--questions-txt",
  49. required=True,
  50. help="Question file. Recommended format: one line per case as 'card_id<TAB>question'.",
  51. )
  52. parser.add_argument(
  53. "--id-map-parquet",
  54. required=True,
  55. help="Parquet file containing card/dashboard/dataset mapping information.",
  56. )
  57. parser.add_argument(
  58. "--output-dir",
  59. default="output",
  60. help="Directory for summary JSON and per-question detail CSV.",
  61. )
  62. parser.add_argument(
  63. "--api-url",
  64. default=DEFAULT_RECALL_API_URL,
  65. help="Recall API URL. The script POSTs JSON: {'question': ..., 'recall_type': ...}.",
  66. )
  67. parser.add_argument(
  68. "--timeout",
  69. type=int,
  70. default=DEFAULT_RECALL_TIMEOUT_SECONDS,
  71. help="Recall API timeout in seconds.",
  72. )
  73. parser.add_argument(
  74. "--first-recall-size",
  75. type=int,
  76. default=DEFAULT_FIRST_RECALL_SIZE,
  77. help="Sample size for the simulated first recall route.",
  78. )
  79. parser.add_argument(
  80. "--first-recall-hit-rate",
  81. type=float,
  82. default=DEFAULT_FIRST_RECALL_HIT_RATE,
  83. help="Probability of forcing the gold dashboard into first-route results.",
  84. )
  85. parser.add_argument(
  86. "--seed",
  87. type=int,
  88. default=DEFAULT_RANDOM_SEED,
  89. help="Random seed used for reproducible first-route simulation.",
  90. )
  91. return parser.parse_args()
  92. def read_questions(path: Path) -> list[QuestionCase]:
  93. cases: list[QuestionCase] = []
  94. with path.open("r", encoding="utf-8") as handle:
  95. for line_no, raw_line in enumerate(handle, start=1):
  96. line = raw_line.strip()
  97. if not line or line.startswith("#"):
  98. continue
  99. card_id, question = split_question_line(line, line_no)
  100. cases.append(QuestionCase(line_no=line_no, card_id=card_id, question=question))
  101. if not cases:
  102. raise ValueError(f"No valid questions found in {path}")
  103. return cases
  104. def split_question_line(line: str, line_no: int) -> tuple[str, str]:
  105. for separator in ("\t", ",", "|"):
  106. if separator in line:
  107. left, right = line.split(separator, 1)
  108. card_id = left.strip()
  109. question = right.strip()
  110. if card_id and question:
  111. return card_id, question
  112. parts = line.split(maxsplit=1)
  113. if len(parts) == 2 and parts[0].strip() and parts[1].strip():
  114. return parts[0].strip(), parts[1].strip()
  115. raise ValueError(
  116. f"Invalid question line {line_no}: '{line}'. Expected 'card_id<TAB>question'."
  117. )
  118. def load_mapping(path: Path) -> tuple[dict[str, CardMapping], list[str]]:
  119. try:
  120. import pyarrow.parquet as pq
  121. except ModuleNotFoundError as exc:
  122. raise RuntimeError(
  123. "Missing dependency 'pyarrow'. Install it first, for example: pip install pyarrow requests"
  124. ) from exc
  125. table = pq.read_table(path)
  126. missing_columns = REQUIRED_COLUMNS.difference(table.column_names)
  127. if missing_columns:
  128. raise ValueError(f"{path} is missing required columns: {sorted(missing_columns)}")
  129. rows = table.to_pylist()
  130. if not rows:
  131. raise ValueError(f"{path} does not contain any rows")
  132. card_map: dict[str, CardMapping] = {}
  133. all_dashboard_ids: set[str] = set()
  134. for row in rows:
  135. card_id = normalize_id(row.get("card_id"))
  136. if card_id is None:
  137. continue
  138. mapping = card_map.setdefault(
  139. card_id,
  140. CardMapping(
  141. card_name=normalize_text(row.get("card_name")),
  142. dashboard_ids=set(),
  143. dashboard_names=set(),
  144. dataset_ids=set(),
  145. dataset_names=set(),
  146. branch_codes=set(),
  147. ),
  148. )
  149. dashboard_id = normalize_id(row.get("dashboard_id"))
  150. dataset_id = normalize_id(row.get("dataset_id"))
  151. dashboard_name = normalize_text(row.get("dashboard_name"))
  152. dataset_name = normalize_text(row.get("dataset_name"))
  153. branch_code = normalize_text(row.get("branch_code"))
  154. if dashboard_id is not None:
  155. mapping.dashboard_ids.add(dashboard_id)
  156. all_dashboard_ids.add(dashboard_id)
  157. if dataset_id is not None:
  158. mapping.dataset_ids.add(dataset_id)
  159. if dashboard_name:
  160. mapping.dashboard_names.add(dashboard_name)
  161. if dataset_name:
  162. mapping.dataset_names.add(dataset_name)
  163. if branch_code:
  164. mapping.branch_codes.add(branch_code)
  165. if not mapping.card_name:
  166. mapping.card_name = normalize_text(row.get("card_name"))
  167. return card_map, sorted(all_dashboard_ids)
  168. def normalize_id(value: Any) -> str | None:
  169. if value is None:
  170. return None
  171. text = str(value).strip()
  172. return text or None
  173. def normalize_text(value: Any) -> str:
  174. if value is None:
  175. return ""
  176. return str(value).strip()
  177. def simulate_first_recall(
  178. all_dashboard_ids: list[str],
  179. gold_dashboard_ids: set[str],
  180. sample_size: int,
  181. force_hit_rate: float,
  182. rng: random.Random,
  183. ) -> set[str]:
  184. if not all_dashboard_ids:
  185. return set()
  186. capped_size = min(sample_size, len(all_dashboard_ids))
  187. gold_candidates = [
  188. dashboard_id for dashboard_id in gold_dashboard_ids if dashboard_id in all_dashboard_ids
  189. ]
  190. should_force_hit = bool(gold_candidates) and rng.random() < force_hit_rate
  191. if should_force_hit:
  192. forced_dashboard = rng.choice(gold_candidates)
  193. remaining = [
  194. dashboard_id for dashboard_id in all_dashboard_ids if dashboard_id != forced_dashboard
  195. ]
  196. picked = rng.sample(remaining, k=max(capped_size - 1, 0))
  197. selected = set(picked)
  198. selected.add(forced_dashboard)
  199. return selected
  200. non_gold_dashboards = [
  201. dashboard_id for dashboard_id in all_dashboard_ids if dashboard_id not in gold_dashboard_ids
  202. ]
  203. if len(non_gold_dashboards) >= capped_size:
  204. return set(rng.sample(non_gold_dashboards, k=capped_size))
  205. return set(rng.sample(all_dashboard_ids, k=capped_size))
  206. def invoke_recall_api(
  207. api_url: str,
  208. question: str,
  209. recall_type: str,
  210. timeout_seconds: int,
  211. ) -> tuple[set[str], list[dict[str, Any]]]:
  212. if api_url == DEFAULT_RECALL_API_URL:
  213. raise ValueError(
  214. "Recall API URL is still the placeholder value. Pass --api-url with the actual endpoint."
  215. )
  216. response = requests.post(
  217. api_url,
  218. json={"question": question, "recall_type": recall_type},
  219. timeout=timeout_seconds,
  220. )
  221. response.raise_for_status()
  222. payload = response.json()
  223. records = extract_records(payload)
  224. ids = {record_id for record in records if (record_id := extract_record_id(record, recall_type))}
  225. return ids, records
  226. def extract_records(payload: Any) -> list[dict[str, Any]]:
  227. if isinstance(payload, list):
  228. return [item for item in payload if isinstance(item, dict)]
  229. if not isinstance(payload, dict):
  230. return []
  231. for key in ("data", "items", "results", "records"):
  232. value = payload.get(key)
  233. if isinstance(value, list):
  234. return [item for item in value if isinstance(item, dict)]
  235. return [payload]
  236. def extract_record_id(record: dict[str, Any], recall_type: str) -> str | None:
  237. preferred_keys = [f"{recall_type}_id", "id"]
  238. fallback_keys = ["dashboard_id", "dataset_id"]
  239. for key in preferred_keys + fallback_keys:
  240. value = normalize_id(record.get(key))
  241. if value is not None:
  242. return value
  243. return None
  244. def safe_api_recall(
  245. api_url: str,
  246. question: str,
  247. recall_type: str,
  248. timeout_seconds: int,
  249. ) -> tuple[set[str], list[dict[str, Any]], str]:
  250. try:
  251. ids, records = invoke_recall_api(api_url, question, recall_type, timeout_seconds)
  252. return ids, records, ""
  253. except Exception as exc: # noqa: BLE001
  254. return set(), [], str(exc)
  255. def ensure_output_dir(path: Path) -> None:
  256. path.mkdir(parents=True, exist_ok=True)
  257. def write_details_csv(path: Path, rows: list[dict[str, Any]]) -> None:
  258. fieldnames = [
  259. "line_no",
  260. "card_id",
  261. "card_name",
  262. "question",
  263. "gold_dashboard_ids",
  264. "gold_dataset_ids",
  265. "first_recall_ids",
  266. "first_recall_hit",
  267. "dashboard_recall_ids",
  268. "dashboard_recall_hit",
  269. "dashboard_recall_error",
  270. "dashboard_recall_records",
  271. "dataset_recall_ids",
  272. "dataset_recall_hit",
  273. "dataset_recall_error",
  274. "dataset_recall_records",
  275. "union_hit",
  276. "branch_codes",
  277. ]
  278. with path.open("w", encoding="utf-8", newline="") as handle:
  279. writer = csv.DictWriter(handle, fieldnames=fieldnames)
  280. writer.writeheader()
  281. for row in rows:
  282. writer.writerow(row)
  283. def write_summary_json(path: Path, summary: dict[str, Any]) -> None:
  284. with path.open("w", encoding="utf-8") as handle:
  285. json.dump(summary, handle, ensure_ascii=False, indent=2)
  286. def rate(hits: int, total: int) -> float:
  287. if total == 0:
  288. return 0.0
  289. return hits / total
  290. def main() -> int:
  291. args = parse_args()
  292. rng = random.Random(args.seed)
  293. questions_path = Path(args.questions_txt)
  294. mapping_path = Path(args.id_map_parquet)
  295. output_dir = Path(args.output_dir)
  296. questions = read_questions(questions_path)
  297. card_map, all_dashboard_ids = load_mapping(mapping_path)
  298. ensure_output_dir(output_dir)
  299. first_hits = 0
  300. dashboard_hits = 0
  301. dataset_hits = 0
  302. union_hits = 0
  303. detail_rows: list[dict[str, Any]] = []
  304. missing_cards: list[dict[str, Any]] = []
  305. error_counter: dict[str, int] = defaultdict(int)
  306. for case in questions:
  307. mapping = card_map.get(case.card_id)
  308. if mapping is None:
  309. missing_cards.append(
  310. {"line_no": case.line_no, "card_id": case.card_id, "question": case.question}
  311. )
  312. detail_rows.append(
  313. {
  314. "line_no": case.line_no,
  315. "card_id": case.card_id,
  316. "card_name": "",
  317. "question": case.question,
  318. "gold_dashboard_ids": "",
  319. "gold_dataset_ids": "",
  320. "first_recall_ids": "",
  321. "first_recall_hit": False,
  322. "dashboard_recall_ids": "",
  323. "dashboard_recall_hit": False,
  324. "dashboard_recall_error": "card_id not found in id_map.parquet",
  325. "dashboard_recall_records": "",
  326. "dataset_recall_ids": "",
  327. "dataset_recall_hit": False,
  328. "dataset_recall_error": "card_id not found in id_map.parquet",
  329. "dataset_recall_records": "",
  330. "union_hit": False,
  331. "branch_codes": "",
  332. }
  333. )
  334. continue
  335. first_ids = simulate_first_recall(
  336. all_dashboard_ids=all_dashboard_ids,
  337. gold_dashboard_ids=mapping.dashboard_ids,
  338. sample_size=args.first_recall_size,
  339. force_hit_rate=args.first_recall_hit_rate,
  340. rng=rng,
  341. )
  342. dashboard_ids, dashboard_records, dashboard_error = safe_api_recall(
  343. api_url=args.api_url,
  344. question=case.question,
  345. recall_type=RECALL_TYPE_DASHBOARD,
  346. timeout_seconds=args.timeout,
  347. )
  348. dataset_ids, dataset_records, dataset_error = safe_api_recall(
  349. api_url=args.api_url,
  350. question=case.question,
  351. recall_type=RECALL_TYPE_DATASET,
  352. timeout_seconds=args.timeout,
  353. )
  354. first_hit = bool(first_ids.intersection(mapping.dashboard_ids))
  355. dashboard_hit = bool(dashboard_ids.intersection(mapping.dashboard_ids))
  356. dataset_hit = bool(dataset_ids.intersection(mapping.dataset_ids))
  357. union_hit = first_hit or dashboard_hit or dataset_hit
  358. first_hits += int(first_hit)
  359. dashboard_hits += int(dashboard_hit)
  360. dataset_hits += int(dataset_hit)
  361. union_hits += int(union_hit)
  362. if dashboard_error:
  363. error_counter["dashboard_api_errors"] += 1
  364. if dataset_error:
  365. error_counter["dataset_api_errors"] += 1
  366. detail_rows.append(
  367. {
  368. "line_no": case.line_no,
  369. "card_id": case.card_id,
  370. "card_name": mapping.card_name,
  371. "question": case.question,
  372. "gold_dashboard_ids": ",".join(sorted(mapping.dashboard_ids)),
  373. "gold_dataset_ids": ",".join(sorted(mapping.dataset_ids)),
  374. "first_recall_ids": ",".join(sorted(first_ids)),
  375. "first_recall_hit": first_hit,
  376. "dashboard_recall_ids": ",".join(sorted(dashboard_ids)),
  377. "dashboard_recall_hit": dashboard_hit,
  378. "dashboard_recall_error": dashboard_error,
  379. "dashboard_recall_records": json.dumps(
  380. dashboard_records, ensure_ascii=False, separators=(",", ":")
  381. ),
  382. "dataset_recall_ids": ",".join(sorted(dataset_ids)),
  383. "dataset_recall_hit": dataset_hit,
  384. "dataset_recall_error": dataset_error,
  385. "dataset_recall_records": json.dumps(
  386. dataset_records, ensure_ascii=False, separators=(",", ":")
  387. ),
  388. "union_hit": union_hit,
  389. "branch_codes": ",".join(sorted(mapping.branch_codes)),
  390. }
  391. )
  392. if dashboard_records:
  393. detail_rows[-1]["dashboard_recall_ids"] = ",".join(sorted(dashboard_ids))
  394. if dataset_records:
  395. detail_rows[-1]["dataset_recall_ids"] = ",".join(sorted(dataset_ids))
  396. total = len(questions)
  397. first_rate = rate(first_hits, total)
  398. dashboard_rate = rate(dashboard_hits, total)
  399. dataset_rate = rate(dataset_hits, total)
  400. union_rate = rate(union_hits, total)
  401. uplift_vs_first = union_rate - first_rate
  402. summary = {
  403. "input": {
  404. "questions_txt": str(questions_path),
  405. "id_map_parquet": str(mapping_path),
  406. "api_url": args.api_url,
  407. "first_recall_size": args.first_recall_size,
  408. "first_recall_hit_rate": args.first_recall_hit_rate,
  409. "seed": args.seed,
  410. },
  411. "counts": {
  412. "total_questions": total,
  413. "first_recall_hits": first_hits,
  414. "dashboard_recall_hits": dashboard_hits,
  415. "dataset_recall_hits": dataset_hits,
  416. "union_hits": union_hits,
  417. "missing_card_mappings": len(missing_cards),
  418. },
  419. "rates": {
  420. "first_recall_rate": first_rate,
  421. "dashboard_recall_rate": dashboard_rate,
  422. "dataset_recall_rate": dataset_rate,
  423. "union_recall_rate": union_rate,
  424. "uplift_vs_first_recall": uplift_vs_first,
  425. },
  426. "errors": dict(error_counter),
  427. "missing_cards": missing_cards,
  428. "outputs": {
  429. "detail_csv": str(output_dir / "recall_details.csv"),
  430. "summary_json": str(output_dir / "recall_summary.json"),
  431. },
  432. }
  433. write_details_csv(output_dir / "recall_details.csv", detail_rows)
  434. write_summary_json(output_dir / "recall_summary.json", summary)
  435. print(json.dumps(summary, ensure_ascii=False, indent=2))
  436. return 0
  437. if __name__ == "__main__":
  438. try:
  439. raise SystemExit(main())
  440. except Exception as exc: # noqa: BLE001
  441. print(f"[ERROR] {exc}", file=sys.stderr)
  442. raise SystemExit(1) from exc