Skip to content

adds available_device to test_precision_recall_curve #3335 #3368

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 19 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 125 additions & 60 deletions tests/ignite/metrics/test_precision_recall_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@
from ignite.metrics.precision_recall_curve import PrecisionRecallCurve


def to_numpy_float32(x):
if isinstance(x, torch.Tensor):
if x.device.type == "mps":
x = x.to("cpu") # Explicitly move from MPS to CPU
return x.detach().to(dtype=torch.float32).numpy()
elif isinstance(x, np.ndarray):
return x.astype(np.float32)
return x


@pytest.fixture()
def mock_no_sklearn():
with patch.dict("sys.modules", {"sklearn.metrics": None}):
Expand All @@ -28,112 +38,144 @@ def test_no_sklearn(mock_no_sklearn):
pr_curve.compute()


def test_precision_recall_curve():
def test_precision_recall_curve(available_device):
size = 100
np_y_pred = np.random.rand(size, 1)
np_y = np.zeros((size,))
np_y[size // 2 :] = 1
sk_precision, sk_recall, sk_thresholds = precision_recall_curve(np_y, np_y_pred)
y_pred = torch.rand(size, 1, dtype=torch.float32, device=available_device)
y_true = torch.zeros(size, dtype=torch.float32, device=available_device)
y_true[size // 2 :] = 1.0
sk_precision, sk_recall, sk_thresholds = precision_recall_curve(to_numpy_float32(y_true), to_numpy_float32(y_pred))

precision_recall_curve_metric = PrecisionRecallCurve()
y_pred = torch.from_numpy(np_y_pred)
y = torch.from_numpy(np_y)
sk_precision = to_numpy_float32(sk_precision)
sk_recall = to_numpy_float32(sk_recall)
sk_thresholds = to_numpy_float32(sk_thresholds)

precision_recall_curve_metric.update((y_pred, y))
precision_recall_curve_metric = PrecisionRecallCurve(device=available_device)
assert precision_recall_curve_metric._device == torch.device(available_device)

precision_recall_curve_metric.update((y_pred, y_true))
precision, recall, thresholds = precision_recall_curve_metric.compute()
precision = precision.numpy()
recall = recall.numpy()
thresholds = thresholds.numpy()

assert pytest.approx(precision) == sk_precision
assert pytest.approx(recall) == sk_recall
precision = to_numpy_float32(precision)
recall = to_numpy_float32(recall)
thresholds = to_numpy_float32(thresholds)

sk_precision = to_numpy_float32(sk_precision)
sk_recall = to_numpy_float32(sk_recall)
sk_thresholds = to_numpy_float32(sk_thresholds)

assert np.allclose(precision, sk_precision, rtol=1e-6)
assert np.allclose(recall, sk_recall, rtol=1e-6)
# assert thresholds almost equal, due to numpy->torch->numpy conversion
np.testing.assert_array_almost_equal(thresholds, sk_thresholds)


def test_integration_precision_recall_curve_with_output_transform():
np.random.seed(1)
def test_integration_precision_recall_curve_with_output_transform(available_device):
torch.manual_seed(1)
size = 100
np_y_pred = np.random.rand(size, 1)
np_y = np.zeros((size,))
np_y[size // 2 :] = 1
np.random.shuffle(np_y)
y_pred = torch.rand(size, 1, dtype=torch.float32, device=available_device)
y_true = torch.zeros(size, dtype=torch.float32, device=available_device)
y_true[size // 2 :] = 1.0
perm = torch.randperm(size)
y_pred = y_pred[perm]
y_true = y_true[perm]

sk_precision, sk_recall, sk_thresholds = precision_recall_curve(np_y, np_y_pred)
sk_precision, sk_recall, sk_thresholds = precision_recall_curve(to_numpy_float32(y_true), to_numpy_float32(y_pred))

batch_size = 10

def update_fn(engine, batch):
idx = (engine.state.iteration - 1) * batch_size
y_true_batch = np_y[idx : idx + batch_size]
y_pred_batch = np_y_pred[idx : idx + batch_size]
return idx, torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch)
y_true_batch = y_true[idx : idx + batch_size]
y_pred_batch = y_pred[idx : idx + batch_size]
return idx, y_pred_batch, y_true_batch

engine = Engine(update_fn)

precision_recall_curve_metric = PrecisionRecallCurve(output_transform=lambda x: (x[1], x[2]))
precision_recall_curve_metric = PrecisionRecallCurve(
output_transform=lambda x: (x[1], x[2]), device=available_device
)
assert precision_recall_curve_metric._device == torch.device(available_device)
precision_recall_curve_metric.attach(engine, "precision_recall_curve")

data = list(range(size // batch_size))
precision, recall, thresholds = engine.run(data, max_epochs=1).metrics["precision_recall_curve"]
precision = precision.numpy()
recall = recall.numpy()
thresholds = thresholds.numpy()
assert pytest.approx(precision) == sk_precision
assert pytest.approx(recall) == sk_recall

precision = to_numpy_float32(precision)
recall = to_numpy_float32(recall)
thresholds = to_numpy_float32(thresholds)

sk_precision = to_numpy_float32(sk_precision)
sk_recall = to_numpy_float32(sk_recall)
sk_thresholds = to_numpy_float32(sk_thresholds)

assert np.allclose(precision, sk_precision, rtol=1e-6)
assert np.allclose(recall, sk_recall, rtol=1e-6)
# assert thresholds almost equal, due to numpy->torch->numpy conversion
np.testing.assert_array_almost_equal(thresholds, sk_thresholds)


def test_integration_precision_recall_curve_with_activated_output_transform():
def test_integration_precision_recall_curve_with_activated_output_transform(available_device):
np.random.seed(1)
size = 100
np_y_pred = np.random.rand(size, 1)
np_y_pred_sigmoid = torch.sigmoid(torch.from_numpy(np_y_pred)).numpy()
np_y = np.zeros((size,))
np_y[size // 2 :] = 1
np.random.shuffle(np_y)

sk_precision, sk_recall, sk_thresholds = precision_recall_curve(np_y, np_y_pred_sigmoid)
y_pred = torch.rand(size, 1, dtype=torch.float32, device=available_device)
y_true = torch.zeros(size, dtype=torch.float32, device=available_device)
y_true[size // 2 :] = 1.0
perm = torch.randperm(size)
y_pred = y_pred[perm]
y_true = y_true[perm]

sigmoid_y_pred = torch.sigmoid(y_pred).cpu().numpy()
sk_precision, sk_recall, sk_thresholds = precision_recall_curve(
to_numpy_float32(y_true), to_numpy_float32(sigmoid_y_pred)
)

batch_size = 10

def update_fn(engine, batch):
idx = (engine.state.iteration - 1) * batch_size
y_true_batch = np_y[idx : idx + batch_size]
y_pred_batch = np_y_pred[idx : idx + batch_size]
return idx, torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch)
y_true_batch = y_true[idx : idx + batch_size]
y_pred_batch = y_pred[idx : idx + batch_size]
return idx, y_pred_batch, y_true_batch

engine = Engine(update_fn)

precision_recall_curve_metric = PrecisionRecallCurve(output_transform=lambda x: (torch.sigmoid(x[1]), x[2]))
precision_recall_curve_metric = PrecisionRecallCurve(
output_transform=lambda x: (torch.sigmoid(x[1]), x[2]), device=available_device
)
assert precision_recall_curve_metric._device == torch.device(available_device)
precision_recall_curve_metric.attach(engine, "precision_recall_curve")

data = list(range(size // batch_size))
precision, recall, thresholds = engine.run(data, max_epochs=1).metrics["precision_recall_curve"]
precision = precision.cpu().numpy()
recall = recall.cpu().numpy()
thresholds = thresholds.cpu().numpy()
precision = to_numpy_float32(precision)
recall = to_numpy_float32(recall)
thresholds = to_numpy_float32(thresholds)

sk_precision = to_numpy_float32(sk_precision)
sk_recall = to_numpy_float32(sk_recall)
sk_thresholds = to_numpy_float32(sk_thresholds)

assert pytest.approx(precision) == sk_precision
assert pytest.approx(recall) == sk_recall
assert np.allclose(precision, sk_precision, rtol=1e-6)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this change?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from what I understand, pytest.approx may convert float32 parameter into float64. This would break on MPS

assert np.allclose(recall, sk_recall, rtol=1e-6)
# assert thresholds almost equal, due to numpy->torch->numpy conversion
np.testing.assert_array_almost_equal(thresholds, sk_thresholds)


def test_check_compute_fn():
def test_check_compute_fn(available_device):
y_pred = torch.zeros((8, 13))
y_pred[:, 1] = 1
y_true = torch.zeros_like(y_pred)
output = (y_pred, y_true)

em = PrecisionRecallCurve(check_compute_fn=True)
em = PrecisionRecallCurve(check_compute_fn=True, device=available_device)
assert em._device == torch.device(available_device)

em.reset()
with pytest.warns(EpochMetricWarning, match=r"Probably, there can be a problem with `compute_fn`"):
em.update(output)

em = PrecisionRecallCurve(check_compute_fn=False)
em = PrecisionRecallCurve(check_compute_fn=False, device=available_device)
assert em._device == torch.device(available_device)
em.update(output)


Expand All @@ -157,15 +199,29 @@ def _test(y_pred, y, batch_size, metric_device):
y_pred = idist.all_gather(y_pred)
y = idist.all_gather(y)

np_y = y.cpu().numpy()
np_y_pred = y_pred.cpu().numpy()
np_y = to_numpy_float32(y)
np_y_pred = to_numpy_float32(y_pred)

res = prc.compute()

assert isinstance(res, Tuple)
assert precision_recall_curve(np_y, np_y_pred)[0] == pytest.approx(res[0].cpu().numpy())
assert precision_recall_curve(np_y, np_y_pred)[1] == pytest.approx(res[1].cpu().numpy())
assert precision_recall_curve(np_y, np_y_pred)[2] == pytest.approx(res[2].cpu().numpy())
sk_precision, sk_recall, sk_thresholds = precision_recall_curve(np_y, np_y_pred)

assert np.allclose(
to_numpy_float32(res[0]),
to_numpy_float32(sk_precision),
rtol=1e-6,
)
assert np.allclose(
to_numpy_float32(res[1]),
to_numpy_float32(sk_recall),
rtol=1e-6,
)
assert np.allclose(
to_numpy_float32(res[2]),
to_numpy_float32(sk_thresholds),
rtol=1e-6,
)

def get_test_cases():
test_cases = [
Expand Down Expand Up @@ -222,17 +278,26 @@ def update(engine, i):

precision, recall, thresholds = engine.state.metrics["prc"]

np_y_true = y_true.cpu().numpy().ravel()
np_y_preds = y_preds.cpu().numpy().ravel()
np_y_true = to_numpy_float32(y_true).ravel()
np_y_preds = to_numpy_float32(y_preds).ravel()

sk_precision, sk_recall, sk_thresholds = precision_recall_curve(np_y_true, np_y_preds)

sk_precision = sk_precision.astype(np.float32)
sk_recall = sk_recall.astype(np.float32)
sk_thresholds = sk_thresholds.astype(np.float32)

precision = to_numpy_float32(precision)
recall = to_numpy_float32(recall)
thresholds = to_numpy_float32(thresholds)

assert precision.shape == sk_precision.shape
assert recall.shape == sk_recall.shape
assert thresholds.shape == sk_thresholds.shape
assert pytest.approx(precision.cpu().numpy()) == sk_precision
assert pytest.approx(recall.cpu().numpy()) == sk_recall
assert pytest.approx(thresholds.cpu().numpy()) == sk_thresholds

assert np.allclose(precision, sk_precision, rtol=1e-6)
assert np.allclose(recall, sk_recall, rtol=1e-6)
np.testing.assert_array_almost_equal(thresholds, sk_thresholds)

metric_devices = ["cpu"]
if device.type != "xla":
Expand Down
Loading