ND4J’s Scala API is called ND4S, which lives here on Github.

Before using ND4S, please make sure you have:

Below is an example of how ND4S looks. Notice how similar the syntax is to Numpy.

Warning: collection like operations such as:

    Nd4j.ones(4).map(_+1)

are currently broken.

    object ScalaMain {
         def main (args: Array[String]) {

         /** Creating arrays in multiple ways, all using numpy syntax */

         var arr = Nd4j.create(4)
         var arr2 = Nd4j.ones(4)
         val arr3 = Nd4j.linspace(1, 10, 10)
         val arr4 = Nd4j.linspace(1, 6, 6).reshape(2, 3)

        /** Array addition in place */

        arr += arr2
        arr += 2

        /** Array multiplication in place */

         arr2 *= 5

        /** Transpose matrix */

        val arrT = arr.T

        /** Row (0) and Column (1) Sums */

        println(Nd4j.sum(arr4, 0) + "Calculate the sum for each row")
        println(Nd4j.sum(arr4, 1) + "Calculate the sum for each column")

        /** Checking array shape */

        println(Arrays.toString(arr2.shape) + "Checking array shape")

        /** Converting array to a string */

        println(arr2.toString() + "Array converted to string")

        /** Filling the array with the value 5 (same as numpy's fill method) */

        println(arr2.assign(5) + "Array assigned value of 5 (equivalent to fill method in numpy)")

        /** Reshaping the array */

        println(arr2.reshape(2, 2) + "Reshaping array")

        /** Raveling the array (returns a flattened array) */

        println(arr2.ravel + "Raveling array")

        /** Flattening the array (same as numpy's flatten method) */

        println(Nd4j.toFlattened(arr2) + "Flattening array (equivalent to flatten in numpy)")

        /** Array sorting */

        println(Nd4j.sort(arr2, 0, true) + "Sorting array")
        println(Nd4j.sortWithIndices(arr2, 0, true) + "Sorting array and returning sorted indices")

        /** Cumulative sum */

        println(Nd4j.cumsum(arr2) + "Cumulative sum")

        /** Basic stats methods */

        println(Nd4j.mean(arr) + "Calculate mean of array")
        println(Nd4j.std(arr2) + "Calculate standard deviation of array")
        println(Nd4j.`var`(arr2), "Calculate variance")

        /** Find min and max values */

        println(Nd4j.max(arr3), "Find max value in array")
        println(Nd4j.min(arr3), "Find min value in array")

          }
         }