#!/usr/bin/env python3 """Evaluate dashboard/dataset recall quality for BI search questions.""" from __future__ import annotations import argparse import csv import json import random import sys from collections import defaultdict from dataclasses import dataclass from pathlib import Path from typing import Any import requests DEFAULT_RECALL_API_URL = "http://replace-with-your-recall-api" DEFAULT_RECALL_TIMEOUT_SECONDS = 10 DEFAULT_FIRST_RECALL_SIZE = 500 DEFAULT_FIRST_RECALL_HIT_RATE = 0.8 DEFAULT_RANDOM_SEED = 20260311 RECALL_TYPE_DASHBOARD = "dashboard" RECALL_TYPE_DATASET = "dataset" REQUIRED_COLUMNS = { "card_id", "card_name", "dashboard_id", "dashboard_name", "dataset_id", "dataset_name", "branch_code", } @dataclass(frozen=True) class QuestionCase: line_no: int card_id: str question: str @dataclass class CardMapping: card_name: str dashboard_ids: set[str] dashboard_names: set[str] dataset_ids: set[str] dataset_names: set[str] branch_codes: set[str] def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Evaluate three-route recall performance for BI dashboard/dataset search." ) parser.add_argument( "--questions-txt", required=True, help="Question file. Recommended format: one line per case as 'card_idquestion'.", ) parser.add_argument( "--id-map-parquet", required=True, help="Parquet file containing card/dashboard/dataset mapping information.", ) parser.add_argument( "--output-dir", default="output", help="Directory for summary JSON and per-question detail CSV.", ) parser.add_argument( "--api-url", default=DEFAULT_RECALL_API_URL, help="Recall API URL. The script POSTs JSON: {'question': ..., 'recall_type': ...}.", ) parser.add_argument( "--timeout", type=int, default=DEFAULT_RECALL_TIMEOUT_SECONDS, help="Recall API timeout in seconds.", ) parser.add_argument( "--first-recall-size", type=int, default=DEFAULT_FIRST_RECALL_SIZE, help="Sample size for the simulated first recall route.", ) parser.add_argument( "--first-recall-hit-rate", type=float, default=DEFAULT_FIRST_RECALL_HIT_RATE, help="Probability of forcing the gold dashboard into first-route results.", ) parser.add_argument( "--seed", type=int, default=DEFAULT_RANDOM_SEED, help="Random seed used for reproducible first-route simulation.", ) return parser.parse_args() def read_questions(path: Path) -> list[QuestionCase]: cases: list[QuestionCase] = [] with path.open("r", encoding="utf-8") as handle: for line_no, raw_line in enumerate(handle, start=1): line = raw_line.strip() if not line or line.startswith("#"): continue card_id, question = split_question_line(line, line_no) cases.append(QuestionCase(line_no=line_no, card_id=card_id, question=question)) if not cases: raise ValueError(f"No valid questions found in {path}") return cases def split_question_line(line: str, line_no: int) -> tuple[str, str]: for separator in ("\t", ",", "|"): if separator in line: left, right = line.split(separator, 1) card_id = left.strip() question = right.strip() if card_id and question: return card_id, question parts = line.split(maxsplit=1) if len(parts) == 2 and parts[0].strip() and parts[1].strip(): return parts[0].strip(), parts[1].strip() raise ValueError( f"Invalid question line {line_no}: '{line}'. Expected 'card_idquestion'." ) def load_mapping(path: Path) -> tuple[dict[str, CardMapping], list[str]]: try: import pyarrow.parquet as pq except ModuleNotFoundError as exc: raise RuntimeError( "Missing dependency 'pyarrow'. Install it first, for example: pip install pyarrow requests" ) from exc table = pq.read_table(path) missing_columns = REQUIRED_COLUMNS.difference(table.column_names) if missing_columns: raise ValueError(f"{path} is missing required columns: {sorted(missing_columns)}") rows = table.to_pylist() if not rows: raise ValueError(f"{path} does not contain any rows") card_map: dict[str, CardMapping] = {} all_dashboard_ids: set[str] = set() for row in rows: card_id = normalize_id(row.get("card_id")) if card_id is None: continue mapping = card_map.setdefault( card_id, CardMapping( card_name=normalize_text(row.get("card_name")), dashboard_ids=set(), dashboard_names=set(), dataset_ids=set(), dataset_names=set(), branch_codes=set(), ), ) dashboard_id = normalize_id(row.get("dashboard_id")) dataset_id = normalize_id(row.get("dataset_id")) dashboard_name = normalize_text(row.get("dashboard_name")) dataset_name = normalize_text(row.get("dataset_name")) branch_code = normalize_text(row.get("branch_code")) if dashboard_id is not None: mapping.dashboard_ids.add(dashboard_id) all_dashboard_ids.add(dashboard_id) if dataset_id is not None: mapping.dataset_ids.add(dataset_id) if dashboard_name: mapping.dashboard_names.add(dashboard_name) if dataset_name: mapping.dataset_names.add(dataset_name) if branch_code: mapping.branch_codes.add(branch_code) if not mapping.card_name: mapping.card_name = normalize_text(row.get("card_name")) return card_map, sorted(all_dashboard_ids) def normalize_id(value: Any) -> str | None: if value is None: return None text = str(value).strip() return text or None def normalize_text(value: Any) -> str: if value is None: return "" return str(value).strip() def simulate_first_recall( all_dashboard_ids: list[str], gold_dashboard_ids: set[str], sample_size: int, force_hit_rate: float, rng: random.Random, ) -> set[str]: if not all_dashboard_ids: return set() capped_size = min(sample_size, len(all_dashboard_ids)) gold_candidates = [ dashboard_id for dashboard_id in gold_dashboard_ids if dashboard_id in all_dashboard_ids ] should_force_hit = bool(gold_candidates) and rng.random() < force_hit_rate if should_force_hit: forced_dashboard = rng.choice(gold_candidates) remaining = [ dashboard_id for dashboard_id in all_dashboard_ids if dashboard_id != forced_dashboard ] picked = rng.sample(remaining, k=max(capped_size - 1, 0)) selected = set(picked) selected.add(forced_dashboard) return selected non_gold_dashboards = [ dashboard_id for dashboard_id in all_dashboard_ids if dashboard_id not in gold_dashboard_ids ] if len(non_gold_dashboards) >= capped_size: return set(rng.sample(non_gold_dashboards, k=capped_size)) return set(rng.sample(all_dashboard_ids, k=capped_size)) def invoke_recall_api( api_url: str, question: str, recall_type: str, timeout_seconds: int, ) -> tuple[set[str], list[dict[str, Any]]]: if api_url == DEFAULT_RECALL_API_URL: raise ValueError( "Recall API URL is still the placeholder value. Pass --api-url with the actual endpoint." ) response = requests.post( api_url, json={"question": question, "recall_type": recall_type}, timeout=timeout_seconds, ) response.raise_for_status() payload = response.json() records = extract_records(payload) ids = {record_id for record in records if (record_id := extract_record_id(record, recall_type))} return ids, records def extract_records(payload: Any) -> list[dict[str, Any]]: if isinstance(payload, list): return [item for item in payload if isinstance(item, dict)] if not isinstance(payload, dict): return [] for key in ("data", "items", "results", "records"): value = payload.get(key) if isinstance(value, list): return [item for item in value if isinstance(item, dict)] return [payload] def extract_record_id(record: dict[str, Any], recall_type: str) -> str | None: preferred_keys = [f"{recall_type}_id", "id"] fallback_keys = ["dashboard_id", "dataset_id"] for key in preferred_keys + fallback_keys: value = normalize_id(record.get(key)) if value is not None: return value return None def safe_api_recall( api_url: str, question: str, recall_type: str, timeout_seconds: int, ) -> tuple[set[str], list[dict[str, Any]], str]: try: ids, records = invoke_recall_api(api_url, question, recall_type, timeout_seconds) return ids, records, "" except Exception as exc: # noqa: BLE001 return set(), [], str(exc) def ensure_output_dir(path: Path) -> None: path.mkdir(parents=True, exist_ok=True) def write_details_csv(path: Path, rows: list[dict[str, Any]]) -> None: fieldnames = [ "line_no", "card_id", "card_name", "question", "gold_dashboard_ids", "gold_dataset_ids", "first_recall_ids", "first_recall_hit", "dashboard_recall_ids", "dashboard_recall_hit", "dashboard_recall_error", "dashboard_recall_records", "dataset_recall_ids", "dataset_recall_hit", "dataset_recall_error", "dataset_recall_records", "union_hit", "branch_codes", ] with path.open("w", encoding="utf-8", newline="") as handle: writer = csv.DictWriter(handle, fieldnames=fieldnames) writer.writeheader() for row in rows: writer.writerow(row) def write_summary_json(path: Path, summary: dict[str, Any]) -> None: with path.open("w", encoding="utf-8") as handle: json.dump(summary, handle, ensure_ascii=False, indent=2) def rate(hits: int, total: int) -> float: if total == 0: return 0.0 return hits / total def main() -> int: args = parse_args() rng = random.Random(args.seed) questions_path = Path(args.questions_txt) mapping_path = Path(args.id_map_parquet) output_dir = Path(args.output_dir) questions = read_questions(questions_path) card_map, all_dashboard_ids = load_mapping(mapping_path) ensure_output_dir(output_dir) first_hits = 0 dashboard_hits = 0 dataset_hits = 0 union_hits = 0 detail_rows: list[dict[str, Any]] = [] missing_cards: list[dict[str, Any]] = [] error_counter: dict[str, int] = defaultdict(int) for case in questions: mapping = card_map.get(case.card_id) if mapping is None: missing_cards.append( {"line_no": case.line_no, "card_id": case.card_id, "question": case.question} ) detail_rows.append( { "line_no": case.line_no, "card_id": case.card_id, "card_name": "", "question": case.question, "gold_dashboard_ids": "", "gold_dataset_ids": "", "first_recall_ids": "", "first_recall_hit": False, "dashboard_recall_ids": "", "dashboard_recall_hit": False, "dashboard_recall_error": "card_id not found in id_map.parquet", "dashboard_recall_records": "", "dataset_recall_ids": "", "dataset_recall_hit": False, "dataset_recall_error": "card_id not found in id_map.parquet", "dataset_recall_records": "", "union_hit": False, "branch_codes": "", } ) continue first_ids = simulate_first_recall( all_dashboard_ids=all_dashboard_ids, gold_dashboard_ids=mapping.dashboard_ids, sample_size=args.first_recall_size, force_hit_rate=args.first_recall_hit_rate, rng=rng, ) dashboard_ids, dashboard_records, dashboard_error = safe_api_recall( api_url=args.api_url, question=case.question, recall_type=RECALL_TYPE_DASHBOARD, timeout_seconds=args.timeout, ) dataset_ids, dataset_records, dataset_error = safe_api_recall( api_url=args.api_url, question=case.question, recall_type=RECALL_TYPE_DATASET, timeout_seconds=args.timeout, ) first_hit = bool(first_ids.intersection(mapping.dashboard_ids)) dashboard_hit = bool(dashboard_ids.intersection(mapping.dashboard_ids)) dataset_hit = bool(dataset_ids.intersection(mapping.dataset_ids)) union_hit = first_hit or dashboard_hit or dataset_hit first_hits += int(first_hit) dashboard_hits += int(dashboard_hit) dataset_hits += int(dataset_hit) union_hits += int(union_hit) if dashboard_error: error_counter["dashboard_api_errors"] += 1 if dataset_error: error_counter["dataset_api_errors"] += 1 detail_rows.append( { "line_no": case.line_no, "card_id": case.card_id, "card_name": mapping.card_name, "question": case.question, "gold_dashboard_ids": ",".join(sorted(mapping.dashboard_ids)), "gold_dataset_ids": ",".join(sorted(mapping.dataset_ids)), "first_recall_ids": ",".join(sorted(first_ids)), "first_recall_hit": first_hit, "dashboard_recall_ids": ",".join(sorted(dashboard_ids)), "dashboard_recall_hit": dashboard_hit, "dashboard_recall_error": dashboard_error, "dashboard_recall_records": json.dumps( dashboard_records, ensure_ascii=False, separators=(",", ":") ), "dataset_recall_ids": ",".join(sorted(dataset_ids)), "dataset_recall_hit": dataset_hit, "dataset_recall_error": dataset_error, "dataset_recall_records": json.dumps( dataset_records, ensure_ascii=False, separators=(",", ":") ), "union_hit": union_hit, "branch_codes": ",".join(sorted(mapping.branch_codes)), } ) if dashboard_records: detail_rows[-1]["dashboard_recall_ids"] = ",".join(sorted(dashboard_ids)) if dataset_records: detail_rows[-1]["dataset_recall_ids"] = ",".join(sorted(dataset_ids)) total = len(questions) first_rate = rate(first_hits, total) dashboard_rate = rate(dashboard_hits, total) dataset_rate = rate(dataset_hits, total) union_rate = rate(union_hits, total) uplift_vs_first = union_rate - first_rate summary = { "input": { "questions_txt": str(questions_path), "id_map_parquet": str(mapping_path), "api_url": args.api_url, "first_recall_size": args.first_recall_size, "first_recall_hit_rate": args.first_recall_hit_rate, "seed": args.seed, }, "counts": { "total_questions": total, "first_recall_hits": first_hits, "dashboard_recall_hits": dashboard_hits, "dataset_recall_hits": dataset_hits, "union_hits": union_hits, "missing_card_mappings": len(missing_cards), }, "rates": { "first_recall_rate": first_rate, "dashboard_recall_rate": dashboard_rate, "dataset_recall_rate": dataset_rate, "union_recall_rate": union_rate, "uplift_vs_first_recall": uplift_vs_first, }, "errors": dict(error_counter), "missing_cards": missing_cards, "outputs": { "detail_csv": str(output_dir / "recall_details.csv"), "summary_json": str(output_dir / "recall_summary.json"), }, } write_details_csv(output_dir / "recall_details.csv", detail_rows) write_summary_json(output_dir / "recall_summary.json", summary) print(json.dumps(summary, ensure_ascii=False, indent=2)) return 0 if __name__ == "__main__": try: raise SystemExit(main()) except Exception as exc: # noqa: BLE001 print(f"[ERROR] {exc}", file=sys.stderr) raise SystemExit(1) from exc