Walnut AI – a CPU-optimized AI neuronal network

Our brain is shaped like a walnut. And that’s for a reason.

The whiskings and bulges of our brain are a clever way of nature to provide a very interesting structure to process data.

At the intersection of two bulges, data can be interchanged while in the bulge itself, the „thinking“ is kind of separated.

Let me introduce Walnut AI: an AI pattern that works on a massive scale of fixed-size small matrices. Why fixed-size matrices? Because compilers can loop-unroll all algorithms on the matrix as long as the size is small and fixed. This makes the network make use of SSE and AVX extensions of modern CPUs.

Of course, small matrices can not process that much inputs as required in a typical AI scenario. But this is no obstacle. One can simply connect multiple small matrices to a big network like in the bitonic sorter network. The bitonic pattern is able to shuffle data around in a manner that at least every value can be compared to every other value. It has a depth complexity of O(log²N) (in the parallel case) and a requirement of O(N*log²N) sorter nodes.

We already trained a single walnut node to learn sequences of numbers like you would in a language model like ChatGPT. This is the output of a 16×16 Matrix learning the pattern 1,1,0,1,1,0,1,1,0,1,1,0:

in  = 1.0000 0.5856 0.7723 0.7887 0.4882 0.2105 0.1576 0.0354 0.3160 0.0931 0.9431 0.9918 0.0214 0.4702 0.0218 0.0164 
out = 0.9997 0.9470 0.8375 0.9348 0.9596 0.6570 0.3475 0.1244 0.6792 0.1485 0.9865 0.9987 0.1671 0.5460 0.1714 0.1164 
err = 0.0003 0.0013 0.0000 -0.0003 0.0004 0.0001 -0.0004 0.0010 0.0002 0.0004 0.0014 0.0015 -0.0009 -0.0003 0.0009 -0.0011 

in  = 1.0000 0.9470 0.8375 0.9348 0.9596 0.6570 0.3475 0.1244 0.6792 0.1485 0.9865 0.9987 0.1671 0.5460 0.1714 0.1164 
out = 0.0007 0.9565 0.9029 0.9496 0.9265 0.6060 0.1951 0.0385 0.6349 0.1823 0.9919 0.9998 0.0632 0.6765 0.0597 0.0489 
err = -0.0007 -0.0000 -0.0000 0.0003 -0.0010 -0.0002 0.0013 0.0012 0.0006 -0.0013 -0.0004 -0.0001 0.0005 -0.0001 -0.0004 -0.0000 

in  = 0.0000 0.9565 0.9029 0.9496 0.9265 0.6060 0.1951 0.0385 0.6349 0.1823 0.9919 0.9998 0.0632 0.6765 0.0597 0.0489 
out = 0.9998 0.5857 0.7724 0.7888 0.4879 0.2104 0.1576 0.0354 0.3158 0.0931 0.9432 0.9918 0.0214 0.4704 0.0218 0.0164 
err = 0.0002 0.0009 -0.0005 -0.0004 0.0003 -0.0015 0.0001 0.0003 -0.0009 -0.0006 0.0000 -0.0013 0.0004 -0.0011 0.0003 0.0002 

in  = 1.0000 0.5857 0.7724 0.7888 0.4879 0.2104 0.1576 0.0354 0.3158 0.0931 0.9432 0.9918 0.0214 0.4704 0.0218 0.0164 
out = 0.9997 0.9470 0.8377 0.9348 0.9596 0.6569 0.3475 0.1245 0.6793 0.1484 0.9865 0.9987 0.1672 0.5461 0.1711 0.1164 
err = 0.0003 -0.0003 0.0008 0.0011 -0.0000 -0.0013 0.0001 0.0001 -0.0002 -0.0005 0.0001 0.0004 -0.0003 -0.0011 -0.0003 0.0012 

in  = 1.0000 0.9470 0.8377 0.9348 0.9596 0.6569 0.3475 0.1245 0.6793 0.1484 0.9865 0.9987 0.1672 0.5461 0.1711 0.1164 
out = 0.0007 0.9565 0.9030 0.9495 0.9265 0.6056 0.1951 0.0386 0.6348 0.1824 0.9919 0.9998 0.0632 0.6765 0.0596 0.0489 
err = -0.0007 0.0004 0.0005 0.0001 0.0005 0.0003 -0.0009 -0.0003 -0.0003 -0.0002 0.0002 0.0015 -0.0008 -0.0002 -0.0005 0.0003 

in  = 0.0000 0.9565 0.9030 0.9495 0.9265 0.6056 0.1951 0.0386 0.6348 0.1824 0.9919 0.9998 0.0632 0.6765 0.0596 0.0489 
out = 0.9998 0.5855 0.7724 0.7888 0.4882 0.2103 0.1577 0.0354 0.3157 0.0931 0.9433 0.9918 0.0214 0.4706 0.0218 0.0164 
err = 0.0002 -0.0002 -0.0000 -0.0006 0.0005 -0.0001 0.0003 -0.0011 0.0000 -0.0006 0.0008 0.0005 0.0002 0.0001 0.0010 0.0003 

