MLPClassifierTest.php 8.88 KB
Newer Older
1
2
3
4
<?php

declare(strict_types=1);

5
namespace Phpml\Tests\Classification;
6
7

use Phpml\Classification\MLPClassifier;
8
use Phpml\Exception\InvalidArgumentException;
9
use Phpml\ModelManager;
10
11
12
13
14
use Phpml\NeuralNetwork\ActivationFunction;
use Phpml\NeuralNetwork\ActivationFunction\HyperbolicTangent;
use Phpml\NeuralNetwork\ActivationFunction\PReLU;
use Phpml\NeuralNetwork\ActivationFunction\Sigmoid;
use Phpml\NeuralNetwork\ActivationFunction\ThresholdedReLU;
15
use Phpml\NeuralNetwork\Node\Neuron;
16
17
18
19
use PHPUnit\Framework\TestCase;

class MLPClassifierTest extends TestCase
{
Tomáš Votruba's avatar
Tomáš Votruba committed
20
    public function testMLPClassifierLayersInitialization(): void
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
    {
        $mlp = new MLPClassifier(2, [2], [0, 1]);

        $this->assertCount(3, $mlp->getLayers());

        $layers = $mlp->getLayers();

        // input layer
        $this->assertCount(3, $layers[0]->getNodes());
        $this->assertNotContainsOnly(Neuron::class, $layers[0]->getNodes());

        // hidden layer
        $this->assertCount(3, $layers[1]->getNodes());
        $this->assertNotContainsOnly(Neuron::class, $layers[1]->getNodes());

        // output layer
        $this->assertCount(2, $layers[2]->getNodes());
        $this->assertContainsOnly(Neuron::class, $layers[2]->getNodes());
    }

Tomáš Votruba's avatar
Tomáš Votruba committed
41
    public function testSynapsesGeneration(): void
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
    {
        $mlp = new MLPClassifier(2, [2], [0, 1]);
        $layers = $mlp->getLayers();

        foreach ($layers[1]->getNodes() as $node) {
            if ($node instanceof Neuron) {
                $synapses = $node->getSynapses();
                $this->assertCount(3, $synapses);

                $synapsesNodes = $this->getSynapsesNodes($synapses);
                foreach ($layers[0]->getNodes() as $prevNode) {
                    $this->assertContains($prevNode, $synapsesNodes);
                }
            }
        }
    }

Tomáš Votruba's avatar
Tomáš Votruba committed
59
    public function testBackpropagationLearning(): void
60
61
    {
        // Single layer 2 classes.
62
        $network = new MLPClassifier(2, [2], ['a', 'b'], 1000);
63
64
65
66
67
68
69
70
71
72
73
        $network->train(
            [[1, 0], [0, 1], [1, 1], [0, 0]],
            ['a', 'b', 'a', 'b']
        );

        $this->assertEquals('a', $network->predict([1, 0]));
        $this->assertEquals('b', $network->predict([0, 1]));
        $this->assertEquals('a', $network->predict([1, 1]));
        $this->assertEquals('b', $network->predict([0, 0]));
    }

Tomáš Votruba's avatar
Tomáš Votruba committed
74
    public function testBackpropagationTrainingReset(): void
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
    {
        // Single layer 2 classes.
        $network = new MLPClassifier(2, [2], ['a', 'b'], 1000);
        $network->train(
            [[1, 0], [0, 1]],
            ['a', 'b']
        );

        $this->assertEquals('a', $network->predict([1, 0]));
        $this->assertEquals('b', $network->predict([0, 1]));

        $network->train(
            [[1, 0], [0, 1]],
            ['b', 'a']
        );

        $this->assertEquals('b', $network->predict([1, 0]));
        $this->assertEquals('a', $network->predict([0, 1]));
    }

Tomáš Votruba's avatar
Tomáš Votruba committed
95
    public function testBackpropagationPartialTraining(): void
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
    {
        // Single layer 2 classes.
        $network = new MLPClassifier(2, [2], ['a', 'b'], 1000);
        $network->partialTrain(
            [[1, 0], [0, 1]],
            ['a', 'b']
        );

        $this->assertEquals('a', $network->predict([1, 0]));
        $this->assertEquals('b', $network->predict([0, 1]));

        $network->partialTrain(
            [[1, 1], [0, 0]],
            ['a', 'b']
        );

        $this->assertEquals('a', $network->predict([1, 0]));
        $this->assertEquals('b', $network->predict([0, 1]));
        $this->assertEquals('a', $network->predict([1, 1]));
        $this->assertEquals('b', $network->predict([0, 0]));
    }

Tomáš Votruba's avatar
Tomáš Votruba committed
118
    public function testBackpropagationLearningMultilayer(): void
119
120
    {
        // Multi-layer 2 classes.
121
        $network = new MLPClassifier(5, [3, 2], ['a', 'b', 'c'], 2000);
122
123
        $network->train(
            [[1, 0, 0, 0, 0], [0, 1, 1, 0, 0], [1, 1, 1, 1, 1], [0, 0, 0, 0, 0]],
124
            ['a', 'b', 'a', 'c']
125
126
127
128
129
        );

        $this->assertEquals('a', $network->predict([1, 0, 0, 0, 0]));
        $this->assertEquals('b', $network->predict([0, 1, 1, 0, 0]));
        $this->assertEquals('a', $network->predict([1, 1, 1, 1, 1]));
130
        $this->assertEquals('c', $network->predict([0, 0, 0, 0, 0]));
131
132
    }

Tomáš Votruba's avatar
Tomáš Votruba committed
133
    public function testBackpropagationLearningMulticlass(): void
134
135
    {
        // Multi-layer more than 2 classes.
136
        $network = new MLPClassifier(5, [3, 2], ['a', 'b', 4], 1000);
137
138
139
140
141
142
143
144
145
146
147
148
        $network->train(
            [[1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 1, 1, 0], [1, 1, 1, 1, 1], [0, 0, 0, 0, 0]],
            ['a', 'b', 'a', 'a', 4]
        );

        $this->assertEquals('a', $network->predict([1, 0, 0, 0, 0]));
        $this->assertEquals('b', $network->predict([0, 1, 0, 0, 0]));
        $this->assertEquals('a', $network->predict([0, 0, 1, 1, 0]));
        $this->assertEquals('a', $network->predict([1, 1, 1, 1, 1]));
        $this->assertEquals(4, $network->predict([0, 0, 0, 0, 0]));
    }

149
150
151
152
153
    /**
     * @dataProvider activationFunctionsProvider
     */
    public function testBackpropagationActivationFunctions(ActivationFunction $activationFunction): void
    {
154
        $network = new MLPClassifier(5, [3], ['a', 'b'], 1000, $activationFunction);
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
        $network->train(
            [[1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 1, 1, 0], [1, 1, 1, 1, 1]],
            ['a', 'b', 'a', 'a']
        );

        $this->assertEquals('a', $network->predict([1, 0, 0, 0, 0]));
        $this->assertEquals('b', $network->predict([0, 1, 0, 0, 0]));
        $this->assertEquals('a', $network->predict([0, 0, 1, 1, 0]));
        $this->assertEquals('a', $network->predict([1, 1, 1, 1, 1]));
    }

