Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 87 additions & 24 deletions cpuhours/cpuhours.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cpuhours

import (
"context"
"encoding/json"
"fmt"
"time"

Expand All @@ -24,6 +25,13 @@ type CPUHours struct {
nc *nats.EncodedConn
}

type CalculationResult struct {
CPUHours *apd.Decimal
Analysis *db.Analysis
BasisTime time.Time
CalcTime time.Time
}

func New(db *db.Database, nc *nats.EncodedConn) *CPUHours {
return &CPUHours{
db: db,
Expand All @@ -32,52 +40,76 @@ func New(db *db.Database, nc *nats.EncodedConn) *CPUHours {
}

// CPUHoursForAnalysis returns the CPU hours total for the analysis as a decimal value.
func (c *CPUHours) CPUHoursForAnalysis(context context.Context, analysisID string) (*apd.Decimal, *db.Analysis, error) {
func (c *CPUHours) CPUHoursForAnalysis(context context.Context, analysisID string) (CalculationResult, error) {
var (
endTime time.Time
analysis *db.Analysis
err error
basisTime time.Time
calcTime time.Time
analysis *db.Analysis
err error
res CalculationResult
)
log = log.WithFields(logrus.Fields{"context": "calculating CPU hours", "analysisID": analysisID})

log.Debug("getting millicores reserved")
millicoresReserved, err := c.db.MillicoresReserved(context, analysisID)
if err != nil {
return nil, nil, err
return res, err
}
log.Debug("done getting millicores reserved")

for {
log.Debug("getting analysis info")
for i := 0; i < 5; i++ { // Try five times, then use time.Now().UTC() instead
log.Debug("getting analysis info and locking row")
analysis, err = c.db.AnalysisWithoutUser(context, analysisID)
if err != nil {
return nil, nil, err
return res, err
}
log.Debug("done getting analysis info")

if !analysis.StartDate.Valid {
return nil, nil, fmt.Errorf("start date is null")
return res, fmt.Errorf("start date is null")
}

// It's possible for this to be reached before the database is updated with the actual
// end date. If that's the case, wait a bit and try again.
//
// We drop and restart the transaction here to avoid lock
// issues and allow the end date to get set by other processes
if !analysis.EndDate.Valid {
if err := c.db.Rollback(); err != nil {
log.WithError(err).Error("failed to rollback transaction")
}
time.Sleep(5 * time.Second)
c.db.Begin(context) // nolint: errcheck
continue

} else {
endTime = analysis.EndDate.Time.UTC()
calcTime = analysis.EndDate.Time.UTC()
break
}
}

startTime := analysis.StartDate.Time.UTC()
res.Analysis = analysis

if calcTime.IsZero() {
calcTime = time.Now().UTC()
}

// Start calculation at the most recent of StartTime or UsageLastUpdate
// calculate to EndDate or now, whichever is earlier
// so start -> now, last update -> now, start -> end time already past, or last update -> end time already past
// then update last update time to the now value that was used
basisTime = analysis.StartDate.Time.UTC()
if analysis.UsageLastUpdate.Valid && analysis.UsageLastUpdate.Time.UTC().After(basisTime) {
basisTime = analysis.UsageLastUpdate.Time.UTC()
}

log.Infof("start date: %s, end date: %s", startTime.String(), endTime.String())
res.BasisTime = basisTime
res.CalcTime = calcTime
log.Infof("basis date: %s, end date: %s", basisTime.String(), calcTime.String())

timeSpent, err := apd.New(0, 0).SetFloat64(endTime.Sub(startTime).Hours())
timeSpent, err := apd.New(0, 0).SetFloat64(calcTime.Sub(basisTime).Hours())
if err != nil {
return nil, nil, err
return res, err
}

mcReserved := apd.New(0, 0).SetInt64(millicoresReserved)
Expand All @@ -87,21 +119,30 @@ func (c *CPUHours) CPUHoursForAnalysis(context context.Context, analysisID strin
bc := apd.BaseContext.WithPrecision(15)
_, err = bc.Mul(cpuHours, mcReserved, timeSpent)
if err != nil {
return nil, nil, err
return res, err
}

_, err = bc.Quo(cpuHours, cpuHours, mc2cores)
if err != nil {
return nil, nil, err
return res, err
}

log.Infof("run time is %s hours; millicores reserved is %s; cpu hours is %s", timeSpent.String(), mcReserved.String(), cpuHours.String())

return cpuHours, analysis, nil
err = c.db.SetUsageLastUpdate(context, analysisID, calcTime)
if err != nil {
return res, err
}

res.CPUHours = cpuHours

return res, nil
}

func (c *CPUHours) addEvent(context context.Context, analysis *db.Analysis, cpuHours *apd.Decimal) error {
func (c *CPUHours) addEvent(context context.Context, res CalculationResult) error {
var err error
analysis := res.Analysis
cpuHours := res.CPUHours

floatValue, err := cpuHours.Float64()
if err != nil {
Expand All @@ -113,6 +154,11 @@ func (c *CPUHours) addEvent(context context.Context, analysis *db.Analysis, cpuH
return err
}

metajson, err := json.Marshal(res)
if err != nil {
return err
}

update := &qms.Update{
ValueType: "usages",
Value: floatValue,
Expand All @@ -127,6 +173,7 @@ func (c *CPUHours) addEvent(context context.Context, analysis *db.Analysis, cpuH
User: &qms.QMSUser{
Username: username,
},
Metadata: string(metajson),
}

request := pbinit.NewAddUpdateRequest(update)
Expand All @@ -147,26 +194,42 @@ func (c *CPUHours) addEvent(context context.Context, analysis *db.Analysis, cpuH

func (c *CPUHours) CalculateForAnalysisByID(context context.Context, analysisID string) error {
var (
cpuHours *apd.Decimal
analysis *db.Analysis
err error
res CalculationResult
err error
)

cpuHours, analysis, err = c.CPUHoursForAnalysis(context, analysisID)
res, err = c.CPUHoursForAnalysis(context, analysisID)
if err != nil {
return err
}

return c.addEvent(context, analysis, cpuHours)
return c.addEvent(context, res)
}

func (c *CPUHours) CalculateForAnalysis(context context.Context, externalID string) error {
log.Debug("getting analysis id")

// We'll do this lookup outside the transaction to limit the lock time
analysisID, err := c.db.GetAnalysisIDByExternalID(context, externalID)
if err != nil {
return err
}
log.Debug("done getting analysis id")

return c.CalculateForAnalysisByID(context, analysisID)
err = c.db.Begin(context)
if err != nil {
return err
}
defer c.db.Rollback() // nolint:errcheck

err = c.CalculateForAnalysisByID(context, analysisID)
if err != nil {
rollbackErr := c.db.Rollback()
if rollbackErr != nil {
log.WithError(rollbackErr).Error("failed to rollback transaction")
}
return err
} else {
return c.db.Commit()
}
}
53 changes: 37 additions & 16 deletions db/analyses.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,18 @@ import (
)

type Analysis struct {
ID string `db:"id"`
AppID string `db:"app_id"`
StartDate null.Time `db:"start_date"`
EndDate null.Time `db:"end_date"`
Status string `db:"status"`
Deleted bool `db:"deleted"`
Submission string `db:"submission"`
UserID string `db:"user_id"`
JobType string `db:"job_type"`
SystemID string `db:"system_id"`
Subdomain null.String `db:"subdomain"`
ID string `db:"id"`
AppID string `db:"app_id"`
StartDate null.Time `db:"start_date"`
EndDate null.Time `db:"end_date"`
Status string `db:"status"`
Deleted bool `db:"deleted"`
Submission string `db:"submission"`
UserID string `db:"user_id"`
JobType string `db:"job_type"`
SystemID string `db:"system_id"`
Subdomain null.String `db:"subdomain"`
UsageLastUpdate null.Time `db:"usage_last_update"`
}

// GetAnalysisIDByExternalID returns the analysis ID based on the external ID
Expand All @@ -31,7 +32,7 @@ func (d *Database) GetAnalysisIDByExternalID(context context.Context, externalID
JOIN job_steps s ON s.job_id = j.id
WHERE s.external_id = $1
`
err := d.db.QueryRowxContext(context, q, externalID).Scan(&analysisID)
err := d.Q().QueryRowxContext(context, q, externalID).Scan(&analysisID)
if err != nil {
return "", err
}
Expand All @@ -50,14 +51,16 @@ func (d *Database) AnalysisWithoutUser(context context.Context, analysisID strin
j.submission,
j.user_id,
j.subdomain,
j.usage_last_update,
t.name job_type,
t.system_id
FROM jobs j
JOIN job_types t ON j.job_type_id = t.id
WHERE j.id = $1;
WHERE j.id = $1
FOR NO KEY UPDATE;
`
var analysis Analysis
err := d.db.QueryRowxContext(context, q, analysisID).StructScan(&analysis)
err := d.Q().QueryRowxContext(context, q, analysisID).StructScan(&analysis)
return &analysis, err
}

Expand All @@ -74,17 +77,35 @@ func (d *Database) Analysis(context context.Context, userID, id string) (*Analys
j.submission,
j.user_id,
j.subdomain,
j.usage_last_update,
t.name job_type,
t.system_id
FROM jobs j
JOIN job_types t ON j.job_type_id = job_types.id
WHERE j.id = $1
AND j.user_id = $2;
`
err := d.db.QueryRowxContext(context, q, id, userID).StructScan(&analysis)
err := d.Q().QueryRowxContext(context, q, id, userID).StructScan(&analysis)
return &analysis, err
}

// SetUsageLastUpdate updates the `usage_last_update` column of the jobs table to the provided time
func (d *Database) SetUsageLastUpdate(context context.Context, analysisID string, usagetime time.Time) error {
const q = `
UPDATE jobs
SET usage_last_update = $2
WHERE id = $1
`

_, err := d.Q().ExecContext(
context,
q,
analysisID,
usagetime.Local(), // we store things in the DB as non-UTC time
)
return err
}

type CalculableAnalysis struct {
ID string `db:"id"`
StartDate time.Time `db:"start_date"`
Expand All @@ -109,7 +130,7 @@ func (d *Database) AdminAllCalculableAnalyses(context context.Context, userID st
AND j.end_date <= $3::timestamp;

`
rows, err := d.db.QueryxContext(context, q, userID, from, to)
rows, err := d.Q().QueryxContext(context, q, userID, from, to)
if err != nil {
return nil, err
}
Expand Down
Loading