in  = 1.0000 0.5855 0.7724 0.7888 0.4882 0.2103 0.1577 0.0354 0.3157 0.0931 0.9433 0.9918 0.0214 0.4706 0.0218 0.0164 
out = 0.9997 0.9470 0.8378 0.9348 0.9596 0.6568 0.3474 0.1245 0.6793 0.1484 0.9866 0.9987 0.1671 0.5464 0.1712 0.1164 
err = 0.0003 0.0012 0.0008 -0.0003 -0.0004 -0.0004 0.0002 -0.0007 0.0007 -0.0006 0.0006 0.0000 0.0009 0.0002 -0.0005 0.0004 

in  = 1.0000 0.9470 0.8378 0.9348 0.9596 0.6568 0.3474 0.1245 0.6793 0.1484 0.9866 0.9987 0.1671 0.5464 0.1712 0.1164 
out = 0.0007 0.9565 0.9030 0.9495 0.9266 0.6056 0.1952 0.0386 0.6349 0.1825 0.9919 0.9998 0.0632 0.6768 0.0596 0.0489 
err = -0.0007 -0.0000 0.0000 0.0002 -0.0004 0.0003 -0.0004 0.0012 0.0001 -0.0006 0.0012 0.0008 -0.0015 -0.0010 -0.0012 -0.0001 

in  = 0.0000 0.9565 0.9030 0.9495 0.9266 0.6056 0.1952 0.0386 0.6349 0.1825 0.9919 0.9998 0.0632 0.6768 0.0596 0.0489 
out = 0.9998 0.5857 0.7725 0.7889 0.4883 0.2103 0.1576 0.0354 0.3157 0.0932 0.9433 0.9918 0.0214 0.4709 0.0218 0.0164 
err = 0.0002 0.0007 0.0001 -0.0013 0.0001 -0.0003 -0.0000 0.0001 -0.0002 0.0003 0.0006 0.0002 -0.0005 0.0008 -0.0008 -0.0014 

in  = 1.0000 0.5857 0.7725 0.7889 0.4883 0.2103 0.1576 0.0354 0.3157 0.0932 0.9433 0.9918 0.0214 0.4709 0.0218 0.0164 
out = 0.9997 0.9470 0.8378 0.9348 0.9596 0.6567 0.3474 0.1244 0.6794 0.1485 0.9866 0.9987 0.1671 0.5466 0.1711 0.1163 
err = 0.0003 -0.0005 0.0007 0.0005 -0.0001 -0.0005 -0.0013 0.0005 -0.0006 0.0003 -0.0006 -0.0000 0.0001 0.0000 -0.0005 -0.0011 

in  = 1.0000 0.9470 0.8378 0.9348 0.9596 0.6567 0.3474 0.1244 0.6794 0.1485 0.9866 0.9987 0.1671 0.5466 0.1711 0.1163 
out = 0.0007 0.9565 0.9030 0.9495 0.9266 0.6055 0.1951 0.0385 0.6347 0.1823 0.9919 0.9998 0.0632 0.6772 0.0596 0.0489 
err = -0.0007 -0.0003 -0.0000 0.0003 0.0009 0.0002 0.0008 -0.0005 0.0001 -0.0014 0.0003 -0.0000 0.0000 0.0006 -0.0012 -0.0003 

in  = 0.0000 0.9565 0.9030 0.9495 0.9266 0.6055 0.1951 0.0385 0.6347 0.1823 0.9919 0.9998 0.0632 0.6772 0.0596 0.0489 
out = 0.9998 0.5856 0.7726 0.7887 0.4883 0.2103 0.1577 0.0354 0.3156 0.0932 0.9434 0.9918 0.0214 0.4712 0.0217 0.0164 
err = 0.0002 0.0006 -0.0013 -0.0008 -0.0001 0.0012 -0.0004 -0.0012 0.0008 -0.0005 0.0009 -0.0010 -0.0004 -0.0007 -0.0005 -0.0007 

in  = 1.0000 0.5856 0.7726 0.7887 0.4883 0.2103 0.1577 0.0354 0.3156 0.0932 0.9434 0.9918 0.0214 0.4712 0.0217 0.0164 
out = 0.9997 0.9470 0.8378 0.9347 0.9596 0.6568 0.3474 0.1244 0.6793 0.1483 0.9866 0.9987 0.1673 0.5470 0.1710 0.1162 
err = 0.0003 0.0008 -0.0009 -0.0002 0.0006 -0.0005 0.0012 0.0002 0.0002 -0.0005 -0.0004 -0.0004 -0.0002 -0.0005 0.0003 0.0001 

