diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..df02d62 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +dist-newstyle/ +result/ +.vscode/ \ No newline at end of file diff --git a/.hlint.yaml b/.hlint.yaml new file mode 100644 index 0000000..3f1b271 --- /dev/null +++ b/.hlint.yaml @@ -0,0 +1,30 @@ +##################################################################### +## HINTS + +- functions: + - {name: Test.Hspec.focus, within: []} # focus should only be used for debugging + - {name: Prelude.undefined, within: []} # Prelude.undefined should only be used temporarily + - {name: Clash.XException.undefined, within: []} # Clash undefined should only be used temporarily (use deepErrorX instead) + +- error: {lhs: fromString (show x), rhs: Data.String.Extra.show' x} +- error: {lhs: Data.Text.pack (show x), rhs: Data.String.Extra.show' x} +- error: {lhs: fromIntegral (Clash.Promoted.Nat.snatToInteger x), rhs: Clash.Promoted.Nat.snatToNum x} +- error: {lhs: Clash.Sized.Internal.BitVector.split# (ClaSH.Class.BitPack.pack x), rhs: Clash.Prelude.BitIndex.split} +- error: {lhs: Clash.Signal.mux p (fmap Just x) (pure Nothing), rhs: Clash.Signal.Extra.boolToMaybe p x} +- error: {lhs: Clash.Signal.mux p (Just <$> x) (pure Nothing), rhs: Clash.Signal.Extra.boolToMaybe p x} +- error: {lhs: Clash.Prelude.moore x id, rhs: Clash.Prelude.Moore.medvedev x} +- error: {lhs: Clash.Prelude.medvedev f x (pure ()), rhs: Clash.Source.source' (flip f ()) x} + +# We tend to use pure over return +- error: {lhs: return, rhs: pure} +- error: {lhs: ceiling (logBase 2 (fromIntegral x)), rhs: Numeric.Natural.Extra.fromNatural (Numeric.Log2.clog2 (fromIntegral x))} +- error: {lhs: floor (logBase 2 (fromIntegral x)), rhs: Numeric.Natural.Extra.fromNatural (Numeric.Log2.flog2 (fromIntegral x))} +- error: {lhs: ceiling (logBase 4 (fromIntegral x)), rhs: Numeric.Natural.Extra.fromNatural (Numeric.Log2.clog4 (fromIntegral x))} +- error: {lhs: floor (logBase 4 (fromIntegral x)), rhs: Numeric.Natural.Extra.fromNatural (Numeric.Log2.flog4 (fromIntegral x))} + +# We all know when it's appropriate to use [Char] +- ignore: {name: Use String} +- ignore: {name: Use head} +- ignore: {name: Reduce duplication} +- ignore: {name: Use tuple-section} +- ignore: {name: Use <$>} diff --git a/.stylish-haskell.yaml b/.stylish-haskell.yaml new file mode 100644 index 0000000..539a00d --- /dev/null +++ b/.stylish-haskell.yaml @@ -0,0 +1,171 @@ +# stylish-haskell configuration file +# ================================== + +# The stylish-haskell tool is mainly configured by specifying steps. These steps +# are a list, so they have an order, and one specific step may appear more than +# once (if needed). Each file is processed by these steps in the given order. +steps: + # Convert some ASCII sequences to their Unicode equivalents. This is disabled + # by default. + # - unicode_syntax: + # # In order to make this work, we also need to insert the UnicodeSyntax + # # language pragma. If this flag is set to true, we insert it when it's + # # not already present. You may want to disable it if you configure + # # language extensions using some other method than pragmas. Default: + # # true. + # add_language_pragma: true + + # Align the right hand side of some elements. This is quite conservative + # and only applies to statements where each element occupies a single + # line. + - simple_align: + cases: false + top_level_patterns: false + records: true + + # Import cleanup + - imports: + # There are different ways we can align names and lists. + # + # - global: Align the import names and import list throughout the entire + # file. + # + # - file: Like global, but don't add padding when there are no qualified + # imports in the file. + # + # - group: Only align the imports per group (a group is formed by adjacent + # import lines). + # + # - none: Do not perform any alignment. + # + # Default: global. + align: group + + # Folowing options affect only import list alignment. + # + # List align has following options: + # + # - after_alias: Import list is aligned with end of import including + # 'as' and 'hiding' keywords. + # + # > import qualified Data.List as List (concat, foldl, foldr, head, + # > init, last, length) + # + # - with_alias: Import list is aligned with start of alias or hiding. + # + # > import qualified Data.List as List (concat, foldl, foldr, head, + # > init, last, length) + # + # - new_line: Import list starts always on new line. + # + # > import qualified Data.List as List + # > (concat, foldl, foldr, head, init, last, length) + # + # Default: after_alias + list_align: after_alias + + # Long list align style takes effect when import is too long. This is + # determined by 'columns' setting. + # + # - inline: This option will put as much specs on same line as possible. + # + # - new_line: Import list will start on new line. + # + # - new_line_multiline: Import list will start on new line when it's + # short enough to fit to single line. Otherwise it'll be multiline. + # + # - multiline: One line per import list entry. + # Type with contructor list acts like single import. + # + # > import qualified Data.Map as M + # > ( empty + # > , singleton + # > , ... + # > , delete + # > ) + # + # Default: inline + long_list_align: new_line + + # List padding determines indentation of import list on lines after import. + # This option affects 'list_align' and 'long_list_align'. + list_padding: 2 + + # Separate lists option affects formating of import list for type + # or class. The only difference is single space between type and list + # of constructors, selectors and class functions. + # + # - true: There is single space between Foldable type and list of it's + # functions. + # + # > import Data.Foldable (Foldable (fold, foldl, foldMap)) + # + # - false: There is no space between Foldable type and list of it's + # functions. + # + # > import Data.Foldable (Foldable(fold, foldl, foldMap)) + # + # Default: true + separate_lists: true + + # Language pragmas + - language_pragmas: + # We can generate different styles of language pragma lists. + # + # - vertical: Vertical-spaced language pragmas, one per line. + # + # - compact: A more compact style. + # + # - compact_line: Similar to compact, but wrap each line with + # `{-#LANGUAGE #-}'. + # + # Default: vertical. + style: vertical + + # Align affects alignment of closing pragma brackets. + # + # - true: Brackets are aligned in same collumn. + # + # - false: Brackets are not aligned together. There is only one space + # between actual import and closing bracket. + # + # Default: true + align: false + + # stylish-haskell can detect redundancy of some language pragmas. If this + # is set to true, it will remove those redundant pragmas. Default: true. + remove_redundant: true + + # Replace tabs by spaces. This is disabled by default. + # - tabs: + # # Number of spaces to use for each tab. Default: 8, as specified by the + # # Haskell report. + # spaces: 8 + + # Remove trailing whitespace + - trailing_whitespace: {} + +# A common setting is the number of columns (parts of) code will be wrapped +# to. Different steps take this into account. Default: 80. +columns: 100 + +# By default, line endings are converted according to the OS. You can override +# preferred format here. +# +# - native: Native newline format. CRLF on Windows, LF on other OSes. +# +# - lf: Convert to LF ("\n"). +# +# - crlf: Convert to CRLF ("\r\n"). +# +# Default: native. +newline: lf + +# Sometimes, language extensions are specified in a cabal file or from the +# command line instead of using language pragmas in the file. stylish-haskell +# needs to be aware of these, so it can parse the file correctly. +# +# No language extensions are enabled by default. +language_extensions: + - MultiParamTypeClasses + - FlexibleContexts diff --git a/DenseTest.hs b/DenseTest.hs deleted file mode 100644 index 9a83149..0000000 --- a/DenseTest.hs +++ /dev/null @@ -1,12 +0,0 @@ -module DenseTest where - -import Dense - -class Typeable a => Dtype a where - add :: a -> a -> a - mult :: a -> a -> a - eq :: a -> a -> Bool - ... - -data NdArray where - NdArray :: Dtype a => Array DynIx a -> NdArray \ No newline at end of file diff --git a/README.md b/README.md index 27dfc78..1afa131 100644 --- a/README.md +++ b/README.md @@ -1 +1,54 @@ -# rowan-ndarray +# rowan-numskull + +## Using Numskull + +This is a Summer Internship project from 2023. Numskull is a NumPy-like library for Haskell, featuring NdArrays which can be created and manipulated to store many different types (of the DType class). + +Numskull was designed for purposes of integration into an [Onnx](https://onnx.ai/) backend, but it can be used anywhere you need to operate on arrays of unspecified type and shape. + +For more information, have a look at my talk: [slides](demo/presentation-slides.pdf). + +To run the demo you need +1) jupyter +2) iHaskell (https://github.com/IHaskell/IHaskell) to put Numskull +code into a jupyter notebook. +3) nix-shell +4) cd demo/notebook/ +5) ./start.sh + +Note that the work in main is Numskull 1.0. +Numskull 2.0 can be found in the so-called branch! The second version is less well tested and complete, but should be more efficient since it makes use of strides. I didn't have time to integrate that into the Onnx backend, but it shouldn't be at all difficult to do so. There is an open pull request so it's easy to find. + +## Development + +### Using Cabal + +This builds like any Cabal project with `cabal build`, `cabal repl`, etc. + +### Using Nix (and Cabal) + +There is a `default.nix` so the project can be built with `nix-build`, and a +`shell.nix` for a development shell with `nix-shell`. + +#### Niv + +Dependencies are maintained using `niv` in `nix/sources.json`. +Source repositories specified in `cabal.project` should be kept up to date with +`nix/sources.json`. + +#### cabal2nix + +`numskull.nix` should be updated with + +```sh +$ cabal2nix . > numskull.nix +``` + +whenever the Cabal file is updated with e.g. new dependencies. + +#### Nix shell + +Within a Nix shell, ideally you build the project with +`cabal build --project-file=cabal-nix.project` to avoid fetching and building +dependencies specifid in `cabal.project` which are only intended for the +non-Nix build. diff --git a/TypeableTest.hs b/TypeableTest.hs deleted file mode 100644 index 8899c54..0000000 --- a/TypeableTest.hs +++ /dev/null @@ -1,115 +0,0 @@ ---language extension list --- https://downloads.haskell.org/ghc/latest/docs/users_guide/exts/table.html - -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE AllowAmbiguousTypes #-} -{-# LANGUAGE TypeApplications #-} - - - -module TypeableTest where - ---import Data.Dense -import Data.Dynamic -import Data.Typeable -import Data.Vector -import Data.Array.Repa as R ---import Onnx.Representation.GraphDef - ---instance Typeable (Array U s ) - -{- -data NDArray typ where - NDArray :: (Typeable typ, Show typ, Shape sh) => - Array U sh typ -> NDArray typ -newtype NDArray a = NDArray (R.Array U s t) --} ---data NDArray2 = NDArray2 (Z :. Int :. ...) Dynamic - --- to fix w/type variable? https://downloads.haskell.org/~ghc/6.4/docs/html/users_guide/type-extensions.html ---data NDArray = NDAConstr { --- fromND :: (Shape sh) => Array U sh Dynamic ---} - ---deriving instance Show (NDArray a) ---instance Show (NDArray a) where --- show (NDArray x) = show x - ---t = NDArray (fromListUnboxed (Z :. (3::Int) :. (3::Int)) [1..9::Int]) - ---dynTypeRep (toDyn 4) == dynTypeRep (toDyn 2) ---DArray a + NDArray b = - - - - --- Dynamic type, a dype, have an instance of each supported type --- which defines how all the standard ops work --- Can you have optional functions defined here? e.g. bitshift -class Typeable a => Dype a where - add :: a -> a -> a - mult :: a -> a -> a - eq :: a -> a -> Bool - -- ... - --- You can provide a shape for the arrays but you can also not --- Need a function which works this out to convert a Dshape to a Sshape --- given the array -data Dshape = Dshape | Sshape [Int] deriving Show - --- Dshape is a hacky instance of Repa's shape so the NdArray definition --- won't complain. -instance Eq Dshape where - Dshape == Dshape = False - Sshape x == Sshape y = x == y - Dshape == Sshape _ = False - Sshape _ == Dshape = False -instance R.Shape Dshape where - rank _ = undefined - ---data NdArray where --- NdArray :: Dype a => Array U Dshape a -> NdArray - ---unwrapND :: Dype a => NdArray -> Array U Dshape a ---unwrapND (NdArray x) = x --- where intarr = --- if typeRep a == "Int" then --- arr :: Just (Array U Dshape Int) --- else Nothing - ---instance Show NdArray where --- show x | = show $ cast (Array U Dshape Int) - -data TypedNdArray type shape = TypedNdArray String [Int] - -data NdArray where - NdArray :: Dype a => Array U (Z :. Int) a -> NdArray - -instance Dype Int where - add x y = x + y - mult x y = x * y - eq x y = x == y - -unwrapND :: NdArray -> Maybe (Array U (Z :. Int) Int) -unwrapND (NdArray x) = Just x - -arr = fromListUnboxed (Z :. (2::Int)) [1,2::Int] -nd = NdArray arr -arr2 = unwrapND nd - - ---deriving instance Show (NdArray) - --- instance of Dypes for Float Int etc - -{- -addArrays :: NdArray -> NdArray -> NdArray -addArrays (NdArray x) (NdArray y) = case typeOf x `eqTypeRep` typeOf b of - Just HRefl -> NdArray (zipWith add x y) - Nothing -> throw ValueError -- or should this automatically convert? --} \ No newline at end of file diff --git a/app/Main.hs b/app/Main.hs deleted file mode 100644 index 22d7dd8..0000000 --- a/app/Main.hs +++ /dev/null @@ -1,11 +0,0 @@ -module Main where - -import qualified MyLib (someFunc) -import Dense - -x = V2 1 2 ^+^ V2 3 4 - -main :: IO () -main = do - putStrLn "Hello, Haskell!" - MyLib.someFunc diff --git a/cabal-nix.project b/cabal-nix.project new file mode 100644 index 0000000..0fe1351 --- /dev/null +++ b/cabal-nix.project @@ -0,0 +1,5 @@ +-- To use this file, use `cabal --project-file=cabal-nix.project` + +packages: . + +-- Nix should be providing dense, so no need for it here diff --git a/cabal.project b/cabal.project new file mode 100644 index 0000000..b764c34 --- /dev/null +++ b/cabal.project @@ -0,0 +1,2 @@ +packages: . + diff --git a/default.nix b/default.nix new file mode 100644 index 0000000..052e731 --- /dev/null +++ b/default.nix @@ -0,0 +1,2 @@ +{ nixpkgs ? import nix/nixpkgs.nix {} }: +nixpkgs.pkgs.haskellPackages.callPackage ./numskull.nix { } \ No newline at end of file diff --git a/demo/default.nix b/demo/default.nix new file mode 100644 index 0000000..b4fc5d3 --- /dev/null +++ b/demo/default.nix @@ -0,0 +1,10 @@ +let + pkgs = import ./pkgs.nix; + nixpkgs = import pkgs.nixpkgs {}; + notebooks = map (folder: { + name = folder; + path = import (./. + "/${folder}"); + }); +in nixpkgs.linkFarm "notebooks" (notebooks [ + "Rowan-Presentation" +]) diff --git a/demo/notebook/.ipynb_checkpoints/Rowan-Presentation-checkpoint.ipynb b/demo/notebook/.ipynb_checkpoints/Rowan-Presentation-checkpoint.ipynb new file mode 100644 index 0000000..678c23f --- /dev/null +++ b/demo/notebook/.ipynb_checkpoints/Rowan-Presentation-checkpoint.ipynb @@ -0,0 +1,502 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Numskull Demo!" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "{-# LANGUAGE TypeApplications #-}\n", + "{-# LANGUAGE TemplateHaskell #-}\n", + "{-# LANGUAGE QuasiQuotes #-}\n", + "import Numskull\n", + "import Data.Maybe (fromJust)\n", + "import Type.Reflection\n", + "\n", + "p = printArray" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "It's easy to make arrays." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "2.0 4.0 6.0" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "3.14" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + " 1 2 3 4 5 \n", + " 6 7 8 9 10 \n", + "11 12 13 14 15 \n", + "16 17 18 19 20" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "0 0 0 \n", + "0 0 0 \n", + "0 0 0" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "1 2 \n", + "3 4 \n", + "5 6" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "p $ fromList [3] [2,4,6]\n", + "\n", + "p $ singleton 3.14\n", + "\n", + "p.fromJust $ reshape [4,5] $ arange 1 (20::Int)\n", + "\n", + "p $ zeros (typeRep @Int) [3,3]\n", + "\n", + "l :: TreeMatrix Int\n", + "l = A [A [B 1, B 2],\n", + " A [B 3, B 4],\n", + " A [B 5, B 6]]\n", + "p $ fromMatrix l" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Or take slices of them..." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "3D Array:" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "3 1 4 \n", + "1 5 9 \n", + "2 6 5 \n", + "\n", + "3 5 8 \n", + "9 7 9 \n", + "3 2 3" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "Sliced:" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "1 5 9 \n", + "2 6 5" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "Sliced, but fancier:" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "1 5 9 \n", + "2 6 5" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "piNd = fromList [2,3,3] [3,1,4,1,5,9,2,6,5,3,5,8,9,7,9,3,2,3::Int]\n", + "\n", + "putStrLn \"3D Array:\"\n", + "p piNd\n", + "\n", + "putStrLn \"Sliced:\"\n", + "p $ slice [(0,0), (1,2)] piNd\n", + "\n", + "putStrLn \"Sliced, but fancier:\"\n", + "p $ piNd /! [q|0,1:3|]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And switch values or even types" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "100 3 6 10 15 21 28 36 45" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + " 1.0 3.0 6.0 10.0 15.0 21.0 28.0 36.0 45.0" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "1 1 0 1" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "intNd = fromListFlat [1, 3, 6, 10, 15, 21, 28, 36, 45 :: Int]\n", + "boolNd = fromListFlat [True, True, False, True]\n", + "\n", + "p $ update intNd [0] 100\n", + "\n", + "p $ convertDTypeTo (typeRep @Double) intNd\n", + "p $ matchDType intNd boolNd" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And do all sorts of fun maths!" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Numeracy:" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "10 11 2 \n", + "13 14 5 \n", + " 6 7 8" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "10 11 2 \n", + "13 14 5 \n", + " 6 7 8" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + " 0 10 0 \n", + "30 40 0 \n", + " 0 0 0" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "Powers/logs:" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + " 0 1 4 \n", + " 9 16 25 \n", + "36 49 64" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "Average:" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "5 5 1 \n", + "6 7 2 \n", + "3 3 4" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "Transpose & diagonal:" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "0 3 6 \n", + "1 4 7 \n", + "2 5 8" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "0 4 8" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "Matrix multiplication:" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + " 2.0 3.0 \n", + " 6.0 11.0" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "16.0 23.0 \n", + "24.0 37.0" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "Determinant:" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "[40.0]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "nd1 = fromList [3,3] [0..8::Int]\n", + "nd2 = padShape [3,3] $ fromList [2,2] [10,10,10,10::Int]\n", + "\n", + "putStrLn \"Numeracy:\"\n", + "p $ nd1 + nd2\n", + "p $ Numskull.sum [nd1, nd2]\n", + "p $ nd1 * nd2\n", + "\n", + "putStrLn \"Powers/logs:\"\n", + "p $ elemPow nd1 (fromList [3,3] $ replicate 9 (2::Int))\n", + "\n", + "putStrLn \"Average:\"\n", + "p $ mean [nd1, nd2]\n", + "\n", + "putStrLn \"Transpose & diagonal:\"\n", + "p $ transpose nd1\n", + "p $ diagonal nd1\n", + "\n", + "putStrLn \"Matrix multiplication:\"\n", + "nd3 = fromList [2,2] [0..3::Float]\n", + "nd4 = fromList [2,2] [4..7::Float]\n", + "p $ matMul nd3 nd3\n", + "m = fromJust (gemm nd3 nd3 nd4 True False 3 1)\n", + "p m\n", + "\n", + "putStrLn \"Determinant:\"\n", + "print (determinant m :: [Float])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If the built-in numskull operations aren't good enough for you, and you don't want to write your own, just use NumPy.\n", + "\n", + "NumSkull will serialise most standard DType arrays to NumPy .npy files and back. But you're just going to have to trust me a bit here..." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0 1 2 \n", + "3 4 5 \n", + "6 7 8" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "saveNpy \"./serialisationdemo.npy\" nd1\n", + "loadNpy \"./serialisationdemo.npy\" >>= p" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Haskell", + "language": "haskell", + "name": "haskell" + }, + "language_info": { + "codemirror_mode": "ihaskell", + "file_extension": ".hs", + "mimetype": "text/x-haskell", + "name": "haskell", + "pygments_lexer": "Haskell", + "version": "9.0.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/demo/notebook/Rowan-Presentation.ipynb b/demo/notebook/Rowan-Presentation.ipynb new file mode 100644 index 0000000..678c23f --- /dev/null +++ b/demo/notebook/Rowan-Presentation.ipynb @@ -0,0 +1,502 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Numskull Demo!" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "{-# LANGUAGE TypeApplications #-}\n", + "{-# LANGUAGE TemplateHaskell #-}\n", + "{-# LANGUAGE QuasiQuotes #-}\n", + "import Numskull\n", + "import Data.Maybe (fromJust)\n", + "import Type.Reflection\n", + "\n", + "p = printArray" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "It's easy to make arrays." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "2.0 4.0 6.0" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "3.14" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + " 1 2 3 4 5 \n", + " 6 7 8 9 10 \n", + "11 12 13 14 15 \n", + "16 17 18 19 20" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "0 0 0 \n", + "0 0 0 \n", + "0 0 0" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "1 2 \n", + "3 4 \n", + "5 6" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "p $ fromList [3] [2,4,6]\n", + "\n", + "p $ singleton 3.14\n", + "\n", + "p.fromJust $ reshape [4,5] $ arange 1 (20::Int)\n", + "\n", + "p $ zeros (typeRep @Int) [3,3]\n", + "\n", + "l :: TreeMatrix Int\n", + "l = A [A [B 1, B 2],\n", + " A [B 3, B 4],\n", + " A [B 5, B 6]]\n", + "p $ fromMatrix l" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Or take slices of them..." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "3D Array:" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "3 1 4 \n", + "1 5 9 \n", + "2 6 5 \n", + "\n", + "3 5 8 \n", + "9 7 9 \n", + "3 2 3" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "Sliced:" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "1 5 9 \n", + "2 6 5" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "Sliced, but fancier:" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "1 5 9 \n", + "2 6 5" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "piNd = fromList [2,3,3] [3,1,4,1,5,9,2,6,5,3,5,8,9,7,9,3,2,3::Int]\n", + "\n", + "putStrLn \"3D Array:\"\n", + "p piNd\n", + "\n", + "putStrLn \"Sliced:\"\n", + "p $ slice [(0,0), (1,2)] piNd\n", + "\n", + "putStrLn \"Sliced, but fancier:\"\n", + "p $ piNd /! [q|0,1:3|]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And switch values or even types" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "100 3 6 10 15 21 28 36 45" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + " 1.0 3.0 6.0 10.0 15.0 21.0 28.0 36.0 45.0" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "1 1 0 1" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "intNd = fromListFlat [1, 3, 6, 10, 15, 21, 28, 36, 45 :: Int]\n", + "boolNd = fromListFlat [True, True, False, True]\n", + "\n", + "p $ update intNd [0] 100\n", + "\n", + "p $ convertDTypeTo (typeRep @Double) intNd\n", + "p $ matchDType intNd boolNd" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And do all sorts of fun maths!" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Numeracy:" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "10 11 2 \n", + "13 14 5 \n", + " 6 7 8" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "10 11 2 \n", + "13 14 5 \n", + " 6 7 8" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + " 0 10 0 \n", + "30 40 0 \n", + " 0 0 0" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "Powers/logs:" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + " 0 1 4 \n", + " 9 16 25 \n", + "36 49 64" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "Average:" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "5 5 1 \n", + "6 7 2 \n", + "3 3 4" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "Transpose & diagonal:" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "0 3 6 \n", + "1 4 7 \n", + "2 5 8" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "0 4 8" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "Matrix multiplication:" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + " 2.0 3.0 \n", + " 6.0 11.0" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "16.0 23.0 \n", + "24.0 37.0" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "Determinant:" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "[40.0]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "nd1 = fromList [3,3] [0..8::Int]\n", + "nd2 = padShape [3,3] $ fromList [2,2] [10,10,10,10::Int]\n", + "\n", + "putStrLn \"Numeracy:\"\n", + "p $ nd1 + nd2\n", + "p $ Numskull.sum [nd1, nd2]\n", + "p $ nd1 * nd2\n", + "\n", + "putStrLn \"Powers/logs:\"\n", + "p $ elemPow nd1 (fromList [3,3] $ replicate 9 (2::Int))\n", + "\n", + "putStrLn \"Average:\"\n", + "p $ mean [nd1, nd2]\n", + "\n", + "putStrLn \"Transpose & diagonal:\"\n", + "p $ transpose nd1\n", + "p $ diagonal nd1\n", + "\n", + "putStrLn \"Matrix multiplication:\"\n", + "nd3 = fromList [2,2] [0..3::Float]\n", + "nd4 = fromList [2,2] [4..7::Float]\n", + "p $ matMul nd3 nd3\n", + "m = fromJust (gemm nd3 nd3 nd4 True False 3 1)\n", + "p m\n", + "\n", + "putStrLn \"Determinant:\"\n", + "print (determinant m :: [Float])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If the built-in numskull operations aren't good enough for you, and you don't want to write your own, just use NumPy.\n", + "\n", + "NumSkull will serialise most standard DType arrays to NumPy .npy files and back. But you're just going to have to trust me a bit here..." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0 1 2 \n", + "3 4 5 \n", + "6 7 8" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "saveNpy \"./serialisationdemo.npy\" nd1\n", + "loadNpy \"./serialisationdemo.npy\" >>= p" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Haskell", + "language": "haskell", + "name": "haskell" + }, + "language_info": { + "codemirror_mode": "ihaskell", + "file_extension": ".hs", + "mimetype": "text/x-haskell", + "name": "haskell", + "pygments_lexer": "Haskell", + "version": "9.0.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/demo/notebook/default.nix b/demo/notebook/default.nix new file mode 100644 index 0000000..b4e7327 --- /dev/null +++ b/demo/notebook/default.nix @@ -0,0 +1,10 @@ +let + pkgs = import ../pkgs.nix; +in import "${pkgs.ihaskell}/release.nix" rec { + nixpkgs = import pkgs.nixpkgs {}; + compiler = "ghc902"; + packages = self: with self; [ + (import ../../default.nix + { inherit nixpkgs; }) + ]; +} diff --git a/demo/notebook/serialisationdemo.npy b/demo/notebook/serialisationdemo.npy new file mode 100644 index 0000000..9f22838 Binary files /dev/null and b/demo/notebook/serialisationdemo.npy differ diff --git a/demo/notebook/start.sh b/demo/notebook/start.sh new file mode 100755 index 0000000..827bb99 --- /dev/null +++ b/demo/notebook/start.sh @@ -0,0 +1,2 @@ +#!/bin/bash +$(nix-build)/bin/jupyter-notebook diff --git a/demo/pkgs.nix b/demo/pkgs.nix new file mode 100644 index 0000000..56f07ef --- /dev/null +++ b/demo/pkgs.nix @@ -0,0 +1,12 @@ + +let + fetcher = { owner, repo, rev, sha256, ... }: builtins.fetchTarball { + inherit sha256; + url = "https://github.com/${owner}/${repo}/tarball/${rev}"; + }; + nixpkgs = import (fetcher (builtins.fromJSON (builtins.readFile ./versions.json)).nixpkgs) { overlays = [ ]; }; + lib = nixpkgs.lib; + versions = lib.mapAttrs + (_: fetcher) + (builtins.fromJSON (builtins.readFile ./versions.json)); +in versions diff --git a/demo/presentation-slides.pdf b/demo/presentation-slides.pdf new file mode 100644 index 0000000..5184178 Binary files /dev/null and b/demo/presentation-slides.pdf differ diff --git a/demo/rise.nix b/demo/rise.nix new file mode 100644 index 0000000..9c9e077 --- /dev/null +++ b/demo/rise.nix @@ -0,0 +1,10 @@ +pythonPackages: pythonPackages.buildPythonPackage rec { + pname = "rise"; + version = "5.6.0"; + name = "${pname}-${version}"; + src = builtins.fetchurl { + url = "https://files.pythonhosted.org/packages/source/r/${pname}/${name}.tar.gz"; + sha256 = "09lfcm2zdi5k11af5c5nx4bnx2vr36z90skw0jp3mri7pqymrr1b"; + }; + propagatedBuildInputs = [ pythonPackages.notebook ]; +} diff --git a/demo/updater b/demo/updater new file mode 100644 index 0000000..5d01cc5 --- /dev/null +++ b/demo/updater @@ -0,0 +1,26 @@ +#! /usr/bin/env nix-shell +#! nix-shell -i bash +#! nix-shell -p curl jq nix + +set -eufo pipefail + +FILE=$1 +PROJECT=$2 + +OWNER=$(jq -r '.[$project].owner' --arg project "$PROJECT" < "$FILE") +REPO=$(jq -r '.[$project].repo' --arg project "$PROJECT" < "$FILE") +DEFAULT_BRANCH=$(jq -r '.[$project].branch // "master"' --arg project "$PROJECT" < "$FILE") + +BRANCH=${3:-$DEFAULT_BRANCH} + +REV=$(curl "https://api.github.com/repos/$OWNER/$REPO/branches/$BRANCH" | jq -r '.commit.sha') +SHA256=$(nix-prefetch-url --unpack "https://github.com/$OWNER/$REPO/tarball/$REV") +TJQ=$(jq '.[$project] = {owner: $owner, repo: $repo, branch: $branch, rev: $rev, sha256: $sha256}' \ + --arg project "$PROJECT" \ + --arg owner "$OWNER" \ + --arg repo "$REPO" \ + --arg branch "$BRANCH" \ + --arg rev "$REV" \ + --arg sha256 "$SHA256" \ + < "$FILE") +[[ $? == 0 ]] && echo "${TJQ}" >| "$FILE" diff --git a/demo/versions.json b/demo/versions.json new file mode 100644 index 0000000..c568e8d --- /dev/null +++ b/demo/versions.json @@ -0,0 +1,16 @@ +{ + "ihaskell": { + "owner": "IHaskell", + "repo": "IHaskell", + "branch": "master", + "rev": "8afa4e22c5724da89fec85a599ee129ab5b4cb9a", + "sha256": "0rkvqrpnsyp33x8mzh1v48vm96bpmza14nl6ah1sgjfbp86ihi8p" + }, + "nixpkgs": { + "owner": "NixOS", + "repo": "nixpkgs", + "branch": "nixos-22.11", + "rev": "da26ae9f6ce2c9ab380c0f394488892616fc5a6a", + "sha256": "1l3xhsnj0msvrf2qz86j4lmbpisvinf4cf1d89qm73zjh5qigzq4" + } +} diff --git a/docs/Numskull.html b/docs/Numskull.html new file mode 100644 index 0000000..d7b8d54 --- /dev/null +++ b/docs/Numskull.html @@ -0,0 +1,188 @@ + +Numskull
numskull-0.1.0.0
Safe HaskellSafe-Inferred
LanguageHaskell2010

