Creating a symbolic analyzer for JVM bytecode

    The link to the project repository will be added later, as the project has not been announced yet.

    Use case

    To develop a symbolic analyzer for JVM bytecode, one needs to get the three-address code intermediate representation (3-address code IR) for each method in a program. Here is how it can be implemented using the JacoDB and KSMT libraries.

    Implementation

    Firstly, we need to get the 3-address code IR for a method. JacoDB has its own 3-address code IR similar to Jimple.

    val jcClass = TODO("obtain target class somehow")
    val method = jcClass.declaredMethods.first { it.name == TODO("Target method name") }
    val instList = method.instList // 3-address IR

    Secondly, we need a symbolic state:

    data class State(
        val path: List<JcInst>, // Execution path
        val memory: Map<String, KExpr<out KSort>>, // Symbolic memory
        val pathCondition: List<KExpr<KBoolSort>>, // Path condition
    )

    To interpret a symbolic state, we need to take the last instruction in the path and process it accordingly to its type:

    fun interpret(state: State): Collection<State> {
        // 0. Get the last instruction in the path
        val inst = state.path.last()
        // 1. Process the last instruction symbolically
        return when (inst) {
            is JcAssignInst -> interpretAssignInst(state, inst) // Assign instruction, i.e. n = i + 42;
            is JcIfInst -> interpretIfStmt(state, inst) // Branching instruction, i.e. if (n == 101)
            else -> TODO("Process other instructions as well")
        }
    }

    For example, to process the assign statement, we need to resolve the right operand of the assign instruction (symbolic expression) and the left operand of the assign operation (memory location):

    private fun interpretAssignInst(state: State, inst: JcAssignInst): Collection<State> {
        // 0. We need to resolve symbolic expression from [inst.rhv]
        val expr = resolveExpr(state, inst.rhv)
    
        // 1. We need to write the [expr] to the symbolic memory
        val nextState = when (val lhv = inst.lhv) {
            is JcLocalVar -> {
                // find the next instruction
                val nextInst = inst.location.method.instList[inst.location.lineNumber + 1]
                state.copy(
                    // update the symbolic memory with the new write
                    memory = state.memory + (lhv.name to expr),
                    // update the execution path
                    path = state.path + nextInst
                )
            }
    
            else -> TODO("Process other JcValues as well")
        }
        return listOf(nextState)
    }

    Resolving JcExpr is straightforward. We match on the expression type and use KSMT builder functions.

    private fun resolveExpr(state: State, expr: JcExpr): KExpr<out KSort> =
        when (expr) {
            is JcAddExpr -> {
                // resolve the left operand recursively
                val leftOperand = resolveExpr(state, expr.lhv) as KExpr<KBvSort>
                // resolve the right operand recursively
                val rightOperand = resolveExpr(state, expr.rhv) as KExpr<KBvSort>
                // make a symbolic add expression with KSMT builder functions
                leftOperand.ctx.mkBvAddExpr(leftOperand, rightOperand)
            }
    
            // resolve an int constant
            is JcInt -> ctx.mkBv(expr.value, 32U)
    
            is JcEqExpr -> {
                // resolve the left operand recursively
                val leftOperand = resolveExpr(state, expr.lhv)
                // resolve the right operand recursively
                val rightOperand = resolveExpr(state, expr.rhv)
                // make a symbolic eq expression with KSMT builder functions
                when (leftOperand.sort) {
                    // both operands of the boolean sort
                    is KBoolSort -> leftOperand.ctx.mkEq(
                        leftOperand as KExpr<KBoolSort>,
                        rightOperand as KExpr<KBoolSort>
                    )
                    // both operands of the bv sort
                    is KBoolSort -> leftOperand.ctx.mkEq(
                        leftOperand as KExpr<KBvSort>,
                        rightOperand as KExpr<KBvSort>
                    )
    
                    else -> TODO("Process other sorts as well")
                }
            }
    
            is JcLocalVar -> state.memory.getValue(expr.name) // get the variable from the current symbolic memory
            else -> TODO("Process other JcExprs as well")
        }

    To process the branching instruction, we need to resolve a symbolic condition and use the solver to check both condition and !condition.

    private fun interpretIfStmt(state: State, inst: JcIfInst): Collection<State> {
        val nextStates = mutableListOf<State>()
    
        // 0. We need to resolve symbolic condition, which should be of the boolean sort
        val expr = resolveExpr(state, inst.condition) as KExpr<KBoolSort>
    
        // 1. Resolve the positive branch
        val positivePathCondition = state.pathCondition + expr
        // 1.0. Check for satisfiability
        if (solver.check(positivePathCondition)) {
            // 1.1. Find the positive instruction
            val positiveInst = inst.location.method.instList[inst.trueBranch.index]
            // 1.2. Add the instruction to the path
            val positivePath = state.path + positiveInst
            // 1.3. Add the positive state to the result list
            val positiveState = state.copy(path = positivePath, pathCondition = positivePathCondition)
            nextStates += positiveState
        }
    
        // 2. Resolve the negative branch
        val negativePathCondition = state.pathCondition + expr.run { ctx.mkNot(this) }
        // 1.0. Check for satisfiability
        if (solver.check(negativePathCondition)) {
            // 2.1. Find the negative instruction
            val negativeInst = inst.location.method.instList[inst.falseBranch.index]
            // 2.2. Add the instruction to the path
            val negativePath = state.path + negativeInst
            // 2.3. Add the negative state to the result list
            val negativeState = state.copy(path = negativePath, pathCondition = negativePathCondition)
            nextStates += negativeState
        }
    
        // 3. Return the result states
        return nextStates
    }

    The complete code example can be found below:

    package org.jacodb.example
    
    import io.ksmt.KContext
    import io.ksmt.expr.KExpr
    import io.ksmt.sort.KBoolSort
    import io.ksmt.sort.KBvSort
    import io.ksmt.sort.KSort
    import org.jacodb.api.JcMethod
    import org.jacodb.api.cfg.JcAddExpr
    import org.jacodb.api.cfg.JcAssignInst
    import org.jacodb.api.cfg.JcEqExpr
    import org.jacodb.api.cfg.JcExpr
    import org.jacodb.api.cfg.JcIfInst
    import org.jacodb.api.cfg.JcInst
    import org.jacodb.api.cfg.JcInt
    import org.jacodb.api.cfg.JcLocalVar
    
    /**
     * A data class, representing symbolic state.
     *
     * Let's assume we use KSMT for storing symbolic expressions.
     */
    data class State(
        val path: List<JcInst>, // Execution path
        val memory: Map<String, KExpr<out KSort>>, // Symbolic memory
        val pathCondition: List<KExpr<KBoolSort>>, // Path condition
    )
    
    /**
     * A solver interface
     */
    interface Solver {
        /**
         * @return true, if [pathCondition] is satisfiable, and false otherwise.
         */
        fun check(pathCondition: List<KExpr<KBoolSort>>): Boolean
    }
    
    class Analyzer(
        ctx: KContext,
        solver: Solver,
    ) {
        // 0.1. Initialize an interpreter
        private val interpreter = Interpreter(solver, ctx)
    
        fun analyze(method: JcMethod) {
            // 1.0. Find the first instruction
            val inst = method.instList.first()
            // 1.1. Initialize an initial symbolic state
            val initialState = State(listOf(inst), emptyMap(), emptyList())
            // 1.2. Add the initial state to the BFS-order queue
            val queue = ArrayDeque(listOf(initialState))
            while (queue.isNotEmpty()) {
                // 2.0. Pop the next state from the queue
                val state = queue.removeFirst()
                // 2.1 Interpret the state symbolically and obtain next states
                val nextStates = interpreter.interpret(state)
                queue.addAll(nextStates)
                // 2.2. here you can check **interesting** states and do something with them
                // your own code
            }
        }
    }
    
    /**
     * An example interpreter.
     */
    class Interpreter(
        val solver: Solver,
        val ctx: KContext,
    ) {
        /**
         * Takes a [state] and returns the next states.
         */
        fun interpret(state: State): Collection<State> {
            // 0. Get the last instruction in the path
            val inst = state.path.last()
            // 1. Process the last instruction symbolically
            return when (inst) {
                is JcAssignInst -> interpretAssignInst(state, inst) // Assign instruction, i.e. n = i + 42;
                is JcIfInst -> interpretIfStmt(state, inst) // Branching instruction, i.e. if (n == 101)
                else -> TODO("Process other instructions as well")
            }
        }
    
        private fun interpretAssignInst(state: State, inst: JcAssignInst): Collection<State> {
            // 0. We need to resolve symbolic expression from [inst.rhv]
            val expr = resolveExpr(state, inst.rhv)
    
            // 1. We need to write the [expr] to the symbolic memory
            val nextState = when (val lhv = inst.lhv) {
                is JcLocalVar -> {
                    // find the next instruction
                    val nextInst = inst.location.method.instList[inst.location.lineNumber + 1]
                    state.copy(
                        // update the symbolic memory with the new write
                        memory = state.memory + (lhv.name to expr),
                        // update the execution path
                        path = state.path + nextInst
                    )
                }
    
                else -> TODO("Process other JcValues as well")
            }
            return listOf(nextState)
        }
    
        private fun interpretIfStmt(state: State, inst: JcIfInst): Collection<State> {
            val nextStates = mutableListOf<State>()
    
            // 0. We need to resolve symbolic condition, which should be of the boolean sort
            val expr = resolveExpr(state, inst.condition) as KExpr<KBoolSort>
    
            // 1. Resolve the positive branch
            val positivePathCondition = state.pathCondition + expr
            // 1.0. Check for satisfiability
            if (solver.check(positivePathCondition)) {
                // 1.1. Find the positive instruction
                val positiveInst = inst.location.method.instList[inst.trueBranch.index]
                // 1.2. Add the instruction to the path
                val positivePath = state.path + positiveInst
                // 1.3. Add the positive state to the result list
                val positiveState = state.copy(path = positivePath, pathCondition = positivePathCondition)
                nextStates += positiveState
            }
    
            // 2. Resolve the negative branch
            val negativePathCondition = state.pathCondition + expr.run { ctx.mkNot(this) }
            // 1.0. Check for satisfiability
            if (solver.check(negativePathCondition)) {
                // 2.1. Find the negative instruction
                val negativeInst = inst.location.method.instList[inst.falseBranch.index]
                // 2.2. Add the instruction to the path
                val negativePath = state.path + negativeInst
                // 2.3. Add the negative state to the result list
                val negativeState = state.copy(path = negativePath, pathCondition = negativePathCondition)
                nextStates += negativeState
            }
    
            // 3. Return the result states
            return nextStates
        }
    
        /**
         * Resolves an [expr] inside of a [state] into symbolic representation [KExpr] of some sort.
         */
        private fun resolveExpr(state: State, expr: JcExpr): KExpr<out KSort> =
            when (expr) {
                is JcAddExpr -> {
                    // resolve the left operand recursively
                    val leftOperand = resolveExpr(state, expr.lhv) as KExpr<KBvSort>
                    // resolve the right operand recursively
                    val rightOperand = resolveExpr(state, expr.rhv) as KExpr<KBvSort>
                    // make a symbolic add expression with KSMT builder functions
                    leftOperand.ctx.mkBvAddExpr(leftOperand, rightOperand)
                }
    
                // resolve int constant
                is JcInt -> ctx.mkBv(expr.value, 32U)
    
                is JcEqExpr -> {
                    // resolve the left operand recursively
                    val leftOperand = resolveExpr(state, expr.lhv)
                    // resolve the right operand recursively
                    val rightOperand = resolveExpr(state, expr.rhv)
                    // make a symbolic eq expression with KSMT builder functions
                    when (leftOperand.sort) {
                        // both operands of the boolean sort
                        is KBoolSort -> leftOperand.ctx.mkEq(
                            leftOperand as KExpr<KBoolSort>,
                            rightOperand as KExpr<KBoolSort>
                        )
                        // both operands of the bv sort
                        is KBoolSort -> leftOperand.ctx.mkEq(
                            leftOperand as KExpr<KBvSort>,
                            rightOperand as KExpr<KBvSort>
                        )
    
                        else -> TODO("Process other sorts as well")
                    }
                }
    
                is JcLocalVar -> state.memory.getValue(expr.name) // get the variable from the current symbolic memory
                else -> TODO("Process other JcExprs as well")
            }
    }
    }