-
Notifications
You must be signed in to change notification settings - Fork 88
Optimized memory usage and speed for covar type "full" #23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
6656009
f775797
f5537b7
29e6d2a
4cf41f8
c4c73d4
54e9cbd
69e58c0
0f70e40
083e241
6d73ade
3ab5051
e2d010a
e9976f7
1569ecc
d11fe27
04ccd94
315520c
c42b198
65ee660
bd1c632
7f862dd
06baed3
d5215a2
d8941a4
13b2d70
e223a57
2572c1a
e50457b
d750bc7
9bc2062
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# Benchmark | ||
GPU: Tesla T4 (16GB DRAM) | ||
|
||
- covariance_type = "full" | ||
- init_params = "random" | ||
- n_iter = 20 | ||
- delta = 0 | ||
|
||
| setup | original | k-loop | optimized (single) | optimized (double) | | ||
| --- | --- | --- | --- | --- | | ||
| n_features=16, n_components=16, n_data=100,000 | 6.9s | 6.9s | 0.5s | 3.44s | | ||
| n_features=16, n_components=16, n_data=1,000,000 | OOM | 68.8s | 3.7s | 34.0s | | ||
| n_features=64, n_components=64, n_data=100,000 | OOM | 161s | 3.57s | 13.9s | | ||
| n_features=64, n_components=64, n_data=1,000,000 | OOM | OOM | 44.4s | 527s | | ||
| n_features=256, n_components=256, n_data=100,000 | OOM | OOM | NAN | 686s | | ||
| n_features=256, n_components=16, n_data=1,000,000 | OOM | OOM | 60s | 454s | | ||
|
||
### Notes: | ||
- OOM: Out Of Memory | ||
- NAN: Covar contains NaN | ||
- Single/Double: dtype of covariance matrix | ||
- k-loop: almost the same as original `GaussianMixture`, except | ||
```python | ||
var = torch.sum((x - mu).unsqueeze(-1).matmul((x - mu).unsqueeze(-2)) * resp.unsqueeze(-1), dim=0, | ||
keepdim=True) / torch.sum(resp, dim=0, keepdim=True).unsqueeze(-1) + eps | ||
``` | ||
in `_m_step` is replaced with | ||
```python | ||
var = torch.empty(1, self.n_components, self.n_features, self.n_features, device=x.device, dtype=resp.dtype) | ||
eps = (torch.eye(self.n_features) * self.eps).to(x.device) | ||
for i in range(self.n_components): | ||
sub_mu = mu[:, i, :] | ||
sub_resp = resp[:, i, :] | ||
sub_x_mu = (x - sub_mu).squeeze(1) | ||
outer = torch.matmul(sub_x_mu[:, :, None], sub_x_mu[:, None, :]) | ||
outer_sum = torch.sum(outer * sub_resp[:, :, None], dim=0, keepdim=True) | ||
sub_var = outer_sum / resp_sum[i] + eps | ||
var[:, i, :, :] = sub_var | ||
``` |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,11 @@ | ||
import torch | ||
import numpy as np | ||
import math | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
from math import pi | ||
from scipy.special import logsumexp | ||
from utils import calculate_matmul, calculate_matmul_n_times | ||
from utils import calculate_matmul, calculate_matmul_n_times, find_optimal_splits | ||
from tqdm import tqdm | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd recommend removing this to keep the repository light on dependencies — users that require this functionality can always add it. |
||
|
||
|
||
class GaussianMixture(torch.nn.Module): | ||
|
@@ -15,30 +17,31 @@ class GaussianMixture(torch.nn.Module): | |
probabilities are shaped (n, k, 1) if they relate to an individual sample, | ||
or (1, k, 1) if they assign membership probabilities to one of the mixture components. | ||
""" | ||
def __init__(self, n_components, n_features, covariance_type="full", eps=1.e-6, init_params="kmeans", mu_init=None, var_init=None): | ||
def __init__(self, n_components, n_features, covariance_type="full", eps=1.e-6, init_params="kmeans", mu_init=None, var_init=None, covariance_data_type="double"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any reservations against going with default |
||
""" | ||
Initializes the model and brings all tensors into their required shape. | ||
The class expects data to be fed as a flat tensor in (n, d). | ||
The class owns: | ||
x: torch.Tensor (n, 1, d) | ||
mu: torch.Tensor (1, k, d) | ||
var: torch.Tensor (1, k, d) or (1, k, d, d) | ||
pi: torch.Tensor (1, k, 1) | ||
covariance_type: str | ||
eps: float | ||
init_params: str | ||
log_likelihood: float | ||
n_components: int | ||
n_features: int | ||
x: torch.Tensor (n, 1, d) | ||
mu: torch.Tensor (1, k, d) | ||
var: torch.Tensor (1, k, d) or (1, k, d, d) | ||
pi: torch.Tensor (1, k, 1) | ||
covariance_type: str | ||
eps: float | ||
init_params: str | ||
log_likelihood: float | ||
n_components: int | ||
n_features: int | ||
args: | ||
n_components: int | ||
n_features: int | ||
n_components: int | ||
n_features: int | ||
options: | ||
mu_init: torch.Tensor (1, k, d) | ||
var_init: torch.Tensor (1, k, d) or (1, k, d, d) | ||
covariance_type: str | ||
eps: float | ||
init_params: str | ||
mu_init: torch.Tensor (1, k, d) | ||
var_init: torch.Tensor (1, k, d) or (1, k, d, d) | ||
covariance_type: str | ||
eps: float | ||
init_params: str | ||
covariance_data_type: str or torch.dtype | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since |
||
""" | ||
super(GaussianMixture, self).__init__() | ||
|
||
|
@@ -50,6 +53,15 @@ def __init__(self, n_components, n_features, covariance_type="full", eps=1.e-6, | |
self.eps = eps | ||
|
||
self.log_likelihood = -np.inf | ||
self.safe_mode = True | ||
self.prev_log_prob = None | ||
|
||
assert covariance_data_type in ["float", "double", torch.float, torch.double] | ||
if covariance_data_type == "float": | ||
covariance_data_type = torch.float | ||
elif covariance_data_type == "double": | ||
covariance_data_type = torch.double | ||
self.covariance_data_type = covariance_data_type | ||
|
||
self.covariance_type = covariance_type | ||
self.init_params = init_params | ||
|
@@ -72,17 +84,17 @@ def _init_params(self): | |
if self.var_init is not None: | ||
# (1, k, d) | ||
assert self.var_init.size() == (1, self.n_components, self.n_features), "Input var_init does not have required tensor dimensions (1, %i, %i)" % (self.n_components, self.n_features) | ||
self.var = torch.nn.Parameter(self.var_init, requires_grad=False) | ||
self.var = torch.nn.Parameter(self.var_init.to(self.covariance_data_type), requires_grad=False) | ||
else: | ||
self.var = torch.nn.Parameter(torch.ones(1, self.n_components, self.n_features), requires_grad=False) | ||
self.var = torch.nn.Parameter(torch.ones(1, self.n_components, self.n_features, dtype=self.covariance_data_type), requires_grad=False) | ||
elif self.covariance_type == "full": | ||
if self.var_init is not None: | ||
# (1, k, d, d) | ||
assert self.var_init.size() == (1, self.n_components, self.n_features, self.n_features), "Input var_init does not have required tensor dimensions (1, %i, %i, %i)" % (self.n_components, self.n_features, self.n_features) | ||
self.var = torch.nn.Parameter(self.var_init, requires_grad=False) | ||
self.var = torch.nn.Parameter(self.var_init.to(self.covariance_data_type), requires_grad=False) | ||
else: | ||
self.var = torch.nn.Parameter( | ||
torch.eye(self.n_features).reshape(1, 1, self.n_features, self.n_features).repeat(1, self.n_components, 1, 1), | ||
torch.eye(self.n_features, dtype=self.covariance_data_type).reshape(1, 1, self.n_features, self.n_features).repeat(1, self.n_components, 1, 1), | ||
requires_grad=False | ||
) | ||
|
||
|
@@ -139,14 +151,15 @@ def fit(self, x, delta=1e-3, n_iter=100, warm_start=False): | |
|
||
i = 0 | ||
j = np.inf | ||
|
||
|
||
pbar = tqdm(total=n_iter) | ||
while (i <= n_iter) and (j >= delta): | ||
|
||
log_likelihood_old = self.log_likelihood | ||
mu_old = self.mu | ||
var_old = self.var | ||
|
||
self.__em(x) | ||
self.__em(x, use_prev_log_prob=True) | ||
self.log_likelihood = self.__score(x) | ||
|
||
if torch.isinf(self.log_likelihood.abs()) or torch.isnan(self.log_likelihood): | ||
|
@@ -170,8 +183,10 @@ def fit(self, x, delta=1e-3, n_iter=100, warm_start=False): | |
# When score decreases, revert to old parameters | ||
self.__update_mu(mu_old) | ||
self.__update_var(var_old) | ||
pbar.update(1) | ||
|
||
self.params_fitted = True | ||
pbar.close() | ||
|
||
|
||
def predict(self, x, probs=False): | ||
|
@@ -188,7 +203,8 @@ def predict(self, x, probs=False): | |
""" | ||
x = self.check_size(x) | ||
|
||
weighted_log_prob = self._estimate_log_prob(x) + torch.log(self.pi) | ||
weighted_log_prob = self._estimate_log_prob(x) | ||
weighted_log_prob.add_(torch.log(self.pi)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. While carrying this out in-place preserves memory, spreading this across two lines here and in 369 and 466 decreases readability somewhat. Alternatively, I reckon this could be moved into |
||
|
||
if probs: | ||
p_k = torch.exp(weighted_log_prob) | ||
|
@@ -257,25 +273,47 @@ def _estimate_log_prob(self, x): | |
log_prob: torch.Tensor (n, k, 1) | ||
""" | ||
x = self.check_size(x) | ||
N, _, D = x.shape | ||
K = self.n_components | ||
|
||
if self.covariance_type == "full": | ||
mu = self.mu | ||
var = self.var | ||
|
||
precision = torch.inverse(var) | ||
d = x.shape[-1] | ||
|
||
log_2pi = d * np.log(2. * pi) | ||
|
||
log_det = self._calculate_log_det(precision) | ||
x = x.to(var.dtype) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since |
||
mu = mu.to(var.dtype) | ||
|
||
x_mu_T = (x - mu).unsqueeze(-2) | ||
x_mu = (x - mu).unsqueeze(-1) | ||
|
||
x_mu_T_precision = calculate_matmul_n_times(self.n_components, x_mu_T, precision) | ||
x_mu_T_precision_x_mu = calculate_matmul(x_mu_T_precision, x_mu) | ||
precision = torch.inverse(var) | ||
|
||
return -.5 * (log_2pi - log_det + x_mu_T_precision_x_mu) | ||
log_2pi = D * np.log(2. * pi) | ||
|
||
log_det = self._calculate_log_det(precision) #[K, 1] | ||
|
||
x_mu_T_precision_x_mu = torch.empty(N, K, 1, device=x.device, dtype=x.dtype) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unless there are reservations/concerns, I would consider moving this into its own utility function, in the interest of preserving readability of the code (happy to take care of this once it has been merged). |
||
|
||
def get_required_memory(sub_K): | ||
x_mu_requires = sub_K * N * D * x.element_size() | ||
x_mu_T_precision_requires = sub_K * N * D * x.element_size() | ||
calculate_matmul_requires = (sub_K * N * D + sub_K * N) * x.element_size() | ||
return x_mu_requires + x_mu_T_precision_requires + calculate_matmul_requires | ||
|
||
n_splits = find_optimal_splits(K, get_required_memory, x.device, safe_mode=self.safe_mode) | ||
sub_K = math.ceil(K / n_splits) | ||
for i in range(n_splits): | ||
K_start = i * sub_K | ||
K_end = min((i + 1) * sub_K, K) | ||
sub_x_mu = x - mu[:, K_start: K_end, :] #[N, sub_K, D] | ||
sub_x_mu_T_precision = (sub_x_mu.transpose(0, 1) @ precision[:, K_start: K_end]).transpose(0, 2) | ||
sub_x_mu_T_precision_x_mu = calculate_matmul(sub_x_mu_T_precision, sub_x_mu[:, :, :, None]) #[N, sub_K, 1] | ||
x_mu_T_precision_x_mu[:, K_start: K_end] = sub_x_mu_T_precision_x_mu | ||
del sub_x_mu, sub_x_mu_T_precision | ||
|
||
log_prob = x_mu_T_precision_x_mu | ||
log_prob.add_(log_2pi) | ||
log_prob.add_(-log_det) | ||
log_prob.mul_(-0.5) | ||
|
||
return log_prob | ||
|
||
elif self.covariance_type == "diag": | ||
mu = self.mu | ||
|
@@ -287,64 +325,105 @@ def _estimate_log_prob(self, x): | |
return -.5 * (self.n_features * np.log(2. * pi) + log_p) + log_det | ||
|
||
|
||
|
||
def _calculate_log_det(self, var): | ||
""" | ||
Calculate log determinant in log space, to prevent overflow errors. | ||
args: | ||
var: torch.Tensor (1, k, d, d) | ||
""" | ||
log_det = torch.empty(size=(self.n_components,)).to(var.device) | ||
|
||
for k in range(self.n_components): | ||
log_det[k] = 2 * torch.log(torch.diagonal(torch.linalg.cholesky(var[0,k]))).sum() | ||
assert (var != var).sum() == 0, "`var` contains NaN, set `covariance_data_type` to double" | ||
assert (var.abs() == float("inf")).sum() == 0, "`var` contains inf, set `covariance_data_type` to double" | ||
|
||
if self.covariance_data_type == torch.float: | ||
evals = torch.linalg.eigvals(var[0]) | ||
# evals, _ = torch.linalg.eig(var[0, k]) | ||
log_det = torch.log(evals).sum(dim=-1).to(var.dtype) | ||
|
||
elif self.covariance_data_type == torch.double: | ||
cholesky = torch.linalg.cholesky(var[0]) | ||
diagonal = torch.diagonal(cholesky, dim1=-2, dim2=-1) | ||
del cholesky | ||
log_det = 2 * torch.log(diagonal).sum(dim=-1) | ||
|
||
return log_det.unsqueeze(-1) | ||
|
||
|
||
def _e_step(self, x): | ||
def _e_step(self, x, use_prev_log_prob=False): | ||
""" | ||
Computes log-responses that indicate the (logarithmic) posterior belief (sometimes called responsibilities) that a data point was generated by one of the k mixture components. | ||
Also returns the mean of the mean of the logarithms of the probabilities (as is done in sklearn). | ||
This is the so-called expectation step of the EM-algorithm. | ||
args: | ||
x: torch.Tensor (n, d) or (n, 1, d) | ||
x: torch.Tensor (n, d) or (n, 1, d) | ||
use_prev_log_prob: bool | ||
returns: | ||
log_prob_norm: torch.Tensor (1) | ||
log_resp: torch.Tensor (n, k, 1) | ||
log_prob_norm: torch.Tensor (1) | ||
log_resp: torch.Tensor (n, k, 1) | ||
""" | ||
x = self.check_size(x) | ||
|
||
weighted_log_prob = self._estimate_log_prob(x) + torch.log(self.pi) | ||
if self.prev_log_prob is not None and use_prev_log_prob: | ||
weighted_log_prob = self.prev_log_prob | ||
else: | ||
weighted_log_prob = self._estimate_log_prob(x) | ||
weighted_log_prob.add_(torch.log(self.pi)) | ||
|
||
log_prob_norm = torch.logsumexp(weighted_log_prob, dim=1, keepdim=True) | ||
log_resp = weighted_log_prob - log_prob_norm | ||
|
||
log_resp = weighted_log_prob | ||
log_resp.sub_(log_prob_norm) | ||
|
||
return torch.mean(log_prob_norm), log_resp | ||
|
||
|
||
def _m_step(self, x, log_resp): | ||
def _m_step(self, x, resp): | ||
""" | ||
From the log-probabilities, computes new parameters pi, mu, var (that maximize the log-likelihood). This is the maximization step of the EM-algorithm. | ||
args: | ||
x: torch.Tensor (n, d) or (n, 1, d) | ||
log_resp: torch.Tensor (n, k, 1) | ||
resp: torch.Tensor (n, k, 1) | ||
returns: | ||
pi: torch.Tensor (1, k, 1) | ||
mu: torch.Tensor (1, k, d) | ||
var: torch.Tensor (1, k, d) | ||
""" | ||
x = self.check_size(x) | ||
N, _, D = x.shape | ||
K = self.n_components | ||
|
||
resp = torch.exp(log_resp) | ||
resp_sum = resp.sum(dim=0).squeeze(-1) #[K] | ||
|
||
pi = torch.sum(resp, dim=0, keepdim=True) + self.eps | ||
mu = torch.sum(resp * x, dim=0, keepdim=True) / pi | ||
|
||
mu = (resp.transpose(0, 1)[:, :, 0] @ x[:, 0, :].to(resp.dtype) )[None, :, :] | ||
mu.div_(pi) | ||
|
||
if self.covariance_type == "full": | ||
eps = (torch.eye(self.n_features) * self.eps).to(x.device) | ||
var = torch.sum((x - mu).unsqueeze(-1).matmul((x - mu).unsqueeze(-2)) * resp.unsqueeze(-1), dim=0, | ||
keepdim=True) / torch.sum(resp, dim=0, keepdim=True).unsqueeze(-1) + eps | ||
var = torch.empty(1, K, D, D, device=x.device, dtype=resp.dtype) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice! 👍 Same thought as before however, given the additional complexity that's introduced here, it might make sense to define these optimizations in some other place. |
||
eps = (torch.eye(D) * self.eps).to(x.device) | ||
|
||
def get_required_memory(sub_K): | ||
sub_x_mu_requires = N * D * sub_K * resp.element_size() | ||
sub_x_mu_resp_requires = 2 * N * D * sub_K * resp.element_size() | ||
sub_var_requires = D * D * sub_K * resp.element_size() | ||
return sub_x_mu_requires + sub_x_mu_resp_requires + sub_var_requires | ||
|
||
n_splits = find_optimal_splits(K, get_required_memory, x.device, safe_mode=self.safe_mode) | ||
sub_K = math.ceil(K / n_splits) | ||
|
||
for i in range(n_splits): | ||
K_start = i * sub_K | ||
K_end = min((i + 1) * sub_K, K) | ||
sub_mu = mu[:, K_start: K_end, :] #[1, sub_K, D] | ||
sub_resp = (resp[:, K_start: K_end, :]).permute(1, 2, 0) #[N, sub_K, 1] | ||
sub_x_mu = (x - sub_mu).permute(1, 2, 0) #[sub_K, D, N] | ||
sub_x_mu_resp = (sub_x_mu * sub_resp).transpose(-1, -2) #[sub_K, N, D] | ||
var[:, K_start: K_end, :, :] = sub_x_mu @ sub_x_mu_resp #[sub_K, D, D] | ||
del sub_x_mu, sub_x_mu_resp | ||
var.div_(resp_sum[None, :, None, None]) | ||
var.add_(eps[None, None, :, :]) | ||
|
||
|
||
elif self.covariance_type == "diag": | ||
x2 = (resp * x * x).sum(0, keepdim=True) / pi | ||
mu2 = mu * mu | ||
|
@@ -356,14 +435,15 @@ def _m_step(self, x, log_resp): | |
return pi, mu, var | ||
|
||
|
||
def __em(self, x): | ||
def __em(self, x, use_prev_log_prob=False): | ||
""" | ||
Performs one iteration of the expectation-maximization algorithm by calling the respective subroutines. | ||
args: | ||
x: torch.Tensor (n, 1, d) | ||
""" | ||
_, log_resp = self._e_step(x) | ||
pi, mu, var = self._m_step(x, log_resp) | ||
_, resp = self._e_step(x, use_prev_log_prob) | ||
resp.exp_() | ||
pi, mu, var = self._m_step(x, resp) | ||
|
||
self.__update_pi(pi) | ||
self.__update_mu(mu) | ||
|
@@ -382,8 +462,10 @@ def __score(self, x, as_average=True): | |
per_sample_score: torch.Tensor (n) | ||
|
||
""" | ||
weighted_log_prob = self._estimate_log_prob(x) + torch.log(self.pi) | ||
weighted_log_prob = self._estimate_log_prob(x) | ||
weighted_log_prob.add_(torch.log(self.pi)) | ||
per_sample_score = torch.logsumexp(weighted_log_prob, dim=1) | ||
self.prev_log_prob = weighted_log_prob | ||
|
||
if as_average: | ||
return per_sample_score.mean() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome results, thanks for sharing these!
Before merging with master, I would suggest removing
benchmark.md
.