Numskull

Synopsis

Metadata

class (Typeable a, Storable a, Show a, Eq a, Ord a) => DType a #

All types storable within an NdArray must implement DType. + This defines some basic properties, mathematical operations and standards for conversion.

Instances

Instances details
DType Int32 # 
Instance details

Defined in DType

DType Int64 # 
Instance details

Defined in DType

DType Bool # 
Instance details

Defined in DType

DType Char # 
Instance details

Defined in DType

DType Double # 
Instance details

Defined in DType

DType Float # 
Instance details

Defined in DType

DType Int # 
Instance details

Defined in DType

Methods

addId :: Int #

multId :: Int #

add :: Int -> Int -> Int #

subtract :: Int -> Int -> Int #

multiply :: Int -> Int -> Int #

divide :: Int -> Int -> Int #

div :: Int -> Int -> Int #

power :: Int -> Double -> Double #

pow :: Int -> Int -> Int #

log :: Int -> Int -> Int #

mod :: Int -> Int -> Int #

abs :: Int -> Int #

signum :: Int -> Int #

ceil :: Int -> Int #

floor :: Int -> Int #

sin :: Int -> Int #

cos :: Int -> Int #

tan :: Int -> Int #

invert :: Int -> Int #

shiftleft :: Int -> Int #

shiftright :: Int -> Int #

dtypeToRational :: Int -> Rational #

rationalToDtype :: Rational -> Int #

size :: [Integer] -> Int #

Gets the total number of elements in a given array shape. + >>> size [2,3] + 6

shape :: NdArray -> [Integer] #

Returns the shape list of an array.

getVector :: forall a. DType a => NdArray -> Vector a #

Gets the vector of an array. Requires a type specification to output safely.

ndType :: NdArray -> String #

Gets the TypeRep String representation of the NdArray elements

checkNdType :: forall a b. (DType a, DType b) => NdArray -> TypeRep a -> Maybe (a :~~: b) #

Compares the type of the array elements to the given TypeRep.

isEmpty :: NdArray -> Bool #

Checks if the undelying vector has any elements.

Creation

data NdArray #

The core of this module. NdArrays can be of any DType a and size/shape (list of dimensions) + These are hidden by the type.

Instances

Instances details
Num NdArray # 
Instance details

Defined in Numskull

Show NdArray #

By default arrays are printed flat with the shape as metadata. + For a tidier representation, use printArray.

Instance details

Defined in NdArray

Eq NdArray # 
Instance details

Defined in Numskull

Ord NdArray # 
Instance details

Defined in Numskull

fromList :: DType a => [Integer] -> [a] -> NdArray #

Creates an NdArray from a given shape and list. The number of elements must match. + >>> printArray $ fromList [2,2] [1,2,3,4::Int] + 1 2 + 3 4

fromListFlat :: DType a => [a] -> NdArray #

Creates a 1xn NdArray from a list. + >>> printArray $ fromListFlat [1,2,3,4::Int] + 1 2 3 4

data TreeMatrix a #

This type is specifically for pretty explicit definitions of NdArrays. +The A constructor is for Array - a set of values and B is the value. +-- Example 2x3x2 +l :: TreeMatrix Int +l = A [A [A [B 1, B 2], + A [B 3, B 4], + A [B 5, B 6]],

A [A [B 7, B 8], + A [B 9, B 10], + A [B 11, B 12]]]

Constructors

B a 
A [TreeMatrix a] 

fromMatrix :: DType a => TreeMatrix a -> NdArray #

Creates an NdArray from an explicitly given matrix such as the example 2x3.

fromVector :: DType a => [Integer] -> Vector a -> Maybe NdArray #

The safe standard constructor. Returns Nothing if the + shape does not match the given vector length.

singleton :: DType a => a -> NdArray #

Creates a 1x1 matrix + >>> printArray $ singleton (3::Int) + 3

arange :: (Enum a, DType a) => a -> a -> NdArray #

Creates a flat array over the specified range.

zeros :: forall a. DType a => TypeRep a -> [Integer] -> NdArray #

Creates an array of the given shape of the identity element for the given type.

squareArr :: forall a. DType a => [a] -> NdArray #

Creates the smallest possible square matrix from the given list, +padding out any required space with the identity element for the DType

Modification

update :: forall a. DType a => NdArray -> [Integer] -> a -> NdArray #

General Mapping, Folding & Zipping

foldrA :: forall a b. DType a => (a -> b -> b) -> b -> NdArray -> b #

Near identical to a standard foldr instance, expect NdArrays do not have an explicit type. +Folds in row-major order.

mapA :: forall a. forall b. (DType a, DType b) => (a -> b) -> NdArray -> NdArray #

Near identical to a standard map implementation in row-major order.

mapTransform :: (forall a. DType a => a -> a) -> NdArray -> NdArray #

Maps functions which return the same type.

pointwiseZip :: (forall t. DType t => t -> t -> t) -> NdArray -> NdArray -> NdArray #

The generic function for operating on two matching DType arrays with the same shape + in an element-wise/pointwise way. Errors if mismatching + >>> x = fromList [2,2] [1,2,3,4 :: Int] + >>> y = fromList [2,2] [5,2,2,2 :: Int] + >>> printArray $ pointwiseZip (DType.multiply) x y + 5 4 + 6 8

pointwiseBool :: (forall t. DType t => t -> t -> Bool) -> NdArray -> NdArray -> NdArray #

A slightly specialised version of pointwise zip intended for comparative functions.

zipArrayWith :: forall a b c. (DType a, DType b, DType c) => (a -> b -> c) -> NdArray -> NdArray -> NdArray #

Completely generic zip on two NdArrays. If the shapes mismatch, they are truncated as with + standard zips. Function inputs must match the DTypes.

Summaries

origin :: forall a. DType a => NdArray -> a #

Returns the element at the 0th position of the array.

maxElem :: forall a. DType a => NdArray -> a #

Returns the largest element.

minElem :: forall a. DType a => NdArray -> a #

Returns the smallest element.

Mathematical constant

scale :: forall a. DType a => a -> NdArray -> NdArray #

Multiplies all elements by a scalar.

absA :: NdArray -> NdArray #

Takes the absolute value of all elements.

signumA :: NdArray -> NdArray #

Replaces all elements by their signum. + >>> printArray $ signumA (fromList [5] [-50, -25, 0, 1, 10::Int]) + -1 -1 0 1 1

ceilA :: NdArray -> NdArray #

Mathematical ceiling of each element (preserving DType).

floorA :: NdArray -> NdArray #

Mathematical floor of each element (preserving DType).

sinA :: NdArray -> NdArray #

Sine of each element (preserving DType).

cosA :: NdArray -> NdArray #

Cosine of each element (preserving DType).

tanA :: NdArray -> NdArray #

Tangent of each element (preserving DType).

invertA :: NdArray -> NdArray #

Either elementwise NOT or NEG depending on the DType.

shiftleftA :: NdArray -> NdArray #

Multiply each element by 2.

shiftrightA :: NdArray -> NdArray #

Divide each element by 2.

Mathematical pointwise

elemDivide :: NdArray -> NdArray -> NdArray #

Pointwise division

elemDiv :: NdArray -> NdArray -> NdArray #

Pointwise integer division. Will return an NdArray of type Int.

elemPow :: NdArray -> NdArray -> NdArray #

Pointwise exponentiation (preserving DType)

elemPower :: NdArray -> NdArray -> NdArray #

Pointwise exponentiation which forces precision. + Takes some NdArray of bases, an array of Double exponents and returns an array of Doubles.

sum :: [NdArray] -> NdArray #

Takes the pointwise sum over all the given NdArrays. If they are different shapes, + the smaller dimensions are padded out with the identity element. + The sum of the empty list is the singleton 0.

mean :: [NdArray] -> NdArray #

Finds the mean pointwise over the list of arrays. Smaller arrays are padded out with + the identity element.

Bounds

clip :: forall a. DType a => Maybe a -> Maybe a -> NdArray -> NdArray #

Constrains all elements of the array to the range specified by [mini, maxi]. + If they are given as Nothing, the range is infinite in that direction. + NB: must still specify type for Nothing i.e. clip (Nothing :: Maybe Int) Nothing myNd

Type Conversions

convertDTypeTo :: forall a. DType a => TypeRep a -> NdArray -> NdArray #

Converting between the standard dtypes and changing the shapes of arrays. +NB the difference between size and shape. The shape is an Integer list +describing the width of each dimension. Size refers to the total number of +elements in the array, i.e. the product of the shape.

Converts an NdArray of one type to any other with a DType instance.

matchDType :: NdArray -> NdArray -> NdArray #

Converts the second NdArray to be the same DType as the first.

Size Conversions

resize :: Integer -> NdArray -> NdArray #

