diff --git a/.github/workflows/qa.yaml b/.github/workflows/qa.yaml index 1eb4e83d03..0fa51001fa 100644 --- a/.github/workflows/qa.yaml +++ b/.github/workflows/qa.yaml @@ -116,6 +116,14 @@ jobs: - name: Run tests (with race detector) run: | go test -race ./... + - name: Run PAM tests (with Address Sanitizer) + env: + # Do not optimize, keep debug symbols and frame pointer for better + # stack trace information in case of ASAN errors. + CGO_CFLAGS: "-O0 -g3 -fno-omit-frame-pointer" + run: | + # Use `-dwarflocationlists` to give ASAN a better time to unwind the stack trace + go test -C pam -asan -gcflags="-dwarflocationlists=true" ./... - name: Upload coverage to Codecov uses: codecov/codecov-action@v3 with: diff --git a/go.mod b/go.mod index 6446b89af6..8f386c22bc 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf github.com/godbus/dbus/v5 v5.1.0 github.com/google/uuid v1.4.0 + github.com/msteinert/pam v1.2.0 github.com/sirupsen/logrus v1.9.3 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/spf13/cobra v1.8.0 @@ -62,3 +63,8 @@ require ( golang.org/x/text v0.13.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20230920204549-e6e6cdab5c13 // indirect ) + +// FIXME: Use released version once we have one! +// The branch below includes changes from this upstream PR: +// - https://github.com/msteinert/pam/pull/13 +replace github.com/msteinert/pam => github.com/3v1n0/go-pam v0.0.0-20231130030658-0f1cc6f16d45 diff --git a/go.sum b/go.sum index d34d2b04fc..61cdbe1954 100644 --- a/go.sum +++ b/go.sum @@ -36,6 +36,8 @@ cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RX cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0= cloud.google.com/go/storage v1.14.0/go.mod h1:GrKmX003DSIwi9o29oFT7YDnHYwZoctc3fOKtUw0Xmo= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= +github.com/3v1n0/go-pam v0.0.0-20231130030658-0f1cc6f16d45 h1:EmfxQhrTNAVT4rzbv6TcP2XScDl16Q8J0ZlCVbazl8c= +github.com/3v1n0/go-pam v0.0.0-20231130030658-0f1cc6f16d45/go.mod h1:d2n0DCUK8rGecChV3JzvmsDjOY4R7AYbsNxAT+ftQl0= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4= diff --git a/internal/log/log.go b/internal/log/log.go index 7dbeb0cc11..f0d3036946 100644 --- a/internal/log/log.go +++ b/internal/log/log.go @@ -43,6 +43,11 @@ func Info(_ context.Context, args ...interface{}) { logrus.Info(args...) } +// Warning is a temporary placeholder. +func Warning(_ context.Context, args ...interface{}) { + logrus.Warning(args...) +} + // Warningf is a temporary placeholder. func Warningf(_ context.Context, format string, args ...interface{}) { logrus.Warningf(format, args...) diff --git a/pam/.gitignore b/pam/.gitignore new file mode 100644 index 0000000000..271d077f69 --- /dev/null +++ b/pam/.gitignore @@ -0,0 +1,2 @@ +./*.h +*.so diff --git a/pam/authentication.go b/pam/authentication.go index d21b26ef4c..ad96f40c43 100644 --- a/pam/authentication.go +++ b/pam/authentication.go @@ -1,3 +1,4 @@ +// Package main is the package for the PAM library package main import ( @@ -7,6 +8,7 @@ import ( tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" + "github.com/msteinert/pam" "github.com/ubuntu/authd" "github.com/ubuntu/authd/internal/brokers/responses" "github.com/ubuntu/authd/internal/log" @@ -32,8 +34,9 @@ func sendIsAuthenticated(ctx context.Context, client authd.PAMClient, sessionID, access: responses.AuthCancelled, } } - return pamSystemError{ - msg: fmt.Sprintf("Authentication status failure: %v", err), + return pamError{ + status: pam.ErrSystem, + msg: fmt.Sprintf("authentication status failure: %v", err), } } @@ -126,15 +129,22 @@ func (m *authenticationModel) Update(msg tea.Msg) (authenticationModel, tea.Cmd) return *m, sendEvent(pamSuccess{brokerID: m.currentBrokerID}) case responses.AuthRetry: - m.errorMsg = dataToMsg(msg.msg) + errorMsg, err := dataToMsg(msg.msg) + if err != nil { + return *m, sendEvent(pamError{status: pam.ErrSystem, msg: err.Error()}) + } + m.errorMsg = errorMsg return *m, sendEvent(startAuthentication{}) case responses.AuthDenied: - errMsg := "Access denied" - if err := dataToMsg(msg.msg); err != "" { - errMsg = err + errMsg, err := dataToMsg(msg.msg) + if err != nil { + return *m, sendEvent(pamError{status: pam.ErrSystem, msg: err.Error()}) + } + if errMsg == "" { + errMsg = "Access denied" } - return *m, sendEvent(pamAuthError{msg: errMsg}) + return *m, sendEvent(pamError{status: pam.ErrAuth, msg: errMsg}) case responses.AuthNext: return *m, sendEvent(GetAuthenticationModesRequested{}) @@ -204,7 +214,7 @@ func (m *authenticationModel) Compose(brokerID, sessionID string, layout *authd. case "qrcode": qrcodeModel, err := newQRCodeModel(layout.GetContent(), layout.GetLabel(), layout.GetButton(), layout.GetWait() == "true") if err != nil { - return sendEvent(pamSystemError{msg: err.Error()}) + return sendEvent(pamError{status: pam.ErrSystem, msg: err.Error()}) } m.currentModel = qrcodeModel @@ -213,7 +223,10 @@ func (m *authenticationModel) Compose(brokerID, sessionID string, layout *authd. m.currentModel = newPasswordModel default: - return sendEvent(pamSystemError{msg: fmt.Sprintf("unknown layout type: %q", layout.Type)}) + return sendEvent(pamError{ + status: pam.ErrSystem, + msg: fmt.Sprintf("unknown layout type: %q", layout.Type), + }) } return sendEvent(startAuthentication{}) @@ -246,24 +259,22 @@ func (m *authenticationModel) Reset() { } // dataToMsg returns the data message from a given JSON message. -func dataToMsg(data string) string { +func dataToMsg(data string) (string, error) { if data == "" { - return "" + return "", nil } v := make(map[string]string) if err := json.Unmarshal([]byte(data), &v); err != nil { - log.Infof(context.TODO(), "Invalid json data from provider: %v", data) - return "" + return "", fmt.Errorf("invalid json data from provider: %v", err) } if len(v) == 0 { - return "" + return "", nil } r, ok := v["message"] if !ok { - log.Debugf(context.TODO(), "No message entry in json data from provider: %v", data) - return "" + return "", fmt.Errorf("no message entry in json data from provider: %v", v) } - return r + return r, nil } diff --git a/pam/authmodeselection.go b/pam/authmodeselection.go index f03d962055..05f210476a 100644 --- a/pam/authmodeselection.go +++ b/pam/authmodeselection.go @@ -9,6 +9,7 @@ import ( "github.com/charmbracelet/bubbles/list" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" + "github.com/msteinert/pam" "github.com/ubuntu/authd" "github.com/ubuntu/authd/internal/log" ) @@ -241,17 +242,15 @@ func getAuthenticationModes(client authd.PAMClient, sessionID string, uiLayouts gamResp, err := client.GetAuthenticationModes(context.Background(), gamReq) if err != nil { - return pamSystemError{ - msg: fmt.Sprintf("could not get authentication modes: %v", err), + return pamError{ + status: pam.ErrSystem, + msg: fmt.Sprintf("could not get authentication modes: %v", err), } } authModes := gamResp.GetAuthenticationModes() if len(authModes) == 0 { - return pamIgnore{ - // TODO: probably go back to broker selection here - msg: "no supported authentication mode available for this provider", - } + return pamIgnore{msg: "no supported authentication mode available for this provider"} } log.Info(context.TODO(), authModes) diff --git a/pam/brokerselection.go b/pam/brokerselection.go index a69e42fa45..64005cc994 100644 --- a/pam/brokerselection.go +++ b/pam/brokerselection.go @@ -9,6 +9,7 @@ import ( "github.com/charmbracelet/bubbles/list" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" + "github.com/msteinert/pam" "github.com/ubuntu/authd" "github.com/ubuntu/authd/internal/log" ) @@ -232,8 +233,9 @@ func getAvailableBrokers(client authd.PAMClient) tea.Cmd { return func() tea.Msg { brokersInfo, err := client.AvailableBrokers(context.TODO(), &authd.Empty{}) if err != nil { - return pamSystemError{ - msg: fmt.Sprintf("could not get current available brokers: %v", err), + return pamError{ + status: pam.ErrSystem, + msg: fmt.Sprintf("could not get current available brokers: %v", err), } } diff --git a/pam/commands.go b/pam/commands.go index 18f2d0e348..cfaea5ff80 100644 --- a/pam/commands.go +++ b/pam/commands.go @@ -7,6 +7,7 @@ import ( "strings" tea "github.com/charmbracelet/bubbletea" + "github.com/msteinert/pam" "github.com/ubuntu/authd" "github.com/ubuntu/authd/internal/log" ) @@ -44,16 +45,16 @@ func startBrokerSession(client authd.PAMClient, brokerID, username string) tea.C sbResp, err := client.SelectBroker(context.TODO(), sbReq) if err != nil { - return pamSystemError{msg: fmt.Sprintf("can't select broker: %v", err)} + return pamError{status: pam.ErrSystem, msg: fmt.Sprintf("can't select broker: %v", err)} } sessionID := sbResp.GetSessionId() if sessionID == "" { - return pamSystemError{msg: "no session ID returned by broker"} + return pamError{status: pam.ErrSystem, msg: "no session ID returned by broker"} } encryptionKey := sbResp.GetEncryptionKey() if encryptionKey == "" { - return pamSystemError{msg: "no encryption key returned by broker"} + return pamError{status: pam.ErrSystem, msg: "no encryption key returned by broker"} } return SessionStarted{ @@ -73,16 +74,18 @@ func getLayout(client authd.PAMClient, sessionID, authModeID string) tea.Cmd { } uiInfo, err := client.SelectAuthenticationMode(context.TODO(), samReq) if err != nil { - return pamSystemError{ - // TODO: probably go back to broker selection here - msg: fmt.Sprintf("can't select authentication mode: %v", err), + // TODO: probably go back to broker selection here + return pamError{ + status: pam.ErrSystem, + msg: fmt.Sprintf("can't select authentication mode: %v", err), } } if uiInfo.UiLayoutInfo == nil { - return pamSystemError{ - // TODO: probably go back to broker selection here - msg: "invalid empty UI Layout information from broker", + // TODO: probably go back to broker selection here + return pamError{ + status: pam.ErrSystem, + msg: "invalid empty UI Layout information from broker", } } diff --git a/pam/main-cli.go b/pam/main-cli.go new file mode 100644 index 0000000000..61692c6b37 --- /dev/null +++ b/pam/main-cli.go @@ -0,0 +1,45 @@ +//go:build pam_binary_cli + +package main + +import ( + "fmt" + "os" + + "github.com/msteinert/pam" + "github.com/sirupsen/logrus" + "github.com/ubuntu/authd/internal/log" + "github.com/ubuntu/authd/pam/pam_test" +) + +// Simulating pam on the CLI for manual testing. +func main() { + log.SetLevel(log.DebugLevel) + f, err := os.OpenFile("/tmp/logdebug", os.O_CREATE|os.O_APPEND|os.O_RDWR, 0600) + if err != nil { + panic(err) + } + defer f.Close() + logrus.SetOutput(f) + + module := &pamModule{} + mTx := pam_test.NewModuleTransactionDummy(pam.ConversationFunc( + func(style pam.Style, msg string) (string, error) { + switch style { + case pam.TextInfo: + fmt.Fprintf(os.Stderr, "PAM INFO: %s\n", msg) + case pam.ErrorMsg: + fmt.Fprintf(os.Stderr, "PAM ERROR: %s\n", msg) + default: + return "", fmt.Errorf("pam style %d not implemented", style) + } + return "", nil + })) + + authResult := module.Authenticate(mTx, pam.Flags(0), nil) + fmt.Println("Auth return:", authResult) + + // Simulate setting auth broker as default. + accMgmtResult := module.AcctMgmt(mTx, pam.Flags(0), nil) + fmt.Println("Acct mgmt return:", accMgmtResult) +} diff --git a/pam/model.go b/pam/model.go index fb8fb2ecc6..c16cc1c6d6 100644 --- a/pam/model.go +++ b/pam/model.go @@ -2,10 +2,10 @@ package main import ( "context" - "fmt" "strings" tea "github.com/charmbracelet/bubbletea" + "github.com/msteinert/pam" "github.com/ubuntu/authd" "github.com/ubuntu/authd/internal/log" ) @@ -35,7 +35,7 @@ type sessionInfo struct { // model is the global models orchestrator. type model struct { - pamh pamHandle + pamMTx pam.ModuleTransaction client authd.PAMClient height int @@ -49,7 +49,7 @@ type model struct { authModeSelectionModel authModeSelectionModel authenticationModel authenticationModel - exitMsg fmt.Stringer + exitStatus pamReturnStatus } /* global events */ @@ -87,7 +87,8 @@ type SessionEnded struct{} // Init initializes the main model orchestrator. func (m *model) Init() tea.Cmd { - m.userSelectionModel = newUserSelectionModel(m.pamh) + m.exitStatus = pamError{status: pam.ErrSystem, msg: "model did not return anything"} + m.userSelectionModel = newUserSelectionModel(m.pamMTx) var cmds []tea.Cmd cmds = append(cmds, m.userSelectionModel.Init()) @@ -113,7 +114,10 @@ func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case tea.KeyMsg: switch msg.String() { case "ctrl+c": - return m, sendEvent(pamAbort{msg: "cancel requested"}) + return m, sendEvent(pamError{ + status: pam.ErrAbort, + msg: "cancel requested", + }) case "esc": if m.brokerSelectionModel.WillCaptureEscape() || m.authModeSelectionModel.WillCaptureEscape() { break @@ -137,20 +141,8 @@ func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { m.brokerSelectionModel.SetWidth(m.width) // Exit cases - case pamIgnore: - m.exitMsg = msg - return m, m.quit() - case pamAbort: - m.exitMsg = msg - return m, m.quit() - case pamSystemError: - m.exitMsg = msg - return m, m.quit() - case pamAuthError: - m.exitMsg = msg - return m, m.quit() - case pamSuccess: - m.exitMsg = msg + case pamReturnStatus: + m.exitStatus = msg return m, m.quit() // Events @@ -194,7 +186,10 @@ func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { msg.ID = m.authModeSelectionModel.currentAuthModeSelectedID } if msg.ID == "" { - return m, sendEvent(pamSystemError{msg: "reselection of current auth mode without current ID"}) + return m, sendEvent(pamError{ + status: pam.ErrSystem, + msg: "reselection of current auth mode without current ID", + }) } return m, getLayout(m.client, m.currentSession.sessionID, msg.ID) diff --git a/pam/pam.go b/pam/pam.go index 0c6c3c673a..8ebc344fda 100644 --- a/pam/pam.go +++ b/pam/pam.go @@ -1,25 +1,22 @@ -package main +//go:generate go run github.com/msteinert/pam/cmd/pam-moduler -libname "pam_authd.so" -type pamModule -tags !pam_binary_cli +//go:generate go generate --skip="pam_module.go" +//go:generate sh -c "cc -o go-loader/pam_go_loader.so go-loader/module.c -Wl,--as-needed -Wl,--allow-shlib-undefined -shared -fPIC -Wl,--unresolved-symbols=report-all -lpam && chmod 600 go-loader/pam_go_loader.so" -/* -#cgo LDFLAGS: -lpam -fPIC -#include -#include -#include -#include +// Package main is the package for the PAM library. +package main -char *string_from_argv(int i, char **argv); -*/ import "C" import ( "context" + "errors" "fmt" "os" "runtime" "strings" tea "github.com/charmbracelet/bubbletea" - "github.com/sirupsen/logrus" + "github.com/msteinert/pam" "github.com/ubuntu/authd" "github.com/ubuntu/authd/internal/consts" "github.com/ubuntu/authd/internal/log" @@ -28,21 +25,54 @@ import ( "google.golang.org/grpc/credentials/insecure" ) -var ( - // brokerIDUsedToAuthenticate global variable is for the second stage authentication to select the default broker for the current user. - brokerIDUsedToAuthenticate string -) +// pamModule is the structure that implements the pam.ModuleHandler interface +// that is called during pam operations. +type pamModule struct { +} -//go:generate sh -c "cc -o go-loader/pam_go_loader.so go-loader/module.c -Wl,--as-needed -Wl,--allow-shlib-undefined -shared -fPIC -Wl,--unresolved-symbols=report-all -lpam && chmod 600 go-loader/pam_go_loader.so" -//go:generate sh -c "go build -ldflags='-extldflags -Wl,-soname,pam_authd.so' -buildmode=c-shared -o pam_authd.so" +const ( + // authenticationBrokerIDKey is the Key used to store the data in the + // PAM module for the second stage authentication to select the default + // broker for the current user. + authenticationBrokerIDKey = "authentication-broker-id" +) /* + FIXME: provide instructions using pam-auth-update instead! Add to /etc/pam.d/common-auth auth [success=3 default=die ignore=ignore] pam_authd.so */ -//export pam_sm_authenticate -func pam_sm_authenticate(pamh *C.pam_handle_t, flags, argc C.int, argv **C.char) C.int { +func showPamMessage(mTx pam.ModuleTransaction, style pam.Style, msg string) error { + switch style { + case pam.TextInfo | pam.ErrorMsg: + default: + return fmt.Errorf("message style not supported: %v", style) + } + if _, err := mTx.StartStringConv(style, msg); err != nil { + log.Errorf(context.TODO(), "Failed sending message to pam: %v", err) + return err + } + return nil +} + +func sendReturnMessageToPam(mTx pam.ModuleTransaction, retStatus pamReturnStatus) { + msg := retStatus.Message() + if msg == "" { + return + } + + style := pam.ErrorMsg + switch retStatus.(type) { + case pamIgnore, pamSuccess: + style = pam.TextInfo + } + + _ = showPamMessage(mTx, style, msg) +} + +// Authenticate is the method that is invoked during pam_authenticate request. +func (h *pamModule) Authenticate(mTx pam.ModuleTransaction, flags pam.Flags, args []string) error { // Initialize localization // TODO @@ -51,19 +81,23 @@ func pam_sm_authenticate(pamh *C.pam_handle_t, flags, argc C.int, argv **C.char) interactiveTerminal := term.IsTerminal(int(os.Stdin.Fd())) - client, closeConn, err := newClient(argc, argv) + client, closeConn, err := newClient(args) if err != nil { log.Debug(context.TODO(), err) - return C.PAM_IGNORE + return pam.ErrAuthinfoUnavail } defer closeConn() appState := model{ - pamh: pamh, + pamMTx: mTx, client: client, interactiveTerminal: interactiveTerminal, } + if err := mTx.SetData(authenticationBrokerIDKey, nil); err != nil { + return err + } + //tea.WithInput(nil) //tea.WithoutRenderer() var opts []tea.ProgramOption @@ -73,68 +107,67 @@ func pam_sm_authenticate(pamh *C.pam_handle_t, flags, argc C.int, argv **C.char) p := tea.NewProgram(&appState, opts...) if _, err := p.Run(); err != nil { log.Errorf(context.TODO(), "Cancelled authentication: %v", err) - return C.PAM_ABORT + return pam.ErrAbort } - logErrMsg := "unknown" - var errCode C.int = C.PAM_SYSTEM_ERR + sendReturnMessageToPam(mTx, appState.exitStatus) - switch exitMsg := appState.exitMsg.(type) { + switch exitStatus := appState.exitStatus.(type) { case pamSuccess: - brokerIDUsedToAuthenticate = exitMsg.brokerID - return C.PAM_SUCCESS + if err := mTx.SetData(authenticationBrokerIDKey, exitStatus.brokerID); err != nil { + return err + } + return nil + case pamIgnore: // localBrokerID is only set on pamIgnore if the user has chosen local broker. - brokerIDUsedToAuthenticate = exitMsg.localBrokerID - if exitMsg.String() != "" { - log.Debugf(context.TODO(), "Ignoring authd authentication: %s", exitMsg) - } - logErrMsg = "" - errCode = C.PAM_IGNORE - case pamAbort: - if exitMsg.String() != "" { - logErrMsg = fmt.Sprintf("cancelled authentication: %s", exitMsg) + if err := mTx.SetData(authenticationBrokerIDKey, exitStatus.localBrokerID); err != nil { + return err } - errCode = C.PAM_ABORT - case pamAuthError: - if exitMsg.String() != "" { - logErrMsg = fmt.Sprintf("authentication: %s", exitMsg) - } - errCode = C.PAM_AUTH_ERR - case pamSystemError: - if exitMsg.String() != "" { - logErrMsg = fmt.Sprintf("system: %s", exitMsg) - } - errCode = C.PAM_SYSTEM_ERR - } + return fmt.Errorf("%w: %s", exitStatus.Status(), exitStatus.Message()) - if logErrMsg != "" { - fmt.Fprintf(os.Stderr, "Error: %v\n", logErrMsg) + case pamReturnError: + return fmt.Errorf("%w: %s", exitStatus.Status(), exitStatus.Message()) } - return errCode + return fmt.Errorf("%w: unknown exit code", pam.ErrSystem) } -// pam_sm_acct_mgmt sets any used brokerID as default for the user. -// -//export pam_sm_acct_mgmt -func pam_sm_acct_mgmt(pamh *C.pam_handle_t, flags, argc C.int, argv **C.char) C.int { +// AcctMgmt sets any used brokerID as default for the user. +func (h *pamModule) AcctMgmt(mTx pam.ModuleTransaction, flags pam.Flags, args []string) error { + brokerData, err := mTx.GetData(authenticationBrokerIDKey) + if err != nil && errors.Is(err, pam.ErrNoModuleData) { + return pam.ErrIgnore + } + + brokerIDUsedToAuthenticate, ok := brokerData.(string) + if !ok { + msg := fmt.Sprintf("broker data as an invalid type %#v", brokerData) + log.Errorf(context.TODO(), msg) + _ = showPamMessage(mTx, pam.ErrorMsg, msg) + return pam.ErrIgnore + } + // Only set the brokerID as default if we stored one after authentication. if brokerIDUsedToAuthenticate == "" { - return C.PAM_IGNORE + return pam.ErrIgnore } // Get current user for broker - user := getPAMUser(pamh) + user, err := mTx.GetItem(pam.User) + if err != nil { + return err + } + if user == "" { log.Infof(context.TODO(), "can't get user from PAM") - return C.PAM_IGNORE + return pam.ErrIgnore } - client, closeConn, err := newClient(argc, argv) + client, closeConn, err := newClient(args) if err != nil { log.Debugf(context.TODO(), "%s", err) - return C.PAM_IGNORE + return pam.ErrIgnore } defer closeConn() @@ -144,15 +177,15 @@ func pam_sm_acct_mgmt(pamh *C.pam_handle_t, flags, argc C.int, argv **C.char) C. } if _, err := client.SetDefaultBrokerForUser(context.TODO(), &req); err != nil { log.Infof(context.TODO(), "Can't set default broker (%q) for %q: %v", brokerIDUsedToAuthenticate, user, err) - return C.PAM_IGNORE + return pam.ErrIgnore } - return C.PAM_SUCCESS + return nil } -// newClient returns a new GRPC client ready to emit requests -func newClient(argc C.int, argv **C.char) (client authd.PAMClient, close func(), err error) { - conn, err := grpc.Dial("unix://"+getSocketPath(argc, argv), grpc.WithTransportCredentials(insecure.NewCredentials())) +// newClient returns a new GRPC client ready to emit requests. +func newClient(args []string) (client authd.PAMClient, close func(), err error) { + conn, err := grpc.Dial("unix://"+getSocketPath(args), grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { return nil, nil, fmt.Errorf("could not connect to authd: %v", err) } @@ -160,9 +193,9 @@ func newClient(argc C.int, argv **C.char) (client authd.PAMClient, close func(), } // getSocketPath returns the socket path to connect to which can be overridden manually. -func getSocketPath(argc C.int, argv **C.char) string { +func getSocketPath(args []string) string { socketPath := consts.DefaultSocketPath - for _, arg := range sliceFromArgv(argc, argv) { + for _, arg := range args { opt, optarg, _ := strings.Cut(arg, "=") switch opt { case "socket": @@ -173,32 +206,29 @@ func getSocketPath(argc C.int, argv **C.char) string { return socketPath } -//export pam_sm_setcred -func pam_sm_setcred(pamh *C.pam_handle_t, flags, argc C.int, argv **C.char) C.int { - return C.PAM_IGNORE +// SetCred is the method that is invoked during pam_setcred request. +func (h *pamModule) SetCred(pam.ModuleTransaction, pam.Flags, []string) error { + return pam.ErrIgnore } -// go_pam_cleanup_module is called by the go-loader PAM module during onload -// -//export go_pam_cleanup_module -func go_pam_cleanup_module() { - runtime.GC() +// ChangeAuthTok is the method that is invoked during pam_chauthtok request. +func (h *pamModule) ChangeAuthTok(pam.ModuleTransaction, pam.Flags, []string) error { + return pam.ErrIgnore } -// Simulating pam on the CLI for manual testing -func main() { - log.SetLevel(log.DebugLevel) - f, err := os.OpenFile("/tmp/logdebug", os.O_CREATE|os.O_APPEND|os.O_RDWR, 0644) - if err != nil { - panic(err) - } - defer f.Close() - logrus.SetOutput(f) +// OpenSession is the method that is invoked during pam_open_session request. +func (h *pamModule) OpenSession(pam.ModuleTransaction, pam.Flags, []string) error { + return pam.ErrIgnore +} - authResult := pam_sm_authenticate(nil, 0, 0, nil) - fmt.Println("Auth return:", authResult) +// CloseSession is the method that is invoked during pam_close_session request. +func (h *pamModule) CloseSession(pam.ModuleTransaction, pam.Flags, []string) error { + return pam.ErrIgnore +} - // Simulate setting auth broker as default. - accMgmtResult := pam_sm_acct_mgmt(nil, 0, 0, nil) - fmt.Println("Acct mgmt return:", accMgmtResult) +// go_pam_cleanup_module is called by the go-loader PAM module during onload. +// +//export go_pam_cleanup_module +func go_pam_cleanup_module() { + runtime.GC() } diff --git a/pam/pam_module.go b/pam/pam_module.go new file mode 100644 index 0000000000..675d3ea834 --- /dev/null +++ b/pam/pam_module.go @@ -0,0 +1,97 @@ +// Code generated by "pam-moduler -libname pam_authd.so -type pamModule -tags !pam_binary_cli"; DO NOT EDIT. + +//go:build !pam_binary_cli + +//go:generate go build "-ldflags=-extldflags -Wl,-soname,pam_authd.so" -buildmode=c-shared -o pam_authd.so -tags go_pam_module + +// Package main is the package for the PAM module library. +package main + +/* +#cgo LDFLAGS: -lpam -fPIC +#include + +typedef const char _const_char_t; +*/ +import "C" + +import ( + "errors" + "fmt" + "github.com/msteinert/pam" + "os" + "unsafe" +) + +var pamModuleHandler pam.ModuleHandler = &pamModule{} + +// sliceFromArgv returns a slice of strings given to the PAM module. +func sliceFromArgv(argc C.int, argv **C._const_char_t) []string { + r := make([]string, 0, argc) + for _, s := range unsafe.Slice(argv, argc) { + r = append(r, C.GoString(s)) + } + return r +} + +// handlePamCall is the function that translates C pam requests to Go. +func handlePamCall(pamh *C.pam_handle_t, flags C.int, argc C.int, + argv **C._const_char_t, moduleFunc pam.ModuleHandlerFunc) C.int { + if pamModuleHandler == nil { + return C.int(pam.ErrNoModuleData) + } + + if moduleFunc == nil { + return C.int(pam.ErrIgnore) + } + + mt := pam.NewModuleTransactionInvoker(pam.NativeHandle(pamh)) + err := mt.InvokeHandler(moduleFunc, pam.Flags(flags), + sliceFromArgv(argc, argv)) + if err == nil { + return 0 + } + + if (pam.Flags(flags)&pam.Silent) == 0 && !errors.Is(err, pam.ErrIgnore) { + fmt.Fprintf(os.Stderr, "module returned error: %v\n", err) + } + + var pamErr pam.Error + if errors.As(err, &pamErr) { + return C.int(pamErr) + } + + return C.int(pam.ErrSystem) +} + +//export pam_sm_authenticate +func pam_sm_authenticate(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.Authenticate) +} + +//export pam_sm_setcred +func pam_sm_setcred(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.SetCred) +} + +//export pam_sm_acct_mgmt +func pam_sm_acct_mgmt(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.AcctMgmt) +} + +//export pam_sm_open_session +func pam_sm_open_session(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.OpenSession) +} + +//export pam_sm_close_session +func pam_sm_close_session(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.CloseSession) +} + +//export pam_sm_chauthtok +func pam_sm_chauthtok(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.ChangeAuthTok) +} + +func main() {} diff --git a/pam/pam_test/module-transaction-dummy.go b/pam/pam_test/module-transaction-dummy.go new file mode 100644 index 0000000000..a9142c7e11 --- /dev/null +++ b/pam/pam_test/module-transaction-dummy.go @@ -0,0 +1,362 @@ +// Package pam_test includes test tools for the PAM module +package pam_test + +import ( + "errors" + "fmt" + "runtime" + "strings" + + "github.com/msteinert/pam" +) + +// ModuleTransactionDummy is an implementation of PamModuleTransaction for +// testing purposes. +type ModuleTransactionDummy struct { + Items map[pam.Item]string + Env map[string]string + Data map[string]any + convHandler pam.ConversationHandler +} + +// NewModuleTransactionDummy returns a new PamModuleTransactionDummy. +func NewModuleTransactionDummy(convHandler pam.ConversationHandler) pam.ModuleTransaction { + return &ModuleTransactionDummy{ + convHandler: convHandler, + Data: make(map[string]any), + Env: make(map[string]string), + Items: make(map[pam.Item]string), + } +} + +// InvokeHandler is called by the C code to invoke the proper handler. +func (m *ModuleTransactionDummy) InvokeHandler(handler pam.ModuleHandlerFunc, + flags pam.Flags, args []string) error { + return pam.ErrAbort +} + +// SetItem sets a PAM information item. +func (m *ModuleTransactionDummy) SetItem(item pam.Item, value string) error { + if item <= 0 { + return pam.ErrBadItem + } + + m.Items[item] = value + return nil +} + +// GetItem retrieves a PAM information item. +func (m *ModuleTransactionDummy) GetItem(item pam.Item) (string, error) { + if item <= 0 { + return "", pam.ErrBadItem + } + return m.Items[item], nil +} + +// PutEnv adds or changes the value of PAM environment variables. +// +// NAME=value will set a variable to a value. +// NAME= will set a variable to an empty value. +// NAME (without an "=") will delete a variable. +func (m *ModuleTransactionDummy) PutEnv(nameVal string) error { + env, value, found := strings.Cut(nameVal, "=") + if !found { + delete(m.Env, env) + return nil + } + if env == "" { + return pam.ErrBadItem + } + m.Env[env] = value + return nil +} + +// GetEnv is used to retrieve a PAM environment variable. +func (m *ModuleTransactionDummy) GetEnv(name string) string { + return m.Env[name] +} + +// GetEnvList returns a copy of the PAM environment as a map. +func (m *ModuleTransactionDummy) GetEnvList() (map[string]string, error) { + return m.Env, nil +} + +// GetUser is similar to GetItem(User), but it would start a conversation if +// no user is currently set in PAM. +func (m *ModuleTransactionDummy) GetUser(prompt string) (string, error) { + user, err := m.GetItem(pam.User) + if err != nil { + return "", err + } + if user != "" { + return user, nil + } + + resp, err := m.StartStringConv(pam.PromptEchoOn, prompt) + if err != nil { + return "", err + } + + return resp.Response(), nil +} + +// SetData allows to save any value in the module data that is preserved +// during the whole time the module is loaded. +func (m *ModuleTransactionDummy) SetData(key string, data any) error { + if data == nil { + delete(m.Data, key) + return nil + } + + m.Data[key] = data + return nil +} + +// GetData allows to get any value from the module data saved using SetData +// that is preserved across the whole time the module is loaded. +func (m *ModuleTransactionDummy) GetData(key string) (any, error) { + data, found := m.Data[key] + if !found { + return nil, pam.ErrNoModuleData + } + return data, nil +} + +// StartStringConv starts a text-based conversation using the provided style +// and prompt. +func (m *ModuleTransactionDummy) StartStringConv(style pam.Style, prompt string) ( + pam.StringConvResponse, error) { + if style == pam.BinaryPrompt { + return nil, fmt.Errorf("%w: binary style is not supported", pam.ErrConv) + } + + res, err := m.StartConv(pam.NewStringConvRequest(style, prompt)) + if err != nil { + return nil, err + } + + stringRes, ok := res.(pam.StringConvResponse) + if !ok { + return nil, fmt.Errorf("%w: can't convert to pam.StringConvResponse", pam.ErrConv) + } + return stringRes, nil +} + +// StartStringConvf allows to start string conversation with formatting support. +func (m *ModuleTransactionDummy) StartStringConvf(style pam.Style, format string, args ...interface{}) ( + pam.StringConvResponse, error) { + return m.StartStringConv(style, fmt.Sprintf(format, args...)) +} + +// StartBinaryConv starts a binary conversation using the provided bytes. +func (m *ModuleTransactionDummy) StartBinaryConv(bytes []byte) ( + pam.BinaryConvResponse, error) { + res, err := m.StartConv(NewBinaryRequestDummyFromBytes(bytes)) + if err != nil { + return nil, err + } + + binaryRes, ok := res.(pam.BinaryConvResponse) + if !ok { + return nil, fmt.Errorf("%w: can't convert to pam.BinaryConvResponse", pam.ErrConv) + } + return binaryRes, nil +} + +// StartConv initiates a PAM conversation using the provided ConvRequest. +func (m *ModuleTransactionDummy) StartConv(req pam.ConvRequest) ( + pam.ConvResponse, error) { + resp, err := m.StartConvMulti([]pam.ConvRequest{req}) + if err != nil { + return nil, err + } + if len(resp) != 1 { + return nil, fmt.Errorf("%w: not enough values returned", pam.ErrConv) + } + return resp[0], nil +} + +func (m *ModuleTransactionDummy) handleStringRequest(req pam.ConvRequest) (pam.StringConvResponse, error) { + msgStyle := req.Style() + if m.convHandler == nil { + return nil, fmt.Errorf("no conversation handler provided for style %v", msgStyle) + } + reply, err := m.convHandler.RespondPAM(msgStyle, + req.(pam.StringConvRequest).Prompt()) + if err != nil { + return nil, err + } + + return StringResponseDummy{ + msgStyle, + reply, + }, nil +} + +func (m *ModuleTransactionDummy) handleBinaryRequest(req pam.ConvRequest) (pam.BinaryConvResponse, error) { + if m.convHandler == nil { + return nil, errors.New("no binary handler provided") + } + + //nolint:forcetypeassert + // req must be a pam.BinaryConvRequester, if that's not the case we should + // just panic since this code is only expected to run in tests. + binReq := req.(pam.BinaryConvRequester) + + switch handler := m.convHandler.(type) { + case pam.BinaryConversationHandler: + r, err := handler.RespondPAMBinary(binReq.Pointer()) + if err != nil { + return nil, err + } + return binReq.CreateResponse(pam.BinaryPointer(&r)), nil + + case pam.BinaryPointerConversationHandler: + r, err := handler.RespondPAMBinary(binReq.Pointer()) + if err != nil { + if r != nil { + resp := binReq.CreateResponse(r) + resp.Release() + } + return nil, err + } + return binReq.CreateResponse(r), nil + + default: + return nil, fmt.Errorf("unsupported conversation handler %#v", handler) + } +} + +// StartConvMulti initiates a PAM conversation with multiple ConvRequest's. +func (m *ModuleTransactionDummy) StartConvMulti(requests []pam.ConvRequest) ( + responses []pam.ConvResponse, err error) { + defer func() { + if err != nil { + err = errors.Join(pam.ErrConv, err) + } + }() + + if len(requests) == 0 { + return nil, errors.New("no requests defined") + } + + responses = make([]pam.ConvResponse, 0, len(requests)) + for _, req := range requests { + msgStyle := req.Style() + switch msgStyle { + case pam.PromptEchoOff: + fallthrough + case pam.PromptEchoOn: + fallthrough + case pam.ErrorMsg: + fallthrough + case pam.TextInfo: + response, err := m.handleStringRequest(req) + if err != nil { + return nil, err + } + responses = append(responses, response) + case pam.BinaryPrompt: + response, err := m.handleBinaryRequest(req) + if err != nil { + return nil, err + } + responses = append(responses, response) + default: + return nil, fmt.Errorf("unsupported conversation type %v", msgStyle) + } + } + + return responses, nil +} + +// BinaryRequestDummy is a dummy pam.BinaryConvRequester implementation. +type BinaryRequestDummy struct { + ptr pam.BinaryPointer +} + +// NewBinaryRequestDummy creates a new BinaryConvRequest with finalizer +// for response BinaryResponse. +func NewBinaryRequestDummy(ptr pam.BinaryPointer) *BinaryRequestDummy { + return &BinaryRequestDummy{ptr} +} + +// NewBinaryRequestDummyFromBytes creates a new BinaryConvRequestDummy from +// an array of bytes. +func NewBinaryRequestDummyFromBytes(bytes []byte) *BinaryRequestDummy { + if bytes == nil { + return &BinaryRequestDummy{} + } + return NewBinaryRequestDummy(pam.BinaryPointer(&bytes)) +} + +// Style returns the response style for the request, so always BinaryPrompt. +func (b BinaryRequestDummy) Style() pam.Style { + return pam.BinaryPrompt +} + +// Pointer returns the conversation style of the StringConvRequest. +func (b BinaryRequestDummy) Pointer() pam.BinaryPointer { + return b.ptr +} + +// CreateResponse creates a new BinaryConvResponse from the request. +func (b BinaryRequestDummy) CreateResponse(ptr pam.BinaryPointer) pam.BinaryConvResponse { + bcr := &BinaryResponseDummy{ptr} + runtime.SetFinalizer(bcr, func(bcr *BinaryResponseDummy) { + bcr.Release() + }) + return bcr +} + +// Release releases the resources allocated by the request. +func (b *BinaryRequestDummy) Release() { + b.ptr = nil +} + +// StringResponseDummy is a simple implementation of pam.StringConvResponse. +type StringResponseDummy struct { + style pam.Style + content string +} + +// Style returns the conversation style of the StringResponseDummy. +func (s StringResponseDummy) Style() pam.Style { + return s.style +} + +// Response returns the string response of the StringResponseDummy. +func (s StringResponseDummy) Response() string { + return s.content +} + +// BinaryResponseDummy is an implementation of pam.BinaryConvResponse. +type BinaryResponseDummy struct { + ptr pam.BinaryPointer +} + +// Style returns the response style for the response, so always BinaryPrompt. +func (b BinaryResponseDummy) Style() pam.Style { + return pam.BinaryPrompt +} + +// Data returns the response native pointer, it's up to the protocol to parse +// it accordingly. +func (b BinaryResponseDummy) Data() pam.BinaryPointer { + return b.ptr +} + +// Release releases the memory associated with the pointer. +func (b *BinaryResponseDummy) Release() { + b.ptr = nil + runtime.SetFinalizer(b, nil) +} + +// Decode decodes the binary data using the provided decoder function. +func (b BinaryResponseDummy) Decode(decoder pam.BinaryDecoder) ( + []byte, error) { + if decoder == nil { + return nil, errors.New("nil decoder provided") + } + return decoder(b.Data()) +} diff --git a/pam/pam_test/module-transaction-dummy_test.go b/pam/pam_test/module-transaction-dummy_test.go new file mode 100644 index 0000000000..ba001bd2a6 --- /dev/null +++ b/pam/pam_test/module-transaction-dummy_test.go @@ -0,0 +1,833 @@ +package pam_test + +import ( + "fmt" + "testing" + + "github.com/msteinert/pam" + "github.com/stretchr/testify/require" +) + +func ptrValue[T any](value T) *T { + return &value +} + +func bytesPointerDecoder(ptr pam.BinaryPointer) ([]byte, error) { + if ptr == nil { + return nil, nil + } + return *(*[]byte)(ptr), nil +} + +func TestSetGetItem(t *testing.T) { + t.Parallel() + t.Cleanup(maybeDoLeakCheck) + + tests := map[string]struct { + item pam.Item + value *string + + wantValue *string + wantGetError error + wantSetError error + }{ + "Set user": { + item: pam.User, + value: ptrValue("an user"), + }, + + "Returns empty when getting an unset user": { + item: pam.User, + wantValue: ptrValue(""), + }, + + "Setting and getting an user": { + item: pam.User, + value: ptrValue("the-user"), + wantValue: ptrValue("the-user"), + }, + + // Error cases + "Error when setting invalid item": { + item: pam.Item(-1), + value: ptrValue("some value"), + wantSetError: pam.ErrBadItem, + }, + + "Error when getting invalid item": { + item: pam.Item(-1), + wantGetError: pam.ErrBadItem, + wantValue: ptrValue(""), + }, + } + + for name, tc := range tests { + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + t.Cleanup(maybeDoLeakCheck) + + tx := NewModuleTransactionDummy(nil) + + if tc.value != nil { + err := tx.SetItem(tc.item, *tc.value) + require.ErrorIs(t, err, tc.wantSetError) + } + + if tc.wantValue != nil { + value, err := tx.GetItem(tc.item) + require.Equal(t, value, *tc.wantValue) + require.ErrorIs(t, err, tc.wantGetError) + } + }) + } +} + +func TestSetPutEnv(t *testing.T) { + t.Parallel() + t.Cleanup(maybeDoLeakCheck) + + tests := map[string]struct { + env string + value *string + presetValues map[string]string + skipPut bool + + wantValue *string + wantPutError error + }{ + "Put var": { + env: "AN_ENV", + value: ptrValue("value"), + }, + + "Unset a not-previously set value": { + env: "NEVER_SET_ENV", + wantValue: ptrValue(""), + }, + + "Unset a preset value": { + presetValues: map[string]string{"PRESET_ENV": "hey!"}, + env: "PRESET_ENV", + wantValue: ptrValue(""), + }, + + "Changes a preset var": { + presetValues: map[string]string{"PRESET_ENV": "hey!"}, + env: "PRESET_ENV", + value: ptrValue("hello!"), + wantValue: ptrValue("hello!"), + }, + + "Get an unset env": { + skipPut: true, + env: "AN_UNSET_ENV", + wantValue: ptrValue(""), + }, + + "Gets an invalid env name": { + env: "", + value: ptrValue("Invalid Value"), + wantValue: ptrValue(""), + skipPut: true, + }, + + // Error cases + "Error when putting an invalid env name": { + env: "", + value: ptrValue("Invalid Value"), + wantPutError: pam.ErrBadItem, + }, + } + + for name, tc := range tests { + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + t.Cleanup(maybeDoLeakCheck) + + tx := NewModuleTransactionDummy(nil) + envList, err := tx.GetEnvList() + require.NoErrorf(t, err, "Setup: GetEnvList should not return an error") + require.Lenf(t, envList, 0, "Setup: GetEnvList should have elements") + + if tc.presetValues != nil && !tc.skipPut { + for env, value := range tc.presetValues { + err := tx.PutEnv(env + "=" + value) + require.NoError(t, err) + } + envList, err = tx.GetEnvList() + require.NoError(t, err) + require.Equal(t, tc.presetValues, envList) + } + + if !tc.skipPut { + var env string + if tc.value != nil { + env = tc.env + "=" + *tc.value + } else { + env = tc.env + } + err := tx.PutEnv(env) + require.ErrorIs(t, err, tc.wantPutError) + + wantEnv := map[string]string{} + if tc.wantPutError == nil { + if tc.value != nil { + wantEnv = map[string]string{tc.env: *tc.value} + } + if tc.value != nil && tc.wantValue != nil { + wantEnv = map[string]string{tc.env: *tc.wantValue} + } + } + gotEnv, err := tx.GetEnvList() + require.NoError(t, err, "tx.GetEnvList should not return an error") + require.Equal(t, wantEnv, gotEnv, "returned env lits should match expected") + } + + if tc.wantValue != nil { + value := tx.GetEnv(tc.env) + require.Equal(t, value, *tc.wantValue) + } + }) + } +} + +func TestSetGetData(t *testing.T) { + t.Parallel() + t.Cleanup(maybeDoLeakCheck) + + tests := map[string]struct { + key string + data any + presetData map[string]any + skipSet bool + skipGet bool + + wantData any + wantSetError error + wantGetError error + }{ + "Sets and gets data": { + presetData: map[string]any{"some-data": []any{"hey! That's", true}}, + key: "data", + data: []any{"hey! That's", true}, + wantData: []any{"hey! That's", true}, + }, + + "Set replaces data": { + presetData: map[string]any{"some-data": []any{"hey! That's", true}}, + key: "some-data", + data: ModuleTransactionDummy{ + Items: map[pam.Item]string{pam.Tty: "yay"}, + Env: map[string]string{"foo": "bar"}, + }, + wantData: ModuleTransactionDummy{ + Items: map[pam.Item]string{pam.Tty: "yay"}, + Env: map[string]string{"foo": "bar"}, + }, + }, + + // Error cases + "Error when getting data that has never been set": { + skipSet: true, + key: "not set", + wantGetError: pam.ErrNoModuleData, + }, + + "Error when getting data that has been removed": { + presetData: map[string]any{"some-data": []any{"hey! That's", true}}, + key: "some-data", + data: nil, + wantGetError: pam.ErrNoModuleData, + }, + } + + for name, tc := range tests { + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + t.Cleanup(maybeDoLeakCheck) + + tx := NewModuleTransactionDummy(nil) + + if tc.presetData != nil && !tc.skipSet { + for key, value := range tc.presetData { + err := tx.SetData(key, value) + require.NoError(t, err) + } + } + + if !tc.skipSet { + err := tx.SetData(tc.key, tc.data) + require.ErrorIs(t, err, tc.wantSetError) + } + + if !tc.skipGet { + data, err := tx.GetData(tc.key) + require.ErrorIs(t, err, tc.wantGetError) + require.Equal(t, tc.wantData, data) + } + }) + } +} + +func TestGetUser(t *testing.T) { + t.Parallel() + t.Cleanup(maybeDoLeakCheck) + + tests := map[string]struct { + presetUser string + convHandler pam.ConversationHandler + + want string + wantError error + }{ + "Getting a previously set user does not require conversation handler": { + presetUser: "an-user", + want: "an-user", + }, + + "Getting a previously set user does not use conversation handler": { + presetUser: "an-user", + want: "an-user", + convHandler: pam.ConversationFunc(func(s pam.Style, msg string) (string, error) { + return "another-user", pam.ErrConv + }), + }, + + "Getting the user uses conversation handler if none was set": { + want: "provided-user", + convHandler: pam.ConversationFunc( + func(s pam.Style, msg string) (string, error) { + require.Equal(t, msg, "Who are you?") + if msg != "Who are you?" { + return "", pam.ErrConv + } + if s == pam.PromptEchoOn { + return "provided-user", nil + } + return "", pam.ErrConv + }), + }, + + // Error cases + "Error when no conversation is set": { + want: "", + wantError: pam.ErrConv, + }, + } + + for name, tc := range tests { + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + t.Cleanup(maybeDoLeakCheck) + + tx := NewModuleTransactionDummy(tc.convHandler) + + if tc.presetUser != "" { + err := tx.SetItem(pam.User, tc.presetUser) + require.NoError(t, err) + } + + prompt := "Who are you?" + user, err := tx.GetUser(prompt) + require.ErrorIs(t, err, tc.wantError) + require.Equal(t, tc.want, user) + }) + } +} + +func TestStartStringConv(t *testing.T) { + t.Parallel() + t.Cleanup(maybeDoLeakCheck) + + tests := map[string]struct { + prompt string + promptFormat string + promptFormatArgs []interface{} + convStyle pam.Style + convError error + convHandler *pam.ConversationFunc + convShouldNotBeCalled bool + + want string + wantError error + }{ + "Messages with error style are handled by conversation": { + prompt: "This is an error!", + convStyle: pam.ErrorMsg, + want: "I'm handling it fine though", + }, + + "Conversation prompt can be formatted": { + promptFormat: "Sending some %s, right? %v", + promptFormatArgs: []interface{}{"info", true}, + convStyle: pam.TextInfo, + want: "And returning some text back", + }, + + // Error cases + "Error if no conversation handler is set": { + convHandler: ptrValue(pam.ConversationFunc(nil)), + wantError: pam.ErrConv, + }, + + "Error if the conversation handler fails": { + prompt: "Tell me your secret!", + convStyle: pam.PromptEchoOff, + convError: pam.ErrBuf, + wantError: pam.ErrBuf, + }, + + "Error when conversation uses binary content style": { + prompt: "I am a binary content\xff!", + convStyle: pam.BinaryPrompt, + convError: pam.ErrConv, + wantError: pam.ErrConv, + convShouldNotBeCalled: true, + }, + } + + for name, tc := range tests { + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + t.Cleanup(maybeDoLeakCheck) + + convFunCalled := false + tx := NewModuleTransactionDummy(func() pam.ConversationFunc { + if tc.convHandler != nil { + return *tc.convHandler + } + prompt := tc.prompt + if tc.promptFormat != "" { + prompt = fmt.Sprintf(tc.promptFormat, tc.promptFormatArgs...) + } + return pam.ConversationFunc( + func(style pam.Style, msg string) (string, error) { + convFunCalled = true + require.Equal(t, prompt, msg) + require.Equal(t, tc.convStyle, style) + return tc.want, tc.convError + }) + }()) + + var reply pam.StringConvResponse + var err error + + if tc.promptFormat != "" { + reply, err = tx.StartStringConvf(tc.convStyle, tc.promptFormat, + tc.promptFormatArgs...) + } else { + reply, err = tx.StartStringConv(tc.convStyle, tc.prompt) + } + + wantConFuncCalled := !tc.convShouldNotBeCalled && tc.convHandler == nil + require.Equal(t, wantConFuncCalled, convFunCalled) + require.ErrorIs(t, err, tc.wantError) + + if tc.wantError != nil { + require.Zero(t, reply) + return + } + + require.NotNil(t, reply) + require.Equal(t, tc.want, reply.Response()) + require.Equal(t, tc.convStyle, reply.Style()) + }) + } +} + +func TestStartBinaryConv(t *testing.T) { + t.Parallel() + t.Cleanup(maybeDoLeakCheck) + + tests := map[string]struct { + request []byte + convError error + convHandler *pam.ConversationHandler + + want []byte + wantError error + }{ + "Simple binary conversation": { + request: []byte{0x01, 0x02, 0x03}, + want: []byte{0x00, 0x01, 0x02, 0x03, 0x4}, + }, + + // Error cases + "Error if no conversation handler is set": { + convHandler: ptrValue(pam.ConversationHandler(nil)), + wantError: pam.ErrConv, + }, + + "Error if no binary conversation handler is set": { + convHandler: ptrValue(pam.ConversationHandler(pam.ConversationFunc( + func(s pam.Style, msg string) (string, error) { + return "", nil + }))), + wantError: pam.ErrConv, + }, + + "Error if the conversation handler fails": { + request: []byte{0x03, 0x02, 0x01}, + convError: pam.ErrBuf, + wantError: pam.ErrBuf, + }, + } + + for name, tc := range tests { + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + t.Cleanup(maybeDoLeakCheck) + + convFunCalled := false + tx := NewModuleTransactionDummy(func() pam.ConversationHandler { + if tc.convHandler != nil { + return *tc.convHandler + } + + return pam.BinaryConversationFunc( + func(ptr pam.BinaryPointer) ([]byte, error) { + convFunCalled = true + require.NotNil(t, ptr) + bytes := *(*[]byte)(ptr) + require.Equal(t, tc.request, bytes) + return tc.want, tc.convError + }) + }()) + + response, err := tx.StartBinaryConv(tc.request) + require.ErrorIs(t, err, tc.wantError) + require.Equal(t, tc.convHandler == nil, convFunCalled) + + if tc.wantError != nil { + require.Nil(t, response) + return + } + + defer response.Release() + require.NotNil(t, response) + require.Equal(t, pam.BinaryPrompt, response.Style()) + require.NotNil(t, response.Data()) + bytes, err := response.Decode(bytesPointerDecoder) + require.NoError(t, err) + require.Equal(t, tc.want, bytes) + + bytes, err = response.Decode(nil) + require.ErrorContains(t, err, "nil decoder provided") + require.Nil(t, bytes) + }) + } +} + +func TestStartBinaryPointerConv(t *testing.T) { + t.Parallel() + t.Cleanup(maybeDoLeakCheck) + + tests := map[string]struct { + request []byte + convError error + convHandler *pam.ConversationHandler + + want []byte + wantError error + }{ + "With nil argument": { + request: nil, + want: nil, + }, + + "With empty argument": { + request: []byte{}, + want: []byte{}, + }, + + "With simple argument": { + request: []byte{0x01, 0x02, 0x03}, + want: []byte{0x00, 0x01, 0x02, 0x03, 0x4}, + }, + + // Error cases + "Error if no conversation handler is set": { + convHandler: ptrValue(pam.ConversationHandler(nil)), + wantError: pam.ErrConv, + }, + + "Error if no binary conversation handler is set": { + convHandler: ptrValue(pam.ConversationHandler(pam.ConversationFunc( + func(s pam.Style, msg string) (string, error) { + return "", nil + }))), + wantError: pam.ErrConv, + }, + + "Error if the conversation handler fails": { + request: []byte{0xde, 0xad, 0xbe, 0xef, 0xf}, + convError: pam.ErrBuf, + wantError: pam.ErrBuf, + }, + + "Error if no conversation handler is set handles allocated data": { + convError: pam.ErrSystem, + want: []byte{0xde, 0xad, 0xbe, 0xef, 0xf}, + wantError: pam.ErrSystem, + }, + } + + for name, tc := range tests { + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + t.Cleanup(maybeDoLeakCheck) + + convFunCalled := false + tx := NewModuleTransactionDummy(func() pam.ConversationHandler { + if tc.convHandler != nil { + return *tc.convHandler + } + + return pam.BinaryPointerConversationFunc( + func(ptr pam.BinaryPointer) (pam.BinaryPointer, error) { + convFunCalled = true + if tc.request == nil { + require.Nil(t, ptr) + } else { + require.NotNil(t, ptr) + } + bytes := cBytesToBytes(ptr, len(tc.request)) + require.Equal(t, tc.request, bytes) + return allocateCBytes(tc.want), tc.convError + }) + }()) + res, err := tx.StartConv(pam.NewBinaryConvRequest( + allocateCBytes(tc.request), releaseCBytesPointer)) + require.ErrorIs(t, err, tc.wantError) + require.Equal(t, tc.convHandler == nil, convFunCalled) + + if tc.wantError != nil { + require.Nil(t, res) + return + } + + response, ok := res.(pam.BinaryConvResponse) + require.True(t, ok) + defer response.Release() + require.NotNil(t, response) + require.Equal(t, pam.BinaryPrompt, response.Style()) + if tc.want == nil { + require.Nil(t, response.Data()) + } else { + require.NotNil(t, response.Data()) + } + bytes, err := response.Decode(func(ptr pam.BinaryPointer) ([]byte, error) { + return cBytesToBytes(ptr, len(tc.want)), nil + }) + require.NoError(t, err) + require.Equal(t, tc.want, bytes) + + bytes, err = response.Decode(nil) + require.ErrorContains(t, err, "nil decoder provided") + require.Nil(t, bytes) + }) + } +} + +type multiConvHandler struct { + t *testing.T + responses []pam.ConvResponse + wantRequests []pam.ConvRequest + timesCalled int +} + +func (c *multiConvHandler) next() (pam.ConvRequest, pam.ConvResponse) { + i := c.timesCalled + c.timesCalled++ + + return c.wantRequests[i], c.responses[i] +} + +func (c *multiConvHandler) RespondPAM(style pam.Style, prompt string) (string, error) { + wantReq, response := c.next() + require.Equal(c.t, wantReq.Style(), style) + stringReq, ok := wantReq.(pam.StringConvRequest) + require.True(c.t, ok) + require.Equal(c.t, stringReq.Prompt(), prompt) + stringRes, ok := response.(pam.StringConvResponse) + require.True(c.t, ok) + return stringRes.Response(), nil +} + +func (c *multiConvHandler) RespondPAMBinary(ptr pam.BinaryPointer) ([]byte, error) { + wantReq, response := c.next() + require.Equal(c.t, wantReq.Style(), pam.BinaryPrompt) + + binReq, ok := wantReq.(pam.BinaryConvRequester) + require.True(c.t, ok) + wantReqBytes, err := bytesPointerDecoder(binReq.Pointer()) + require.NoError(c.t, err) + actualReqBytes, err := bytesPointerDecoder(ptr) + require.NoError(c.t, err) + require.Equal(c.t, wantReqBytes, actualReqBytes) + + bytes, err := response.(pam.BinaryConvResponse).Decode(bytesPointerDecoder) + require.NoError(c.t, err) + return bytes, nil +} + +func TestStartConvMulti(t *testing.T) { + t.Parallel() + t.Cleanup(maybeDoLeakCheck) + + tests := map[string]struct { + requests []pam.ConvRequest + + wantResponses []pam.ConvResponse + wantConvCalls *int + wantError error + }{ + "Can address multiple string requests": { + requests: []pam.ConvRequest{ + pam.NewStringConvRequest(pam.PromptEchoOff, "give some PromptEchoOff"), + pam.NewStringConvRequest(pam.PromptEchoOn, "give some PromptEchoOn"), + pam.NewStringConvRequest(pam.ErrorMsg, "give some ErrorMsg"), + pam.NewStringConvRequest(pam.TextInfo, "give some TextInfo"), + }, + wantResponses: []pam.ConvResponse{ + StringResponseDummy{pam.PromptEchoOff, "answer to PromptEchoOff"}, + StringResponseDummy{pam.PromptEchoOn, "answer to PromptEchoOn"}, + StringResponseDummy{pam.ErrorMsg, "answer to ErrorMsg"}, + StringResponseDummy{pam.TextInfo, "answer to TextInfo"}, + }, + }, + + "Can address multiple binary requests": { + requests: []pam.ConvRequest{ + NewBinaryRequestDummy(nil), + NewBinaryRequestDummy(pam.BinaryPointer(&[]byte{})), + NewBinaryRequestDummy(pam.BinaryPointer(&[]byte{0xFF, 0x00, 0xBA, 0xAB})), + NewBinaryRequestDummy(pam.BinaryPointer(&[]byte{0x55})), + }, + wantResponses: []pam.ConvResponse{ + &BinaryResponseDummy{pam.BinaryPointer(&[]byte{})}, + &BinaryResponseDummy{nil}, + &BinaryResponseDummy{pam.BinaryPointer(&[]byte{0x53})}, + &BinaryResponseDummy{pam.BinaryPointer(&[]byte{0xAF, 0x00, 0xBA, 0xAC})}, + }, + }, + + "Can address multiple mixed binary and string requests ": { + requests: []pam.ConvRequest{ + NewBinaryRequestDummy(nil), + pam.NewStringConvRequest(pam.PromptEchoOff, "PromptEchoOff"), + NewBinaryRequestDummy(pam.BinaryPointer(&[]byte{})), + pam.NewStringConvRequest(pam.PromptEchoOn, "PromptEchoOn"), + NewBinaryRequestDummy(pam.BinaryPointer(&[]byte{0xFF, 0x00, 0xBA, 0xAB})), + pam.NewStringConvRequest(pam.ErrorMsg, "ErrorMsg"), + NewBinaryRequestDummy(pam.BinaryPointer(&[]byte{0x55})), + pam.NewStringConvRequest(pam.TextInfo, "TextInfo"), + }, + wantResponses: []pam.ConvResponse{ + &BinaryResponseDummy{pam.BinaryPointer(&[]byte{})}, + StringResponseDummy{pam.PromptEchoOff, "PromptEchoOff"}, + &BinaryResponseDummy{pam.BinaryPointer(&[]byte{0x55})}, + StringResponseDummy{pam.PromptEchoOn, "PromptEchoOn"}, + &BinaryResponseDummy{nil}, + StringResponseDummy{pam.ErrorMsg, "ErrorMsg"}, + &BinaryResponseDummy{pam.BinaryPointer(&[]byte{0xAF, 0x00, 0xBA, 0xAC})}, + StringResponseDummy{pam.TextInfo, "TextInfo"}, + }, + }, + + // Error cases + "Error if no request is provided": { + wantError: pam.ErrConv, + }, + + "Error if one of the multiple request fails": { + requests: []pam.ConvRequest{ + NewBinaryRequestDummy(nil), + pam.NewStringConvRequest(pam.PromptEchoOff, "PromptEchoOff"), + NewBinaryRequestDummy(pam.BinaryPointer(&[]byte{})), + pam.NewStringConvRequest(pam.PromptEchoOn, "PromptEchoOn"), + NewBinaryRequestDummy(pam.BinaryPointer(&[]byte{0xFF, 0x00, 0xBA, 0xAB})), + // The case below will lead to the whole request to fail! + pam.NewStringConvRequest(pam.Style(-1), "Invalid style"), + pam.NewStringConvRequest(pam.ErrorMsg, "ErrorMsg"), + NewBinaryRequestDummy(pam.BinaryPointer(&[]byte{0x55})), + pam.NewStringConvRequest(pam.TextInfo, "TextInfo"), + }, + wantResponses: []pam.ConvResponse{ + &BinaryResponseDummy{pam.BinaryPointer(&[]byte{})}, + StringResponseDummy{pam.PromptEchoOff, "PromptEchoOff"}, + &BinaryResponseDummy{pam.BinaryPointer(&[]byte{0x55})}, + StringResponseDummy{pam.PromptEchoOn, "PromptEchoOn"}, + &BinaryResponseDummy{nil}, + StringResponseDummy{pam.Style(-1), "Invalid style"}, + StringResponseDummy{pam.ErrorMsg, "ErrorMsg"}, + &BinaryResponseDummy{pam.BinaryPointer(&[]byte{0xAF, 0x00, 0xBA, 0xAC})}, + StringResponseDummy{pam.TextInfo, "TextInfo"}, + }, + wantConvCalls: ptrValue(5), + wantError: pam.ErrConv, + }, + } + + for name, tc := range tests { + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + t.Cleanup(maybeDoLeakCheck) + + require.Equalf(t, len(tc.wantResponses), len(tc.requests), + "Setup: mismatch on expectations / requests numbers") + + convHandler := &multiConvHandler{ + t: t, + wantRequests: tc.requests, + responses: tc.wantResponses, + } + tx := NewModuleTransactionDummy(convHandler) + + responses, err := tx.StartConvMulti(tc.requests) + require.ErrorIs(t, err, tc.wantError) + + wantConvCalls := len(tc.requests) + if tc.wantConvCalls != nil { + wantConvCalls = *tc.wantConvCalls + } + require.Equal(t, convHandler.timesCalled, wantConvCalls) + + if tc.wantError != nil { + require.Nil(t, responses) + return + } + + require.NotNil(t, responses) + require.Len(t, responses, len(tc.requests)) + + for i, res := range responses { + wantRes := tc.wantResponses[i] + require.Equal(t, wantRes.Style(), res.Style()) + + switch r := res.(type) { + case pam.StringConvResponse: + require.Equal(t, wantRes, res) + case pam.BinaryConvResponse: + wantBinRes, ok := wantRes.(pam.BinaryConvResponse) + require.True(t, ok) + wb, err := wantBinRes.Decode(bytesPointerDecoder) + require.NoError(t, err) + bytes, err := r.Decode(bytesPointerDecoder) + require.NoError(t, err) + require.Equal(t, wb, bytes) + default: + t.Fatalf("conversation %#v is not handled", r) + } + } + }) + } +} diff --git a/pam/pam_test/utils.go b/pam/pam_test/utils.go new file mode 100644 index 0000000000..438603f2a1 --- /dev/null +++ b/pam/pam_test/utils.go @@ -0,0 +1,56 @@ +// Package pam_test includes test tools for the PAM module +package pam_test + +/* +#include + +#ifdef __SANITIZE_ADDRESS__ +#include +#endif + +static inline void +maybe_do_leak_check (void) +{ +#ifdef __SANITIZE_ADDRESS__ + __lsan_do_leak_check(); +#endif +} +*/ +import "C" + +import ( + "runtime" + "time" + "unsafe" + + "github.com/msteinert/pam" +) + +// maybeDoLeakCheck triggers the garbage collector and if the go program is +// compiled with -asan flag, do a memory leak check. +// This is meant to be used as a test Cleanup function, to force Go detecting +// if allocated resources have been released, e.g. using +// t.Cleanup(pam_test.maybeDoLeakCheck) +func maybeDoLeakCheck() { + runtime.GC() + time.Sleep(time.Millisecond * 10) + C.maybe_do_leak_check() +} + +func allocateCBytes(bytes []byte) pam.BinaryPointer { + if bytes == nil { + return nil + } + return pam.BinaryPointer(C.CBytes(bytes)) +} + +func cBytesToBytes(ptr pam.BinaryPointer, size int) []byte { + if ptr == nil { + return nil + } + return C.GoBytes(unsafe.Pointer(ptr), C.int(size)) +} + +func releaseCBytesPointer(ptr pam.BinaryPointer) { + C.free(unsafe.Pointer(ptr)) +} diff --git a/pam/return.go b/pam/return.go index 1eea1fe628..fe93925c1c 100644 --- a/pam/return.go +++ b/pam/return.go @@ -1,54 +1,64 @@ package main +import ( + "github.com/msteinert/pam" +) + // Various signalling return messaging to PAM. -// pamSuccess signals PAM module to return PAM_SUCCESS and Quit tea.Model. +// pamReturnStatus is the interface that all PAM return types should implement. +type pamReturnStatus interface { + Message() string +} + +// pamReturnError is an interface that PAM errors return types should implement. +type pamReturnError interface { + pamReturnStatus + Status() pam.Error +} + +// pamSuccess signals PAM module to return with provided pam.Success and Quit tea.Model. type pamSuccess struct { brokerID string + msg string } -// String returns the string of pamSuccess. -func (err pamSuccess) String() string { - return "" +// Message returns the message that should be sent to pam as info message. +func (p pamSuccess) Message() string { + return p.msg } -// pamIgnore signals PAM module to return PAM_IGNORE and Quit tea.Model. +// pamIgnore signals PAM module to return pam.Ignore and Quit tea.Model. type pamIgnore struct { localBrokerID string // Only set for local broker to store it globally. msg string } -// String returns the string of pamIgnore message. -func (err pamIgnore) String() string { - return err.msg -} - -// pamAbort signals PAM module to return PAM_ABORT and Quit tea.Model. -type pamAbort struct { - msg string -} - -// String returns the string of pamAbort message. -func (err pamAbort) String() string { - return err.msg +// Status returns [pam.ErrIgnore]. +func (p pamIgnore) Status() pam.Error { + return pam.ErrIgnore } -// pamSystemError signals PAM module to return PAM_SYSTEM_ERROR and Quit tea.Model. -type pamSystemError struct { - msg string +// Message returns the message that should be sent to pam as info message. +func (p pamIgnore) Message() string { + return p.msg } -// String returns the string of pamSystemError message. -func (err pamSystemError) String() string { - return err.msg +// pamIgnore signals PAM module to return the provided error message and Quit tea.Model. +type pamError struct { + status pam.Error + msg string } -// pamAuthError signals PAM module to return PAM_AUTH_ERROR and Quit tea.Model. -type pamAuthError struct { - msg string +// Status returns the PAM exit status code. +func (p pamError) Status() pam.Error { + return p.status } -// String returns the string of pamAuthError message. -func (err pamAuthError) String() string { - return err.msg +// Message returns the message that should be sent to pam as error message. +func (p pamError) Message() string { + if p.msg != "" { + return p.msg + } + return p.status.Error() } diff --git a/pam/userselection.go b/pam/userselection.go index aa5954277a..85cecca876 100644 --- a/pam/userselection.go +++ b/pam/userselection.go @@ -3,13 +3,14 @@ package main import ( "github.com/charmbracelet/bubbles/textinput" tea "github.com/charmbracelet/bubbletea" + "github.com/msteinert/pam" ) // userSelectionModel allows selecting from PAM or interactively an user. type userSelectionModel struct { textinput.Model - pamh pamHandle + pamMTx pam.ModuleTransaction } // userSelected events to select a new username. @@ -25,7 +26,7 @@ func sendUserSelected(username string) tea.Cmd { } // newUserSelectionModel returns an initialized userSelectionModel. -func newUserSelectionModel(pamh pamHandle) userSelectionModel { +func newUserSelectionModel(pamMTx pam.ModuleTransaction) userSelectionModel { u := textinput.New() u.Prompt = "Username: " // TODO: i18n u.Placeholder = "user name" @@ -34,13 +35,16 @@ func newUserSelectionModel(pamh pamHandle) userSelectionModel { return userSelectionModel{ Model: u, - pamh: pamh, + pamMTx: pamMTx, } } // Init initializes userSelectionModel, by getting it from PAM if prefilled. func (m *userSelectionModel) Init() tea.Cmd { - pamUser := getPAMUser(m.pamh) + pamUser, err := m.pamMTx.GetItem(pam.User) + if err != nil { + return sendEvent(pamError{status: pam.ErrSystem, msg: err.Error()}) + } if pamUser != "" { return sendUserSelected(pamUser) } @@ -54,7 +58,9 @@ func (m userSelectionModel) Update(msg tea.Msg) (userSelectionModel, tea.Cmd) { if msg.username != "" { // synchronise our internal validated field and the text one. m.SetValue(msg.username) - setPAMUser(m.pamh, msg.username) + if err := m.pamMTx.SetItem(pam.User, msg.username); err != nil { + return m, sendEvent(pamError{status: pam.ErrAbort, msg: err.Error()}) + } return m, sendEvent(UsernameOrBrokerListReceived{}) } return m, nil diff --git a/pam/utils_c.go b/pam/utils_c.go deleted file mode 100644 index dc79be28f5..0000000000 --- a/pam/utils_c.go +++ /dev/null @@ -1,79 +0,0 @@ -package main - -/* -#include -#include -#include -#include -#include - -char *string_from_argv(int i, char **argv) { - return strdup(argv[i]); -} - -char *get_user(pam_handle_t *pamh) { - if (!pamh) - return NULL; - int pam_err = 0; - const char *user; - if ((pam_err = pam_get_item(pamh, PAM_USER, (const void**)&user)) != PAM_SUCCESS) - return NULL; - return strdup(user); -} - -char *set_user(pam_handle_t *pamh, char *username) { - if (!pamh) - return NULL; - int pam_err = 0; - if ((pam_err = pam_set_item(pamh, PAM_USER, (const void*)username)) != PAM_SUCCESS) - return NULL; - return NULL; -} -*/ -import "C" - -import ( - "unsafe" -) - -// pamHandle allows to pass C.pam_handle_t to this package. -type pamHandle = *C.pam_handle_t - -// sliceFromArgv returns a slice of strings given to the PAM module. -func sliceFromArgv(argc C.int, argv **C.char) []string { - r := make([]string, 0, argc) - for i := 0; i < int(argc); i++ { - s := C.string_from_argv(C.int(i), argv) - defer C.free(unsafe.Pointer(s)) - r = append(r, C.GoString(s)) - } - return r -} - -// mockPamUser mocks the PAM user item in absence of pamh for manual testing. -var mockPamUser = "user1" // TODO: remove assignement once ok with debugging - -// getPAMUser returns the user from PAM. -func getPAMUser(pamh *C.pam_handle_t) string { - if pamh == nil { - return mockPamUser - } - cUsername := C.get_user(pamh) - if cUsername == nil { - return "" - } - defer C.free(unsafe.Pointer(cUsername)) - return C.GoString(cUsername) -} - -// setPAMUser set current user to PAM. -func setPAMUser(pamh *C.pam_handle_t, username string) { - if pamh == nil { - mockPamUser = username - return - } - cUsername := C.CString(username) - defer C.free(unsafe.Pointer(cUsername)) - - C.set_user(pamh, cUsername) -}