Skip to content

Commit

Permalink
small improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
ivarflakstad committed Oct 30, 2023
1 parent aed9b29 commit 2c1490c
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 25 deletions.
10 changes: 6 additions & 4 deletions examples/mps/matrix-multiplication/main.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use metal::mps::*;
use metal::*;
use rand::{thread_rng, Rng};
use std::io;
use std::io::Write;
use std::ops::{AddAssign, Mul};
use std::{array, io};

use rand::{thread_rng, Rng};

use metal::mps::*;
use metal::*;

fn main() {
correctness();
Expand Down
50 changes: 29 additions & 21 deletions src/mps.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.

use super::*;
use half::{bf16, f16};
use objc::runtime::{BOOL, YES};
use std::fmt::{Debug, Display, Formatter};
use std::hash::Hash;

use half::{bf16, f16};
use objc::runtime::{BOOL, YES};

use super::*;

#[cfg_attr(
feature = "link",
link(name = "MetalPerformanceShaders", kind = "framework")
Expand Down Expand Up @@ -780,8 +782,11 @@ pub struct GEMMInput<T: MPSDataType> {
/// Input data type must be one of MPSDataTypeFloat32, MPSDataTypeFloat16, MPSDataTypeInt8,
/// or MPSDataTypeInt16
impl Valid for GEMMInput<Float16> {}

impl Valid for GEMMInput<Float32> {}

impl Valid for GEMMInput<Int8> {}

impl Valid for GEMMInput<Int16> {}

/// Helper struct used to indicate a valid matrix multiplication result type.
Expand All @@ -791,6 +796,7 @@ pub struct GEMMResult<T: MPSDataType> {

/// Only MPSDataTypeFloat16 and MPSDataTypeFloat32 are supported for the result matrix.
impl Valid for GEMMResult<Float16> {}

impl Valid for GEMMResult<Float32> {}

/// Helper struct used to indicate valid matrix multiplication types.
Expand Down Expand Up @@ -819,7 +825,9 @@ where

/// These input types can produce a MPSDataTypeFloat16 result.
impl Valid for GEMMSpecification<Int8, Int8, Float16> {}

impl Valid for GEMMSpecification<Int16, Int16, Float16> {}

impl Valid for GEMMSpecification<Float16, Float16, Float16> {}

/// See <https://developer.apple.com/documentation/metalperformanceshaders/mpsmatrixdescriptor?language=objc>
Expand Down Expand Up @@ -902,7 +910,8 @@ foreign_obj_type! {
/// Generic matrix for MPSDataTypes.
#[derive(Debug)]
pub struct Matrix<T: MPSDataType> {
entries: Vec<T::Type>, // row-major order
entries: Vec<T::Type>,
// row-major order
rows: NSUInteger,
columns: NSUInteger,
}
Expand Down Expand Up @@ -1158,9 +1167,11 @@ impl MatrixMultiplicationRef {
}

pub struct MatrixBuffer<T> {
pub buffer: Buffer,
buffer: Buffer,
rows: NSUInteger,
columns: NSUInteger,
count: usize,
allocated_size: usize,
_marker: PhantomData<T>,
}

Expand All @@ -1177,16 +1188,18 @@ impl<T: MPSDataType> MatrixBuffer<T> {
buffer,
rows,
columns,
count: (rows * columns) as usize,
allocated_size: length as usize,
_marker: PhantomData,
}
}

pub fn count(&self) -> usize {
(self.rows * self.columns) as usize
self.count
}

pub fn contents(&self) -> Vec<T::Type> {
self.buffer.read_to_vec(self.count())
self.buffer.read_to_vec(self.count)
}
}

Expand All @@ -1209,23 +1222,18 @@ where
GEMMResult<C>: Valid,
GEMMSpecification<A, B, C>: Valid,
{
let M = if transpose_left {
left.columns
} else {
left.rows
};
let N = if transpose_right {
right.rows
let (M, K) = if transpose_left {
(left.columns, left.rows)
} else {
right.columns
(left.rows, left.columns)
};
let K = if transpose_left {
left.rows
let (N, B_K) = if transpose_right {
(right.rows, right.columns)
} else {
left.columns
(right.columns, right.rows)
};

validate_shapes(M, N, K);
validate_shapes(M, N, K, B_K);

// Create descriptors for the matrices.
let left_row_bytes = MatrixDescriptor::row_bytes_for_columns(K, A::TYPE_ID);
Expand All @@ -1249,7 +1257,6 @@ where
// Create matrix objects
let left_matrix =
MatrixObject::init_with_buffer_descriptor(&left_buffer, &left_descriptor).unwrap();

let right_matrix =
MatrixObject::init_with_buffer_descriptor(&right_buffer, &right_descriptor).unwrap();
let result_matrix =
Expand Down Expand Up @@ -1281,13 +1288,14 @@ where
result_buffer
}

fn validate_shapes(M: NSUInteger, N: NSUInteger, K: NSUInteger) {
fn validate_shapes(M: NSUInteger, N: NSUInteger, K: NSUInteger, B_K: NSUInteger) {
// Certain constraints apply to the sizes of the matrices depending on the transposition
// operations and sizes requested at initialization time as well as the origins at the time
// this routine is called:
assert!(M > 0);
assert!(N > 0);
assert!(K > 0);
assert_eq!(K, B_K);
// Left column size must equal right row size.
assert_eq!(K, N);

Expand Down

0 comments on commit 2c1490c

Please sign in to comment.