MLPClassifier.php 1.38 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
<?php

declare(strict_types=1);

namespace Phpml\Classification;

use Phpml\Exception\InvalidArgumentException;
use Phpml\NeuralNetwork\Network\MultilayerPerceptron;

class MLPClassifier extends MultilayerPerceptron implements Classifier
{
    /**
13
14
15
     * @param mixed $target
     *
     * @throws InvalidArgumentException
16
     */
17
    public function getTargetClass($target): int
18
    {
19
        if (!in_array($target, $this->classes, true)) {
20
21
22
            throw new InvalidArgumentException(
                sprintf('Target with value "%s" is not part of the accepted classes', $target)
            );
23
        }
24

25
        return array_search($target, $this->classes, true);
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
    }

    /**
     * @return mixed
     */
    protected function predictSample(array $sample)
    {
        $output = $this->setInput($sample)->getOutput();

        $predictedClass = null;
        $max = 0;
        foreach ($output as $class => $value) {
            if ($value > $max) {
                $predictedClass = $class;
                $max = $value;
            }
        }
43

44
        return $predictedClass;
45
46
47
48
49
    }

    /**
     * @param mixed $target
     */
Tomáš Votruba's avatar
Tomáš Votruba committed
50
    protected function trainSample(array $sample, $target): void
51
52
    {
        // Feed-forward.
53
        $this->setInput($sample);
54
55
56
57
58

        // Back-propagate.
        $this->backpropagation->backpropagate($this->getLayers(), $this->getTargetClass($target));
    }
}