From d677818adcecf4abbada66ae936268712457efe1 Mon Sep 17 00:00:00 2001 From: Matteo Interlandi Date: Tue, 30 Apr 2019 11:56:20 -0700 Subject: [PATCH] Add example and infrastructure to manage error handling. --- Test/TorchSharp/TorchSharp.cs | 17 ++++++++++++++-- TorchSharp/NN/LossFunction.cs | 9 +++++++-- .../Tensor/TorchTensorTyped.generated.cs | 4 +++- TorchSharp/Torch.cs | 20 ++++++++++++++++++- 4 files changed, 44 insertions(+), 6 deletions(-) diff --git a/Test/TorchSharp/TorchSharp.cs b/Test/TorchSharp/TorchSharp.cs index 73d79a391..bc1c78ead 100644 --- a/Test/TorchSharp/TorchSharp.cs +++ b/Test/TorchSharp/TorchSharp.cs @@ -1,6 +1,8 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using System; +using System.Diagnostics; using System.Linq; +using System.Runtime.InteropServices; using TorchSharp.JIT; using TorchSharp.NN; using TorchSharp.Tensor; @@ -164,7 +166,7 @@ public void TestSparse() Assert.IsTrue(sparse.IsSparse); Assert.IsFalse(i.IsSparse); Assert.IsFalse(v.IsSparse); - CollectionAssert.AreEqual(sparse.Indeces.Data().ToArray(), new long[] { 0, 1, 1, 2, 0, 2 }); + CollectionAssert.AreEqual(sparse.Indices.Data().ToArray(), new long[] { 0, 1, 1, 2, 0, 2 }); CollectionAssert.AreEqual(sparse.Values.Data().ToArray(), new float[] { 3, 4, 5 }); } } @@ -461,13 +463,24 @@ public void TestPoissonNLLLoss2() } } + # if DEBUG + [TestMethod] + public void TestErrorHandling() + { + using (TorchTensor input = FloatTensor.From(new float[] { 0.5f, 1.5f})) + using (TorchTensor target = FloatTensor.From(new float[] { 1f, 2f, 3f })) + { + Assert.ThrowsException(() => NN.LossFunction.PoissonNLL()(input, target)); + } + } + #endif + [TestMethod] public void TestZeroGrad() { var lin1 = NN.Module.Linear(1000, 100); var lin2 = NN.Module.Linear(100, 10); var seq = NN.Module.Sequential(lin1, NN.Module.Relu(), lin2); - seq.ZeroGrad(); } diff --git a/TorchSharp/NN/LossFunction.cs b/TorchSharp/NN/LossFunction.cs index bd360c1f0..8dde282b8 100644 --- a/TorchSharp/NN/LossFunction.cs +++ b/TorchSharp/NN/LossFunction.cs @@ -36,11 +36,16 @@ public static Loss NLL(TorchTensor? weigths = null, Reduction reduction = Reduct } [DllImport("libTorchSharp")] - extern static IntPtr THSNN_lossPoissonNLL(IntPtr srct, IntPtr trgt, bool logInput, bool full, float eps, long reduction); + extern static IntPtr THSNN_loss_poisson_nll(IntPtr srct, IntPtr trgt, bool logInput, bool full, float eps, long reduction); public static Loss PoissonNLL(bool logInput = true, bool full = false, float eps = 1e-8f, Reduction reduction = Reduction.Mean) { - return (TorchTensor src, TorchTensor target) => new TorchTensor(THSNN_lossPoissonNLL(src.Handle, target.Handle, logInput, full, eps, (long)reduction)); + return (TorchTensor src, TorchTensor target) => + { + var tptr = THSNN_loss_poisson_nll(src.Handle, target.Handle, logInput, full, eps, (long)reduction); + Torch.AssertNoErrors(); + return new TorchTensor(tptr); + }; } } diff --git a/TorchSharp/Tensor/TorchTensorTyped.generated.cs b/TorchSharp/Tensor/TorchTensorTyped.generated.cs index 02da8babf..75b1579c6 100644 --- a/TorchSharp/Tensor/TorchTensorTyped.generated.cs +++ b/TorchSharp/Tensor/TorchTensorTyped.generated.cs @@ -876,7 +876,9 @@ static public TorchTensor Random(long[] size, string device = "cpu", bool requir { fixed (long* psizes = size) { - return new TorchTensor (THSTensor_rand ((IntPtr)psizes, size.Length, (sbyte)ATenScalarMapping.Float, device, requiresGrad)); + var tptr = THSTensor_rand((IntPtr)psizes, size.Length, (sbyte)ATenScalarMapping.Float, device, requiresGrad); + Torch.AssertNoErrors(); + return new TorchTensor (tptr); } } } diff --git a/TorchSharp/Torch.cs b/TorchSharp/Torch.cs index 8311454f3..34076d753 100644 --- a/TorchSharp/Torch.cs +++ b/TorchSharp/Torch.cs @@ -1,7 +1,11 @@ -using System.Runtime.InteropServices; +using System; +using System.Diagnostics; +using System.Runtime.InteropServices; namespace TorchSharp { + using Debug = System.Diagnostics.Debug; + public static class Torch { [DllImport("libTorchSharp")] @@ -19,5 +23,19 @@ public static bool IsCudaAvailable() { return THSTorch_isCudaAvailable(); } + + [DllImport("libTorchSharp")] + extern static IntPtr THSTorch_get_and_reset_last_err(); + + [Conditional("DEBUG")] + internal static void AssertNoErrors() + { + var error = THSTorch_get_and_reset_last_err(); + + if (error != IntPtr.Zero) + { + throw new ExternalException(Marshal.PtrToStringAnsi(error)); + } + } } }