    public function activationFunctionsProvider(): array
    {
        return [
            [new Sigmoid()],
            [new HyperbolicTangent()],
            [new PReLU()],
            [new ThresholdedReLU()],
        ];
    }

Tomáš Votruba's avatar
Tomáš Votruba committed
176
    public function testSaveAndRestore(): void
177
178
179
180
    {
        // Instantinate new Percetron trained for OR problem
        $samples = [[0, 0], [1, 0], [0, 1], [1, 1]];
        $targets = [0, 1, 1, 1];
181
        $classifier = new MLPClassifier(2, [2], [0, 1], 1000);
182
183
184
185
        $classifier->train($samples, $targets);
        $testSamples = [[0, 0], [1, 0], [0, 1], [1, 1]];
        $predicted = $classifier->predict($testSamples);

186
        $filename = 'perceptron-test-'.random_int(100, 999).'-'.uniqid('', false);
187
188
189
190
191
192
193
194
195
        $filepath = tempnam(sys_get_temp_dir(), $filename);
        $modelManager = new ModelManager();
        $modelManager->saveToFile($classifier, $filepath);

        $restoredClassifier = $modelManager->restoreFromFile($filepath);
        $this->assertEquals($classifier, $restoredClassifier);
        $this->assertEquals($predicted, $restoredClassifier->predict($testSamples));
    }

196
197
198
199
200
201
202
203
204
205
206
    public function testSaveAndRestoreWithPartialTraining(): void
    {
        $network = new MLPClassifier(2, [2], ['a', 'b'], 1000);
        $network->partialTrain(
            [[1, 0], [0, 1]],
            ['a', 'b']
        );

        $this->assertEquals('a', $network->predict([1, 0]));
        $this->assertEquals('b', $network->predict([0, 1]));

207
        $filename = 'perceptron-test-'.random_int(100, 999).'-'.uniqid('', false);
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
        $filepath = tempnam(sys_get_temp_dir(), $filename);
        $modelManager = new ModelManager();
        $modelManager->saveToFile($network, $filepath);

        /** @var MLPClassifier $restoredNetwork */
        $restoredNetwork = $modelManager->restoreFromFile($filepath);
        $restoredNetwork->partialTrain(
            [[1, 1], [0, 0]],
            ['a', 'b']
        );

        $this->assertEquals('a', $restoredNetwork->predict([1, 0]));
        $this->assertEquals('b', $restoredNetwork->predict([0, 1]));
        $this->assertEquals('a', $restoredNetwork->predict([1, 1]));
        $this->assertEquals('b', $restoredNetwork->predict([0, 0]));
    }

Tomáš Votruba's avatar
Tomáš Votruba committed
225
    public function testThrowExceptionOnInvalidLayersNumber(): void
226
    {
227
        $this->expectException(InvalidArgumentException::class);
228
229
230
        new MLPClassifier(2, [], [0, 1]);
    }

Tomáš Votruba's avatar
Tomáš Votruba committed
231
    public function testThrowExceptionOnInvalidPartialTrainingClasses(): void
232
    {
233
        $this->expectException(InvalidArgumentException::class);
234
235
236
237
238
239
240
        $classifier = new MLPClassifier(2, [2], [0, 1]);
        $classifier->partialTrain(
            [[0, 1], [1, 0]],
            [0, 2],
            [0, 1, 2]
        );
    }
241

Tomáš Votruba's avatar
Tomáš Votruba committed
242
    public function testThrowExceptionOnInvalidClassesNumber(): void
243
    {
244
        $this->expectException(InvalidArgumentException::class);
245
246
247
        new MLPClassifier(2, [2], [0]);
    }

248
249
250
251
252
253
254
    public function testOutputWithLabels(): void
    {
        $output = (new MLPClassifier(2, [2, 2], ['T', 'F']))->getOutput();

        $this->assertEquals(['T', 'F'], array_keys($output));
    }

255
    private function getSynapsesNodes(array $synapses): array
256
257
258
259
260
261
262
263
264
    {
        $nodes = [];
        foreach ($synapses as $synapse) {
            $nodes[] = $synapse->getNode();
        }

        return $nodes;
    }
}