diff --git a/pkg/codegen/prelude.go b/pkg/codegen/prelude.go index cb7d704..55e6a9c 100644 --- a/pkg/codegen/prelude.go +++ b/pkg/codegen/prelude.go @@ -65,7 +65,7 @@ var PreludeFunction = map[string]OpFunc{ mTypes.PRELUDE_MAP: PreludeMap, mTypes.PRELUDE_FILTER: PreludeFilter, mTypes.PRELUDE_REDUCE: PreludeReduce, - //mTypes.LIB_CORE_ASSOC: PreludeAssoc, + mTypes.PRELUDE_ASSOC: PreludeAssoc, //mTypes.LIB_CORE_POP: PreludePop, mTypes.OPERATOR_ADD: OperatorDispatcher(mTypes.OPERATOR_ADD), mTypes.OPERATOR_SUB: OperatorDispatcher(mTypes.OPERATOR_SUB), diff --git a/pkg/codegen/prelude_assoc.go b/pkg/codegen/prelude_assoc.go index 86cb389..120af7e 100644 --- a/pkg/codegen/prelude_assoc.go +++ b/pkg/codegen/prelude_assoc.go @@ -2,34 +2,108 @@ package codegen import ( "github.com/llir/llvm/ir" + "github.com/llir/llvm/ir/constant" + "github.com/llir/llvm/ir/enum" "github.com/llir/llvm/ir/types" "github.com/llir/llvm/ir/value" - "github.com/wf001/modo/pkg/log" mTypes "github.com/wf001/modo/pkg/types" ) -func PreludeAssoc(block *ir.Block, internal *mTypes.Internal, node *mTypes.Node) value.Value { - return nil -} -func PreludeAssocOld(block *ir.Block, internal *mTypes.Internal, node *mTypes.Node) value.Value { - oldArrPtr, ok := node.IRValue.(*ir.InstAlloca) - if !ok { - log.Panic("Array elements must be ir.InstAlloca: have %+v", oldArrPtr) +func PreludeAssoc(ctx *Context, n *mTypes.Node) value.Value { + srcPtr := n.IRValue + + tyPtr, isTyPtr := srcPtr.Type().(*types.PointerType) + tyStr, isTyStr := tyPtr.ElemType.(*types.StructType) + if !isTyPtr || !isTyStr { + return constant.NewNull(types.NewPointer(types.I32)) } - pos, newValue := node.Next.IRValue, node.Next.Next.IRValue - oldArr := oldArrPtr.ElemType.(*types.ArrayType) + typeInfo := ctx.prog.Declare.Type.Struct[tyStr.TypeName] + structType := typeInfo.Types + + nullStructPtr := constant.NewNull(types.NewPointer(structType)) + gepEndPtr := ctx.block.NewGetElementPtr(structType, nullStructPtr, mTypes.I32one) + mallocSize := ctx.block.NewPtrToInt(gepEndPtr, types.I64) + rawPtr := ctx.block.NewCall(ctx.internal.Cstd.Malloc, mallocSize) + destPtr := ctx.block.NewBitCast(rawPtr, types.NewPointer(structType)) + + // deep copy + for _, field := range typeInfo.Field { + fieldIndex := field.Pos + fieldType := field.Type + + origFieldPtr := ctx.block.NewGetElementPtr( + structType, + srcPtr, + mTypes.I32zero, + constant.NewInt(types.I32, int64(fieldIndex)), + ) + origVal := ctx.block.NewLoad(fieldType, origFieldPtr) + + copyFieldPtr := ctx.block.NewGetElementPtr( + structType, + destPtr, + mTypes.I32zero, + constant.NewInt(types.I32, int64(fieldIndex)), + ) + + if ptrType, ok := fieldType.(*types.PointerType); ok { + isNull := ctx.block.NewICmp(enum.IPredEQ, origVal, constant.NewNull(ptrType)) + + nullBlock := ctx.NewBlock("copy.null", n) + nonNullBlock := ctx.NewBlock("copy.non.null", n) + mergeBlock := ctx.NewBlock("copy.merge", n) + + ctx.block.NewCondBr(isNull, nullBlock, nonNullBlock) + + nullBlock.NewStore(constant.NewNull(ptrType), copyFieldPtr) + nullBlock.NewBr(mergeBlock) - newArr := CopyArrayOld(block, oldArrPtr, oldArr.ElemType, oldArr.Len, oldArr.Len) + elemType := ptrType.ElemType + nullElemPtr := constant.NewNull(types.NewPointer(elemType)) + gepEnd := nonNullBlock.NewGetElementPtr(elemType, nullElemPtr, mTypes.I32one) + allocSize := nonNullBlock.NewPtrToInt(gepEnd, types.I64) - newElemPtr := block.NewGetElementPtr( - newArr.ElemType, - newArr, + malloc := nonNullBlock.NewCall(ctx.internal.Cstd.Malloc, allocSize) + newPtr := nonNullBlock.NewBitCast(malloc, ptrType) + + loadedElem := nonNullBlock.NewLoad(elemType, origVal) + nonNullBlock.NewStore(loadedElem, newPtr) + + nonNullBlock.NewStore(newPtr, copyFieldPtr) + nonNullBlock.NewBr(mergeBlock) + + ctx.block = mergeBlock + + } else { + ctx.block.NewStore(origVal, copyFieldPtr) + } + } + + fieldName := n.Next.Val + fieldInfo := typeInfo.Field[fieldName] + fieldIndex := fieldInfo.Pos + fieldType := fieldInfo.Type + + fieldPtr := ctx.block.NewGetElementPtr( + structType, + destPtr, mTypes.I32zero, - pos, + constant.NewInt(types.I32, int64(fieldIndex)), ) - block.NewStore(newValue, newElemPtr) - return newArr + val := n.Next.Next.IRValue + + if _, ok := val.(*ir.InstBitCast); ok { + if ptrType, ok := fieldType.(*types.PointerType); ok { + nullVal := constant.NewNull(ptrType) + ctx.block.NewStore(nullVal, fieldPtr) + ctx.block.NewStore(val, fieldPtr) + return destPtr + } + } + + ctx.block.NewStore(val, fieldPtr) + return destPtr } diff --git a/pkg/codegen/prelude_update.go b/pkg/codegen/prelude_update.go new file mode 100644 index 0000000..400d1a6 --- /dev/null +++ b/pkg/codegen/prelude_update.go @@ -0,0 +1,35 @@ +package codegen + +import ( + "github.com/llir/llvm/ir" + "github.com/llir/llvm/ir/types" + "github.com/llir/llvm/ir/value" + + "github.com/wf001/modo/pkg/log" + mTypes "github.com/wf001/modo/pkg/types" +) + +func PreludeUpdate(block *ir.Block, internal *mTypes.Internal, node *mTypes.Node) value.Value { + return nil +} +func PreludeUpdateOld(block *ir.Block, internal *mTypes.Internal, node *mTypes.Node) value.Value { + oldArrPtr, ok := node.IRValue.(*ir.InstAlloca) + if !ok { + log.Panic("Array elements must be ir.InstAlloca: have %+v", oldArrPtr) + } + + pos, newValue := node.Next.IRValue, node.Next.Next.IRValue + oldArr := oldArrPtr.ElemType.(*types.ArrayType) + + newArr := CopyArrayOld(block, oldArrPtr, oldArr.ElemType, oldArr.Len, oldArr.Len) + + newElemPtr := block.NewGetElementPtr( + newArr.ElemType, + newArr, + mTypes.I32zero, + pos, + ) + block.NewStore(newValue, newElemPtr) + + return newArr +} diff --git a/script/test-full.sh b/script/test-full.sh index 25895a2..fd2a882 100755 --- a/script/test-full.sh +++ b/script/test-full.sh @@ -346,6 +346,12 @@ testexec(){ assertexec '(def f::[int] => [int] (fn[v] (conj v 78))) (def main ::int (fn[] (let [v ::[int] [42 64 90]] (prn (f v)))))' "[42 64 90 78]\\\n" assertexec "(def main :: int (fn [] (let [vec :: [[int]] [[12 34 5 6] [789]]] (prn (nth (conj vec [11 12]) 2)) (prn (nth vec 2)))))" "[11 12]\\\nnil\\\n" + echo "===================" + echo "== assoc ===" + echo "===================" + assertexec '(defschema Person {:name :: string :age :: int :isMale :: bool}) (def main :: int (fn [] (let [node :: Person {:age 20 :isMale true}] (prn (get node :name)) (prn (get (assoc node :name "fendder") :name)) (prn (get node :name)))))' "nil\\\nfendder\\\nnil\\\n" + assertexec '(defschema Country {:name :: string}) (defschema Person {:name :: string :country :: Country}) (def main :: int (fn [] (let [co :: Country {:name "deutsch"} person :: Person {:name "jimmy"}] (prn (get (get person :country) :name)) (prn (get (get (assoc person :country co) :country) :name)) (prn (get (get person :country) :name)) )))' "nil\\\ndeutsch\\\nnil\\\n" + # echo "===================" # echo "== assoc ===" # echo "==================="