Skip to content

Commit b8f4873

Browse files
authored
compare.py: Add --statistics to show statistical significance with a t-test. (#288)
* compare.py: Add --statistics to show statistical significance with a t-test. The alpha level can also be adjusted with --alpha
1 parent 591c457 commit b8f4873

File tree

1 file changed

+124
-14
lines changed

1 file changed

+124
-14
lines changed

utils/compare.py

Lines changed: 124 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -127,17 +127,29 @@ def merge_values(values, merge_function):
127127
return values.groupby(level=1).apply(merge_function)
128128

129129

130-
def get_values(values):
131-
# Create data view without diff column.
132-
if "diff" in values.columns:
133-
values = values[[c for c in values.columns if c != "diff"]]
130+
def get_values(values, lhs_name, rhs_name):
131+
exclude_cols = ["diff", "t-value", "p-value", "significant"]
132+
exclude_cols.extend([f'std_{lhs_name}', f'std_{rhs_name}'])
133+
values = values[[c for c in values.columns if c not in exclude_cols]]
134134
has_two_runs = len(values.columns) == 2
135135
if has_two_runs:
136136
return (values.iloc[:, 0], values.iloc[:, 1])
137137
else:
138138
return (values.min(axis=1), values.max(axis=1))
139139

140140

141+
def get_default_metric(data, second_data=None):
142+
"""Find a default metric to use if none specified.
143+
data: Primary dataframe to check
144+
second_data: Optional secondary dataframe (for 'vs' mode with lhs/rhs)
145+
"""
146+
defaults = ["Exec_Time", "exec_time", "Value", "Runtime"]
147+
for defkey in defaults:
148+
if defkey in data.columns or (second_data is not None and defkey in second_data.columns):
149+
return [defkey]
150+
return []
151+
152+
141153
def add_diff_column(metric, values, absolute_diff=False):
142154
values0, values1 = get_values(values[metric])
143155
values0.fillna(0.0, inplace=True)
@@ -150,14 +162,68 @@ def add_diff_column(metric, values, absolute_diff=False):
150162
return values
151163

152164

153-
def add_geomean_row(metrics, data, dataout):
165+
def compute_statistics(lhs_d, rhs_d, metrics, alpha, lhs_name, rhs_name):
166+
stats_dict = {}
167+
168+
for metric in metrics:
169+
if metric not in lhs_d.columns or metric not in rhs_d.columns:
170+
continue
171+
172+
stats_dict[metric] = {}
173+
174+
# Group by program (more efficient than unique+loc)
175+
for program, lhs_group in lhs_d.groupby(level=1):
176+
lhs_values = lhs_group[metric].dropna()
177+
rhs_values = rhs_d.loc[(slice(None), program), metric].dropna()
178+
179+
# Compute t-test if we have enough samples
180+
if len(lhs_values) >= 2 and len(rhs_values) >= 2:
181+
stats_dict[metric][program] = {
182+
f'std_{lhs_name}': lhs_values.std(ddof=1),
183+
f'std_{rhs_name}': rhs_values.std(ddof=1),
184+
}
185+
t_stat, p_val = stats.ttest_ind(lhs_values, rhs_values)
186+
stats_dict[metric][program]['t-value'] = t_stat
187+
stats_dict[metric][program]['p-value'] = p_val
188+
stats_dict[metric][program]['significant'] = "Y" if p_val < alpha else "N"
189+
else:
190+
stats_dict[metric][program] = {
191+
f'std_{lhs_name}': float('nan'),
192+
f'std_{rhs_name}': float('nan'),
193+
't-value': float('nan'),
194+
'p-value': float('nan'),
195+
'significant': ""
196+
}
197+
198+
return stats_dict
199+
200+
201+
def add_precomputed_statistics(data, stats_dict, stat_col_names):
202+
"""Add precomputed statistics to the unstacked dataframe."""
203+
for metric in data.columns.levels[0]:
204+
if metric not in stats_dict:
205+
continue
206+
207+
for stat_name in stat_col_names:
208+
values = []
209+
for program in data.index:
210+
if program in stats_dict[metric]:
211+
values.append(stats_dict[metric][program][stat_name])
212+
else:
213+
values.append(float('nan') if stat_name != 'significant' else "")
214+
data[(metric, stat_name)] = values
215+
216+
return data
217+
218+
219+
def add_geomean_row(metrics, data, dataout, lhs_name, rhs_name):
154220
"""
155221
Normalize values1 over values0, compute geomean difference and add a
156222
summary row to dataout.
157223
"""
158224
gm = pd.DataFrame(index=[GEOMEAN_ROW], columns=dataout.columns, dtype="float64")
159225
for metric in metrics:
160-
values0, values1 = get_values(data[metric])
226+
values0, values1 = get_values(data[metric], lhs_name, rhs_name)
161227
# Avoid infinite values in the diff and instead use NaN, as otherwise
162228
# the computation of the geometric mean will fail.
163229
values0 = values0.replace({0: float("NaN")})
@@ -249,6 +315,8 @@ def print_result(
249315
sortkey="diff",
250316
sort_by_abs=True,
251317
absolute_diff=False,
318+
lhs_name="lhs",
319+
rhs_name="rhs"
252320
):
253321
metrics = d.columns.levels[0]
254322
if sort_by_abs:
@@ -272,6 +340,16 @@ def print_result(
272340
if not absolute_diff:
273341
for m in metrics:
274342
formatters[(m, "diff")] = format_relative_diff
343+
# Add formatters for statistical columns
344+
for m in metrics:
345+
if (m, "p-value") in dataout.columns:
346+
formatters[(m, "p-value")] = lambda x: "%.4f" % x if not pd.isna(x) else ""
347+
if (m, "t-value") in dataout.columns:
348+
formatters[(m, "t-value")] = lambda x: "%.3f" % x if not pd.isna(x) else ""
349+
if (m, f'std_{lhs_name}') in dataout.columns:
350+
formatters[(m, f'std_{lhs_name}')] = lambda x: "%.3f" % x if not pd.isna(x) else ""
351+
if (m, f'std_{rhs_name}') in dataout.columns:
352+
formatters[(m, f'std_{rhs_name}')] = lambda x: "%.3f" % x if not pd.isna(x) else ""
275353
# Turn index into a column so we can format it...
276354
formatted_program = dataout.index.to_series()
277355
if shorten_names:
@@ -302,7 +380,7 @@ def strip_name_fully(name):
302380
# as it will otherwise interfere with common prefix/suffix computation.
303381
if show_diff_column and not absolute_diff:
304382
# geometric mean only makes sense for relative differences.
305-
dataout = add_geomean_row(metrics, d, dataout)
383+
dataout = add_geomean_row(metrics, d, dataout, lhs_name, rhs_name)
306384

307385
def float_format(x):
308386
if x == "":
@@ -320,7 +398,10 @@ def float_format(x):
320398
formatters=formatters,
321399
)
322400
print(out)
323-
print(d.describe())
401+
exclude_from_summary = ["t-value", "p-value", "significant"]
402+
exclude_from_summary.extend([f'std_{lhs_name}', f'std_{rhs_name}'])
403+
d_summary = d.drop(columns=exclude_from_summary, level=1, errors='ignore')
404+
print(d_summary.describe())
324405

325406

326407
def main():
@@ -400,6 +481,19 @@ def main():
400481
default=False,
401482
help="Don't use abs() when sorting results",
402483
)
484+
parser.add_argument(
485+
"--statistics",
486+
action="store_true",
487+
dest="statistics",
488+
default=False,
489+
help="Add statistical analysis columns (std, t-value, p-value, significance)",
490+
)
491+
parser.add_argument(
492+
"--alpha",
493+
type=float,
494+
default=0.05,
495+
help="Significance level for statistical tests (default: 0.05)",
496+
)
403497
config = parser.parse_args()
404498

