|
47 | 47 |
|
48 | 48 | interface
|
49 | 49 |
|
50 |
| -uses neuralabfun, neuralcache; |
| 50 | +uses neuralabfun, neuralcache, neuralnetwork, neuralvolume; |
51 | 51 |
|
52 | 52 | type
|
53 | 53 | TCountings = array of longint;
|
@@ -344,10 +344,127 @@ interface
|
344 | 344 | function GetD(posPredictedState: longint): single;
|
345 | 345 | end;
|
346 | 346 |
|
| 347 | + // This class is very experimental - do not use it. |
| 348 | + TEasyLearnAndPredictDeepLearning = class(TMObject) |
| 349 | + private |
| 350 | + NN: TNNet; |
| 351 | + FActions, FStates, FPredictedStates, FOutput: TNNetVolume; |
| 352 | + aActions, aCurrentState, aPredictedState: array of byte; |
| 353 | + FCached: boolean; |
| 354 | + FUseCache: boolean; |
| 355 | + public |
| 356 | + FCache: TCacheMem; |
| 357 | + |
| 358 | + constructor Create( |
| 359 | + pActionByteLen {action array size in bytes}, |
| 360 | + pStateByteLen{state array size in bytes}: word; |
| 361 | + // false = creates operation/neurons for non zero entries only. |
| 362 | + NumNeurons: integer; |
| 363 | + // the higher the number, more computations are used on each step. If you don't know what number to use, give 40. |
| 364 | + pUseCache: boolean |
| 365 | + // replies the same prediction for the same given state. Use false if you aren't sure. |
| 366 | + ); |
| 367 | + destructor Destroy(); override; |
| 368 | + |
| 369 | + // THIS METHOD WILL PREDICT THE NEXT SATE GIVEN AN ARRAY OF ACTIONS AND STATES. |
| 370 | + // You can understand ACTIONS as a kind of "current state". |
| 371 | + // Returned value "predicted states" contains the neural network prediction. |
| 372 | + procedure Predict(var pActions, pCurrentState: array of byte; |
| 373 | + var pPredictedState: array of byte); |
| 374 | + |
| 375 | + // Call this method to train the neural network so it can learn from the "found state". |
| 376 | + // Call this method and when the state of your environment changes so the neural |
| 377 | + // network can learn how the state changes from time to time. |
| 378 | + function newStateFound(stateFound: array of byte): extended; |
| 379 | + end; |
| 380 | + |
347 | 381 | implementation
|
348 | 382 |
|
349 | 383 | uses neuralab, SysUtils, Classes;
|
350 | 384 |
|
| 385 | +constructor TEasyLearnAndPredictDeepLearning.Create(pActionByteLen, |
| 386 | + pStateByteLen: word; NumNeurons: integer; |
| 387 | + pUseCache: boolean); |
| 388 | +var |
| 389 | + NNetInputLayer1, NNetInputLayer2: TNNetLayer; |
| 390 | +begin |
| 391 | + inherited Create(); |
| 392 | + NN := TNNet.Create; |
| 393 | + FActions := TNNetVolume.Create(); |
| 394 | + FStates := TNNetVolume.Create(); |
| 395 | + FPredictedStates := TNNetVolume.Create(); |
| 396 | + FOutput := TNNetVolume.Create(); |
| 397 | + SetLength(aActions, pActionByteLen); |
| 398 | + SetLength(aCurrentState, pStateByteLen); |
| 399 | + SetLength(aPredictedState, pStateByteLen); |
| 400 | + if (pUseCache) |
| 401 | + then FCache.Init(pActionByteLen, pStateByteLen) |
| 402 | + else FCache.Init(1, 1); |
| 403 | + FUseCache := pUseCache; |
| 404 | + NNetInputLayer1 := NN.AddLayer( TNNetInput.Create(pActionByteLen*8) ); |
| 405 | + NNetInputLayer2 := NN.AddLayer( TNNetInput.Create(pStateByteLen*8) ); |
| 406 | + NN.AddLayer( TNNetConcat.Create([NNetInputLayer1, NNetInputLayer2])); |
| 407 | + NN.AddLayer( TNNetFullConnectReLU.Create( NumNeurons ) ); |
| 408 | + NN.AddLayer( TNNetFullConnectReLU.Create( NumNeurons ) ); |
| 409 | + NN.AddLayer( TNNetFullConnectReLU.Create( NumNeurons ) ); |
| 410 | + NN.AddLayer( TNNetFullConnect.Create( (pStateByteLen)*8 ) ); |
| 411 | + NN.SetLearningRate(0.01, 0.9); |
| 412 | + NN.DebugStructure(); |
| 413 | +end; |
| 414 | + |
| 415 | +destructor TEasyLearnAndPredictDeepLearning.Destroy(); |
| 416 | +begin |
| 417 | + FOutput.Free; |
| 418 | + FActions.Free; |
| 419 | + FStates.Free; |
| 420 | + FPredictedStates.Free; |
| 421 | + NN.Free; |
| 422 | + inherited Destroy(); |
| 423 | +end; |
| 424 | + |
| 425 | +procedure TEasyLearnAndPredictDeepLearning.Predict(var pActions, |
| 426 | + pCurrentState: array of byte; var pPredictedState: array of byte); |
| 427 | +var |
| 428 | + idxCache: longint; |
| 429 | + Equal: boolean; |
| 430 | +begin |
| 431 | + ABCopy(aActions, pActions); |
| 432 | + ABCopy(aCurrentState, pCurrentState); |
| 433 | + if FUseCache then |
| 434 | + idxCache := FCache.Read(pActions, pPredictedState); |
| 435 | + Equal := ABCmp(pActions, pCurrentState); |
| 436 | + if FUseCache and (idxCache <> -1) and Equal then |
| 437 | + begin |
| 438 | + FCached := True; |
| 439 | + end |
| 440 | + else |
| 441 | + begin |
| 442 | + //BytePred.Prediction(aActions, aCurrentState, pPredictedState, FRelationProbability, FVictoryIndex); |
| 443 | + FActions.CopyAsBits(aActions); |
| 444 | + FStates.CopyAsBits(pCurrentState); |
| 445 | + NN.Compute([FActions, FStates]); |
| 446 | + NN.GetOutput(FPredictedStates); |
| 447 | + FPredictedStates.ReadAsBits(pPredictedState); |
| 448 | + FCached := False; |
| 449 | + end; |
| 450 | + ABCopy(aPredictedState, pPredictedState); |
| 451 | +end; |
| 452 | + |
| 453 | +function TEasyLearnAndPredictDeepLearning.newStateFound(stateFound: array of byte): extended; |
| 454 | +begin |
| 455 | + newStateFound := ABCountDif(stateFound, aPredictedState); |
| 456 | + // Do we have a cached prediction |
| 457 | + if Not(FCached) then |
| 458 | + begin |
| 459 | + FPredictedStates.CopyAsBits(stateFound); |
| 460 | + NN.GetOutput(FOutput); |
| 461 | + NN.Backpropagate(FPredictedStates); |
| 462 | + //newStateFound := FOutput.SumDiff(FPredictedStates); |
| 463 | + end; |
| 464 | + if FUseCache then |
| 465 | + FCache.Include(aActions, stateFound); |
| 466 | +end; |
| 467 | + |
351 | 468 | { TClassifier }
|
352 | 469 |
|
353 | 470 | procedure TClassifier.AddClassifier(NumClasses, NumStates: integer);
|
|
0 commit comments