-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathdot.rs
103 lines (101 loc) · 4.92 KB
/
dot.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
use tensor::Tensor;
use blas;
macro_rules! add_impl {
($t:ty, $gemv:ident, $gemm:ident, $dot:ident) => (
impl Tensor<$t> {
/// Takes the product of two tensors. If the tensors are both matrices (2D), then a
/// matrix multiplication is taken. If the tensors are both vectors (1D), the scalar
/// product is taken.
pub fn dot(&self, rhs: &Tensor<$t>) -> Tensor<$t> {
if self.ndim() == 2 && rhs.ndim() == 1 {
assert_eq!(self.shape[1], rhs.shape[0]);
let mut t3 = Tensor::empty(&[self.shape[0]]);
{
let mut data = t3.slice_mut();
if cfg!(noblas) {
// Naive implementation, BLAS will be much faster
for i in 0..self.shape[0] {
let mut v = 0.0;
for k in 0..self.shape[1] {
v += self.get2(i, k) * rhs.data[k];
}
data[i] = v;
}
} else {
let t1 = self.canonize();
let t2 = rhs.canonize();
blas::$gemv(b'T', t1.shape[1] as i32, t1.shape[0] as i32, 1.0, &t1.data,
t1.shape[1] as i32, &t2.data, 1, 0.0, data, 1);
}
}
t3
} else if self.ndim() == 1 && rhs.ndim() == 2 {
assert_eq!(self.shape[0], rhs.shape[0]);
let mut t3 = Tensor::empty(&[rhs.shape[1]]);
{
let mut data = t3.slice_mut();
if cfg!(noblas) {
// Naive implementation, BLAS will be much faster
for i in 0..rhs.shape[1] {
let mut v = 0.0;
for k in 0..rhs.shape[0] {
v += self.data[k] * rhs.get2(k, i);
}
data[i] = v;
}
} else {
let t1 = self.canonize();
let t2 = rhs.canonize();
blas::$gemv(b'N', t2.shape[1] as i32, t2.shape[0] as i32, 1.0, &t2.data,
t2.shape[1] as i32, &t1.data, 1, 0.0, data, 1);
}
}
t3
} else if self.ndim() == 2 && rhs.ndim() == 2 {
assert_eq!(self.shape[1], rhs.shape[0]);
let mut t3 = Tensor::empty(&[self.shape[0], rhs.shape[1]]);
if cfg!(noblas) {
// Naive implementation, BLAS will be much faster
for i in 0..self.shape[0] {
for j in 0..rhs.shape[1] {
let mut v = 0.0;
for k in 0..self.shape[1] {
v += self.get2(i, k) * rhs.get2(k, j);
}
t3.set2(i, j, v);
}
}
} else {
// Note: dgemm assumes column-major while we have row-major,
// so we have to re-arrange things a bit
let t1 = self.canonize();
let t2 = rhs.canonize();
let mut data = t3.slice_mut();
blas::$gemm(b'N', b'N', t2.shape[1] as i32, t1.shape[0] as i32, t2.shape[0] as i32, 1.0,
&t2.data, t2.shape[1] as i32, &t1.data, t2.shape[0] as i32, 0.0,
data, t2.shape[1] as i32);
}
t3
} else if self.ndim() == 1 && rhs.ndim() == 1 { // scalar product
assert_eq!(self.size(), rhs.size());
let mut v = 0.0;
if cfg!(noblas) {
// Naive implementation, BLAS will be much faster
for (v1, v2) in self.iter().zip(rhs.iter()) {
v += v1 * v2;
}
} else {
let t1 = self.canonize();
let t2 = rhs.canonize();
v = blas::$dot(t1.size() as i32, &t1.data, 1, &t2.data, 1);
}
Tensor::scalar(v)
} else {
panic!("Dot product is not supported for the matrix dimensions provided");
}
}
}
)
}
add_impl!(f32, sgemv, sgemm, sdot);
add_impl!(f64, dgemv, dgemm, ddot);