Learn by Directing AI
All materials

evaluation-suite.py

pyevaluation-suite.py
"""
Evaluation suite for the MedConnect matching model.
Computes quality metrics and checks against configurable thresholds.
"""

import json
import sys

# Eval thresholds -- adjust these based on the project requirements
THRESHOLDS = {
    "recall": 0.55,
    "precision": 0.50,
    "f1": 0.55,
    "max_fairness_gap": 0.15,
}


def load_model(model_path="model/matching_model.pkl"):
    """Load the trained matching model."""
    import pickle
    with open(model_path, "rb") as f:
        model = pickle.load(f)
    return model


def load_test_data(data_path="data/test_data.csv"):
    """Load the test dataset for evaluation."""
    import csv
    rows = []
    with open(data_path, "r") as f:
        reader = csv.DictReader(f)
        for row in reader:
            rows.append(row)
    return rows


def compute_metrics(model, data):
    """
    Compute evaluation metrics for the matching model.

    Returns a dict with:
    - accuracy: overall accuracy
    - recall: overall recall for positive class
    - precision: overall precision for positive class
    - f1: overall F1 score
    - per_region_recall: dict mapping region -> recall
    - fairness_gap: max difference in recall across regions
    """
    # Placeholder implementation -- the student configures this
    # with the actual model evaluation logic
    metrics = {
        "accuracy": 0.0,
        "recall": 0.0,
        "precision": 0.0,
        "f1": 0.0,
        "per_region_recall": {},
        "fairness_gap": 0.0,
    }
    return metrics


def check_thresholds(metrics, thresholds):
    """
    Compare computed metrics against thresholds.

    Returns:
        (passed: bool, failures: list of str)
    """
    failures = []

    if metrics["recall"] < thresholds["recall"]:
        failures.append(
            f"FAIL: recall {metrics['recall']:.3f} below threshold {thresholds['recall']}"
        )

    if metrics["precision"] < thresholds["precision"]:
        failures.append(
            f"FAIL: precision {metrics['precision']:.3f} below threshold {thresholds['precision']}"
        )

    if metrics["f1"] < thresholds["f1"]:
        failures.append(
            f"FAIL: f1 {metrics['f1']:.3f} below threshold {thresholds['f1']}"
        )

    if metrics["fairness_gap"] > thresholds["max_fairness_gap"]:
        failures.append(
            f"FAIL: fairness_gap {metrics['fairness_gap']:.3f} exceeds threshold {thresholds['max_fairness_gap']}"
        )

    passed = len(failures) == 0
    return passed, failures


def main():
    """Run the full evaluation suite."""
    print("=" * 60)
    print("MedConnect Matching Model -- Evaluation Suite")
    print("=" * 60)

    model = load_model()
    data = load_test_data()
    metrics = compute_metrics(model, data)

    print("\nMetrics:")
    print(f"  accuracy:     {metrics['accuracy']:.3f}")
    print(f"  recall:       {metrics['recall']:.3f}")
    print(f"  precision:    {metrics['precision']:.3f}")
    print(f"  f1:           {metrics['f1']:.3f}")
    print(f"  fairness_gap: {metrics['fairness_gap']:.3f}")

    if metrics["per_region_recall"]:
        print("\nPer-region recall:")
        for region, recall in sorted(metrics["per_region_recall"].items()):
            print(f"  {region}: {recall:.3f}")

    print("\nThresholds:")
    for metric, threshold in THRESHOLDS.items():
        print(f"  {metric}: {threshold}")

    passed, failures = check_thresholds(metrics, THRESHOLDS)

    print("\n" + "=" * 60)
    if passed:
        print("RESULT: PASS -- All metrics above threshold.")
    else:
        print("RESULT: FAIL -- Blocking deployment.")
        for failure in failures:
            print(f"  {failure}")
        sys.exit(1)

    print("=" * 60)


if __name__ == "__main__":
    main()