Our brain is shaped like a walnut. And that’s for a reason.
The whiskings and bulges of our brain are a clever way of nature to provide a very interesting structure to process data.
At the intersection of two bulges, data can be interchanged while in the bulge itself, the „thinking“ is kind of separated.
Let me introduce Walnut AI: an AI pattern that works on a massive scale of fixed-size small matrices. Why fixed-size matrices? Because compilers can loop-unroll all algorithms on the matrix as long as the size is small and fixed. This makes the network make use of SSE and AVX extensions of modern CPUs.
Of course, small matrices can not process that much inputs as required in a typical AI scenario. But this is no obstacle. One can simply connect multiple small matrices to a big network like in the bitonic sorter network. The bitonic pattern is able to shuffle data around in a manner that at least every value can be compared to every other value. It has a depth complexity of O(log²N) (in the parallel case) and a requirement of O(N*log²N) sorter nodes.
We already trained a single walnut node to learn sequences of numbers like you would in a language model like ChatGPT. This is the output of a 16×16 Matrix learning the pattern 1,1,0,1,1,0,1,1,0,1,1,0:
in = 1.0000 0.5856 0.7723 0.7887 0.4882 0.2105 0.1576 0.0354 0.3160 0.0931 0.9431 0.9918 0.0214 0.4702 0.0218 0.0164 out = 0.9997 0.9470 0.8375 0.9348 0.9596 0.6570 0.3475 0.1244 0.6792 0.1485 0.9865 0.9987 0.1671 0.5460 0.1714 0.1164 err = 0.0003 0.0013 0.0000 -0.0003 0.0004 0.0001 -0.0004 0.0010 0.0002 0.0004 0.0014 0.0015 -0.0009 -0.0003 0.0009 -0.0011 in = 1.0000 0.9470 0.8375 0.9348 0.9596 0.6570 0.3475 0.1244 0.6792 0.1485 0.9865 0.9987 0.1671 0.5460 0.1714 0.1164 out = 0.0007 0.9565 0.9029 0.9496 0.9265 0.6060 0.1951 0.0385 0.6349 0.1823 0.9919 0.9998 0.0632 0.6765 0.0597 0.0489 err = -0.0007 -0.0000 -0.0000 0.0003 -0.0010 -0.0002 0.0013 0.0012 0.0006 -0.0013 -0.0004 -0.0001 0.0005 -0.0001 -0.0004 -0.0000 in = 0.0000 0.9565 0.9029 0.9496 0.9265 0.6060 0.1951 0.0385 0.6349 0.1823 0.9919 0.9998 0.0632 0.6765 0.0597 0.0489 out = 0.9998 0.5857 0.7724 0.7888 0.4879 0.2104 0.1576 0.0354 0.3158 0.0931 0.9432 0.9918 0.0214 0.4704 0.0218 0.0164 err = 0.0002 0.0009 -0.0005 -0.0004 0.0003 -0.0015 0.0001 0.0003 -0.0009 -0.0006 0.0000 -0.0013 0.0004 -0.0011 0.0003 0.0002 in = 1.0000 0.5857 0.7724 0.7888 0.4879 0.2104 0.1576 0.0354 0.3158 0.0931 0.9432 0.9918 0.0214 0.4704 0.0218 0.0164 out = 0.9997 0.9470 0.8377 0.9348 0.9596 0.6569 0.3475 0.1245 0.6793 0.1484 0.9865 0.9987 0.1672 0.5461 0.1711 0.1164 err = 0.0003 -0.0003 0.0008 0.0011 -0.0000 -0.0013 0.0001 0.0001 -0.0002 -0.0005 0.0001 0.0004 -0.0003 -0.0011 -0.0003 0.0012 in = 1.0000 0.9470 0.8377 0.9348 0.9596 0.6569 0.3475 0.1245 0.6793 0.1484 0.9865 0.9987 0.1672 0.5461 0.1711 0.1164 out = 0.0007 0.9565 0.9030 0.9495 0.9265 0.6056 0.1951 0.0386 0.6348 0.1824 0.9919 0.9998 0.0632 0.6765 0.0596 0.0489 err = -0.0007 0.0004 0.0005 0.0001 0.0005 0.0003 -0.0009 -0.0003 -0.0003 -0.0002 0.0002 0.0015 -0.0008 -0.0002 -0.0005 0.0003 in = 0.0000 0.9565 0.9030 0.9495 0.9265 0.6056 0.1951 0.0386 0.6348 0.1824 0.9919 0.9998 0.0632 0.6765 0.0596 0.0489 out = 0.9998 0.5855 0.7724 0.7888 0.4882 0.2103 0.1577 0.0354 0.3157 0.0931 0.9433 0.9918 0.0214 0.4706 0.0218 0.0164 err = 0.0002 -0.0002 -0.0000 -0.0006 0.0005 -0.0001 0.0003 -0.0011 0.0000 -0.0006 0.0008 0.0005 0.0002 0.0001 0.0010 0.0003 in = 1.0000 0.5855 0.7724 0.7888 0.4882 0.2103 0.1577 0.0354 0.3157 0.0931 0.9433 0.9918 0.0214 0.4706 0.0218 0.0164 out = 0.9997 0.9470 0.8378 0.9348 0.9596 0.6568 0.3474 0.1245 0.6793 0.1484 0.9866 0.9987 0.1671 0.5464 0.1712 0.1164 err = 0.0003 0.0012 0.0008 -0.0003 -0.0004 -0.0004 0.0002 -0.0007 0.0007 -0.0006 0.0006 0.0000 0.0009 0.0002 -0.0005 0.0004 in = 1.0000 0.9470 0.8378 0.9348 0.9596 0.6568 0.3474 0.1245 0.6793 0.1484 0.9866 0.9987 0.1671 0.5464 0.1712 0.1164 out = 0.0007 0.9565 0.9030 0.9495 0.9266 0.6056 0.1952 0.0386 0.6349 0.1825 0.9919 0.9998 0.0632 0.6768 0.0596 0.0489 err = -0.0007 -0.0000 0.0000 0.0002 -0.0004 0.0003 -0.0004 0.0012 0.0001 -0.0006 0.0012 0.0008 -0.0015 -0.0010 -0.0012 -0.0001 in = 0.0000 0.9565 0.9030 0.9495 0.9266 0.6056 0.1952 0.0386 0.6349 0.1825 0.9919 0.9998 0.0632 0.6768 0.0596 0.0489 out = 0.9998 0.5857 0.7725 0.7889 0.4883 0.2103 0.1576 0.0354 0.3157 0.0932 0.9433 0.9918 0.0214 0.4709 0.0218 0.0164 err = 0.0002 0.0007 0.0001 -0.0013 0.0001 -0.0003 -0.0000 0.0001 -0.0002 0.0003 0.0006 0.0002 -0.0005 0.0008 -0.0008 -0.0014 in = 1.0000 0.5857 0.7725 0.7889 0.4883 0.2103 0.1576 0.0354 0.3157 0.0932 0.9433 0.9918 0.0214 0.4709 0.0218 0.0164 out = 0.9997 0.9470 0.8378 0.9348 0.9596 0.6567 0.3474 0.1244 0.6794 0.1485 0.9866 0.9987 0.1671 0.5466 0.1711 0.1163 err = 0.0003 -0.0005 0.0007 0.0005 -0.0001 -0.0005 -0.0013 0.0005 -0.0006 0.0003 -0.0006 -0.0000 0.0001 0.0000 -0.0005 -0.0011 in = 1.0000 0.9470 0.8378 0.9348 0.9596 0.6567 0.3474 0.1244 0.6794 0.1485 0.9866 0.9987 0.1671 0.5466 0.1711 0.1163 out = 0.0007 0.9565 0.9030 0.9495 0.9266 0.6055 0.1951 0.0385 0.6347 0.1823 0.9919 0.9998 0.0632 0.6772 0.0596 0.0489 err = -0.0007 -0.0003 -0.0000 0.0003 0.0009 0.0002 0.0008 -0.0005 0.0001 -0.0014 0.0003 -0.0000 0.0000 0.0006 -0.0012 -0.0003 in = 0.0000 0.9565 0.9030 0.9495 0.9266 0.6055 0.1951 0.0385 0.6347 0.1823 0.9919 0.9998 0.0632 0.6772 0.0596 0.0489 out = 0.9998 0.5856 0.7726 0.7887 0.4883 0.2103 0.1577 0.0354 0.3156 0.0932 0.9434 0.9918 0.0214 0.4712 0.0217 0.0164 err = 0.0002 0.0006 -0.0013 -0.0008 -0.0001 0.0012 -0.0004 -0.0012 0.0008 -0.0005 0.0009 -0.0010 -0.0004 -0.0007 -0.0005 -0.0007 in = 1.0000 0.5856 0.7726 0.7887 0.4883 0.2103 0.1577 0.0354 0.3156 0.0932 0.9434 0.9918 0.0214 0.4712 0.0217 0.0164 out = 0.9997 0.9470 0.8378 0.9347 0.9596 0.6568 0.3474 0.1244 0.6793 0.1483 0.9866 0.9987 0.1673 0.5470 0.1710 0.1162 err = 0.0003 0.0008 -0.0009 -0.0002 0.0006 -0.0005 0.0012 0.0002 0.0002 -0.0005 -0.0004 -0.0004 -0.0002 -0.0005 0.0003 0.0001 in = 1.0000 0.9470 0.8378 0.9347 0.9596 0.6568 0.3474 0.1244 0.6793 0.1483 0.9866 0.9987 0.1673 0.5470 0.1710 0.1162 out = 0.0007 0.9565 0.9030 0.9494 0.9265 0.6056 0.1952 0.0385 0.6344 0.1823 0.9920 0.9998 0.0633 0.6773 0.0595 0.0489 err = -0.0007 -0.0001 -0.0005 0.0002 -0.0011 0.0009 0.0019 0.0003 -0.0014 -0.0008 -0.0005 0.0005 0.0002 0.0006 0.0004 0.0008 in = 0.0000 0.9565 0.9030 0.9494 0.9265 0.6056 0.1952 0.0385 0.6344 0.1823 0.9920 0.9998 0.0633 0.6773 0.0595 0.0489 out = 0.9998 0.5857 0.7725 0.7886 0.4878 0.2103 0.1578 0.0354 0.3153 0.0931 0.9434 0.9918 0.0214 0.4712 0.0217 0.0164 err = 0.0002 -0.0012 -0.0005 0.0002 -0.0009 -0.0011 -0.0002 -0.0005 -0.0003 -0.0006 -0.0009 0.0008 -0.0006 0.0008 -0.0007 0.0003 in = 1.0000 0.5857 0.7725 0.7886 0.4878 0.2103 0.1578 0.0354 0.3153 0.0931 0.9434 0.9918 0.0214 0.4712 0.0217 0.0164 out = 0.9997 0.9469 0.8378 0.9347 0.9596 0.6567 0.3476 0.1243 0.6790 0.1482 0.9866 0.9987 0.1673 0.5473 0.1710 0.1162 err = 0.0003 0.0005 -0.0007 -0.0001 -0.0009 -0.0007 -0.0006 0.0003 0.0005 0.0008 -0.0004 -0.0008 -0.0004 -0.0006 0.0001 -0.0006 in = 1.0000 0.9469 0.8378 0.9347 0.9596 0.6567 0.3476 0.1243 0.6790 0.1482 0.9866 0.9987 0.1673 0.5473 0.1710 0.1162 out = 0.0007 0.9564 0.9030 0.9494 0.9264 0.6053 0.1955 0.0385 0.6341 0.1821 0.9920 0.9998 0.0632 0.6774 0.0595 0.0488 err = -0.0007 0.0002 -0.0004 0.0007 -0.0001 0.0011 -0.0004 0.0009 -0.0004 -0.0011 -0.0011 0.0016 -0.0001 -0.0005 -0.0004 -0.0001 in = 0.0000 0.9564 0.9030 0.9494 0.9264 0.6053 0.1955 0.0385 0.6341 0.1821 0.9920 0.9998 0.0632 0.6774 0.0595 0.0488 out = 0.9998 0.5853 0.7725 0.7886 0.4880 0.2104 0.1580 0.0354 0.3153 0.0931 0.9434 0.9918 0.0215 0.4712 0.0217 0.0164 err = 0.0002 0.0007 0.0014 0.0003 -0.0002 0.0006 0.0013 0.0005 -0.0003 0.0002 -0.0001 0.0005 0.0014 0.0007 0.0002 0.0009 in = 1.0000 0.5853 0.7725 0.7886 0.4880 0.2104 0.1580 0.0354 0.3153 0.0931 0.9434 0.9918 0.0215 0.4712 0.0217 0.0164 out = 0.9997 0.9469 0.8376 0.9347 0.9596 0.6568 0.3479 0.1244 0.6791 0.1482 0.9866 0.9987 0.1674 0.5473 0.1711 0.1162 err = 0.0003 0.0002 -0.0009 0.0004 0.0004 0.0002 -0.0007 -0.0007 -0.0007 0.0005 0.0003 -0.0008 -0.0000 -0.0010 0.0007 0.0002 in = 1.0000 0.9469 0.8376 0.9347 0.9596 0.6568 0.3479 0.1244 0.6791 0.1482 0.9866 0.9987 0.1674 0.5473 0.1711 0.1162 out = 0.0007 0.9564 0.9029 0.9494 0.9265 0.6058 0.1955 0.0385 0.6340 0.1822 0.9920 0.9998 0.0632 0.6778 0.0595 0.0488 err = -0.0007 0.0005 -0.0017 0.0007 0.0003 0.0008 -0.0008 -0.0002 -0.0005 0.0003 -0.0005 0.0012 -0.0002 0.0001 -0.0004 -0.0001 learn result: -9.1130 -4.6952 10.3669 1.1419 -6.2698 -10.0712 -2.9436 -4.3510 -1.4037 -10.0334 5.1694 4.8659 -6.3518 5.0172 -18.3967 -18.4768 6.8743 2.5601 -0.8417 0.3259 -0.4397 -0.1577 0.9781 -0.4246 0.5796 0.1861 0.7129 -0.1788 0.7377 0.0229 -0.5120 0.4339 0.8746 0.3554 0.8910 0.8962 0.1511 -0.6213 -0.3017 0.7336 0.7339 -0.0504 -0.5286 -1.0849 0.0232 -0.7474 -0.5905 1.2326 1.2608 0.8690 0.8541 1.6852 0.0285 -0.6908 0.9214 -0.2172 0.4253 -0.0445 -0.8371 -0.0912 1.0301 0.9175 0.4505 -0.2547 0.3568 0.5283 0.3138 -0.7397 2.5996 -0.9552 -0.6625 0.6841 0.2753 -0.9938 -0.5806 0.8085 -0.6441 0.4657 1.0774 -0.1523 0.3393 0.3682 0.0272 0.8907 0.3349 1.5433 0.3355 -0.2034 0.5916 -0.1975 -0.5763 0.9588 0.1515 -0.9263 1.0299 0.0268 -0.2792 -0.1076 -0.7625 0.7556 -0.1802 -0.5422 0.3006 -0.4480 -0.6562 -0.4864 -0.1389 -0.4163 -0.7516 0.8076 -1.1816 -0.4384 -0.0586 0.8537 0.8005 0.9533 0.4370 0.0824 -0.3859 0.5443 -0.4398 -0.4003 -0.8587 0.1024 -0.7692 -0.7797 0.3381 -0.9221 0.6265 -0.7598 0.3027 -1.0755 0.8635 -0.8534 0.0740 -0.7462 1.2046 -0.6341 0.3949 0.2949 -0.0074 1.0891 0.3266 0.6937 -1.2099 -0.5910 -0.3737 -0.8537 -1.1373 -0.6702 -0.2335 1.0057 1.0412 0.8039 0.3254 -1.1183 -0.9075 0.9645 -0.6656 0.8144 -0.9859 0.7908 -0.8362 0.3005 -1.1398 -0.8621 0.6471 0.1135 -0.7613 -1.1906 1.9425 -0.0765 0.4977 0.6575 -0.2365 0.3545 0.3037 0.0675 0.3144 -0.6991 -0.1284 0.3520 -0.1917 0.8875 0.4458 1.2798 0.7816 3.3990 0.5631 -1.0584 0.7851 0.9132 1.2111 1.0663 0.6461 0.9690 0.6089 1.0361 1.3116 0.5554 -0.0895 -0.8398 0.9198 -0.3761 1.1812 -1.1888 0.2389 -0.7478 0.0856 -0.9482 0.4038 -0.1398 -0.5496 -0.2332 -1.1034 -0.0459 -0.5768 0.3246 -0.1756 0.9763 -0.4574 0.6902 0.4523 0.4499 -0.6451 0.1693 0.1208 1.0342 0.0904 -0.0714 0.7592 -0.0598 -0.7098 -0.3367 0.2257 1.1379 -0.2981 -0.2852 1.0571 -0.6162 -0.5068 -1.0256 -0.1029 -0.9256 0.1746 -1.0149 -0.3888 0.0819 -0.5433 -0.1432 -0.4409 -0.7199 0.1160 0.0509 0.2938 1.0808 -0.9413 0.1215 -1.2361 -0.0094 0.1238 -0.0850 -0.3479 -0.3395 -1.0339 -1.3508 -0.2073 0.4107 -0.9436 -0.7933 -0.6040 0.4724
Every triple of input
, output
and error
is a 16-vector where the first element is our input/output while the rest of the 15 variables are its internal thinking – the walnut bulge. Every output of a bulge node is feedbacked into the very same slot in the input while the IO variables are passed in and out.
Every training step includes adjusting the matrix weights according to the error vector for the desired output. In the same step, the error vector for the input is constructed so errors can be passed down the network. Every error will induce a slight change of the matrix weights in the right direction (we only assure that the sign of the error is correct and the value of the error will not induce an overswing).
The bencharked performance on a AMD Ryzen 9 3900X (zen2) is around 166,000 16×16-matrix learn propagations/sec on a single CPU core.
Our next steps are to create a network of these matrices and start giving them fodder. The use case of walnut AI is sequenced data like language models or request-response patterns for classification.
Comments are closed