Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
jafioti committed Jun 22, 2024
2 parents 4db3120 + f61d53f commit 14c269e
Show file tree
Hide file tree
Showing 40 changed files with 891 additions and 2,369 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,4 @@ members = [
"crates/luminal_nn",
"crates/luminal_training",
]
exclude = ["crates/luminal_cuda", "crates/luminal_metal"]
exclude = ["crates/luminal_cuda", "crates/luminal_metal", "crates/luminal_metal_super"]
2 changes: 1 addition & 1 deletion crates/luminal_metal/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ license = "MIT OR Apache-2.0"
[dependencies]
itertools = "0.12.1"
luminal = { path = "../.." }
metal-rs = { version = "0.27.0", package = "metal", features = ["mps"] }
metal-rs = { version = "0.28.0", package = "metal", features = ["mps"] }
num-traits = "0.2.18"
regex = "1.10.4"
rustc-hash = "1.1.0"
Expand Down
10 changes: 8 additions & 2 deletions crates/luminal_metal/src/elementwise_fusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,10 @@ impl<T: MetalFloat> Compiler for ElementwiseFusionCompiler<T> {
.sorted_by_key(|(i, _, _)| *i)
.map(|(_, _, s)| s)
.collect::<Vec<_>>(),
);
)
.into_iter()
.map(|s| s.simplify())
.collect();
let new_op = graph
.add_op(FusedElementwiseOp::<T> {
kernel_str: "".to_string(),
Expand Down Expand Up @@ -283,7 +286,10 @@ impl<T: MetalFloat> Compiler for ElementwiseFusionCompiler<T> {
let output_buffer_sizes = graph
.node_custom::<MetalKernelWrapper, _>(op, "metal", ())
.unwrap()
.output_buffer_sizes(&input_shapes);
.output_buffer_sizes(&input_shapes)
.into_iter()
.map(|s| s.simplify())
.collect();
let new_op = graph
.add_op(FusedElementwiseOp::<T> {
kernel_str: "".to_string(),
Expand Down
361 changes: 360 additions & 1 deletion crates/luminal_metal/src/kernels/bf16.h

Large diffs are not rendered by default.

365 changes: 0 additions & 365 deletions crates/luminal_metal/src/kernels/bf16_math.h

This file was deleted.

4 changes: 2 additions & 2 deletions crates/luminal_metal/src/kernels/gemm.metal
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright © 2023 Apple Inc.

#include "KERNEL_PATH/bf16.h"
#include "KERNEL_PATH/gemm.h"
BF16.H
GEMM.H

using namespace metal;

Expand Down
6 changes: 3 additions & 3 deletions crates/luminal_metal/src/kernels/gemv.metal
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
#include <metal_stdlib>
#include <metal_simdgroup>

#include "KERNEL_PATH/bf16.h"
#include "KERNEL_PATH/defines.h"
#include "KERNEL_PATH/utils.h"
BF16.H
DEFINES.H
UTILS.H

using namespace metal;

Expand Down
228 changes: 0 additions & 228 deletions crates/luminal_metal/src/kernels/softmax.metal

This file was deleted.

Loading

0 comments on commit 14c269e

Please sign in to comment.