in  = 1.0000 0.9470 0.8378 0.9347 0.9596 0.6568 0.3474 0.1244 0.6793 0.1483 0.9866 0.9987 0.1673 0.5470 0.1710 0.1162 
out = 0.0007 0.9565 0.9030 0.9494 0.9265 0.6056 0.1952 0.0385 0.6344 0.1823 0.9920 0.9998 0.0633 0.6773 0.0595 0.0489 
err = -0.0007 -0.0001 -0.0005 0.0002 -0.0011 0.0009 0.0019 0.0003 -0.0014 -0.0008 -0.0005 0.0005 0.0002 0.0006 0.0004 0.0008 

in  = 0.0000 0.9565 0.9030 0.9494 0.9265 0.6056 0.1952 0.0385 0.6344 0.1823 0.9920 0.9998 0.0633 0.6773 0.0595 0.0489 
out = 0.9998 0.5857 0.7725 0.7886 0.4878 0.2103 0.1578 0.0354 0.3153 0.0931 0.9434 0.9918 0.0214 0.4712 0.0217 0.0164 
err = 0.0002 -0.0012 -0.0005 0.0002 -0.0009 -0.0011 -0.0002 -0.0005 -0.0003 -0.0006 -0.0009 0.0008 -0.0006 0.0008 -0.0007 0.0003 

in  = 1.0000 0.5857 0.7725 0.7886 0.4878 0.2103 0.1578 0.0354 0.3153 0.0931 0.9434 0.9918 0.0214 0.4712 0.0217 0.0164 
out = 0.9997 0.9469 0.8378 0.9347 0.9596 0.6567 0.3476 0.1243 0.6790 0.1482 0.9866 0.9987 0.1673 0.5473 0.1710 0.1162 
err = 0.0003 0.0005 -0.0007 -0.0001 -0.0009 -0.0007 -0.0006 0.0003 0.0005 0.0008 -0.0004 -0.0008 -0.0004 -0.0006 0.0001 -0.0006 

in  = 1.0000 0.9469 0.8378 0.9347 0.9596 0.6567 0.3476 0.1243 0.6790 0.1482 0.9866 0.9987 0.1673 0.5473 0.1710 0.1162 
out = 0.0007 0.9564 0.9030 0.9494 0.9264 0.6053 0.1955 0.0385 0.6341 0.1821 0.9920 0.9998 0.0632 0.6774 0.0595 0.0488 
err = -0.0007 0.0002 -0.0004 0.0007 -0.0001 0.0011 -0.0004 0.0009 -0.0004 -0.0011 -0.0011 0.0016 -0.0001 -0.0005 -0.0004 -0.0001 

in  = 0.0000 0.9564 0.9030 0.9494 0.9264 0.6053 0.1955 0.0385 0.6341 0.1821 0.9920 0.9998 0.0632 0.6774 0.0595 0.0488 
out = 0.9998 0.5853 0.7725 0.7886 0.4880 0.2104 0.1580 0.0354 0.3153 0.0931 0.9434 0.9918 0.0215 0.4712 0.0217 0.0164 
err = 0.0002 0.0007 0.0014 0.0003 -0.0002 0.0006 0.0013 0.0005 -0.0003 0.0002 -0.0001 0.0005 0.0014 0.0007 0.0002 0.0009 

in  = 1.0000 0.5853 0.7725 0.7886 0.4880 0.2104 0.1580 0.0354 0.3153 0.0931 0.9434 0.9918 0.0215 0.4712 0.0217 0.0164 
out = 0.9997 0.9469 0.8376 0.9347 0.9596 0.6568 0.3479 0.1244 0.6791 0.1482 0.9866 0.9987 0.1674 0.5473 0.1711 0.1162 
err = 0.0003 0.0002 -0.0009 0.0004 0.0004 0.0002 -0.0007 -0.0007 -0.0007 0.0005 0.0003 -0.0008 -0.0000 -0.0010 0.0007 0.0002 

in  = 1.0000 0.9469 0.8376 0.9347 0.9596 0.6568 0.3479 0.1244 0.6791 0.1482 0.9866 0.9987 0.1674 0.5473 0.1711 0.1162 
out = 0.0007 0.9564 0.9029 0.9494 0.9265 0.6058 0.1955 0.0385 0.6340 0.1822 0.9920 0.9998 0.0632 0.6778 0.0595 0.0488 
err = -0.0007 0.0005 -0.0017 0.0007 0.0003 0.0008 -0.0008 -0.0002 -0.0005 0.0003 -0.0005 0.0012 -0.0002 0.0001 -0.0004 -0.0001 