Truncate or pad the NdArray to match the new given size. +The shape will be collapsed to 1xn.

Shape Conversions/Manipulations

reshape :: [Integer] -> NdArray -> Maybe NdArray #

Shape-shift one array to another of the same size (Nothing otherwise). + >>> x = fromList [2,3] [1,2,3,4,5,6 :: Int] + >>> printArray x + 1 2 + 3 4 + 5 6 + >>> printArray $ fromJust $ reshape [3,2] x + 1 2 3 + 4 5 6

padShape :: [Integer] -> NdArray -> NdArray #

Adds zero-rows to an array. Will error if you map to a smaller shape. + >>> x = fromList [2,2] [1,2,3,4 :: Int] + >>> printArray $ padShape [4,3] x + 1 2 0 0 + 3 4 0 0 + 0 0 0 0

constrainShape :: [Integer] -> NdArray -> NdArray #

Truncates the array to be no larger than the specified dimensions.

broadcast :: (NdArray, NdArray) -> Maybe (NdArray, NdArray) #

Takes a pair of NdArrays and attempts to copy slices so that they are size matched. + Arrays are broadcastable if they either match in corresponding dimensions or one is + of dimension size 1 e.g. [2,5,1] and [2,1,6]. Missing dimensions are padded with 1s + e.g. [1,2,3] and [3] are broadcastable.

concatAlong :: Int -> [NdArray] -> Maybe NdArray #

Concatenate a list of tensors into a single tensor. All input tensors must have the + same shape, except for the dimension size of the axis to concatenate on. + Returns Nothing if the arrays are not all of the same type or matching shapes.

gather :: NdArray -> [Integer] -> Integer -> NdArray #

Takes an array, set of sub-indices and axis and repeatedly takes slices + of the array restricted to that index along the specified axis. + The slices are then concatenated into the final array.

Matrix Manipulation

swapRows :: Integer -> Integer -> NdArray -> NdArray #

Switches the rows at the two given indices over. +NB: designed for 2x2 matrices so will only make swaps in the front matrix of a tensor.

diagonal :: NdArray -> NdArray #

Gets the flat array of the leading diagonal of the front matrix of the tensor.

transpose :: NdArray -> NdArray #

Reverses the order of axes and switches the elements accordingly.

transposePerm :: [Int] -> NdArray -> NdArray #

Transposes the axes of an array according to the given permutation (e.g. [2,0,1])

Matrix Multiplication

dot :: forall a. DType a => NdArray -> NdArray -> a #

Dot product over matricies of the same shape.

matMul :: NdArray -> NdArray -> NdArray #

Standard matrix multiplication following NumPy conventions. + 1D arrays have the extra dimension pre/appended + 2D arrays are multiplied as expected + ND-arrays are broadcast to match each other where possible and treated as stacks of nxm/pxq arrays.

upperTriangle :: NdArray -> NdArray #

Converts a nxn matrix to upper triangle form. O(n^3).

determinant :: forall a. DType a => NdArray -> [a] #

Finds the determinant(s) of a tensor. Over matrices of more than two dimensions +each 2D matrix's determinant is individually calculated and concatenated together (as in numpy: +https:/numpy.orgdocstablereferencegeneratednumpy.linalg.det.html ). +If the matrix is non-square it is assumed to be padded out and will have determinant of 0

determinant2D :: forall a. DType a => NdArray -> a #

Calculates the determinant of a 2D matrix using LU decomposition as described in the +below paper. O(n^3). +https:/informatika.stei.itb.ac.id~rinaldi.munirMatdis2016-2017Makalah2016Makalah-Matdis-2016-051.pdf

gemm :: (DType a, DType b) => NdArray -> NdArray -> NdArray -> Bool -> Bool -> a -> b -> Maybe NdArray #

General matrix multiplication. Calculates alpha*AB + beta*C with the option +to transpose A and B first. +Takes A, B, C, A transpose?, B transpose?, alpha, beta +Returns nothing if the matrix types/sizes do not match. +Will attempt to broadcast the shape of C and convert the types of alpha & beta.

For more information see: +https:/en.wikipedia.orgwiki/Basic_Linear_Algebra_Subprograms#Level_3 +NB: if the matrices are integers the scalars will also become integers so you should convert the matrices first

Indexing

collapseInd :: [Integer] -> [Integer] -> Integer #

Converts a shape and multi-index to a 1D index.

expandInd :: [Integer] -> Integer -> [Integer] #

Converts a shape and 1D index to a multi-index.

map1DIndex :: [Integer] -> [Integer] -> Integer -> Integer #

Converts the multi-index for one shape to another

validIndex :: NdArray -> [Integer] -> Bool #

Checks an index does not exceed the shape.

