Fork me on GitHub

ND4J, Scala & Scientific Computing

ND4J’s Scala API is called ND4S, which lives here on Github.

Before using ND4S, please make sure you have:

Below is an example of how ND4S looks. Notice how similar the syntax is to Numpy.

Warning: collection like operations such as:

    Nd4j.ones(4).map(_+1)

are currently broken.

    object ScalaMain {
         def main (args: Array[String]) {
         
         /** Creating arrays in multiple ways, all using numpy syntax */
         
         var arr = Nd4j.create(4)
         var arr2 = Nd4j.ones(4)
         val arr3 = Nd4j.linspace(1, 10, 10)
         val arr4 = Nd4j.linspace(1, 6, 6).reshape(2, 3)
         
        /** Array addition in place */

        arr += arr2
        arr += 2
        
        /** Array multiplication in place */
        
         arr2 *= 5
         
        /** Transpose matrix */
        
        val arrT = arr.T
        
        /** Row (0) and Column (1) Sums */
        
        println(Nd4j.sum(arr4, 0) + "Calculate the sum for each row")
        println(Nd4j.sum(arr4, 1) + "Calculate the sum for each column")
        
        /** Checking array shape */
        
        println(Arrays.toString(arr2.shape) + "Checking array shape")
        
        /** Converting array to a string */
        
        println(arr2.toString() + "Array converted to string")
        
        /** Filling the array with the value 5 (same as numpy's fill method) */
        
        println(arr2.assign(5) + "Array assigned value of 5 (equivalent to fill method in numpy)")
        
        /** Reshaping the array */
        
        println(arr2.reshape(2, 2) + "Reshaping array")
        
        /** Raveling the array (returns a flattened array) */
        
        println(arr2.ravel + "Raveling array")
        
        /** Flattening the array (same as numpy's flatten method) */
        
        println(Nd4j.toFlattened(arr2) + "Flattening array (equivalent to flatten in numpy)")
        
        /** Array sorting */
        
        println(Nd4j.sort(arr2, 0, true) + "Sorting array")
        println(Nd4j.sortWithIndices(arr2, 0, true) + "Sorting array and returning sorted indices")
        
        /** Cumulative sum */
        
        println(Nd4j.cumsum(arr2) + "Cumulative sum")
        
        /** Basic stats methods */
        
        println(Nd4j.mean(arr) + "Calculate mean of array")
        println(Nd4j.std(arr2) + "Calculate standard deviation of array")
        println(Nd4j.`var`(arr2), "Calculate variance")
        
        /** Find min and max values */
        
        println(Nd4j.max(arr3), "Find max value in array")
        println(Nd4j.min(arr3), "Find min value in array")
        
          }
         }