diff --git a/src/linalg/tensor/transformation.rs b/src/linalg/tensor/transformation.rs index 0992fdf..8a724f7 100644 --- a/src/linalg/tensor/transformation.rs +++ b/src/linalg/tensor/transformation.rs @@ -51,16 +51,42 @@ impl Tensor { /// # use engram::*; /// let a = tensor![[1.0, 2.0], [3.0, 4.0]]; /// let b = tensor![[1.0, 2.0, 8.0], [3.0, 4.0, 9.0]]; - /// let c = a.broadcast_to(&b); + /// let c = a.broadcast(&b); /// assert_eq!(b.data, vec![vec![1.0, 2.0, 8.0], vec![3.0, 4.0, 9.0]]); /// ``` - pub fn broadcast_to(&self, other: &Tensor) -> Tensor { - let mut res = Tensor::zeros(other.rows, other.cols); - for i in 0..other.rows { - for j in 0..other.cols { - res.data[i][j] = self.data[i % self.rows][j % self.cols]; + pub fn broadcast(&self, other: &Tensor) -> Tensor { + if self.rows == other.rows && self.cols == other.cols { + return other.clone(); + } + + if other.rows == 1 && self.cols == other.cols { + let mut res = Tensor::zeros(self.rows, other.cols); + for i in 0..self.rows { + for j in 0..other.cols { + res.data[i][j] = other.data[0][j]; + } + } + Tensor { + rows: self.rows, + cols: other.cols, + data: res.data, + grad: None, } + } else if other.cols == 1 && self.rows == other.rows { + let mut res = Tensor::zeros(self.rows, self.cols); + for i in 0..other.rows { + for j in 0..self.cols { + res.data[i][j] = other.data[i][0]; + } + } + Tensor { + rows: other.rows, + cols: self.cols, + data: res.data, + grad: None, + } + } else { + other.clone() } - res } }