Matrix multiplication fixes

This commit is contained in:
Clemens Schwaighofer
2024-11-18 14:44:18 +09:00
parent 0a45300c21
commit cb3d5e1f27
2 changed files with 112 additions and 10 deletions

View File

@@ -156,12 +156,32 @@ final class CoreLibsConvertMathTest extends TestCase
public function providerMultiplyMatrices(): array public function providerMultiplyMatrices(): array
{ {
return [ return [
'single' => [ '[3] x [3] => [3x1]' => [
[1, 2, 3], [1, 2, 3],
[1, 2, 3], [1, 2, 3],
[14] [14]
], ],
'double first' => [ '[3] x [3x1]' => [
[1, 2, 3],
[[1], [2], [3]],
[14]
],
'[3] x [3x1]' => [
[1, 2, 3],
[[1], [2], [3]],
[14]
],
'[1x3L] x [3x1]' => [
[[1, 2, 3]],
[[1], [2], [3]],
[14]
],
'[1x3] x [3x1]' => [
[[1], [2], [3]],
[[1], [2], [3]],
[1, 2, 3]
],
'[2x3] x [3] => [3x1]' => [
[ [
[1, 2, 3], [1, 2, 3],
[1, 2, 3] [1, 2, 3]
@@ -172,7 +192,18 @@ final class CoreLibsConvertMathTest extends TestCase
14 14
] ]
], ],
'double both' => [ '[2x3] x [3x1]' => [
[
[1, 2, 3],
[1, 2, 3]
],
[[1], [2], [3]],
[
14,
14
]
],
'[2x3] x [2x3] => [3x3]' => [
[ [
[1, 2, 3], [1, 2, 3],
[1, 2, 3], [1, 2, 3],
@@ -186,7 +217,37 @@ final class CoreLibsConvertMathTest extends TestCase
[3, 6, 9] [3, 6, 9]
] ]
], ],
'tripple first, single second' => [ '[2x3] x [3x3]' => [
[
[1, 2, 3],
[1, 2, 3],
],
[
[1, 2, 3],
[1, 2, 3],
[0, 0, 0],
],
[
[3, 6, 9],
[3, 6, 9]
]
],
'[2x3] x [3x2]' => [
'a' => [
[1, 2, 3],
[1, 2, 3],
],
'b' => [
[1, 1],
[2, 2],
[3, 3],
],
'prod' => [
[14, 14],
[14, 14],
]
],
'[3x3] x [3] => [1x3]' => [
[ [
[1, 2, 3], [1, 2, 3],
[1, 2, 3], [1, 2, 3],
@@ -199,7 +260,7 @@ final class CoreLibsConvertMathTest extends TestCase
14 14
] ]
], ],
'tripple first, double second' => [ '[3x3] x [2x3] => [3x3]' => [
[ [
[1, 2, 3], [1, 2, 3],
[1, 2, 3], [1, 2, 3],
@@ -215,7 +276,24 @@ final class CoreLibsConvertMathTest extends TestCase
[3, 6, 9], [3, 6, 9],
] ]
], ],
'single first, tripple second' => [ '[3x3] x [3x3]' => [
[
[1, 2, 3],
[1, 2, 3],
[1, 2, 3],
],
[
[1, 2, 3],
[1, 2, 3],
// [0, 0, 0],
],
[
[3, 6, 9],
[3, 6, 9],
[3, 6, 9],
]
],
'[3] x [3x3]' => [
[1, 2, 3], [1, 2, 3],
[ [
[1, 2, 3], [1, 2, 3],
@@ -226,7 +304,7 @@ final class CoreLibsConvertMathTest extends TestCase
[6, 12, 18], [6, 12, 18],
] ]
], ],
'double first, tripple second' => [ '[2x3] x [3x3]' => [
[ [
[1, 2, 3], [1, 2, 3],
[1, 2, 3], [1, 2, 3],

View File

@@ -136,6 +136,28 @@ class Math
* *
* It returns an array which is the product of the two number matrices passed as parameters. * It returns an array which is the product of the two number matrices passed as parameters.
* *
* NOTE:
* if the right side (B matrix) has a missing row, this row will be fillwed with 0 instead of
* throwing an error:
* A:
* [
* [1, 2, 3],
* [4, 5, 6],
* ]
* B:
* [
* [7, 8, 9],
* [10, 11, 12],
* ]
* The B will get a third row with [0, 0, 0] added to make the multiplication work as it will be
* rewritten as
* B-rewrite:
* [
* [7, 10, 0],
* [8, 11, 12],
* [0, 0, 0] <- automatically added
* ]
*
* @param array<float|int|array<int|float>> $a m x n matrice * @param array<float|int|array<int|float>> $a m x n matrice
* @param array<float|int|array<int|float>> $b n x p matrice * @param array<float|int|array<int|float>> $b n x p matrice
* *
@@ -161,8 +183,9 @@ class Math
$p = count($b[0]); $p = count($b[0]);
// transpose $b: // transpose $b:
// so that we can multiply row by row
$bCols = array_map( $bCols = array_map(
callback: fn ($k) => \array_map( callback: fn ($k) => array_map(
(fn ($i) => is_array($i) ? $i[$k] : 0), (fn ($i) => is_array($i) ? $i[$k] : 0),
$b, $b,
), ),
@@ -175,7 +198,8 @@ class Math
array_reduce( array_reduce(
array: $row, array: $row,
callback: fn ($a, $v, $i = null) => $a + $v * ( callback: fn ($a, $v, $i = null) => $a + $v * (
$col[$i ?? array_search($v, $row) ?: 0] // if last entry missing for full copy add a 0 to it
$col[$i ?? array_search($v, $row, true)] ?? 0 /** @phpstan-ignore-line */
), ),
initial: 0, initial: 0,
) : ) :
@@ -191,7 +215,7 @@ class Math
if ($m === 1) { if ($m === 1) {
// Avoid [[a, b, c, ...]]: // Avoid [[a, b, c, ...]]:
$product = $product[0]; return $product[0];
} }
if ($p === 1) { if ($p === 1) {