Kaggle competition attempt. Not successful

Greetings everybody,

here I want to preserve for future usage my code which I've used to train model for one of Kagglecompetitions:

  1 package org.deeplearning4j.examples.convolution;
  2 
  3 import com.google.common.io.LittleEndianDataInputStream;
  4 import org.deeplearning4j.api.storage.StatsStorage;
  5 import org.deeplearning4j.datasets.iterator.BaseDatasetIterator;
  6 import org.deeplearning4j.datasets.iterator.FloatsDataSetIterator;
  7 import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator;
  8 import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
  9 import org.deeplearning4j.eval.Evaluation;
 10 import org.deeplearning4j.nn.api.OptimizationAlgorithm;
 11 import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
 12 import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 13 import org.deeplearning4j.nn.conf.Updater;
 14 import org.deeplearning4j.nn.conf.inputs.InputType;
 15 import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
 16 import org.deeplearning4j.nn.conf.layers.DenseLayer;
 17 import org.deeplearning4j.nn.conf.layers.OutputLayer;
 18 import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
 19 import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
 20 import org.deeplearning4j.nn.weights.WeightInit;
 21 import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
 22 import org.deeplearning4j.ui.api.UIServer;
 23 import org.deeplearning4j.ui.stats.StatsListener;
 24 import org.deeplearning4j.ui.storage.InMemoryStatsStorage;
 25 import org.deeplearning4j.util.ModelSerializer;
 26 import org.jetbrains.annotations.NotNull;
 27 import org.nd4j.linalg.activations.Activation;
 28 import org.nd4j.linalg.api.ndarray.INDArray;
 29 import org.nd4j.linalg.dataset.DataSet;
 30 import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
 31 import org.nd4j.linalg.factory.Nd4j;
 32 import org.nd4j.linalg.lossfunctions.LossFunctions;
 33 import org.slf4j.Logger;
 34 import org.slf4j.LoggerFactory;
 35 
 36 import java.io.*;
 37 import java.text.SimpleDateFormat;
 38 import java.util.*;
 39 import java.util.function.Consumer;
 40 import java.util.stream.Collectors;
 41 
 42 import org.apache.commons.io.FilenameUtils;
 43 
 44 /**
 45  * Created by Yuriy Zaletskyy
 46  */
 47 public class KaggleCompetition {
 48     private static final Logger log = LoggerFactory.getLogger(KaggleCompetition.class);
 49 
 50     public static void main(String[] args) throws Exception {
 51 
 52         int nChannels = 8; // Number of input channels
 53         int outputNum = 17; // The number of possible outcomes
 54         int batchSize = 50; // Test batch size
 55         int nEpochs = 1001; // Number of training epochs
 56         int iterations = 1; // Number of training iterations
 57         int seed = 123; //
 58         int learningSetSize = 800;
 59         int x = 256;
 60         int y = 330;
 61         int z = 8;
 62         int sizeInt = 4;
 63         int sizeOfOneVideo = x * y * z;
 64         int numberOfZones = 17;
 65 
 66         String labelsFileName = "d:\\Kaggle\\stage1_labels.csv";
 67         List<String> labels = ReadCsvFile(labelsFileName);
 68 
 69         String folderName = "d:\\Kaggle\\stage1_bins_resized\\";
 70 
 71         INDArray input = Nd4j.zeros(learningSetSize + 1, sizeOfOneVideo);
 72         INDArray outputKaggle = Nd4j.zeros(learningSetSize +1, numberOfZones);
 73         File folder = new File(folderName);
 74         List<String> fileNames = new ArrayList<String>(500);
 75 
 76         GetFileNames(folder, fileNames);
 77 
 78         int rowNumber = 0;
 79         Date timeMarker = new Date();
 80         SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
 81         System.out.println("Before reading files " + sdf.format(timeMarker));
 82         for(String fileName : fileNames)
 83         {
 84             InputStream inputStream = null;
 85             DataInputStream dataInputStream = null;
 86 
 87             dataInputStream = new DataInputStream(new FileInputStream(fileName));
 88 
 89             LittleEndianDataInputStream lendian = new LittleEndianDataInputStream(dataInputStream);
 90 
 91             List<Float> listOfFloats = new ArrayList<Float>();
 92 
 93             int fileSize = x * y * z * sizeInt;
 94 
 95             byte[] fileContent = new byte[fileSize];
 96             lendian.readFully(fileContent);
 97 
 98             ReadFromFile(listOfFloats, fileSize, fileContent);
 99 
100             lendian.close();
101             dataInputStream.close();
102 
103             File f = new File(fileName);
104             String containsFilter = FilenameUtils.removeExtension(f.getName());
105             List<String> outputStrings = labels.stream().filter(a -> a.contains(containsFilter))
106                 .collect(Collectors.toList());
107 
108 
109             int indexToTurnOn = getIndexToTurnOn(outputStrings);
110 
111             float[] zoneOut = new float[17];
112 
113             for(int i = 0; i < 17; i++)
114             {
115                 zoneOut[i] = 0.0f;
116                 if(i == indexToTurnOn)
117                 {
118                     zoneOut[i] = 1.0f;
119                 }
120             }
121 
122             float[] inputRow = new float[listOfFloats.size()];
123             int j = 0;
124             for(Float ff: listOfFloats)
125             {
126                 inputRow[j++] = (ff != null ? ff: Float.NaN);
127             }
128 
129             input.putRow(rowNumber, Nd4j.create(inputRow));
130 
131             outputKaggle.putRow(rowNumber, Nd4j.create(zoneOut));
132 
133             if(rowNumber > learningSetSize - 1)
134             {
135                 break;
136             }
137             rowNumber++;
138         }
139 
140         timeMarker = new Date();
141         System.out.println("After reading files " + sdf.format(timeMarker));
142 
143         System.out.println("Learning set loaded");
144         DataSet dsKaggleAll = new DataSet(input, outputKaggle);
145 
146         List<DataSet> listDs = dsKaggleAll.asList();
147 
148         Random rng = new Random(seed);
149         Collections.shuffle(listDs,rng);
150 
151         DataSetIterator dsKaggle = new ListDataSetIterator(listDs, batchSize);
152         log.info("Build model....");
153         MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
154                 .seed(seed)
155                 .iterations(iterations) // Training iterations as above
156                 .regularization(true).l2(0.001)
157                 /*
158                     Uncomment the following for learning decay and bias
159                  */
160                 .learningRate(.0001).biasLearningRate(0.02)
161                 //.learningRateDecayPolicy(LearningRatePolicy.Inverse).lrPolicyDecayRate(0.001).lrPolicyPower(0.75)
162                 .weightInit(WeightInit.XAVIER)
163                 .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
164                 .updater(Updater.NESTEROVS).momentum(0.9)
165                 .list()
166                 .layer(0, new ConvolutionLayer.Builder(5, 5)
167                         //nIn and nOut specify depth. nIn here is the nChannels and nOut is the number of filters to be applied
168                         .nIn(nChannels)
169                         .stride(1, 1)
170                         .nOut(40)
171                         .activation(Activation.IDENTITY)
172                         .build())
173                 .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
174                         .kernelSize(2,2)
175                         .stride(2,2)
176                         .build())
177                 .layer(2, new ConvolutionLayer.Builder(5, 5)
178                         //Note that nIn need not be specified in later layers
179                         .stride(1, 1)
180                         .nOut(100)
181                         .activation(Activation.IDENTITY)
182                         .build())
183                 .layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
184                         .kernelSize(2,2)
185                         .stride(2,2)
186                         .build())
187                 .layer(4, new DenseLayer.Builder().activation(Activation.RELU)
188                         .nOut(500).build())
189                 .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
190                         .nOut(outputNum)
191                         .activation(Activation.SOFTMAX)
192                         .build())
193                 .setInputType(InputType.convolutionalFlat(x,y,z)) //See note below
194                 .backprop(true).pretrain(true).build();
195 
196 
197         MultiLayerNetwork model = new MultiLayerNetwork(conf);
198         model.init();
199 
200 
201         log.info("Train model....");
202 
203         //model = ModelSerializer.restoreMultiLayerNetwork("d:\\Kaggle\\models\\2017-12-1902_33_38.zip");
204         model.setListeners(new ScoreIterationListener(1));
205 
206         //Initialize the user interface backend
207         UIServer uiServer = UIServer.getInstance();
208 
209         //Configure where the network information (gradients, score vs. time etc) is to be stored. Here: store in memory.
210         StatsStorage statsStorage = new InMemoryStatsStorage();         //Alternative: new FileStatsStorage(File), for saving and loading later
211 
212         //Attach the StatsStorage instance to the UI: this allows the contents of the StatsStorage to be visualized
213         uiServer.attach(statsStorage);
214 
215         //Then add the StatsListener to collect this information from the network, as it trains
216         model.setListeners(new StatsListener(statsStorage));
217 
218         for( int i=0; i< nEpochs; i++ ) {
219             dsKaggle.reset();
220             model.fit(dsKaggle);
221             log.info("*** Completed epoch {} ***", i);
222 
223             //log.info("Evaluate model....");
224             //Evaluation eval = new Evaluation(outputNum);
225 
226             timeMarker = new Date();
227             String fileName = "d:\\Kaggle\\models\\" + (sdf.format(timeMarker) + ".zip")
228                 .replace(" ", "")
229                 .replace(":", "_");
230             ModelSerializer.writeModel(model,fileName,true);
231         }
232         log.info("****************Example finished********************");
233     }
234 
235      private static void ReadFromFile(List<Float> listOfFloats, int fileSize, byte[] fileContent) {
236         float min, max;
237         min = Float.MAX_VALUE;
238         max = Float.MIN_VALUE;
239         for(int i = 0; i < fileSize; i+=4)
240         {
241             int valueInt = (fileContent[i + 3]) << 24 |
242                 (fileContent[i + 2] &0xff) << 16 |
243                 (fileContent[i + 1] &0xff) << 8 |
244                 (fileContent[i]&0xff);
245             float value = Float.intBitsToFloat(valueInt);
246             if(value > max)
247             {
248                 max = value;
249             }
250             if(value < min)
251             {
252                 min = value;
253             }
254             listOfFloats.add(value);
255         }
256 
257         for(int i = 0; i < listOfFloats.size(); i++)
258         {
259             float normalized = ( listOfFloats.get(i) - min ) / (max - min);
260             listOfFloats.set(i, normalized);
261         }
262     }
263 
264     private static int getIndexToTurnOn(List<String> outputStrings) {
265         int indexToTurnOn = 0;
266         for(String outString : outputStrings)
267         {
268             String[] strings = outString.split("_");
269             String[] secondPart = strings[1].split(",");
270 
271             String zoneName = secondPart[0].replace("Zone", "");
272             String zoneValue = secondPart[1];
273 
274             if(zoneValue.equals("1"))
275             {
276                 indexToTurnOn = Integer.parseInt(zoneName);
277                 break;
278             }
279         }
280         return indexToTurnOn;
281     }
282 
283     public static List<String> ReadCsvFile(String csvFile )
284     {
285         List<String> result = new ArrayList<String>();
286 
287         BufferedReader br = null;
288         String line = "";
289         String cvsSplitBy = ",";
290 
291         try {
292 
293             br = new BufferedReader(new FileReader(csvFile));
294             while ((line = br.readLine()) != null) {
295                 result.add(line);
296             }
297 
298         } catch (FileNotFoundException e) {
299             e.printStackTrace();
300         } catch (IOException e) {
301             e.printStackTrace();
302         } finally {
303             if (br != null) {
304                 try {
305                     br.close();
306                 } catch (IOException e) {
307                     e.printStackTrace();
308                 }
309             }
310         }
311 
312         return result;
313     }
314 
315     private static void GetFileNames(File folder, List<String> fileNames) {
316         File[] listOfFiles = folder.listFiles();
317         for(File file : listOfFiles)
318         {
319             if(file.isFile())
320             {
321                 fileNames.add(file.getAbsolutePath());
322             }
323         }
324     }
325 }
326 
327 

Among different features this code demonstrates mostly how to read some data, how to create network, save state of network in file, load state of network from file, and how to execute learning. Also this code presents way for monitoring network learning success via url http://localhost:9000

No Comments

Add a Comment