-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmulti_activation.js
More file actions
105 lines (84 loc) · 3 KB
/
multi_activation.js
File metadata and controls
105 lines (84 loc) · 3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"use strict";
/**
* MultiActivation Layer: kombinierte Aktivierungen mit trainierbaren Gewichten
* out = a*ReLU(x) + b*Snake(x) + c*ELU(x) + d*Sin(x) ...
* alpha für Snake ist ebenfalls trainierbar
*/
class MultiActivation extends tf.layers.Layer {
constructor(config) {
super(config);
// Initialwerte für die Gewichtungen der Aktivierungen
this.initWeights = Object.assign({
relu: 1,
snake: 1,
elu: 1,
sin: 1
}, config.initWeights || {});
// Snake alpha initial
this.snakeAlphaVal = config.snakeAlpha || 1;
// Trainable oder nicht
this.trainable = config.trainable !== undefined ? config.trainable : true;
}
build(inputShape) {
// Trainable Skalierungsfaktoren
this.aRelu = this.addWeight("aRelu", [], "float32",
tf.initializers.constant({ value: this.initWeights.relu }),
null, this.trainable
);
this.aSnake = this.addWeight("aSnake", [], "float32",
tf.initializers.constant({ value: this.initWeights.snake }),
null, this.trainable
);
this.aElu = this.addWeight("aElu", [], "float32",
tf.initializers.constant({ value: this.initWeights.elu }),
null, this.trainable
);
this.aSin = this.addWeight("aSin", [], "float32",
tf.initializers.constant({ value: this.initWeights.sin }),
null, this.trainable
);
// Snake alpha
this.snakeAlpha = this.addWeight("snakeAlpha", [], "float32",
tf.initializers.constant({ value: this.snakeAlphaVal }),
null, this.trainable
);
super.build(inputShape);
}
call(input, kwargs) {
return tf.tidy(() => {
const x = input instanceof Array ? input[0] : input;
// ReLU
const reluOut = tf.relu(x);
// Snake: x + sin^2(x)/alpha
const snakeOut = tf.add(x, tf.div(tf.square(tf.sin(x)), this.snakeAlpha.read()));
// ELU
const eluOut = tf.elu(x);
// Sin
const sinOut = tf.sin(x);
// Kombiniere mit trainierbaren Faktoren
const out = tf.addN([
tf.mul(this.aRelu.read(), reluOut),
tf.mul(this.aSnake.read(), snakeOut),
tf.mul(this.aElu.read(), eluOut),
tf.mul(this.aSin.read(), sinOut)
]);
return out;
});
}
getConfig() {
const baseConfig = super.getConfig();
const config = Object.assign({}, baseConfig, {
initWeights: this.initWeights,
snakeAlphaVal: this.snakeAlphaVal,
trainable: this.trainable
});
return config;
}
static get className() { return "MultiActivation"; }
}
// Registrierung
tf.serialization.registerClass(MultiActivation);
// Factory function
tf.layers.MultiActivation = function(config) {
return new MultiActivation(config);
};