405499
if config.show_diff is None:
@@ -425,15 +519,30 @@ def main():
425519

426520
# Read inputs
427521
files = config.files
522+
stats_dict = None
523+
stat_col_names = None
428524
if "vs" in files:
429525
split = files.index("vs")
430526
lhs = files[0:split]
431527
rhs = files[split + 1 :]
432528

433529
# Combine the multiple left and right hand sides.
434530
lhs_d = readmulti(lhs)
435-
lhs_merged = merge_values(lhs_d, config.merge_function)
436531
rhs_d = readmulti(rhs)
532+
533+
# Compute statistics on raw data before merging (if requested)
534+
if config.statistics:
535+
metrics_for_stats = config.metrics if len(config.metrics) > 0 else get_default_metric(lhs_d, rhs_d)
536+
stats_dict = compute_statistics(
537+
lhs_d, rhs_d, metrics_for_stats,
538+
alpha=config.alpha,
539+
lhs_name=config.lhs_name,
540+
rhs_name=config.rhs_name
541+
)
542+
stat_col_names = [f'std_{config.lhs_name}', f'std_{config.rhs_name}', 't-value', 'p-value', 'significant']
543+
544+
# Merge data
545+
lhs_merged = merge_values(lhs_d, config.merge_function)
437546
rhs_merged = merge_values(rhs_d, config.merge_function)
438547

439548
# Combine to new dataframe
@@ -448,11 +557,7 @@ def main():
448557
# Decide which metric to display / what is our "main" metric
449558
metrics = config.metrics
450559
if len(metrics) == 0:
451-
defaults = ["Exec_Time", "exec_time", "Value", "Runtime"]
452-
for defkey in defaults:
453-
if defkey in data.columns:
454-
metrics = [defkey]
455-
break
560+
metrics = get_default_metric(data)
456561
if len(metrics) == 0:
457562
sys.stderr.write("No default metric found and none specified\n")
458563
sys.stderr.write("Available metrics:\n")
@@ -508,6 +613,9 @@ def main():
508613
for metric in data.columns.levels[0]:
509614
data = add_diff_column(metric, data, absolute_diff=config.absolute_diff)
510615

616+
if config.statistics and stats_dict is not None:
617+
data = add_precomputed_statistics(data, stats_dict, stat_col_names)
618+
511619
sortkey = "diff"
512620
# TODO: should we still be sorting by diff even if the diff is hidden?
513621
if len(config.files) == 1:
@@ -526,6 +634,8 @@ def main():
526634
sortkey,
527635
config.no_abs_sort,
528636
config.absolute_diff,
637+
config.lhs_name,
638+
config.rhs_name,
529639
)
530640

531641

0 commit comments

Comments
 (0)