Skip to content

Optimized memory usage and speed for covar type "full" #23

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 31 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions benchmark.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Benchmark
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome results, thanks for sharing these!

Before merging with master, I would suggest removing benchmark.md.

GPU: Tesla T4 (16GB DRAM)

- covariance_type = "full"
- init_params = "random"
- n_iter = 20
- delta = 0

| setup | original | k-loop | optimized (single) | optimized (double) |
| --- | --- | --- | --- | --- |
| n_features=16, n_components=16, n_data=100,000 | 6.9s | 6.9s | 0.5s | 3.44s |
| n_features=16, n_components=16, n_data=1,000,000 | OOM | 68.8s | 3.7s | 34.0s |
| n_features=64, n_components=64, n_data=100,000 | OOM | 161s | 3.57s | 13.9s |
| n_features=64, n_components=64, n_data=1,000,000 | OOM | OOM | 44.4s | 527s |
| n_features=256, n_components=256, n_data=100,000 | OOM | OOM | NAN | 686s |
| n_features=256, n_components=16, n_data=1,000,000 | OOM | OOM | 60s | 454s |

### Notes:
- OOM: Out Of Memory
- NAN: Covar contains NaN
- Single/Double: dtype of covariance matrix
- k-loop: almost the same as original `GaussianMixture`, except
```python
var = torch.sum((x - mu).unsqueeze(-1).matmul((x - mu).unsqueeze(-2)) * resp.unsqueeze(-1), dim=0,
keepdim=True) / torch.sum(resp, dim=0, keepdim=True).unsqueeze(-1) + eps
```
in `_m_step` is replaced with
```python
var = torch.empty(1, self.n_components, self.n_features, self.n_features, device=x.device, dtype=resp.dtype)
eps = (torch.eye(self.n_features) * self.eps).to(x.device)
for i in range(self.n_components):
sub_mu = mu[:, i, :]
sub_resp = resp[:, i, :]
sub_x_mu = (x - sub_mu).squeeze(1)
outer = torch.matmul(sub_x_mu[:, :, None], sub_x_mu[:, None, :])
outer_sum = torch.sum(outer * sub_resp[:, :, None], dim=0, keepdim=True)
sub_var = outer_sum / resp_sum[i] + eps
var[:, i, :, :] = sub_var
```
202 changes: 142 additions & 60 deletions gmm.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import torch
import numpy as np
import math
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gmm.py:5 imports from math, so it'd make sense to either replace all occurrences of pi or import ceil alongside it.


from math import pi
from scipy.special import logsumexp
from utils import calculate_matmul, calculate_matmul_n_times
from utils import calculate_matmul, calculate_matmul_n_times, find_optimal_splits
from tqdm import tqdm
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd recommend removing this to keep the repository light on dependencies — users that require this functionality can always add it.



class GaussianMixture(torch.nn.Module):
Expand All @@ -15,30 +17,31 @@ class GaussianMixture(torch.nn.Module):
probabilities are shaped (n, k, 1) if they relate to an individual sample,
or (1, k, 1) if they assign membership probabilities to one of the mixture components.
"""
def __init__(self, n_components, n_features, covariance_type="full", eps=1.e-6, init_params="kmeans", mu_init=None, var_init=None):
def __init__(self, n_components, n_features, covariance_type="full", eps=1.e-6, init_params="kmeans", mu_init=None, var_init=None, covariance_data_type="double"):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reservations against going with default "float" as default type (matches the torch.Tensor default)?

"""
Initializes the model and brings all tensors into their required shape.
The class expects data to be fed as a flat tensor in (n, d).
The class owns:
x: torch.Tensor (n, 1, d)
mu: torch.Tensor (1, k, d)
var: torch.Tensor (1, k, d) or (1, k, d, d)
pi: torch.Tensor (1, k, 1)
covariance_type: str
eps: float
init_params: str
log_likelihood: float
n_components: int
n_features: int
x: torch.Tensor (n, 1, d)
mu: torch.Tensor (1, k, d)
var: torch.Tensor (1, k, d) or (1, k, d, d)
pi: torch.Tensor (1, k, 1)
covariance_type: str
eps: float
init_params: str
log_likelihood: float
n_components: int
n_features: int
args:
n_components: int
n_features: int
n_components: int
n_features: int
options:
mu_init: torch.Tensor (1, k, d)
var_init: torch.Tensor (1, k, d) or (1, k, d, d)
covariance_type: str
eps: float
init_params: str
mu_init: torch.Tensor (1, k, d)
var_init: torch.Tensor (1, k, d) or (1, k, d, d)
covariance_type: str
eps: float
init_params: str
covariance_data_type: str or torch.dtype
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since mu is getting matched against this type, might as well go ahead and introduce this as dtype altogether, right?

