Skip to content

Commit

Permalink
fix(tensor/broadcast): update implementation to handle 1D->2D
Browse files Browse the repository at this point in the history
  • Loading branch information
drewxs committed Mar 10, 2024
1 parent 95cd5d0 commit bfbc6e3
Showing 1 changed file with 33 additions and 7 deletions.
40 changes: 33 additions & 7 deletions src/linalg/tensor/transformation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,42 @@ impl Tensor {
/// # use engram::*;
/// let a = tensor![[1.0, 2.0], [3.0, 4.0]];
/// let b = tensor![[1.0, 2.0, 8.0], [3.0, 4.0, 9.0]];
/// let c = a.broadcast_to(&b);
/// let c = a.broadcast(&b);
/// assert_eq!(b.data, vec![vec![1.0, 2.0, 8.0], vec![3.0, 4.0, 9.0]]);
/// ```
pub fn broadcast_to(&self, other: &Tensor) -> Tensor {
let mut res = Tensor::zeros(other.rows, other.cols);
for i in 0..other.rows {
for j in 0..other.cols {
res.data[i][j] = self.data[i % self.rows][j % self.cols];
pub fn broadcast(&self, other: &Tensor) -> Tensor {
if self.rows == other.rows && self.cols == other.cols {
return other.clone();
}

if other.rows == 1 && self.cols == other.cols {
let mut res = Tensor::zeros(self.rows, other.cols);
for i in 0..self.rows {
for j in 0..other.cols {
res.data[i][j] = other.data[0][j];
}
}
Tensor {
rows: self.rows,
cols: other.cols,
data: res.data,
grad: None,
}
} else if other.cols == 1 && self.rows == other.rows {
let mut res = Tensor::zeros(self.rows, self.cols);
for i in 0..other.rows {
for j in 0..self.cols {
res.data[i][j] = other.data[i][0];
}
}
Tensor {
rows: other.rows,
cols: self.cols,
data: res.data,
grad: None,
}
} else {
other.clone()
}
res
}
}

0 comments on commit bfbc6e3

Please sign in to comment.