From 9ec3d00cd30b2e692ab50fa69140ca24d1a30632 Mon Sep 17 00:00:00 2001 From: Mzack9999 Date: Tue, 24 Mar 2026 23:11:11 +0100 Subject: [PATCH] fixing param type --- memoize/memoize.go | 2 +- memoize/memoize_test.go | 48 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/memoize/memoize.go b/memoize/memoize.go index 2c7fa5ac..0c23c0ed 100644 --- a/memoize/memoize.go +++ b/memoize/memoize.go @@ -139,7 +139,7 @@ func Src(tpl, sourcePath string, source []byte, packageName string) ([]byte, err for _, name := range param.Names { funcParam.Name = name.String() } - funcParam.Type = fmt.Sprint(param.Type) + funcParam.Type = types.ExprString(param.Type) funcDeclaration.Params = append(funcDeclaration.Params, funcParam) } } diff --git a/memoize/memoize_test.go b/memoize/memoize_test.go index 0abbbaa5..c1074b35 100644 --- a/memoize/memoize_test.go +++ b/memoize/memoize_test.go @@ -1,6 +1,10 @@ package memoize import ( + "go/ast" + "go/parser" + "go/token" + "go/types" "testing" "time" @@ -26,3 +30,47 @@ func TestSrc(t *testing.T) { require.Nil(t, err) require.True(t, len(out) > 0) } + +func TestParamTypeContextContext(t *testing.T) { + source := `package example + +import "context" + +// @memo +func DoSomething(ctx context.Context, key string) (string, error) { + return "", nil +} +` + + fset := token.NewFileSet() + node, err := parser.ParseFile(fset, "test.go", source, parser.ParseComments) + require.NoError(t, err) + + var params []FuncValue + ast.Inspect(node, func(n ast.Node) bool { + fn, ok := n.(*ast.FuncDecl) + if !ok || fn.Doc == nil { + return true + } + for _, comment := range fn.Doc.List { + if comment.Text == "// @memo" { + for idx, param := range fn.Type.Params.List { + var fv FuncValue + fv.Index = idx + for _, name := range param.Names { + fv.Name = name.String() + } + fv.Type = types.ExprString(param.Type) + params = append(params, fv) + } + } + } + return false + }) + + require.Len(t, params, 2) + require.Equal(t, "ctx", params[0].Name) + require.Equal(t, "context.Context", params[0].Type, "context.Context param type should be a clean string, not fmt.Sprint garbage") + require.Equal(t, "key", params[1].Name) + require.Equal(t, "string", params[1].Type) +}