diff --git a/cmd/commit.go b/cmd/commit.go index 55c2de4..cf842a4 100644 --- a/cmd/commit.go +++ b/cmd/commit.go @@ -20,7 +20,6 @@ import ( "context" "fmt" "io" - "os" "path/filepath" "strings" "time" @@ -32,13 +31,14 @@ import ( "github.com/spf13/cast" "github.com/spf13/cobra" + "github.com/marstr/baronial/internal/format" "github.com/marstr/baronial/internal/index" ) const ( amountFlag = "amount" amountShorthand = "a" - amountDefault = "" + amountDefault = "" amountUsage = "The magnitude of the transaction that should be displayed in logs." ) @@ -77,6 +77,13 @@ const ( forceUsage = "Ignore warnings, commit the transaction anyway." ) +const ( + dryrunFlag = "dry-run" + dryrunShorthand = "d" + dryrunDefault = false + dryrunUsage = "Generates and prints a commit without writing it or updating any references." +) + const ( bankRecordIDFlag = "bank-record-id" bankRecordIDShorthand = "b" @@ -121,16 +128,6 @@ var commitCmd = &cobra.Command{ if err != nil { return err } - } else { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - var err error - commitTransactionFromFlags.Amount, err = calculateAmount(ctx, ".") - if err != nil { - logrus.Fatalf("Failed to calculate the amount from %q because of the following error: %s", amountDefault, err) - } - } commitTransactionFromFlags.EnteredTime = time.Now() @@ -146,6 +143,14 @@ var commitCmd = &cobra.Command{ logrus.Fatal(err) } + if !cmd.Flags().Changed(amountFlag) { + var err error + commitTransactionFromFlags.Amount, err = calculateAmount(ctx, ".") + if err != nil { + logrus.Fatalf("Failed to calculate the amount from %q because of the following error: %s", amountDefault, err) + } + } + commitTransactionFromFlags.State, err = index.LoadState(ctx, targetDir) if err != nil { logrus.Fatal(err) @@ -174,8 +179,8 @@ var commitCmd = &cobra.Command{ shouldContinue, err := promptToContinue( ctx, "proceed despite imbalance?", - os.Stdout, - os.Stdin) + cmd.OutOrStdout(), + cmd.InOrStdin()) if err != nil { logrus.Fatal(err) } @@ -255,10 +260,44 @@ var commitCmd = &cobra.Command{ } commitTransactionFromFlags.RecordID = envelopes.BankRecordID(rawRecordId) - err = persist.Commit(ctx, repo, commitTransactionFromFlags, additionalParents...) + var dryrun bool + dryrun, err = cmd.Flags().GetBool(dryrunFlag) if err != nil { logrus.Fatal(err) } + + if dryrun { + var head persist.RefSpec + head, err = repo.Current(ctx) + if err != nil { + logrus.Fatal(err) + } + + var parent envelopes.ID + if head != "" { + parent, err = persist.Resolve(ctx, repo, head) + if err != nil { + logrus.Fatal(err) + } + } + + if parent.Equal(envelopes.ID{}) { + commitTransactionFromFlags.Parents = []envelopes.ID{} + } else { + commitTransactionFromFlags.Parents = append([]envelopes.ID{parent}, additionalParents...) + } + + err = format.PrettyPrintTransaction(ctx, cmd.OutOrStdout(), repo, commitTransactionFromFlags) + if err != nil { + logrus.Fatal(err) + } + + } else { + err = persist.Commit(ctx, repo, commitTransactionFromFlags, additionalParents...) + if err != nil { + logrus.Fatal(err) + } + } }, } @@ -272,6 +311,7 @@ func init() { commitCmd.Flags().StringP(amountFlag, amountShorthand, amountDefault, amountUsage) commitCmd.Flags().StringP(bankRecordIDFlag, bankRecordIDShorthand, bankRecordIDDefault, bankRecordIDUsage) commitCmd.Flags().BoolP(forceFlag, forceShorthand, forceDefault, forceUsage) + commitCmd.Flags().BoolP(dryrunFlag, dryrunShorthand, dryrunDefault, dryrunUsage) } func promptToContinue(ctx context.Context, message string, output io.Writer, input io.Reader) (bool, error) { @@ -346,9 +386,6 @@ func promptToContinue(ctx context.Context, message string, output io.Writer, inp } func calculateAmount(ctx context.Context, targetDir string) (envelopes.Balance, error) { - ctx, cancel := context.WithTimeout(ctx, 2*time.Second) - defer cancel() - targetDir, err := index.RootDirectory(targetDir) if err != nil { return envelopes.Balance{}, err