diff --git a/gorfc/gorfc.go b/gorfc/gorfc.go index 7e49857..6902e90 100644 --- a/gorfc/gorfc.go +++ b/gorfc/gorfc.go @@ -49,7 +49,7 @@ package gorfc #cgo darwin CFLAGS: -Wall -O2 -Wno-uninitialized -Wcast-align #cgo darwin CFLAGS: -DSAP_UC_is_wchar -DSAPwithUNICODE -D__NO_MATH_INLINES -DSAPwithTHREADS -DSAPonDARW -#cgo darwin CFLAGS: -fexceptions -funsigned-char -fno-strict-aliasing -fPIC -pthread -std=c17 -mmacosx-version-min=10.15 +#cgo darwin CFLAGS: -fexceptions -funsigned-char -fno-strict-aliasing -fPIC -pthread -std=c17 #cgo darwin CFLAGS: -fno-omit-frame-pointer #cgo darwin CFLAGS: -I/usr/local/sap/nwrfcsdk/include @@ -58,7 +58,6 @@ package gorfc #cgo darwin LDFLAGS: -O2 -g -pthread #cgo darwin LDFLAGS: -stdlib=libc++ -#cgo darwin LDFLAGS: -mmacosx-version-min=10.15 #include @@ -74,6 +73,8 @@ static unsigned GoStrlenU(SAP_UTF16 *str) { import "C" import ( + "context" + "errors" "fmt" "reflect" "runtime" @@ -1140,8 +1141,41 @@ func (conn *Connection) GetFunctionDescription(goFuncName string) (goFuncDesc Fu return wrapFunctionDescription(funcDesc) } -// Call calls the given function with the given parameters and wraps the results returned. -func (conn *Connection) Call(goFuncName string, params interface{}) (result map[string]interface{}, err error) { +func setupParameter(params interface{}, funcDesc C.RFC_FUNCTION_DESC_HANDLE, funcCont C.RFC_FUNCTION_HANDLE) error { + paramsValue := reflect.ValueOf(params) + if paramsValue.Type().Kind() == reflect.Map { + keys := paramsValue.MapKeys() + if len(keys) > 0 { + if keys[0].Kind() == reflect.String { + for _, nameValue := range keys { + fieldName := nameValue.String() + fieldValue := paramsValue.MapIndex(nameValue).Interface() + err := fillFunctionParameter(funcDesc, funcCont, fieldName, fieldValue) + if err != nil { + return err + } + } + } else { + return errors.New("could not fill parameters passed as map with non-string keys") + } + } + } else if paramsValue.Type().Kind() == reflect.Struct { + for i := 0; i < paramsValue.NumField(); i++ { + fieldName := paramsValue.Type().Field(i).Name + fieldValue := paramsValue.Field(i).Interface() + + err := fillFunctionParameter(funcDesc, funcCont, fieldName, fieldValue) + if err != nil { + return err + } + } + } else { + return errors.New("parameters can only be passed as types map[string]interface{} or go-structures") + } + return nil +} + +func (conn *Connection) rfcInvoke(goFuncName string, params interface{}) (result map[string]interface{}, err error) { if !conn.alive { return nil, goRfcError("Call() method requires an open connection", nil) } @@ -1161,52 +1195,23 @@ func (conn *Connection) Call(goFuncName string, params interface{}) (result map[ } } - funcDesc := C.RfcGetFunctionDesc(conn.handle, funcName, &errorInfo) + var funcDesc C.RFC_FUNCTION_DESC_HANDLE = C.RfcGetFunctionDesc(conn.handle, funcName, &errorInfo) if funcDesc == nil { return result, rfcError(errorInfo, "Could not get function description for \"%v\"", funcName) } - funcCont := C.RfcCreateFunction(funcDesc, &errorInfo) + var funcCont C.RFC_FUNCTION_HANDLE = C.RfcCreateFunction(funcDesc, &errorInfo) if funcCont == nil { return result, rfcError(errorInfo, "Could not create function") } - defer C.RfcDestroyFunction(funcCont, nil) - paramsValue := reflect.ValueOf(params) - if paramsValue.Type().Kind() == reflect.Map { - keys := paramsValue.MapKeys() - if len(keys) > 0 { - if keys[0].Kind() == reflect.String { - for _, nameValue := range keys { - fieldName := nameValue.String() - fieldValue := paramsValue.MapIndex(nameValue).Interface() - - err = fillFunctionParameter(funcDesc, funcCont, fieldName, fieldValue) - if err != nil { - return - } - } - } else { - return result, rfcError(errorInfo, "Could not fill parameters passed as map with non-string keys") - } - } - } else if paramsValue.Type().Kind() == reflect.Struct { - for i := 0; i < paramsValue.NumField(); i++ { - fieldName := paramsValue.Type().Field(i).Name - fieldValue := paramsValue.Field(i).Interface() - - err = fillFunctionParameter(funcDesc, funcCont, fieldName, fieldValue) - if err != nil { - return - } - } - } else { - return result, rfcError(errorInfo, "Parameters can only be passed as types map[string]interface{} or go-structures") + err = setupParameter(params, funcDesc, funcCont) + if err != nil { + return } rc := C.RfcInvoke(conn.handle, funcCont, &errorInfo) - if rc != C.RFC_OK { return result, rfcError(errorInfo, "Could not invoke function \"%v\"", goFuncName) } @@ -1216,3 +1221,41 @@ func (conn *Connection) Call(goFuncName string, params interface{}) (result map[ } return wrapResult(funcDesc, funcCont, C.RFC_IMPORT, conn.rstrip) } + +func (conn *Connection) rfcCancel(goFuncName string) error { + var errorInfo C.RFC_ERROR_INFO + rc := C.RfcCancel(conn.handle, &errorInfo) + if rc != C.RFC_OK { + return rfcError(errorInfo, "Could not invoke function \"%v\"", goFuncName) + } + conn.alive = false + return nil +} + +// Call calls the given function with the given parameters and wraps the results returned. +func (conn *Connection) Call(goFuncName string, params interface{}) (result map[string]interface{}, err error) { + return conn.CallContext(context.Background(), goFuncName, params) +} + +// CallContext calls the given function with the given parameters and wraps the results returned. +func (conn *Connection) CallContext(ctx context.Context, goFuncName string, params interface{}) (result map[string]interface{}, err error) { + if ctx.Done() == nil { + return conn.rfcInvoke(goFuncName, params) + } + + done := make(chan struct{}) + go func() { + defer close(done) + result, err = conn.rfcInvoke(goFuncName, params) + done <- struct{}{} + }() + + select { + case <-ctx.Done(): + conn.alive = false + err = errors.Join(ctx.Err(), conn.rfcCancel(goFuncName), conn.Open()) + case <-done: + } + + return +} diff --git a/gorfc/gorfc_test.go b/gorfc/gorfc_test.go index 633b496..0e87a5c 100644 --- a/gorfc/gorfc_test.go +++ b/gorfc/gorfc_test.go @@ -1,7 +1,9 @@ package gorfc import ( + "context" "fmt" + "github.com/stretchr/testify/require" "os" "reflect" "strconv" @@ -14,9 +16,7 @@ import ( "github.com/sap/gorfc/gorfc/testutils" ) -// // NW RFC Lib Version -// func TestNWRFCLibVersion(t *testing.T) { major, minor, patchlevel := GetNWRFCLibVersion() assert.Equal(t, uint(7500), major) // adapt to your NW RFC Lib version @@ -24,9 +24,7 @@ func TestNWRFCLibVersion(t *testing.T) { assert.Greater(t, patchlevel, uint(4)) } -// // Connection Tests -// func TestConnect(t *testing.T) { fmt.Println("Connection test: Open and Close") c, err := ConnectionFromParams(abapSystem()) @@ -36,7 +34,7 @@ func TestConnect(t *testing.T) { assert.NotNil(t, c) assert.Nil(t, err) assert.True(t, c.Alive()) - c.Close() + assert.NoError(t, c.Close()) assert.False(t, c.Alive()) } @@ -361,6 +359,22 @@ func TestConfigParameter(t *testing.T) { c.Close() } +func TestCancelCall(t *testing.T) { + c, err := ConnectionFromParams(abapSystem()) + require.Nil(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + _, err = c.CallContext(ctx, "RFC_PING_AND_WAIT", map[string]interface{}{ + "SECONDS": 4, + }) + assert.ErrorIs(t, err, context.DeadlineExceeded) + + _, err = c.Call("RFC_PING", map[string]interface{}{}) + assert.NoError(t, err) + assert.NoError(t, c.Close()) +} + func TestInvalidParameterFunctionCall(t *testing.T) { fmt.Println("STFC: Invalid RFM parameter") c, err := ConnectionFromParams(abapSystem())