From 6656009d191e37a5d7b59f7bad14f5b8f022b9eb Mon Sep 17 00:00:00 2001 From: DeMoriarty Date: Sun, 13 Mar 2022 14:41:07 +0300 Subject: [PATCH 01/31] Update gmm.py --- gmm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/gmm.py b/gmm.py index 5aa7d28..998b9d9 100644 --- a/gmm.py +++ b/gmm.py @@ -50,6 +50,8 @@ def __init__(self, n_components, n_features, covariance_type="full", eps=1.e-6, self.eps = eps self.log_likelihood = -np.inf + self.safe_mode = True + self.prev_log_prob_norm = None self.covariance_type = covariance_type self.init_params = init_params From f77579726d7a8e4b7091f503cb1dd569b544f867 Mon Sep 17 00:00:00 2001 From: DeMoriarty Date: Sun, 13 Mar 2022 15:51:52 +0300 Subject: [PATCH 02/31] optimizations --- gmm.py | 137 ++++++++++++++++++++++++++++++++++++++++--------------- utils.py | 82 +++++++++++++++++++++++++++++++++ 2 files changed, 182 insertions(+), 37 deletions(-) diff --git a/gmm.py b/gmm.py index 998b9d9..9c2c438 100644 --- a/gmm.py +++ b/gmm.py @@ -3,7 +3,8 @@ from math import pi from scipy.special import logsumexp -from utils import calculate_matmul, calculate_matmul_n_times +from utils import calculate_matmul, calculate_matmul_n_times, find_optimal_splits +from tqdm import tqdm class GaussianMixture(torch.nn.Module): @@ -141,14 +142,15 @@ def fit(self, x, delta=1e-3, n_iter=100, warm_start=False): i = 0 j = np.inf - + + pbar = tqdm(total=n_iter) while (i <= n_iter) and (j >= delta): log_likelihood_old = self.log_likelihood mu_old = self.mu var_old = self.var - self.__em(x) + self.__em(x, use_prev_log_prob_norm=True) self.log_likelihood = self.__score(x) if torch.isinf(self.log_likelihood.abs()) or torch.isnan(self.log_likelihood): @@ -172,8 +174,10 @@ def fit(self, x, delta=1e-3, n_iter=100, warm_start=False): # When score decreases, revert to old parameters self.__update_mu(mu_old) self.__update_var(var_old) + pbar.update(1) self.params_fitted = True + pbar.close() def predict(self, x, probs=False): @@ -190,7 +194,8 @@ def predict(self, x, probs=False): """ x = self.check_size(x) - weighted_log_prob = self._estimate_log_prob(x) + torch.log(self.pi) + weighted_log_prob = self._estimate_log_prob(x) + weighted_log_prob.add_(torch.log(self.pi)) if probs: p_k = torch.exp(weighted_log_prob) @@ -259,25 +264,47 @@ def _estimate_log_prob(self, x): log_prob: torch.Tensor (n, k, 1) """ x = self.check_size(x) + N, _, D = x.shape + K = self.n_components if self.covariance_type == "full": mu = self.mu var = self.var - precision = torch.inverse(var) - d = x.shape[-1] - - log_2pi = d * np.log(2. * pi) + x = x.to(var.dtype) + mu = mu.to(var.dtype) - log_det = self._calculate_log_det(precision) - - x_mu_T = (x - mu).unsqueeze(-2) - x_mu = (x - mu).unsqueeze(-1) - - x_mu_T_precision = calculate_matmul_n_times(self.n_components, x_mu_T, precision) - x_mu_T_precision_x_mu = calculate_matmul(x_mu_T_precision, x_mu) + precision = torch.inverse(var) - return -.5 * (log_2pi - log_det + x_mu_T_precision_x_mu) + log_2pi = D * np.log(2. * pi) + + log_det = self._calculate_log_det(precision) #[K, 1] + + x_mu_T_precision_x_mu = torch.empty(N, K, 1, device=x.device, dtype=x.dtype) + + def get_required_memory(sub_K): + x_mu_requires = sub_K * N * D * x.element_size() + x_mu_T_precision_requires = sub_K * N * D * x.element_size() + calculate_matmul_requires = (sub_K * N * D + sub_K * N) * x.element_size() + return x_mu_requires + x_mu_T_precision_requires + calculate_matmul_requires + + n_splits = find_optimal_splits(K, get_required_memory, x.device, safe_mode=self.safe_mode) + sub_K = math.ceil(K / n_splits) + for i in range(n_splits): + K_start = i * sub_K + K_end = min((i + 1) * sub_K, K) + sub_x_mu = x - mu[:, K_start: K_end, :] #[N, sub_K, D] + sub_x_mu_T_precision = (sub_x_mu.transpose(0, 1) @ precision[:, K_start: K_end]).transpose(0, 2) + sub_x_mu_T_precision_x_mu = calculate_matmul(sub_x_mu_T_precision, sub_x_mu[:, :, :, None]) #[N, sub_K, 1] + x_mu_T_precision_x_mu[:, K_start: K_end] = sub_x_mu_T_precision_x_mu + del sub_x_mu, sub_x_mu_T_precision + + log_prob = x_mu_T_precision_x_mu + log_prob.add_(log_2pi) + log_prob.add_(-log_det) + log_prob.mul_(-0.5) + + return log_prob elif self.covariance_type == "diag": mu = self.mu @@ -289,64 +316,97 @@ def _estimate_log_prob(self, x): return -.5 * (self.n_features * np.log(2. * pi) + log_p) + log_det - def _calculate_log_det(self, var): """ Calculate log determinant in log space, to prevent overflow errors. args: var: torch.Tensor (1, k, d, d) """ - log_det = torch.empty(size=(self.n_components,)).to(var.device) - - for k in range(self.n_components): - log_det[k] = 2 * torch.log(torch.diagonal(torch.linalg.cholesky(var[0,k]))).sum() + cholesky = torch.linalg.cholesky(var[0]) + diag = torch.diagonal(cholesky, dim1=-2, dim2=-1) + del cholesky + log_det = 2 * torch.log(diagonal).sum(dim=-1) + return log_det.unsqueeze(-1) - def _e_step(self, x): + def _e_step(self, x, use_prev_log_prob_norm=False): """ Computes log-responses that indicate the (logarithmic) posterior belief (sometimes called responsibilities) that a data point was generated by one of the k mixture components. Also returns the mean of the mean of the logarithms of the probabilities (as is done in sklearn). This is the so-called expectation step of the EM-algorithm. args: - x: torch.Tensor (n, d) or (n, 1, d) + x: torch.Tensor (n, d) or (n, 1, d) + use_prev_log_prob_norm: bool returns: - log_prob_norm: torch.Tensor (1) - log_resp: torch.Tensor (n, k, 1) + log_prob_norm: torch.Tensor (1) + log_resp: torch.Tensor (n, k, 1) """ x = self.check_size(x) - weighted_log_prob = self._estimate_log_prob(x) + torch.log(self.pi) + if self.prev_log_prob_norm is not None and use_prev_log_prob_norm: + log_prob_norm = self.prev_log_prob_norm + else: + weighted_log_prob = self._estimate_log_prob(x) + weighted_log_prob.add_(torch.log(self.pi)) + log_prob_norm = torch.logsumexp(weighted_log_prob, dim=1, keepdim=True) - log_prob_norm = torch.logsumexp(weighted_log_prob, dim=1, keepdim=True) - log_resp = weighted_log_prob - log_prob_norm + log_resp = weighted_log_prob + log_resp.sub_(log_prob_norm) return torch.mean(log_prob_norm), log_resp - def _m_step(self, x, log_resp): + def _m_step(self, x, resp): """ From the log-probabilities, computes new parameters pi, mu, var (that maximize the log-likelihood). This is the maximization step of the EM-algorithm. args: x: torch.Tensor (n, d) or (n, 1, d) - log_resp: torch.Tensor (n, k, 1) + resp: torch.Tensor (n, k, 1) returns: pi: torch.Tensor (1, k, 1) mu: torch.Tensor (1, k, d) var: torch.Tensor (1, k, d) """ x = self.check_size(x) + N, _, D = x.shape + K = self.n_components resp = torch.exp(log_resp) + resp_sum = resp.sum(dim=0).squeeze(-1) #[K] pi = torch.sum(resp, dim=0, keepdim=True) + self.eps - mu = torch.sum(resp * x, dim=0, keepdim=True) / pi + + mu = (resp.transpose(0, 1)[:, :, 0] @ x[:, 0, :].double() )[None, :, :] + mu.div_(pi) if self.covariance_type == "full": - eps = (torch.eye(self.n_features) * self.eps).to(x.device) - var = torch.sum((x - mu).unsqueeze(-1).matmul((x - mu).unsqueeze(-2)) * resp.unsqueeze(-1), dim=0, - keepdim=True) / torch.sum(resp, dim=0, keepdim=True).unsqueeze(-1) + eps + var = torch.empty(1, K, D, D, device=x.device, dtype=resp.dtype) + eps = (torch.eye(K) * self.eps).to(x.device) + + def get_required_memory(sub_K): + sub_x_mu_requires = N * D * sub_K * resp.element_size() + sub_x_mu_resp_requires = 2 * N * D * sub_K * resp.element_size() + sub_var_requires = D * D * sub_K * resp.element_size() + return sub_x_mu_requires + sub_x_mu_resp_requires + sub_var_requires + + n_splits = find_optimal_splits(K, get_required_memory, x.device, safe_mode=self.safe_mode) + sub_K = math.ceil(K / n_splits) + + for i in range(n_splits): + K_start = i * sub_K + K_end = min((i + 1) * sub_K, K) + sub_mu = mu[:, K_start: K_end, :] #[1, sub_K, D] + sub_resp = (resp[:, K_start: K_end, :]).permute(1, 2, 0) #[N, sub_K, 1] + sub_x_mu = (x - sub_mu).permute(1, 2, 0) #[sub_K, D, N] + sub_x_mu_resp = (sub_x_mu * sub_resp).transpose(-1, -2) #[sub_K, N, D] + var[:, K_start: K_end, :, :] = sub_x_mu @ sub_x_mu_resp #[sub_K, D, D] + del sub_x_mu, sub_x_mu_resp + var.div_(resp_sum[None, :, None, None]) + var.add_(eps[None, None, :, :]) + + elif self.covariance_type == "diag": x2 = (resp * x * x).sum(0, keepdim=True) / pi mu2 = mu * mu @@ -364,8 +424,9 @@ def __em(self, x): args: x: torch.Tensor (n, 1, d) """ - _, log_resp = self._e_step(x) - pi, mu, var = self._m_step(x, log_resp) + _, resp = self._e_step(x) + resp.exp_() + pi, mu, var = self._m_step(x, resp) self.__update_pi(pi) self.__update_mu(mu) @@ -384,8 +445,10 @@ def __score(self, x, as_average=True): per_sample_score: torch.Tensor (n) """ - weighted_log_prob = self._estimate_log_prob(x) + torch.log(self.pi) + weighted_log_prob = self._estimate_log_prob(x) + weighted_log_prob.add_(torch.log(self.pi)) per_sample_score = torch.logsumexp(weighted_log_prob, dim=1) + self.prev_log_prob_norm = per_sample_score.unsqueeze(1) if as_average: return per_sample_score.mean() diff --git a/utils.py b/utils.py index 06c2206..79c21a0 100644 --- a/utils.py +++ b/utils.py @@ -1,4 +1,5 @@ import torch +import math def calculate_matmul_n_times(n_components, mat_a, mat_b): """ @@ -28,3 +29,84 @@ def calculate_matmul(mat_a, mat_b): """ assert mat_a.shape[-2] == 1 and mat_b.shape[-1] == 1 return torch.sum(mat_a.squeeze(-2) * mat_b.squeeze(-1), dim=2, keepdim=True) + + +def check_available_ram(device="cpu"): + """ + Returns available RAM on target device + args: + device: str or torch.device + """ + if isinstance(device, str): + device = torch.device(device) + elif isinstance(device, torch.device): + device = device + else: + raise RuntimeError("`device` must be str or torch.device") + + if device.type == "cpu": + return psutil.virtual_memory().available + else: + total = torch.cuda.get_device_properties(device).total_memory + used = torch.cuda.memory_allocated(device) + return total - used + +def will_it_fit(size, device="cpu", safe_mode=True): + """ + Returns True if an array of given byte size fits in target device. + + if self.safe_mode = False, this function simply compares the given byte size with the remaining RAM on target device. This option is faster, + but it doesn't take memory fragmentation into account. So it will still be possible to run out of memory. + + if self.safe_mode = True, it will try to allocate a tensor with the given size. if allocation fails, return False. + This option is recommended when the other option fails because of OOM. + + args: + size: int + device: str or torch.device + safe_mode: bool + returns: + result: bool + """ + if safe_mode: + try: + torch.empty(size, device=device, dtype=torch.uint8) + except: + return False + return True + else: + return check_available_ram(device) >= size + + +def find_optimal_splits(n, get_required_memory, device="cpu", safe_mode=True): + """ + Find an optimal number of split for `n`, such that `get_required_memory(math.ceil(n / n_split))` fits in target device's RAM. + get_required_memory should be a fucntion that receives `math.ceil(n/n_split)` and returns the required memory in bytes. + + args: + n: int + get_required_memory: function + device: str or torch.device + safe_mode: bool + + returns: + n_splits: int + + """ + splits = 1 + sub_n = n + break_next = False + while True: + if break_next: + break + if splits > n: + splits = n + break_next = True + sub_n = math.ceil(n / splits) + required_memory = get_required_memory(sub_n) + if will_it_fit(required_memory, device): + break + else: + splits *= 2 + continue + return splits \ No newline at end of file From f5537b785ca9b2e19414e9380eb18c97435193d1 Mon Sep 17 00:00:00 2001 From: DeMoriarty Date: Sun, 13 Mar 2022 15:57:58 +0300 Subject: [PATCH 03/31] bugfix --- gmm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gmm.py b/gmm.py index 9c2c438..cee8f88 100644 --- a/gmm.py +++ b/gmm.py @@ -418,13 +418,13 @@ def get_required_memory(sub_K): return pi, mu, var - def __em(self, x): + def __em(self, x, use_prev_log_prob_norm=False): """ Performs one iteration of the expectation-maximization algorithm by calling the respective subroutines. args: x: torch.Tensor (n, 1, d) """ - _, resp = self._e_step(x) + _, resp = self._e_step(x, use_prev_log_prob_norm) resp.exp_() pi, mu, var = self._m_step(x, resp) From 29e6d2ab05f9ed13f0c46f39b9208edf0d0036aa Mon Sep 17 00:00:00 2001 From: DeMoriarty Date: Sun, 13 Mar 2022 16:00:32 +0300 Subject: [PATCH 04/31] bugfix --- gmm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gmm.py b/gmm.py index cee8f88..22f77fa 100644 --- a/gmm.py +++ b/gmm.py @@ -324,7 +324,7 @@ def _calculate_log_det(self, var): """ cholesky = torch.linalg.cholesky(var[0]) - diag = torch.diagonal(cholesky, dim1=-2, dim2=-1) + diagonal = torch.diagonal(cholesky, dim1=-2, dim2=-1) del cholesky log_det = 2 * torch.log(diagonal).sum(dim=-1) From 4cf41f8cb2f1c90769eb89b0a05428592a660b10 Mon Sep 17 00:00:00 2001 From: DeMoriarty Date: Sun, 13 Mar 2022 16:02:52 +0300 Subject: [PATCH 05/31] bugfix --- gmm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/gmm.py b/gmm.py index 22f77fa..1f0e51f 100644 --- a/gmm.py +++ b/gmm.py @@ -1,5 +1,6 @@ import torch import numpy as np +import math from math import pi from scipy.special import logsumexp From c4c73d462789694667414473b7604b5eded34824 Mon Sep 17 00:00:00 2001 From: DeMoriarty Date: Sun, 13 Mar 2022 16:03:45 +0300 Subject: [PATCH 06/31] bugfix --- gmm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/gmm.py b/gmm.py index 1f0e51f..f4655fa 100644 --- a/gmm.py +++ b/gmm.py @@ -374,7 +374,6 @@ def _m_step(self, x, resp): N, _, D = x.shape K = self.n_components - resp = torch.exp(log_resp) resp_sum = resp.sum(dim=0).squeeze(-1) #[K] pi = torch.sum(resp, dim=0, keepdim=True) + self.eps From 54e9cbdfb602b36f8064e0af1fffddba87167d37 Mon Sep 17 00:00:00 2001 From: DeMoriarty Date: Sun, 13 Mar 2022 16:19:42 +0300 Subject: [PATCH 07/31] bugfix --- gmm.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/gmm.py b/gmm.py index f4655fa..de94bc2 100644 --- a/gmm.py +++ b/gmm.py @@ -53,7 +53,7 @@ def __init__(self, n_components, n_features, covariance_type="full", eps=1.e-6, self.log_likelihood = -np.inf self.safe_mode = True - self.prev_log_prob_norm = None + self.prev_log_prob = None self.covariance_type = covariance_type self.init_params = init_params @@ -151,7 +151,7 @@ def fit(self, x, delta=1e-3, n_iter=100, warm_start=False): mu_old = self.mu var_old = self.var - self.__em(x, use_prev_log_prob_norm=True) + self.__em(x, use_prev_log_prob=True) self.log_likelihood = self.__score(x) if torch.isinf(self.log_likelihood.abs()) or torch.isnan(self.log_likelihood): @@ -332,26 +332,27 @@ def _calculate_log_det(self, var): return log_det.unsqueeze(-1) - def _e_step(self, x, use_prev_log_prob_norm=False): + def _e_step(self, x, use_prev_log_prob=False): """ Computes log-responses that indicate the (logarithmic) posterior belief (sometimes called responsibilities) that a data point was generated by one of the k mixture components. Also returns the mean of the mean of the logarithms of the probabilities (as is done in sklearn). This is the so-called expectation step of the EM-algorithm. args: x: torch.Tensor (n, d) or (n, 1, d) - use_prev_log_prob_norm: bool + use_prev_log_prob: bool returns: log_prob_norm: torch.Tensor (1) log_resp: torch.Tensor (n, k, 1) """ x = self.check_size(x) - if self.prev_log_prob_norm is not None and use_prev_log_prob_norm: - log_prob_norm = self.prev_log_prob_norm + if self.prev_log_prob is not None and use_prev_log_prob: + weighted_log_prob = self.prev_log_prob else: weighted_log_prob = self._estimate_log_prob(x) weighted_log_prob.add_(torch.log(self.pi)) - log_prob_norm = torch.logsumexp(weighted_log_prob, dim=1, keepdim=True) + + log_prob_norm = torch.logsumexp(weighted_log_prob, dim=1, keepdim=True) log_resp = weighted_log_prob log_resp.sub_(log_prob_norm) @@ -378,12 +379,12 @@ def _m_step(self, x, resp): pi = torch.sum(resp, dim=0, keepdim=True) + self.eps - mu = (resp.transpose(0, 1)[:, :, 0] @ x[:, 0, :].double() )[None, :, :] + mu = (resp.transpose(0, 1)[:, :, 0] @ x[:, 0, :].to(resp.dtype) )[None, :, :] mu.div_(pi) if self.covariance_type == "full": var = torch.empty(1, K, D, D, device=x.device, dtype=resp.dtype) - eps = (torch.eye(K) * self.eps).to(x.device) + eps = (torch.eye(D) * self.eps).to(x.device) def get_required_memory(sub_K): sub_x_mu_requires = N * D * sub_K * resp.element_size() @@ -418,13 +419,13 @@ def get_required_memory(sub_K): return pi, mu, var - def __em(self, x, use_prev_log_prob_norm=False): + def __em(self, x, use_prev_log_prob=False): """ Performs one iteration of the expectation-maximization algorithm by calling the respective subroutines. args: x: torch.Tensor (n, 1, d) """ - _, resp = self._e_step(x, use_prev_log_prob_norm) + _, resp = self._e_step(x, use_prev_log_prob) resp.exp_() pi, mu, var = self._m_step(x, resp) @@ -448,7 +449,7 @@ def __score(self, x, as_average=True): weighted_log_prob = self._estimate_log_prob(x) weighted_log_prob.add_(torch.log(self.pi)) per_sample_score = torch.logsumexp(weighted_log_prob, dim=1) - self.prev_log_prob_norm = per_sample_score.unsqueeze(1) + self.prev_log_prob = weighted_log_prob if as_average: return per_sample_score.mean() From 69e58c0c8d9f968208678fa09b711cf485ddb3a4 Mon Sep 17 00:00:00 2001 From: DeMoriarty Date: Sun, 13 Mar 2022 16:50:08 +0300 Subject: [PATCH 08/31] bugfix --- gmm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gmm.py b/gmm.py index de94bc2..ed9e9cc 100644 --- a/gmm.py +++ b/gmm.py @@ -86,7 +86,7 @@ def _init_params(self): self.var = torch.nn.Parameter(self.var_init, requires_grad=False) else: self.var = torch.nn.Parameter( - torch.eye(self.n_features).reshape(1, 1, self.n_features, self.n_features).repeat(1, self.n_components, 1, 1), + torch.eye(self.n_features, torch.float64).reshape(1, 1, self.n_features, self.n_features).repeat(1, self.n_components, 1, 1), requires_grad=False ) @@ -351,7 +351,7 @@ def _e_step(self, x, use_prev_log_prob=False): else: weighted_log_prob = self._estimate_log_prob(x) weighted_log_prob.add_(torch.log(self.pi)) - + log_prob_norm = torch.logsumexp(weighted_log_prob, dim=1, keepdim=True) log_resp = weighted_log_prob From 0f70e40ddc294d0947c58f9f13025788e6998b20 Mon Sep 17 00:00:00 2001 From: DeMoriarty Date: Sun, 13 Mar 2022 16:51:07 +0300 Subject: [PATCH 09/31] bugfix --- gmm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gmm.py b/gmm.py index ed9e9cc..3f9bc32 100644 --- a/gmm.py +++ b/gmm.py @@ -86,7 +86,7 @@ def _init_params(self): self.var = torch.nn.Parameter(self.var_init, requires_grad=False) else: self.var = torch.nn.Parameter( - torch.eye(self.n_features, torch.float64).reshape(1, 1, self.n_features, self.n_features).repeat(1, self.n_components, 1, 1), + torch.eye(self.n_features, dtype=torch.float64).reshape(1, 1, self.n_features, self.n_features).repeat(1, self.n_components, 1, 1), requires_grad=False ) From 083e2414b963d8153682645a89309f1f1b8f7038 Mon Sep 17 00:00:00 2001 From: DeMoriarty Date: Sun, 13 Mar 2022 17:00:58 +0300 Subject: [PATCH 10/31] configurable covar dtype --- gmm.py | 44 ++++++++++++++++++++++++++------------------ 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/gmm.py b/gmm.py index 3f9bc32..e8c9cd4 100644 --- a/gmm.py +++ b/gmm.py @@ -22,25 +22,26 @@ def __init__(self, n_components, n_features, covariance_type="full", eps=1.e-6, Initializes the model and brings all tensors into their required shape. The class expects data to be fed as a flat tensor in (n, d). The class owns: - x: torch.Tensor (n, 1, d) - mu: torch.Tensor (1, k, d) - var: torch.Tensor (1, k, d) or (1, k, d, d) - pi: torch.Tensor (1, k, 1) - covariance_type: str - eps: float - init_params: str - log_likelihood: float - n_components: int - n_features: int + x: torch.Tensor (n, 1, d) + mu: torch.Tensor (1, k, d) + var: torch.Tensor (1, k, d) or (1, k, d, d) + pi: torch.Tensor (1, k, 1) + covariance_type: str + eps: float + init_params: str + log_likelihood: float + n_components: int + n_features: int args: - n_components: int - n_features: int + n_components: int + n_features: int options: - mu_init: torch.Tensor (1, k, d) - var_init: torch.Tensor (1, k, d) or (1, k, d, d) - covariance_type: str - eps: float - init_params: str + mu_init: torch.Tensor (1, k, d) + var_init: torch.Tensor (1, k, d) or (1, k, d, d) + covariance_type: str + eps: float + init_params: str + covariance_data_type: str or torch.dtype """ super(GaussianMixture, self).__init__() @@ -55,6 +56,13 @@ def __init__(self, n_components, n_features, covariance_type="full", eps=1.e-6, self.safe_mode = True self.prev_log_prob = None + assert covariance_data_type in ["float", "double", torch.float, torch.double] + if covariance_data_type == "float": + covariance_data_type = torch.float + elif covariance_data_type == "double": + covariance_data_type = torch.double + self.covariance_data_type = covariance_data_type + self.covariance_type = covariance_type self.init_params = init_params @@ -86,7 +94,7 @@ def _init_params(self): self.var = torch.nn.Parameter(self.var_init, requires_grad=False) else: self.var = torch.nn.Parameter( - torch.eye(self.n_features, dtype=torch.float64).reshape(1, 1, self.n_features, self.n_features).repeat(1, self.n_components, 1, 1), + torch.eye(self.n_features, dtype=self.covariance_data_type).reshape(1, 1, self.n_features, self.n_features).repeat(1, self.n_components, 1, 1), requires_grad=False ) From 6d73ade31ce594959040d47016b838b7e94dd0b5 Mon Sep 17 00:00:00 2001 From: DeMoriarty Date: Sun, 13 Mar 2022 17:12:48 +0300 Subject: [PATCH 11/31] bugfix --- gmm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gmm.py b/gmm.py index e8c9cd4..c81a57f 100644 --- a/gmm.py +++ b/gmm.py @@ -17,7 +17,7 @@ class GaussianMixture(torch.nn.Module): probabilities are shaped (n, k, 1) if they relate to an individual sample, or (1, k, 1) if they assign membership probabilities to one of the mixture components. """ - def __init__(self, n_components, n_features, covariance_type="full", eps=1.e-6, init_params="kmeans", mu_init=None, var_init=None): + def __init__(self, n_components, n_features, covariance_type="full", eps=1.e-6, init_params="kmeans", mu_init=None, var_init=None, covariance_data_type="double"): """ Initializes the model and brings all tensors into their required shape. The class expects data to be fed as a flat tensor in (n, d). From 3ab5051484a18d662950c70980371931469ecf59 Mon Sep 17 00:00:00 2001 From: DeMoriarty Date: Sun, 13 Mar 2022 17:35:56 +0300 Subject: [PATCH 12/31] . --- gmm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gmm.py b/gmm.py index c81a57f..c0cf380 100644 --- a/gmm.py +++ b/gmm.py @@ -84,14 +84,14 @@ def _init_params(self): if self.var_init is not None: # (1, k, d) assert self.var_init.size() == (1, self.n_components, self.n_features), "Input var_init does not have required tensor dimensions (1, %i, %i)" % (self.n_components, self.n_features) - self.var = torch.nn.Parameter(self.var_init, requires_grad=False) + self.var = torch.nn.Parameter(self.var_init.to(self.covariance_data_type), requires_grad=False) else: - self.var = torch.nn.Parameter(torch.ones(1, self.n_components, self.n_features), requires_grad=False) + self.var = torch.nn.Parameter(torch.ones(1, self.n_components, self.n_features, dtype=self.covariance_data_type), requires_grad=False) elif self.covariance_type == "full": if self.var_init is not None: # (1, k, d, d) assert self.var_init.size() == (1, self.n_components, self.n_features, self.n_features), "Input var_init does not have required tensor dimensions (1, %i, %i, %i)" % (self.n_components, self.n_features, self.n_features) - self.var = torch.nn.Parameter(self.var_init, requires_grad=False) + self.var = torch.nn.Parameter(self.var_init.to(self.covariance_data_type), requires_grad=False) else: self.var = torch.nn.Parameter( torch.eye(self.n_features, dtype=self.covariance_data_type).reshape(1, 1, self.n_features, self.n_features).repeat(1, self.n_components, 1, 1), From e2d010a628d48b44888e33a0f69a9eef89eb0991 Mon Sep 17 00:00:00 2001 From: DeMoriarty Date: Sun, 13 Mar 2022 19:44:09 +0300 Subject: [PATCH 13/31] use eigvals to compute log det --- gmm.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/gmm.py b/gmm.py index c0cf380..0a1b35a 100644 --- a/gmm.py +++ b/gmm.py @@ -332,11 +332,14 @@ def _calculate_log_det(self, var): var: torch.Tensor (1, k, d, d) """ - cholesky = torch.linalg.cholesky(var[0]) - diagonal = torch.diagonal(cholesky, dim1=-2, dim2=-1) - del cholesky - log_det = 2 * torch.log(diagonal).sum(dim=-1) + # cholesky = torch.linalg.cholesky(var[0]) + # diagonal = torch.diagonal(cholesky, dim1=-2, dim2=-1) + # del cholesky + # log_det = 2 * torch.log(diagonal).sum(dim=-1) + evals = torch.linalg.eigvals(var[0]) + log_det = evals.log().sum(dim=-1).to(var.dtype) + return log_det.unsqueeze(-1) From e9976f7b006d3a1996cea881449397c02e728b45 Mon Sep 17 00:00:00 2001 From: DeMoriarty Date: Sun, 13 Mar 2022 20:03:32 +0300 Subject: [PATCH 14/31] . --- gmm.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/gmm.py b/gmm.py index 0a1b35a..f4bd413 100644 --- a/gmm.py +++ b/gmm.py @@ -337,8 +337,12 @@ def _calculate_log_det(self, var): # del cholesky # log_det = 2 * torch.log(diagonal).sum(dim=-1) - evals = torch.linalg.eigvals(var[0]) - log_det = evals.log().sum(dim=-1).to(var.dtype) + + log_det = torch.empty(size=(self.n_components,), device=var.device, dtype=var.dtype) + + for k in range(self.n_components): + evals = torch.linalg.eigvals(var[0, i]) + log_det = evals.log().sum().to(var.dtype) return log_det.unsqueeze(-1) From 1569ecc97174362922e22f647d616caba2d7bb46 Mon Sep 17 00:00:00 2001 From: DeMoriarty Date: Sun, 13 Mar 2022 20:04:21 +0300 Subject: [PATCH 15/31] . --- gmm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gmm.py b/gmm.py index f4bd413..640b420 100644 --- a/gmm.py +++ b/gmm.py @@ -341,7 +341,7 @@ def _calculate_log_det(self, var): log_det = torch.empty(size=(self.n_components,), device=var.device, dtype=var.dtype) for k in range(self.n_components): - evals = torch.linalg.eigvals(var[0, i]) + evals = torch.linalg.eigvals(var[0, k]) log_det = evals.log().sum().to(var.dtype) return log_det.unsqueeze(-1) From d11fe279f0267e22428b09b3c3774c3df84e36b5 Mon Sep 17 00:00:00 2001 From: DeMoriarty Date: Sun, 13 Mar 2022 20:06:21 +0300 Subject: [PATCH 16/31] bugfix --- gmm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gmm.py b/gmm.py index 640b420..c39847e 100644 --- a/gmm.py +++ b/gmm.py @@ -341,7 +341,8 @@ def _calculate_log_det(self, var): log_det = torch.empty(size=(self.n_components,), device=var.device, dtype=var.dtype) for k in range(self.n_components): - evals = torch.linalg.eigvals(var[0, k]) + # evals = torch.linalg.eigvals(var[0, k]) + evals, _ = torch.linalg.eig(var[0, k]) log_det = evals.log().sum().to(var.dtype) return log_det.unsqueeze(-1) From 04ccd94dbcc51b35715b6181b24c5b84f230aaad Mon Sep 17 00:00:00 2001 From: DeMoriarty Date: Sun, 13 Mar 2022 20:22:35 +0300 Subject: [PATCH 17/31] . --- gmm.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/gmm.py b/gmm.py index c39847e..047f24b 100644 --- a/gmm.py +++ b/gmm.py @@ -338,12 +338,14 @@ def _calculate_log_det(self, var): # log_det = 2 * torch.log(diagonal).sum(dim=-1) + assert (var != var).sum() == 0, "var contains NaN" + assert (var.abs() == float("inf")).sum() == 0, "var contains inf" log_det = torch.empty(size=(self.n_components,), device=var.device, dtype=var.dtype) for k in range(self.n_components): # evals = torch.linalg.eigvals(var[0, k]) evals, _ = torch.linalg.eig(var[0, k]) - log_det = evals.log().sum().to(var.dtype) + log_det = torch.log(evals).sum().to(var.dtype) return log_det.unsqueeze(-1) From 315520cf37741c0fe6e1a48219d53f8b3f71d76f Mon Sep 17 00:00:00 2001 From: DeMoriarty Date: Sun, 13 Mar 2022 20:26:08 +0300 Subject: [PATCH 18/31] . --- gmm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gmm.py b/gmm.py index 047f24b..dd50b87 100644 --- a/gmm.py +++ b/gmm.py @@ -338,8 +338,8 @@ def _calculate_log_det(self, var): # log_det = 2 * torch.log(diagonal).sum(dim=-1) - assert (var != var).sum() == 0, "var contains NaN" - assert (var.abs() == float("inf")).sum() == 0, "var contains inf" + assert (var != var).sum() == 0, "`var` contains NaN, set `covariance_data_type` to double" + assert (var.abs() == float("inf")).sum() == 0, "`var` contains inf, set `covariance_data_type` to double" log_det = torch.empty(size=(self.n_components,), device=var.device, dtype=var.dtype) for k in range(self.n_components): From c42b198b70f27148380188de0ed2e4e234300c8d Mon Sep 17 00:00:00 2001 From: DeMoriarty Date: Sun, 13 Mar 2022 20:30:49 +0300 Subject: [PATCH 19/31] . --- gmm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gmm.py b/gmm.py index dd50b87..cef9836 100644 --- a/gmm.py +++ b/gmm.py @@ -343,8 +343,8 @@ def _calculate_log_det(self, var): log_det = torch.empty(size=(self.n_components,), device=var.device, dtype=var.dtype) for k in range(self.n_components): - # evals = torch.linalg.eigvals(var[0, k]) - evals, _ = torch.linalg.eig(var[0, k]) + evals = torch.linalg.eigvals(var[0, k]) + # evals, _ = torch.linalg.eig(var[0, k]) log_det = torch.log(evals).sum().to(var.dtype) return log_det.unsqueeze(-1) From 65ee6607aaae4d9f9f10bf78a89ccd97e85b4835 Mon Sep 17 00:00:00 2001 From: DeMoriarty Date: Sun, 13 Mar 2022 20:32:55 +0300 Subject: [PATCH 20/31] . --- gmm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gmm.py b/gmm.py index cef9836..74c9ce0 100644 --- a/gmm.py +++ b/gmm.py @@ -345,7 +345,7 @@ def _calculate_log_det(self, var): for k in range(self.n_components): evals = torch.linalg.eigvals(var[0, k]) # evals, _ = torch.linalg.eig(var[0, k]) - log_det = torch.log(evals).sum().to(var.dtype) + log_det[i] = torch.log(evals).sum().to(var.dtype) return log_det.unsqueeze(-1) From bd1c6323e07d237edf1de9438b0ced42f4c5dc96 Mon Sep 17 00:00:00 2001 From: DeMoriarty Date: Sun, 13 Mar 2022 20:34:17 +0300 Subject: [PATCH 21/31] . --- gmm.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/gmm.py b/gmm.py index 74c9ce0..348f945 100644 --- a/gmm.py +++ b/gmm.py @@ -331,6 +331,8 @@ def _calculate_log_det(self, var): args: var: torch.Tensor (1, k, d, d) """ + assert (var != var).sum() == 0, "`var` contains NaN, set `covariance_data_type` to double" + assert (var.abs() == float("inf")).sum() == 0, "`var` contains inf, set `covariance_data_type` to double" # cholesky = torch.linalg.cholesky(var[0]) # diagonal = torch.diagonal(cholesky, dim1=-2, dim2=-1) @@ -338,14 +340,16 @@ def _calculate_log_det(self, var): # log_det = 2 * torch.log(diagonal).sum(dim=-1) - assert (var != var).sum() == 0, "`var` contains NaN, set `covariance_data_type` to double" - assert (var.abs() == float("inf")).sum() == 0, "`var` contains inf, set `covariance_data_type` to double" - log_det = torch.empty(size=(self.n_components,), device=var.device, dtype=var.dtype) + + # evals = torch.linalg.eigvals(var[0]) + # # evals, _ = torch.linalg.eig(var[0, k]) + # log_det = torch.log(evals).sum(dim=-1).to(var.dtype) + log_det = torch.empty(size=(self.n_components,), device=var.device, dtype=var.dtype) for k in range(self.n_components): evals = torch.linalg.eigvals(var[0, k]) # evals, _ = torch.linalg.eig(var[0, k]) - log_det[i] = torch.log(evals).sum().to(var.dtype) + log_det[k] = torch.log(evals).sum().to(var.dtype) return log_det.unsqueeze(-1) From 7f862dddbcfe32929173c95cf3a796511a0e7ecb Mon Sep 17 00:00:00 2001 From: DeMoriarty Date: Sun, 13 Mar 2022 20:40:06 +0300 Subject: [PATCH 22/31] . --- gmm.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/gmm.py b/gmm.py index 348f945..c85bba5 100644 --- a/gmm.py +++ b/gmm.py @@ -341,15 +341,15 @@ def _calculate_log_det(self, var): - # evals = torch.linalg.eigvals(var[0]) - # # evals, _ = torch.linalg.eig(var[0, k]) - # log_det = torch.log(evals).sum(dim=-1).to(var.dtype) + evals = torch.linalg.eigvals(var[0]) + # evals, _ = torch.linalg.eig(var[0, k]) + log_det = torch.log(evals).sum(dim=-1).to(var.dtype) - log_det = torch.empty(size=(self.n_components,), device=var.device, dtype=var.dtype) - for k in range(self.n_components): - evals = torch.linalg.eigvals(var[0, k]) - # evals, _ = torch.linalg.eig(var[0, k]) - log_det[k] = torch.log(evals).sum().to(var.dtype) + # log_det = torch.empty(size=(self.n_components,), device=var.device, dtype=var.dtype) + # for k in range(self.n_components): + # evals = torch.linalg.eigvals(var[0, k]) + # # evals, _ = torch.linalg.eig(var[0, k]) + # log_det[k] = torch.log(evals).sum().to(var.dtype) return log_det.unsqueeze(-1) From 06baed345a6ef079e62d3d39c104c6925fdb34fc Mon Sep 17 00:00:00 2001 From: DeMoriarty Date: Sun, 13 Mar 2022 21:25:01 +0300 Subject: [PATCH 23/31] Create benchmark.md --- benchmark.md | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 benchmark.md diff --git a/benchmark.md b/benchmark.md new file mode 100644 index 0000000..e58f59a --- /dev/null +++ b/benchmark.md @@ -0,0 +1,19 @@ +# Benchmark +GPU: Tesla T4 (16GM DRAM) + +- covariance_type = "full" +- init_params = "random" +- n_ter = 20 +- delta = 0 + +| setup | original | k-loop | optimized (single) | optimized (double) | +| --- | --- | --- | --- | --- | +| n_features=16, n_components=16, n_data=100,000 | 6.9 | 6.9s | 0.5s | 3.44s | +| n_features=16, n_components=16, n_data=1,000,000 | OOM | 68.8s | 3.7s | 34.0s | +| n_features=64, n_components=64, n_data=100,000 | OOM | 161s | 3.57s | 13.9s | +| n_features=64, n_components=64, n_data=1,000,000 | OOM | OOM | 44.4s | 527s | +| n_features=256, n_components=256, n_data=100,000 | OOM | OOM | NAN | 686s | +| n_features=256, n_components=16, n_data=1,000,000 | OOM | OOM | 60s | 454s | + +- OOM: Out Of Memory +- NAN: Covar contains NaN From d5215a281b65f0b4bc13b06bcb416b25fa541268 Mon Sep 17 00:00:00 2001 From: DeMoriarty Date: Sun, 13 Mar 2022 21:29:24 +0300 Subject: [PATCH 24/31] . --- gmm.py | 28 +++++++++++----------------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/gmm.py b/gmm.py index c85bba5..d39c142 100644 --- a/gmm.py +++ b/gmm.py @@ -334,23 +334,17 @@ def _calculate_log_det(self, var): assert (var != var).sum() == 0, "`var` contains NaN, set `covariance_data_type` to double" assert (var.abs() == float("inf")).sum() == 0, "`var` contains inf, set `covariance_data_type` to double" - # cholesky = torch.linalg.cholesky(var[0]) - # diagonal = torch.diagonal(cholesky, dim1=-2, dim2=-1) - # del cholesky - # log_det = 2 * torch.log(diagonal).sum(dim=-1) - - - - evals = torch.linalg.eigvals(var[0]) - # evals, _ = torch.linalg.eig(var[0, k]) - log_det = torch.log(evals).sum(dim=-1).to(var.dtype) - - # log_det = torch.empty(size=(self.n_components,), device=var.device, dtype=var.dtype) - # for k in range(self.n_components): - # evals = torch.linalg.eigvals(var[0, k]) - # # evals, _ = torch.linalg.eig(var[0, k]) - # log_det[k] = torch.log(evals).sum().to(var.dtype) - + if self.covariance_data_type == torch.float: + evals = torch.linalg.eigvals(var[0]) + # evals, _ = torch.linalg.eig(var[0, k]) + log_det = torch.log(evals).sum(dim=-1).to(var.dtype) + + elif self.covariance_data_type == torch.double: + cholesky = torch.linalg.cholesky(var[0]) + diagonal = torch.diagonal(cholesky, dim1=-2, dim2=-1) + del cholesky + log_det = 2 * torch.log(diagonal).sum(dim=-1) + return log_det.unsqueeze(-1) From d8941a4bc932e2268ba80f890f1fd3a067556018 Mon Sep 17 00:00:00 2001 From: DeMoriarty Date: Sun, 13 Mar 2022 21:39:20 +0300 Subject: [PATCH 25/31] . --- gmm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gmm.py b/gmm.py index d39c142..e1e4c82 100644 --- a/gmm.py +++ b/gmm.py @@ -355,7 +355,7 @@ def _e_step(self, x, use_prev_log_prob=False): This is the so-called expectation step of the EM-algorithm. args: x: torch.Tensor (n, d) or (n, 1, d) - use_prev_log_prob: bool + use_prev_log_prob: bool returns: log_prob_norm: torch.Tensor (1) log_resp: torch.Tensor (n, k, 1) @@ -381,7 +381,7 @@ def _m_step(self, x, resp): From the log-probabilities, computes new parameters pi, mu, var (that maximize the log-likelihood). This is the maximization step of the EM-algorithm. args: x: torch.Tensor (n, d) or (n, 1, d) - resp: torch.Tensor (n, k, 1) + resp: torch.Tensor (n, k, 1) returns: pi: torch.Tensor (1, k, 1) mu: torch.Tensor (1, k, d) From 13b2d70e2ef2e7b05da3ec939b6e8d70eb50cc94 Mon Sep 17 00:00:00 2001 From: DeMoriarty Date: Sun, 13 Mar 2022 22:47:57 +0300 Subject: [PATCH 26/31] Update benchmark.md --- benchmark.md | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/benchmark.md b/benchmark.md index e58f59a..919a222 100644 --- a/benchmark.md +++ b/benchmark.md @@ -15,5 +15,24 @@ GPU: Tesla T4 (16GM DRAM) | n_features=256, n_components=256, n_data=100,000 | OOM | OOM | NAN | 686s | | n_features=256, n_components=16, n_data=1,000,000 | OOM | OOM | 60s | 454s | +### Notes: - OOM: Out Of Memory - NAN: Covar contains NaN +- k-loop: almost the same as original `GaussianMixture`, except +``` +var = torch.sum((x - mu).unsqueeze(-1).matmul((x - mu).unsqueeze(-2)) * resp.unsqueeze(-1), dim=0, + keepdim=True) / torch.sum(resp, dim=0, keepdim=True).unsqueeze(-1) + eps +``` +in `_m_step` is replaced with +``` +var = torch.empty(1, self.n_components, self.n_features, self.n_features, device=x.device, dtype=resp.dtype) +eps = (torch.eye(self.n_features) * self.eps).to(x.device) +for i in range(self.n_components): + sub_mu = mu[:, i, :] + sub_resp = resp[:, i, :] + sub_x_mu = (x - sub_mu).squeeze(1) + outer = torch.matmul(sub_x_mu[:, :, None], sub_x_mu[:, None, :]) + outer_sum = torch.sum(outer * sub_resp[:, :, None], dim=0, keepdim=True) + sub_var = outer_sum / resp_sum[i] + eps + var[:, i, :, :] = sub_var +``` From e223a5739f0a59223baa4d57169507f234bdf371 Mon Sep 17 00:00:00 2001 From: DeMoriarty Date: Sun, 13 Mar 2022 22:48:18 +0300 Subject: [PATCH 27/31] Update benchmark.md --- benchmark.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmark.md b/benchmark.md index 919a222..ce5f1a4 100644 --- a/benchmark.md +++ b/benchmark.md @@ -19,12 +19,12 @@ GPU: Tesla T4 (16GM DRAM) - OOM: Out Of Memory - NAN: Covar contains NaN - k-loop: almost the same as original `GaussianMixture`, except -``` +```python var = torch.sum((x - mu).unsqueeze(-1).matmul((x - mu).unsqueeze(-2)) * resp.unsqueeze(-1), dim=0, keepdim=True) / torch.sum(resp, dim=0, keepdim=True).unsqueeze(-1) + eps ``` in `_m_step` is replaced with -``` +```python var = torch.empty(1, self.n_components, self.n_features, self.n_features, device=x.device, dtype=resp.dtype) eps = (torch.eye(self.n_features) * self.eps).to(x.device) for i in range(self.n_components): From 2572c1af254d8cbcf0458c77f003c794b583fce7 Mon Sep 17 00:00:00 2001 From: DeMoriarty Date: Sun, 13 Mar 2022 23:01:34 +0300 Subject: [PATCH 28/31] Update benchmark.md --- benchmark.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmark.md b/benchmark.md index ce5f1a4..afab70f 100644 --- a/benchmark.md +++ b/benchmark.md @@ -8,7 +8,7 @@ GPU: Tesla T4 (16GM DRAM) | setup | original | k-loop | optimized (single) | optimized (double) | | --- | --- | --- | --- | --- | -| n_features=16, n_components=16, n_data=100,000 | 6.9 | 6.9s | 0.5s | 3.44s | +| n_features=16, n_components=16, n_data=100,000 | 6.9s | 6.9s | 0.5s | 3.44s | | n_features=16, n_components=16, n_data=1,000,000 | OOM | 68.8s | 3.7s | 34.0s | | n_features=64, n_components=64, n_data=100,000 | OOM | 161s | 3.57s | 13.9s | | n_features=64, n_components=64, n_data=1,000,000 | OOM | OOM | 44.4s | 527s | From e50457b5a1a13e21c2281f83203e7f75d217986e Mon Sep 17 00:00:00 2001 From: DeMoriarty Date: Sun, 13 Mar 2022 23:03:06 +0300 Subject: [PATCH 29/31] Update benchmark.md --- benchmark.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmark.md b/benchmark.md index afab70f..db5037a 100644 --- a/benchmark.md +++ b/benchmark.md @@ -1,5 +1,5 @@ # Benchmark -GPU: Tesla T4 (16GM DRAM) +GPU: Tesla T4 (16GB DRAM) - covariance_type = "full" - init_params = "random" From d750bc7bab27097e621c36ace1f8859945cc2777 Mon Sep 17 00:00:00 2001 From: DeMoriarty Date: Mon, 14 Mar 2022 16:58:28 +0300 Subject: [PATCH 30/31] Update benchmark.md --- benchmark.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmark.md b/benchmark.md index db5037a..861244c 100644 --- a/benchmark.md +++ b/benchmark.md @@ -3,7 +3,7 @@ GPU: Tesla T4 (16GB DRAM) - covariance_type = "full" - init_params = "random" -- n_ter = 20 +- n_iter = 20 - delta = 0 | setup | original | k-loop | optimized (single) | optimized (double) | From 9bc2062573b376fb6a0c0d380df2141dc5cef27e Mon Sep 17 00:00:00 2001 From: DeMoriarty Date: Mon, 14 Mar 2022 16:59:52 +0300 Subject: [PATCH 31/31] Update benchmark.md --- benchmark.md | 1 + 1 file changed, 1 insertion(+) diff --git a/benchmark.md b/benchmark.md index 861244c..51e9b19 100644 --- a/benchmark.md +++ b/benchmark.md @@ -18,6 +18,7 @@ GPU: Tesla T4 (16GB DRAM) ### Notes: - OOM: Out Of Memory - NAN: Covar contains NaN +- Single/Double: dtype of covariance matrix - k-loop: almost the same as original `GaussianMixture`, except ```python var = torch.sum((x - mu).unsqueeze(-1).matmul((x - mu).unsqueeze(-2)) * resp.unsqueeze(-1), dim=0,