diff --git a/internal/logging/Logging.go b/internal/logging/Logging.go index c8357703..0fb88bf5 100644 --- a/internal/logging/Logging.go +++ b/internal/logging/Logging.go @@ -457,3 +457,40 @@ func GetIpAddress(r *http.Request) string { return ip } + +// IsFromTrustedProxy returns true if the request originates from a trusted proxy +func IsFromTrustedProxy(r *http.Request) bool { + ip, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + ip = r.RemoteAddr + } + netIP := net.ParseIP(ip) + if netIP == nil { + return false + } + return isTrustedProxy(netIP) +} + +// GetServerUrl returns the server URL, taking into account proxy headers if the request is from a trusted proxy. +// If the request is nil or not from a trusted proxy, it returns the fallbackUrl. +func GetServerUrl(r *http.Request, fallbackUrl string) string { + if r == nil || !IsFromTrustedProxy(r) { + return fallbackUrl + } + proto := r.Header.Get("X-Forwarded-Proto") + if proto == "" { + proto = "http" + if r.TLS != nil { + proto = "https" + } + } + host := r.Header.Get("X-Forwarded-Host") + if host == "" { + host = r.Host + } + url := proto + "://" + host + if !strings.HasSuffix(url, "/") { + url += "/" + } + return url +} diff --git a/internal/webserver/api/Api.go b/internal/webserver/api/Api.go index f643a696..ebdc2f87 100644 --- a/internal/webserver/api/Api.go +++ b/internal/webserver/api/Api.go @@ -115,7 +115,7 @@ func apiEditFile(w http.ResponseWriter, r requestParser, user models.User) { database.SaveMetaData(file) logging.LogEdit(file, user) - outputFileApiInfo(w, file) + outputFileApiInfo(w, file, request.WebRequest) } // generateNewKey generates and saves a new API key @@ -341,7 +341,7 @@ func apiRestoreFile(w http.ResponseWriter, r requestParser, user models.User) { return } logging.LogRestore(file, user) - outputFileJson(w, file) + outputFileJson(w, file, request.WebRequest) } func apiChunkAdd(w http.ResponseWriter, r requestParser, _ models.User) { @@ -468,14 +468,14 @@ func apiChunkComplete(w http.ResponseWriter, r requestParser, user models.User) request.FileSize, "") if request.IsNonBlocking { - go doBlockingPartCompleteChunk(nil, request.Uuid, request.FileHeader, user, uploadParams) + go doBlockingPartCompleteChunk(nil, request.Uuid, request.FileHeader, user, uploadParams, request.WebRequest) _, _ = io.WriteString(w, "{\"result\":\"OK\"}") return } - doBlockingPartCompleteChunk(w, request.Uuid, request.FileHeader, user, uploadParams) + doBlockingPartCompleteChunk(w, request.Uuid, request.FileHeader, user, uploadParams, request.WebRequest) } -func doBlockingPartCompleteChunk(w http.ResponseWriter, uuid string, fileHeader chunking.FileHeader, user models.User, uploadParameters models.UploadParameters) { +func doBlockingPartCompleteChunk(w http.ResponseWriter, uuid string, fileHeader chunking.FileHeader, user models.User, uploadParameters models.UploadParameters, r *http.Request) { file, err := fileupload.CompleteChunk(uuid, fileHeader, user.Id, uploadParameters) if err != nil { sendError(w, http.StatusBadRequest, errorcodes.UnspecifiedError, err.Error()) @@ -486,7 +486,7 @@ func doBlockingPartCompleteChunk(w http.ResponseWriter, uuid string, fileHeader } fr, _ := filerequest.Get(uploadParameters.FileRequestId) logging.LogUpload(file, user, fr) - outputFileJson(w, file) + outputFileJson(w, file, r) } func apiChunkUploadRequestComplete(w http.ResponseWriter, r requestParser, user models.User) { @@ -503,11 +503,11 @@ func apiChunkUploadRequestComplete(w http.ResponseWriter, r requestParser, user 0, "", true, true, false, request.FileSize, fileRequest.Id) if request.IsNonBlocking { - go doBlockingPartCompleteChunk(nil, request.Uuid, request.FileHeader, user, uploadParams) + go doBlockingPartCompleteChunk(nil, request.Uuid, request.FileHeader, user, uploadParams, request.WebRequest) _, _ = io.WriteString(w, "{\"result\":\"OK\"}") return } - doBlockingPartCompleteChunk(w, request.Uuid, request.FileHeader, user, uploadParams) + doBlockingPartCompleteChunk(w, request.Uuid, request.FileHeader, user, uploadParams, request.WebRequest) } func apiVersionInfo(w http.ResponseWriter, _ requestParser, _ models.User) { @@ -540,23 +540,24 @@ func apiList(w http.ResponseWriter, r requestParser, user models.User) { if !ok { panic("invalid parameter passed") } - validFiles := getFilesForUser(user, request.ShowFileRequests) + validFiles := getFilesForUser(user, request.ShowFileRequests, request.WebRequest) result, err := json.Marshal(validFiles) helper.Check(err) _, _ = w.Write(result) } -func getFilesForUser(user models.User, includeUploadRequests bool) []models.FileApiOutput { +func getFilesForUser(user models.User, includeUploadRequests bool, r *http.Request) []models.FileApiOutput { var validFiles []models.FileApiOutput timeNow := time.Now().Unix() config := configuration.Get() + serverUrl := logging.GetServerUrl(r, config.ServerUrl) for _, element := range database.GetAllMetadata() { if !includeUploadRequests && element.IsFileRequest() { continue } if element.UserId == user.Id || user.HasPermission(models.UserPermListOtherUploads) { if !storage.IsExpiredFile(element, timeNow) { - file, err := element.ToFileApiOutput(config.ServerUrl, config.IncludeFilename) + file, err := element.ToFileApiOutput(serverUrl, config.IncludeFilename) helper.Check(err) validFiles = append(validFiles, file) } @@ -580,7 +581,8 @@ func apiListSingle(w http.ResponseWriter, r requestParser, user models.User) { return } config := configuration.Get() - output, err := file.ToFileApiOutput(config.ServerUrl, config.IncludeFilename) + serverUrl := logging.GetServerUrl(request.WebRequest, config.ServerUrl) + output, err := file.ToFileApiOutput(serverUrl, config.IncludeFilename) helper.Check(err) result, err := json.Marshal(output) helper.Check(err) @@ -601,7 +603,7 @@ func apiDownloadSingle(w http.ResponseWriter, r requestParser, user models.User) storage.ServeFile(file, w, request.WebRequest, true, request.IncreaseCounter, true) return } - createAndOutputPresignedUrl([]string{file.Id}, w, "") + createAndOutputPresignedUrl([]string{file.Id}, w, request.WebRequest, "") } func apiDownloadZip(w http.ResponseWriter, r requestParser, user models.User) { @@ -624,7 +626,7 @@ func apiDownloadZip(w http.ResponseWriter, r requestParser, user models.User) { storage.ServeFilesAsZip(requestedFiles, request.Filename, w, request.WebRequest) return } - createAndOutputPresignedUrl(requestedFileIds, w, request.Filename) + createAndOutputPresignedUrl(requestedFileIds, w, request.WebRequest, request.Filename) } func checkDownloadAllowed(fileId string, user models.User) (models.File, int, int, string) { @@ -641,7 +643,7 @@ func checkDownloadAllowed(fileId string, user models.User) (models.File, int, in return file, 0, 0, "" } -func createAndOutputPresignedUrl(ids []string, w http.ResponseWriter, filename string) { +func createAndOutputPresignedUrl(ids []string, w http.ResponseWriter, r *http.Request, filename string) { presignUrl := models.Presign{ Id: helper.GenerateRandomString(60), FileIds: ids, @@ -649,10 +651,11 @@ func createAndOutputPresignedUrl(ids []string, w http.ResponseWriter, filename s Filename: filename, } presign.Save(presignUrl) + serverUrl := logging.GetServerUrl(r, configuration.Get().ServerUrl) response := struct { Result string `json:"Result"` DownloadUrl string `json:"downloadUrl"` - }{"OK", configuration.Get().ServerUrl + "downloadPresigned?key=" + presignUrl.Id} + }{"OK", serverUrl + "downloadPresigned?key=" + presignUrl.Id} result, err := json.Marshal(response) helper.Check(err) _, _ = w.Write(result) @@ -705,7 +708,7 @@ func apiDuplicateFile(w http.ResponseWriter, r requestParser, user models.User) sendError(w, http.StatusInternalServerError, errorcodes.InternalServer, err.Error()) return } - outputFileApiInfo(w, newFile) + outputFileApiInfo(w, newFile, request.WebRequest) } func apiChangeFileOwner(w http.ResponseWriter, r requestParser, user models.User) { @@ -729,7 +732,7 @@ func apiChangeFileOwner(w http.ResponseWriter, r requestParser, user models.User } file.UserId = request.NewOwner database.SaveMetaData(file) - outputFileApiInfo(w, file) + outputFileApiInfo(w, file, request.WebRequest) } func apiReplaceFile(w http.ResponseWriter, r requestParser, user models.User) { @@ -774,24 +777,26 @@ func apiReplaceFile(w http.ResponseWriter, r requestParser, user models.User) { return } logging.LogReplace(fileOriginal, modifiedFile, user) - outputFileApiInfo(w, modifiedFile) + outputFileApiInfo(w, modifiedFile, request.WebRequest) } -func outputFileApiInfo(w http.ResponseWriter, file models.File) { +func outputFileApiInfo(w http.ResponseWriter, file models.File, r *http.Request) { config := configuration.Get() - publicOutput, err := file.ToFileApiOutput(config.ServerUrl, config.IncludeFilename) + serverUrl := logging.GetServerUrl(r, config.ServerUrl) + publicOutput, err := file.ToFileApiOutput(serverUrl, config.IncludeFilename) helper.Check(err) result, err := json.Marshal(publicOutput) helper.Check(err) _, _ = w.Write(result) } -func outputFileJson(w http.ResponseWriter, file models.File) { +func outputFileJson(w http.ResponseWriter, file models.File, r *http.Request) { if w == nil { return } config := configuration.Get() - _, _ = io.WriteString(w, file.ToJsonResult(config.ServerUrl, config.IncludeFilename)) + serverUrl := logging.GetServerUrl(r, config.ServerUrl) + _, _ = io.WriteString(w, file.ToJsonResult(serverUrl, config.IncludeFilename)) } func apiModifyUser(w http.ResponseWriter, r requestParser, user models.User) { @@ -1035,7 +1040,7 @@ func apiLogResetTraffic(w http.ResponseWriter, _ requestParser, _ models.User) { func apiE2eGet(w http.ResponseWriter, _ requestParser, user models.User) { info := database.GetEnd2EndInfo(user.Id) // If e2e is supported for upload requests at some point, this needs to be changed - files := getFilesForUser(user, false) + files := getFilesForUser(user, false, nil) ids := make([]string, len(files)) for i, file := range files { ids[i] = file.Id diff --git a/internal/webserver/api/Api_test.go b/internal/webserver/api/Api_test.go index 763bb036..30195ca1 100644 --- a/internal/webserver/api/Api_test.go +++ b/internal/webserver/api/Api_test.go @@ -1530,7 +1530,7 @@ func TestChunkComplete(t *testing.T) { } func TestMinorFunctions(t *testing.T) { - outputFileJson(nil, models.File{}) + outputFileJson(nil, models.File{}, nil) sendError(nil, 0, 0, "none") } diff --git a/internal/webserver/api/routing.go b/internal/webserver/api/routing.go index b9f1e10c..b844329e 100644 --- a/internal/webserver/api/routing.go +++ b/internal/webserver/api/routing.go @@ -291,19 +291,23 @@ type requestParser interface { } type paramFilesListAll struct { + WebRequest *http.Request ShowFileRequests bool `header:"showFileRequests"` foundHeaders map[string]bool } -func (p *paramFilesListAll) ProcessParameter(_ *http.Request) error { +func (p *paramFilesListAll) ProcessParameter(r *http.Request) error { + p.WebRequest = r return nil } type paramFilesListSingle struct { - Id string + Id string + WebRequest *http.Request } func (p *paramFilesListSingle) ProcessParameter(r *http.Request) error { + p.WebRequest = r url := parseRequestUrl(r) p.Id = strings.TrimPrefix(url, "/files/list/") return nil @@ -352,10 +356,12 @@ func (p *paramFilesAdd) ProcessParameter(r *http.Request) error { type paramFilesChangeOwner struct { Id string `header:"id" required:"true"` NewOwner int `header:"newOwner" required:"true"` + WebRequest *http.Request foundHeaders map[string]bool } -func (p *paramFilesChangeOwner) ProcessParameter(_ *http.Request) error { +func (p *paramFilesChangeOwner) ProcessParameter(r *http.Request) error { + p.WebRequest = r return nil } @@ -369,10 +375,12 @@ type paramFilesDuplicate struct { UnlimitedDownloads bool UnlimitedTime bool RequestedChanges int + WebRequest *http.Request foundHeaders map[string]bool } func (p *paramFilesDuplicate) ProcessParameter(r *http.Request) error { + p.WebRequest = r if p.foundHeaders["allowedDownloads"] { p.RequestedChanges |= storage.ParamDownloads if p.AllowedDownloads == 0 { @@ -405,10 +413,12 @@ type paramFilesModify struct { UnlimitedDownloads bool UnlimitedExpiry bool IsPasswordSet bool + WebRequest *http.Request foundHeaders map[string]bool } -func (p *paramFilesModify) ProcessParameter(_ *http.Request) error { +func (p *paramFilesModify) ProcessParameter(r *http.Request) error { + p.WebRequest = r if p.foundHeaders["allowedDownloads"] && p.AllowedDownloads == 0 { p.UnlimitedDownloads = true } @@ -423,10 +433,14 @@ type paramFilesReplace struct { Id string `header:"id" required:"true"` IdNewContent string `header:"idNewContent" required:"true"` Delete bool `header:"deleteNewFile"` + WebRequest *http.Request foundHeaders map[string]bool } -func (p *paramFilesReplace) ProcessParameter(_ *http.Request) error { return nil } +func (p *paramFilesReplace) ProcessParameter(r *http.Request) error { + p.WebRequest = r + return nil +} type paramFilesDelete struct { Id string `header:"id" required:"true"` @@ -438,10 +452,14 @@ func (p *paramFilesDelete) ProcessParameter(_ *http.Request) error { return nil type paramFilesRestore struct { Id string `header:"id" required:"true"` + WebRequest *http.Request foundHeaders map[string]bool } -func (p *paramFilesRestore) ProcessParameter(_ *http.Request) error { return nil } +func (p *paramFilesRestore) ProcessParameter(r *http.Request) error { + p.WebRequest = r + return nil +} type paramAuthCreate struct { FriendlyName string `header:"friendlyName"` @@ -661,10 +679,12 @@ type paramChunkComplete struct { UnlimitedDownloads bool UnlimitedTime bool FileHeader chunking.FileHeader + WebRequest *http.Request foundHeaders map[string]bool } -func (p *paramChunkComplete) ProcessParameter(_ *http.Request) error { +func (p *paramChunkComplete) ProcessParameter(r *http.Request) error { + p.WebRequest = r if !p.foundHeaders["realsize"] { if !p.IsE2E { @@ -735,10 +755,12 @@ type paramChunkUploadRequestComplete struct { IsNonBlocking bool `header:"nonblocking"` ApiKey string `header:"apikey" unpublished:"true"` // not published in API documentation FileHeader chunking.FileHeader + WebRequest *http.Request foundHeaders map[string]bool } -func (p *paramChunkUploadRequestComplete) ProcessParameter(_ *http.Request) error { +func (p *paramChunkUploadRequestComplete) ProcessParameter(r *http.Request) error { + p.WebRequest = r if p.ContentType == "" { p.ContentType = "application/octet-stream" }