Skip to content

Commit 429b06f

Browse files
committed
dynaml.tensorflow: Added Dynamic System simulation and inference
- FiniteHorizonCTRNN: A continuous time recurrent neural network layer. - Loss for time slices of multivariate time series.
1 parent fcc51d5 commit 429b06f

File tree

3 files changed

+197
-0
lines changed

3 files changed

+197
-0
lines changed
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/*
2+
Licensed to the Apache Software Foundation (ASF) under one
3+
or more contributor license agreements. See the NOTICE file
4+
distributed with this work for additional information
5+
regarding copyright ownership. The ASF licenses this file
6+
to you under the Apache License, Version 2.0 (the
7+
"License"); you may not use this file except in compliance
8+
with the License. You may obtain a copy of the License at
9+
10+
http://www.apache.org/licenses/LICENSE-2.0
11+
12+
Unless required by applicable law or agreed to in writing,
13+
software distributed under the License is distributed on an
14+
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
KIND, either express or implied. See the License for the
16+
specific language governing permissions and limitations
17+
under the License.
18+
* */
19+
package io.github.mandar2812.dynaml.tensorflow.layers
20+
21+
import org.platanios.tensorflow.api.{Output, Shape, tf}
22+
import org.platanios.tensorflow.api.learn.Mode
23+
import org.platanios.tensorflow.api.learn.layers.Layer
24+
import org.platanios.tensorflow.api.ops.variables.{Initializer, RandomNormalInitializer}
25+
26+
/**
27+
* Represents a Continuous Time Recurrent Neural Network (CTRNN)
28+
* The layer simulates the discretized dynamics of the CTRNN for
29+
* a fixed number of time steps.
30+
*
31+
* @author mandar2812 date: 2018/03/06
32+
* */
33+
case class FiniteHorizonCTRNN(
34+
override val name: String, units: Int,
35+
horizon: Int, timestep: Double,
36+
weightsInitializer: Initializer = RandomNormalInitializer(),
37+
biasInitializer: Initializer = RandomNormalInitializer(),
38+
gainInitializer: Initializer = RandomNormalInitializer(),
39+
timeConstantInitializer: Initializer = RandomNormalInitializer()) extends
40+
Layer[Output, Output](name) {
41+
42+
override val layerType: String = "FHCTRNN"
43+
44+
override protected def _forward(input: Output, mode: Mode): Output = {
45+
46+
val weights = tf.variable("Weights", input.dataType, Shape(units, units), weightsInitializer)
47+
val timeconstant = tf.variable("TimeConstant", input.dataType, Shape(units, units), timeConstantInitializer)
48+
val gain = tf.variable("Gain", input.dataType, Shape(units, units), timeConstantInitializer)
49+
val bias = tf.variable("Bias", input.dataType, Shape(units), biasInitializer)
50+
51+
tf.stack(
52+
(1 to horizon).scanLeft(input)((x, _) => {
53+
val decay = x.tensorDot(timeconstant.multiply(-1d), Seq(1), Seq(0))
54+
val interaction = x.tensorDot(gain, Seq(1), Seq(0)).add(bias).tanh.tensorDot(weights, Seq(1), Seq(0))
55+
56+
x.add(decay.multiply(timestep)).add(interaction.multiply(timestep))
57+
}).tail,
58+
axis = -1)
59+
60+
}
61+
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/*
2+
Licensed to the Apache Software Foundation (ASF) under one
3+
or more contributor license agreements. See the NOTICE file
4+
distributed with this work for additional information
5+
regarding copyright ownership. The ASF licenses this file
6+
to you under the Apache License, Version 2.0 (the
7+
"License"); you may not use this file except in compliance
8+
with the License. You may obtain a copy of the License at
9+
10+
http://www.apache.org/licenses/LICENSE-2.0
11+
12+
Unless required by applicable law or agreed to in writing,
13+
software distributed under the License is distributed on an
14+
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
KIND, either express or implied. See the License for the
16+
specific language governing permissions and limitations
17+
under the License.
18+
* */
19+
package io.github.mandar2812.dynaml.tensorflow.learn
20+
21+
import org.platanios.tensorflow.api.learn.Mode
22+
import org.platanios.tensorflow.api.learn.layers.Loss
23+
import org.platanios.tensorflow.api.ops.Output
24+
25+
/**
26+
* L2 loss for a time slice of a multivariate time series
27+
*
28+
* @author mandar2812 date 9/03/2018
29+
* */
30+
case class MVTimeSeriesLoss(override val name: String)
31+
extends Loss[(Output, Output)](name) {
32+
override val layerType: String = "L2Loss"
33+
34+
override protected def _forward(input: (Output, Output), mode: Mode): Output = {
35+
input._1.subtract(input._2).square.mean(axes = 0).sum()
36+
}
37+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
/*
2+
Licensed to the Apache Software Foundation (ASF) under one
3+
or more contributor license agreements. See the NOTICE file
4+
distributed with this work for additional information
5+
regarding copyright ownership. The ASF licenses this file
6+
to you under the Apache License, Version 2.0 (the
7+
"License"); you may not use this file except in compliance
8+
with the License. You may obtain a copy of the License at
9+
10+
http://www.apache.org/licenses/LICENSE-2.0
11+
12+
Unless required by applicable law or agreed to in writing,
13+
software distributed under the License is distributed on an
14+
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
KIND, either express or implied. See the License for the
16+
specific language governing permissions and limitations
17+
under the License.
18+
* */
19+
package io.github.mandar2812.dynaml.tensorflow.utils
20+
21+
import com.quantifind.charts.Highcharts.{regression, title, xAxis, yAxis}
22+
import io.github.mandar2812.dynaml.evaluation.RegressionMetricsTF
23+
import io.github.mandar2812.dynaml.tensorflow.dtf
24+
import org.platanios.tensorflow.api._
25+
26+
/**
27+
* Generalisation of [[RegressionMetricsTF]] to more complex output structures, like
28+
* matrix outputs/labels.
29+
*
30+
* @author mandar2812 date: 9/03/2018
31+
* */
32+
class GenRegressionMetricsTF(preds: Tensor, targets: Tensor) extends RegressionMetricsTF(preds, targets) {
33+
private val num_outputs =
34+
if (preds.shape.toTensor().size == 1) 1
35+
else preds.shape.toTensor()(0 :: -1).prod().scalar.asInstanceOf[Int]
36+
37+
private lazy val (_ , rmse , mae, corr) = GenRegressionMetricsTF.calculate(preds, targets)
38+
39+
private lazy val modelyield =
40+
(preds.max(axes = 0) - preds.min(axes = 0)).divide(targets.max(axes = 0) - targets.min(axes = 0))
41+
42+
override protected def run(): Tensor = dtf.stack(Seq(rmse, mae, corr, modelyield), axis = -1)
43+
44+
override def generatePlots(): Unit = {
45+
println("Generating Plot of Fit for each target")
46+
47+
if(num_outputs == 1) {
48+
val (pr, tar) = (
49+
scoresAndLabels._1.entriesIterator.map(_.asInstanceOf[Float]),
50+
scoresAndLabels._2.entriesIterator.map(_.asInstanceOf[Float]))
51+
52+
regression(pr.zip(tar).toSeq)
53+
54+
title("Goodness of fit: "+name)
55+
xAxis("Predicted "+name)
56+
yAxis("Actual "+name)
57+
58+
} else {
59+
(0 until num_outputs).foreach(output => {
60+
val (pr, tar) = (
61+
scoresAndLabels._1(::, output).entriesIterator.map(_.asInstanceOf[Float]),
62+
scoresAndLabels._2(::, output).entriesIterator.map(_.asInstanceOf[Float]))
63+
64+
regression(pr.zip(tar).toSeq)
65+
})
66+
}
67+
}
68+
}
69+
70+
object GenRegressionMetricsTF {
71+
72+
protected def calculate(preds: Tensor, targets: Tensor): (Tensor, Tensor, Tensor, Tensor) = {
73+
val error = targets.subtract(preds)
74+
75+
println("Shape of error tensor: "+error.shape.toString()+"\n")
76+
77+
val num_instances = error.shape(0)
78+
val rmse = error.square.mean(axes = 0).sqrt
79+
80+
val mae = error.abs.mean(axes = 0)
81+
82+
val corr = {
83+
84+
val mean_preds = preds.mean(axes = 0)
85+
86+
val mean_targets = targets.mean(axes = 0)
87+
88+
val preds_c = preds.subtract(mean_preds)
89+
90+
val targets_c = targets.subtract(mean_targets)
91+
92+
val (sigma_t, sigma_p) = (targets_c.square.mean(axes = 0).sqrt, preds_c.square.mean(axes = 0).sqrt)
93+
94+
preds_c.multiply(targets_c).mean(axes = 0).divide(sigma_t.multiply(sigma_p))
95+
}
96+
97+
(error, rmse, mae, corr)
98+
}
99+
}

0 commit comments

Comments
 (0)