/*
 * Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 *
 * This program is free software; you can redistribute it and/or modify it
 * under the terms and conditions of the GNU General Public License,
 * version 2, as published by the Free Software Foundation.
 *
 * This program is distributed in the hope it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
 * more details.
 */

#include <vmlinux.h>

#include <bpf/bpf_core_read.h>
#include <bpf/bpf_helpers.h>
#include <bpf/bpf_tracing.h>

#include "syscall_trace.h"

struct tls_data {
    u64 timestamp_start;
    u64 syscall_nr;
    u32 tgid;
    u32 pid;
    u16 seq_nr;
};

/*
 * To filter in containers, uniquely identify the PID namespace of the processes
 * to trace.
 * https://lore.kernel.org/bpf/20200304204157.58695-1-cneirabustos@gmail.com/
 * The following two variables must be set before loading the BPF programs.
 */
volatile const __u64 allowed_pidns_dev;
volatile const __u64 allowed_pidns_ino;

/*
 * The size of the buffer storing user space backtraces, it is essentially
 * <max number of entries> * sizeof(syscall_trace_data.ustack[0]).
 * Must be set before loading `nsys_trace_sys_exit` where it is used.
 */
volatile const __u64 backtrace_buffer_size_bytes;

/*
 * The minimum running time (in ns) for a syscall to be traced.
 * Must be set before loading `nsys_trace_sys_exit` where it is used.
 */
volatile const __u64 threshold;

/*
 * The minimum running time (in ns) for a syscall to attach its backtrace to
 * the trace.
 * Must be set before loading `nsys_trace_sys_exit` where it is used.
 */
volatile const __u64 backtrace_threshold;

/**
 * This variable must be set before loading nsys_trace_sys_enter/exit programs.
 * It defines how syscalls traces are filtered, when it is:
 *   `false`: it assumes that there is a traced app, it's original PID is
 *            provided in `traced_app_tgids` and then forks are traced to
 *            populate this set with sub-processes. Only syscalls made by these
 *            sub-processes are traced.
 *   `true`:  syscalls made by any task that belong to the traced PID namespace
 *            (identified by `allowed_pidns_dev` and `allowed_pidns_ino`) and
 *            its child namespaces are traced.
 */
volatile const bool is_pid_namespace_wide;

/*
 * The following map is used to store the TGIDs (as seen from the default PID
 * namespace) of the processes that need to be traced. The field `max_entries`
 * must be set before loading the BPF programs. After the BPF programs have been
 * loaded, this map must be updated with the TGIDs (as seen from the default PID
 * namespace) of the processes that need to be traced.
 */
SEC(".maps")
struct {
    __uint(type, BPF_MAP_TYPE_HASH);
    __type(key, u32);
    __type(value, u32);
} traced_app_tgids;

/*
 * The following map is used to store information for each desired `task_struct`
 * object in local storage. The BPF iterator `nsys_delete_tls` must be triggered
 * before (re)attaching the programs `nsys_trace_sys_enter` and
 * `nsys_trace_sys_exit`. This will ensure that there is no stale data in the
 * local storage from any previous attachments of the BPF programs
 * `nsys_trace_sys_enter` and `nsys_trace_sys_exit`.
 */
SEC(".maps")
struct {
    __uint(type, BPF_MAP_TYPE_TASK_STORAGE);
    __uint(map_flags, BPF_F_NO_PREALLOC);
    __type(key, int);
    __type(value, struct tls_data);
} tls_map;

/*
 * The following map is a ring buffer that is used to transfer trace data
 * records from the kernel space to the user space. The field `max_entries` must
 * be set before loading the BPF programs.
 */
SEC(".maps")
struct {
    __uint(type, BPF_MAP_TYPE_RINGBUF);
} ring_buffer;

/**
 * Returns PID and TGID of the current task if it belongs to the monitored PID
 * namespace (defined by `allowed_pidns_dev` and `allowed_pidns_ino`).
 * The values are given as they are seen in this PID namespace. This behavior
 * is dictated by `bpf_get_ns_current_pid_tgid()`. Errors of the fucntion are
 * propagated through, see the function documentation for details.
 */
static long get_allowed_ns_pidns_info(struct bpf_pidns_info *nsdata)
{
    return bpf_get_ns_current_pid_tgid(allowed_pidns_dev, allowed_pidns_ino,
                                       nsdata, sizeof(*nsdata));
}

static bool is_in_traced_app_process_tree(__u32 tgid)
{
    return bpf_map_lookup_elem(&traced_app_tgids, &tgid);
}

static struct pid_namespace *get_pidns(struct task_struct *task)
{
    struct pid *pid = BPF_CORE_READ(task, thread_pid);
    unsigned int level = BPF_CORE_READ(pid, level);

    return BPF_CORE_READ(pid, numbers[level].ns);
}

