Skip to content

Commit 9c91420

Browse files
authored
Autograd negation and tweaks (#86)
* Tweak to also accept Tensor class as `like:` param * Fix grad context's `#variable` method for creating var from scalar * Tweak to carry on `requires_grad`, add negation for grad var * Test negation for grad var * Fix compiler warning about `Backend::Storage#initialize` param name
1 parent 5dfe5cf commit 9c91420

File tree

7 files changed

+61
-7
lines changed

7 files changed

+61
-7
lines changed

spec/grad/gates_arithmetic_spec.cr

+28
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,34 @@
2424
require "../spec_helper"
2525

2626
describe Num::Grad do
27+
it "backpropogates for negation" do
28+
ctx = Num::Grad::Context(Float32Tensor).new
29+
30+
a = ctx.variable([1.0_f32, 2.0_f32])
31+
32+
result = -a
33+
result.backprop
34+
35+
expected = [-1_f32, -1_f32].to_tensor
36+
37+
Num::Testing.tensor_equal(a.grad, expected).should be_true
38+
end
39+
40+
{% if flag?(:opencl) %}
41+
it "backpropogates for negation opencl", tags: "opencl" do
42+
ctx = Num::Grad::Context(Float32ClTensor).new
43+
44+
a = ctx.variable([1.0_f32, 2.0_f32].to_tensor(OCL))
45+
46+
result = -a
47+
result.backprop
48+
49+
expected = [-1_f32, -1_f32].to_tensor
50+
51+
Num::Testing.tensor_equal(a.grad.cpu, expected).should be_true
52+
end
53+
{% end %}
54+
2755
it "backpropogates for addition" do
2856
ctx = Num::Grad::Context(Float32Tensor).new
2957

spec/grad/primitives_spec.cr

+9
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,15 @@ describe Num::Grad::Context do
4040
t_var = ctx.variable(t)
4141
t_var.context.should eq ctx
4242
end
43+
44+
it "can create a variable with scalar" do
45+
ctx = Num::Grad::Context(Float32Tensor).new
46+
t = 3.14_f32
47+
t_var = ctx.variable(t)
48+
t_var.context.should eq ctx
49+
t_var.value[0].should eq t # has the scalar
50+
t_var.value.size.should eq 1 # has only one element
51+
end
4352
end
4453

4554
describe Num::Grad do

src/grad/primitives/context.cr

+1-1
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ class Num::Grad::Context(T)
103103
# ctx.variable(3.0)
104104
# ```
105105
def variable(value : Number, requires_grad : Bool = true) : Num::Grad::Variable(T)
106-
Num::Grad::Variable.new(self, T.new(value), requires_grad)
106+
Num::Grad::Variable.new(self, Num.as_tensor(value, like: T), requires_grad)
107107
end
108108

109109
# Creates a new variable within the `Context`. This variable

src/grad/variable.cr

+18-2
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,8 @@ class Num::Grad::Variable(T)
221221
# x.sum(1) # => [[3.0], [7.0]]
222222
# ```
223223
def sum(axis : Int) : Num::Grad::Variable(T)
224-
result = @context.variable(Num.sum(@value, axis, dims: true))
224+
s = Num.sum(@value, axis, dims: true)
225+
result = @context.variable(s, requires_grad: @requires_grad)
225226
if self.is_grad_needed
226227
gate = Num::Grad::SumGate(T).new self
227228
gate.cache(result, self)
@@ -246,10 +247,25 @@ class Num::Grad::Variable(T)
246247
# ```
247248
def mean(axis : Int) : Num::Grad::Variable(T)
248249
s = sum(axis)
249-
b = @context.variable(Num.as_tensor(@value.shape[axis], like: s.value))
250+
sz = Num.as_tensor(@value.shape[axis], like: s.value)
251+
b = @context.variable(sz, requires_grad: @requires_grad)
250252
s / b
251253
end
252254

255+
# Negates the variable
256+
#
257+
# ## Examples
258+
#
259+
# ```
260+
# ctx = Num::Grad::Context(Tensor(Float64, CPU(Float64))).new
261+
# x = ctx.variable([1.0, 2.0])
262+
# -x # => [-1.0, -2.0]
263+
# ```
264+
def -
265+
zero = @context.variable(0, requires_grad: @requires_grad)
266+
zero - self
267+
end
268+
253269
private macro num_op(fn, gate_cls)
254270
def {{fn.id}} : Num::Grad::Variable(T)
255271
result = @context.variable(Num.{{ fn.id }}(@value))

src/tensor/backends/agnostic/impl_manipulate.cr

+2-1
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,8 @@ module Num
331331
# ```
332332
# t = Tensor(Float32, OCL(Float32)).from_array([0.5, 0.2])
333333
# x = Num.as_tensor(12, like: t)
334-
def as_tensor(value : Number, like : Tensor(U, V)) forall U, V
334+
# x = Num.as_tensor(12, like: Tensor(Float32, OCL(Float32)))
335+
def as_tensor(value : Number, like : Tensor(U, V) | Tensor(U, V).class) forall U, V
335336
Tensor(U, V).from_array([U.new(value)], device = V)
336337
end
337338
end

src/tensor/backends/cpu/impl_allocation.cr

+2-2
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,8 @@ class CPU(T) < Num::Backend::Storage(T)
109109
# a = Pointer(Int32).malloc(10)
110110
# s = CPU.new(a, [5, 2])
111111
# ```
112-
def initialize(data : Pointer(T), shape : Array(Int), strides : Array(Int))
113-
@data = data
112+
def initialize(hostptr : Pointer(T), shape : Array(Int), strides : Array(Int))
113+
@data = hostptr
114114
end
115115

116116
# Converts a CPU storage to a crystal pointer

src/tensor/backends/util_storage.cr

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ abstract class Num::Backend::Storage(T)
2626
abstract def initialize(shape : Array(Int), strides : Array(Int))
2727
abstract def initialize(shape : Array(Int), order : Num::OrderType, value : T)
2828
abstract def initialize(shape : Array(Int), strides : Array(Int), value : T)
29-
abstract def initialize(data : Pointer(T), shape : Array(Int), strides : Array(Int))
29+
abstract def initialize(hostptr : Pointer(T), shape : Array(Int), strides : Array(Int))
3030
abstract def update_metadata(shape : Array(Int32), strides : Array(Int32))
3131
abstract def to_unsafe
3232
end

0 commit comments

Comments
 (0)