Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 10 additions & 21 deletions ebpf_ffi/cbpf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
#include <linux/kernel.h>
#include <netpacket/packet.h>

#include "ffi.h"

bool load_cbpf_program(void *prog_buff, size_t size, std::string &error,
int *socks) {
if (socketpair(AF_UNIX, SOCK_DGRAM, 0, socks) < 0) {
Expand Down Expand Up @@ -62,38 +64,25 @@ struct bpf_result validation_error(std::string error_message,
return serialize_proto(*vres);
}

struct bpf_result ffi_load_cbpf_program(void *prog_buff, size_t size,
int coverage_enabled,
uint64_t coverage_size) {
struct bpf_result ffi_load_cbpf_program(void *prog_buff, size_t size) {
std::string error_message;

struct coverage_data cover;
memset(&cover, 0, sizeof(struct coverage_data));
cover.fd = -1;
cover.coverage_size = coverage_size;
if (coverage_enabled) enable_coverage(&cover);

ValidationResult vres;

int socks[2] = {-1, -1};
if (!load_cbpf_program(prog_buff, size, error_message, socks)) {
// Return why we failed to load the program.
if (coverage_enabled) get_coverage_and_free_resources(&cover, &vres);
bool coverage_enabled = enable_coverage();
bool cbpf_loaded = load_cbpf_program(prog_buff, size, error_message, socks);
if (coverage_enabled) {
disable_coverage();
get_coverage(&vres);
}
if (!cbpf_loaded) {
return validation_error(error_message, &vres);
}

if (coverage_enabled) get_coverage_and_free_resources(&cover, &vres);

// Start building the validation result proto.
vres.set_socket_write(socks[0]);
vres.set_socket_read(socks[1]);
if (cover.fd != -1) {
vres.set_did_collect_coverage(true);
vres.set_coverage_size(cover.coverage_size);
vres.set_coverage_buffer(reinterpret_cast<uint64_t>(cover.coverage_buffer));
} else {
vres.set_did_collect_coverage(false);
}

if (socks[0] < 0) {
// Return why we failed to load the program.
Expand Down
6 changes: 2 additions & 4 deletions ebpf_ffi/cbpf.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#ifndef EBPF_FUZZER_EBPF_FFI_CBPF_H_
#define EBPF_FUZZER_EBPF_FFI_CBPF_H_

#include "ffi.h"
#include "ebpf_ffi/ffi.h"

extern "C" {

Expand All @@ -29,9 +29,7 @@ bool load_cbpf_program(void *prog_buff, size_t size, std::string &error,

// Loads a bpf program specified by |prog_buff| with |size| and returns struct
// with a serialized ValidationResult proto.
struct bpf_result ffi_load_cbpf_program(void *prog_buff, size_t size,
int coverage_enabled,
uint64_t coverage_size);
struct bpf_result ffi_load_cbpf_program(void *prog_buff, size_t size);

bool execute_cbpf_program(int prog_fd, uint8_t *input, uint8_t *output,
int input_length, std::string &error_message);
Expand Down
27 changes: 8 additions & 19 deletions ebpf_ffi/ebpf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,40 +115,29 @@ ValidationResult load_ebpf_program(EncodedProgram program, std::string &error) {
return res;
}

struct bpf_result ffi_load_ebpf_program(void *serialized_proto, size_t size,
int coverage_enabled,
uint64_t coverage_size) {
struct bpf_result ffi_load_ebpf_program(void *serialized_proto, size_t size) {
std::string error_message;

struct coverage_data cover;
memset(&cover, 0, sizeof(struct coverage_data));
cover.fd = -1;
cover.coverage_size = coverage_size;
if (coverage_enabled) enable_coverage(&cover);

std::string serialized_proto_string(
reinterpret_cast<const char *>(serialized_proto), size);
EncodedProgram program;
if (!program.ParseFromString(serialized_proto_string)) {
error_message = "Could not parse EncodedProgram proto";
}

bool coverage_enabled = enable_coverage();
ValidationResult vres = load_ebpf_program(program, error_message);
if (coverage_enabled) get_coverage_and_free_resources(&cover, &vres);

if (cover.fd != -1) {
vres.set_did_collect_coverage(true);
vres.set_coverage_size(cover.coverage_size);
vres.set_coverage_buffer(reinterpret_cast<uint64_t>(cover.coverage_buffer));
} else {
vres.set_did_collect_coverage(false);
vres.set_did_collect_coverage(false);
if (coverage_enabled) {
get_coverage(&vres);
disable_coverage();
}

vres.set_is_valid(true);
if (vres.program_fd() < 0) {
// Return why we failed to load the program.
vres.set_bpf_error(error_message);
vres.set_is_valid(false);
} else {
vres.set_is_valid(true);
}

return serialize_proto(vres);
Expand Down
4 changes: 1 addition & 3 deletions ebpf_ffi/ebpf.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@ ValidationResult load_ebpf_program(EncodedProgram program, std::string &error);

// Loads a bpf program specified by |prog_buff| with |size| and returns struct
// with a serialized ValidationResult proto.
struct bpf_result ffi_load_ebpf_program(void *serialized_proto, size_t size,
int coverage_enabled,
uint64_t coverage_size);
struct bpf_result ffi_load_ebpf_program(void *serialized_proto, size_t size);

bool get_map_elements(int map_fd, size_t map_size, std::vector<uint64_t> *res,
std::string &error);
Expand Down
87 changes: 56 additions & 31 deletions ebpf_ffi/ffi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,19 @@
#include "ebpf_ffi/ffi.h"

#include <arpa/inet.h>
#include <errno.h>
#include <fcntl.h>
#include <netinet/in.h>
#include <stdio.h>
#include <sys/ioctl.h>
#include <sys/mman.h>
#include <sys/socket.h>
#include <sys/syscall.h>
#include <unistd.h>

#include <string>
#include <unordered_set>

#include "absl/container/flat_hash_set.h"
#include "absl/strings/escaping.h"
#include "ffi.h"
#include "google/protobuf/message.h"
#include "google/protobuf/repeated_field.h"
#include "proto/ffi.pb.h"
Expand All @@ -46,6 +44,8 @@ using ebpf_fuzzer::ExecutionResult;
using ebpf_fuzzer::MapElements;
using ebpf_fuzzer::ValidationResult;

struct coverage_data *kCoverageData = nullptr;

// All the functions in this extern are FFIs intended to be invoked from go.
extern "C" {

Expand All @@ -64,43 +64,68 @@ bpf_result serialize_proto(const google::protobuf::Message &proto) {
return res;
}

void enable_coverage(struct coverage_data *coverage_info) {
int fd = open("/sys/kernel/debug/kcov", O_RDWR);
if (fd == -1) return;
/* Setup trace mode and trace size. */
if (ioctl(fd, KCOV_INIT_TRACE, coverage_info->coverage_size)) return;
/* Mmap buffer shared between kernel- and user-space. */
uint64_t *cover =
(uint64_t *)mmap(nullptr, coverage_info->coverage_size * sizeof(uint64_t),
PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
if ((void *)cover == MAP_FAILED) return;
/* Enable coverage collection on the current thread. */
if (ioctl(fd, KCOV_ENABLE, KCOV_TRACE_PC)) return;
/* Reset coverage from the tail of the ioctl() call. */
__atomic_store_n(&cover[0], 0, __ATOMIC_RELAXED);
coverage_info->fd = fd;
coverage_info->coverage_buffer = cover;
}

void get_coverage_and_free_resources(struct coverage_data *cstruct,
ValidationResult *vres) {
if (cstruct->fd == -1) return;
uint64_t trace_size =
__atomic_load_n(&cstruct->coverage_buffer[0], __ATOMIC_RELAXED);
void get_coverage(ValidationResult *vres) {
if (kCoverageData == nullptr || kCoverageData->fd == -1) return;
uint64_t trace_size = kCoverageData->coverage_buffer[0];

auto *coverage_addresses = vres->mutable_coverage_address();
absl::flat_hash_set<uint64_t> seen_address;
for (uint64_t i = 0; i < trace_size; i++) {
uint64_t addr = cstruct->coverage_buffer[i + 1];
uint64_t addr = kCoverageData->coverage_buffer[i + 1];
if (seen_address.find(addr) == seen_address.end()) {
coverage_addresses->Add(cstruct->coverage_buffer[i + 1]);
coverage_addresses->Add(addr);
seen_address.insert(addr);
}
}

ioctl(cstruct->fd, KCOV_DISABLE, 0);
close(cstruct->fd);
munmap(cstruct->coverage_buffer, cstruct->coverage_size * sizeof(uint64_t));
// reset kcov buffer.
memset(kCoverageData->coverage_buffer, 0, KCOV_SIZE * sizeof(uint64_t));
vres->set_did_collect_coverage(true);
vres->set_coverage_size(trace_size);
return;
}

bool enable_coverage() {
if (!kCoverageData || kCoverageData->fd == -1) return false;
return ioctl(kCoverageData->fd, KCOV_ENABLE, KCOV_TRACE_PC) == 0;
}

void disable_coverage() {
if (kCoverageData == nullptr || kCoverageData->fd == -1) return;
(void)ioctl(kCoverageData->fd, KCOV_DISABLE, 0);
}

int ffi_setup_coverage() {
if (!kCoverageData) {
kCoverageData =
(struct coverage_data *)malloc(sizeof(struct coverage_data));
memset(kCoverageData, 0, sizeof(struct coverage_data));
}

int fd = open("/sys/kernel/debug/kcov", O_RDWR);
kCoverageData->fd = fd;
if (fd == -1) return -1;
/* Setup trace mode and trace size. */
if (ioctl(fd, KCOV_INIT_TRACE, KCOV_SIZE)) return -1;
/* Mmap buffer shared between kernel- and user-space. */
uint64_t *cover = (uint64_t *)mmap(nullptr, KCOV_SIZE * sizeof(uint64_t),
PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
if ((void *)cover == MAP_FAILED) return -1;
memset(cover, 0, KCOV_SIZE * sizeof(uint64_t));

kCoverageData->fd = fd;
kCoverageData->coverage_buffer = cover;
return 0;
}

int ffi_cleanup_coverage() {
if (!kCoverageData) return 0;

close(kCoverageData->fd);
munmap(kCoverageData->coverage_buffer, KCOV_SIZE * sizeof(uint64_t));
free(kCoverageData);
kCoverageData = nullptr;
return 0;
}

bool execute_error(std::string &error_message, const char *strerr,
Expand Down
21 changes: 16 additions & 5 deletions ebpf_ffi/ffi.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@
#define KCOV_TRACE_PC 0
#define KCOV_TRACE_CMP 1

// 64mb for kcov coverage.
#define KCOV_SIZE 1024 * 1024 * 64

using ebpf_fuzzer::CbpfExecutionRequest;
using ebpf_fuzzer::EncodedProgram;
using ebpf_fuzzer::ExecutionRequest;
Expand All @@ -66,11 +69,6 @@ struct bpf_result {

bpf_result serialize_proto(const google::protobuf::Message &proto);

void enable_coverage(struct coverage_data *coverage_info);

void get_coverage_and_free_resources(struct coverage_data *cstruct,
ValidationResult *vres);

bool execute_error(std::string &error_message, const char *strerr,
int *sockets);

Expand All @@ -83,10 +81,23 @@ int ffi_create_bpf_map(size_t size);
// Closes the given file descriptor, this is to free up resources.
void ffi_close_fd(int fd);

// Enable kcov coverage.
int ffi_setup_coverage();

// Disble kcov coverage.
int ffi_cleanup_coverage();

bool enable_coverage();
void disable_coverage();
void get_coverage(ValidationResult *vres);

struct coverage_data {
int fd;
uint64_t coverage_size;
uint64_t *coverage_buffer;
};

extern struct coverage_data *kCoverageData;
}

#endif // EBPF_FUZZER_EBPF_FFI_FFI_H_
15 changes: 7 additions & 8 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,12 @@ import (

// Flags that the binary can accept.
var (
coverageBufferSize = flag.Uint64("coverage_buffer_size", 64<<20, "Size of the buffer passed to kcov to get coverage addresses, the higher the number, the slower coverage collection will be")
metricsThreshold = flag.Int("metrics_threshold", 200, "Collect detailed metrics (coverage) every `metrics_threshold` validated programs")
strategyName = flag.String("strategy", "playground", "Strategy to use for fuzzing")
vmLinuxPath = flag.String("vmlinux_path", "/root/vmlinux", "Path to the linux image that will be passed to addr2line to get coverage info")
sourceFilesPath = flag.String("src_path", "/root/sourceFiles", "The fuzzer will look for source files to visualize the coverage at this path")
metricsServerAddr = flag.String("metrics_server_addr", "0.0.0.0", "Address that the metrics server will listen to at")
metricsServerPort = flag.Uint("metrics_server_port", 8080, "Port that the metrics server will listen to at")
metricsThreshold = flag.Int("metrics_threshold", 200, "Collect detailed metrics (coverage) every `metrics_threshold` validated programs")
strategyName = flag.String("strategy", "playground", "Strategy to use for fuzzing")
vmLinuxPath = flag.String("vmlinux_path", "/root/vmlinux", "Path to the linux image that will be passed to addr2line to get coverage info")
sourceFilesPath = flag.String("src_path", "/root/sourceFiles", "The fuzzer will look for source files to visualize the coverage at this path")
metricsServerAddr = flag.String("metrics_server_addr", "0.0.0.0", "Address that the metrics server will listen to at")
metricsServerPort = flag.Uint("metrics_server_port", 8080, "Port that the metrics server will listen to at")
)

var (
Expand Down Expand Up @@ -79,7 +78,7 @@ func main() {
})

controlUnit := units.Control{}
metricsUnit := units.NewMetricsUnit(*metricsThreshold, *coverageBufferSize, *vmLinuxPath, *sourceFilesPath, *metricsServerAddr, uint16(*metricsServerPort), coverageManager)
metricsUnit := units.NewMetricsUnit(*metricsThreshold, *vmLinuxPath, *sourceFilesPath, *metricsServerAddr, uint16(*metricsServerPort), coverageManager)

if err := controlUnit.Init(&units.FFI{
MetricsUnit: metricsUnit,
Expand Down
4 changes: 4 additions & 0 deletions pkg/units/control.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ func (cu *Control) IsReady() bool {

// RunFuzzer kickstars the fuzzer in the mode that was specified at Init time.
func (cu *Control) RunFuzzer() error {
cu.ffi.InitKcov()
defer func() {
cu.ffi.CleanupKcov()
}()
for !cu.strat.IsFuzzingDone() {
prog, err := cu.strat.GenerateProgram(cu.ffi)
if err != nil {
Expand Down
Loading
Loading