(#!) :: DType a => NdArray -> [Integer] -> a #

Takes a multi-dimensional index and returns the value in the NdArray at that position. +Indicies can be negative, where -1 is the row in that dimension. +If an index exceeds the size of its dimension, a value will still be returned, the identity +value for the array e.g. 0. To avoid this use !?.

(#?) :: forall a. DType a => NdArray -> [Integer] -> Maybe a #

The safer version of #! which returns Nothing if an index exceeds the shape bounds.

slice :: [(Integer, Integer)] -> NdArray -> NdArray #

Takes a series of ranges corresponding to each dimension in the array and returns +the sub-array. Indicies are inclusive and can be negative.

(/!) :: NdArray -> QuasiSlice -> NdArray #

The concise operator for slicing. Instead of providing an IndexRange, + You may QuasiQuote a NumPy-like index e.g. myArray /! [q|5,2:6,:3|]. + Unspecified values in ranges denote the start/end.

Pretty Printing

printArray :: NdArray -> IO () #

Prints out the pretty NdArray representation.

prettyShowArray :: NdArray -> String #

Converts an NdArray to its pretty representation. + Values along a row are separated whitespace. Along a column, newlines. + For higher dimensions, an additional newline is added to separate the nxm matrices.

(=@=) :: (Typeable a, Typeable b) => a -> b -> Maybe (a :~~: b) #

eqTypeRep synonym, returning Just HRefl in the case of type equality. + >>> case True =@= False of + >>> Just HRefl -> putStrLn "Two Booleans will match" + >>> Nothing -> putStrLn "Mismatching types" + Two Booleans will match

loadNpy :: FilePath -> IO NdArray #

Loads an NdArray from a .npy file

Orphan instances

\ No newline at end of file diff --git a/nix/nixpkgs.nix b/nix/nixpkgs.nix new file mode 100644 index 0000000..78da415 --- /dev/null +++ b/nix/nixpkgs.nix @@ -0,0 +1,18 @@ +{ sources ? import ./sources.nix }: + +let overlay = _: pkgs: { + niv = (import sources.niv {}).niv; + haskellPackages = pkgs.haskellPackages.override { + overrides = self: super: { + dense = pkgs.haskell.lib.markUnbroken ( + pkgs.haskell.lib.dontCheck ( # doctests are broken + pkgs.haskell.lib.overrideSrc super.dense { + src = sources.dense; + } + ) + ); + }; + }; +}; +in +import sources.nixpkgs { overlays = [ overlay ] ; config = {}; } \ No newline at end of file diff --git a/nix/sources.json b/nix/sources.json new file mode 100644 index 0000000..8a66ea7 --- /dev/null +++ b/nix/sources.json @@ -0,0 +1,14 @@ +{ + "nixpkgs": { + "branch": "nixos-23.05", + "description": "Nix Packages collection", + "homepage": "https://github.com/NixOS/nixpkgs", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "8163a64662b43848802092d52015ef60777d6129", + "sha256": "12dlinlsmn9m4597bbyd9wjlsi5n1fd33kwrlf5v8n9357cciq54", + "type": "tarball", + "url": "https://github.com/NixOS/nixpkgs/archive/8163a64662b43848802092d52015ef60777d6129.tar.gz", + "url_template": "https://github.com///archive/.tar.gz" + } +} \ No newline at end of file diff --git a/sources.nix b/nix/sources.nix similarity index 99% rename from sources.nix rename to nix/sources.nix index 7a9a2ae..1a3a74e 100644 --- a/sources.nix +++ b/nix/sources.nix @@ -86,4 +86,4 @@ mapAttrs (_: spec: spec // { outPath = callFunctionWith spec (getFetcher spec) { }; } else spec - ) sources + ) sources \ No newline at end of file diff --git a/nixpkgs.nix b/nixpkgs.nix deleted file mode 100644 index ef19037..0000000 --- a/nixpkgs.nix +++ /dev/null @@ -1,17 +0,0 @@ -{ sources ? import ./sources.nix }: - -let clash-compiler = import sources.clash-compiler {}; - - overlay = _: pkgs: { - niv = (import sources.niv {}).niv; - haskellPackages = pkgs.haskellPackages.override { - overrides = self: super: { - yoda = - pkgs.haskell.lib.doJailbreak - (self.callCabal2nix "yoda" sources.yoda {}); - } // clash-compiler; - }; - }; -in -import sources.nixpkgs - { overlays = [ overlay ] ; config = { }; } diff --git a/numskull.cabal b/numskull.cabal new file mode 100644 index 0000000..5a3f306 --- /dev/null +++ b/numskull.cabal @@ -0,0 +1,47 @@ +cabal-version: 2.4 +name: numskull +version: 0.2.0.0 +-- synopsis: +-- description: +license: MIT +license-file: LICENSE +author: Rowan Mather +maintainer: rowan@myrtle.ai +-- copyright: +build-type: Simple +-- extra-doc-files: CHANGELOG.md +-- extra-source-files: + +library + exposed-modules: Numskull + , NdArray + , NdArrayException + , DType + , MatrixForm + , Indexing + , QuasiSlice + , QuasiSlice.Quote + , Typing + , Serialisation + -- other-modules: + -- other-extensions: + build-depends: base >=4.13.0.0 && <5 + , vector + , split + , containers + , parsec + , template-haskell + , deepseq + hs-source-dirs: src + default-language: Haskell2010 + ghc-options: -O2 -Wall + +test-suite tests + type: exitcode-stdio-1.0 + main-is: Test.hs + build-depends: base >=4.13.0.0 && <5 + , numskull + , QuickCheck + , hspec + hs-source-dirs: test + ghc-options: -O2 -Wall diff --git a/numskull.nix b/numskull.nix new file mode 100644 index 0000000..ad4a018 --- /dev/null +++ b/numskull.nix @@ -0,0 +1,13 @@ +{ mkDerivation, base, containers, deepseq, hspec, lib, parsec +, QuickCheck, split, template-haskell, vector +}: +mkDerivation { + pname = "numskull"; + version = "0.1.0.0"; + src = ./.; + libraryHaskellDepends = [ + base containers deepseq parsec split template-haskell vector + ]; + testHaskellDepends = [ base hspec QuickCheck ]; + license = lib.licenses.mit; +} diff --git a/rowan-ndarray.cabal b/rowan-ndarray.cabal deleted file mode 100644 index 8f7197c..0000000 --- a/rowan-ndarray.cabal +++ /dev/null @@ -1,130 +0,0 @@ -cabal-version: 2.4 --- The cabal-version field refers to the version of the .cabal specification, --- and can be different from the cabal-install (the tool) version and the --- Cabal (the library) version you are using. As such, the Cabal (the library) --- version used must be equal or greater than the version stated in this field. --- Starting from the specification version 2.2, the cabal-version field must be --- the first thing in the cabal file. - --- Initial package description 'rowan-ndarray' generated by --- 'cabal init'. For further documentation, see: --- http://haskell.org/cabal/users-guide/ --- --- The name of the package. -name: rowan-ndarray - --- The package version. --- See the Haskell package versioning policy (PVP) for standards --- guiding when and how versions should be incremented. --- https://pvp.haskell.org --- PVP summary: +-+------- breaking API changes --- | | +----- non-breaking API additions --- | | | +--- code changes with no API change -version: 0.1.0.0 - --- A short (one-line) description of the package. --- synopsis: - --- A longer description of the package. --- description: - --- The license under which the package is released. -license: MIT - --- The file containing the license text. -license-file: LICENSE - --- The package author(s). -author: Rowan Mather - --- An email address to which users can send suggestions, bug reports, and patches. -maintainer: rowan@myrtle.ai - --- A copyright notice. --- copyright: -category: Math -build-type: Simple - --- Extra doc files to be distributed with the package, such as a CHANGELOG or a README. -extra-doc-files: CHANGELOG.md - --- Extra source files to be distributed with the package, such as examples, or a tutorial module. --- extra-source-files: - -common warnings - ghc-options: -Wall - -library - -- Import common warning flags. - import: warnings - - -- Modules exported by the library. - exposed-modules: MyLib - - -- Modules included in this library but not exported. - -- other-modules: - - -- LANGUAGE extensions used by modules in this package. - -- other-extensions: - - -- Other library packages from which modules are imported. - build-depends: base ^>=4.9.1.0 - - -- Directories containing source files. - hs-source-dirs: src - - -- Base language which the package is written in. - default-language: Haskell2010 - -executable rowan-ndarray - -- Import common warning flags. - import: warnings - - -- .hs or .lhs file containing the Main module. - main-is: Main.hs - - -- Modules included in this executable, other than Main. - -- other-modules: - - -- LANGUAGE extensions used by modules in this package. - -- other-extensions: - - -- Other library packages from which modules are imported. - build-depends: - base ^>=4.9.1.0, - rowan-ndarray, - dense - - -- Directories containing source files. - hs-source-dirs: app - - -- Base language which the package is written in. - default-language: Haskell2010 - -test-suite rowan-ndarray-test - -- Import common warning flags. - import: warnings - - -- Base language which the package is written in. - default-language: Haskell2010 - - -- Modules included in this executable, other than Main. - -- other-modules: - - -- LANGUAGE extensions used by modules in this package. - -- other-extensions: - - -- The interface type and version of the test suite. - type: exitcode-stdio-1.0 - - -- Directories containing source files. - hs-source-dirs: test - - -- The entrypoint to the test suite. - main-is: Main.hs - - -- Test dependencies. - build-depends: - base ^>=4.9.1.0, - rowan-ndarray, - dense diff --git a/shell.nix b/shell.nix index 0c88006..de66c7c 100644 --- a/shell.nix +++ b/shell.nix @@ -1,17 +1,7 @@ -{ nixpkgs ? import ./nixpkgs.nix {} -}: - -with nixpkgs; - -let ghc = haskellPackages.ghcWithPackages (pkgs: with pkgs; [ - clash-ghc - clash-prelude - QuickCheck - ghcid - repa - ]); -in -mkShell { - name = "clash-exercises"; - buildInputs = [ ghc ]; -} +{ nixpkgs ? import nix/nixpkgs.nix {} }: +(import ./default.nix { inherit nixpkgs; }).env.overrideAttrs (finalAttrs: prevAttrs: { + buildInputs = with nixpkgs.haskellPackages; prevAttrs.buildInputs ++ [ + haskell-language-server + ghcid.bin + ]; +}) \ No newline at end of file diff --git a/sources.json b/sources.json deleted file mode 100644 index ecb4275..0000000 --- a/sources.json +++ /dev/null @@ -1,158 +0,0 @@ -{ - "clash-compiler": { - "branch": "1.4", - "description": "Haskell to VHDL/Verilog/SystemVerilog compiler", - "homepage": "https://www.clash-lang.org/", - "owner": "clash-lang", - "repo": "clash-compiler", - "rev": "fa01fd98799cd7c00cae44a9df847142410f1618", - "sha256": "0b42gschkb7sk958a7j529lcw9kmzikhcqjxblm43lggwwy24hns", - "type": "tarball", - "url": "https://github.com/clash-lang/clash-compiler/archive/fa01fd98799cd7c00cae44a9df847142410f1618.tar.gz", - "url_template": "https://github.com///archive/.tar.gz" - }, - "clash-ffi-sources": { - "branch": "master", - "description": "Haskell to VHDL/Verilog/SystemVerilog compiler", - "homepage": "https://clash-lang.org/", - "owner": "clash-lang", - "repo": "clash-compiler", - "rev": "47918754513298acd06ed4f85eef459e291bd644", - "sha256": "0q89gk7a7y856pv6c0w1lgn18n3i35rgy3caxn45g2z0w4lg7c8n", - "type": "tarball", - "url": "https://github.com/clash-lang/clash-compiler/archive/47918754513298acd06ed4f85eef459e291bd644.tar.gz", - "url_template": "https://github.com///archive/.tar.gz" - }, - "dense": { - "branch": "master", - "description": "dense Haskell library", - "homepage": "", - "owner": "cchalmers", - "repo": "dense", - "rev": "a84e7e43c7efca0ddfe1a5f60ee18cf006dc04fa", - "sha256": "1rg7s1ybdac6139j8l511x2f9jg5793wr20vph54j9p6ybybigsd", - "type": "tarball", - "url": "https://github.com/cchalmers/dense/archive/a84e7e43c7efca0ddfe1a5f60ee18cf006dc04fa.tar.gz", - "url_template": "https://github.com///archive/.tar.gz" - }, - "derive-storable": { - "branch": "master", - "description": "Deriving Storable instances using GHC.Generics", - "homepage": "https://hackage.haskell.org/package/derive-storable", - "owner": "mkloczko", - "repo": "derive-storable", - "rev": "b378aa4cec9ce2fc1cebdb9fb617f6a21872d99e", - "sha256": "0bzl0v47y42n9awfp7nh9kvdz049shdr4aldqrjx40b5izwsll91", - "type": "tarball", - "url": "https://github.com/mkloczko/derive-storable/archive/b378aa4cec9ce2fc1cebdb9fb617f6a21872d99e.tar.gz", - "url_template": "https://github.com///archive/.tar.gz" - }, - "derive-storable-plugin": { - "branch": "master", - "description": null, - "homepage": null, - "owner": "mkloczko", - "repo": "derive-storable-plugin", - "rev": "6082e149e2b8f575394f7e1972d4d2c1d5b0ddb6", - "sha256": "1334w8vcvxg1929b9v1yy5x52in28mr8y05z3vmfz3z7drxm9c83", - "type": "tarball", - "url": "https://github.com/mkloczko/derive-storable-plugin/archive/6082e149e2b8f575394f7e1972d4d2c1d5b0ddb6.tar.gz", - "url_template": "https://github.com///archive/.tar.gz" - }, - "ghc-typelits-natnormalise": { - "branch": "master", - "description": "Normalise GHC.TypeLits.Nat equations", - "homepage": "", - "owner": "clash-lang", - "repo": "ghc-typelits-natnormalise", - "rev": "def05130ae7b5fc64772755ad7da1320265d5672", - "sha256": "12bnzi1lkv0dg5gvrymwksn6zdv72zb20p0sy2wpk0j99sygic01", - "type": "tarball", - "url": "https://github.com/clash-lang/ghc-typelits-natnormalise/archive/def05130ae7b5fc64772755ad7da1320265d5672.tar.gz", - "url_template": "https://github.com///archive/.tar.gz" - }, - "mach-nix": { - "branch": "master", - "description": "Create highly reproducible python environments", - "homepage": "", - "owner": "DavHau", - "repo": "mach-nix", - "rev": "8d903072c7b5426d90bc42a008242c76590af916", - "sha256": "1xmz1rzip6cwk7zhrakigl7zg04mrmsvlarcvhwk38zz0x7kbi10", - "type": "tarball", - "url": "https://github.com/DavHau/mach-nix/archive/8d903072c7b5426d90bc42a008242c76590af916.tar.gz", - "url_template": "https://github.com///archive/.tar.gz" - }, - "maturin-nix": { - "branch": "buildRustPackage", - "description": "Build pyo3 rust wheels in nix", - "homepage": "", - "owner": "cchalmers", - "repo": "maturin-nix", - "rev": "b4e6ad8d045dc1fb6bfc57e3da5e064c8d9d30ca", - "sha256": "1zgz0ccwfz9p925drp00a8469fy0dg7yv2swf7snb3g21j15cs5c", - "type": "tarball", - "url": "https://github.com/cchalmers/maturin-nix/archive/b4e6ad8d045dc1fb6bfc57e3da5e064c8d9d30ca.tar.gz", - "url_template": "https://github.com///archive/.tar.gz" - }, - "niv": { - "branch": "master", - "description": "Easy dependency management for Nix projects", - "homepage": "https://github.com/nmattia/niv", - "owner": "nmattia", - "repo": "niv", - "rev": "ba57d5a29b4e0f2085917010380ef3ddc3cf380f", - "sha256": "1kpsvc53x821cmjg1khvp1nz7906gczq8mp83664cr15h94sh8i4", - "type": "tarball", - "url": "https://github.com/nmattia/niv/archive/ba57d5a29b4e0f2085917010380ef3ddc3cf380f.tar.gz", - "url_template": "https://github.com///archive/.tar.gz" - }, - "nixpkgs": { - "branch": "nixos-20.09", - "description": "Nix Packages collection", - "homepage": "https://github.com/NixOS/nixpkgs", - "owner": "NixOS", - "repo": "nixpkgs", - "rev": "896270d629efd47d14972e96f4fbb79fc9f45c80", - "sha256": "0xmjjayg19wm6cn88sh724mrsdj6mgrql6r3zc0g4x9bx4y342p7", - "type": "tarball", - "url": "https://github.com/NixOS/nixpkgs/archive/896270d629efd47d14972e96f4fbb79fc9f45c80.tar.gz", - "url_template": "https://github.com///archive/.tar.gz" - }, - "nixpkgs-2205": { - "branch": "release-22.05", - "description": "Nix Packages collection", - "homepage": "https://github.com/NixOS/nixpkgs", - "owner": "NixOS", - "repo": "nixpkgs", - "rev": "284079785291494969f790d01a4296b3d8b63741", - "sha256": "1a7is9ism6qhqbrwmwmp1dhrqhxlmzb83df04m8z27lhqd6al356", - "type": "tarball", - "url": "https://github.com/NixOS/nixpkgs/archive/284079785291494969f790d01a4296b3d8b63741.tar.gz", - "url_template": "https://github.com///archive/.tar.gz" - }, - "nixpkgs-2305": { - "branch": "release-23.05", - "description": "Nix Packages collection", - "homepage": "https://github.com/NixOS/nixpkgs", - "owner": "NixOS", - "repo": "nixpkgs", - "rev": "dd5ea4aa63eafe67a8f394856095f867a001a5ca", - "sha256": "0qzr64bgq326zscnimfcm99fp43n31z3rj7s09z8g03v5v5g02yi", - "type": "tarball", - "url": "https://github.com/NixOS/nixpkgs/archive/dd5ea4aa63eafe67a8f394856095f867a001a5ca.tar.gz", - "url_template": "https://github.com///archive/.tar.gz" - }, - "stylish-haskell": { - "branch": "main", - "description": "Haskell code prettifier", - "homepage": "", - "owner": "jaspervdj", - "repo": "stylish-haskell", - "rev": "4ec5c509290d8c2b94045189a122d3fca5e45a4e", - "sha256": "0xkrah3cjg6h87vzh16nanj8spgjvz9js5jwf99lwfmwp35wp92r", - "type": "tarball", - "url": "https://github.com/jaspervdj/stylish-haskell/archive/4ec5c509290d8c2b94045189a122d3fca5e45a4e.tar.gz", - "url_template": "https://github.com///archive/.tar.gz" - } -} \ No newline at end of file diff --git a/src/DType.hs b/src/DType.hs new file mode 100644 index 0000000..cde5717 --- /dev/null +++ b/src/DType.hs @@ -0,0 +1,281 @@ +{-# LANGUAGE TypeApplications #-} + +module DType where + +import Prelude as P +import Data.Vector.Storable +import Type.Reflection +import GHC.Float (float2Double) +import Data.Int +import Data.Char +import Control.DeepSeq (NFData) + +-- | All types storable within an NdArray must implement DType. +-- This defines some basic properties, mathematical operations and standards for conversion. +class (Typeable a, Storable a, Show a, Eq a, Ord a, NFData a) => DType a where + -- | Additive identity + addId :: a + -- | Multiplicative identity + multId :: a + -- | Standard numeric operations + -- NB: + -- divide preserves DType + -- div is specifically for integer division and returns an Int + -- pow preserves DType + -- power is for precision and uses Doubles + -- mod returns an Int + add :: a -> a -> a + subtract :: a -> a -> a + multiply :: a -> a -> a + divide :: a -> a -> a + div :: a -> a -> Int + power :: a -> Double -> Double + pow :: a -> a -> a + -- Log base x of y + log :: a -> a -> a + mod :: a -> a -> Int + abs :: a -> a + signum :: a -> a + ceil :: a -> a + floor :: a -> a + -- Trig + sin :: a -> a + cos :: a -> a + tan :: a -> a + -- | Most logical operations are simply defined in the numeric section on Booleans. + -- Invert is naturally defined as -x numerically and NOT x logically. + invert :: a -> a + shiftleft :: a -> a + shiftright :: a -> a + -- | Dtypes are converted between via the intermediate type of rational + dtypeToRational :: a -> Rational + rationalToDtype :: Rational -> a + +instance DType Int where + addId = 0 + multId = 1 + -- Numeric + add x y = x + y + subtract x y = x - y + multiply x y = x * y + divide = P.div + div x y = (fromIntegral $ P.div x y) :: Int + power x d = fromIntegral x ** d + pow x y = x ^ y + log x y = (P.floor $ logBase xd yd) :: Int + where xd = fromIntegral @Int @Double x + yd = fromIntegral @Int @Double y + mod = P.mod + abs = P.abs + signum = P.signum + ceil x = x + floor x = x + -- Trig + sin = roundIntFunc P.sin + cos = roundIntFunc P.cos + tan = roundIntFunc P.tan + -- Logical + invert x = -x + shiftleft x = x * 2 + shiftright x = x `P.div` 2 + -- Conversion + dtypeToRational = toRational + rationalToDtype = P.floor . fromRational @Double + +roundIntFunc :: (Float -> Float) -> Int -> Int +roundIntFunc f x = (round $ f $ fromIntegral @Int @Float x) :: Int + +instance DType Int32 where + addId = 0 + multId = 1 + -- Numeric + add x y = x + y + subtract x y = x - y + multiply x y = x * y + divide = P.div + div x y = fromIntegral @Int32 @Int $ P.div x y + power x d = fromIntegral x ** d + pow x y = x ^ y + log x y = (P.floor $ logBase xd yd) :: Int32 + where xd = fromIntegral @Int32 @Double x + yd = fromIntegral @Int32 @Double y + mod x y = fromIntegral @Int32 @Int $ P.mod x y + abs = P.abs + signum = P.signum + ceil x = x + floor x = x + -- Trig + sin x = (round $ P.sin $ fromIntegral @Int32 @Float x) :: Int32 + cos x = (round $ P.sin $ fromIntegral @Int32 @Float x) :: Int32 + tan x = (round $ P.sin $ fromIntegral @Int32 @Float x) :: Int32 + -- Logical + invert x = -x + shiftleft x = x * 2 + shiftright x = x `P.div` 2 + -- Conversion + dtypeToRational = toRational + rationalToDtype = P.floor . fromRational @Double + +instance DType Int64 where + addId = 0 + multId = 1 + -- Numeric + add x y = x + y + subtract x y = x - y + multiply x y = x * y + divide = P.div + div x y = fromIntegral @Int64 @Int $ P.div x y + power x d = fromIntegral x ** d + pow x y = x ^ y + log x y = (P.floor $ logBase xd yd) :: Int64 + where xd = fromIntegral @Int64 @Double x + yd = fromIntegral @Int64 @Double y + mod x y = fromIntegral @Int64 @Int $ P.mod x y + abs = P.abs + signum = P.signum + ceil x = x + floor x = x + -- Trig + sin x = (round $ P.sin $ fromIntegral @Int64 @Float x) :: Int64 + cos x = (round $ P.sin $ fromIntegral @Int64 @Float x) :: Int64 + tan x = (round $ P.sin $ fromIntegral @Int64 @Float x) :: Int64 + -- Logical + invert x = -x + shiftleft x = x * 2 + shiftright x = x `P.div` 2 + -- Conversion + dtypeToRational = toRational + rationalToDtype = P.floor . fromRational @Double + +instance DType Float where + addId = 0.0 + multId = 1.0 + -- Numeric + add x y = x + y + subtract x y = x - y + multiply x y = x * y + divide x y = x/y + div x y = P.floor x `P.div` P.floor y + power x d = float2Double x ** d + pow x y = x ** y + log = logBase + mod x y = P.floor x `P.mod` P.floor y + abs = P.abs + signum = P.signum + ceil = fromIntegral @Integer @Float . P.ceiling + floor = fromIntegral @Integer @Float . P.floor + -- Trig + sin = P.sin + cos = P.cos + tan = P.tan + -- Logical + invert x = -x + shiftleft x = x * 2 + shiftright x = x / 2 + -- Conversion + dtypeToRational = toRational + rationalToDtype = fromRational @Float + +instance DType Double where + addId = 0.0 + multId = 1.0 + -- Numeric + add x y = x + y + subtract x y = x - y + multiply x y = x * y + divide x y = x/y + div x y = P.floor x `P.div` P.floor y + power x d = x ** d + pow x y = x ** y + log = logBase + mod x y = P.floor x `P.mod` P.floor y + abs = P.abs + signum = P.signum + ceil = fromIntegral @Integer @Double . P.ceiling + floor = fromIntegral @Integer @Double . P.floor + -- Trig + sin = P.sin + cos = P.cos + tan = P.tan + -- Logical + invert x = -x + shiftleft x = x * 2 + shiftright x = x / 2 + -- Conversion + dtypeToRational = toRational + rationalToDtype = fromRational @Double + +instance DType Bool where + addId = False + multId = True + -- | Logical OR + add x y = x || y + -- | Logical NOR + subtract x y = not (x || y) + -- | Logical AND + multiply x y = x && y + -- | Logical NAND + divide x y = not (x && y) + div x y = fromEnum $ DType.divide x y + -- | Numeric power + power x d = fromIntegral (fromEnum x) ** d + -- | Logical reverse implication + pow x y = not y || x + -- | Logical implication + log x y = not x || y + -- | Logical XOR, but Int result + mod x y = fromEnum $ (x || y) && not (x && y) + abs _ = True + signum = id + ceil = id + floor = id + -- Trig (False = 0, True = 1 or /=0) + sin False = False + sin True = True + cos False = True + cos True = True + tan False = False + tan True = True + -- Logical + -- Logical NOT + invert = not + shiftleft _ = False + shiftright _ = False + -- Conversions + dtypeToRational False = 0 + dtypeToRational True = 1 + rationalToDtype 0 = False + rationalToDtype _ = True + +instance DType Char where + addId = '\NUL' + multId = 'a' + -- Numeric + add x y = chr $ ord x + ord y + subtract x y = chr $ min 0 $ ord x + ord y + multiply _ _ = undefined + divide _ _ = undefined + div _ _ = undefined + power _ _ = undefined + pow _ _ = undefined + log _ _ = undefined + mod _ _ = undefined + abs = undefined + signum c + | isUpper c = 'A' + | isLower c = 'a' + | isDigit c = '0' + | otherwise = c + ceil = toUpper + floor = toLower + -- Trig + sin = undefined + cos = undefined + tan = undefined + -- Logical + invert c = if isUpper c then toLower c else toUpper c + shiftleft x = chr $ ord x + 1 + shiftright x = chr $ ord x - 1 + -- Conversion + dtypeToRational = toRational . ord + rationalToDtype = chr . P.floor. fromRational @Double \ No newline at end of file diff --git a/src/Indexing.hs b/src/Indexing.hs new file mode 100644 index 0000000..a888573 --- /dev/null +++ b/src/Indexing.hs @@ -0,0 +1,244 @@ +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE GADTs #-} + +module Indexing where + + +import Control.Exception +import qualified Data.Vector.Storable as V +import Data.Vector.Storable (Vector) +import qualified Data.Map as M +import Type.Reflection + +import NdArray +import Typing +import qualified DType +import DType (DType) +import NdArrayException + +preciseDiv x y = fromIntegral @Int @Float x / fromIntegral @Int @Float y + +-- Takes the shape assuming you want to move directly between dimensions. +defStride :: Vector Int -> Vector Int +defStride sh = V.scanr' (*) 1 $ V.drop 1 sh + +stride :: NdArray -> NdArray +stride (NdArray sh st v) = + let + -- shape + dimAcc = V.scanr' (*) 1 sh + --newshape = V.map (i -> ((dim' V.!(i+1)) * (sh V.!i -1) +1) / st V.! i) (enumFromN 0 (V.length sh)) + newshape = + V.map (\i -> + ceiling $ preciseDiv (1 + (dimAcc V.! (i+1)) * (sh V.!i -1)) (st V.! i) :: Int) + (V.enumFromN 0 (V.length sh)) + --V.zipWith (\d s -> ceiling (preciseDiv d s) :: Int) [0..] st + -- stride + singleStride = defStride newshape + -- vector + grab i = V.and $ V.zipWith (\t h -> i `mod` t < h) st (V.drop 1 dimAcc) + newV = V.force (V.ifilter (\i _ -> grab i) v) + in + if V.all (==1) st then (NdArray sh st v) + else NdArray newshape singleStride newV + +{- | Arrays are stored as vectors with a shape. Since vectors only have one dimension, +we convert between the vector index, i, and multi-dimension index, [x,y,z,...], using the +shape of the array, [sx,sy,sz,...], as follows: + + i = x + y*sx + z*sx*sy + ... + + x = i/(1) % sx + y = i/(sx) % sy + z = i/(sx*sy) % sz + ... +-} + +-- * INDEXING + +-- | Generates the list of all multi-dimensional indices for a given shape +generateIndices :: [Int] -> [[Int]] +generateIndices = foldr (\x xs -> [ i:t | i <- [0..(x-1)], t <- xs]) [[]] + +{- | Generates two maps to convert between the single dimension index of the +underlying vector and the multi-dimensional index of the NdArray and back, +given the NdArray shape. +-} +mapIndices :: [Int] -> (M.Map Int [Int], M.Map [Int] Int) +mapIndices sh = (M.fromList oneDkey, M.fromList twoDkey) + where + twoDinds = generateIndices sh + oneDkey = zip [0..] twoDinds + twoDkey = zip twoDinds [0..] + +-- Indexes a vector with an NdArray multi-index using a mapping (unsafe). +--vecInd :: forall a . DType a => M.Map [Int] Int -> Vector a -> [Int] -> a- +--vecInd mapp v i = v V.! (mapp M.! i) + +{- +-- | Converts a shape and multi-index to a 1D index. +collapseInd :: [Integer] -> [Integer] -> Integer +collapseInd sh indices = collapseRun (reverse sh) (reverse indices) 1 + +-- Helper for collapseInd +collapseRun :: [Integer] -> [Integer] -> Integer -> Integer +collapseRun _ [] _ = 0 +collapseRun [] _ _ = 0 +collapseRun (s:ss) (x:xs) runSize = x*runSize + collapseRun ss xs (s*runSize) + +-- | Converts a shape and 1D index to a multi-index. +expandInd :: [Integer] -> Integer -> [Integer] +expandInd sh i = reverse $ expandRun (reverse sh) i 1 + +-- Helper for expandInd +expandRun :: [Integer] -> Integer -> Integer -> [Integer] +expandRun [] _ _ = [] +expandRun (s:ss) i runSize = x : expandRun ss i (s*runSize) + where x = (i `div` runSize) `mod` s +-} +vGet :: DType a => Vector a -> Vector Int -> [Int] -> a +vGet v t is = v V.! (collapseInd t $ V.fromList is) + +collapseInd :: Vector Int -> Vector Int -> Int +collapseInd st ind = V.sum $ V.zipWith (*) st ind + +expandInd :: Vector Int -> Int -> Vector Int +expandInd st ind = + let st' = V.toList st + in V.fromList $ expandRun st' ind + +expandRun :: [Int] -> Int -> [Int] +expandRun [] _ = [] +expandRun (s:sts) x = + if s == 0 then (0 : expandRun sts x) + else x `div` s : expandRun sts (x `mod` s) + {- + let st' = V.map (/= 0) st + V.zipwith (/) (V.drop 1 $ V.scanl mod ind st) st + +scanl :: (b -> a -> b) -> b -> [a] -> [b] +scanl f z xs = foldr go (const []) xs z + where + go x continue acc = let next = f acc x in next : continue next +-} + +-- | Converts the multi-index for one shape to another +map1DIndex :: Vector Int -> Vector Int -> Int -> Int +map1DIndex t d i = collapseInd d (expandInd t i) + +-- | Checks an index does not exceed the shape. +validIndex :: NdArray -> [Int] -> Bool +validIndex (NdArray _ sh _) i = (length i == length s) && and (zipWith lessAbs i s) + where + s = V.toList sh + lessAbs x y = (0 <= x && x < y) || (0 < -x && -x <= y) + +{- | Takes a multi-dimensional index and returns the value in the NdArray at that position. +Indices can be negative, where -1 is the row in that dimension. +If an index exceeds the size of its dimension, a value will still be returned, the identity +value for the array e.g. 0. To avoid this use !?. +-} +-- >>> m = fromListFlat [2,4,8 :: Int] +-- >>> m #! [1] :: Int +-- 4 +-- >>> m #! [50] :: Int +-- 0 +(#!) :: DType a => NdArray -> [Int] -> a +(NdArray st sh v) #! i = case NdArray sh sh v !? i of + Just val -> val + Nothing -> DType.addId :: DType a => a + +{- | The safer version of #! which returns Nothing if an index exceeds the shape bounds. -} +-- >>> m = fromListFlat [2,4,8 :: Int] +-- >>> m !? [1] :: Maybe Int +-- Just 4 +-- >>> m !? [50] :: Maybe Int +-- Nothing +(!?) :: forall a . DType a => NdArray -> [Int] -> Maybe a +(NdArray st sh v) !? i = + let + -- Converts any negative indices to their equivalent positives + positives = V.zipWith positiveInd sh (V.fromList i) + flatInd = fromIntegral $ collapseInd st positives :: Int + in + -- The type comparison should always hold + if validIndex (NdArray st sh v) i then + case ty v `eqTypeRep` typeRep @(Vector a) of + Just HRefl -> Just (v V.! flatInd) :: Maybe a -- Indexing the vector + Nothing -> Nothing + else Nothing + +-- * SLICING +{- +-- | Type which allows you to provide only a single index or a range of indices. +data IndexRange = I Integer | R Integer Integer deriving (Show, Eq) + +-- | Integrated indexing and slicing. For each dimension you can provide either a single value +-- or a range of values where a slice will be taken. +(#!+) :: NdArray -> [IndexRange] -> NdArray +(#!+) (NdArray sh v) irs = sliceWithMap m 0 (map forceRange irs) (NdArray sh v) + where (m,_) = mapIndices sh + +-- Converts an IndexRange to a range of indices in the standard pair form. +forceRange :: IndexRange -> (Integer, Integer) +forceRange (I i) = (i,i) +forceRange (R s t) = (s,t) +-} +-- Converts negative indices to their positive equivalents, counting back +-- from the end of the array (i.e. -1 is the last element). +positiveInd :: (Ord a, Num a) => a -> a -> a +positiveInd s i = if i < 0 then s+i else i + +{- | Takes a series of ranges corresponding to each dimension in the array and returns +the sub-array. Indices are inclusive and can be negative. -} +slice :: [(Int, Int)] -> NdArray -> NdArray +slice sl (NdArray sh st v) = + let -- todo: -ve indices + ranges = zipWith (\(i,j) t -> [t*i, t*(i+1) .. t*j]) sl (V.toList st) + indices = V.fromList $ (map sum . sequence) ranges + sh' = V.fromList $ map (\(i,j) -> j-i+1) sl + v' = V.map (v V.!) indices + in + NdArray sh' (defStride sh') v' +--slice ss (NdArray sh v) = sliceWithMap m 0 ss (NdArray sh v) +-- where (m,_) = mapIndices sh +-- https://rosettacode.org/wiki/Cartesian_product_of_two_or_more_lists#Haskell + +{- +-- | Equivalent slicing operator. +(!/) :: NdArray -> [(Integer, Integer)] -> NdArray +(!/) nd ss = slice ss nd + +-- Takes a slice on an NdArray given the mapping from the vector index to NdArray index. +-- Iterates through each dimension of the slice one at a time. +sliceWithMap :: M.Map Int [Integer] -> Int -> [(Integer, Integer)] -> NdArray -> NdArray +sliceWithMap _ _ [] nd = nd +sliceWithMap _ d _ (NdArray sh v) | d >= length sh = NdArray sh v +sliceWithMap m d (s : ss) (NdArray sh v) = sliceWithMap m (d+1) ss $ + sliceDim s d m (NdArray sh v) + +-- Takes a slice of an NdArray at a particular dimension. +sliceDim :: (Integer, Integer) -> Int -> M.Map Int [Integer] -> NdArray -> NdArray +sliceDim (x,y) d m (NdArray sh v) = + if d >= length sh then throw (ExceededShape (fromIntegral d) sh) + else NdArray + (if y' < x' then [] else shrinkNth d (y'-x'+1) sh) + (V.ifilter + (\i _ -> + let dimInd = (m M.! i) !! d + in x' <= dimInd && dimInd <= y') + v + ) + where + dimSize = sh !! d + (x', y') = (positiveInd dimSize x, positiveInd dimSize y) + +-- Replaces the nth value of an array if the newValue is smaller. +-- https://stackoverflow.com/questions/5852722/replace-individual-list-elements-in-haskell +shrinkNth :: Ord a => Int -> a -> [a] -> [a] +shrinkNth _ _ [] = [] +shrinkNth n newVal (x:xs) + | n == 0 = if newVal < x then newVal:xs else x:xs + | otherwise = x:shrinkNth (n-1) newVal xs + -} \ No newline at end of file diff --git a/src/MatrixForm.hs b/src/MatrixForm.hs new file mode 100644 index 0000000..29a718e --- /dev/null +++ b/src/MatrixForm.hs @@ -0,0 +1,74 @@ +module MatrixForm where + + +import NdArray +import Indexing +import Data.Tree +import qualified Data.Vector.Storable as V +{- +-- * READING MATRICIES + +{- | This type is specifically for pretty explicit definitions of NdArrays. +The A constructor is for Array - a set of values and B is the value. +-- Example 2x3x2 +l :: TreeMatrix Int +l = A [A [A [B 1, B 2], + A [B 3, B 4], + A [B 5, B 6]], + + A [A [B 7, B 8], + A [B 9, B 10], + A [B 11, B 12]]] +-} +data TreeMatrix a = B a | A [TreeMatrix a] + +-- Converts a TreeMatrix to a Tree of lists +matrixToTree :: TreeMatrix a -> Tree [a] +matrixToTree (B x) = Node [x] [] +matrixToTree (A xs) = Node [] (map matrixToTree xs) + +-- Converts a Tree of lists to a single ordered list. +flattenToList :: Tree [a] -> [a] +flattenToList = concat . flatten + +-- Calculates the shape of the NdArray corresponding to the Tree. +treeShape :: Tree [a] -> [Integer] +treeShape t = zipWith (\x y -> fromIntegral $ div x y ::Integer) (drop 1 levelLen) levelLen + where levelLen = map length $ levels t + +-- Calculates the shape of the NdArray corresponding to the TreeMatrix. +matrixShape :: TreeMatrix a -> [Integer] +matrixShape = treeShape . matrixToTree +-} +-- * WRITING MATRICIES + +-- | Prints out the pretty NdArray representation. +printArray :: NdArray -> IO () +printArray nd = putStr $ prettyShowArray nd + +-- | Converts an NdArray to its pretty representation. +-- Values along a row are separated whitespace. Along a column, newlines. +-- For higher dimensions, an additional newline is added to separate the nxm matrices. +prettyShowArray :: NdArray -> String +prettyShowArray nd = + case stride nd of + (NdArray s _ v) -> + conc <> "\n" + where + vl = map show (V.toList v) + largest = maximum $ map length vl + newlines = scanr1 (*) (V.toList s) + spaced = zipWith (\i x -> (i, padStringTo largest x)) [0..] vl + lined = addNewlines newlines spaced + conc = concatMap snd lined + +-- Separates values along a row by whitespace. +padStringTo :: Int -> String -> String +padStringTo i s = replicate (i - length s) ' ' ++ s ++ " " + +-- Separates columns and higher dimensions by newlines. +addNewlines :: [Int] -> [(Int, String)] -> [(Int, String)] +addNewlines ls xs = foldr (\l -> + map (\(i, x) -> if i /= 0 && i `mod` l == 0 + then (i, "\n" ++ x) + else (i, x))) xs ls diff --git a/src/MyLib.hs b/src/MyLib.hs deleted file mode 100644 index e657c44..0000000 --- a/src/MyLib.hs +++ /dev/null @@ -1,4 +0,0 @@ -module MyLib (someFunc) where - -someFunc :: IO () -someFunc = putStrLn "someFunc" diff --git a/src/NdArray.hs b/src/NdArray.hs new file mode 100644 index 0000000..da990b2 --- /dev/null +++ b/src/NdArray.hs @@ -0,0 +1,18 @@ +{-# LANGUAGE GADTs #-} + +module NdArray where + +import DType +import Data.Vector.Storable + +-- * NdArray +-- | The core of this module. NdArrays can be of any DType a and size/shape (list of dimensions) +-- These are hidden by the type. +data NdArray where + -- shape stride elements + NdArray :: DType a => Vector Int -> Vector Int -> Vector a -> NdArray + +-- | By default arrays are printed flat with the shape as metadata. +-- For a tidier representation, use printArray. +instance Show NdArray where + show (NdArray sh st v) = "{elements: " <> show v <> ", shape: " <> show sh <> ", stride: " <> show st <> "}" \ No newline at end of file diff --git a/src/NdArrayException.hs b/src/NdArrayException.hs new file mode 100644 index 0000000..e56dbbc --- /dev/null +++ b/src/NdArrayException.hs @@ -0,0 +1,57 @@ +{-# LANGUAGE GADTs #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE ScopedTypeVariables #-} + +module NdArrayException where + +import Control.Exception +import Type.Reflection +import Data.Vector.Storable (Vector) + +import DType +import NdArray + +-- | The main type of exception thrown from Numskull functions when the user +-- tries to perform illegal operations given the size and shape of the array. +data NdArrayException + = DTypeMismatch NdArray NdArray String + | ShapeMismatch NdArray NdArray String + | CreationSize Int (Vector Int) + | TypeMismatch String + | ExceededShape Int (Vector Int) + | NotBroadcastable NdArray NdArray String + +instance Exception NdArrayException + +instance Show NdArrayException where + show (DTypeMismatch (NdArray _ _ v) (NdArray _ _ u) extra) = + if extra == "" then + "Cannot match NdArrays of type '" <> showType v <> + "' and type '" <> showType u <> "'." + else + "Cannot perform " <> extra <> " on mismatching NdArrays of type '" <> showType v <> + "' and type '" <> showType u <> "'." + + show (ShapeMismatch (NdArray s t _) (NdArray r d _) extra) = + if extra == "" then + "Cannot match NdArrays of shape " <> show s <> " and stride " <> show t <> + ", and shape " <> show r <> "and stride " <> show d <> "." + else + "Cannot perform " <> extra <> " on mismatching NdArrays of shape " <> + show s <> " and stride " <> show t <> + " and shape " <> show r <> "and stride " <> show d <> "." + + show (CreationSize sz sh) = + "Cannot create array of size " <> show sz <> " and shape " <> show sh <> "." + + show (TypeMismatch str) = str + + show (ExceededShape dim sh) = + "Cannot index into dimension " <> show dim <> "in NdArray of shape " <> show sh <> "." + + show (NotBroadcastable (NdArray s _ _) (NdArray r _ _) str) = + "Cannot broadcast NdArrays of shape " <> show s <> "and shape" <> show r <> str <> "." + +-- Returns the string type of vector elements. +showType :: forall a . DType a => Vector a -> String +showType _ = show (typeRep @a) \ No newline at end of file diff --git a/src/Numskull.hs b/src/Numskull.hs new file mode 100644 index 0000000..39b0dc7 --- /dev/null +++ b/src/Numskull.hs @@ -0,0 +1,1316 @@ +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} + +module Numskull ( + {- + -- Metadata + DType + , size + , shape + , getVector + , ndType + , checkNdType + , isEmpty + + -- Creation + , NdArray + , fromList + , fromListFlat + , TreeMatrix + , fromMatrix + , fromVector + , singleton + , arange + , zeros + , squareArr + + -- General mapping, folding & zipping + , foldrA + , mapA + , mapTransform + , pointwiseZip + , pointwiseBool + , zipArrayWith + + -- Summaries + , origin + , maxElem + , minElem + + -- Mathematical constant + , scale + , absA + , signumA + , ceilA + , floorA + , sinA + , cosA + , tanA + , invertA + , shiftleftA + , shiftrightA + + -- Mathematical pointwise + , elemDivide + , elemDiv + , elemPow + , elemPower + , Numskull.sum + , mean + + -- Bounds + , clip + + -- Type Conversions + , convertDTypeTo + , matchDType + + -- Size conversions + , resize + + -- Shape conversions/manipulations + , reshape + , padShape + , constrainShape + , broadcast + , concatAlong + , gather + + -- Matrix manipulation + , swapRows + , diagonal + , transpose + , transposePerm + + --Matrix multiplication + , dot + , matMul + , upperTriangle + , determinant + , determinant2D + , swapRowsWith0Pivot + , gemm + + -- Indexing + , IndexRange + , collapseInd + , expandInd + , map1DIndex + , validIndex + , (#!) + , (!?) + , (#!+) + , slice + , (!/) + + -- Pretty printing + , printArray + , prettyShowArray + + -- typing + , (=@=) + + -- numpy serialisation + , saveNpy + , loadNpy +-} +) where + +import qualified DType +import DType (DType) +import Indexing +import MatrixForm +import NdArray +import NdArrayException +import Serialisation +import Typing + +import Control.Exception +import Control.Monad (zipWithM) +import Data.List (elemIndex, intersect, sort, zipWith6) +import qualified Data.Map as M +import Data.Maybe (fromJust, isNothing) +import Data.Vector.Storable (Vector) +import qualified Data.Vector.Storable as V +import Type.Reflection + +-- $setup +-- >>> import Numskull as N +-- >>> import qualified Vector + +-- * Numeric & Comparative NdArray instances: +-------------------------------------------------------------------------------- +--calculateNewshape :: Vector Int -> Vector Int -> Vector Int +--c-alculateNewshape sh st = V.generate (V.length sh) $ i -> +-- (sh V.! i + 1) / (st V.! i) + +--V.scanr' (*) 1 (V.drop 1 $ V.fromList [10,5,5::Int]) + + + +--force $ and also check for emptyness + +{- +grab :: Int -> Vector Int -> Vector Int -> Bool +grab i sh st = + let + dimAccSml = V.scanr' (*) 1 $ V.drop 1 sh + p = V.zipWith (\t h -> i `mod` t < h) st dimAccSml + in V.and p + + +stride :: NdArray -> Vector Int +stride (NdArray sh st v) = --force $ and also check for emptyness + let + dimAcc = V.scanr1' (*) sh + fl = fromIntegral @Int @Float + newshape = V.zipWith (\d s -> ceiling (fl d / fl s) ::Int) dimAcc st + coef = V.drop 1 dimAcc + in + newshape -} + {-V.generate (V.head dimAcc) $ i -> + let + x = 4 + V.zipWith + in + undefined + -} + +n1 = NdArray (V.fromList [2,2]) (V.fromList [2,1]) (V.fromList [1,2,3,4::Int]) +n2 = NdArray (V.fromList [3,3]) (V.fromList [3,1]) (V.fromList [1..9::Int]) +n3 = NdArray (V.fromList [4,4]) (V.fromList [4,2]) (V.fromList [1..16::Int]) +n4 = NdArray (V.fromList [4,2]) (V.fromList [2,1]) (V.fromList [1,3,5,7,9,11,13,15::Int]) +n5 = NdArray (V.fromList [2]) (V.fromList [1]) (V.fromList [1,2::Int]) + +instance Eq NdArray where + -- | Arrays are equal if their elements and shape exactly match. + nd1 == nd2 = + case (stride nd1, stride nd2) of + (NdArray s _ v, NdArray r _ u) -> + case v =@= u of + Just HRefl -> s == r && v == u + Nothing -> False + nd1 /= nd2 = + case (stride nd1, stride nd2) of + (NdArray s _ v, NdArray r _ u) -> + case v =@= u of + Just HRefl -> s /= r || v /= u + Nothing -> True + +instance Ord NdArray where + {- | Arrays are only comparable when they are the same shape. Then they are + ordered by pointwise comparison. + -} + nd1 `compare` nd2 = + case (stride nd1, stride nd2) of + (NdArray s t v, NdArray r d u) -> + if s == r then case v =@= u of + Just HRefl -> compare v u + Nothing -> throw (DTypeMismatch (NdArray s t v) (NdArray r d u) "compare") + else throw (ShapeMismatch (NdArray s t v) (NdArray r d u) "compare") + + nd1 <= nd2 = + case (stride nd1, stride nd2) of + (NdArray s t v, NdArray r d u) -> + if s == r then case v =@= u of + Just HRefl -> v <= u + Nothing -> throw (DTypeMismatch (NdArray s t v) (NdArray r d u) "<=") + else throw (ShapeMismatch (NdArray s t v) (NdArray r d u) "<=") + +instance Num NdArray where + -- | Adds elements pointwise + (+) = broadcastZipTyped DType.add + -- | Subtracts elements pointwise + (-) = broadcastZipTyped DType.subtract + -- | Multiplies elements pointwise + (*) = broadcastZipTyped DType.multiply + -- | Inverts all elements according to their DType instance + negate (NdArray sh st v) = NdArray sh st (V.map DType.invert v) + -- | Absolute value of each element + abs (NdArray sh st v) = NdArray sh st (V.map DType.abs v) + -- | Signum of each element + signum (NdArray sh st v) = NdArray sh st (V.map DType.signum v) + -- Creates a singleton array. NB: must be converted to a storable Int. + fromInteger = singleton . fromInteger @Int + + +-- * General & Creation +-------------------------------------------------------------------------------- + +-- | Gets the total number of elements in a given array shape. +-- >>> size [2,3] +-- 6 +size :: Vector Int -> Int +size sh = V.product sh + +-- | Returns the shape list of an array. +shape :: NdArray -> Vector Int +shape (NdArray s _ _) = s + +-- | Gets the vector of an array. Requires a type specification to output safely. +getVector :: forall a . DType a => NdArray -> Vector a +getVector (NdArray _ _ v) = v <-@ typeRep @(Vector a) + +-- | Gets the TypeRep String representation of the NdArray elements +ndType :: NdArray -> String +ndType (NdArray _ _ v) = show $ vecType v + +-- | Compares the type of the array elements to the given TypeRep. +checkNdType :: forall a b . (DType a, DType b) => NdArray -> TypeRep a -> Maybe (a :~~: b) +checkNdType (NdArray _ _ v) _ = + let tv = vecType v + in case eqTypeRep tv (typeRep @b) of + Just HRefl -> eqTypeRep (typeRep @a) (tv :: TypeRep b) + _ -> error "Impossibly mismatching types." + +-- | Helper to get the vector typeRep. +vecType :: forall a . DType a => Vector a -> TypeRep a +vecType _ = typeRep @a + +-- | Checks if the undelying vector has any elements. +isEmpty :: NdArray -> Bool +isEmpty (NdArray _ _ v) = V.null v + +-- | Convert a list of arrays to a list of vectors, provided they are all of the specified type. +unpackArrays :: forall a . DType a => [NdArray] -> TypeRep a -> Maybe ([Vector Int],[Vector Int],[Vector a]) +unpackArrays [] _ = Just ([],[],[]) +unpackArrays ((NdArray sh st v) : nds) t = + case v =@ typeRep @(Vector a) of + Just HRefl -> + case unpackArrays nds t of + Just (shs, sts, vs) -> Just (sh:shs, st:sts, v:vs) + _ -> Nothing + _ -> Nothing + +-- Gets the DType additive identity matching the element type of a vector. +identityElem :: forall a . DType a => Vector a -> a +identityElem _ = DType.addId :: DType a => a + +-- | Creates an NdArray from a given shape and list. The number of elements must match. +-- >>> printArray $ fromList [2,2] [1,2,3,4::Int] +-- 1 2 +-- 3 4 +fromList :: DType a => [Int] -> [a] -> NdArray +fromList sh l = + if length l /= product sh then throw $ CreationSize (length l) (V.fromList sh) + else NdArray h (defStride h) (V.fromList l) + where h = V.fromList sh + +-- | Creates a 1xn NdArray from a list. +-- >>> printArray $ fromListFlat [1,2,3,4::Int] +-- 1 2 3 4 +fromListFlat :: DType a => [a] -> NdArray +fromListFlat l = NdArray sh (defStride sh) (V.fromList l) + where sh = V.fromList [length l] + +genStride :: DType a => [Int] -> Vector a -> NdArray +genStride shL v = let sh = V.fromList shL in NdArray sh (defStride sh) v + +{-| Creates an NdArray from an explicitly given matrix such as the example 2x3. -} +-- >>> m :: TreeMatrix Int +-- >>> m = A [A [B 1, B 2], +-- >>> A [B 3, B 4], +-- >>> A [B 5, B 6]] +-- >>> printArray $ fromMatrix m +-- 1 2 +-- 3 4 +-- 5 6 +{- +fromMatrix :: DType a => TreeMatrix a -> NdArray +fromMatrix m = NdArray (matrixShape m) (V.fromList l) + where l = flattenToList $ matrixToTree m + +-- | The safe standard constructor. Returns Nothing if the +-- shape does not match the given vector length. +fromVector :: DType a => [Integer] -> Vector a -> Maybe NdArray +fromVector sh v = if V.length v == fromIntegral (product sh) + then Just $ NdArray sh v + else Nothing +-} + +-- | Creates a 1x1 matrix +-- >>> printArray $ singleton (3::Int) +-- 3 +singleton :: DType a => a -> NdArray +singleton x = NdArray (V.singleton 1) (V.singleton 1) (V.singleton x) + +-- | Creates a flat array over the specified range. +arange :: (Enum a, DType a) => a -> a -> NdArray +arange mini maxi = + if mini <= maxi + then NdArray (V.fromList [fromIntegral $ fromEnum maxi - fromEnum mini + 1]) (V.singleton 1) (V.fromList [mini..maxi]) + else NdArray e e e + where e = V.empty :: Vector Int + +{- | Creates the smallest possible square matrix from the given list, +padding out any required space with the identity element for the DType -} +squareArr :: forall a . DType a => [a] -> NdArray +squareArr [] = NdArray e e e + where e = V.empty :: Vector Int +squareArr xs = + let + l = length xs + d = ceiling (sqrt $ fromIntegral @Int @Float l) + p = V.replicate (d^(2::Int) - l) (DType.addId :: a) + sh = V.fromList [d, d] + in NdArray sh (defStride sh) (V.fromList xs V.++ p) + +{- | Creates an array of the given shape of the identity element for the given type. -} +zeros :: forall a . DType a => TypeRep a -> Vector Int -> NdArray +zeros _ s = NdArray s (defStride s) zerovec + where + ident = DType.addId :: (DType a => a) + zerovec = V.replicate (size s) ident :: DType a => Vector a + +-- * Pointwise Functions +-------------------------------------------------------------------------------- + +-- * One Argument + +{- | Near identical to a standard foldr instance, expect NdArrays do not have an explicit type. +Folds in row-major order. +-} +foldrA :: forall a b . DType a => (a -> b -> b) -> b -> NdArray -> b +foldrA f z nd = + case stride nd of + (NdArray _ _ v) -> + case v =@= (undefined :: Vector a) of + Just HRefl -> V.foldr f z v + _ -> throw $ TypeMismatch "Fold starting value type does not match array type." + +-- | Near identical to a standard map implementation in row-major order. +mapA :: forall a . forall b . (DType a, DType b) => (a -> b) -> NdArray -> NdArray +mapA f (NdArray sh st v) = case v =@= (undefined :: Vector a) of + Just HRefl -> NdArray sh st (V.map f v) + _ -> throw $ TypeMismatch "Map function input does not match array type." + +-- | Maps functions which return the same type. +mapTransform :: (forall a . DType a => a -> a) -> NdArray -> NdArray +mapTransform f (NdArray sh st v) = NdArray sh st (V.map f v) + +-- | Multiplies all elements by a scalar. +scale :: forall a . DType a => a -> NdArray -> NdArray +scale x = mapA (DType.multiply x) + +-- | Takes the absolute value of all elements. +absA :: NdArray -> NdArray +absA = mapTransform DType.abs + +-- | Replaces all elements by their signum. +-- >>> printArray $ signumA (fromList [5] [-50, -25, 0, 1, 10::Int]) +-- -1 -1 0 1 1 +signumA :: NdArray -> NdArray +signumA = mapTransform DType.signum + +-- | Mathematical ceiling of each element (preserving DType). +ceilA :: NdArray -> NdArray +ceilA = mapTransform DType.ceil + +-- | Mathematical floor of each element (preserving DType). +floorA :: NdArray -> NdArray +floorA = mapTransform DType.floor + +-- | Sine of each element (preserving DType). +sinA :: NdArray -> NdArray +sinA = mapTransform DType.sin + +-- | Cosine of each element (preserving DType). +cosA :: NdArray -> NdArray +cosA = mapTransform DType.cos + +-- | Tangent of each element (preserving DType). +tanA :: NdArray -> NdArray +tanA = mapTransform DType.tan + +-- | Either elementwise NOT or NEG depending on the DType. +invertA :: NdArray -> NdArray +invertA = mapTransform DType.invert + +-- | Multiply each element by 2. +shiftleftA :: NdArray -> NdArray +shiftleftA = mapTransform DType.shiftleft + +-- | Divide each element by 2. +shiftrightA :: NdArray -> NdArray +shiftrightA = mapTransform DType.shiftright + +-- | Returns the element at the 0th position of the array. +origin :: forall a . DType a => NdArray -> a +origin (NdArray _ _ v) = (v V.! 0) <-@ typeRep @a + +-- | Returns the largest element. +maxElem :: forall a . DType a => NdArray -> a +maxElem nd = foldrA max (origin nd) nd + +-- | Returns the smallest element. +minElem :: forall a . DType a => NdArray -> a +minElem nd = foldrA min (origin nd) nd + +-- | Constrains all elements of the array to the range specified by [mini, maxi]. +-- If they are given as Nothing, the range is infinite in that direction. +-- NB: must still specify type for Nothing i.e. clip (Nothing :: Maybe Int) Nothing myNd +clip :: forall a . DType a => Maybe a -> Maybe a -> NdArray -> NdArray +clip mini maxi (NdArray sh st v) = case v =@= (undefined :: Vector a) of + Just HRefl -> + case (mini, maxi) of + (Just mn, Just mx) -> mapA (\x -> if x <= mn then mn else if x >= mx then mx else x) (NdArray sh st v) + (Just mn, Nothing) -> mapA (\x -> if x <= mn then mn else x) (NdArray sh st v) + (Nothing, Just mx) -> mapA (\x -> if x >= mx then mx else x) (NdArray sh st v) + (Nothing, Nothing) -> NdArray sh st v + _ -> throw (TypeMismatch $ "Min and max types do not match array type of " <> show (vecType v) <> ".") + +-- * Two Arguments + +broadcastZipTyped :: (forall t . DType t => t -> t -> t) -> NdArray -> NdArray -> NdArray +broadcastZipTyped zipfunc (NdArray s t v) (NdArray r d u) = + case v =@= u of + Nothing -> throw (DTypeMismatch (NdArray s t v) (NdArray r d u) "broadcastZipTyped") + Just HRefl -> + case broadcastConfig (NdArray s t v) (NdArray r d u) of + Nothing -> throw (NotBroadcastable (NdArray s t v) (NdArray r d u) " in some function") + Just (newshape, t', d') -> + let newstride = defStride newshape + in NdArray newshape newstride $ V.generate (size newshape) (\i -> + let + multi = expandInd newstride i + -- collapse the multi index over the two arrays + -- apply the operation to the fetched values + v1 = v V.! collapseInd t' multi + v2 = u V.! collapseInd d' multi + in + zipfunc v1 v2 + ) + +broadcastZipUntyped :: forall a b c . (DType a, DType b, DType c) => (a -> b -> c) -> NdArray -> NdArray -> NdArray +broadcastZipUntyped zipfunc (NdArray s t v) (NdArray r d u) = + case (v =@ typeRep @(Vector a), u =@ typeRep @(Vector b)) of + (Just HRefl, Just HRefl) -> + case broadcastConfig (NdArray s t v) (NdArray r d u) of + Nothing -> throw (NotBroadcastable (NdArray s t v) (NdArray r d u) " in some function") + Just (newshape, t', d') -> + let newstride = defStride newshape + in NdArray newshape newstride $ V.generate (size newshape) (\i -> + let + multi = expandInd newstride i + -- collapse the multi index over the two arrays + -- apply the operation to the fetched values + v1 = v V.! collapseInd t' multi + v2 = u V.! collapseInd d' multi + in + zipfunc v1 v2 :: c + ) + _ -> throw (TypeMismatch "Cannot zip NdArrays with different dtypes to the zip function.") + +-- | The generic function for operating on two matching DType arrays with the same shape +-- in an element-wise/pointwise way. Errors if mismatching +-- >>> x = fromList [2,2] [1,2,3,4 :: Int] +-- >>> y = fromList [2,2] [5,2,2,2 :: Int] +-- >>> printArray $ pointwiseZip (DType.multiply) x y +-- 5 4 +-- 6 8 +pointwiseZip :: (forall t . DType t => t -> t -> t) -> NdArray -> NdArray -> NdArray +pointwiseZip zipfunc nd1 nd2 = + case (stride nd1, stride nd2) of + (NdArray s t v, NdArray r d u) -> + if s == r then + case v =@= u of + Just HRefl -> NdArray s t (V.zipWith zipfunc v u) + Nothing -> throw (DTypeMismatch (NdArray s t v) (NdArray r d u) "pointwiseZip") + else throw (ShapeMismatch (NdArray s t v) (NdArray r d u) "pointwiseZip") + +-- | A slightly specialised version of pointwise zip intended for comparative functions. +pointwiseBool :: (forall t . DType t => t -> t -> Bool) -> NdArray -> NdArray -> NdArray +pointwiseBool zipfunc nd1 nd2 = + case (stride nd1, stride nd2) of + (NdArray s t v, NdArray r d u) -> + if s == r then + case v =@= u of + Just HRefl -> NdArray s t (V.zipWith zipfunc v u) + Nothing -> throw (DTypeMismatch (NdArray s t v) (NdArray r d u) "pointwiseBool") + else throw (ShapeMismatch (NdArray s t v) (NdArray r d u) "pointwiseBool") + +-- | Completely generic zip on two NdArrays. If the shapes mismatch, they are truncated as with +-- standard zips. Function inputs must match the DTypes. +zipArrayWith :: forall a b c . (DType a, DType b, DType c) => (a -> b -> c) -> NdArray -> NdArray -> NdArray +zipArrayWith zipfunc (NdArray s t v) (NdArray r d u) = + let + nd1' = stride (NdArray s t v) + nd2' = stride (NdArray r d u) + -- Truncate the shapes to match each other + ndC1 = constrainShape (shape nd1') nd1' + ndC2 = constrainShape (shape nd2') nd2' + s' = shape ndC1 + in + -- Type check the function + case (v =@ typeRep @(Vector a), u =@ typeRep @(Vector b)) of + (Just HRefl, Just HRefl) -> + let + v' = getVector ndC1 :: Vector a + u' = getVector ndC2 :: Vector b + in NdArray s' (defStride s') (V.zipWith zipfunc v' u' :: Vector c) + _ -> throw (TypeMismatch "Cannot zip NdArrays with different dtypes to the zip function.") + +-- | Pointwise integer division. Will return an NdArray of type Int. +elemDiv :: NdArray -> NdArray -> NdArray +elemDiv nd1 nd2 = + case (nd1, nd2) of + (NdArray s t v, NdArray r d u) -> + if s == r then + case v =@= u of + Just HRefl -> elemDivVec s v r u + Nothing -> throw (DTypeMismatch (NdArray s t v) (NdArray r d u) "elemDiv") + else throw (ShapeMismatch (NdArray s t v) (NdArray r d u) "elemDiv") + +elemDivVec :: forall a . DType a => Vector Int -> Vector a -> Vector Int -> Vector a -> NdArray +elemDivVec s v r u = broadcastZipUntyped (DType.div :: a -> a -> Int) (NdArray s (defStride s) v) (NdArray r (defStride s) u) + +-- | Pointwise division +elemDivide :: NdArray -> NdArray -> NdArray +elemDivide = broadcastZipTyped DType.divide + +-- | Pointwise exponentiation (preserving DType) +elemPow :: NdArray -> NdArray -> NdArray +elemPow = broadcastZipTyped DType.pow + +-- | Pointwise exponentiation which forces precision. +-- Takes some NdArray of bases, an array of Double exponents and returns an array of Doubles. +elemPower :: NdArray -> NdArray -> NdArray +elemPower nd1 nd2 = + case (nd1, nd2) of + (NdArray s t v, NdArray r d u) -> + if s == r then + case u =@ typeRep @(Vector Double) of + Just HRefl -> elemPowerVec s v r u + Nothing -> throw (DTypeMismatch (NdArray s t v) (NdArray r d u) "elemPower") + else throw (ShapeMismatch (NdArray s t v) (NdArray r d u) "elemPower") + +elemPowerVec :: forall a . DType a => Vector Int -> Vector a -> Vector Int -> Vector Double -> NdArray +elemPowerVec s v r u = broadcastZipUntyped (DType.power :: a -> Double -> Double) (NdArray s (defStride s) v) (NdArray r (defStride r) u) + +-- * Many Arguments + +-- | Takes the pointwise sum over all the given NdArrays. If they are different shapes, +-- the smaller dimensions are padded out with the identity element. +-- The sum of the empty list is the singleton 0. +sum :: [NdArray] -> NdArray +sum [] = singleton (0::Int) +sum [nd] = nd +sum (NdArray sh1 st1 v1 : nds) = foldr (\x acc -> padShape sh (NdArray sh1 st1 v1) + acc) (zeros (vecType v1) sh) (NdArray sh1 st1 v1 : nds) + where sh = maximiseShape (map shape nds) + +sumStride :: [NdArray] -> NdArray +sumStride xs = Numskull.sum $ map stride xs + +-- Takes the maximum of each element pointwise matching from the end. +maximiseShape :: [Vector Int] -> Vector Int +maximiseShape [] = V.empty :: Vector Int +maximiseShape [sh] = sh +maximiseShape (sh : shs) = + let + m = maximiseShape shs + diff = V.length sh - V.length m + in + if diff > 0 + then V.zipWith max sh (V.take diff sh V.++ m) + else V.zipWith max (V.take (-diff) m V.++ sh) m + +-- | Finds the mean pointwise over the list of arrays. Smaller arrays are padded out with +-- the identity element. +mean :: [NdArray] -> NdArray +mean [] = let e = V.fromList ([] :: [Int]) in NdArray e e e +mean nds = s `elemDivide` NdArray sh (defStride sh) (V.replicate (size sh) (length nds)) + where + s = Numskull.sum nds + sh = shape s + +meanStride :: [NdArray] -> NdArray +meanStride xs = mean $ map stride xs + + +-- * Type & Shape Conversion +-------------------------------------------------------------------------------- + +{- | Converting between the standard dtypes and changing the shapes of arrays. +NB the difference between 'size' and 'shape'. The shape is an Integer list +describing the width of each dimension. Size refers to the total number of +elements in the array, i.e. the product of the shape. +-} + +-- | Converts an NdArray of one type to any other with a DType instance. +convertDTypeTo :: forall a . DType a => TypeRep a -> NdArray -> NdArray +convertDTypeTo t (NdArray sh st v) = convertDTFromTo (vecType v) t (NdArray sh st v) + +-- Helper with additional typing information +convertDTFromTo :: forall a b . (DType a, DType b) => + TypeRep a -> TypeRep b -> NdArray -> NdArray +convertDTFromTo _ _ (NdArray sh st v) = case v =@= (undefined :: Vector a) of + Just HRefl -> NdArray sh st (V.map convert v) + _ -> error "Impossible type mismatch." + where + convert :: (DType a, DType b) => a -> b + convert x = DType.rationalToDtype (DType.dtypeToRational x) + +-- | Converts the second NdArray to be the same DType as the first. +matchDType :: NdArray -> NdArray -> NdArray +matchDType (NdArray _ _ v) = convertDTypeTo (vecType v) + +{- Helper which checks that the array isn't larger than the shape contraints. +If it is valid the Boolean in the pair will be true and the vector is returned. +If it is invalid the vector is truncated first. +-} +constrainSize :: DType a => Int -> Vector a -> (Bool, Vector a) +constrainSize s v = + if s < V.length v then (False, V.take s v) + else (True, v) + +-- Fill out any spaces in a vector smaller than the shape with 0s (or whatever the dtype 'identity' is) +padSize :: DType a => Int -> Vector a -> Vector a +padSize s v = v V.++ V.replicate (s - len) DType.addId + where len = V.length v + +-- Contrain or pad the vector to match the size +setSize :: DType a => Int -> Vector a -> Vector a +setSize s v = let (unchanged, u) = constrainSize s v in + if unchanged then padSize s u else u + +{- | Truncate or pad the NdArray to match the new given size. +The shape will be collapsed to 1xn. +-} +-- >>> x = fromList [2,2] [1,2,3,4 :: Int] +-- >>> printArray $ resize 6 x +-- 1 2 3 4 0 0 +-- >>> printArray $ resize 2 x +-- 1 2 +resize :: Int -> NdArray -> NdArray +resize s (NdArray _ _ v) = + NdArray (V.singleton s) (V.singleton 1) (setSize s v) + +-- | Shape-shift one array to another of the same size (Nothing otherwise). +-- >>> x = fromList [2,3] [1,2,3,4,5,6 :: Int] +-- >>> printArray x +-- 1 2 +-- 3 4 +-- 5 6 +-- >>> printArray $ fromJust $ reshape [3,2] x +-- 1 2 3 +-- 4 5 6 +reshape :: Vector Int -> NdArray -> Maybe NdArray +reshape r (NdArray sh st v) = if V.product sh == V.product r + then Just $ NdArray r st v + else Nothing + +-- Checks that the first shape is smaller or equal to the second. +smallerShape :: Vector Int -> Vector Int -> Bool +smallerShape s r = (V.length s <= V.length r) && V.and (V.zipWith (<=) s r) + +-- | Adds zero-rows to an array. Will error if you map to a smaller shape. +-- >>> x = fromList [2,2] [1,2,3,4 :: Int] +-- >>> printArray $ padShape [4,3] x +-- 1 2 0 0 +-- 3 4 0 0 +-- 0 0 0 0 + +padShape :: Vector Int -> NdArray -> NdArray +padShape r nd = + case stride nd of + NdArray sh st v -> + let + nullVec = V.replicate (size r) (identityElem v) + newIndices = V.imap (\i _ -> map1DIndex st (defStride r) i) v + in + if smallerShape sh r + then NdArray r (defStride r) (V.unsafeUpdate_ nullVec newIndices v) + else error "Cannot map to a smaller shape." + +{- +setDimensions :: Int -> Vector Int -> Vector Int +setDimensions d ind = let diff = d - V.length ind in + if diff >= 0 then V.replicate diff 0 V.++ ind + else V.drop (-diff) ind +-} + +strictSmallerShape :: Vector Int -> Vector Int -> Bool +strictSmallerShape s r = (V.length s <= V.length r) && V.and (V.zipWith (<) s r) + +addDimensions :: Int -> Vector Int -> Vector Int +addDimensions d ind = V.replicate (d - V.length ind) 1 V.++ ind + +padShape' :: Vector Int -> NdArray -> NdArray +padShape' r (NdArray sh st v) = + NdArray r d $ V.generate (size r) (\i -> + let + multi = expandInd d i + in + if V.and $ V.zipWith (<) multi p + then v V.! (collapseInd st $ V.drop (V.length multi - V.length st) multi) + else identityElem v + ) + where + d = defStride r + p = addDimensions (V.length r) sh + + +-- | Truncates the array to be no larger than the specified dimensions. +constrainShape :: Vector Int -> NdArray -> NdArray +constrainShape r nd = + case stride nd of + NdArray sh _ v -> + let + s' = V.zipWith min r sh + sPad = s' V.++ V.replicate (V.length sh - V.length r) 1 + in NdArray s' (defStride s') $ + V.ifilter (\i _ -> V.and $ V.zipWith (<) (expandInd sh i) sPad) v + + +-- generate the strides & new shape for two maybe broadcastable arrays +broadcastConfig :: NdArray -> NdArray -> Maybe (Vector Int, Vector Int, Vector Int) +broadcastConfig (NdArray s t v) (NdArray r d u) = + let + s' = V.replicate (V.length r - V.length s) 1 V.++ s + t' = V.replicate (V.length r - V.length s) 0 V.++ t + r' = V.replicate (V.length s - V.length r) 1 V.++ r + d' = V.replicate (V.length s - V.length r) 0 V.++ d + newshape = V.zipWithM (\x y -> + if x == y || x == 1 || y == 1 + then Just (max x y) else Nothing) s' r' + t'' = V.zipWith3 (\sx rx tx -> + if sx == 1 && rx /= 1 then 0 else tx) s' r' t' + d'' = V.zipWith3 (\rx sx dx -> + if rx == 1 && sx /= 1 then 0 else dx) r' s' d' + in + (\x->(x,t'',d'')) <$> newshape + +{-} +-- | Takes a pair of NdArrays and attempts to copy slices so that they are size matched. +-- Arrays are broadcastable if they either match in corresponding dimensions or one is +-- of dimension size 1 e.g. [2,5,1] and [2,1,6]. Missing dimensions are padded with 1s +-- e.g. [1,2,3] and [3] are broadcastable. +broadcast :: (NdArray, NdArray) -> Maybe (NdArray, NdArray) +broadcast (NdArray s v, NdArray r u) = + let + (s',v',r',u') = broadcastDimensions s v r u + newshape = zipWithM (\x y -> if x == y || x == 1 || y == 1 + then Just (max x y) else Nothing) s' r' + in + case newshape of + Nothing -> Nothing + Just ns -> Just ( + NdArray ns $ padRepeats ns m s' v', + NdArray ns $ padRepeats ns m r' u') + where m = fst $ mapIndices ns +-} +-- Pads out dimensions for broadcasting if one array is dimensionally smaller than another. +-- e.g. [1,2,3] and [3]. +broadcastDimensions :: (DType a, DType b) => + Vector Int -> Vector a -> Vector Int -> Vector b -> + (Vector Int, Vector a, Vector Int, Vector b) +broadcastDimensions s v r u + | sl == rl = (s,v, + r,u) + | sl > rl = (s,v, + sdiff V.++ r, + V.concat $ replicate (V.product sdiff) u) + | sl < rl = (rdiff V.++ s, + V.concat $ replicate (V.product rdiff) v, + r,u) + where + sl = V.length s + rl = V.length r + diff = Prelude.abs (sl - rl) + sdiff = V.take diff s + rdiff = V.take diff r + +-- Pads out a newshape with repetitions of the existing values +-- Takes the newshape, its map, the old shape and the vector. +{- +padRepeats :: DType a => + [Integer] -> M.Map Int [Integer] -> [Integer] -> Vector a -> Vector a +padRepeats newshape oneDmap s v = + let (_, multiMap) = mapIndices s + in V.generate (fromIntegral $ product newshape) (\i -> + let + multiI = oneDmap M.! i -- equivalent multi-index + multiWrap = zipWith mod multiI s -- wrap the index over dimensions of size 1 + flatWrap = multiMap M.! multiWrap -- collapse the index over the vector + in v V.! flatWrap) +-} + +-- | Concatenate a list of tensors into a single tensor. All input tensors must have the +-- same shape, except for the dimension size of the axis to concatenate on. +-- Returns Nothing if the arrays are not all of the same type or matching shapes. +concatAlong :: Int -> [NdArray] -> Maybe NdArray +concatAlong _ [] = Nothing +concatAlong _ [nd] = Just nd +concatAlong axis (NdArray sh st v : nds) = + case unpackArrays (NdArray sh st v : nds) (vecType v) of + Nothing -> Nothing + Just (shs, sts, vs) -> + case concatAlongVec shs sts vs axis of + Nothing -> Nothing + Just (csh, cst, cv) -> Just $ NdArray csh cst cv + +-- Helper for concatenation of vectors and their associated shapes. +concatAlongVec :: forall a . DType a => [Vector Int] -> [Vector Int] -> [Vector a] -> Int -> Maybe (Vector Int, Vector Int, Vector a) +concatAlongVec shs sts vs axis = + if not (checkShapeLengths shs) || not (checkAxis axis shs) then Nothing + else + let + -- Calculates the newshape by adding up all the dimensions along the axis + axDim = map (V.! axis) shs + newshape = head shs V.// [(axis, Prelude.sum axDim)] + newstride = defStride newshape + -- Each array to be concatenated is given a number to index it with + -- Values are indexed by array number, then by position in the array + --arrayPlot = V.fromList $ concat $ zipWith (\arr dim -> [(arr, x) | x <- [0..dim-1]]) [0..] axDim + arrayNums = V.concat $ zipWith (V.replicate) [0..] axDim + arrayAxInds = V.concat $ map (V.enumFromN 0) axDim + --(newMultiInds, _) = mapIndices newshape + --subArrayMaps = map (snd . mapIndices) shs + in + Just (newshape, newstride, + V.generate (V.product newshape) (\i -> + let + -- Generating the new vector by converting the new flat index to a multi-index + -- then mapping it to a sub-array and index and reading the value. + multiI = expandInd newstride i + --arrayMultiI = multiI + --multiI = newMultiInds M.! i + arrNum = arrayNums V.! (multiI V.! axis) + arrAxInd = arrayAxInds V.! (multiI V.! axis) + arr = vs !! arrNum + arrStr = sts !! arrNum + in + arr V.! (collapseInd arrStr (multiI V.// [(axis, arrAxInd)])) + --in + -- vecInd arrayMap array arrayMultiI <-@ typeRep @a + ) + ) + +-- Swaps in a value at the given index +--replaceNth :: Int -> a -> [a] -> [a] +--replaceNth n x l = take n l ++ [x] ++ drop (n+1) l + +-- Checks for the same number of dimensions +checkShapeLengths :: [Vector Int] -> Bool +checkShapeLengths [] = False -- same #dimensions but also invalid +checkShapeLengths shapes = all (\sh -> V.length sh == baseLen) shapes + where baseLen = V.length $ head shapes + +-- Checks that each dimension is the same save perhaps the axis one +{- +checkAxis :: Int -> [Vector Int] -> Bool +checkAxis _ [] = False +checkAxis axis shapes = + let + dropAxis = map (\sh -> take axis sh ++ drop (axis+1) sh) shapes + base = head dropAxis + in 0 <= axis && axis <= length base && + foldr intersect base dropAxis == base +-} +checkAxis :: Int -> [Vector Int] -> Bool +checkAxis _ [] = False +checkAxis axis shapes = + let (preAx, postAx) = (V.take axis (head shapes), V.drop (axis+1) (head shapes)) + in all (\s -> V.take axis s == preAx && V.drop (axis+1) s == postAx) shapes + +-- | Takes an array, set of sub-indices and axis and repeatedly takes slices +-- of the array restricted to that index along the specified axis. +-- The slices are then concatenated into the final array. +gather :: NdArray -> [Int] -> Int -> NdArray +gather nd is axis = fromJust $ concatAlong axis $ map (\i -> slice (sliceLead ++ [(i,i)]) nd) is + where sliceLead = replicate axis (0,-1) + + +-- * Matrix Operations +-------------------------------------------------------------------------------- + +-- * Rows, Columns and Diagonals + +{- | Switches the rows at the two given indices over. +NB: designed for 2x2 matrices so will only make swaps in the 'front' matrix of a tensor. +-} +swapRows :: Int -> Int -> NdArray -> NdArray +swapRows r1 r2 (NdArray sh st v) + | r1 == r2 = NdArray sh st v + | V.length sh < 2 = error "Too few rows to make swaps." + | r1 >= numRows || r2 >= numRows = error "Row index exceeds number of rows." + | otherwise = + let + rowInds1 = V.iterateN lenRows (+ V.last st) (st V.! colI * r1) + rowInds2 = V.iterateN lenRows (+ V.last st) (st V.! colI * r2) + row1 = V.map (v V.!) rowInds1 + row2 = V.map (v V.!) rowInds2 + in + NdArray sh st $ V.force $ V.update_ v (rowInds2 V.++ rowInds1) (row1 V.++ row2) + where + colI = V.length sh - 2 + numRows = sh V.! colI + lenRows = V.last sh + +{- | Gets the flat array of the leading diagonal of the 'front' matrix of the tensor. -} +diagonal :: NdArray -> NdArray +diagonal (NdArray sh st v) = + let + rows = V.last sh; cols = sh V.! (V.length sh - 2) + rStr = V.last st; cStr = st V.! (V.length st - 2) + v' = V.generate (min rows cols) (\i -> v V.! (i * (cStr + rStr))) + sh' = V.singleton $ V.length v' + in NdArray sh' (defStride sh') v' +{- +diagonal (NdArray sh st v) = NdArray sh' st' v' + where + v' = V.force $ diagonalVec sh v + sh' = V.singleton (V.length v') + st' = defStride sh' +-} + +{- +-- Helper to take the leading diagonal in the vector form. +diagonalVec :: forall a . DType a => Vector Int -> Vector a -> Vector a +diagonalVec s = V.ifilter (\i _ -> i `mod` (rowLen+1) == 0 && i < rowLen*columns) + where + rowLen = s V.! (V.length s - 1) + columns = s V.! (V.length s - 2) +-} +{- +diagonalVec :: forall a . DType a => Vector Int -> Vector Int -> Vector a -> Vector a +diagonalVec sh st v = + let + rows = V.last sh + cols = sh V.! (V.length sh - 2) + rStr = V.last st + cStr = st V.! (V.length st - 2) + in + V.generate (min rows cols) ((V.!) v . (cStr + rStr) * ) +-} + +-- * Transposition + +-- | Reverses the order of axes and switches the elements accordingly. +{-transpose :: NdArray -> NdArray +transpose (NdArray sh v) = transposePerm dec (NdArray sh v) + where + l = length sh + dec = [l-1, l-2 .. 0] +-} + +transpose :: NdArray -> NdArray +transpose (NdArray sh st v) = NdArray (V.reverse sh) (V.reverse st) v + +-- | Transposes the axes of an array according to the given permutation (e.g. [2,0,1]) + +transposePerm :: [Int] -> NdArray -> NdArray +transposePerm perm (NdArray sh st v) = + let + sh' = V.fromList $ permuteList perm $ V.toList sh + st' = V.fromList $ permuteList perm $ V.toList st + in NdArray sh' st' v + +-- Applies a permutation to a list +permuteList :: [Int] -> [a] -> [a] +permuteList perm l = if sort perm /= [0 .. length l -1] + then error "Invalid permutation given." + else map (l!!) perm + +-- Finds the inverse of a permutation +invertPermutation :: [Int] -> [Int] +invertPermutation perm = map (\i -> fromJust $ elemIndex i perm) [0..length perm -1] + +-- * Multiplication + +-- | Dot product over matrices of the same shape. +dot :: DType a => NdArray -> NdArray -> a +dot nd1 nd2 = foldrA DType.add DType.addId (nd1*nd2) + +-- | Standard matrix multiplication following NumPy conventions. +-- 1D arrays have the extra dimension pre/appended +-- 2D arrays are multiplied as expected +-- ND-arrays are broadcast to match each other where possible and treated as stacks of nxm/pxq arrays. +matMul :: NdArray -> NdArray -> NdArray +matMul (NdArray s t v) (NdArray r d u) = + case v =@= u of + Just HRefl -> + case (reverse $ V.toList s, reverse $ V.toList r) of + -- Standard matrix multiplication + ([m, n], [q, p]) | m == p -> genStride [n,q] (matMulVec s t v r d u) + -- 1D arrays have the extra dimension pre/appended then result collapses back to 1D + ([m], [q, p]) | m == p -> genStride [q] (matMulVec (V.fromList [1,m]) (V.fromList [0,1]) v r d u) + ([m, n], [p]) | m == p -> genStride [n] (matMulVec s t v (V.fromList [p,1]) (V.fromList [1,0]) u) + -- ND-arrays are broadcast to match each other where possible and treated as + -- stacks of nxm/pxq arrays. + (m : n : _, q : p : _) | m == p -> case (stride (NdArray s t v), stride (NdArray r d u)) of + (NdArray ss st sv, NdArray sr sd su) -> + let (s', v', _r', u') = broadcastDimensions ss sv sr su + in + case v' =@= u' of + Just HRefl -> + let + -- here is where you need to care about the strides + stackA = vectorChunksOf (m * n) v' + stackB = vectorChunksOf (q * p) u' + dimA = V.fromList [n,m] + dimB = V.fromList [p,q] + stackAB = zipWith6 matMulVec (repeat dimA) (repeat $ defStride dimA) stackA + (repeat dimB) (repeat $ defStride dimB) stackB + in + genStride (take (V.length s' -2) (V.toList s') ++ [n,q]) $ V.concat stackAB + _ -> throw (ShapeMismatch (NdArray s t v) (NdArray r d u) "matMul") + _ -> throw (DTypeMismatch (NdArray s t v) (NdArray r d u) "matMul") + +-- Splits a vector into a list of vectors of the given size. +vectorChunksOf :: V.Storable a => Int -> Vector a -> [Vector a] +vectorChunksOf _ v | V.null v = [] +vectorChunksOf n v = first : vectorChunksOf n rest + where (first, rest) = V.splitAt n v + +-- Returning the vector result of the standard nxm matMul +matMulVec :: forall a . DType a => + Vector Int -> Vector Int -> Vector a -> Vector Int -> Vector Int -> Vector a -> Vector a +matMulVec s t v r d u = + let + oneDkey = fst $ mapIndices [s V.!0, r V.!1] + sz = M.size oneDkey + --map1 = vecInd (snd $ mapIndices s) v + map1 is = v V.! collapseInd t (V.fromList is) + map2 is = u V.! collapseInd d (V.fromList is) + --map2 = vecInd (snd $ mapIndices r) u + ks = [0 .. (s V.! 1 -1)] + in + V.generate sz (matMulElem map1 map2 ks . (M.!) oneDkey) + +-- Calculates the element at position [i,j] in the resultant nxp matrix of a matMul +matMulElem :: forall a . DType a => + ([Int] -> a) -> ([Int] -> a) -> [Int] -> [Int] -> a +matMulElem m1 m2 ks (i:j:_) = + foldr (\k acc -> + DType.add acc $ DType.multiply (m1 [i,k]) (m2 [k,j]) + ) DType.addId ks +matMulElem _ _ _ _ = DType.multId :: a + + +{- | General matrix multiplication. Calculates alpha*AB + beta*C with the option +to transpose A and B first. +Takes A, B, C, A transpose?, B transpose?, alpha, beta +Returns nothing if the matrix types/sizes do not match. +Will attempt to broadcast the shape of C and convert the types of alpha & beta. + +For more information see: +https://en.wikipedia.org/wiki/Basic_Linear_Algebra_Subprograms#Level_3 +NB: if the matrices are integers the scalars will also become integers so you should convert the matrices first +-} +gemm :: (DType a, DType b) => + NdArray -> NdArray -> NdArray -> Bool -> Bool -> a -> b -> Maybe NdArray +gemm (NdArray sA dA vA) (NdArray sB dB vB) (NdArray sC dC vC) transA transB alpha beta = + let + -- Apply transposition to A and B if specified + (sAT, dAT) = if transA then (V.reverse sA, V.reverse dA) else (sA, dA) + (sBT, dBT) = if transB then (V.reverse sB, V.reverse dB) else (sB, dB) + in + -- Check all the types match + case gemmTyping vA vB vC alpha beta of + Nothing -> Nothing + Just (vA', vB', vC', alpha', beta') -> + -- Check A and B have shapes (M,K) and (K, N) + if (V.length sAT /= 2) || (V.length sBT /= 2) || (V.length sC /= 2) || sAT V.!1 /= sBT V.!0 then Nothing + else + let + alphaAB = scale alpha' (matMul (NdArray sAT dAT vA') (NdArray sBT dBT vB')) + sAB = shape alphaAB + in + -- Check if C dimension matches or is broadcastable + if (sC V.!0 /= 1 && sC V.!0 /= sAB V.!0) || (sC V.!1 /= 1 && sC V.!1 /= sAB V.!1) then Nothing + else + let betaC = scale beta' (NdArray sC dC vC') + -- $ if (sC!!0 /= sAB!!0) || (sC!!1 /= sAB!!1) + -- then snd $ fromJust $ broadcast (alphaAB, NdArray sC vC') + -- else NdArray sC vC' + in + -- Finally, combine the two + Just (alphaAB + betaC) + +{- +-- Transpose the shape-vector pair if the boolean is true, otherwise return the original. +applyTransposition :: forall a . DType a => ([Integer], Vector a) -> Bool -> ([Integer], Vector a) +applyTransposition (s, v) b = + let + ndT = Numskull.transpose (NdArray s v) + sT = shape ndT + vT = getVector ndT :: Vector a + in + if b then (sT, vT) else (s, v) +-} + +-- Checking all mats are same type & converting scalars if neccersary +gemmTyping :: forall a b c d e . (DType a, DType b, DType c, DType d, DType e) => + Vector a -> Vector b -> Vector c -> d -> e -> + Maybe (Vector a, Vector a, Vector a, a, a) +gemmTyping vA vB vC alpha beta = + case vA =@= vB of + Just HRefl -> + case vA =@= vC of + Just HRefl -> + -- All matrices match types + let + vA' = vA :: Vector a + vB' = vB :: Vector a + vC' = vC :: Vector a + + -- Convert scalar types + alpha' = + case alpha =@= (undefined :: a) of + Just HRefl -> alpha :: a + _ -> DType.rationalToDtype (DType.dtypeToRational alpha) :: a + beta' = + case beta =@= (undefined :: a) of + Just HRefl -> beta :: a + _ -> DType.rationalToDtype (DType.dtypeToRational beta) :: a + in + Just (vA', vB', vC', alpha', beta') + _ -> Nothing + _ -> Nothing + +-- * Determinants and Inverses + +-- | Converts a nxn matrix to upper triangle form. O(n^3). +upperTriangle :: NdArray -> NdArray +upperTriangle (NdArray s t v) | V.null s = NdArray s t v +upperTriangle (NdArray s t v) = + let + c = V.head s + traversals = [(i,j,k) | i <- [0..c-1], j <- [i+1..c-1], k <- [0..c-1]] + in NdArray s (defStride s) $ triangulateVec t v traversals (identityElem v) + +-- Upper triangle form on the hidden vector. +triangulateVec :: DType a => Vector Int -> Vector a -> [(Int,Int,Int)] -> a -> Vector a +triangulateVec _ v [] _ = v +triangulateVec t v ((i,j,k) : trv) r = + let + vSet x y e = v V.// [(collapseInd t (V.fromList [x,y]), e)] + ratio = if k == 0 then DType.divide (vGet v t [j,i]) (vGet v t [i,i]) else r + scaled = DType.multiply ratio (vGet v t [i,k]) + newVjk = DType.subtract (vGet v t [j,k]) scaled + in + triangulateVec t (vSet j k newVjk) trv ratio + +{- | Finds the determinant(s) of a tensor. Over matrices of more than two dimensions +each 2D matrix's determinant is individually calculated and concatenated together (as in numpy: +https://numpy.org/doc/stable/reference/generated/numpy.linalg.det.html ). +If the matrix is non-square it is assumed to be padded out and will have determinant of 0 +-} +determinant :: forall a . DType a => NdArray -> [a] +determinant (NdArray s t v) = case V.length s of + 0 -> [] + 1 -> [DType.addId :: a] + 2 -> [determinant2D (NdArray s t v)] + _ | V.null v -> [] + l -> + let + colrow = V.drop (l-2) s + crt = V.drop (l-2) t + (twoDim, rest) = V.splitAt (V.product colrow) v + in (determinant2D (NdArray colrow crt twoDim) : determinant (NdArray s t rest)) + +{- | Calculates the determinant of a 2D matrix using LU decomposition as described in the +below paper. O(n^3). +https://informatika.stei.itb.ac.id/~rinaldi.munir/Matdis/2016-2017/Makalah2016/Makalah-Matdis-2016-051.pdf +-} +determinant2D :: forall a . DType a => NdArray -> a +determinant2D nd = + case V.toList $ shape nd of + -- 2x2 matrices are calculated quickly with the standard ad-bc + [2,2] -> determinant2x2 nd + -- nxn matrices are row-swapped to find an arrangement with no zeros/identity elements + -- in the leading diagonal (pivots) then put into upper triangle form + -- determinant is the product of the new pivots + [c,r] | c == r && not (zeroRow nd) -> case swapRowsWith0Pivot nd of + Just (NdArray s t v) -> + let pivots = getVector $ diagonal $ upperTriangle (NdArray s t v) :: Vector a + in V.foldr DType.multiply (DType.multId :: a) pivots + -- If the matrix is non-square or has a zero-row/column, it is singular. + Nothing -> DType.addId + [_,_] -> DType.addId + _ -> error "Given matrix is not 2D." + +-- 2x2 quick determinant calculation of ad-bc +determinant2x2 :: forall a . DType a => NdArray -> a +determinant2x2 (NdArray _ t v) = + let + ad = DType.multiply (vGet v t [0,0]) (vGet v t [1,1]) + bc = DType.multiply (vGet v t [0,1]) (vGet v t [1,0]) + det = ad `DType.subtract` bc + in det <-@ (typeRep @a) + +-- | Checks the whole array for the prescence of a zero-row. +zeroRow :: NdArray -> Bool +zeroRow (NdArray s t v) = --zeroRowVec (last s) v + case V.length s of + 0 -> False + 1 -> s V.!0 == V.length (V.ifilter (\i x -> i `mod` (V.last t) == 0 && x == identityElem v) v) + _ -> + let + rowLen = V.last s + numRows = s V.! (V.length s - 2) + rowStride = V.last t + colStride = t V.! (V.length t - 2) + in + isNothing $ traverse (\r -> + let sect = V.slice (r*colStride) (rowLen*rowStride) v + in if rowLen == V.length (V.ifilter (\i x -> i `mod` rowStride == 0 && x == identityElem v) v) + then Nothing + else Just False + ) [0..numRows-1] + +-- Checks the array in vector form for a zero-row. +{- +zeroRowVec :: forall a . DType a => Int -> Vector a -> Bool +zeroRowVec r v = + let + ident = DType.addId :: a + (row, rest) = V.splitAt r v + in + not (V.null v) && + (V.all (==ident) row || + zeroRowVec r rest) +-} + +{- Repeatedly swaps rows until the matrix is found to be singular or +there are no pivots which are zero/identity elem. If singular, returns Nothing. +Note: hangs if given a matrix with a zero-row. +-} +swapRowsWith0Pivot :: NdArray -> Maybe NdArray +swapRowsWith0Pivot (NdArray sh st v) = + let + diag = getVector $ diagonal (NdArray sh st v) + ident = identityElem diag + in + case V.elemIndex ident diag of + -- x is the column-index of the 0 pivot + Just c -> case V.findIndex (/= ident) (frontColumn c sh st v) of + -- Swap 0-pivot and non-0 rows & try again + Just x -> swapRowsWith0Pivot $ + swapRows x c (NdArray sh st v) + -- The matrix is singular + Nothing -> Nothing + -- There is no 0-pivot + Nothing -> Just (NdArray sh st v) + +frontColumn :: forall a . DType a => + Int -> Vector Int -> Vector Int -> Vector a -> Vector a +frontColumn c sh st v = + let col = c * V.last st + in V.generate (sh V.! (V.length sh -2)) (\i -> v V.! ((V.length st -2)*i+col)) + +{-} +{- Extracts the indexed column from the front matrix of a tensor given its shape and vector. -} +frontColumn :: forall a . DType a => Int -> [Integer] -> Vector a -> Vector a +frontColumn col s v = V.ifilter + (\i _ -> i `mod` rowLen == col && i < rowLen*columns) $ + v <-@ (typeRep @(Vector a)) + where + rowLen = fromIntegral @Integer @Int $ s!!(length s -1) + columns = fromIntegral @Integer @Int $ s!!(length s -2) +-} \ No newline at end of file diff --git a/src/QuasiSlice.hs b/src/QuasiSlice.hs new file mode 100644 index 0000000..e1683c6 --- /dev/null +++ b/src/QuasiSlice.hs @@ -0,0 +1,101 @@ +{-# LANGUAGE DeriveDataTypeable #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE QuasiQuotes #-} + +module QuasiSlice (IndexRange(..), + QuasiSlice(..), + evalSlice, + parseSlice) +where + +import Text.ParserCombinators.Parsec +import Data.Typeable +import Data.Data + +-- | Type which allows you to provide only a single index or a range of indices. +data IndexRange = I Integer | R Integer Integer deriving (Show, Eq) + +-- QuasiQuoted slices are converted to this to be evaluated. +data QuasiSlice = + NoIndexEx + | IndexEx Integer + | NegIndexEx Integer + | AntiIndexEx Bool String + | SliceEx QuasiSlice QuasiSlice + | CommaEx QuasiSlice QuasiSlice + deriving(Show, Typeable, Data) + +-- Checks for the prescence of a value in a range e.g. ():4) +evalBound :: Bool -> QuasiSlice -> Integer +evalBound False NoIndexEx = 0 +evalBound True NoIndexEx = -1 +evalBound _ (IndexEx i) = i +evalBound _ (NegIndexEx i) = -i + +-- Converts the Quasi slice to an IndexRange which can be operated on as usual in Indexing. +evalSlice :: QuasiSlice -> [IndexRange] +evalSlice x = case x of + NoIndexEx -> [R 0 (-1)] + IndexEx i -> [I i] + NegIndexEx i -> [I (-i)] + SliceEx l r -> [R (evalBound False l) (evalBound True r)] + CommaEx ex1 ex2 -> evalSlice ex1 ++ evalSlice ex2 + +------------ PARSER + +lexeme p = do{ x <- p; spaces; return x } +symbol name = lexeme (string name) + +comma = do{ symbol ","; return $ CommaEx } + +indicesExpr :: CharParser st QuasiSlice +indicesExpr = sliceIndex `chainl1` comma + +number :: CharParser st QuasiSlice +number = do + m <- optionMaybe $ symbol "-" + ds <- many digit + case (m, ds) of + (Nothing, []) -> try antiIntExpr <|> pure NoIndexEx + (Nothing, _) -> pure $ IndexEx (read ds) + (Just _, []) -> try (fmap antiNeg antiIntExpr) <|> pure NoIndexEx + (Just _, _) -> pure $ IndexEx (negate $ read ds) + where + antiNeg (AntiIndexEx _ x) = AntiIndexEx False x + +sliceIndex :: CharParser st QuasiSlice +sliceIndex = lexeme $ do + l <- number + s <- optionMaybe $ symbol ":" + r <- number + case s of + Nothing -> pure l + Just _ -> pure $ SliceEx l r + +small = lower <|> char '_' +large = upper +idchar = small <|> large <|> digit <|> char '\'' + +ident :: CharParser s String +ident = do{ c <- small; cs <- many idchar; return (c:cs) } + +-- To include variables in scope, not just integers +antiIntExpr = lexeme $ do{ id <- ident; return $ AntiIndexEx True id } +--------------- + +parseSlice :: (Monad m, MonadFail m) => (String, Int, Int) -> String -> m QuasiSlice +parseSlice (file, line, col) s = + case runParser p () "" s of + Left err -> fail $ show err + Right e -> return e + where + p = do pos <- getPosition + setPosition $ + (flip setSourceName) file $ + (flip setSourceLine) line $ + (flip setSourceColumn) col $ + pos + spaces + e <- indicesExpr + eof + return e \ No newline at end of file diff --git a/src/QuasiSlice/Quote.hs b/src/QuasiSlice/Quote.hs new file mode 100644 index 0000000..784d6b9 --- /dev/null +++ b/src/QuasiSlice/Quote.hs @@ -0,0 +1,58 @@ +{-# LANGUAGE QuasiQuotes #-} + +module QuasiSlice.Quote (q) where + +--import Data.Generics +import qualified Language.Haskell.TH as TH +import Language.Haskell.TH.Quote +import Data.Typeable +--import Language.Haskell.TH.Syntax(liftData) + +import QuasiSlice + +extQ :: ( Typeable a, Typeable b) => (a -> r) -> (b -> r) -> a -> r +extQ f g a = maybe (f a) g (cast a) + +q :: QuasiQuoter +q = QuasiQuoter { quoteExp = quoteExprExp +-- , quotePat = quoteExprPat + -- with ghc >= 7.4, you could also + -- define quoteType and quoteDec for + -- quasiquotes in those places too + } +------- + +quoteExprExp :: String -> TH.ExpQ +quoteExprExp s = do loc <- TH.location + let pos = (TH.loc_filename loc, + fst (TH.loc_start loc), + snd (TH.loc_start loc)) + expr <- parseSlice pos s + --dataToExpQ (\x -> Nothing) expr + --liftData expr + dataToExpQ (const Nothing `extQ` antiExprExp) expr + +antiExprExp :: QuasiSlice -> Maybe (TH.Q TH.Exp) +antiExprExp (AntiIndexEx True v) = Just $ TH.appE (TH.conE (TH.mkName "IndexEx")) + (TH.varE (TH.mkName v)) +antiExprExp (AntiIndexEx False v) = Just $ TH.appE (TH.conE (TH.mkName "NegIndexEx")) + (TH.varE (TH.mkName v)) +--antiExprExp (AntiExpr v) = Just $ TH.varE (TH.mkName v) +antiExprExp _ = Nothing + +------- +{- +quoteExprPat :: String -> TH.PatQ +quoteExprPat s = do loc <- TH.location + let pos = (TH.loc_filename loc, + fst (TH.loc_start loc), + snd (TH.loc_start loc)) + expr <- parseSlice pos s + dataToPatQ (const Nothing `extQ` antiExprPat) expr + +antiExprPat :: QuasiSlice -> Maybe (TH.Q TH.Pat) +antiExprPat (AntiIndexEx v) = Just $ TH.conP (TH.mkName "IndexEx") + [TH.varP (TH.mkName v)] +--antiExprPat (AntiExpr v) = Just $ TH.varP (TH.mkName v) +antiExprPat _ = Nothing +-} \ No newline at end of file diff --git a/src/Serialisation.hs b/src/Serialisation.hs new file mode 100644 index 0000000..5cfb196 --- /dev/null +++ b/src/Serialisation.hs @@ -0,0 +1,148 @@ +{-# LANGUAGE GADTs #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} + +module Serialisation where +{- +import DType +import NdArray +import Typing + +import Data.Int +import Data.List as List +import Data.List.Split +import qualified Data.Map as M +import Data.Maybe (isJust) +import qualified Data.Vector.Storable as V +import Data.Word (Word16) +import Foreign (Ptr, alloca, mallocBytes) +import Foreign.Storable (peek, poke, sizeOf) +import System.IO +import Type.Reflection + +-- * HASKELL TO PYTHON + +-- | Built in numpy serialisation descriptions +getNumpyDType :: NdArray -> String +getNumpyDType (NdArray _ v) + | isType (typeRep @Int) = " V.Vector a -> TypeRep a + vectorType _ = typeRep @a + isType :: DType a => TypeRep a -> Bool + isType t = isJust (eqTypeRep (vectorType v) t) + +-- | Converts shape list to a string of the Numpy tuple form e.g. (3,2,) +getNumpyShape :: NdArray -> String +getNumpyShape (NdArray s _) = "(" <> drop 1 (take (length lshape -1) lshape) <> ",)" + where lshape = show s + +-- | Gets the maximum memory required for any single element in an array +getElemSize :: NdArray -> Int +getElemSize (NdArray _ v) = V.maximum $ V.map sizeOf v + +-- | Saves any of the standard types defined above as a .npy +-- Thanks Chris! https://github.com/cchalmers/dense/blob/6eced9f5a3ab6b5026fe4f7ab4f67a8bce4d6262/src/Data/Dense/Storable.hs#L686 +-- See https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html +saveNpy :: FilePath -> NdArray -> IO () +saveNpy path (NdArray s v) = withBinaryFile path WriteMode $ \h -> do + let + -- Unpacked specs + nd = NdArray s v + dtype = getNumpyDType nd + shape = getNumpyShape nd + vectorSize = (fromInteger $ product s) :: Int + elemSize = getElemSize nd + -- Header string without length + header = + "{'descr': '"<> dtype <> "', " <> + "'fortran_order': False, "<> + "'shape': "<> shape <> " }" + -- Calculate header length (& padding) + unpaddedLen = 6 + 2 + 2 + List.length header + 1 + paddedLen = ((unpaddedLen + 63) `Prelude.div` 64) * 64 + padding = paddedLen - unpaddedLen + headerLen = List.length header + padding + 1 + -- Put header & padding + hPutStr h "\x93NUMPY\x01\x00" + alloca $ \ptr -> poke ptr (fromIntegral headerLen :: Word16) >> hPutBuf h ptr 2 + hPutStr h header + hPutStr h (List.replicate padding ' ') + hPutChar h '\n' + -- Put vector body + V.unsafeWith v (\ptr -> hPutBuf h ptr (vectorSize * elemSize)) + + +-- * PYTHON TO HASKELL + +-- Splits the metadata into a list of keys and values +listDict :: String -> [String] +listDict x = splitOn " " $ splitOneOf "{}" (filter (/='\'') x) !! 1 + +-- Pairs adjacent keys and values in the metadata +pairDict :: [String] -> [(String, String)] +pairDict [] = [] +pairDict [_] = [] +pairDict (k:v:ls) = (k, v) : pairDict ls + +-- Read in an element from the handle +buffElement :: forall a . DType a => Handle -> IO a +buffElement h = do + let elemSize = sizeOf (undefined :: a) + ptr <- mallocBytes elemSize + _ <- hGetBuf h ptr elemSize + peek ptr + +-- Read in the complete array as a list from the handle +buffArray :: forall a . DType a => TypeRep a -> Handle -> Integer -> [IO a] +buffArray _ _ 0 = [] +buffArray t h i = do + let buffed = buffElement h : buffArray t h (i-1) + case eqTypeRep (typeOf buffed) (typeRep @[IO a]) of + Just HRefl -> buffed + _ -> error "Given TypeRep does not match data type." + +-- Reads a buffer into an NdArray given the handle, shape and dtype +loadPayload :: forall a . DType a => Handle -> [Integer] -> TypeRep a -> IO NdArray +loadPayload h sh _ = do + l <- sequenceA (buffArray (typeRep @a) h (product sh)) + pure $ NdArray sh (V.fromList l) + +-- Todo: check unicode UTF +-- Facilitates conversion from a numpy dtype signature to a typeRep +reifyDType :: String -> (forall a . DType a => TypeRep a -> r) -> r +reifyDType dtype cont = + case dtype of + " cont (typeRep @Int64) + " cont (typeRep @Int32) + " cont (typeRep @Float) + " cont (typeRep @Double) + " cont (typeRep @Char) + " cont (typeRep @Bool) + _ -> error "Unsupported dtype." + +-- | Loads an NdArray from a .npy file +loadNpy :: FilePath -> IO NdArray +loadNpy path = withBinaryFile path ReadMode $ \h -> do + -- Unpacks and parses the header to get the array type and size + descr <- hGetLine h + let + -- Places the dtype description, fortran order and shape in a map + metadata = (M.fromList . pairDict . listDict) descr + -- Extracts the dtype description e.g. Vector Int -> Vector a -> Vector Int +stride sh st v = + let + -- shape + dim' = V.scanr' (*) 1 sh + newshape = V.map (\i -> + ceiling $ preciseDiv + ( 1 + (dim' V.! (i+1)) * (sh V.!i -1) ) + ( st V.! i ) + :: Int) + (V.enumFromN 0 (V.length sh)) + in + newshape + +t = stride (V.fromList [5,5,5]) (V.fromList [50, 10, 2]) (V.fromList [0..124]) +-} + +expandRun :: [Int] -> Int -> [Int] +expandRun [] _ = [] +expandRun (s:sts) x = + if s == 0 then (0 : expandRun sts x) + else x `div` s : expandRun sts (x `mod` s) \ No newline at end of file diff --git a/src/Typing.hs b/src/Typing.hs new file mode 100644 index 0000000..9c62ed0 --- /dev/null +++ b/src/Typing.hs @@ -0,0 +1,29 @@ +{-# LANGUAGE GADTs #-} +{-# LANGUAGE TypeOperators #-} + +module Typing where + +import Type.Reflection + +-- * Typing Shorthand +-- | typeOf synonym. +ty :: Typeable a => a -> TypeRep a +ty = typeOf + +-- | eqTypeRep synonym, returning Just HRefl in the case of type equality. +-- >>> case True =@= False of +-- >>> Just HRefl -> putStrLn "Two Booleans will match" +-- >>> Nothing -> putStrLn "Mismatching types" +-- Two Booleans will match +(=@=) :: (Typeable a, Typeable b) => a -> b -> Maybe (a :~~: b) +(=@=) v u = eqTypeRep (ty v) (ty u) + +-- | eqTypeRep-like for checking a value against a typeRep. +(=@) :: Typeable a => a -> TypeRep b -> Maybe (a :~~: b) +(=@) x = eqTypeRep (ty x) + +-- Helper asserting a type. +(<-@) ::Typeable a => a -> TypeRep b -> b +(<-@) val t = case eqTypeRep t (ty val) of + Just HRefl -> val + _ -> error "Mismatching type." \ No newline at end of file diff --git a/test/DocTest.hs b/test/DocTest.hs new file mode 100644 index 0000000..4019468 --- /dev/null +++ b/test/DocTest.hs @@ -0,0 +1,11 @@ +module Main where + +-- doctest +import Test.DocTest + +main :: IO () +main = doctest $ "-isrc" : map ("src/" <>) + [ "DType.hs" + , "Numskull.hs" + , "Serialisation.hs" + ] \ No newline at end of file diff --git a/test/Main.hs b/test/Main.hs deleted file mode 100644 index 3e2059e..0000000 --- a/test/Main.hs +++ /dev/null @@ -1,4 +0,0 @@ -module Main (main) where - -main :: IO () -main = putStrLn "Test suite not yet implemented." diff --git a/test/Test.hs b/test/Test.hs new file mode 100644 index 0000000..15d55c1 --- /dev/null +++ b/test/Test.hs @@ -0,0 +1,13 @@ +module Main where + +-- hspec +import Test.Hspec + +-- ndarray (local) +import qualified Test.Numskull +import qualified Test.Serialisation + +main :: IO () +main = hspec $ do + describe "Test.Numskull" Test.Numskull.spec + describe "Test.Serialisation" Test.Serialisation.spec \ No newline at end of file diff --git a/test/Test/Numskull.hs b/test/Test/Numskull.hs new file mode 100644 index 0000000..95abfac --- /dev/null +++ b/test/Test/Numskull.hs @@ -0,0 +1,33 @@ +module Test.Numskull where + +-- hspec +import Test.Hspec + +-- QuickCheck +import Test.QuickCheck (NonNegative(..), property) + +-- ndarray (local) +import Numskull as N + +spec :: Spec +spec = do + describe "NdArray equality" $ + it "works" $ + N.fromList [3] [1,2,3::Int] == N.fromList [3] [1,2,3::Int] + +{- + describe "padShape" $ do + focus . it "works on a less simple example" $ + property $ \content (NonNegative extra) -> + let n = toInteger $ length content + in + padShape (N.fromList [n] content) [n + extra] `shouldBe` N.fromList [n + extra] (content <> replicate (fromInteger extra) (0 :: Int)) +-} +--cabal test --test-show-details=streaming +-- ghci -isrc -itest test/Test/Numskull.hs +-- ghci> hspec spec + + +-- https://hackage.haskell.org/package/hspec-2.11.3/docs/Test-Hspec.html#v:example + +-- https://hspec.github.io/ diff --git a/test/Test/Serialisation.hs b/test/Test/Serialisation.hs new file mode 100644 index 0000000..4b497c6 --- /dev/null +++ b/test/Test/Serialisation.hs @@ -0,0 +1,7 @@ +module Test.Serialisation where + +-- hspec +import Test.Hspec + +spec :: Spec +spec = pure () \ No newline at end of file diff --git a/tests/VectorTest.hs b/tests/VectorTest.hs deleted file mode 100644 index ea06fa1..0000000 --- a/tests/VectorTest.hs +++ /dev/null @@ -1,84 +0,0 @@ --- trust me bro ;) --- :set -fdefer-type-errors - -{-# LANGUAGE GADTs #-} - -module VectorTest where - -import Prelude as P -import Data.Vector as V -import Data.Dynamic -import Type.Reflection -import Data.Maybe (isJust, fromJust) - -ty x = typeOf x - --- DType -- -class (Show a, Typeable a) => DType a where - add :: a -> a -> a - subtract :: a -> a -> a - multiply :: a -> a -> a - eq :: a -> a -> Bool - dtypeToInt :: a -> Int - -instance DType Int where - add x y = x + y - multiply x y = x * y - eq x y = x == y - --- NdArray -- -data NdArray where - NdArray :: (Typeable a, DType a) => Vector a -> NdArray - -instance Show NdArray where - show (NdArray x) = show x - -instance Num NdArray where - (NdArray x) + (NdArray y) = case (eqTypeRep xtype ytype, matchDType (NdArray x) (NdArray y)) of - (Just HRefl, _) -> NdArray (V.zipWith add x y) -- Types match - (_, Just casted) -> (NdArray x) + casted -- Second type can be converted to first - otherwise -> error ("Cannot convert second matrix of type '" P.++ show ytype P.++ "' to type '" P.++ show xtype P.++ "'.") - where - xtype = ty x - ytype = ty y - - --(NdArray x) - (NDArray y) = - - --(NdArray x) * (NDArray y) - - --- Helper -eqDType x y = case eqTypeRep (ty x) (ty y) of - Just HRefl -> True - otherwise -> False - -matchDType :: NdArray -> NdArray -> Maybe NdArray -matchDType (NdArray x) (NdArray y) = case eqTypeRep (ty x) (ty (fromList [1::Int])) of - Just HRefl -> Just $ NdArray (V.map dtypeToInt y) - otherwise -> Nothing - - - - --- Spaghetti -{- -instance Num NdArray where - (NdArray x) + (NdArray y) = - case typeOf x `eqTypeRep` typeOf y of - Just HRefl -> NdArray (V.zipWith add x y) -- Types match - Nothing -> case casted of Just newNd -> (V.zipWith add x newNd) - where casted = matchDType (NdArray x) (NdArray y) --} - - - ----- Testing - -unwrapND :: NdArray -> (String, Vector Dynamic) -unwrapND (NdArray x) = case typeOf x of - vecTypeInt -> ("Int", V.map toDyn x) - vecTypeBool -> ("Bool", V.map toDyn x) - -nd1 = NdArray (fromList [1,2,3::Int]) -nd2 = NdArray (fromList [10,11,12::Int]) -