learn result:
-9.1130 -4.6952 10.3669 1.1419 -6.2698 -10.0712 -2.9436 -4.3510 -1.4037 -10.0334 5.1694 4.8659 -6.3518 5.0172 -18.3967 -18.4768 6.8743
2.5601 -0.8417 0.3259 -0.4397 -0.1577 0.9781 -0.4246 0.5796 0.1861 0.7129 -0.1788 0.7377 0.0229 -0.5120 0.4339 0.8746 0.3554
0.8910 0.8962 0.1511 -0.6213 -0.3017 0.7336 0.7339 -0.0504 -0.5286 -1.0849 0.0232 -0.7474 -0.5905 1.2326 1.2608 0.8690 0.8541
1.6852 0.0285 -0.6908 0.9214 -0.2172 0.4253 -0.0445 -0.8371 -0.0912 1.0301 0.9175 0.4505 -0.2547 0.3568 0.5283 0.3138 -0.7397
2.5996 -0.9552 -0.6625 0.6841 0.2753 -0.9938 -0.5806 0.8085 -0.6441 0.4657 1.0774 -0.1523 0.3393 0.3682 0.0272 0.8907 0.3349
1.5433 0.3355 -0.2034 0.5916 -0.1975 -0.5763 0.9588 0.1515 -0.9263 1.0299 0.0268 -0.2792 -0.1076 -0.7625 0.7556 -0.1802 -0.5422
0.3006 -0.4480 -0.6562 -0.4864 -0.1389 -0.4163 -0.7516 0.8076 -1.1816 -0.4384 -0.0586 0.8537 0.8005 0.9533 0.4370 0.0824 -0.3859
0.5443 -0.4398 -0.4003 -0.8587 0.1024 -0.7692 -0.7797 0.3381 -0.9221 0.6265 -0.7598 0.3027 -1.0755 0.8635 -0.8534 0.0740 -0.7462
1.2046 -0.6341 0.3949 0.2949 -0.0074 1.0891 0.3266 0.6937 -1.2099 -0.5910 -0.3737 -0.8537 -1.1373 -0.6702 -0.2335 1.0057 1.0412
0.8039 0.3254 -1.1183 -0.9075 0.9645 -0.6656 0.8144 -0.9859 0.7908 -0.8362 0.3005 -1.1398 -0.8621 0.6471 0.1135 -0.7613 -1.1906
1.9425 -0.0765 0.4977 0.6575 -0.2365 0.3545 0.3037 0.0675 0.3144 -0.6991 -0.1284 0.3520 -0.1917 0.8875 0.4458 1.2798 0.7816
3.3990 0.5631 -1.0584 0.7851 0.9132 1.2111 1.0663 0.6461 0.9690 0.6089 1.0361 1.3116 0.5554 -0.0895 -0.8398 0.9198 -0.3761
1.1812 -1.1888 0.2389 -0.7478 0.0856 -0.9482 0.4038 -0.1398 -0.5496 -0.2332 -1.1034 -0.0459 -0.5768 0.3246 -0.1756 0.9763 -0.4574
0.6902 0.4523 0.4499 -0.6451 0.1693 0.1208 1.0342 0.0904 -0.0714 0.7592 -0.0598 -0.7098 -0.3367 0.2257 1.1379 -0.2981 -0.2852
1.0571 -0.6162 -0.5068 -1.0256 -0.1029 -0.9256 0.1746 -1.0149 -0.3888 0.0819 -0.5433 -0.1432 -0.4409 -0.7199 0.1160 0.0509 0.2938
1.0808 -0.9413 0.1215 -1.2361 -0.0094 0.1238 -0.0850 -0.3479 -0.3395 -1.0339 -1.3508 -0.2073 0.4107 -0.9436 -0.7933 -0.6040 0.4724

Every triple of input, output and error is a 16-vector where the first element is our input/output while the rest of the 15 variables are its internal thinking – the walnut bulge. Every output of a bulge node is feedbacked into the very same slot in the input while the IO variables are passed in and out.

Every training step includes adjusting the matrix weights according to the error vector for the desired output. In the same step, the error vector for the input is constructed so errors can be passed down the network. Every error will induce a slight change of the matrix weights in the right direction (we only assure that the sign of the error is correct and the value of the error will not induce an overswing).

The bencharked performance on a AMD Ryzen 9 3900X (zen2) is around 166,000 16×16-matrix learn propagations/sec on a single CPU core.

Our next steps are to create a network of these matrices and start giving them fodder. The use case of walnut AI is sequenced data like language models or request-response patterns for classification.

de_DEGerman

Durch die weitere Nutzung der Seite stimmst du der Verwendung von Cookies zu. Weitere Informationen

Die Cookie-Einstellungen auf dieser Website sind auf "Cookies zulassen" eingestellt, um das beste Surferlebnis zu ermöglichen. Wenn du diese Website ohne Änderung der Cookie-Einstellungen verwendest oder auf "Akzeptieren" klickst, erklärst du sich damit einverstanden.

Schließen