Unverified Commit d30c212f authored by Arkadiusz Kondas's avatar Arkadiusz Kondas Committed by GitHub
Browse files

Check if feature exist when predict target in NaiveBayes (#327)

* Check if feature exist when predict target in NaiveBayes

* Fix typo
parent 18c36b97
......@@ -4,6 +4,7 @@ declare(strict_types=1);
namespace Phpml\Classification;
use Phpml\Exception\InvalidArgumentException;
use Phpml\Helper\Predictable;
use Phpml\Helper\Trainable;
use Phpml\Math\Statistic\Mean;
......@@ -137,6 +138,10 @@ class NaiveBayes implements Classifier
*/
private function sampleProbability(array $sample, int $feature, string $label): float
{
if (!isset($sample[$feature])) {
throw new InvalidArgumentException('Missing feature. All samples must have equal number of features');
}
$value = $sample[$feature];
if ($this->dataType[$label][$feature] == self::NOMINAL) {
if (!isset($this->discreteProb[$label][$feature][$value]) ||
......
......@@ -5,6 +5,7 @@ declare(strict_types=1);
namespace Phpml\Tests\Classification;
use Phpml\Classification\NaiveBayes;
use Phpml\Exception\InvalidArgumentException;
use Phpml\ModelManager;
use PHPUnit\Framework\TestCase;
......@@ -125,4 +126,19 @@ class NaiveBayesTest extends TestCase
self::assertEquals($classifier, $restoredClassifier);
self::assertEquals($predicted, $restoredClassifier->predict($testSamples));
}
public function testInconsistentFeaturesInSamples(): void
{
$trainSamples = [[5, 1, 1], [1, 5, 1], [1, 1, 5]];
$trainLabels = ['1996', '1997', '1998'];
$testSamples = [[3, 1, 1], [5, 1], [4, 3, 8]];
$classifier = new NaiveBayes();
$classifier->train($trainSamples, $trainLabels);
$this->expectException(InvalidArgumentException::class);
$classifier->predict($testSamples);
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment