Skip to content

Commit 942af82

Browse files
Update epoch_metric and recall in test for generating data with different rank (#2675)
* Update epoch_metric and recall in test for generating data with different rank Update `epoch_metric` and `recall` * Update test_recall.py * Update test_recall.py Co-authored-by: Sadra Barikbin <[email protected]>
1 parent 02d4c81 commit 942af82

File tree

2 files changed

+34
-28
lines changed

2 files changed

+34
-28
lines changed

tests/ignite/metrics/test_epoch_metric.py

+12-10
Original file line numberDiff line numberDiff line change
@@ -159,34 +159,36 @@ def _test_distrib_integration(device=None):
159159
device = idist.device() if idist.device().type != "xla" else "cpu"
160160

161161
rank = idist.get_rank()
162-
torch.manual_seed(12)
162+
torch.manual_seed(12 + rank)
163163

164-
n_iters = 60
165-
s = 16
164+
n_iters = 3
165+
batch_size = 2
166166
n_classes = 7
167167

168-
offset = n_iters * s
169-
y_true = torch.randint(0, n_classes, size=(offset * idist.get_world_size(),), device=device)
170-
y_preds = torch.rand(offset * idist.get_world_size(), n_classes, device=device)
168+
y_true = torch.randint(0, n_classes, size=(n_iters * batch_size,), device=device)
169+
y_preds = torch.rand(n_iters * batch_size, n_classes, device=device)
171170

172171
def update(engine, i):
173172
return (
174-
y_preds[i * s + rank * offset : (i + 1) * s + rank * offset, :],
175-
y_true[i * s + rank * offset : (i + 1) * s + rank * offset],
173+
y_preds[i * batch_size : (i + 1) * batch_size, :],
174+
y_true[i * batch_size : (i + 1) * batch_size],
176175
)
177176

178177
engine = Engine(update)
179178

180179
def assert_data_fn(all_preds, all_targets):
181-
assert all_preds.equal(y_preds), f"{all_preds.shape} vs {y_preds.shape}"
182-
assert all_targets.equal(y_true), f"{all_targets.shape} vs {y_true.shape}"
183180
return (all_preds.argmax(dim=1) == all_targets).sum().item()
184181

185182
ep_metric = EpochMetric(assert_data_fn, check_compute_fn=False, device=device)
186183
ep_metric.attach(engine, "epm")
187184

188185
data = list(range(n_iters))
186+
189187
engine.run(data=data, max_epochs=3)
188+
189+
y_preds = idist.all_gather(y_preds)
190+
y_true = idist.all_gather(y_true)
191+
190192
assert engine.state.metrics["epm"] == (y_preds.argmax(dim=1) == y_true).sum().item()
191193

192194

tests/ignite/metrics/test_recall.py

+22-18
Original file line numberDiff line numberDiff line change
@@ -430,22 +430,18 @@ def _test_distrib_integration_multiclass(device):
430430

431431
from ignite.engine import Engine
432432

433-
rank = idist.get_rank()
434-
torch.manual_seed(12)
435-
436433
def _test(average, n_epochs, metric_device):
437434
n_iters = 60
438-
s = 16
435+
batch_size = 16
439436
n_classes = 7
440437

441-
offset = n_iters * s
442-
y_true = torch.randint(0, n_classes, size=(offset * idist.get_world_size(),)).to(device)
443-
y_preds = torch.rand(offset * idist.get_world_size(), n_classes).to(device)
438+
y_true = torch.randint(0, n_classes, size=(n_iters * batch_size,)).to(device)
439+
y_preds = torch.rand(n_iters * batch_size, n_classes).to(device)
444440

445441
def update(engine, i):
446442
return (
447-
y_preds[i * s + rank * offset : (i + 1) * s + rank * offset, :],
448-
y_true[i * s + rank * offset : (i + 1) * s + rank * offset],
443+
y_preds[i * batch_size : (i + 1) * batch_size, :],
444+
y_true[i * batch_size : (i + 1) * batch_size],
449445
)
450446

451447
engine = Engine(update)
@@ -457,6 +453,9 @@ def update(engine, i):
457453
data = list(range(n_iters))
458454
engine.run(data=data, max_epochs=n_epochs)
459455

456+
y_preds = idist.all_gather(y_preds)
457+
y_true = idist.all_gather(y_true)
458+
460459
assert "re" in engine.state.metrics
461460
assert re._updated is True
462461
res = engine.state.metrics["re"]
@@ -475,7 +474,9 @@ def update(engine, i):
475474
metric_devices = [torch.device("cpu")]
476475
if device.type != "xla":
477476
metric_devices.append(idist.device())
478-
for _ in range(2):
477+
rank = idist.get_rank()
478+
for i in range(2):
479+
torch.manual_seed(12 + rank + i)
479480
for metric_device in metric_devices:
480481
_test(average=False, n_epochs=1, metric_device=metric_device)
481482
_test(average=False, n_epochs=2, metric_device=metric_device)
@@ -491,22 +492,20 @@ def _test_distrib_integration_multilabel(device):
491492

492493
from ignite.engine import Engine
493494

494-
rank = idist.get_rank()
495495
torch.manual_seed(12)
496496

497497
def _test(average, n_epochs, metric_device):
498498
n_iters = 60
499-
s = 16
499+
batch_size = 16
500500
n_classes = 7
501501

502-
offset = n_iters * s
503-
y_true = torch.randint(0, 2, size=(offset * idist.get_world_size(), n_classes, 6, 8)).to(device)
504-
y_preds = torch.randint(0, 2, size=(offset * idist.get_world_size(), n_classes, 6, 8)).to(device)
502+
y_true = torch.randint(0, 2, size=(n_iters * batch_size, n_classes, 6, 8)).to(device)
503+
y_preds = torch.randint(0, 2, size=(n_iters * batch_size, n_classes, 6, 8)).to(device)
505504

506505
def update(engine, i):
507506
return (
508-
y_preds[i * s + rank * offset : (i + 1) * s + rank * offset, ...],
509-
y_true[i * s + rank * offset : (i + 1) * s + rank * offset, ...],
507+
y_preds[i * batch_size : (i + 1) * batch_size, ...],
508+
y_true[i * batch_size : (i + 1) * batch_size, ...],
510509
)
511510

512511
engine = Engine(update)
@@ -518,6 +517,9 @@ def update(engine, i):
518517
data = list(range(n_iters))
519518
engine.run(data=data, max_epochs=n_epochs)
520519

520+
y_preds = idist.all_gather(y_preds)
521+
y_true = idist.all_gather(y_true)
522+
521523
assert "re" in engine.state.metrics
522524
assert re._updated is True
523525
res = engine.state.metrics["re"]
@@ -540,7 +542,9 @@ def update(engine, i):
540542
metric_devices = ["cpu"]
541543
if device.type != "xla":
542544
metric_devices.append(idist.device())
543-
for _ in range(2):
545+
rank = idist.get_rank()
546+
for i in range(2):
547+
torch.manual_seed(12 + rank + i)
544548
for metric_device in metric_devices:
545549
_test(average=False, n_epochs=1, metric_device=metric_device)
546550
_test(average=False, n_epochs=2, metric_device=metric_device)

0 commit comments

Comments
 (0)