"""
super(GaussianMixture, self).__init__()

Expand All @@ -50,6 +53,15 @@ def __init__(self, n_components, n_features, covariance_type="full", eps=1.e-6,
self.eps = eps

self.log_likelihood = -np.inf
self.safe_mode = True
self.prev_log_prob = None

assert covariance_data_type in ["float", "double", torch.float, torch.double]
if covariance_data_type == "float":
covariance_data_type = torch.float
elif covariance_data_type == "double":
covariance_data_type = torch.double
self.covariance_data_type = covariance_data_type

self.covariance_type = covariance_type
self.init_params = init_params
Expand All @@ -72,17 +84,17 @@ def _init_params(self):
if self.var_init is not None:
# (1, k, d)
assert self.var_init.size() == (1, self.n_components, self.n_features), "Input var_init does not have required tensor dimensions (1, %i, %i)" % (self.n_components, self.n_features)
self.var = torch.nn.Parameter(self.var_init, requires_grad=False)
self.var = torch.nn.Parameter(self.var_init.to(self.covariance_data_type), requires_grad=False)
else:
self.var = torch.nn.Parameter(torch.ones(1, self.n_components, self.n_features), requires_grad=False)
self.var = torch.nn.Parameter(torch.ones(1, self.n_components, self.n_features, dtype=self.covariance_data_type), requires_grad=False)
elif self.covariance_type == "full":
if self.var_init is not None:
# (1, k, d, d)
assert self.var_init.size() == (1, self.n_components, self.n_features, self.n_features), "Input var_init does not have required tensor dimensions (1, %i, %i, %i)" % (self.n_components, self.n_features, self.n_features)
self.var = torch.nn.Parameter(self.var_init, requires_grad=False)
self.var = torch.nn.Parameter(self.var_init.to(self.covariance_data_type), requires_grad=False)
else:
self.var = torch.nn.Parameter(
torch.eye(self.n_features).reshape(1, 1, self.n_features, self.n_features).repeat(1, self.n_components, 1, 1),
torch.eye(self.n_features, dtype=self.covariance_data_type).reshape(1, 1, self.n_features, self.n_features).repeat(1, self.n_components, 1, 1),
requires_grad=False
)

Expand Down Expand Up @@ -139,14 +151,15 @@ def fit(self, x, delta=1e-3, n_iter=100, warm_start=False):

i = 0
j = np.inf


pbar = tqdm(total=n_iter)
while (i <= n_iter) and (j >= delta):

log_likelihood_old = self.log_likelihood
mu_old = self.mu
var_old = self.var

self.__em(x)
self.__em(x, use_prev_log_prob=True)
self.log_likelihood = self.__score(x)

if torch.isinf(self.log_likelihood.abs()) or torch.isnan(self.log_likelihood):
Expand All @@ -170,8 +183,10 @@ def fit(self, x, delta=1e-3, n_iter=100, warm_start=False):
# When score decreases, revert to old parameters
self.__update_mu(mu_old)
self.__update_var(var_old)
pbar.update(1)

self.params_fitted = True
pbar.close()


def predict(self, x, probs=False):
Expand All @@ -188,7 +203,8 @@ def predict(self, x, probs=False):
"""
x = self.check_size(x)

weighted_log_prob = self._estimate_log_prob(x) + torch.log(self.pi)
weighted_log_prob = self._estimate_log_prob(x)
weighted_log_prob.add_(torch.log(self.pi))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While carrying this out in-place preserves memory, spreading this across two lines here and in 369 and 466 decreases readability somewhat. Alternatively, I reckon this could be moved into _estimate_log_prob.


