Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions cmd/baton-retool/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,18 @@ var (
"skip-disabled-users",
field.WithDescription("Skip syncing disabled users"),
)
OrganizationID = field.StringField(
"organization-id",
field.WithDescription("Restrict sync to a single Retool organization by its numeric ID"),
)
)

var configurationFields = []field.SchemaField{
ConnectionString,
SkipPages,
SkipResources,
SkipDisabledUsers,
OrganizationID,
}

var configRelations = []field.SchemaFieldRelationship{}
Expand Down
3 changes: 2 additions & 1 deletion cmd/baton-retool/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ func getConnector(ctx context.Context, v *viper.Viper) (types.ConnectorServer, e
skipPages := v.GetBool(SkipPages.FieldName)
skipResources := v.GetBool(SkipResources.FieldName)
skipDisabledUsers := v.GetBool(SkipDisabledUsers.FieldName)
organizationID := v.GetString(OrganizationID.FieldName)

cb, err := connector.New(ctx, connString, skipPages, skipResources, skipDisabledUsers)
cb, err := connector.New(ctx, connString, skipPages, skipResources, skipDisabledUsers, organizationID)
if err != nil {
l.Error("error creating connector builder", zap.Error(err))
return nil, err
Expand Down
30 changes: 26 additions & 4 deletions pkg/client/organization.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,35 @@ package client

import (
"context"
"fmt"
"strconv"
"strings"

"github.com/georgysavva/scany/pgxscan"
"github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap"
"go.uber.org/zap"
)

type OrgModel struct {
ID int64 `db:"id"`
Name string `db:"name"`
}

func (c *Client) GetOrganization(ctx context.Context, orgID int64) (*OrgModel, error) {
l := ctxzap.Extract(ctx)
l.Debug("getting organization", zap.Int64("org_id", orgID))

var ret OrgModel
err := pgxscan.Get(ctx, c.db, &ret, `SELECT "id", "name" FROM organizations WHERE "id"=$1`, orgID)
if err != nil {
return nil, err
}

return &ret, nil
}

// select id, domain, name, hostname, subdomain from organizations;.
func (c *Client) ListOrganizations(ctx context.Context, pager *Pager) ([]*OrgModel, string, error) {
func (c *Client) ListOrganizations(ctx context.Context, pager *Pager, organizationID *int64) ([]*OrgModel, string, error) {
l := ctxzap.Extract(ctx)
l.Debug("listing organizations")

Expand All @@ -27,13 +42,20 @@ func (c *Client) ListOrganizations(ctx context.Context, pager *Pager) ([]*OrgMod

sb := &strings.Builder{}

_, _ = sb.WriteString(`SELECT "id", "name" FROM organizations ORDER BY "id"`)
_, _ = sb.WriteString(`SELECT "id", "name" FROM organizations `)

if organizationID != nil {
args = append(args, *organizationID)
_, _ = sb.WriteString(fmt.Sprintf(`WHERE "id"=$%d `, len(args)))
}

_, _ = sb.WriteString(`ORDER BY "id" `)

_, _ = sb.WriteString("LIMIT $1 ")
args = append(args, limit+1)
_, _ = sb.WriteString(fmt.Sprintf("LIMIT $%d ", len(args)))
if offset > 0 {
_, _ = sb.WriteString("OFFSET $2")
args = append(args, offset)
_, _ = sb.WriteString(fmt.Sprintf("OFFSET $%d", len(args)))
}

var ret []*OrgModel
Expand Down
13 changes: 13 additions & 0 deletions pkg/client/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,19 @@ func (u *UserModel) GetLastLoggedIn() time.Time {
return time.Time{}
}

func (c *Client) GetUser(ctx context.Context, userID int64) (*UserModel, error) {
l := ctxzap.Extract(ctx)
l.Debug("getting user", zap.Int64("user_id", userID))

var ret UserModel
err := pgxscan.Get(ctx, c.db, &ret, `SELECT "id", "email", "firstName", "lastName", "profilePhotoUrl", "enabled", "userName", "organizationId", "lastLoggedIn" FROM users WHERE "id"=$1`, userID)
if err != nil {
return nil, err
}

return &ret, nil
}

func (c *Client) ListUsersForOrg(ctx context.Context, orgID int64, pager *Pager, skipDisabledUsers bool) ([]*UserModel, string, error) {
l := ctxzap.Extract(ctx)
l.Debug("listing users for org", zap.Int64("org_id", orgID))
Expand Down
17 changes: 15 additions & 2 deletions pkg/connector/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package connector

import (
"context"
"fmt"
"io"
"strconv"

v2 "github.com/conductorone/baton-sdk/pb/c1/connector/v2"
"github.com/conductorone/baton-sdk/pkg/annotations"
Expand All @@ -24,6 +26,7 @@ type ConnectorImpl struct {
skipPages bool
skipResources bool
skipDisabledUsers bool
organizationID *int64
}

func (c *ConnectorImpl) Metadata(ctx context.Context) (*v2.ConnectorMetadata, error) {
Expand All @@ -47,7 +50,7 @@ func (c *ConnectorImpl) Asset(ctx context.Context, asset *v2.AssetRef) (string,

func (c *ConnectorImpl) ResourceSyncers(ctx context.Context) []connectorbuilder.ResourceSyncer {
syncers := []connectorbuilder.ResourceSyncer{
newOrgSyncer(ctx, c.client, c.skipPages, c.skipResources, c.skipDisabledUsers),
newOrgSyncer(ctx, c.client, c.skipPages, c.skipResources, c.skipDisabledUsers, c.organizationID),
newUserSyncer(ctx, c.client, c.skipDisabledUsers),
newGroupSyncer(ctx, c.client, c.skipDisabledUsers),
}
Expand All @@ -63,16 +66,26 @@ func (c *ConnectorImpl) ResourceSyncers(ctx context.Context) []connectorbuilder.
return syncers
}

func New(ctx context.Context, dsn string, skipPages bool, skipResources bool, skipDisabledUsers bool) (*ConnectorImpl, error) {
func New(ctx context.Context, dsn string, skipPages bool, skipResources bool, skipDisabledUsers bool, organizationID string) (*ConnectorImpl, error) {
c, err := client.New(ctx, dsn)
if err != nil {
return nil, err
}

var orgID *int64
if organizationID != "" {
parsed, err := strconv.ParseInt(organizationID, 10, 64)
if err != nil {
return nil, fmt.Errorf("baton-retool: invalid organization-id %q: %w", organizationID, err)
}
orgID = &parsed
}

return &ConnectorImpl{
client: c,
skipPages: skipPages,
skipResources: skipResources,
skipDisabledUsers: skipDisabledUsers,
organizationID: orgID,
}, nil
}
33 changes: 33 additions & 0 deletions pkg/connector/groups.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,10 @@ func (o *groupSyncer) Grant(ctx context.Context, principial *v2.Resource, entitl
return nil, err
}

if err := o.validateOrgMatch(ctx, userID, groupID); err != nil {
return nil, err
}

member, err := o.client.GetGroupMember(ctx, groupID, userID)
if err != nil {
return nil, err
Expand Down Expand Up @@ -207,6 +211,10 @@ func (o *groupSyncer) Revoke(ctx context.Context, grant *v2.Grant) (annotations.
return nil, err
}

if err := o.validateOrgMatch(ctx, userID, groupID); err != nil {
return nil, err
}

err = o.client.RemoveGroupMember(ctx, groupID, userID)
if err != nil {
l.Error(
Expand All @@ -221,6 +229,31 @@ func (o *groupSyncer) Revoke(ctx context.Context, grant *v2.Grant) (annotations.
return nil, nil
}

// validateOrgMatch checks that the user belongs to the same organization as the group.
// This prevents silent failures where a grant/revoke succeeds at the DB level but
// doesn't actually take effect in Retool because the user is in a different org.
func (o *groupSyncer) validateOrgMatch(ctx context.Context, userID, groupID int64) error {
user, err := o.client.GetUser(ctx, userID)
if err != nil {
return fmt.Errorf("baton-retool: failed to get user %d for org validation: %w", userID, err)
}

group, err := o.client.GetGroup(ctx, groupID)
if err != nil {
return fmt.Errorf("baton-retool: failed to get group %d for org validation: %w", groupID, err)
}

groupOrgID := group.GetOrgID()
if groupOrgID != 0 && user.OrganizationID != groupOrgID {
return fmt.Errorf(
"baton-retool: organization mismatch - user %d belongs to org %d but group %d belongs to org %d",
userID, user.OrganizationID, groupID, groupOrgID,
)
}

return nil
}

func newGroupSyncer(ctx context.Context, c *client.Client, skipDisabledUsers bool) *groupSyncer {
return &groupSyncer{
resourceType: resourceTypeGroup,
Expand Down
6 changes: 4 additions & 2 deletions pkg/connector/organizations.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type orgSyncer struct {
skipPages bool
skipResources bool
skipDisabledUsers bool
organizationID *int64
}

func (s *orgSyncer) ResourceType(ctx context.Context) *v2.ResourceType {
Expand All @@ -38,7 +39,7 @@ func (s *orgSyncer) List(
) ([]*v2.Resource, string, annotations.Annotations, error) {
var annos annotations.Annotations

orgs, nextPageToken, err := s.client.ListOrganizations(ctx, &client.Pager{Token: pToken.Token, Size: pToken.Size})
orgs, nextPageToken, err := s.client.ListOrganizations(ctx, &client.Pager{Token: pToken.Token, Size: pToken.Size}, s.organizationID)
if err != nil {
return nil, "", nil, err
}
Expand Down Expand Up @@ -282,12 +283,13 @@ func (s *orgSyncer) Grants(ctx context.Context, resource *v2.Resource, pToken *p
return ret, nextPageToken, nil, nil
}

func newOrgSyncer(ctx context.Context, c *client.Client, skipPages bool, skipResources bool, skipDisabledUsers bool) *orgSyncer {
func newOrgSyncer(ctx context.Context, c *client.Client, skipPages bool, skipResources bool, skipDisabledUsers bool, organizationID *int64) *orgSyncer {
return &orgSyncer{
resourceType: resourceTypeOrg,
client: c,
skipPages: skipPages,
skipResources: skipResources,
skipDisabledUsers: skipDisabledUsers,
organizationID: organizationID,
}
}
24 changes: 16 additions & 8 deletions pkg/connector/users.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ func (s *userSyncer) List(
return nil, "", nil, err
}

org, err := s.client.GetOrganization(ctx, orgID)
if err != nil {
return nil, "", nil, fmt.Errorf("baton-retool: failed to get organization %d: %w", orgID, err)
}

users, nextPageToken, err := s.client.ListUsersForOrg(ctx, orgID, &client.Pager{Token: pToken.Token, Size: pToken.Size}, s.skipDisabledUsers)
if err != nil {
return nil, "", nil, err
Expand All @@ -61,22 +66,25 @@ func (s *userSyncer) List(

resourceID := formatObjectID(resourceTypeUser.Id, o.ID)
ut, err := resources.NewUserTrait(resources.WithEmail(o.Email, true), resources.WithStatus(utStatus), resources.WithUserProfile(map[string]interface{}{
"email": o.Email,
"first_name": o.GetFirstName(),
"last_name": o.GetLastName(),
"user_id": fmt.Sprintf("%s:%s", parentResourceID.Resource, resourceID),
"last_logged_in": o.GetLastLoggedIn().Format("2006-01-02 15:04:05.999999999 -0700 MST"),
"organization_id": o.OrganizationID,
"user_name": o.GetUserName(),
"email": o.Email,
"first_name": o.GetFirstName(),
"last_name": o.GetLastName(),
"user_id": fmt.Sprintf("%s:%s", parentResourceID.Resource, resourceID),
"last_logged_in": o.GetLastLoggedIn().Format("2006-01-02 15:04:05.999999999 -0700 MST"),
"organization_id": o.OrganizationID,
"organization_name": org.Name,
"user_name": o.GetUserName(),
}))
if err != nil {
return nil, "", nil, err
}

annos.Append(ut)

displayName := fmt.Sprintf("%s %s (%s)", o.GetFirstName(), o.GetLastName(), org.Name)

ret = append(ret, &v2.Resource{
DisplayName: fmt.Sprintf("%s %s", o.GetFirstName(), o.GetLastName()),
DisplayName: displayName,
Id: &v2.ResourceId{
ResourceType: s.resourceType.Id,
Resource: resourceID,
Expand Down
Loading