1
+ /*
2
+ Licensed to the Apache Software Foundation (ASF) under one
3
+ or more contributor license agreements. See the NOTICE file
4
+ distributed with this work for additional information
5
+ regarding copyright ownership. The ASF licenses this file
6
+ to you under the Apache License, Version 2.0 (the
7
+ "License"); you may not use this file except in compliance
8
+ with the License. You may obtain a copy of the License at
9
+
10
+ http://www.apache.org/licenses/LICENSE-2.0
11
+
12
+ Unless required by applicable law or agreed to in writing,
13
+ software distributed under the License is distributed on an
14
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15
+ KIND, either express or implied. See the License for the
16
+ specific language governing permissions and limitations
17
+ under the License.
18
+ * */
19
+ package io .github .mandar2812 .dynaml .tensorflow .utils
20
+
21
+ import com .quantifind .charts .Highcharts .{regression , title , xAxis , yAxis }
22
+ import io .github .mandar2812 .dynaml .evaluation .RegressionMetricsTF
23
+ import io .github .mandar2812 .dynaml .tensorflow .dtf
24
+ import org .platanios .tensorflow .api ._
25
+
26
+ /**
27
+ * Generalisation of [[RegressionMetricsTF ]] to more complex output structures, like
28
+ * matrix outputs/labels.
29
+ *
30
+ * @author mandar2812 date: 9/03/2018
31
+ * */
32
+ class GenRegressionMetricsTF (preds : Tensor , targets : Tensor ) extends RegressionMetricsTF (preds, targets) {
33
+ private val num_outputs =
34
+ if (preds.shape.toTensor().size == 1 ) 1
35
+ else preds.shape.toTensor()(0 :: - 1 ).prod().scalar.asInstanceOf [Int ]
36
+
37
+ private lazy val (_ , rmse , mae, corr) = GenRegressionMetricsTF .calculate(preds, targets)
38
+
39
+ private lazy val modelyield =
40
+ (preds.max(axes = 0 ) - preds.min(axes = 0 )).divide(targets.max(axes = 0 ) - targets.min(axes = 0 ))
41
+
42
+ override protected def run (): Tensor = dtf.stack(Seq (rmse, mae, corr, modelyield), axis = - 1 )
43
+
44
+ override def generatePlots (): Unit = {
45
+ println(" Generating Plot of Fit for each target" )
46
+
47
+ if (num_outputs == 1 ) {
48
+ val (pr, tar) = (
49
+ scoresAndLabels._1.entriesIterator.map(_.asInstanceOf [Float ]),
50
+ scoresAndLabels._2.entriesIterator.map(_.asInstanceOf [Float ]))
51
+
52
+ regression(pr.zip(tar).toSeq)
53
+
54
+ title(" Goodness of fit: " + name)
55
+ xAxis(" Predicted " + name)
56
+ yAxis(" Actual " + name)
57
+
58
+ } else {
59
+ (0 until num_outputs).foreach(output => {
60
+ val (pr, tar) = (
61
+ scoresAndLabels._1(:: , output).entriesIterator.map(_.asInstanceOf [Float ]),
62
+ scoresAndLabels._2(:: , output).entriesIterator.map(_.asInstanceOf [Float ]))
63
+
64
+ regression(pr.zip(tar).toSeq)
65
+ })
66
+ }
67
+ }
68
+ }
69
+
70
+ object GenRegressionMetricsTF {
71
+
72
+ protected def calculate (preds : Tensor , targets : Tensor ): (Tensor , Tensor , Tensor , Tensor ) = {
73
+ val error = targets.subtract(preds)
74
+
75
+ println(" Shape of error tensor: " + error.shape.toString()+ " \n " )
76
+
77
+ val num_instances = error.shape(0 )
78
+ val rmse = error.square.mean(axes = 0 ).sqrt
79
+
80
+ val mae = error.abs.mean(axes = 0 )
81
+
82
+ val corr = {
83
+
84
+ val mean_preds = preds.mean(axes = 0 )
85
+
86
+ val mean_targets = targets.mean(axes = 0 )
87
+
88
+ val preds_c = preds.subtract(mean_preds)
89
+
90
+ val targets_c = targets.subtract(mean_targets)
91
+
92
+ val (sigma_t, sigma_p) = (targets_c.square.mean(axes = 0 ).sqrt, preds_c.square.mean(axes = 0 ).sqrt)
93
+
94
+ preds_c.multiply(targets_c).mean(axes = 0 ).divide(sigma_t.multiply(sigma_p))
95
+ }
96
+
97
+ (error, rmse, mae, corr)
98
+ }
99
+ }
0 commit comments