if probs:
p_k = torch.exp(weighted_log_prob)
Expand Down Expand Up @@ -257,25 +273,47 @@ def _estimate_log_prob(self, x):
log_prob: torch.Tensor (n, k, 1)
"""
x = self.check_size(x)
N, _, D = x.shape
K = self.n_components

if self.covariance_type == "full":
mu = self.mu
var = self.var

precision = torch.inverse(var)
d = x.shape[-1]

log_2pi = d * np.log(2. * pi)

log_det = self._calculate_log_det(precision)
x = x.to(var.dtype)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since self.covariance_data_type has been allocated, maybe use that instead?

mu = mu.to(var.dtype)

x_mu_T = (x - mu).unsqueeze(-2)
x_mu = (x - mu).unsqueeze(-1)

x_mu_T_precision = calculate_matmul_n_times(self.n_components, x_mu_T, precision)
x_mu_T_precision_x_mu = calculate_matmul(x_mu_T_precision, x_mu)
precision = torch.inverse(var)

return -.5 * (log_2pi - log_det + x_mu_T_precision_x_mu)
log_2pi = D * np.log(2. * pi)

log_det = self._calculate_log_det(precision) #[K, 1]

x_mu_T_precision_x_mu = torch.empty(N, K, 1, device=x.device, dtype=x.dtype)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless there are reservations/concerns, I would consider moving this into its own utility function, in the interest of preserving readability of the code (happy to take care of this once it has been merged).


def get_required_memory(sub_K):
x_mu_requires = sub_K * N * D * x.element_size()
x_mu_T_precision_requires = sub_K * N * D * x.element_size()
calculate_matmul_requires = (sub_K * N * D + sub_K * N) * x.element_size()
return x_mu_requires + x_mu_T_precision_requires + calculate_matmul_requires

n_splits = find_optimal_splits(K, get_required_memory, x.device, safe_mode=self.safe_mode)
sub_K = math.ceil(K / n_splits)
for i in range(n_splits):
K_start = i * sub_K
K_end = min((i + 1) * sub_K, K)
sub_x_mu = x - mu[:, K_start: K_end, :] #[N, sub_K, D]
sub_x_mu_T_precision = (sub_x_mu.transpose(0, 1) @ precision[:, K_start: K_end]).transpose(0, 2)
sub_x_mu_T_precision_x_mu = calculate_matmul(sub_x_mu_T_precision, sub_x_mu[:, :, :, None]) #[N, sub_K, 1]
x_mu_T_precision_x_mu[:, K_start: K_end] = sub_x_mu_T_precision_x_mu
del sub_x_mu, sub_x_mu_T_precision

log_prob = x_mu_T_precision_x_mu
log_prob.add_(log_2pi)
log_prob.add_(-log_det)
log_prob.mul_(-0.5)

return log_prob

elif self.covariance_type == "diag":
mu = self.mu
Expand All @@ -287,64 +325,105 @@ def _estimate_log_prob(self, x):
return -.5 * (self.n_features * np.log(2. * pi) + log_p) + log_det



def _calculate_log_det(self, var):
"""
Calculate log determinant in log space, to prevent overflow errors.
args:
var: torch.Tensor (1, k, d, d)
"""
log_det = torch.empty(size=(self.n_components,)).to(var.device)

for k in range(self.n_components):
log_det[k] = 2 * torch.log(torch.diagonal(torch.linalg.cholesky(var[0,k]))).sum()
assert (var != var).sum() == 0, "`var` contains NaN, set `covariance_data_type` to double"
assert (var.abs() == float("inf")).sum() == 0, "`var` contains inf, set `covariance_data_type` to double"

if self.covariance_data_type == torch.float:
evals = torch.linalg.eigvals(var[0])
# evals, _ = torch.linalg.eig(var[0, k])
log_det = torch.log(evals).sum(dim=-1).to(var.dtype)

elif self.covariance_data_type == torch.double:
cholesky = torch.linalg.cholesky(var[0])
diagonal = torch.diagonal(cholesky, dim1=-2, dim2=-1)
del cholesky
log_det = 2 * torch.log(diagonal).sum(dim=-1)

return log_det.unsqueeze(-1)


def _e_step(self, x):
def _e_step(self, x, use_prev_log_prob=False):
"""
Computes log-responses that indicate the (logarithmic) posterior belief (sometimes called responsibilities) that a data point was generated by one of the k mixture components.
Also returns the mean of the mean of the logarithms of the probabilities (as is done in sklearn).
This is the so-called expectation step of the EM-algorithm.
args:
x: torch.Tensor (n, d) or (n, 1, d)
x: torch.Tensor (n, d) or (n, 1, d)
use_prev_log_prob: bool
returns:
log_prob_norm: torch.Tensor (1)
log_resp: torch.Tensor (n, k, 1)
log_prob_norm: torch.Tensor (1)
log_resp: torch.Tensor (n, k, 1)
"""
x = self.check_size(x)

weighted_log_prob = self._estimate_log_prob(x) + torch.log(self.pi)
if self.prev_log_prob is not None and use_prev_log_prob:
weighted_log_prob = self.prev_log_prob
else:
weighted_log_prob = self._estimate_log_prob(x)
weighted_log_prob.add_(torch.log(self.pi))

log_prob_norm = torch.logsumexp(weighted_log_prob, dim=1, keepdim=True)
log_resp = weighted_log_prob - log_prob_norm

log_resp = weighted_log_prob
log_resp.sub_(log_prob_norm)

return torch.mean(log_prob_norm), log_resp


def _m_step(self, x, log_resp):
def _m_step(self, x, resp):
"""
From the log-probabilities, computes new parameters pi, mu, var (that maximize the log-likelihood). This is the maximization step of the EM-algorithm.
args:
x: torch.Tensor (n, d) or (n, 1, d)
log_resp: torch.Tensor (n, k, 1)
resp: torch.Tensor (n, k, 1)
returns:
pi: torch.Tensor (1, k, 1)
mu: torch.Tensor (1, k, d)
var: torch.Tensor (1, k, d)
"""
x = self.check_size(x)
N, _, D = x.shape
K = self.n_components

resp = torch.exp(log_resp)
resp_sum = resp.sum(dim=0).squeeze(-1) #[K]

pi = torch.sum(resp, dim=0, keepdim=True) + self.eps
mu = torch.sum(resp * x, dim=0, keepdim=True) / pi

mu = (resp.transpose(0, 1)[:, :, 0] @ x[:, 0, :].to(resp.dtype) )[None, :, :]
mu.div_(pi)

if self.covariance_type == "full":
eps = (torch.eye(self.n_features) * self.eps).to(x.device)
var = torch.sum((x - mu).unsqueeze(-1).matmul((x - mu).unsqueeze(-2)) * resp.unsqueeze(-1), dim=0,
keepdim=True) / torch.sum(resp, dim=0, keepdim=True).unsqueeze(-1) + eps
var = torch.empty(1, K, D, D, device=x.device, dtype=resp.dtype)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! 👍

Same thought as before however, given the additional complexity that's introduced here, it might make sense to define these optimizations in some other place.

eps = (torch.eye(D) * self.eps).to(x.device)

def get_required_memory(sub_K):
sub_x_mu_requires = N * D * sub_K * resp.element_size()
sub_x_mu_resp_requires = 2 * N * D * sub_K * resp.element_size()
sub_var_requires = D * D * sub_K * resp.element_size()
return sub_x_mu_requires + sub_x_mu_resp_requires + sub_var_requires

n_splits = find_optimal_splits(K, get_required_memory, x.device, safe_mode=self.safe_mode)
sub_K = math.ceil(K / n_splits)

for i in range(n_splits):
K_start = i * sub_K
K_end = min((i + 1) * sub_K, K)
sub_mu = mu[:, K_start: K_end, :] #[1, sub_K, D]
sub_resp = (resp[:, K_start: K_end, :]).permute(1, 2, 0) #[N, sub_K, 1]
sub_x_mu = (x - sub_mu).permute(1, 2, 0) #[sub_K, D, N]
sub_x_mu_resp = (sub_x_mu * sub_resp).transpose(-1, -2) #[sub_K, N, D]
var[:, K_start: K_end, :, :] = sub_x_mu @ sub_x_mu_resp #[sub_K, D, D]
del sub_x_mu, sub_x_mu_resp
var.div_(resp_sum[None, :, None, None])
var.add_(eps[None, None, :, :])


elif self.covariance_type == "diag":
x2 = (resp * x * x).sum(0, keepdim=True) / pi
mu2 = mu * mu
Expand All @@ -356,14 +435,15 @@ def _m_step(self, x, log_resp):
return pi, mu, var


def __em(self, x):
def __em(self, x, use_prev_log_prob=False):
"""
Performs one iteration of the expectation-maximization algorithm by calling the respective subroutines.
args:
x: torch.Tensor (n, 1, d)
"""
_, log_resp = self._e_step(x)
pi, mu, var = self._m_step(x, log_resp)
_, resp = self._e_step(x, use_prev_log_prob)
resp.exp_()
pi, mu, var = self._m_step(x, resp)

self.__update_pi(pi)
self.__update_mu(mu)
Expand All @@ -382,8 +462,10 @@ def __score(self, x, as_average=True):
per_sample_score: torch.Tensor (n)

"""
weighted_log_prob = self._estimate_log_prob(x) + torch.log(self.pi)
weighted_log_prob = self._estimate_log_prob(x)
weighted_log_prob.add_(torch.log(self.pi))
per_sample_score = torch.logsumexp(weighted_log_prob, dim=1)
self.prev_log_prob = weighted_log_prob

if as_average:
return per_sample_score.mean()
Expand Down
Loading