@@ -743,6 +743,55 @@ def get_handler_for_metric(
743743 raise ValueError (f"Unsupported metric: { metric .name } " )
744744
745745
746+ def calculate_win_rates (eval_result : types .EvaluationResult ) -> dict [str , Any ]:
747+ """Calculates win/tie rates for comparison results."""
748+ if not eval_result .eval_case_results :
749+ return {}
750+ max_models = max (
751+ (
752+ len (case .response_candidate_results )
753+ for case in eval_result .eval_case_results
754+ if case .response_candidate_results
755+ ),
756+ default = 0 ,
757+ )
758+ if max_models == 0 :
759+ return {}
760+ stats = collections .defaultdict (
761+ lambda : {"wins" : [0 ] * max_models , "ties" : 0 , "valid_comparisons" : 0 }
762+ )
763+ for case in eval_result .eval_case_results :
764+ if not case .response_candidate_results :
765+ continue
766+ scores_by_metric = collections .defaultdict (list )
767+ for idx , candidate in enumerate (case .response_candidate_results ):
768+ for name , res in (
769+ candidate .metric_results .items () if candidate .metric_results else {}
770+ ):
771+ if res .score is not None :
772+ scores_by_metric [name ].append ({"score" : res .score , "cand_idx" : idx })
773+ for name , scores in scores_by_metric .items ():
774+ if not scores :
775+ continue
776+ stats [name ]["valid_comparisons" ] += 1
777+ max_score = max (s ["score" ] for s in scores )
778+ winners = [s ["cand_idx" ] for s in scores if s ["score" ] == max_score ]
779+ if len (winners ) == 1 :
780+ stats [name ]["wins" ][winners [0 ]] += 1
781+ else :
782+ stats [name ]["ties" ] += 1
783+ win_rates = {}
784+ for name , metric_stats in stats .items ():
785+ if metric_stats ["valid_comparisons" ] > 0 :
786+ win_rates [name ] = {
787+ "win_rates" : [
788+ w / metric_stats ["valid_comparisons" ] for w in metric_stats ["wins" ]
789+ ],
790+ "tie_rate" : metric_stats ["ties" ] / metric_stats ["valid_comparisons" ],
791+ }
792+ return win_rates
793+
794+
746795def _aggregate_metric_results (
747796 metric_handlers : list [MetricHandler ],
748797 eval_case_results : list [types .EvalCaseResult ],
@@ -1001,18 +1050,27 @@ def compute_metrics_and_aggregate(
10011050 )
10021051 final_eval_case_results .append (eval_case_result )
10031052
1004- aggregated_metric_results = _aggregate_metric_results (
1005- metric_handlers , final_eval_case_results
1006- )
1007-
10081053 if submission_errors :
10091054 logger .warning ("Encountered %d submission errors." , len (submission_errors ))
10101055 logger .warning ("Submission errors: %s" , submission_errors )
10111056 if execution_errors :
10121057 logger .warning ("Encountered %d execution errors." , len (execution_errors ))
10131058 logger .warning ("Execution errors: %s" , execution_errors )
10141059
1015- return types .EvaluationResult (
1060+ aggregated_metric_results = _aggregate_metric_results (
1061+ metric_handlers , final_eval_case_results
1062+ )
1063+ eval_result = types .EvaluationResult (
10161064 eval_case_results = final_eval_case_results ,
10171065 summary_metrics = aggregated_metric_results ,
10181066 )
1067+ if evaluation_run_config .num_response_candidates > 1 :
1068+ try :
1069+ eval_result .win_rates = calculate_win_rates (eval_result )
1070+ except Exception as e : # pylint: disable=broad-exception-caught
1071+ logger .error (
1072+ "Error calculating win rates: %s" ,
1073+ e ,
1074+ exc_info = True ,
1075+ )
1076+ return eval_result
0 commit comments