Skip to content

Commit 145dc33

Browse files
Fix some type ambiguities (#585)
1 parent 30432ce commit 145dc33

File tree

5 files changed

+16
-6
lines changed

5 files changed

+16
-6
lines changed

lib/mps/random.jl

+1
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ function Random.rand!(rng::RNG, A::AbstractArray{T, N}) where {T <: Union{Unifor
8383
end
8484
return A
8585
end
86+
8687
function Random.randn!(rng::RNG, A::AbstractArray{T, N}) where {T <: Float32, N}
8788
isempty(A) && return A
8889
if MTL.can_alloc_nocopy(pointer(A), sizeof(A))

lib/mpsgraphs/execution.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11

2-
MPS.encode!(commandBuffer::MPSCommandBuffer, graph::MPSGraph, feeds::MPSGraphTensorDataDictionary, resultsDictionary::MPSGraphTensorDataDictionary) = @inline MPS.encode!(commandBuffer, graph, feeds, nil, resultsDictionary, MPSGraphExecutionDescriptor())
3-
function MPS.encode!(commandBuffer::MPSCommandBuffer, graph::MPSGraph, feeds::MPSGraphTensorDataDictionary, targetOperations, resultsDictionary::MPSGraphTensorDataDictionary, executionDescriptor)
2+
MPS.encode!(commandBuffer::MPSCommandBuffer, graph::MPSGraph, feeds::MPSGraphTensorDataDictionary, resultsDictionary::MPSGraphTensorDataDictionary) = @inline MPS.encode!(commandBuffer, graph, feeds, resultsDictionary, nil, MPSGraphExecutionDescriptor())
3+
function MPS.encode!(commandBuffer::MPSCommandBuffer, graph::MPSGraph, feeds::MPSGraphTensorDataDictionary, resultsDictionary::MPSGraphTensorDataDictionary, targetOperations, executionDescriptor::MPSGraphExecutionDescriptor)
44
@objc [graph::id{MPSGraph} encodeToCommandBuffer:commandBuffer::id{MPSCommandBuffer}
55
feeds:feeds::id{MPSGraphTensorDataDictionary}
66
targetOperations:targetOperations::id{Object}
@@ -9,7 +9,7 @@ function MPS.encode!(commandBuffer::MPSCommandBuffer, graph::MPSGraph, feeds::MP
99
return resultsDictionary
1010
end
1111

12-
function MPS.encode!(commandBuffer::MPSCommandBuffer, graph::MPSGraph, feeds::MPSGraphTensorDataDictionary, targetTensors::NSArray, targetOperations=nil, executionDescriptor=MPSGraphExecutionDescriptor())
12+
function MPS.encode!(commandBuffer::MPSCommandBuffer, graph::MPSGraph, feeds::MPSGraphTensorDataDictionary, targetTensors::NSArray, targetOperations=nil, executionDescriptor::MPSGraphExecutionDescriptor=MPSGraphExecutionDescriptor())
1313
obj = @objc [graph::id{MPSGraph} encodeToCommandBuffer:commandBuffer::id{MPSCommandBuffer}
1414
feeds:feeds::id{MPSGraphTensorDataDictionary}
1515
targetTensors:targetTensors::id{NSArray}

lib/mpsgraphs/matmul.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ end
8686
)
8787

8888
cmdbuf = MPSCommandBuffer(Metal.global_queue(device()))
89-
encode!(cmdbuf, graph, NSDictionary(feeds), nil, NSDictionary(resultdict), default_exec_desc())
89+
encode!(cmdbuf, graph, NSDictionary(feeds), NSDictionary(resultdict), nil, default_exec_desc())
9090
commit!(cmdbuf)
9191
wait_completed(cmdbuf)
9292

src/indexing.jl

+5-2
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,12 @@ function Base.findall(bools::WrappedMtlArray{Bool})
5252
return ys
5353
end
5454

55-
function Base.findall(f::Function, A::WrappedMtlArray)
55+
@inline function _findall(f, A)
5656
bools = map(f, A)
5757
ys = findall(bools)
5858
unsafe_free!(bools)
5959
return ys
60-
end
60+
end
61+
62+
Base.findall(f::Function, A::WrappedMtlArray) = _findall(f, A)
63+
Base.findall(f::Base.Fix2{typeof(in)}, A::WrappedMtlArray) = _findall(f, A)

test/array.jl

+6
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,12 @@ end
540540
let x = rand(Float32, 1000, 1000)
541541
@test findall(y->y>Float32(0.5), x) == Array(findall(y->y>Float32(0.5), MtlArray(x)))
542542
end
543+
544+
# ambiguity
545+
let f = in(3)
546+
x = MtlArray([1, 2, 3, 4, 5, 3])
547+
@test Array(findall(f, x)) == [3, 6]
548+
end
543549
end
544550

545551
@testset "broadcast" begin

0 commit comments

Comments
 (0)