Skip to content

Commit 750a6ee

Browse files
authored
add support for Tensors from slices (#90)
* feat: tensor from slice + CPU to array perf * chore: crystal tool format
1 parent 2aa2f6d commit 750a6ee

File tree

6 files changed

+122
-31
lines changed

6 files changed

+122
-31
lines changed

spec/extensions/slice_spec.cr

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright (c) 2021 Crystal Data Contributors
2+
#
3+
# MIT License
4+
#
5+
# Permission is hereby granted, free of charge, to any person obtaining
6+
# a copy of this software and associated documentation files (the
7+
# "Software"), to deal in the Software without restriction, including
8+
# without limitation the rights to use, copy, modify, merge, publish,
9+
# distribute, sublicense, and/or sell copies of the Software, and to
10+
# permit persons to whom the Software is furnished to do so, subject to
11+
# the following conditions:
12+
#
13+
# The above copyright notice and this permission notice shall be
14+
# included in all copies or substantial portions of the Software.
15+
#
16+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
17+
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
18+
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
19+
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
20+
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
21+
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
22+
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
23+
24+
describe Slice do
25+
it "creates a Tensor from a stdlib slice" do
26+
s = Slice.new(5) { |i| (i + 10).to_u8 }
27+
t = s.to_tensor
28+
t.shape.should eq [5]
29+
t.to_a.should eq s.to_a
30+
end
31+
end

spec/linalg/work_spec.cr

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,29 @@
1-
# Copyright (c) 2022 Crystal Data Contributors
2-
#
3-
# MIT License
4-
#
5-
# Permission is hereby granted, free of charge, to any person obtaining
6-
# a copy of this software and associated documentation files (the
7-
# "Software"), to deal in the Software without restriction, including
8-
# without limitation the rights to use, copy, modify, merge, publish,
9-
# distribute, sublicense, and/or sell copies of the Software, and to
10-
# permit persons to whom the Software is furnished to do so, subject to
11-
# the following conditions:
12-
#
13-
# The above copyright notice and this permission notice shall be
14-
# included in all copies or substantial portions of the Software.
15-
#
16-
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
17-
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
18-
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
19-
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
20-
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
21-
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
22-
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
23-
24-
describe Num::WorkPool do
25-
it "can allocate complex numbers" do
26-
pool = Num::WorkPool.new
27-
pool.get_cmplx(10)
28-
end
29-
end
1+
# Copyright (c) 2022 Crystal Data Contributors
2+
#
3+
# MIT License
4+
#
5+
# Permission is hereby granted, free of charge, to any person obtaining
6+
# a copy of this software and associated documentation files (the
7+
# "Software"), to deal in the Software without restriction, including
8+
# without limitation the rights to use, copy, modify, merge, publish,
9+
# distribute, sublicense, and/or sell copies of the Software, and to
10+
# permit persons to whom the Software is furnished to do so, subject to
11+
# the following conditions:
12+
#
13+
# The above copyright notice and this permission notice shall be
14+
# included in all copies or substantial portions of the Software.
15+
#
16+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
17+
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
18+
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
19+
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
20+
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
21+
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
22+
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
23+
24+
describe Num::WorkPool do
25+
it "can allocate complex numbers" do
26+
pool = Num::WorkPool.new
27+
pool.get_cmplx(10)
28+
end
29+
end

src/api.cr

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ require "complex"
88
require "arrow"
99
{% end %}
1010

11+
require "./extensions/slice"
1112
require "./extensions/array"
1213
require "./extensions/number"
1314

src/extensions/slice.cr

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright (c) 2020 Crystal Data Contributors
2+
#
3+
# MIT License
4+
#
5+
# Permission is hereby granted, free of charge, to any person obtaining
6+
# a copy of this software and associated documentation files (the
7+
# "Software"), to deal in the Software without restriction, including
8+
# without limitation the rights to use, copy, modify, merge, publish,
9+
# distribute, sublicense, and/or sell copies of the Software, and to
10+
# permit persons to whom the Software is furnished to do so, subject to
11+
# the following conditions:
12+
#
13+
# The above copyright notice and this permission notice shall be
14+
# included in all copies or substantial portions of the Software.
15+
#
16+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
17+
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
18+
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
19+
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
20+
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
21+
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
22+
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
23+
24+
struct Slice(T)
25+
# Converts a standard library slice to a `Tensor`.
26+
# The type of Tensor is inferred from the element type, and alternative
27+
# shapes can be provided.
28+
#
29+
# ## Arguments
30+
#
31+
# * device : `Num::Storage` - The storage backend on which to place the `Tensor`
32+
# * shape : `Array(Int32)?` - An optional shape this slice represents
33+
#
34+
# ## Examples
35+
#
36+
# ```
37+
# s = Slice.new(200) { |i| (i + 10).to_u8 }
38+
# typeof(s.to_tensor) # => Tensor(UInt8, CPU(Float32))
39+
# ```
40+
def to_tensor(device = CPU, shape : Array(Int32)? = nil)
41+
Tensor.from_slice self, device: device, shape: shape
42+
end
43+
end

src/tensor/allocation.cr

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ class Tensor(T, S)
6767
# without having to provide a specific generic type to the array.
6868
# Since the storage instance is S, having the array passed
6969
# allows Num to infer T
70-
private def initialize(@data : S, shape : Array(Int), from_array : Array(T))
70+
private def initialize(@data : S, shape : Array(Int), from_array : Array(T) | Slice(T))
7171
assert_types
7272
@shape = shape.map &.to_i
7373
@strides = Num::Internal.shape_to_strides(shape, Num::RowMajor)
@@ -155,6 +155,22 @@ class Tensor(T, S)
155155
new(storage, shape, from_array: flat)
156156
end
157157

158+
# Creates a Tensor from a standard library slice onto a specified device.
159+
# The type of Tensor is inferred from the element type, and alternative
160+
# shapes can be provided.
161+
#
162+
# ## Examples
163+
#
164+
# ```
165+
# s = Slice.new(200) { |i| (i + 10).to_u8 }
166+
# Tensor.from_array(s, device: OCL) # => [200] Tensor stored on a GPU
167+
# ```
168+
def self.from_slice(s : Slice, device = CPU, shape : Array(Int32)? = nil)
169+
shape = shape || [s.size]
170+
storage = device.new(s.to_unsafe, shape, Num::Internal.shape_to_strides(shape))
171+
new(storage, shape, from_array: s)
172+
end
173+
158174
# Creates a `Tensor` of a provided shape, filled with 0. The generic type
159175
# must be specified.
160176
#

src/tensor/backends/cpu/impl_convert.cr

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ module Num
3636
# a.to_a # => [0, 1, 2, 3]
3737
# ```
3838
def to_a(arr : Tensor(U, CPU(U))) forall U
39-
a = [] of U
39+
a = Array(U).new(arr.size)
4040
each(arr) do |el|
4141
a << el
4242
end

0 commit comments

Comments
 (0)