@@ -430,22 +430,18 @@ def _test_distrib_integration_multiclass(device):
430
430
431
431
from ignite .engine import Engine
432
432
433
- rank = idist .get_rank ()
434
- torch .manual_seed (12 )
435
-
436
433
def _test (average , n_epochs , metric_device ):
437
434
n_iters = 60
438
- s = 16
435
+ batch_size = 16
439
436
n_classes = 7
440
437
441
- offset = n_iters * s
442
- y_true = torch .randint (0 , n_classes , size = (offset * idist .get_world_size (),)).to (device )
443
- y_preds = torch .rand (offset * idist .get_world_size (), n_classes ).to (device )
438
+ y_true = torch .randint (0 , n_classes , size = (n_iters * batch_size ,)).to (device )
439
+ y_preds = torch .rand (n_iters * batch_size , n_classes ).to (device )
444
440
445
441
def update (engine , i ):
446
442
return (
447
- y_preds [i * s + rank * offset : (i + 1 ) * s + rank * offset , :],
448
- y_true [i * s + rank * offset : (i + 1 ) * s + rank * offset ],
443
+ y_preds [i * batch_size : (i + 1 ) * batch_size , :],
444
+ y_true [i * batch_size : (i + 1 ) * batch_size ],
449
445
)
450
446
451
447
engine = Engine (update )
@@ -457,6 +453,9 @@ def update(engine, i):
457
453
data = list (range (n_iters ))
458
454
engine .run (data = data , max_epochs = n_epochs )
459
455
456
+ y_preds = idist .all_gather (y_preds )
457
+ y_true = idist .all_gather (y_true )
458
+
460
459
assert "re" in engine .state .metrics
461
460
assert re ._updated is True
462
461
res = engine .state .metrics ["re" ]
@@ -475,7 +474,9 @@ def update(engine, i):
475
474
metric_devices = [torch .device ("cpu" )]
476
475
if device .type != "xla" :
477
476
metric_devices .append (idist .device ())
478
- for _ in range (2 ):
477
+ rank = idist .get_rank ()
478
+ for i in range (2 ):
479
+ torch .manual_seed (12 + rank + i )
479
480
for metric_device in metric_devices :
480
481
_test (average = False , n_epochs = 1 , metric_device = metric_device )
481
482
_test (average = False , n_epochs = 2 , metric_device = metric_device )
@@ -491,22 +492,20 @@ def _test_distrib_integration_multilabel(device):
491
492
492
493
from ignite .engine import Engine
493
494
494
- rank = idist .get_rank ()
495
495
torch .manual_seed (12 )
496
496
497
497
def _test (average , n_epochs , metric_device ):
498
498
n_iters = 60
499
- s = 16
499
+ batch_size = 16
500
500
n_classes = 7
501
501
502
- offset = n_iters * s
503
- y_true = torch .randint (0 , 2 , size = (offset * idist .get_world_size (), n_classes , 6 , 8 )).to (device )
504
- y_preds = torch .randint (0 , 2 , size = (offset * idist .get_world_size (), n_classes , 6 , 8 )).to (device )
502
+ y_true = torch .randint (0 , 2 , size = (n_iters * batch_size , n_classes , 6 , 8 )).to (device )
503
+ y_preds = torch .randint (0 , 2 , size = (n_iters * batch_size , n_classes , 6 , 8 )).to (device )
505
504
506
505
def update (engine , i ):
507
506
return (
508
- y_preds [i * s + rank * offset : (i + 1 ) * s + rank * offset , ...],
509
- y_true [i * s + rank * offset : (i + 1 ) * s + rank * offset , ...],
507
+ y_preds [i * batch_size : (i + 1 ) * batch_size , ...],
508
+ y_true [i * batch_size : (i + 1 ) * batch_size , ...],
510
509
)
511
510
512
511
engine = Engine (update )
@@ -518,6 +517,9 @@ def update(engine, i):
518
517
data = list (range (n_iters ))
519
518
engine .run (data = data , max_epochs = n_epochs )
520
519
520
+ y_preds = idist .all_gather (y_preds )
521
+ y_true = idist .all_gather (y_true )
522
+
521
523
assert "re" in engine .state .metrics
522
524
assert re ._updated is True
523
525
res = engine .state .metrics ["re" ]
@@ -540,7 +542,9 @@ def update(engine, i):
540
542
metric_devices = ["cpu" ]
541
543
if device .type != "xla" :
542
544
metric_devices .append (idist .device ())
543
- for _ in range (2 ):
545
+ rank = idist .get_rank ()
546
+ for i in range (2 ):
547
+ torch .manual_seed (12 + rank + i )
544
548
for metric_device in metric_devices :
545
549
_test (average = False , n_epochs = 1 , metric_device = metric_device )
546
550
_test (average = False , n_epochs = 2 , metric_device = metric_device )
0 commit comments