diff --git a/ebpf_ffi/cbpf.cc b/ebpf_ffi/cbpf.cc index fc6e1b5..4c51583 100644 --- a/ebpf_ffi/cbpf.cc +++ b/ebpf_ffi/cbpf.cc @@ -20,6 +20,8 @@ #include #include +#include "ffi.h" + bool load_cbpf_program(void *prog_buff, size_t size, std::string &error, int *socks) { if (socketpair(AF_UNIX, SOCK_DGRAM, 0, socks) < 0) { @@ -62,38 +64,25 @@ struct bpf_result validation_error(std::string error_message, return serialize_proto(*vres); } -struct bpf_result ffi_load_cbpf_program(void *prog_buff, size_t size, - int coverage_enabled, - uint64_t coverage_size) { +struct bpf_result ffi_load_cbpf_program(void *prog_buff, size_t size) { std::string error_message; - struct coverage_data cover; - memset(&cover, 0, sizeof(struct coverage_data)); - cover.fd = -1; - cover.coverage_size = coverage_size; - if (coverage_enabled) enable_coverage(&cover); - ValidationResult vres; int socks[2] = {-1, -1}; - if (!load_cbpf_program(prog_buff, size, error_message, socks)) { - // Return why we failed to load the program. - if (coverage_enabled) get_coverage_and_free_resources(&cover, &vres); + bool coverage_enabled = enable_coverage(); + bool cbpf_loaded = load_cbpf_program(prog_buff, size, error_message, socks); + if (coverage_enabled) { + disable_coverage(); + get_coverage(&vres); + } + if (!cbpf_loaded) { return validation_error(error_message, &vres); } - if (coverage_enabled) get_coverage_and_free_resources(&cover, &vres); - // Start building the validation result proto. vres.set_socket_write(socks[0]); vres.set_socket_read(socks[1]); - if (cover.fd != -1) { - vres.set_did_collect_coverage(true); - vres.set_coverage_size(cover.coverage_size); - vres.set_coverage_buffer(reinterpret_cast(cover.coverage_buffer)); - } else { - vres.set_did_collect_coverage(false); - } if (socks[0] < 0) { // Return why we failed to load the program. diff --git a/ebpf_ffi/cbpf.h b/ebpf_ffi/cbpf.h index c1bc374..6a647c7 100644 --- a/ebpf_ffi/cbpf.h +++ b/ebpf_ffi/cbpf.h @@ -17,7 +17,7 @@ #ifndef EBPF_FUZZER_EBPF_FFI_CBPF_H_ #define EBPF_FUZZER_EBPF_FFI_CBPF_H_ -#include "ffi.h" +#include "ebpf_ffi/ffi.h" extern "C" { @@ -29,9 +29,7 @@ bool load_cbpf_program(void *prog_buff, size_t size, std::string &error, // Loads a bpf program specified by |prog_buff| with |size| and returns struct // with a serialized ValidationResult proto. -struct bpf_result ffi_load_cbpf_program(void *prog_buff, size_t size, - int coverage_enabled, - uint64_t coverage_size); +struct bpf_result ffi_load_cbpf_program(void *prog_buff, size_t size); bool execute_cbpf_program(int prog_fd, uint8_t *input, uint8_t *output, int input_length, std::string &error_message); diff --git a/ebpf_ffi/ebpf.cc b/ebpf_ffi/ebpf.cc index 7651d10..9d48acb 100644 --- a/ebpf_ffi/ebpf.cc +++ b/ebpf_ffi/ebpf.cc @@ -115,40 +115,29 @@ ValidationResult load_ebpf_program(EncodedProgram program, std::string &error) { return res; } -struct bpf_result ffi_load_ebpf_program(void *serialized_proto, size_t size, - int coverage_enabled, - uint64_t coverage_size) { +struct bpf_result ffi_load_ebpf_program(void *serialized_proto, size_t size) { std::string error_message; - struct coverage_data cover; - memset(&cover, 0, sizeof(struct coverage_data)); - cover.fd = -1; - cover.coverage_size = coverage_size; - if (coverage_enabled) enable_coverage(&cover); - std::string serialized_proto_string( reinterpret_cast(serialized_proto), size); EncodedProgram program; if (!program.ParseFromString(serialized_proto_string)) { error_message = "Could not parse EncodedProgram proto"; } + + bool coverage_enabled = enable_coverage(); ValidationResult vres = load_ebpf_program(program, error_message); - if (coverage_enabled) get_coverage_and_free_resources(&cover, &vres); - - if (cover.fd != -1) { - vres.set_did_collect_coverage(true); - vres.set_coverage_size(cover.coverage_size); - vres.set_coverage_buffer(reinterpret_cast(cover.coverage_buffer)); - } else { - vres.set_did_collect_coverage(false); + vres.set_did_collect_coverage(false); + if (coverage_enabled) { + get_coverage(&vres); + disable_coverage(); } + vres.set_is_valid(true); if (vres.program_fd() < 0) { // Return why we failed to load the program. vres.set_bpf_error(error_message); vres.set_is_valid(false); - } else { - vres.set_is_valid(true); } return serialize_proto(vres); diff --git a/ebpf_ffi/ebpf.h b/ebpf_ffi/ebpf.h index 57fa3b2..ce309c8 100644 --- a/ebpf_ffi/ebpf.h +++ b/ebpf_ffi/ebpf.h @@ -34,9 +34,7 @@ ValidationResult load_ebpf_program(EncodedProgram program, std::string &error); // Loads a bpf program specified by |prog_buff| with |size| and returns struct // with a serialized ValidationResult proto. -struct bpf_result ffi_load_ebpf_program(void *serialized_proto, size_t size, - int coverage_enabled, - uint64_t coverage_size); +struct bpf_result ffi_load_ebpf_program(void *serialized_proto, size_t size); bool get_map_elements(int map_fd, size_t map_size, std::vector *res, std::string &error); diff --git a/ebpf_ffi/ffi.cc b/ebpf_ffi/ffi.cc index 51f7c3e..d46e28e 100644 --- a/ebpf_ffi/ffi.cc +++ b/ebpf_ffi/ffi.cc @@ -15,10 +15,8 @@ #include "ebpf_ffi/ffi.h" #include -#include #include #include -#include #include #include #include @@ -26,10 +24,10 @@ #include #include -#include #include "absl/container/flat_hash_set.h" #include "absl/strings/escaping.h" +#include "ffi.h" #include "google/protobuf/message.h" #include "google/protobuf/repeated_field.h" #include "proto/ffi.pb.h" @@ -46,6 +44,8 @@ using ebpf_fuzzer::ExecutionResult; using ebpf_fuzzer::MapElements; using ebpf_fuzzer::ValidationResult; +struct coverage_data *kCoverageData = nullptr; + // All the functions in this extern are FFIs intended to be invoked from go. extern "C" { @@ -64,43 +64,68 @@ bpf_result serialize_proto(const google::protobuf::Message &proto) { return res; } -void enable_coverage(struct coverage_data *coverage_info) { - int fd = open("/sys/kernel/debug/kcov", O_RDWR); - if (fd == -1) return; - /* Setup trace mode and trace size. */ - if (ioctl(fd, KCOV_INIT_TRACE, coverage_info->coverage_size)) return; - /* Mmap buffer shared between kernel- and user-space. */ - uint64_t *cover = - (uint64_t *)mmap(nullptr, coverage_info->coverage_size * sizeof(uint64_t), - PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); - if ((void *)cover == MAP_FAILED) return; - /* Enable coverage collection on the current thread. */ - if (ioctl(fd, KCOV_ENABLE, KCOV_TRACE_PC)) return; - /* Reset coverage from the tail of the ioctl() call. */ - __atomic_store_n(&cover[0], 0, __ATOMIC_RELAXED); - coverage_info->fd = fd; - coverage_info->coverage_buffer = cover; -} - -void get_coverage_and_free_resources(struct coverage_data *cstruct, - ValidationResult *vres) { - if (cstruct->fd == -1) return; - uint64_t trace_size = - __atomic_load_n(&cstruct->coverage_buffer[0], __ATOMIC_RELAXED); +void get_coverage(ValidationResult *vres) { + if (kCoverageData == nullptr || kCoverageData->fd == -1) return; + uint64_t trace_size = kCoverageData->coverage_buffer[0]; auto *coverage_addresses = vres->mutable_coverage_address(); absl::flat_hash_set seen_address; for (uint64_t i = 0; i < trace_size; i++) { - uint64_t addr = cstruct->coverage_buffer[i + 1]; + uint64_t addr = kCoverageData->coverage_buffer[i + 1]; if (seen_address.find(addr) == seen_address.end()) { - coverage_addresses->Add(cstruct->coverage_buffer[i + 1]); + coverage_addresses->Add(addr); seen_address.insert(addr); } } - ioctl(cstruct->fd, KCOV_DISABLE, 0); - close(cstruct->fd); - munmap(cstruct->coverage_buffer, cstruct->coverage_size * sizeof(uint64_t)); + // reset kcov buffer. + memset(kCoverageData->coverage_buffer, 0, KCOV_SIZE * sizeof(uint64_t)); + vres->set_did_collect_coverage(true); + vres->set_coverage_size(trace_size); + return; +} + +bool enable_coverage() { + if (!kCoverageData || kCoverageData->fd == -1) return false; + return ioctl(kCoverageData->fd, KCOV_ENABLE, KCOV_TRACE_PC) == 0; +} + +void disable_coverage() { + if (kCoverageData == nullptr || kCoverageData->fd == -1) return; + (void)ioctl(kCoverageData->fd, KCOV_DISABLE, 0); +} + +int ffi_setup_coverage() { + if (!kCoverageData) { + kCoverageData = + (struct coverage_data *)malloc(sizeof(struct coverage_data)); + memset(kCoverageData, 0, sizeof(struct coverage_data)); + } + + int fd = open("/sys/kernel/debug/kcov", O_RDWR); + kCoverageData->fd = fd; + if (fd == -1) return -1; + /* Setup trace mode and trace size. */ + if (ioctl(fd, KCOV_INIT_TRACE, KCOV_SIZE)) return -1; + /* Mmap buffer shared between kernel- and user-space. */ + uint64_t *cover = (uint64_t *)mmap(nullptr, KCOV_SIZE * sizeof(uint64_t), + PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + if ((void *)cover == MAP_FAILED) return -1; + memset(cover, 0, KCOV_SIZE * sizeof(uint64_t)); + + kCoverageData->fd = fd; + kCoverageData->coverage_buffer = cover; + return 0; +} + +int ffi_cleanup_coverage() { + if (!kCoverageData) return 0; + + close(kCoverageData->fd); + munmap(kCoverageData->coverage_buffer, KCOV_SIZE * sizeof(uint64_t)); + free(kCoverageData); + kCoverageData = nullptr; + return 0; } bool execute_error(std::string &error_message, const char *strerr, diff --git a/ebpf_ffi/ffi.h b/ebpf_ffi/ffi.h index 5fdf37a..0e7ac52 100644 --- a/ebpf_ffi/ffi.h +++ b/ebpf_ffi/ffi.h @@ -48,6 +48,9 @@ #define KCOV_TRACE_PC 0 #define KCOV_TRACE_CMP 1 +// 64mb for kcov coverage. +#define KCOV_SIZE 1024 * 1024 * 64 + using ebpf_fuzzer::CbpfExecutionRequest; using ebpf_fuzzer::EncodedProgram; using ebpf_fuzzer::ExecutionRequest; @@ -66,11 +69,6 @@ struct bpf_result { bpf_result serialize_proto(const google::protobuf::Message &proto); -void enable_coverage(struct coverage_data *coverage_info); - -void get_coverage_and_free_resources(struct coverage_data *cstruct, - ValidationResult *vres); - bool execute_error(std::string &error_message, const char *strerr, int *sockets); @@ -83,10 +81,23 @@ int ffi_create_bpf_map(size_t size); // Closes the given file descriptor, this is to free up resources. void ffi_close_fd(int fd); +// Enable kcov coverage. +int ffi_setup_coverage(); + +// Disble kcov coverage. +int ffi_cleanup_coverage(); + +bool enable_coverage(); +void disable_coverage(); +void get_coverage(ValidationResult *vres); + struct coverage_data { int fd; uint64_t coverage_size; uint64_t *coverage_buffer; }; + +extern struct coverage_data *kCoverageData; } + #endif // EBPF_FUZZER_EBPF_FFI_FFI_H_ diff --git a/main.go b/main.go index cf0cc53..738e57e 100644 --- a/main.go +++ b/main.go @@ -27,13 +27,12 @@ import ( // Flags that the binary can accept. var ( - coverageBufferSize = flag.Uint64("coverage_buffer_size", 64<<20, "Size of the buffer passed to kcov to get coverage addresses, the higher the number, the slower coverage collection will be") - metricsThreshold = flag.Int("metrics_threshold", 200, "Collect detailed metrics (coverage) every `metrics_threshold` validated programs") - strategyName = flag.String("strategy", "playground", "Strategy to use for fuzzing") - vmLinuxPath = flag.String("vmlinux_path", "/root/vmlinux", "Path to the linux image that will be passed to addr2line to get coverage info") - sourceFilesPath = flag.String("src_path", "/root/sourceFiles", "The fuzzer will look for source files to visualize the coverage at this path") - metricsServerAddr = flag.String("metrics_server_addr", "0.0.0.0", "Address that the metrics server will listen to at") - metricsServerPort = flag.Uint("metrics_server_port", 8080, "Port that the metrics server will listen to at") + metricsThreshold = flag.Int("metrics_threshold", 200, "Collect detailed metrics (coverage) every `metrics_threshold` validated programs") + strategyName = flag.String("strategy", "playground", "Strategy to use for fuzzing") + vmLinuxPath = flag.String("vmlinux_path", "/root/vmlinux", "Path to the linux image that will be passed to addr2line to get coverage info") + sourceFilesPath = flag.String("src_path", "/root/sourceFiles", "The fuzzer will look for source files to visualize the coverage at this path") + metricsServerAddr = flag.String("metrics_server_addr", "0.0.0.0", "Address that the metrics server will listen to at") + metricsServerPort = flag.Uint("metrics_server_port", 8080, "Port that the metrics server will listen to at") ) var ( @@ -79,7 +78,7 @@ func main() { }) controlUnit := units.Control{} - metricsUnit := units.NewMetricsUnit(*metricsThreshold, *coverageBufferSize, *vmLinuxPath, *sourceFilesPath, *metricsServerAddr, uint16(*metricsServerPort), coverageManager) + metricsUnit := units.NewMetricsUnit(*metricsThreshold, *vmLinuxPath, *sourceFilesPath, *metricsServerAddr, uint16(*metricsServerPort), coverageManager) if err := controlUnit.Init(&units.FFI{ MetricsUnit: metricsUnit, diff --git a/pkg/units/control.go b/pkg/units/control.go index 32e42dd..0618142 100644 --- a/pkg/units/control.go +++ b/pkg/units/control.go @@ -86,6 +86,10 @@ func (cu *Control) IsReady() bool { // RunFuzzer kickstars the fuzzer in the mode that was specified at Init time. func (cu *Control) RunFuzzer() error { + cu.ffi.InitKcov() + defer func() { + cu.ffi.CleanupKcov() + }() for !cu.strat.IsFuzzingDone() { prog, err := cu.strat.GenerateProgram(cu.ffi) if err != nil { diff --git a/pkg/units/ffi.go b/pkg/units/ffi.go index f6a78ca..e8750a4 100644 --- a/pkg/units/ffi.go +++ b/pkg/units/ffi.go @@ -21,9 +21,9 @@ package units // char* serialized_proto; // size_t size; //}; -//struct bpf_result ffi_load_cbpf_program(void* prog_buff, size_t size, int coverage_enabled, unsigned long coverage_size); +//struct bpf_result ffi_load_cbpf_program(void* prog_buff, size_t size); //struct bpf_result ffi_execute_cbpf_program(void* serialized_proto, size_t length); -//struct bpf_result ffi_load_ebpf_program(void* serialized_proto, size_t size, int coverage_enabled, unsigned long coverage_size); +//struct bpf_result ffi_load_ebpf_program(void* serialized_proto, size_t size); //struct bpf_result ffi_execute_ebpf_program(void* serialized_proto, size_t length); //struct bpf_result ffi_get_map_elements(int map_fd, uint64_t map_size); //struct bpf_result ffi_get_map_elements_fd_array(uint64_t fd_array_addr, uint32_t idx, uint64_t map_size); @@ -31,6 +31,8 @@ package units //void ffi_close_fd(int fd); //int ffi_update_map_element(int map_fd, int key, uint64_t value); //void ffi_clean_fd_array(unsigned long long int addr, int size); +//int ffi_setup_coverage(); +//int ffi_cleanup_coverage(); import "C" import ( @@ -157,14 +159,8 @@ func (e *FFI) ValidateEbpfProgram(encodedProgram *fpb.EncodedProgram) (*fpb.Vali if len(encodedProgram.Program) == 0 && encodedProgram != nil { return nil, fmt.Errorf("cannot run empty program") } - shouldCollect, coverageSize := e.MetricsUnit.ShouldGetCoverage() - cbool := 0 - if shouldCollect { - cbool = 1 - } serializedProto, err := proto.Marshal(encodedProgram) - bpfVerifyResult := C.ffi_load_ebpf_program(unsafe.Pointer(&serializedProto[0]), C.ulong(len(serializedProto)), - C.int(cbool), C.ulong(coverageSize)) + bpfVerifyResult := C.ffi_load_ebpf_program(unsafe.Pointer(&serializedProto[0]), C.ulong(len(serializedProto))) res, err := validationProtoFromStruct(&bpfVerifyResult) if err != nil { return nil, err @@ -191,12 +187,7 @@ func (e *FFI) ValidateCbpfProgram(prog []cbpf.Filter) (*fpb.ValidationResult, er if len(prog) == 0 { return nil, fmt.Errorf("cannot run empty program") } - shouldCollect, coverageSize := e.MetricsUnit.ShouldGetCoverage() - cbool := 0 - if shouldCollect { - cbool = 1 - } - bpfVerifyResult := C.ffi_load_cbpf_program(unsafe.Pointer(&prog[0]), C.ulong(len(prog)), C.int(cbool) /*coverage_size=*/, C.ulong(coverageSize)) + bpfVerifyResult := C.ffi_load_cbpf_program(unsafe.Pointer(&prog[0]), C.ulong(len(prog))) res, err := validationProtoFromStruct(&bpfVerifyResult) if err != nil { return nil, err @@ -214,3 +205,15 @@ func (e *FFI) RunCbpfProgram(executionRequest *fpb.CbpfExecutionRequest) (*fpb.E res := C.ffi_execute_cbpf_program(unsafe.Pointer(&serializedProto[0]), C.ulong(len(serializedProto))) return executionProtoFromStruct(&res) } + +// InitKcov sets up all the required kcov structures. +func (e *FFI) InitKcov() { + if C.ffi_setup_coverage() != 0 { + fmt.Println("could not setup coverage correctly") + } +} + +// CleanupKcov destroys all the resources created for kcov. +func (e *FFI) CleanupKcov() { + C.ffi_cleanup_coverage() +} diff --git a/pkg/units/metrics_unit.go b/pkg/units/metrics_unit.go index 46c4f19..e5dd9c9 100644 --- a/pkg/units/metrics_unit.go +++ b/pkg/units/metrics_unit.go @@ -35,11 +35,6 @@ type Metrics struct { // the Metrics will collect detailed info (coverage, and verifier logs). SamplingThreshold int - // KCovSize represents the size of the coverage sample that kcov will - // collect, the bigger the sample the slower collecting coverage - // will be (but the more precies). - KCovSize uint64 - isKCovSupported bool // Since Processing coverage is a slow operation, we put all the @@ -95,28 +90,10 @@ func (mu *Metrics) validationResultProcessingRoutine() { } } -// ShouldGetCoverage has two purposes: record that a program is about -// to be passed by the verifier and return if the metrics unit wants to -// collect coverage information on it. -func (mu *Metrics) ShouldGetCoverage() (bool, uint64) { - mu.metricsCollection.recordVerifiedProgram() - if !mu.isKCovSupported { - return false, 0 - } - - if !mu.shouldCollectDetailedInfo() { - return false, 0 - } - return mu.isKCovSupported, mu.KCovSize -} - -func (mu *Metrics) shouldCollectDetailedInfo() bool { - return mu.metricsCollection.getProgramsVerified()%mu.SamplingThreshold == 0 -} - // RecordVerificationResults collects metrics from the provided // verification result proto. func (mu *Metrics) RecordVerificationResults(vr *fpb.ValidationResult) { + mu.metricsCollection.recordVerifiedProgram() if vr.GetIsValid() { mu.metricsCollection.recordValidProgram() } @@ -133,7 +110,7 @@ func (mu *Metrics) init() { } // NewMetricsUnit Creates a new Central Metrics Unit. -func NewMetricsUnit(threshold int, kcovSize uint64, vmLinuxPath, sourceFilesPath, metricsServerAddr string, metricsServerPort uint16, cm *CoverageManager) *Metrics { +func NewMetricsUnit(threshold int, vmLinuxPath, sourceFilesPath, metricsServerAddr string, metricsServerPort uint16, cm *CoverageManager) *Metrics { mc := &MetricsCollection{ coverageManager: cm, verifierVerdicts: make(map[string]int), @@ -146,7 +123,6 @@ func NewMetricsUnit(threshold int, kcovSize uint64, vmLinuxPath, sourceFilesPath } mu := &Metrics{ SamplingThreshold: threshold, - KCovSize: kcovSize, metricsCollection: mc, metricsServer: ms, } diff --git a/pkg/units/metrics_unit_test.go b/pkg/units/metrics_unit_test.go index 953ca8f..e02e05a 100644 --- a/pkg/units/metrics_unit_test.go +++ b/pkg/units/metrics_unit_test.go @@ -21,7 +21,6 @@ import ( ) func TestMetrics(t *testing.T) { - expectedKcovSize := uint64(42) cm := &CoverageManager{ coverageCache: make(map[uint64]string), coverageInfoMap: make(map[string][]int), @@ -35,29 +34,16 @@ func TestMetrics(t *testing.T) { metricsUnit := Metrics{ SamplingThreshold: 1, - KCovSize: expectedKcovSize, isKCovSupported: true, metricsCollection: metricsCollection, } - isKCovSupported, kCovSize := metricsUnit.ShouldGetCoverage() - if !isKCovSupported { - t.Errorf("isKCovSupported = %v, want = true", isKCovSupported) - } - - if kCovSize != expectedKcovSize { - t.Errorf("kCovSize = %d, want = %d", kCovSize, expectedKcovSize) - } - - if metricsUnit.metricsCollection.programsVerified != 1 { - t.Errorf("metrics unit did not advance the quantity of programs verified") - } - vr := &fpb.ValidationResult{ IsValid: true, DidCollectCoverage: true, } metricsUnit.RecordVerificationResults(vr) + if metricsUnit.metricsCollection.validPrograms != 1 { t.Errorf("metrics unit did not advance the quantity of valid programs") } diff --git a/tools/ffi.go b/tools/ffi.go index 83758a3..3f5320f 100644 --- a/tools/ffi.go +++ b/tools/ffi.go @@ -37,6 +37,7 @@ func EncodeEBPF(serializedProgram unsafe.Pointer, serializedProgramSize C.int, Program: encodedProg, Function: encodedfunc, Btf: program.Btf, + Maps: program.Maps, } // Then do magic to return it to C++ serializedProto, err := proto.Marshal(result) diff --git a/tools/loader.cc b/tools/loader.cc index c3ff9bd..f670eff 100644 --- a/tools/loader.cc +++ b/tools/loader.cc @@ -32,9 +32,6 @@ int main(int argc, char **argv) { return -1; } - const int map_size = 2; - int map_fd = bpf_create_map(BPF_MAP_TYPE_ARRAY, sizeof(uint32_t), - sizeof(uint64_t), map_size); std::string verifier_log, error_message; std::string serialized_proto_string( reinterpret_cast(serialized_proto), size); @@ -58,17 +55,7 @@ int main(int argc, char **argv) { return -1; } - std::vector map_elements; - if (!get_map_elements(map_fd, map_size, &map_elements, error_message)) { - std::cerr << "Could not get map elements: " << error_message << std::endl; - return -1; - } - - std::cout << "map elements: " << std::endl; - for (auto element : map_elements) { - std::cout << "element: " << element << std::endl; - } - + ffi_clean_fd_array(vres.fd_array_addr(), program.maps().size()); free(serialized_proto); return 0; }