From 0e7469f1231902b127b66c547d0fb5197ae6b9d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Trevisan=20=28Trevi=C3=B1o=29?= Date: Fri, 20 Oct 2023 02:05:53 +0200 Subject: [PATCH 01/12] pam: Add comments for main package to please CI --- pam/authentication.go | 1 + pam/pam.go | 1 + 2 files changed, 2 insertions(+) diff --git a/pam/authentication.go b/pam/authentication.go index d21b26ef4c..002024d732 100644 --- a/pam/authentication.go +++ b/pam/authentication.go @@ -1,3 +1,4 @@ +// Package main is the package for the PAM library package main import ( diff --git a/pam/pam.go b/pam/pam.go index 0c6c3c673a..4e23fba871 100644 --- a/pam/pam.go +++ b/pam/pam.go @@ -1,3 +1,4 @@ +// Package main is the package for the PAM library. package main /* From 1ad064df5d2a2d10840182097a7a67f82f7c505d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Trevisan=20=28Trevi=C3=B1o=29?= Date: Sat, 14 Oct 2023 07:23:47 +0200 Subject: [PATCH 02/12] pam: Address some linter warnings --- pam/pam.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pam/pam.go b/pam/pam.go index 4e23fba871..8548c6e4e0 100644 --- a/pam/pam.go +++ b/pam/pam.go @@ -151,7 +151,7 @@ func pam_sm_acct_mgmt(pamh *C.pam_handle_t, flags, argc C.int, argv **C.char) C. return C.PAM_SUCCESS } -// newClient returns a new GRPC client ready to emit requests +// newClient returns a new GRPC client ready to emit requests. func newClient(argc C.int, argv **C.char) (client authd.PAMClient, close func(), err error) { conn, err := grpc.Dial("unix://"+getSocketPath(argc, argv), grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { @@ -179,17 +179,17 @@ func pam_sm_setcred(pamh *C.pam_handle_t, flags, argc C.int, argv **C.char) C.in return C.PAM_IGNORE } -// go_pam_cleanup_module is called by the go-loader PAM module during onload +// go_pam_cleanup_module is called by the go-loader PAM module during onload. // //export go_pam_cleanup_module func go_pam_cleanup_module() { runtime.GC() } -// Simulating pam on the CLI for manual testing +// Simulating pam on the CLI for manual testing. func main() { log.SetLevel(log.DebugLevel) - f, err := os.OpenFile("/tmp/logdebug", os.O_CREATE|os.O_APPEND|os.O_RDWR, 0644) + f, err := os.OpenFile("/tmp/logdebug", os.O_CREATE|os.O_APPEND|os.O_RDWR, 0600) if err != nil { panic(err) } From 781de205cabff3cccbc2c41dbd45c551a55d50fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Trevisan=20=28Trevi=C3=B1o=29?= Date: Fri, 21 Jul 2023 04:49:32 +0200 Subject: [PATCH 03/12] pam: Fail if no broker has been found --- pam/pam.go | 7 ++++++- pam/return.go | 10 ++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/pam/pam.go b/pam/pam.go index 8548c6e4e0..5600584764 100644 --- a/pam/pam.go +++ b/pam/pam.go @@ -55,7 +55,7 @@ func pam_sm_authenticate(pamh *C.pam_handle_t, flags, argc C.int, argv **C.char) client, closeConn, err := newClient(argc, argv) if err != nil { log.Debug(context.TODO(), err) - return C.PAM_IGNORE + return C.PAM_AUTHINFO_UNAVAIL } defer closeConn() @@ -102,6 +102,11 @@ func pam_sm_authenticate(pamh *C.pam_handle_t, flags, argc C.int, argv **C.char) logErrMsg = fmt.Sprintf("authentication: %s", exitMsg) } errCode = C.PAM_AUTH_ERR + case pamAuthInfoUnavailable: + if exitMsg.String() != "" { + logErrMsg = fmt.Sprintf("missing authentication data: %s", exitMsg) + } + errCode = C.PAM_AUTHINFO_UNAVAIL case pamSystemError: if exitMsg.String() != "" { logErrMsg = fmt.Sprintf("system: %s", exitMsg) diff --git a/pam/return.go b/pam/return.go index 1eea1fe628..e09fcb4f33 100644 --- a/pam/return.go +++ b/pam/return.go @@ -52,3 +52,13 @@ type pamAuthError struct { func (err pamAuthError) String() string { return err.msg } + +// pamAuthInfoUnavailable signals PAM module to return PAM_AUTHINFO_UNAVAIL and Quit tea.Model. +type pamAuthInfoUnavailable struct { + msg string +} + +// String returns the string of pamAuthInfoUnavailable message. +func (err pamAuthInfoUnavailable) String() string { + return err.msg +} From 41b84c8b0d4c88720dfcae827ae418c757c2e079 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Trevisan=20=28Trevi=C3=B1o=29?= Date: Fri, 13 Oct 2023 20:52:43 +0200 Subject: [PATCH 04/12] pam: Use pam-go to generate module stubs and handle PAM APIs Avoid doing all the C manual work inside authd pam, but handle this in a separate library that handles this all and comes with fully tested operations. We could so mock the ModuleTransaction here if we want so that we can make things more testable as they are now. As per this change, the module can be simply be generated with: go generate -C pam --- go.mod | 6 +++ go.sum | 2 + pam/.gitignore | 2 + pam/model.go | 5 +- pam/pam.go | 108 +++++++++++++++++++++++++------------------ pam/pam_module.go | 93 +++++++++++++++++++++++++++++++++++++ pam/userselection.go | 16 +++++-- pam/utils_c.go | 79 ------------------------------- 8 files changed, 180 insertions(+), 131 deletions(-) create mode 100644 pam/.gitignore create mode 100644 pam/pam_module.go delete mode 100644 pam/utils_c.go diff --git a/go.mod b/go.mod index 6446b89af6..8f386c22bc 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf github.com/godbus/dbus/v5 v5.1.0 github.com/google/uuid v1.4.0 + github.com/msteinert/pam v1.2.0 github.com/sirupsen/logrus v1.9.3 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/spf13/cobra v1.8.0 @@ -62,3 +63,8 @@ require ( golang.org/x/text v0.13.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20230920204549-e6e6cdab5c13 // indirect ) + +// FIXME: Use released version once we have one! +// The branch below includes changes from this upstream PR: +// - https://github.com/msteinert/pam/pull/13 +replace github.com/msteinert/pam => github.com/3v1n0/go-pam v0.0.0-20231130030658-0f1cc6f16d45 diff --git a/go.sum b/go.sum index d34d2b04fc..61cdbe1954 100644 --- a/go.sum +++ b/go.sum @@ -36,6 +36,8 @@ cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RX cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0= cloud.google.com/go/storage v1.14.0/go.mod h1:GrKmX003DSIwi9o29oFT7YDnHYwZoctc3fOKtUw0Xmo= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= +github.com/3v1n0/go-pam v0.0.0-20231130030658-0f1cc6f16d45 h1:EmfxQhrTNAVT4rzbv6TcP2XScDl16Q8J0ZlCVbazl8c= +github.com/3v1n0/go-pam v0.0.0-20231130030658-0f1cc6f16d45/go.mod h1:d2n0DCUK8rGecChV3JzvmsDjOY4R7AYbsNxAT+ftQl0= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4= diff --git a/pam/.gitignore b/pam/.gitignore new file mode 100644 index 0000000000..271d077f69 --- /dev/null +++ b/pam/.gitignore @@ -0,0 +1,2 @@ +./*.h +*.so diff --git a/pam/model.go b/pam/model.go index fb8fb2ecc6..8052ff8f2a 100644 --- a/pam/model.go +++ b/pam/model.go @@ -6,6 +6,7 @@ import ( "strings" tea "github.com/charmbracelet/bubbletea" + "github.com/msteinert/pam" "github.com/ubuntu/authd" "github.com/ubuntu/authd/internal/log" ) @@ -35,7 +36,7 @@ type sessionInfo struct { // model is the global models orchestrator. type model struct { - pamh pamHandle + pamMTx pam.ModuleTransaction client authd.PAMClient height int @@ -87,7 +88,7 @@ type SessionEnded struct{} // Init initializes the main model orchestrator. func (m *model) Init() tea.Cmd { - m.userSelectionModel = newUserSelectionModel(m.pamh) + m.userSelectionModel = newUserSelectionModel(m.pamMTx) var cmds []tea.Cmd cmds = append(cmds, m.userSelectionModel.Init()) diff --git a/pam/pam.go b/pam/pam.go index 5600584764..b9fe114348 100644 --- a/pam/pam.go +++ b/pam/pam.go @@ -1,15 +1,10 @@ +//go:generate go run github.com/msteinert/pam/cmd/pam-moduler -libname "pam_authd.so" -no-main -type pamModule +//go:generate go generate --skip="pam_module.go" +//go:generate sh -c "cc -o go-loader/pam_go_loader.so go-loader/module.c -Wl,--as-needed -Wl,--allow-shlib-undefined -shared -fPIC -Wl,--unresolved-symbols=report-all -lpam && chmod 600 go-loader/pam_go_loader.so" + // Package main is the package for the PAM library. package main -/* -#cgo LDFLAGS: -lpam -fPIC -#include -#include -#include -#include - -char *string_from_argv(int i, char **argv); -*/ import "C" import ( @@ -20,6 +15,7 @@ import ( "strings" tea "github.com/charmbracelet/bubbletea" + "github.com/msteinert/pam" "github.com/sirupsen/logrus" "github.com/ubuntu/authd" "github.com/ubuntu/authd/internal/consts" @@ -29,21 +25,24 @@ import ( "google.golang.org/grpc/credentials/insecure" ) +// pamModule is the structure that implements the pam.ModuleHandler interface +// that is called during pam operations. +type pamModule struct { +} + var ( // brokerIDUsedToAuthenticate global variable is for the second stage authentication to select the default broker for the current user. brokerIDUsedToAuthenticate string ) -//go:generate sh -c "cc -o go-loader/pam_go_loader.so go-loader/module.c -Wl,--as-needed -Wl,--allow-shlib-undefined -shared -fPIC -Wl,--unresolved-symbols=report-all -lpam && chmod 600 go-loader/pam_go_loader.so" -//go:generate sh -c "go build -ldflags='-extldflags -Wl,-soname,pam_authd.so' -buildmode=c-shared -o pam_authd.so" - /* + FIXME: provide instructions using pam-auth-update instead! Add to /etc/pam.d/common-auth auth [success=3 default=die ignore=ignore] pam_authd.so */ -//export pam_sm_authenticate -func pam_sm_authenticate(pamh *C.pam_handle_t, flags, argc C.int, argv **C.char) C.int { +// Authenticate is the method that is invoked during pam_authenticate request. +func (h *pamModule) Authenticate(mTx pam.ModuleTransaction, flags pam.Flags, args []string) error { // Initialize localization // TODO @@ -52,15 +51,15 @@ func pam_sm_authenticate(pamh *C.pam_handle_t, flags, argc C.int, argv **C.char) interactiveTerminal := term.IsTerminal(int(os.Stdin.Fd())) - client, closeConn, err := newClient(argc, argv) + client, closeConn, err := newClient(args) if err != nil { log.Debug(context.TODO(), err) - return C.PAM_AUTHINFO_UNAVAIL + return pam.ErrAuthinfoUnavail } defer closeConn() appState := model{ - pamh: pamh, + pamMTx: mTx, client: client, interactiveTerminal: interactiveTerminal, } @@ -74,16 +73,16 @@ func pam_sm_authenticate(pamh *C.pam_handle_t, flags, argc C.int, argv **C.char) p := tea.NewProgram(&appState, opts...) if _, err := p.Run(); err != nil { log.Errorf(context.TODO(), "Cancelled authentication: %v", err) - return C.PAM_ABORT + return pam.ErrAbort } logErrMsg := "unknown" - var errCode C.int = C.PAM_SYSTEM_ERR + errCode := pam.ErrSystem switch exitMsg := appState.exitMsg.(type) { case pamSuccess: brokerIDUsedToAuthenticate = exitMsg.brokerID - return C.PAM_SUCCESS + return nil case pamIgnore: // localBrokerID is only set on pamIgnore if the user has chosen local broker. brokerIDUsedToAuthenticate = exitMsg.localBrokerID @@ -91,27 +90,27 @@ func pam_sm_authenticate(pamh *C.pam_handle_t, flags, argc C.int, argv **C.char) log.Debugf(context.TODO(), "Ignoring authd authentication: %s", exitMsg) } logErrMsg = "" - errCode = C.PAM_IGNORE + errCode = pam.ErrIgnore case pamAbort: if exitMsg.String() != "" { logErrMsg = fmt.Sprintf("cancelled authentication: %s", exitMsg) } - errCode = C.PAM_ABORT + errCode = pam.ErrAbort case pamAuthError: if exitMsg.String() != "" { logErrMsg = fmt.Sprintf("authentication: %s", exitMsg) } - errCode = C.PAM_AUTH_ERR + errCode = pam.ErrAuth case pamAuthInfoUnavailable: if exitMsg.String() != "" { logErrMsg = fmt.Sprintf("missing authentication data: %s", exitMsg) } - errCode = C.PAM_AUTHINFO_UNAVAIL + errCode = pam.ErrAuthinfoUnavail case pamSystemError: if exitMsg.String() != "" { logErrMsg = fmt.Sprintf("system: %s", exitMsg) } - errCode = C.PAM_SYSTEM_ERR + errCode = pam.ErrSystem } if logErrMsg != "" { @@ -121,26 +120,28 @@ func pam_sm_authenticate(pamh *C.pam_handle_t, flags, argc C.int, argv **C.char) return errCode } -// pam_sm_acct_mgmt sets any used brokerID as default for the user. -// -//export pam_sm_acct_mgmt -func pam_sm_acct_mgmt(pamh *C.pam_handle_t, flags, argc C.int, argv **C.char) C.int { +// AcctMgmt sets any used brokerID as default for the user. +func (h *pamModule) AcctMgmt(mTx pam.ModuleTransaction, flags pam.Flags, args []string) error { // Only set the brokerID as default if we stored one after authentication. if brokerIDUsedToAuthenticate == "" { - return C.PAM_IGNORE + return pam.ErrIgnore } // Get current user for broker - user := getPAMUser(pamh) + user, err := mTx.GetItem(pam.User) + if err != nil { + return err + } + if user == "" { log.Infof(context.TODO(), "can't get user from PAM") - return C.PAM_IGNORE + return pam.ErrIgnore } - client, closeConn, err := newClient(argc, argv) + client, closeConn, err := newClient(args) if err != nil { log.Debugf(context.TODO(), "%s", err) - return C.PAM_IGNORE + return pam.ErrIgnore } defer closeConn() @@ -150,15 +151,15 @@ func pam_sm_acct_mgmt(pamh *C.pam_handle_t, flags, argc C.int, argv **C.char) C. } if _, err := client.SetDefaultBrokerForUser(context.TODO(), &req); err != nil { log.Infof(context.TODO(), "Can't set default broker (%q) for %q: %v", brokerIDUsedToAuthenticate, user, err) - return C.PAM_IGNORE + return pam.ErrIgnore } - return C.PAM_SUCCESS + return nil } // newClient returns a new GRPC client ready to emit requests. -func newClient(argc C.int, argv **C.char) (client authd.PAMClient, close func(), err error) { - conn, err := grpc.Dial("unix://"+getSocketPath(argc, argv), grpc.WithTransportCredentials(insecure.NewCredentials())) +func newClient(args []string) (client authd.PAMClient, close func(), err error) { + conn, err := grpc.Dial("unix://"+getSocketPath(args), grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { return nil, nil, fmt.Errorf("could not connect to authd: %v", err) } @@ -166,9 +167,9 @@ func newClient(argc C.int, argv **C.char) (client authd.PAMClient, close func(), } // getSocketPath returns the socket path to connect to which can be overridden manually. -func getSocketPath(argc C.int, argv **C.char) string { +func getSocketPath(args []string) string { socketPath := consts.DefaultSocketPath - for _, arg := range sliceFromArgv(argc, argv) { + for _, arg := range args { opt, optarg, _ := strings.Cut(arg, "=") switch opt { case "socket": @@ -179,9 +180,24 @@ func getSocketPath(argc C.int, argv **C.char) string { return socketPath } -//export pam_sm_setcred -func pam_sm_setcred(pamh *C.pam_handle_t, flags, argc C.int, argv **C.char) C.int { - return C.PAM_IGNORE +// SetCred is the method that is invoked during pam_setcred request. +func (h *pamModule) SetCred(pam.ModuleTransaction, pam.Flags, []string) error { + return pam.ErrIgnore +} + +// ChangeAuthTok is the method that is invoked during pam_chauthtok request. +func (h *pamModule) ChangeAuthTok(pam.ModuleTransaction, pam.Flags, []string) error { + return pam.ErrIgnore +} + +// OpenSession is the method that is invoked during pam_open_session request. +func (h *pamModule) OpenSession(pam.ModuleTransaction, pam.Flags, []string) error { + return pam.ErrIgnore +} + +// CloseSession is the method that is invoked during pam_close_session request. +func (h *pamModule) CloseSession(pam.ModuleTransaction, pam.Flags, []string) error { + return pam.ErrIgnore } // go_pam_cleanup_module is called by the go-loader PAM module during onload. @@ -201,10 +217,12 @@ func main() { defer f.Close() logrus.SetOutput(f) - authResult := pam_sm_authenticate(nil, 0, 0, nil) + module := &pamModule{} + + authResult := module.Authenticate(nil, pam.Flags(0), nil) fmt.Println("Auth return:", authResult) // Simulate setting auth broker as default. - accMgmtResult := pam_sm_acct_mgmt(nil, 0, 0, nil) + accMgmtResult := module.AcctMgmt(nil, pam.Flags(0), nil) fmt.Println("Acct mgmt return:", accMgmtResult) } diff --git a/pam/pam_module.go b/pam/pam_module.go new file mode 100644 index 0000000000..4dfe8e9c77 --- /dev/null +++ b/pam/pam_module.go @@ -0,0 +1,93 @@ +// Code generated by "pam-moduler -libname pam_authd.so -no-main -type pamModule"; DO NOT EDIT. + +//go:generate go build "-ldflags=-extldflags -Wl,-soname,pam_authd.so" -buildmode=c-shared -o pam_authd.so -tags go_pam_module + +// Package main is the package for the PAM module library. +package main + +/* +#cgo LDFLAGS: -lpam -fPIC +#include + +typedef const char _const_char_t; +*/ +import "C" + +import ( + "errors" + "fmt" + "github.com/msteinert/pam" + "os" + "unsafe" +) + +var pamModuleHandler pam.ModuleHandler = &pamModule{} + +// sliceFromArgv returns a slice of strings given to the PAM module. +func sliceFromArgv(argc C.int, argv **C._const_char_t) []string { + r := make([]string, 0, argc) + for _, s := range unsafe.Slice(argv, argc) { + r = append(r, C.GoString(s)) + } + return r +} + +// handlePamCall is the function that translates C pam requests to Go. +func handlePamCall(pamh *C.pam_handle_t, flags C.int, argc C.int, + argv **C._const_char_t, moduleFunc pam.ModuleHandlerFunc) C.int { + if pamModuleHandler == nil { + return C.int(pam.ErrNoModuleData) + } + + if moduleFunc == nil { + return C.int(pam.ErrIgnore) + } + + mt := pam.NewModuleTransactionInvoker(pam.NativeHandle(pamh)) + err := mt.InvokeHandler(moduleFunc, pam.Flags(flags), + sliceFromArgv(argc, argv)) + if err == nil { + return 0 + } + + if (pam.Flags(flags)&pam.Silent) == 0 && !errors.Is(err, pam.ErrIgnore) { + fmt.Fprintf(os.Stderr, "module returned error: %v\n", err) + } + + var pamErr pam.Error + if errors.As(err, &pamErr) { + return C.int(pamErr) + } + + return C.int(pam.ErrSystem) +} + +//export pam_sm_authenticate +func pam_sm_authenticate(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.Authenticate) +} + +//export pam_sm_setcred +func pam_sm_setcred(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.SetCred) +} + +//export pam_sm_acct_mgmt +func pam_sm_acct_mgmt(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.AcctMgmt) +} + +//export pam_sm_open_session +func pam_sm_open_session(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.OpenSession) +} + +//export pam_sm_close_session +func pam_sm_close_session(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.CloseSession) +} + +//export pam_sm_chauthtok +func pam_sm_chauthtok(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.ChangeAuthTok) +} diff --git a/pam/userselection.go b/pam/userselection.go index aa5954277a..d0b743a62d 100644 --- a/pam/userselection.go +++ b/pam/userselection.go @@ -3,13 +3,14 @@ package main import ( "github.com/charmbracelet/bubbles/textinput" tea "github.com/charmbracelet/bubbletea" + "github.com/msteinert/pam" ) // userSelectionModel allows selecting from PAM or interactively an user. type userSelectionModel struct { textinput.Model - pamh pamHandle + pamMTx pam.ModuleTransaction } // userSelected events to select a new username. @@ -25,7 +26,7 @@ func sendUserSelected(username string) tea.Cmd { } // newUserSelectionModel returns an initialized userSelectionModel. -func newUserSelectionModel(pamh pamHandle) userSelectionModel { +func newUserSelectionModel(pamMTx pam.ModuleTransaction) userSelectionModel { u := textinput.New() u.Prompt = "Username: " // TODO: i18n u.Placeholder = "user name" @@ -34,13 +35,16 @@ func newUserSelectionModel(pamh pamHandle) userSelectionModel { return userSelectionModel{ Model: u, - pamh: pamh, + pamMTx: pamMTx, } } // Init initializes userSelectionModel, by getting it from PAM if prefilled. func (m *userSelectionModel) Init() tea.Cmd { - pamUser := getPAMUser(m.pamh) + pamUser, err := m.pamMTx.GetItem(pam.User) + if err != nil { + return sendEvent(pamAbort{err.Error()}) + } if pamUser != "" { return sendUserSelected(pamUser) } @@ -54,7 +58,9 @@ func (m userSelectionModel) Update(msg tea.Msg) (userSelectionModel, tea.Cmd) { if msg.username != "" { // synchronise our internal validated field and the text one. m.SetValue(msg.username) - setPAMUser(m.pamh, msg.username) + if err := m.pamMTx.SetItem(pam.User, msg.username); err != nil { + return m, sendEvent(pamAbort{err.Error()}) + } return m, sendEvent(UsernameOrBrokerListReceived{}) } return m, nil diff --git a/pam/utils_c.go b/pam/utils_c.go deleted file mode 100644 index dc79be28f5..0000000000 --- a/pam/utils_c.go +++ /dev/null @@ -1,79 +0,0 @@ -package main - -/* -#include -#include -#include -#include -#include - -char *string_from_argv(int i, char **argv) { - return strdup(argv[i]); -} - -char *get_user(pam_handle_t *pamh) { - if (!pamh) - return NULL; - int pam_err = 0; - const char *user; - if ((pam_err = pam_get_item(pamh, PAM_USER, (const void**)&user)) != PAM_SUCCESS) - return NULL; - return strdup(user); -} - -char *set_user(pam_handle_t *pamh, char *username) { - if (!pamh) - return NULL; - int pam_err = 0; - if ((pam_err = pam_set_item(pamh, PAM_USER, (const void*)username)) != PAM_SUCCESS) - return NULL; - return NULL; -} -*/ -import "C" - -import ( - "unsafe" -) - -// pamHandle allows to pass C.pam_handle_t to this package. -type pamHandle = *C.pam_handle_t - -// sliceFromArgv returns a slice of strings given to the PAM module. -func sliceFromArgv(argc C.int, argv **C.char) []string { - r := make([]string, 0, argc) - for i := 0; i < int(argc); i++ { - s := C.string_from_argv(C.int(i), argv) - defer C.free(unsafe.Pointer(s)) - r = append(r, C.GoString(s)) - } - return r -} - -// mockPamUser mocks the PAM user item in absence of pamh for manual testing. -var mockPamUser = "user1" // TODO: remove assignement once ok with debugging - -// getPAMUser returns the user from PAM. -func getPAMUser(pamh *C.pam_handle_t) string { - if pamh == nil { - return mockPamUser - } - cUsername := C.get_user(pamh) - if cUsername == nil { - return "" - } - defer C.free(unsafe.Pointer(cUsername)) - return C.GoString(cUsername) -} - -// setPAMUser set current user to PAM. -func setPAMUser(pamh *C.pam_handle_t, username string) { - if pamh == nil { - mockPamUser = username - return - } - cUsername := C.CString(username) - defer C.free(unsafe.Pointer(cUsername)) - - C.set_user(pamh, cUsername) -} From 9b6d55bd92a7071aa11dc30daeb5fa31a898fac0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Trevisan=20=28Trevi=C3=B1o=29?= Date: Mon, 20 Nov 2023 19:43:37 +0100 Subject: [PATCH 05/12] log: Add Warning stub definition --- internal/log/log.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/internal/log/log.go b/internal/log/log.go index 7dbeb0cc11..f0d3036946 100644 --- a/internal/log/log.go +++ b/internal/log/log.go @@ -43,6 +43,11 @@ func Info(_ context.Context, args ...interface{}) { logrus.Info(args...) } +// Warning is a temporary placeholder. +func Warning(_ context.Context, args ...interface{}) { + logrus.Warning(args...) +} + // Warningf is a temporary placeholder. func Warningf(_ context.Context, format string, args ...interface{}) { logrus.Warningf(format, args...) From 7dc0c29324a06fc651093ce9c95876b7cac9afaf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Trevisan=20=28Trevi=C3=B1o=29?= Date: Sat, 14 Oct 2023 00:59:49 +0200 Subject: [PATCH 06/12] pam: Simplify the return code handling by using pam.StatusError Instead of using our own types for return status, let's just use pam.StatusError to return more complex error values back to the user --- pam/authentication.go | 44 +++++++++++++--------- pam/authmodeselection.go | 11 +++--- pam/brokerselection.go | 6 ++- pam/commands.go | 21 ++++++----- pam/model.go | 30 ++++++--------- pam/pam.go | 43 ++++----------------- pam/return.go | 80 ++++++++++++++++++++-------------------- pam/userselection.go | 4 +- 8 files changed, 110 insertions(+), 129 deletions(-) diff --git a/pam/authentication.go b/pam/authentication.go index 002024d732..ad96f40c43 100644 --- a/pam/authentication.go +++ b/pam/authentication.go @@ -8,6 +8,7 @@ import ( tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" + "github.com/msteinert/pam" "github.com/ubuntu/authd" "github.com/ubuntu/authd/internal/brokers/responses" "github.com/ubuntu/authd/internal/log" @@ -33,8 +34,9 @@ func sendIsAuthenticated(ctx context.Context, client authd.PAMClient, sessionID, access: responses.AuthCancelled, } } - return pamSystemError{ - msg: fmt.Sprintf("Authentication status failure: %v", err), + return pamError{ + status: pam.ErrSystem, + msg: fmt.Sprintf("authentication status failure: %v", err), } } @@ -127,15 +129,22 @@ func (m *authenticationModel) Update(msg tea.Msg) (authenticationModel, tea.Cmd) return *m, sendEvent(pamSuccess{brokerID: m.currentBrokerID}) case responses.AuthRetry: - m.errorMsg = dataToMsg(msg.msg) + errorMsg, err := dataToMsg(msg.msg) + if err != nil { + return *m, sendEvent(pamError{status: pam.ErrSystem, msg: err.Error()}) + } + m.errorMsg = errorMsg return *m, sendEvent(startAuthentication{}) case responses.AuthDenied: - errMsg := "Access denied" - if err := dataToMsg(msg.msg); err != "" { - errMsg = err + errMsg, err := dataToMsg(msg.msg) + if err != nil { + return *m, sendEvent(pamError{status: pam.ErrSystem, msg: err.Error()}) + } + if errMsg == "" { + errMsg = "Access denied" } - return *m, sendEvent(pamAuthError{msg: errMsg}) + return *m, sendEvent(pamError{status: pam.ErrAuth, msg: errMsg}) case responses.AuthNext: return *m, sendEvent(GetAuthenticationModesRequested{}) @@ -205,7 +214,7 @@ func (m *authenticationModel) Compose(brokerID, sessionID string, layout *authd. case "qrcode": qrcodeModel, err := newQRCodeModel(layout.GetContent(), layout.GetLabel(), layout.GetButton(), layout.GetWait() == "true") if err != nil { - return sendEvent(pamSystemError{msg: err.Error()}) + return sendEvent(pamError{status: pam.ErrSystem, msg: err.Error()}) } m.currentModel = qrcodeModel @@ -214,7 +223,10 @@ func (m *authenticationModel) Compose(brokerID, sessionID string, layout *authd. m.currentModel = newPasswordModel default: - return sendEvent(pamSystemError{msg: fmt.Sprintf("unknown layout type: %q", layout.Type)}) + return sendEvent(pamError{ + status: pam.ErrSystem, + msg: fmt.Sprintf("unknown layout type: %q", layout.Type), + }) } return sendEvent(startAuthentication{}) @@ -247,24 +259,22 @@ func (m *authenticationModel) Reset() { } // dataToMsg returns the data message from a given JSON message. -func dataToMsg(data string) string { +func dataToMsg(data string) (string, error) { if data == "" { - return "" + return "", nil } v := make(map[string]string) if err := json.Unmarshal([]byte(data), &v); err != nil { - log.Infof(context.TODO(), "Invalid json data from provider: %v", data) - return "" + return "", fmt.Errorf("invalid json data from provider: %v", err) } if len(v) == 0 { - return "" + return "", nil } r, ok := v["message"] if !ok { - log.Debugf(context.TODO(), "No message entry in json data from provider: %v", data) - return "" + return "", fmt.Errorf("no message entry in json data from provider: %v", v) } - return r + return r, nil } diff --git a/pam/authmodeselection.go b/pam/authmodeselection.go index f03d962055..05f210476a 100644 --- a/pam/authmodeselection.go +++ b/pam/authmodeselection.go @@ -9,6 +9,7 @@ import ( "github.com/charmbracelet/bubbles/list" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" + "github.com/msteinert/pam" "github.com/ubuntu/authd" "github.com/ubuntu/authd/internal/log" ) @@ -241,17 +242,15 @@ func getAuthenticationModes(client authd.PAMClient, sessionID string, uiLayouts gamResp, err := client.GetAuthenticationModes(context.Background(), gamReq) if err != nil { - return pamSystemError{ - msg: fmt.Sprintf("could not get authentication modes: %v", err), + return pamError{ + status: pam.ErrSystem, + msg: fmt.Sprintf("could not get authentication modes: %v", err), } } authModes := gamResp.GetAuthenticationModes() if len(authModes) == 0 { - return pamIgnore{ - // TODO: probably go back to broker selection here - msg: "no supported authentication mode available for this provider", - } + return pamIgnore{msg: "no supported authentication mode available for this provider"} } log.Info(context.TODO(), authModes) diff --git a/pam/brokerselection.go b/pam/brokerselection.go index a69e42fa45..64005cc994 100644 --- a/pam/brokerselection.go +++ b/pam/brokerselection.go @@ -9,6 +9,7 @@ import ( "github.com/charmbracelet/bubbles/list" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" + "github.com/msteinert/pam" "github.com/ubuntu/authd" "github.com/ubuntu/authd/internal/log" ) @@ -232,8 +233,9 @@ func getAvailableBrokers(client authd.PAMClient) tea.Cmd { return func() tea.Msg { brokersInfo, err := client.AvailableBrokers(context.TODO(), &authd.Empty{}) if err != nil { - return pamSystemError{ - msg: fmt.Sprintf("could not get current available brokers: %v", err), + return pamError{ + status: pam.ErrSystem, + msg: fmt.Sprintf("could not get current available brokers: %v", err), } } diff --git a/pam/commands.go b/pam/commands.go index 18f2d0e348..cfaea5ff80 100644 --- a/pam/commands.go +++ b/pam/commands.go @@ -7,6 +7,7 @@ import ( "strings" tea "github.com/charmbracelet/bubbletea" + "github.com/msteinert/pam" "github.com/ubuntu/authd" "github.com/ubuntu/authd/internal/log" ) @@ -44,16 +45,16 @@ func startBrokerSession(client authd.PAMClient, brokerID, username string) tea.C sbResp, err := client.SelectBroker(context.TODO(), sbReq) if err != nil { - return pamSystemError{msg: fmt.Sprintf("can't select broker: %v", err)} + return pamError{status: pam.ErrSystem, msg: fmt.Sprintf("can't select broker: %v", err)} } sessionID := sbResp.GetSessionId() if sessionID == "" { - return pamSystemError{msg: "no session ID returned by broker"} + return pamError{status: pam.ErrSystem, msg: "no session ID returned by broker"} } encryptionKey := sbResp.GetEncryptionKey() if encryptionKey == "" { - return pamSystemError{msg: "no encryption key returned by broker"} + return pamError{status: pam.ErrSystem, msg: "no encryption key returned by broker"} } return SessionStarted{ @@ -73,16 +74,18 @@ func getLayout(client authd.PAMClient, sessionID, authModeID string) tea.Cmd { } uiInfo, err := client.SelectAuthenticationMode(context.TODO(), samReq) if err != nil { - return pamSystemError{ - // TODO: probably go back to broker selection here - msg: fmt.Sprintf("can't select authentication mode: %v", err), + // TODO: probably go back to broker selection here + return pamError{ + status: pam.ErrSystem, + msg: fmt.Sprintf("can't select authentication mode: %v", err), } } if uiInfo.UiLayoutInfo == nil { - return pamSystemError{ - // TODO: probably go back to broker selection here - msg: "invalid empty UI Layout information from broker", + // TODO: probably go back to broker selection here + return pamError{ + status: pam.ErrSystem, + msg: "invalid empty UI Layout information from broker", } } diff --git a/pam/model.go b/pam/model.go index 8052ff8f2a..c16cc1c6d6 100644 --- a/pam/model.go +++ b/pam/model.go @@ -2,7 +2,6 @@ package main import ( "context" - "fmt" "strings" tea "github.com/charmbracelet/bubbletea" @@ -50,7 +49,7 @@ type model struct { authModeSelectionModel authModeSelectionModel authenticationModel authenticationModel - exitMsg fmt.Stringer + exitStatus pamReturnStatus } /* global events */ @@ -88,6 +87,7 @@ type SessionEnded struct{} // Init initializes the main model orchestrator. func (m *model) Init() tea.Cmd { + m.exitStatus = pamError{status: pam.ErrSystem, msg: "model did not return anything"} m.userSelectionModel = newUserSelectionModel(m.pamMTx) var cmds []tea.Cmd cmds = append(cmds, m.userSelectionModel.Init()) @@ -114,7 +114,10 @@ func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case tea.KeyMsg: switch msg.String() { case "ctrl+c": - return m, sendEvent(pamAbort{msg: "cancel requested"}) + return m, sendEvent(pamError{ + status: pam.ErrAbort, + msg: "cancel requested", + }) case "esc": if m.brokerSelectionModel.WillCaptureEscape() || m.authModeSelectionModel.WillCaptureEscape() { break @@ -138,20 +141,8 @@ func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { m.brokerSelectionModel.SetWidth(m.width) // Exit cases - case pamIgnore: - m.exitMsg = msg - return m, m.quit() - case pamAbort: - m.exitMsg = msg - return m, m.quit() - case pamSystemError: - m.exitMsg = msg - return m, m.quit() - case pamAuthError: - m.exitMsg = msg - return m, m.quit() - case pamSuccess: - m.exitMsg = msg + case pamReturnStatus: + m.exitStatus = msg return m, m.quit() // Events @@ -195,7 +186,10 @@ func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { msg.ID = m.authModeSelectionModel.currentAuthModeSelectedID } if msg.ID == "" { - return m, sendEvent(pamSystemError{msg: "reselection of current auth mode without current ID"}) + return m, sendEvent(pamError{ + status: pam.ErrSystem, + msg: "reselection of current auth mode without current ID", + }) } return m, getLayout(m.client, m.currentSession.sessionID, msg.ID) diff --git a/pam/pam.go b/pam/pam.go index b9fe114348..24346ff37c 100644 --- a/pam/pam.go +++ b/pam/pam.go @@ -76,48 +76,21 @@ func (h *pamModule) Authenticate(mTx pam.ModuleTransaction, flags pam.Flags, arg return pam.ErrAbort } - logErrMsg := "unknown" - errCode := pam.ErrSystem - - switch exitMsg := appState.exitMsg.(type) { + switch exitStatus := appState.exitStatus.(type) { case pamSuccess: - brokerIDUsedToAuthenticate = exitMsg.brokerID + brokerIDUsedToAuthenticate = exitStatus.brokerID return nil + case pamIgnore: // localBrokerID is only set on pamIgnore if the user has chosen local broker. - brokerIDUsedToAuthenticate = exitMsg.localBrokerID - if exitMsg.String() != "" { - log.Debugf(context.TODO(), "Ignoring authd authentication: %s", exitMsg) - } - logErrMsg = "" - errCode = pam.ErrIgnore - case pamAbort: - if exitMsg.String() != "" { - logErrMsg = fmt.Sprintf("cancelled authentication: %s", exitMsg) - } - errCode = pam.ErrAbort - case pamAuthError: - if exitMsg.String() != "" { - logErrMsg = fmt.Sprintf("authentication: %s", exitMsg) - } - errCode = pam.ErrAuth - case pamAuthInfoUnavailable: - if exitMsg.String() != "" { - logErrMsg = fmt.Sprintf("missing authentication data: %s", exitMsg) - } - errCode = pam.ErrAuthinfoUnavail - case pamSystemError: - if exitMsg.String() != "" { - logErrMsg = fmt.Sprintf("system: %s", exitMsg) - } - errCode = pam.ErrSystem - } + brokerIDUsedToAuthenticate = exitStatus.localBrokerID + return fmt.Errorf("%w: %s", exitStatus.Status(), exitStatus.Message()) - if logErrMsg != "" { - fmt.Fprintf(os.Stderr, "Error: %v\n", logErrMsg) + case pamReturnError: + return fmt.Errorf("%w: %s", exitStatus.Status(), exitStatus.Message()) } - return errCode + return fmt.Errorf("%w: unknown exit code", pam.ErrSystem) } // AcctMgmt sets any used brokerID as default for the user. diff --git a/pam/return.go b/pam/return.go index e09fcb4f33..fe93925c1c 100644 --- a/pam/return.go +++ b/pam/return.go @@ -1,64 +1,64 @@ package main +import ( + "github.com/msteinert/pam" +) + // Various signalling return messaging to PAM. -// pamSuccess signals PAM module to return PAM_SUCCESS and Quit tea.Model. +// pamReturnStatus is the interface that all PAM return types should implement. +type pamReturnStatus interface { + Message() string +} + +// pamReturnError is an interface that PAM errors return types should implement. +type pamReturnError interface { + pamReturnStatus + Status() pam.Error +} + +// pamSuccess signals PAM module to return with provided pam.Success and Quit tea.Model. type pamSuccess struct { brokerID string + msg string } -// String returns the string of pamSuccess. -func (err pamSuccess) String() string { - return "" +// Message returns the message that should be sent to pam as info message. +func (p pamSuccess) Message() string { + return p.msg } -// pamIgnore signals PAM module to return PAM_IGNORE and Quit tea.Model. +// pamIgnore signals PAM module to return pam.Ignore and Quit tea.Model. type pamIgnore struct { localBrokerID string // Only set for local broker to store it globally. msg string } -// String returns the string of pamIgnore message. -func (err pamIgnore) String() string { - return err.msg -} - -// pamAbort signals PAM module to return PAM_ABORT and Quit tea.Model. -type pamAbort struct { - msg string -} - -// String returns the string of pamAbort message. -func (err pamAbort) String() string { - return err.msg -} - -// pamSystemError signals PAM module to return PAM_SYSTEM_ERROR and Quit tea.Model. -type pamSystemError struct { - msg string -} - -// String returns the string of pamSystemError message. -func (err pamSystemError) String() string { - return err.msg +// Status returns [pam.ErrIgnore]. +func (p pamIgnore) Status() pam.Error { + return pam.ErrIgnore } -// pamAuthError signals PAM module to return PAM_AUTH_ERROR and Quit tea.Model. -type pamAuthError struct { - msg string +// Message returns the message that should be sent to pam as info message. +func (p pamIgnore) Message() string { + return p.msg } -// String returns the string of pamAuthError message. -func (err pamAuthError) String() string { - return err.msg +// pamIgnore signals PAM module to return the provided error message and Quit tea.Model. +type pamError struct { + status pam.Error + msg string } -// pamAuthInfoUnavailable signals PAM module to return PAM_AUTHINFO_UNAVAIL and Quit tea.Model. -type pamAuthInfoUnavailable struct { - msg string +// Status returns the PAM exit status code. +func (p pamError) Status() pam.Error { + return p.status } -// String returns the string of pamAuthInfoUnavailable message. -func (err pamAuthInfoUnavailable) String() string { - return err.msg +// Message returns the message that should be sent to pam as error message. +func (p pamError) Message() string { + if p.msg != "" { + return p.msg + } + return p.status.Error() } diff --git a/pam/userselection.go b/pam/userselection.go index d0b743a62d..85cecca876 100644 --- a/pam/userselection.go +++ b/pam/userselection.go @@ -43,7 +43,7 @@ func newUserSelectionModel(pamMTx pam.ModuleTransaction) userSelectionModel { func (m *userSelectionModel) Init() tea.Cmd { pamUser, err := m.pamMTx.GetItem(pam.User) if err != nil { - return sendEvent(pamAbort{err.Error()}) + return sendEvent(pamError{status: pam.ErrSystem, msg: err.Error()}) } if pamUser != "" { return sendUserSelected(pamUser) @@ -59,7 +59,7 @@ func (m userSelectionModel) Update(msg tea.Msg) (userSelectionModel, tea.Cmd) { // synchronise our internal validated field and the text one. m.SetValue(msg.username) if err := m.pamMTx.SetItem(pam.User, msg.username); err != nil { - return m, sendEvent(pamAbort{err.Error()}) + return m, sendEvent(pamError{status: pam.ErrAbort, msg: err.Error()}) } return m, sendEvent(UsernameOrBrokerListReceived{}) } From 17e85855bc17dbe4aa3a6cbf16d86bfa3a1facc8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Trevisan=20=28Trevi=C3=B1o=29?= Date: Sat, 14 Oct 2023 01:28:23 +0200 Subject: [PATCH 07/12] pam: Send error or info message to PAM if any --- pam/pam.go | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/pam/pam.go b/pam/pam.go index 24346ff37c..364ff21128 100644 --- a/pam/pam.go +++ b/pam/pam.go @@ -41,6 +41,23 @@ var ( auth [success=3 default=die ignore=ignore] pam_authd.so */ +func sendReturnMessageToPam(mTx pam.ModuleTransaction, retStatus pamReturnStatus) { + msg := retStatus.Message() + if msg == "" { + return + } + + style := pam.ErrorMsg + switch retStatus.(type) { + case pamIgnore, pamSuccess: + style = pam.TextInfo + } + + if _, err := mTx.StartStringConv(style, msg); err != nil { + log.Errorf(context.TODO(), "Failed sending message to pam: %v", err) + } +} + // Authenticate is the method that is invoked during pam_authenticate request. func (h *pamModule) Authenticate(mTx pam.ModuleTransaction, flags pam.Flags, args []string) error { // Initialize localization @@ -76,6 +93,8 @@ func (h *pamModule) Authenticate(mTx pam.ModuleTransaction, flags pam.Flags, arg return pam.ErrAbort } + sendReturnMessageToPam(mTx, appState.exitStatus) + switch exitStatus := appState.exitStatus.(type) { case pamSuccess: brokerIDUsedToAuthenticate = exitStatus.brokerID From c8749d98ebcc5b14a2f5a650056b8c35c166b8c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Trevisan=20=28Trevi=C3=B1o=29?= Date: Sat, 14 Oct 2023 01:18:54 +0200 Subject: [PATCH 08/12] pam: Save broker authentication ID in the module data We used to store the authentication ID as global value, but pam can handle this natively now, allowing us to store it as module data. --- pam/pam.go | 51 +++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 43 insertions(+), 8 deletions(-) diff --git a/pam/pam.go b/pam/pam.go index 364ff21128..c9f79a0319 100644 --- a/pam/pam.go +++ b/pam/pam.go @@ -9,6 +9,7 @@ import "C" import ( "context" + "errors" "fmt" "os" "runtime" @@ -30,9 +31,11 @@ import ( type pamModule struct { } -var ( - // brokerIDUsedToAuthenticate global variable is for the second stage authentication to select the default broker for the current user. - brokerIDUsedToAuthenticate string +const ( + // authenticationBrokerIDKey is the Key used to store the data in the + // PAM module for the second stage authentication to select the default + // broker for the current user. + authenticationBrokerIDKey = "authentication-broker-id" ) /* @@ -41,6 +44,19 @@ var ( auth [success=3 default=die ignore=ignore] pam_authd.so */ +func showPamMessage(mTx pam.ModuleTransaction, style pam.Style, msg string) error { + switch style { + case pam.TextInfo | pam.ErrorMsg: + default: + return fmt.Errorf("message style not supported: %v", style) + } + if _, err := mTx.StartStringConv(style, msg); err != nil { + log.Errorf(context.TODO(), "Failed sending message to pam: %v", err) + return err + } + return nil +} + func sendReturnMessageToPam(mTx pam.ModuleTransaction, retStatus pamReturnStatus) { msg := retStatus.Message() if msg == "" { @@ -53,9 +69,7 @@ func sendReturnMessageToPam(mTx pam.ModuleTransaction, retStatus pamReturnStatus style = pam.TextInfo } - if _, err := mTx.StartStringConv(style, msg); err != nil { - log.Errorf(context.TODO(), "Failed sending message to pam: %v", err) - } + _ = showPamMessage(mTx, style, msg) } // Authenticate is the method that is invoked during pam_authenticate request. @@ -81,6 +95,10 @@ func (h *pamModule) Authenticate(mTx pam.ModuleTransaction, flags pam.Flags, arg interactiveTerminal: interactiveTerminal, } + if err := mTx.SetData(authenticationBrokerIDKey, nil); err != nil { + return err + } + //tea.WithInput(nil) //tea.WithoutRenderer() var opts []tea.ProgramOption @@ -97,12 +115,16 @@ func (h *pamModule) Authenticate(mTx pam.ModuleTransaction, flags pam.Flags, arg switch exitStatus := appState.exitStatus.(type) { case pamSuccess: - brokerIDUsedToAuthenticate = exitStatus.brokerID + if err := mTx.SetData(authenticationBrokerIDKey, exitStatus.brokerID); err != nil { + return err + } return nil case pamIgnore: // localBrokerID is only set on pamIgnore if the user has chosen local broker. - brokerIDUsedToAuthenticate = exitStatus.localBrokerID + if err := mTx.SetData(authenticationBrokerIDKey, exitStatus.localBrokerID); err != nil { + return err + } return fmt.Errorf("%w: %s", exitStatus.Status(), exitStatus.Message()) case pamReturnError: @@ -114,6 +136,19 @@ func (h *pamModule) Authenticate(mTx pam.ModuleTransaction, flags pam.Flags, arg // AcctMgmt sets any used brokerID as default for the user. func (h *pamModule) AcctMgmt(mTx pam.ModuleTransaction, flags pam.Flags, args []string) error { + brokerData, err := mTx.GetData(authenticationBrokerIDKey) + if err != nil && errors.Is(err, pam.ErrNoModuleData) { + return pam.ErrIgnore + } + + brokerIDUsedToAuthenticate, ok := brokerData.(string) + if !ok { + msg := fmt.Sprintf("broker data as an invalid type %#v", brokerData) + log.Errorf(context.TODO(), msg) + _ = showPamMessage(mTx, pam.ErrorMsg, msg) + return pam.ErrIgnore + } + // Only set the brokerID as default if we stored one after authentication. if brokerIDUsedToAuthenticate == "" { return pam.ErrIgnore From 178ed275c2f13eced46a2ecdc5b65bd69763d852 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Trevisan=20=28Trevi=C3=B1o=29?= Date: Sat, 14 Oct 2023 05:14:43 +0200 Subject: [PATCH 09/12] pam: Add pam_test.ModuleTransactionDummy interface definition We have a base interface that the pam.ModuleTransaction should implement and so use and pass this interface instead of relying on the actual module transaction implementation, so that we can mock it. As per this introduce a new dummy implementation for it that can be used in local tests --- pam/pam_test/module-transaction-dummy.go | 362 ++++++++ pam/pam_test/module-transaction-dummy_test.go | 817 ++++++++++++++++++ pam/pam_test/utils.go | 30 + 3 files changed, 1209 insertions(+) create mode 100644 pam/pam_test/module-transaction-dummy.go create mode 100644 pam/pam_test/module-transaction-dummy_test.go create mode 100644 pam/pam_test/utils.go diff --git a/pam/pam_test/module-transaction-dummy.go b/pam/pam_test/module-transaction-dummy.go new file mode 100644 index 0000000000..a9142c7e11 --- /dev/null +++ b/pam/pam_test/module-transaction-dummy.go @@ -0,0 +1,362 @@ +// Package pam_test includes test tools for the PAM module +package pam_test + +import ( + "errors" + "fmt" + "runtime" + "strings" + + "github.com/msteinert/pam" +) + +// ModuleTransactionDummy is an implementation of PamModuleTransaction for +// testing purposes. +type ModuleTransactionDummy struct { + Items map[pam.Item]string + Env map[string]string + Data map[string]any + convHandler pam.ConversationHandler +} + +// NewModuleTransactionDummy returns a new PamModuleTransactionDummy. +func NewModuleTransactionDummy(convHandler pam.ConversationHandler) pam.ModuleTransaction { + return &ModuleTransactionDummy{ + convHandler: convHandler, + Data: make(map[string]any), + Env: make(map[string]string), + Items: make(map[pam.Item]string), + } +} + +// InvokeHandler is called by the C code to invoke the proper handler. +func (m *ModuleTransactionDummy) InvokeHandler(handler pam.ModuleHandlerFunc, + flags pam.Flags, args []string) error { + return pam.ErrAbort +} + +// SetItem sets a PAM information item. +func (m *ModuleTransactionDummy) SetItem(item pam.Item, value string) error { + if item <= 0 { + return pam.ErrBadItem + } + + m.Items[item] = value + return nil +} + +// GetItem retrieves a PAM information item. +func (m *ModuleTransactionDummy) GetItem(item pam.Item) (string, error) { + if item <= 0 { + return "", pam.ErrBadItem + } + return m.Items[item], nil +} + +// PutEnv adds or changes the value of PAM environment variables. +// +// NAME=value will set a variable to a value. +// NAME= will set a variable to an empty value. +// NAME (without an "=") will delete a variable. +func (m *ModuleTransactionDummy) PutEnv(nameVal string) error { + env, value, found := strings.Cut(nameVal, "=") + if !found { + delete(m.Env, env) + return nil + } + if env == "" { + return pam.ErrBadItem + } + m.Env[env] = value + return nil +} + +// GetEnv is used to retrieve a PAM environment variable. +func (m *ModuleTransactionDummy) GetEnv(name string) string { + return m.Env[name] +} + +// GetEnvList returns a copy of the PAM environment as a map. +func (m *ModuleTransactionDummy) GetEnvList() (map[string]string, error) { + return m.Env, nil +} + +// GetUser is similar to GetItem(User), but it would start a conversation if +// no user is currently set in PAM. +func (m *ModuleTransactionDummy) GetUser(prompt string) (string, error) { + user, err := m.GetItem(pam.User) + if err != nil { + return "", err + } + if user != "" { + return user, nil + } + + resp, err := m.StartStringConv(pam.PromptEchoOn, prompt) + if err != nil { + return "", err + } + + return resp.Response(), nil +} + +// SetData allows to save any value in the module data that is preserved +// during the whole time the module is loaded. +func (m *ModuleTransactionDummy) SetData(key string, data any) error { + if data == nil { + delete(m.Data, key) + return nil + } + + m.Data[key] = data + return nil +} + +// GetData allows to get any value from the module data saved using SetData +// that is preserved across the whole time the module is loaded. +func (m *ModuleTransactionDummy) GetData(key string) (any, error) { + data, found := m.Data[key] + if !found { + return nil, pam.ErrNoModuleData + } + return data, nil +} + +// StartStringConv starts a text-based conversation using the provided style +// and prompt. +func (m *ModuleTransactionDummy) StartStringConv(style pam.Style, prompt string) ( + pam.StringConvResponse, error) { + if style == pam.BinaryPrompt { + return nil, fmt.Errorf("%w: binary style is not supported", pam.ErrConv) + } + + res, err := m.StartConv(pam.NewStringConvRequest(style, prompt)) + if err != nil { + return nil, err + } + + stringRes, ok := res.(pam.StringConvResponse) + if !ok { + return nil, fmt.Errorf("%w: can't convert to pam.StringConvResponse", pam.ErrConv) + } + return stringRes, nil +} + +// StartStringConvf allows to start string conversation with formatting support. +func (m *ModuleTransactionDummy) StartStringConvf(style pam.Style, format string, args ...interface{}) ( + pam.StringConvResponse, error) { + return m.StartStringConv(style, fmt.Sprintf(format, args...)) +} + +// StartBinaryConv starts a binary conversation using the provided bytes. +func (m *ModuleTransactionDummy) StartBinaryConv(bytes []byte) ( + pam.BinaryConvResponse, error) { + res, err := m.StartConv(NewBinaryRequestDummyFromBytes(bytes)) + if err != nil { + return nil, err + } + + binaryRes, ok := res.(pam.BinaryConvResponse) + if !ok { + return nil, fmt.Errorf("%w: can't convert to pam.BinaryConvResponse", pam.ErrConv) + } + return binaryRes, nil +} + +// StartConv initiates a PAM conversation using the provided ConvRequest. +func (m *ModuleTransactionDummy) StartConv(req pam.ConvRequest) ( + pam.ConvResponse, error) { + resp, err := m.StartConvMulti([]pam.ConvRequest{req}) + if err != nil { + return nil, err + } + if len(resp) != 1 { + return nil, fmt.Errorf("%w: not enough values returned", pam.ErrConv) + } + return resp[0], nil +} + +func (m *ModuleTransactionDummy) handleStringRequest(req pam.ConvRequest) (pam.StringConvResponse, error) { + msgStyle := req.Style() + if m.convHandler == nil { + return nil, fmt.Errorf("no conversation handler provided for style %v", msgStyle) + } + reply, err := m.convHandler.RespondPAM(msgStyle, + req.(pam.StringConvRequest).Prompt()) + if err != nil { + return nil, err + } + + return StringResponseDummy{ + msgStyle, + reply, + }, nil +} + +func (m *ModuleTransactionDummy) handleBinaryRequest(req pam.ConvRequest) (pam.BinaryConvResponse, error) { + if m.convHandler == nil { + return nil, errors.New("no binary handler provided") + } + + //nolint:forcetypeassert + // req must be a pam.BinaryConvRequester, if that's not the case we should + // just panic since this code is only expected to run in tests. + binReq := req.(pam.BinaryConvRequester) + + switch handler := m.convHandler.(type) { + case pam.BinaryConversationHandler: + r, err := handler.RespondPAMBinary(binReq.Pointer()) + if err != nil { + return nil, err + } + return binReq.CreateResponse(pam.BinaryPointer(&r)), nil + + case pam.BinaryPointerConversationHandler: + r, err := handler.RespondPAMBinary(binReq.Pointer()) + if err != nil { + if r != nil { + resp := binReq.CreateResponse(r) + resp.Release() + } + return nil, err + } + return binReq.CreateResponse(r), nil + + default: + return nil, fmt.Errorf("unsupported conversation handler %#v", handler) + } +} + +// StartConvMulti initiates a PAM conversation with multiple ConvRequest's. +func (m *ModuleTransactionDummy) StartConvMulti(requests []pam.ConvRequest) ( + responses []pam.ConvResponse, err error) { + defer func() { + if err != nil { + err = errors.Join(pam.ErrConv, err) + } + }() + + if len(requests) == 0 { + return nil, errors.New("no requests defined") + } + + responses = make([]pam.ConvResponse, 0, len(requests)) + for _, req := range requests { + msgStyle := req.Style() + switch msgStyle { + case pam.PromptEchoOff: + fallthrough + case pam.PromptEchoOn: + fallthrough + case pam.ErrorMsg: + fallthrough + case pam.TextInfo: + response, err := m.handleStringRequest(req) + if err != nil { + return nil, err + } + responses = append(responses, response) + case pam.BinaryPrompt: + response, err := m.handleBinaryRequest(req) + if err != nil { + return nil, err + } + responses = append(responses, response) + default: + return nil, fmt.Errorf("unsupported conversation type %v", msgStyle) + } + } + + return responses, nil +} + +// BinaryRequestDummy is a dummy pam.BinaryConvRequester implementation. +type BinaryRequestDummy struct { + ptr pam.BinaryPointer +} + +// NewBinaryRequestDummy creates a new BinaryConvRequest with finalizer +// for response BinaryResponse. +func NewBinaryRequestDummy(ptr pam.BinaryPointer) *BinaryRequestDummy { + return &BinaryRequestDummy{ptr} +} + +// NewBinaryRequestDummyFromBytes creates a new BinaryConvRequestDummy from +// an array of bytes. +func NewBinaryRequestDummyFromBytes(bytes []byte) *BinaryRequestDummy { + if bytes == nil { + return &BinaryRequestDummy{} + } + return NewBinaryRequestDummy(pam.BinaryPointer(&bytes)) +} + +// Style returns the response style for the request, so always BinaryPrompt. +func (b BinaryRequestDummy) Style() pam.Style { + return pam.BinaryPrompt +} + +// Pointer returns the conversation style of the StringConvRequest. +func (b BinaryRequestDummy) Pointer() pam.BinaryPointer { + return b.ptr +} + +// CreateResponse creates a new BinaryConvResponse from the request. +func (b BinaryRequestDummy) CreateResponse(ptr pam.BinaryPointer) pam.BinaryConvResponse { + bcr := &BinaryResponseDummy{ptr} + runtime.SetFinalizer(bcr, func(bcr *BinaryResponseDummy) { + bcr.Release() + }) + return bcr +} + +// Release releases the resources allocated by the request. +func (b *BinaryRequestDummy) Release() { + b.ptr = nil +} + +// StringResponseDummy is a simple implementation of pam.StringConvResponse. +type StringResponseDummy struct { + style pam.Style + content string +} + +// Style returns the conversation style of the StringResponseDummy. +func (s StringResponseDummy) Style() pam.Style { + return s.style +} + +// Response returns the string response of the StringResponseDummy. +func (s StringResponseDummy) Response() string { + return s.content +} + +// BinaryResponseDummy is an implementation of pam.BinaryConvResponse. +type BinaryResponseDummy struct { + ptr pam.BinaryPointer +} + +// Style returns the response style for the response, so always BinaryPrompt. +func (b BinaryResponseDummy) Style() pam.Style { + return pam.BinaryPrompt +} + +// Data returns the response native pointer, it's up to the protocol to parse +// it accordingly. +func (b BinaryResponseDummy) Data() pam.BinaryPointer { + return b.ptr +} + +// Release releases the memory associated with the pointer. +func (b *BinaryResponseDummy) Release() { + b.ptr = nil + runtime.SetFinalizer(b, nil) +} + +// Decode decodes the binary data using the provided decoder function. +func (b BinaryResponseDummy) Decode(decoder pam.BinaryDecoder) ( + []byte, error) { + if decoder == nil { + return nil, errors.New("nil decoder provided") + } + return decoder(b.Data()) +} diff --git a/pam/pam_test/module-transaction-dummy_test.go b/pam/pam_test/module-transaction-dummy_test.go new file mode 100644 index 0000000000..35096e78f5 --- /dev/null +++ b/pam/pam_test/module-transaction-dummy_test.go @@ -0,0 +1,817 @@ +package pam_test + +import ( + "fmt" + "testing" + + "github.com/msteinert/pam" + "github.com/stretchr/testify/require" +) + +func ptrValue[T any](value T) *T { + return &value +} + +func bytesPointerDecoder(ptr pam.BinaryPointer) ([]byte, error) { + if ptr == nil { + return nil, nil + } + return *(*[]byte)(ptr), nil +} + +func TestSetGetItem(t *testing.T) { + t.Parallel() + + tests := map[string]struct { + item pam.Item + value *string + + wantValue *string + wantGetError error + wantSetError error + }{ + "Set user": { + item: pam.User, + value: ptrValue("an user"), + }, + + "Returns empty when getting an unset user": { + item: pam.User, + wantValue: ptrValue(""), + }, + + "Setting and getting an user": { + item: pam.User, + value: ptrValue("the-user"), + wantValue: ptrValue("the-user"), + }, + + // Error cases + "Error when setting invalid item": { + item: pam.Item(-1), + value: ptrValue("some value"), + wantSetError: pam.ErrBadItem, + }, + + "Error when getting invalid item": { + item: pam.Item(-1), + wantGetError: pam.ErrBadItem, + wantValue: ptrValue(""), + }, + } + + for name, tc := range tests { + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + + tx := NewModuleTransactionDummy(nil) + + if tc.value != nil { + err := tx.SetItem(tc.item, *tc.value) + require.ErrorIs(t, err, tc.wantSetError) + } + + if tc.wantValue != nil { + value, err := tx.GetItem(tc.item) + require.Equal(t, value, *tc.wantValue) + require.ErrorIs(t, err, tc.wantGetError) + } + }) + } +} + +func TestSetPutEnv(t *testing.T) { + t.Parallel() + + tests := map[string]struct { + env string + value *string + presetValues map[string]string + skipPut bool + + wantValue *string + wantPutError error + }{ + "Put var": { + env: "AN_ENV", + value: ptrValue("value"), + }, + + "Unset a not-previously set value": { + env: "NEVER_SET_ENV", + wantValue: ptrValue(""), + }, + + "Unset a preset value": { + presetValues: map[string]string{"PRESET_ENV": "hey!"}, + env: "PRESET_ENV", + wantValue: ptrValue(""), + }, + + "Changes a preset var": { + presetValues: map[string]string{"PRESET_ENV": "hey!"}, + env: "PRESET_ENV", + value: ptrValue("hello!"), + wantValue: ptrValue("hello!"), + }, + + "Get an unset env": { + skipPut: true, + env: "AN_UNSET_ENV", + wantValue: ptrValue(""), + }, + + "Gets an invalid env name": { + env: "", + value: ptrValue("Invalid Value"), + wantValue: ptrValue(""), + skipPut: true, + }, + + // Error cases + "Error when putting an invalid env name": { + env: "", + value: ptrValue("Invalid Value"), + wantPutError: pam.ErrBadItem, + }, + } + + for name, tc := range tests { + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + + tx := NewModuleTransactionDummy(nil) + envList, err := tx.GetEnvList() + require.NoErrorf(t, err, "Setup: GetEnvList should not return an error") + require.Lenf(t, envList, 0, "Setup: GetEnvList should have elements") + + if tc.presetValues != nil && !tc.skipPut { + for env, value := range tc.presetValues { + err := tx.PutEnv(env + "=" + value) + require.NoError(t, err) + } + envList, err = tx.GetEnvList() + require.NoError(t, err) + require.Equal(t, tc.presetValues, envList) + } + + if !tc.skipPut { + var env string + if tc.value != nil { + env = tc.env + "=" + *tc.value + } else { + env = tc.env + } + err := tx.PutEnv(env) + require.ErrorIs(t, err, tc.wantPutError) + + wantEnv := map[string]string{} + if tc.wantPutError == nil { + if tc.value != nil { + wantEnv = map[string]string{tc.env: *tc.value} + } + if tc.value != nil && tc.wantValue != nil { + wantEnv = map[string]string{tc.env: *tc.wantValue} + } + } + gotEnv, err := tx.GetEnvList() + require.NoError(t, err, "tx.GetEnvList should not return an error") + require.Equal(t, wantEnv, gotEnv, "returned env lits should match expected") + } + + if tc.wantValue != nil { + value := tx.GetEnv(tc.env) + require.Equal(t, value, *tc.wantValue) + } + }) + } +} + +func TestSetGetData(t *testing.T) { + t.Parallel() + + tests := map[string]struct { + key string + data any + presetData map[string]any + skipSet bool + skipGet bool + + wantData any + wantSetError error + wantGetError error + }{ + "Sets and gets data": { + presetData: map[string]any{"some-data": []any{"hey! That's", true}}, + key: "data", + data: []any{"hey! That's", true}, + wantData: []any{"hey! That's", true}, + }, + + "Set replaces data": { + presetData: map[string]any{"some-data": []any{"hey! That's", true}}, + key: "some-data", + data: ModuleTransactionDummy{ + Items: map[pam.Item]string{pam.Tty: "yay"}, + Env: map[string]string{"foo": "bar"}, + }, + wantData: ModuleTransactionDummy{ + Items: map[pam.Item]string{pam.Tty: "yay"}, + Env: map[string]string{"foo": "bar"}, + }, + }, + + // Error cases + "Error when getting data that has never been set": { + skipSet: true, + key: "not set", + wantGetError: pam.ErrNoModuleData, + }, + + "Error when getting data that has been removed": { + presetData: map[string]any{"some-data": []any{"hey! That's", true}}, + key: "some-data", + data: nil, + wantGetError: pam.ErrNoModuleData, + }, + } + + for name, tc := range tests { + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + + tx := NewModuleTransactionDummy(nil) + + if tc.presetData != nil && !tc.skipSet { + for key, value := range tc.presetData { + err := tx.SetData(key, value) + require.NoError(t, err) + } + } + + if !tc.skipSet { + err := tx.SetData(tc.key, tc.data) + require.ErrorIs(t, err, tc.wantSetError) + } + + if !tc.skipGet { + data, err := tx.GetData(tc.key) + require.ErrorIs(t, err, tc.wantGetError) + require.Equal(t, tc.wantData, data) + } + }) + } +} + +func TestGetUser(t *testing.T) { + t.Parallel() + + tests := map[string]struct { + presetUser string + convHandler pam.ConversationHandler + + want string + wantError error + }{ + "Getting a previously set user does not require conversation handler": { + presetUser: "an-user", + want: "an-user", + }, + + "Getting a previously set user does not use conversation handler": { + presetUser: "an-user", + want: "an-user", + convHandler: pam.ConversationFunc(func(s pam.Style, msg string) (string, error) { + return "another-user", pam.ErrConv + }), + }, + + "Getting the user uses conversation handler if none was set": { + want: "provided-user", + convHandler: pam.ConversationFunc( + func(s pam.Style, msg string) (string, error) { + require.Equal(t, msg, "Who are you?") + if msg != "Who are you?" { + return "", pam.ErrConv + } + if s == pam.PromptEchoOn { + return "provided-user", nil + } + return "", pam.ErrConv + }), + }, + + // Error cases + "Error when no conversation is set": { + want: "", + wantError: pam.ErrConv, + }, + } + + for name, tc := range tests { + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + + tx := NewModuleTransactionDummy(tc.convHandler) + + if tc.presetUser != "" { + err := tx.SetItem(pam.User, tc.presetUser) + require.NoError(t, err) + } + + prompt := "Who are you?" + user, err := tx.GetUser(prompt) + require.ErrorIs(t, err, tc.wantError) + require.Equal(t, tc.want, user) + }) + } +} + +func TestStartStringConv(t *testing.T) { + t.Parallel() + + tests := map[string]struct { + prompt string + promptFormat string + promptFormatArgs []interface{} + convStyle pam.Style + convError error + convHandler *pam.ConversationFunc + convShouldNotBeCalled bool + + want string + wantError error + }{ + "Messages with error style are handled by conversation": { + prompt: "This is an error!", + convStyle: pam.ErrorMsg, + want: "I'm handling it fine though", + }, + + "Conversation prompt can be formatted": { + promptFormat: "Sending some %s, right? %v", + promptFormatArgs: []interface{}{"info", true}, + convStyle: pam.TextInfo, + want: "And returning some text back", + }, + + // Error cases + "Error if no conversation handler is set": { + convHandler: ptrValue(pam.ConversationFunc(nil)), + wantError: pam.ErrConv, + }, + + "Error if the conversation handler fails": { + prompt: "Tell me your secret!", + convStyle: pam.PromptEchoOff, + convError: pam.ErrBuf, + wantError: pam.ErrBuf, + }, + + "Error when conversation uses binary content style": { + prompt: "I am a binary content\xff!", + convStyle: pam.BinaryPrompt, + convError: pam.ErrConv, + wantError: pam.ErrConv, + convShouldNotBeCalled: true, + }, + } + + for name, tc := range tests { + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + + convFunCalled := false + tx := NewModuleTransactionDummy(func() pam.ConversationFunc { + if tc.convHandler != nil { + return *tc.convHandler + } + prompt := tc.prompt + if tc.promptFormat != "" { + prompt = fmt.Sprintf(tc.promptFormat, tc.promptFormatArgs...) + } + return pam.ConversationFunc( + func(style pam.Style, msg string) (string, error) { + convFunCalled = true + require.Equal(t, prompt, msg) + require.Equal(t, tc.convStyle, style) + return tc.want, tc.convError + }) + }()) + + var reply pam.StringConvResponse + var err error + + if tc.promptFormat != "" { + reply, err = tx.StartStringConvf(tc.convStyle, tc.promptFormat, + tc.promptFormatArgs...) + } else { + reply, err = tx.StartStringConv(tc.convStyle, tc.prompt) + } + + wantConFuncCalled := !tc.convShouldNotBeCalled && tc.convHandler == nil + require.Equal(t, wantConFuncCalled, convFunCalled) + require.ErrorIs(t, err, tc.wantError) + + if tc.wantError != nil { + require.Zero(t, reply) + return + } + + require.NotNil(t, reply) + require.Equal(t, tc.want, reply.Response()) + require.Equal(t, tc.convStyle, reply.Style()) + }) + } +} + +func TestStartBinaryConv(t *testing.T) { + t.Parallel() + + tests := map[string]struct { + request []byte + convError error + convHandler *pam.ConversationHandler + + want []byte + wantError error + }{ + "Simple binary conversation": { + request: []byte{0x01, 0x02, 0x03}, + want: []byte{0x00, 0x01, 0x02, 0x03, 0x4}, + }, + + // Error cases + "Error if no conversation handler is set": { + convHandler: ptrValue(pam.ConversationHandler(nil)), + wantError: pam.ErrConv, + }, + + "Error if no binary conversation handler is set": { + convHandler: ptrValue(pam.ConversationHandler(pam.ConversationFunc( + func(s pam.Style, msg string) (string, error) { + return "", nil + }))), + wantError: pam.ErrConv, + }, + + "Error if the conversation handler fails": { + request: []byte{0x03, 0x02, 0x01}, + convError: pam.ErrBuf, + wantError: pam.ErrBuf, + }, + } + + for name, tc := range tests { + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + + convFunCalled := false + tx := NewModuleTransactionDummy(func() pam.ConversationHandler { + if tc.convHandler != nil { + return *tc.convHandler + } + + return pam.BinaryConversationFunc( + func(ptr pam.BinaryPointer) ([]byte, error) { + convFunCalled = true + require.NotNil(t, ptr) + bytes := *(*[]byte)(ptr) + require.Equal(t, tc.request, bytes) + return tc.want, tc.convError + }) + }()) + + response, err := tx.StartBinaryConv(tc.request) + require.ErrorIs(t, err, tc.wantError) + require.Equal(t, tc.convHandler == nil, convFunCalled) + + if tc.wantError != nil { + require.Nil(t, response) + return + } + + defer response.Release() + require.NotNil(t, response) + require.Equal(t, pam.BinaryPrompt, response.Style()) + require.NotNil(t, response.Data()) + bytes, err := response.Decode(bytesPointerDecoder) + require.NoError(t, err) + require.Equal(t, tc.want, bytes) + + bytes, err = response.Decode(nil) + require.ErrorContains(t, err, "nil decoder provided") + require.Nil(t, bytes) + }) + } +} + +func TestStartBinaryPointerConv(t *testing.T) { + t.Parallel() + + tests := map[string]struct { + request []byte + convError error + convHandler *pam.ConversationHandler + + want []byte + wantError error + }{ + "With nil argument": { + request: nil, + want: nil, + }, + + "With empty argument": { + request: []byte{}, + want: []byte{}, + }, + + "With simple argument": { + request: []byte{0x01, 0x02, 0x03}, + want: []byte{0x00, 0x01, 0x02, 0x03, 0x4}, + }, + + // Error cases + "Error if no conversation handler is set": { + convHandler: ptrValue(pam.ConversationHandler(nil)), + wantError: pam.ErrConv, + }, + + "Error if no binary conversation handler is set": { + convHandler: ptrValue(pam.ConversationHandler(pam.ConversationFunc( + func(s pam.Style, msg string) (string, error) { + return "", nil + }))), + wantError: pam.ErrConv, + }, + + "Error if the conversation handler fails": { + request: []byte{0xde, 0xad, 0xbe, 0xef, 0xf}, + convError: pam.ErrBuf, + wantError: pam.ErrBuf, + }, + + "Error if no conversation handler is set handles allocated data": { + convError: pam.ErrSystem, + want: []byte{0xde, 0xad, 0xbe, 0xef, 0xf}, + wantError: pam.ErrSystem, + }, + } + + for name, tc := range tests { + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + + convFunCalled := false + tx := NewModuleTransactionDummy(func() pam.ConversationHandler { + if tc.convHandler != nil { + return *tc.convHandler + } + + return pam.BinaryPointerConversationFunc( + func(ptr pam.BinaryPointer) (pam.BinaryPointer, error) { + convFunCalled = true + if tc.request == nil { + require.Nil(t, ptr) + } else { + require.NotNil(t, ptr) + } + bytes := cBytesToBytes(ptr, len(tc.request)) + require.Equal(t, tc.request, bytes) + return allocateCBytes(tc.want), tc.convError + }) + }()) + res, err := tx.StartConv(pam.NewBinaryConvRequest( + allocateCBytes(tc.request), releaseCBytesPointer)) + require.ErrorIs(t, err, tc.wantError) + require.Equal(t, tc.convHandler == nil, convFunCalled) + + if tc.wantError != nil { + require.Nil(t, res) + return + } + + response, ok := res.(pam.BinaryConvResponse) + require.True(t, ok) + defer response.Release() + require.NotNil(t, response) + require.Equal(t, pam.BinaryPrompt, response.Style()) + if tc.want == nil { + require.Nil(t, response.Data()) + } else { + require.NotNil(t, response.Data()) + } + bytes, err := response.Decode(func(ptr pam.BinaryPointer) ([]byte, error) { + return cBytesToBytes(ptr, len(tc.want)), nil + }) + require.NoError(t, err) + require.Equal(t, tc.want, bytes) + + bytes, err = response.Decode(nil) + require.ErrorContains(t, err, "nil decoder provided") + require.Nil(t, bytes) + }) + } +} + +type multiConvHandler struct { + t *testing.T + responses []pam.ConvResponse + wantRequests []pam.ConvRequest + timesCalled int +} + +func (c *multiConvHandler) next() (pam.ConvRequest, pam.ConvResponse) { + i := c.timesCalled + c.timesCalled++ + + return c.wantRequests[i], c.responses[i] +} + +func (c *multiConvHandler) RespondPAM(style pam.Style, prompt string) (string, error) { + wantReq, response := c.next() + require.Equal(c.t, wantReq.Style(), style) + stringReq, ok := wantReq.(pam.StringConvRequest) + require.True(c.t, ok) + require.Equal(c.t, stringReq.Prompt(), prompt) + stringRes, ok := response.(pam.StringConvResponse) + require.True(c.t, ok) + return stringRes.Response(), nil +} + +func (c *multiConvHandler) RespondPAMBinary(ptr pam.BinaryPointer) ([]byte, error) { + wantReq, response := c.next() + require.Equal(c.t, wantReq.Style(), pam.BinaryPrompt) + + binReq, ok := wantReq.(pam.BinaryConvRequester) + require.True(c.t, ok) + wantReqBytes, err := bytesPointerDecoder(binReq.Pointer()) + require.NoError(c.t, err) + actualReqBytes, err := bytesPointerDecoder(ptr) + require.NoError(c.t, err) + require.Equal(c.t, wantReqBytes, actualReqBytes) + + bytes, err := response.(pam.BinaryConvResponse).Decode(bytesPointerDecoder) + require.NoError(c.t, err) + return bytes, nil +} + +func TestStartConvMulti(t *testing.T) { + t.Parallel() + + tests := map[string]struct { + requests []pam.ConvRequest + + wantResponses []pam.ConvResponse + wantConvCalls *int + wantError error + }{ + "Can address multiple string requests": { + requests: []pam.ConvRequest{ + pam.NewStringConvRequest(pam.PromptEchoOff, "give some PromptEchoOff"), + pam.NewStringConvRequest(pam.PromptEchoOn, "give some PromptEchoOn"), + pam.NewStringConvRequest(pam.ErrorMsg, "give some ErrorMsg"), + pam.NewStringConvRequest(pam.TextInfo, "give some TextInfo"), + }, + wantResponses: []pam.ConvResponse{ + StringResponseDummy{pam.PromptEchoOff, "answer to PromptEchoOff"}, + StringResponseDummy{pam.PromptEchoOn, "answer to PromptEchoOn"}, + StringResponseDummy{pam.ErrorMsg, "answer to ErrorMsg"}, + StringResponseDummy{pam.TextInfo, "answer to TextInfo"}, + }, + }, + + "Can address multiple binary requests": { + requests: []pam.ConvRequest{ + NewBinaryRequestDummy(nil), + NewBinaryRequestDummy(pam.BinaryPointer(&[]byte{})), + NewBinaryRequestDummy(pam.BinaryPointer(&[]byte{0xFF, 0x00, 0xBA, 0xAB})), + NewBinaryRequestDummy(pam.BinaryPointer(&[]byte{0x55})), + }, + wantResponses: []pam.ConvResponse{ + &BinaryResponseDummy{pam.BinaryPointer(&[]byte{})}, + &BinaryResponseDummy{nil}, + &BinaryResponseDummy{pam.BinaryPointer(&[]byte{0x53})}, + &BinaryResponseDummy{pam.BinaryPointer(&[]byte{0xAF, 0x00, 0xBA, 0xAC})}, + }, + }, + + "Can address multiple mixed binary and string requests ": { + requests: []pam.ConvRequest{ + NewBinaryRequestDummy(nil), + pam.NewStringConvRequest(pam.PromptEchoOff, "PromptEchoOff"), + NewBinaryRequestDummy(pam.BinaryPointer(&[]byte{})), + pam.NewStringConvRequest(pam.PromptEchoOn, "PromptEchoOn"), + NewBinaryRequestDummy(pam.BinaryPointer(&[]byte{0xFF, 0x00, 0xBA, 0xAB})), + pam.NewStringConvRequest(pam.ErrorMsg, "ErrorMsg"), + NewBinaryRequestDummy(pam.BinaryPointer(&[]byte{0x55})), + pam.NewStringConvRequest(pam.TextInfo, "TextInfo"), + }, + wantResponses: []pam.ConvResponse{ + &BinaryResponseDummy{pam.BinaryPointer(&[]byte{})}, + StringResponseDummy{pam.PromptEchoOff, "PromptEchoOff"}, + &BinaryResponseDummy{pam.BinaryPointer(&[]byte{0x55})}, + StringResponseDummy{pam.PromptEchoOn, "PromptEchoOn"}, + &BinaryResponseDummy{nil}, + StringResponseDummy{pam.ErrorMsg, "ErrorMsg"}, + &BinaryResponseDummy{pam.BinaryPointer(&[]byte{0xAF, 0x00, 0xBA, 0xAC})}, + StringResponseDummy{pam.TextInfo, "TextInfo"}, + }, + }, + + // Error cases + "Error if no request is provided": { + wantError: pam.ErrConv, + }, + + "Error if one of the multiple request fails": { + requests: []pam.ConvRequest{ + NewBinaryRequestDummy(nil), + pam.NewStringConvRequest(pam.PromptEchoOff, "PromptEchoOff"), + NewBinaryRequestDummy(pam.BinaryPointer(&[]byte{})), + pam.NewStringConvRequest(pam.PromptEchoOn, "PromptEchoOn"), + NewBinaryRequestDummy(pam.BinaryPointer(&[]byte{0xFF, 0x00, 0xBA, 0xAB})), + // The case below will lead to the whole request to fail! + pam.NewStringConvRequest(pam.Style(-1), "Invalid style"), + pam.NewStringConvRequest(pam.ErrorMsg, "ErrorMsg"), + NewBinaryRequestDummy(pam.BinaryPointer(&[]byte{0x55})), + pam.NewStringConvRequest(pam.TextInfo, "TextInfo"), + }, + wantResponses: []pam.ConvResponse{ + &BinaryResponseDummy{pam.BinaryPointer(&[]byte{})}, + StringResponseDummy{pam.PromptEchoOff, "PromptEchoOff"}, + &BinaryResponseDummy{pam.BinaryPointer(&[]byte{0x55})}, + StringResponseDummy{pam.PromptEchoOn, "PromptEchoOn"}, + &BinaryResponseDummy{nil}, + StringResponseDummy{pam.Style(-1), "Invalid style"}, + StringResponseDummy{pam.ErrorMsg, "ErrorMsg"}, + &BinaryResponseDummy{pam.BinaryPointer(&[]byte{0xAF, 0x00, 0xBA, 0xAC})}, + StringResponseDummy{pam.TextInfo, "TextInfo"}, + }, + wantConvCalls: ptrValue(5), + wantError: pam.ErrConv, + }, + } + + for name, tc := range tests { + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + + require.Equalf(t, len(tc.wantResponses), len(tc.requests), + "Setup: mismatch on expectations / requests numbers") + + convHandler := &multiConvHandler{ + t: t, + wantRequests: tc.requests, + responses: tc.wantResponses, + } + tx := NewModuleTransactionDummy(convHandler) + + responses, err := tx.StartConvMulti(tc.requests) + require.ErrorIs(t, err, tc.wantError) + + wantConvCalls := len(tc.requests) + if tc.wantConvCalls != nil { + wantConvCalls = *tc.wantConvCalls + } + require.Equal(t, convHandler.timesCalled, wantConvCalls) + + if tc.wantError != nil { + require.Nil(t, responses) + return + } + + require.NotNil(t, responses) + require.Len(t, responses, len(tc.requests)) + + for i, res := range responses { + wantRes := tc.wantResponses[i] + require.Equal(t, wantRes.Style(), res.Style()) + + switch r := res.(type) { + case pam.StringConvResponse: + require.Equal(t, wantRes, res) + case pam.BinaryConvResponse: + wantBinRes, ok := wantRes.(pam.BinaryConvResponse) + require.True(t, ok) + wb, err := wantBinRes.Decode(bytesPointerDecoder) + require.NoError(t, err) + bytes, err := r.Decode(bytesPointerDecoder) + require.NoError(t, err) + require.Equal(t, wb, bytes) + default: + t.Fatalf("conversation %#v is not handled", r) + } + } + }) + } +} diff --git a/pam/pam_test/utils.go b/pam/pam_test/utils.go new file mode 100644 index 0000000000..594ecd9b62 --- /dev/null +++ b/pam/pam_test/utils.go @@ -0,0 +1,30 @@ +// Package pam_test includes test tools for the PAM module +package pam_test + +/* +#include +*/ +import "C" +import ( + "unsafe" + + "github.com/msteinert/pam" +) + +func allocateCBytes(bytes []byte) pam.BinaryPointer { + if bytes == nil { + return nil + } + return pam.BinaryPointer(C.CBytes(bytes)) +} + +func cBytesToBytes(ptr pam.BinaryPointer, size int) []byte { + if ptr == nil { + return nil + } + return C.GoBytes(unsafe.Pointer(ptr), C.int(size)) +} + +func releaseCBytesPointer(ptr pam.BinaryPointer) { + C.free(unsafe.Pointer(ptr)) +} From 41bbd9158bf597134623e524edb930d3ded0c56e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Trevisan=20=28Trevi=C3=B1o=29?= Date: Mon, 27 Nov 2023 20:22:34 +0100 Subject: [PATCH 10/12] pam: Use dummy transaction to run PAM module as test cli app As per this, do not define main() as an actual function when we're generating the module, to avoid adding unwanted code in the library. --- pam/main-cli.go | 45 +++++++++++++++++++++++++++++++++++++++++++++ pam/pam.go | 23 +---------------------- pam/pam_module.go | 6 +++++- 3 files changed, 51 insertions(+), 23 deletions(-) create mode 100644 pam/main-cli.go diff --git a/pam/main-cli.go b/pam/main-cli.go new file mode 100644 index 0000000000..61692c6b37 --- /dev/null +++ b/pam/main-cli.go @@ -0,0 +1,45 @@ +//go:build pam_binary_cli + +package main + +import ( + "fmt" + "os" + + "github.com/msteinert/pam" + "github.com/sirupsen/logrus" + "github.com/ubuntu/authd/internal/log" + "github.com/ubuntu/authd/pam/pam_test" +) + +// Simulating pam on the CLI for manual testing. +func main() { + log.SetLevel(log.DebugLevel) + f, err := os.OpenFile("/tmp/logdebug", os.O_CREATE|os.O_APPEND|os.O_RDWR, 0600) + if err != nil { + panic(err) + } + defer f.Close() + logrus.SetOutput(f) + + module := &pamModule{} + mTx := pam_test.NewModuleTransactionDummy(pam.ConversationFunc( + func(style pam.Style, msg string) (string, error) { + switch style { + case pam.TextInfo: + fmt.Fprintf(os.Stderr, "PAM INFO: %s\n", msg) + case pam.ErrorMsg: + fmt.Fprintf(os.Stderr, "PAM ERROR: %s\n", msg) + default: + return "", fmt.Errorf("pam style %d not implemented", style) + } + return "", nil + })) + + authResult := module.Authenticate(mTx, pam.Flags(0), nil) + fmt.Println("Auth return:", authResult) + + // Simulate setting auth broker as default. + accMgmtResult := module.AcctMgmt(mTx, pam.Flags(0), nil) + fmt.Println("Acct mgmt return:", accMgmtResult) +} diff --git a/pam/pam.go b/pam/pam.go index c9f79a0319..8ebc344fda 100644 --- a/pam/pam.go +++ b/pam/pam.go @@ -1,4 +1,4 @@ -//go:generate go run github.com/msteinert/pam/cmd/pam-moduler -libname "pam_authd.so" -no-main -type pamModule +//go:generate go run github.com/msteinert/pam/cmd/pam-moduler -libname "pam_authd.so" -type pamModule -tags !pam_binary_cli //go:generate go generate --skip="pam_module.go" //go:generate sh -c "cc -o go-loader/pam_go_loader.so go-loader/module.c -Wl,--as-needed -Wl,--allow-shlib-undefined -shared -fPIC -Wl,--unresolved-symbols=report-all -lpam && chmod 600 go-loader/pam_go_loader.so" @@ -17,7 +17,6 @@ import ( tea "github.com/charmbracelet/bubbletea" "github.com/msteinert/pam" - "github.com/sirupsen/logrus" "github.com/ubuntu/authd" "github.com/ubuntu/authd/internal/consts" "github.com/ubuntu/authd/internal/log" @@ -233,23 +232,3 @@ func (h *pamModule) CloseSession(pam.ModuleTransaction, pam.Flags, []string) err func go_pam_cleanup_module() { runtime.GC() } - -// Simulating pam on the CLI for manual testing. -func main() { - log.SetLevel(log.DebugLevel) - f, err := os.OpenFile("/tmp/logdebug", os.O_CREATE|os.O_APPEND|os.O_RDWR, 0600) - if err != nil { - panic(err) - } - defer f.Close() - logrus.SetOutput(f) - - module := &pamModule{} - - authResult := module.Authenticate(nil, pam.Flags(0), nil) - fmt.Println("Auth return:", authResult) - - // Simulate setting auth broker as default. - accMgmtResult := module.AcctMgmt(nil, pam.Flags(0), nil) - fmt.Println("Acct mgmt return:", accMgmtResult) -} diff --git a/pam/pam_module.go b/pam/pam_module.go index 4dfe8e9c77..675d3ea834 100644 --- a/pam/pam_module.go +++ b/pam/pam_module.go @@ -1,4 +1,6 @@ -// Code generated by "pam-moduler -libname pam_authd.so -no-main -type pamModule"; DO NOT EDIT. +// Code generated by "pam-moduler -libname pam_authd.so -type pamModule -tags !pam_binary_cli"; DO NOT EDIT. + +//go:build !pam_binary_cli //go:generate go build "-ldflags=-extldflags -Wl,-soname,pam_authd.so" -buildmode=c-shared -o pam_authd.so -tags go_pam_module @@ -91,3 +93,5 @@ func pam_sm_close_session(pamh *C.pam_handle_t, flags C.int, argc C.int, argv ** func pam_sm_chauthtok(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.ChangeAuthTok) } + +func main() {} From 33fbd9de255b854aa39e9161349ba01826d92d0f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Trevisan=20=28Trevi=C3=B1o=29?= Date: Fri, 20 Oct 2023 02:19:54 +0200 Subject: [PATCH 11/12] CI: Run PAM tests with address sanitizer enabled PAM code is using CGO quite a lot, so run tests using address sanitizer to catch memory issues and leaks --- .github/workflows/qa.yaml | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/.github/workflows/qa.yaml b/.github/workflows/qa.yaml index 1eb4e83d03..0fa51001fa 100644 --- a/.github/workflows/qa.yaml +++ b/.github/workflows/qa.yaml @@ -116,6 +116,14 @@ jobs: - name: Run tests (with race detector) run: | go test -race ./... + - name: Run PAM tests (with Address Sanitizer) + env: + # Do not optimize, keep debug symbols and frame pointer for better + # stack trace information in case of ASAN errors. + CGO_CFLAGS: "-O0 -g3 -fno-omit-frame-pointer" + run: | + # Use `-dwarflocationlists` to give ASAN a better time to unwind the stack trace + go test -C pam -asan -gcflags="-dwarflocationlists=true" ./... - name: Upload coverage to Codecov uses: codecov/codecov-action@v3 with: From c974f7896de44b46f71bbaebc4e9439f409eea1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Trevisan=20=28Trevi=C3=B1o=29?= Date: Fri, 20 Oct 2023 02:55:38 +0200 Subject: [PATCH 12/12] pam_tests: Force doing a memory leak check at tests Cleanup ASAN in go does not catch memory leaks properly at the end of the test program execution, so force this using a wrapper function that is called when each test is completed. --- pam/pam_test/module-transaction-dummy_test.go | 16 ++++++++++++ pam/pam_test/utils.go | 26 +++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/pam/pam_test/module-transaction-dummy_test.go b/pam/pam_test/module-transaction-dummy_test.go index 35096e78f5..ba001bd2a6 100644 --- a/pam/pam_test/module-transaction-dummy_test.go +++ b/pam/pam_test/module-transaction-dummy_test.go @@ -21,6 +21,7 @@ func bytesPointerDecoder(ptr pam.BinaryPointer) ([]byte, error) { func TestSetGetItem(t *testing.T) { t.Parallel() + t.Cleanup(maybeDoLeakCheck) tests := map[string]struct { item pam.Item @@ -64,6 +65,7 @@ func TestSetGetItem(t *testing.T) { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() + t.Cleanup(maybeDoLeakCheck) tx := NewModuleTransactionDummy(nil) @@ -83,6 +85,7 @@ func TestSetGetItem(t *testing.T) { func TestSetPutEnv(t *testing.T) { t.Parallel() + t.Cleanup(maybeDoLeakCheck) tests := map[string]struct { env string @@ -141,6 +144,7 @@ func TestSetPutEnv(t *testing.T) { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() + t.Cleanup(maybeDoLeakCheck) tx := NewModuleTransactionDummy(nil) envList, err := tx.GetEnvList() @@ -191,6 +195,7 @@ func TestSetPutEnv(t *testing.T) { func TestSetGetData(t *testing.T) { t.Parallel() + t.Cleanup(maybeDoLeakCheck) tests := map[string]struct { key string @@ -242,6 +247,7 @@ func TestSetGetData(t *testing.T) { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() + t.Cleanup(maybeDoLeakCheck) tx := NewModuleTransactionDummy(nil) @@ -268,6 +274,7 @@ func TestSetGetData(t *testing.T) { func TestGetUser(t *testing.T) { t.Parallel() + t.Cleanup(maybeDoLeakCheck) tests := map[string]struct { presetUser string @@ -315,6 +322,7 @@ func TestGetUser(t *testing.T) { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() + t.Cleanup(maybeDoLeakCheck) tx := NewModuleTransactionDummy(tc.convHandler) @@ -333,6 +341,7 @@ func TestGetUser(t *testing.T) { func TestStartStringConv(t *testing.T) { t.Parallel() + t.Cleanup(maybeDoLeakCheck) tests := map[string]struct { prompt string @@ -385,6 +394,7 @@ func TestStartStringConv(t *testing.T) { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() + t.Cleanup(maybeDoLeakCheck) convFunCalled := false tx := NewModuleTransactionDummy(func() pam.ConversationFunc { @@ -432,6 +442,7 @@ func TestStartStringConv(t *testing.T) { func TestStartBinaryConv(t *testing.T) { t.Parallel() + t.Cleanup(maybeDoLeakCheck) tests := map[string]struct { request []byte @@ -471,6 +482,7 @@ func TestStartBinaryConv(t *testing.T) { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() + t.Cleanup(maybeDoLeakCheck) convFunCalled := false tx := NewModuleTransactionDummy(func() pam.ConversationHandler { @@ -514,6 +526,7 @@ func TestStartBinaryConv(t *testing.T) { func TestStartBinaryPointerConv(t *testing.T) { t.Parallel() + t.Cleanup(maybeDoLeakCheck) tests := map[string]struct { request []byte @@ -569,6 +582,7 @@ func TestStartBinaryPointerConv(t *testing.T) { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() + t.Cleanup(maybeDoLeakCheck) convFunCalled := false tx := NewModuleTransactionDummy(func() pam.ConversationHandler { @@ -666,6 +680,7 @@ func (c *multiConvHandler) RespondPAMBinary(ptr pam.BinaryPointer) ([]byte, erro func TestStartConvMulti(t *testing.T) { t.Parallel() + t.Cleanup(maybeDoLeakCheck) tests := map[string]struct { requests []pam.ConvRequest @@ -765,6 +780,7 @@ func TestStartConvMulti(t *testing.T) { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() + t.Cleanup(maybeDoLeakCheck) require.Equalf(t, len(tc.wantResponses), len(tc.requests), "Setup: mismatch on expectations / requests numbers") diff --git a/pam/pam_test/utils.go b/pam/pam_test/utils.go index 594ecd9b62..438603f2a1 100644 --- a/pam/pam_test/utils.go +++ b/pam/pam_test/utils.go @@ -3,14 +3,40 @@ package pam_test /* #include + +#ifdef __SANITIZE_ADDRESS__ +#include +#endif + +static inline void +maybe_do_leak_check (void) +{ +#ifdef __SANITIZE_ADDRESS__ + __lsan_do_leak_check(); +#endif +} */ import "C" + import ( + "runtime" + "time" "unsafe" "github.com/msteinert/pam" ) +// maybeDoLeakCheck triggers the garbage collector and if the go program is +// compiled with -asan flag, do a memory leak check. +// This is meant to be used as a test Cleanup function, to force Go detecting +// if allocated resources have been released, e.g. using +// t.Cleanup(pam_test.maybeDoLeakCheck) +func maybeDoLeakCheck() { + runtime.GC() + time.Sleep(time.Millisecond * 10) + C.maybe_do_leak_check() +} + func allocateCBytes(bytes []byte) pam.BinaryPointer { if bytes == nil { return nil