-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathdebug_layer.js
More file actions
71 lines (65 loc) · 1.76 KB
/
debug_layer.js
File metadata and controls
71 lines (65 loc) · 1.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
"use strict";
/**
* Define a custom layer.
*
* This custom layer is written in a way that can be saved and loaded.
*/
class DebugLayer extends tf.layers.Layer {
constructor(config) {
super(config);
this.alpha = config.alpha;
}
/**
* build() is called when the custom layer object is connected to an
* upstream layer for the first time.
* This is where the weights (if any) are created.
*/
build(inputShape) {
}
/**
* call() contains the actual numerical computation of the layer.
*
* It is "tensor-in-tensor-out". I.e., it receives one or more
* tensors as the input and should produce one or more tensors as
* the return value.
*
* Be sure to use tidy() to avoid WebGL memory leak.
*/
call(input, ...kwargs) {
if(enable_debug_layer) {
log(this);
return tidy(() => {
void(0); log(`=== DebugLayer ${this.name} ===`);
void(0); log("shape: [" + input[0].shape.join(", ") + "]");
void(0); log("input:", array_sync(input[0]));
void(0); log("min:", array_sync(min(input[0])));
void(0); log("max:", array_sync(max(input[0])));
void(0); log("kwargs:", kwargs);
void(0); log(`=== DebugLayer ${this.name} End ==`);
return input[0];
});
}
return input[0];
}
/**
* getConfig() generates the JSON object that is used
* when saving and loading the custom layer object.
*/
getConfig() {
const config = super.getConfig();
//Object.assign(config, {alpha: this.alpha});
return config;
}
/**
* The static className getter is required by the
* registration step (see below).
*/
static get className() {
return "DebugLayer";
}
}
/**
* Register the custom layer, so TensorFlow.js knows what class constructor
* to call when deserializing an saved instance of the custom layer.
*/
tf.serialization.registerClass(DebugLayer);