static int get_tgid_vnr(struct task_struct *task)
{
    struct pid *pid = BPF_CORE_READ(task, signal, pids[PIDTYPE_TGID]);
    unsigned int level = BPF_CORE_READ(pid, level);

    return BPF_CORE_READ(pid, numbers[level].nr);
}

/**
 * Returns PID and TGID of the task as seen from the monitored PID namespace
 * (defined by `allowed_pidns_dev` and `allowed_pidns_ino`) perspective.
 * The main difference from `get_allowed_ns_pidns_info()` is that it supports
 * child PID namespaces (e.g. containers of the host), hence "subtreee" in
 * the name.
 */
static long get_allowed_ns_subtree_pidns_info(struct task_struct *task,
                                              struct bpf_pidns_info *nsdata)
{
    /*
     * To check whether the `task` runs in a sub-namespace of the monitored PID
     * namespace, this function traverses the PID namespace hierarchy starting from
     * the one the `task` belongs to and going up to the monitored PID (or the
     * top of the tree). The implementation is based on the assumption that the
     * PID namespace hierarchy is stable (e.g. it isn't changed during the
     * lifetime of the `task`), as of now, there is no way to restructure the
     * hierarchy (only adding and removing of nodes is possible).
     */
    struct pid_namespace *current_pid_ns = get_pidns(task);
    int ns_traceable_depth = 16;
    struct upid tgid_upid;
    struct upid pid_upid;
    unsigned int level;

    while (current_pid_ns != 0 && ns_traceable_depth > 0) {
        /* According to the comment on `bpf_get_ns_current_pid_tgid()`
         * https://lore.kernel.org/bpf/20200304204157.58695-1-cneirabustos@gmail.com/,
         * the device may need to be checked too in future. Currently
         * `bpf_get_ns_current_pid_tgid()` checks a global variable which we
         * have no access to.
         */
        if (allowed_pidns_ino == BPF_CORE_READ(current_pid_ns, ns.inum)) {
            level = BPF_CORE_READ(current_pid_ns, level);
            tgid_upid =
                BPF_CORE_READ(task, signal, pids[PIDTYPE_TGID], numbers[level]);
            pid_upid = BPF_CORE_READ(task, thread_pid, numbers[level]);

            if (tgid_upid.nr != 0 && pid_upid.nr != 0 &&
                tgid_upid.ns == current_pid_ns &&
                pid_upid.ns == current_pid_ns) {
                nsdata->tgid = tgid_upid.nr;
                nsdata->pid = pid_upid.nr;
                return 0;
            } else {
                break;
            }
        }
        current_pid_ns = BPF_CORE_READ(current_pid_ns, parent);
        --ns_traceable_depth;
    }
    // -ENOENT, to mimic the behavior of `bpf_get_ns_current_pid_tgid()`.
    return -2;
}

/*
 * The following program is used to track forks from the processes that need to
 * be traced. This program must be attached at all times, even if tracing is not
 * being done. This will ensure that the BPF map `traced_app_tgids` is kept up to
 * date.
 */
SEC("tp_btf/sched_process_fork")
int BPF_PROG(nsys_track_sched_process_fork, struct task_struct *parent,
                                            struct task_struct *child)
{
    struct bpf_pidns_info nsdata;
    long err;
    u32 tgid;

    /* The current task is `parent` */
    if (get_allowed_ns_pidns_info(&nsdata) ||
        !is_in_traced_app_process_tree(nsdata.tgid) ||
        get_pidns(child) != get_pidns(parent))
        goto out;

    tgid = get_tgid_vnr(child);
    if (tgid == nsdata.tgid)
        goto out;

    err = bpf_map_update_elem(&traced_app_tgids, &tgid, &tgid, BPF_NOEXIST);
    if (err) {
        bpf_printk("%s: bpf_map_update_elem() error: %ld", __func__, -err);
        goto out;
    }

out:
    return 0;
}

/*
 * The following program is used to track the exits of the processes that need
 * to be traced. This program must be attached at all times, even if tracing is
 * not being done. This will ensure that the BPF map `traced_app_tgids` is kept up
 * to date.
 */
SEC("tp_btf/sched_process_exit")
int BPF_PROG(nsys_track_sched_process_exit, struct task_struct *task)
{
    struct bpf_pidns_info nsdata;
    long err;

    /* The current task is `task` */
    if (get_allowed_ns_pidns_info(&nsdata) ||
        !is_in_traced_app_process_tree(nsdata.tgid) ||
        nsdata.tgid != nsdata.pid)
        goto out;

    err = bpf_map_delete_elem(&traced_app_tgids, &nsdata.tgid);
    if (err) {
        bpf_printk("%s: bpf_map_delete_elem() error: %ld", __func__, -err);
        goto out;
    }

out:
    return 0;
}

/*
 * The following program is used to trace the `sys_enter` tracepoint.
 */
