Skip to content

Commit 885722f

Browse files
Byte prediction is back! #53
1 parent 6fc01ab commit 885722f

File tree

3 files changed

+123
-3
lines changed

3 files changed

+123
-3
lines changed

neural/neuralab.pas

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
unit neuralab;
22
{
3-
U stands for Unit. A stands for Array. B stands for Bytes.
3+
A stands for Array. B stands for Bytes.
44
This unit contains "array of bytes" functions.
55
Copyright (C) 2017 Joao Paulo Schwarz Schuler
66
@@ -86,7 +86,7 @@ function ABToString(var AB: array of byte): string;
8686
// returns a string from the array
8787
function ABToStringR(var AB: array of byte): string;
8888

89-
// clears array (fills with zeros
89+
// clears array (fills with zeros)
9090
procedure ABClear(var AB: array of byte);
9191

9292
// fills with 1 the array.

neural/neuralbyteprediction.pas

+118-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747

4848
interface
4949

50-
uses neuralabfun, neuralcache;
50+
uses neuralabfun, neuralcache, neuralnetwork, neuralvolume;
5151

5252
type
5353
TCountings = array of longint;
@@ -344,10 +344,127 @@ interface
344344
function GetD(posPredictedState: longint): single;
345345
end;
346346

347+
// This class is very experimental - do not use it.
348+
TEasyLearnAndPredictDeepLearning = class(TMObject)
349+
private
350+
NN: TNNet;
351+
FActions, FStates, FPredictedStates, FOutput: TNNetVolume;
352+
aActions, aCurrentState, aPredictedState: array of byte;
353+
FCached: boolean;
354+
FUseCache: boolean;
355+
public
356+
FCache: TCacheMem;
357+
358+
constructor Create(
359+
pActionByteLen {action array size in bytes},
360+
pStateByteLen{state array size in bytes}: word;
361+
// false = creates operation/neurons for non zero entries only.
362+
NumNeurons: integer;
363+
// the higher the number, more computations are used on each step. If you don't know what number to use, give 40.
364+
pUseCache: boolean
365+
// replies the same prediction for the same given state. Use false if you aren't sure.
366+
);
367+
destructor Destroy(); override;
368+
369+
// THIS METHOD WILL PREDICT THE NEXT SATE GIVEN AN ARRAY OF ACTIONS AND STATES.
370+
// You can understand ACTIONS as a kind of "current state".
371+
// Returned value "predicted states" contains the neural network prediction.
372+
procedure Predict(var pActions, pCurrentState: array of byte;
373+
var pPredictedState: array of byte);
374+
375+
// Call this method to train the neural network so it can learn from the "found state".
376+
// Call this method and when the state of your environment changes so the neural
377+
// network can learn how the state changes from time to time.
378+
function newStateFound(stateFound: array of byte): extended;
379+
end;
380+
347381
implementation
348382

349383
uses neuralab, SysUtils, Classes;
350384

385+
constructor TEasyLearnAndPredictDeepLearning.Create(pActionByteLen,
386+
pStateByteLen: word; NumNeurons: integer;
387+
pUseCache: boolean);
388+
var
389+
NNetInputLayer1, NNetInputLayer2: TNNetLayer;
390+
begin
391+
inherited Create();
392+
NN := TNNet.Create;
393+
FActions := TNNetVolume.Create();
394+
FStates := TNNetVolume.Create();
395+
FPredictedStates := TNNetVolume.Create();
396+
FOutput := TNNetVolume.Create();
397+
SetLength(aActions, pActionByteLen);
398+
SetLength(aCurrentState, pStateByteLen);
399+
SetLength(aPredictedState, pStateByteLen);
400+
if (pUseCache)
401+
then FCache.Init(pActionByteLen, pStateByteLen)
402+
else FCache.Init(1, 1);
403+
FUseCache := pUseCache;
404+
NNetInputLayer1 := NN.AddLayer( TNNetInput.Create(pActionByteLen*8) );
405+
NNetInputLayer2 := NN.AddLayer( TNNetInput.Create(pStateByteLen*8) );
406+
NN.AddLayer( TNNetConcat.Create([NNetInputLayer1, NNetInputLayer2]));
407+
NN.AddLayer( TNNetFullConnectReLU.Create( NumNeurons ) );
408+
NN.AddLayer( TNNetFullConnectReLU.Create( NumNeurons ) );
409+
NN.AddLayer( TNNetFullConnectReLU.Create( NumNeurons ) );
410+
NN.AddLayer( TNNetFullConnect.Create( (pStateByteLen)*8 ) );
411+
NN.SetLearningRate(0.01, 0.9);
412+
NN.DebugStructure();
413+
end;
414+
415+
destructor TEasyLearnAndPredictDeepLearning.Destroy();
416+
begin
417+
FOutput.Free;
418+
FActions.Free;
419+
FStates.Free;
420+
FPredictedStates.Free;
421+
NN.Free;
422+
inherited Destroy();
423+
end;
424+
425+
procedure TEasyLearnAndPredictDeepLearning.Predict(var pActions,
426+
pCurrentState: array of byte; var pPredictedState: array of byte);
427+
var
428+
idxCache: longint;
429+
Equal: boolean;
430+
begin
431+
ABCopy(aActions, pActions);
432+
ABCopy(aCurrentState, pCurrentState);
433+
if FUseCache then
434+
idxCache := FCache.Read(pActions, pPredictedState);
435+
Equal := ABCmp(pActions, pCurrentState);
436+
if FUseCache and (idxCache <> -1) and Equal then
437+
begin
438+
FCached := True;
439+
end
440+
else
441+
begin
442+
//BytePred.Prediction(aActions, aCurrentState, pPredictedState, FRelationProbability, FVictoryIndex);
443+
FActions.CopyAsBits(aActions);
444+
FStates.CopyAsBits(pCurrentState);
445+
NN.Compute([FActions, FStates]);
446+
NN.GetOutput(FPredictedStates);
447+
FPredictedStates.ReadAsBits(pPredictedState);
448+
FCached := False;
449+
end;
450+
ABCopy(aPredictedState, pPredictedState);
451+
end;
452+
453+
function TEasyLearnAndPredictDeepLearning.newStateFound(stateFound: array of byte): extended;
454+
begin
455+
newStateFound := ABCountDif(stateFound, aPredictedState);
456+
// Do we have a cached prediction
457+
if Not(FCached) then
458+
begin
459+
FPredictedStates.CopyAsBits(stateFound);
460+
NN.GetOutput(FOutput);
461+
NN.Backpropagate(FPredictedStates);
462+
//newStateFound := FOutput.SumDiff(FPredictedStates);
463+
end;
464+
if FUseCache then
465+
FCache.Include(aActions, stateFound);
466+
end;
467+
351468
{ TClassifier }
352469

353470
procedure TClassifier.AddClassifier(NumClasses, NumStates: integer);

neural/neuralvolume.pas

+3
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@
3535
// TVolume has also been inpired on Exentia
3636
// http://www.tommesani.com/ExentiaWhatsNew.html
3737

38+
{$IFDEF FPC}
39+
{$mode objfpc}
40+
{$ENDIF}
3841

3942
interface
4043

0 commit comments

Comments
 (0)