SEC("tp_btf/sys_enter")
int BPF_PROG(nsys_trace_sys_enter, struct pt_regs *regs, long syscall_nr)
{
    struct task_struct *task = bpf_get_current_task_btf();
    struct bpf_pidns_info nsdata;
    struct tls_data *td;

    td = bpf_task_storage_get(&tls_map, task, NULL, 0);
    if (!td) {
        if (get_allowed_ns_subtree_pidns_info(task, &nsdata))
            goto out;
        if (!is_in_traced_app_process_tree(nsdata.tgid) && !is_pid_namespace_wide)
            goto out;

        td = bpf_task_storage_get(&tls_map, task, NULL,
                                  BPF_LOCAL_STORAGE_GET_F_CREATE);
        if (!td) {
            bpf_printk("%s: bpf_task_storage_get() error", __func__);
            goto out;
        }

        td->tgid = nsdata.tgid;
        td->pid = nsdata.pid;
    }

    td->timestamp_start = bpf_ktime_get_ns();
    td->syscall_nr = syscall_nr;
    td->seq_nr += 1;

out:
    return 0;
}

/*
 * The following program is used to trace the `sys_exit` tracepoint. Since trace
 * data records are submitted into the BPF map `ring_buffer` from this program,
 * this program must be detached before detaching the BPF program
 * `nsys_trace_sys_enter`. This will ensure that there is no incorrect data in
 * the BPF map `ring_buffer`.
 */
SEC("tp_btf/sys_exit")
int BPF_PROG(nsys_trace_sys_exit, struct pt_regs *regs, long ret)
{
    struct task_struct *task = bpf_get_current_task_btf();
    struct syscall_trace_data *data;
    __u64 timestamp_end;
    __u64 running_time;
    struct tls_data *td;

    td = bpf_task_storage_get(&tls_map, task, NULL, 0);
    if (!td)
        goto out;

    timestamp_end = bpf_ktime_get_ns();
    running_time = timestamp_end - td->timestamp_start;

    if (running_time < threshold) {
        /**
         * `seq_nr` is used to identify missing (e.g. due to congestion) syscall
         * traces and highlite the region red on the timeline. Revering
         * the increment of `seq_nr` (made in `nsys_trace_sys_enter`) allows
         * to not treat these traces as missing, as they heve a valid reason
         * for it.
         */
        td->seq_nr -= 1;
        goto out;
    }

    /*
     * Check whether the syscall's been running long enough to add a backtrace.
     * Two `bpf_ringbuf_reserve` calls in both if/else branches are deliberate
     * (as an alternative to having a single call with 2 different `size`'s),
     * otherwise it's `size` argument can be mistakenly restricted by
     * the verifier, leaving only the smaller value and making the whole
     * program invalid.
     */
    if (running_time >= backtrace_threshold &&
        (!is_pid_namespace_wide || is_in_traced_app_process_tree(td->tgid))) {
        long bytes_copied;

        /* See `syscall_trace_data.ustack` for why extra capacity is needed. */
        data = bpf_ringbuf_reserve(&ring_buffer,
                                   sizeof(*data) + backtrace_buffer_size_bytes,
                                   0);
        if (!data) {
            bpf_printk("%s: bpf_ringbuf_reserve() error", __func__);
            goto out;
        }

        bytes_copied = bpf_get_stack(ctx, data->ustack,
                                     backtrace_buffer_size_bytes,
                                     BPF_F_USER_STACK);
        if (bytes_copied < 0) {
            bpf_printk("%s: bpf_get_stack() error: %ld", __func__, -bytes_copied);
            data->ustack_byte_size = 0;
        } else {
            data->ustack_byte_size = bytes_copied;
        }
    } else {
        data = bpf_ringbuf_reserve(&ring_buffer, sizeof(*data), 0);
        if (!data) {
            bpf_printk("%s: bpf_ringbuf_reserve() error", __func__);
            goto out;
        }
        data->ustack_byte_size = 0;
    }

    data->timestamp_start = td->timestamp_start;
    data->timestamp_end = timestamp_end;
    data->syscall_nr = td->syscall_nr;
    data->tgid = td->tgid;
    data->pid = td->pid;
    data->seq_nr = td->seq_nr;

    bpf_ringbuf_submit(data, 0);

out:
    return 0;
}

/*
 * The following program is used to delete the local storage associated with
 * the BPF map `tls_map` from every `task_struct` object.
 */
SEC("iter/task")
int nsys_delete_tls(struct bpf_iter__task *ctx)
{
    struct task_struct *task = ctx->task;
    long err;

    if (!task || !bpf_task_storage_get(&tls_map, task, NULL, 0))
        goto out;

    err = bpf_task_storage_delete(&tls_map, task);
    if (err) {
        bpf_printk("%s: bpf_task_storage_delete() error: %ld", __func__, -err);
        goto out;
    }

out:
    return 0;
}

char LICENSE[] SEC("license") = "GPL v2";
