From 002b09f45a16331e4ebdb27e39228523988056e9 Mon Sep 17 00:00:00 2001 From: Dev Chauhan Date: Tue, 25 Feb 2020 01:55:06 +0530 Subject: [PATCH 1/6] This commit fixes #191 --- text/char-rnn/char-rnn.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/text/char-rnn/char-rnn.jl b/text/char-rnn/char-rnn.jl index 0db14463d..1abc0d2d6 100644 --- a/text/char-rnn/char-rnn.jl +++ b/text/char-rnn/char-rnn.jl @@ -31,7 +31,7 @@ m = gpu(m) function loss(xs, ys) l = sum(crossentropy.(m.(gpu.(xs)), gpu.(ys))) - Flux.truncate!(m) + Flux.reset!(m) return l end From d198eb316dc646db0daf2c90eb3829c8343ae03e Mon Sep 17 00:00:00 2001 From: Dev Chauhan Date: Mon, 2 Mar 2020 16:36:22 +0530 Subject: [PATCH 2/6] env for char-rnn --- text/char-rnn/Manifest.toml | 379 ++++++++++++++++++++++++++++++++++++ text/char-rnn/Project.toml | 2 + text/char-rnn/char-rnn.jl | 3 +- 3 files changed, 382 insertions(+), 2 deletions(-) create mode 100644 text/char-rnn/Manifest.toml create mode 100644 text/char-rnn/Project.toml diff --git a/text/char-rnn/Manifest.toml b/text/char-rnn/Manifest.toml new file mode 100644 index 000000000..edc8b56f9 --- /dev/null +++ b/text/char-rnn/Manifest.toml @@ -0,0 +1,379 @@ +# This file is machine-generated - editing it directly is not advised + +[[AbstractFFTs]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "051c95d6836228d120f5f4b984dd5aba1624f716" +uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" +version = "0.5.0" + +[[AbstractTrees]] +deps = ["Markdown"] +git-tree-sha1 = "86d092c2599f1f7bb01668bf8eb3412f98d61e47" +uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" +version = "0.3.2" + +[[Adapt]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "c88cfc7f9c1f9f8633cddf0b56e86302b70f64c5" +uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +version = "1.0.1" + +[[Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" + +[[BinaryProvider]] +deps = ["Libdl", "SHA"] +git-tree-sha1 = "5b08ed6036d9d3f0ee6369410b830f8873d4024c" +uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232" +version = "0.5.8" + +[[CEnum]] +git-tree-sha1 = "62847acab40e6855a9b5905ccb99c2b5cf6b3ebb" +uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" +version = "0.2.0" + +[[CUDAapi]] +deps = ["Libdl", "Logging"] +git-tree-sha1 = "d7ceadd8f821177d05b897c0517e94633db535fe" +uuid = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3" +version = "3.1.0" + +[[CUDAdrv]] +deps = ["CEnum", "CUDAapi", "Printf"] +git-tree-sha1 = "01e90fa34e25776bc7c8661183d4519149ebfe59" +uuid = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde" +version = "6.0.0" + +[[CUDAnative]] +deps = ["Adapt", "CEnum", "CUDAapi", "CUDAdrv", "DataStructures", "InteractiveUtils", "LLVM", "Libdl", "Printf", "TimerOutputs"] +git-tree-sha1 = "f86269ff60ebe082a2806ecbce51f3cadc68afe9" +uuid = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17" +version = "2.10.2" + +[[CodecZlib]] +deps = ["BinaryProvider", "Libdl", "TranscodingStreams"] +git-tree-sha1 = "05916673a2627dd91b4969ff8ba6941bc85a960e" +uuid = "944b1d66-785c-5afd-91f1-9de20f533193" +version = "0.6.0" + +[[ColorTypes]] +deps = ["FixedPointNumbers", "Random"] +git-tree-sha1 = "7b62b728a5f3dd6ee3b23910303ccf27e82fad5e" +uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" +version = "0.8.1" + +[[Colors]] +deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Printf", "Reexport"] +git-tree-sha1 = "c9c1845d6bf22e34738bee65c357a69f416ed5d1" +uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" +version = "0.9.6" + +[[CommonSubexpressions]] +deps = ["Test"] +git-tree-sha1 = "efdaf19ab11c7889334ca247ff4c9f7c322817b0" +uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" +version = "0.2.0" + +[[CompilerSupportLibraries_jll]] +deps = ["Libdl", "Pkg"] +git-tree-sha1 = "b57c5d019367c90f234a7bc7e24ff0a84971da5d" +uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" +version = "0.2.0+1" + +[[CuArrays]] +deps = ["AbstractFFTs", "Adapt", "CEnum", "CUDAapi", "CUDAdrv", "CUDAnative", "DataStructures", "GPUArrays", "Libdl", "LinearAlgebra", "MacroTools", "NNlib", "Printf", "Random", "Requires", "SparseArrays", "TimerOutputs"] +git-tree-sha1 = "7c20c5a45bb245cf248f454d26966ea70255b271" +uuid = "3a865a2d-5b23-5a0f-bc46-62713ec82fae" +version = "1.7.2" + +[[DataAPI]] +git-tree-sha1 = "674b67f344687a88310213ddfa8a2b3c76cc4252" +uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" +version = "1.1.0" + +[[DataStructures]] +deps = ["InteractiveUtils", "OrderedCollections"] +git-tree-sha1 = "5a431d46abf2ef2a4d5d00bd0ae61f651cf854c8" +uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +version = "0.17.10" + +[[Dates]] +deps = ["Printf"] +uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" + +[[DelimitedFiles]] +deps = ["Mmap"] +uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" + +[[DiffResults]] +deps = ["StaticArrays"] +git-tree-sha1 = "da24935df8e0c6cf28de340b958f6aac88eaa0cc" +uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +version = "1.0.2" + +[[DiffRules]] +deps = ["NaNMath", "Random", "SpecialFunctions"] +git-tree-sha1 = "eb0c34204c8410888844ada5359ac8b96292cfd1" +uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" +version = "1.0.1" + +[[Distributed]] +deps = ["Random", "Serialization", "Sockets"] +uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" + +[[FFTW]] +deps = ["AbstractFFTs", "FFTW_jll", "IntelOpenMP_jll", "Libdl", "LinearAlgebra", "MKL_jll", "Reexport"] +git-tree-sha1 = "109d82fa4b00429f9afcce873e9f746f11f018d3" +uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" +version = "1.2.0" + +[[FFTW_jll]] +deps = ["Libdl", "Pkg"] +git-tree-sha1 = "ddb57f4cf125243b4aa4908c94d73a805f3cbf2c" +uuid = "f5851436-0d7a-5f13-b9de-f02708fd171a" +version = "3.3.9+4" + +[[FillArrays]] +deps = ["LinearAlgebra", "Random", "SparseArrays"] +git-tree-sha1 = "85c6b57e2680fa28d5c8adc798967377646fbf66" +uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" +version = "0.8.5" + +[[FixedPointNumbers]] +git-tree-sha1 = "d14a6fa5890ea3a7e5dcab6811114f132fec2b4b" +uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" +version = "0.6.1" + +[[Flux]] +deps = ["AbstractTrees", "Adapt", "CodecZlib", "Colors", "CuArrays", "DelimitedFiles", "Juno", "MacroTools", "NNlib", "Pkg", "Printf", "Random", "Reexport", "SHA", "Statistics", "StatsBase", "Test", "ZipFile", "Zygote"] +git-tree-sha1 = "8134adbb0f10b0d22b22f8b4299d0d20509edc5f" +uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" +version = "0.10.1" + +[[ForwardDiff]] +deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "NaNMath", "Random", "SpecialFunctions", "StaticArrays"] +git-tree-sha1 = "88b082d492be6b63f967b6c96b352e25ced1a34c" +uuid = "f6369f11-7733-5829-9624-2563aa707210" +version = "0.10.9" + +[[GPUArrays]] +deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization"] +git-tree-sha1 = "e756da6cee76a5f1436a05827fa8fdf3badc577f" +uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" +version = "2.0.1" + +[[IRTools]] +deps = ["InteractiveUtils", "MacroTools", "Test"] +git-tree-sha1 = "1a4355e4b5b50be2311ebb644f34f3306dbd0410" +uuid = "7869d1d1-7146-5819-86e3-90919afe41df" +version = "0.3.1" + +[[IntelOpenMP_jll]] +deps = ["Libdl", "Pkg"] +git-tree-sha1 = "fb8e1c7a5594ba56f9011310790e03b5384998d6" +uuid = "1d5cc7b8-4909-519e-a0f8-d0f5ad9712d0" +version = "2018.0.3+0" + +[[InteractiveUtils]] +deps = ["Markdown"] +uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" + +[[Juno]] +deps = ["Base64", "Logging", "Media", "Profile", "Test"] +git-tree-sha1 = "30d94657a422d09cb97b6f86f04f750fa9c50df8" +uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d" +version = "0.7.2" + +[[LLVM]] +deps = ["CEnum", "Libdl", "Printf", "Unicode"] +git-tree-sha1 = "1d08d7e4250f452f6cb20e4574daaebfdbee0ff7" +uuid = "929cbde3-209d-540e-8aea-75f648917ca0" +version = "1.3.3" + +[[LibGit2]] +deps = ["Printf"] +uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" + +[[Libdl]] +uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" + +[[LinearAlgebra]] +deps = ["Libdl"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[[Logging]] +uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" + +[[MKL_jll]] +deps = ["IntelOpenMP_jll", "Libdl", "Pkg"] +git-tree-sha1 = "720629cc8cbd12c146ca01b661fd1a6cf66e2ff4" +uuid = "856f044c-d86e-5d09-b602-aeab76dc8ba7" +version = "2019.0.117+2" + +[[MacroTools]] +deps = ["DataStructures", "Markdown", "Random"] +git-tree-sha1 = "07ee65e03e28ca88bc9a338a3726ae0c3efaa94b" +uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +version = "0.5.4" + +[[Markdown]] +deps = ["Base64"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" + +[[Media]] +deps = ["MacroTools", "Test"] +git-tree-sha1 = "75a54abd10709c01f1b86b84ec225d26e840ed58" +uuid = "e89f7d12-3494-54d1-8411-f7d8b9ae1f27" +version = "0.5.0" + +[[Missings]] +deps = ["DataAPI"] +git-tree-sha1 = "de0a5ce9e5289f27df672ffabef4d1e5861247d5" +uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" +version = "0.4.3" + +[[Mmap]] +uuid = "a63ad114-7e13-5084-954f-fe012c677804" + +[[NNlib]] +deps = ["BinaryProvider", "Libdl", "LinearAlgebra", "Requires", "Statistics"] +git-tree-sha1 = "21a3c22bc197b6ae2f8d4d75631876e2b6506dbe" +uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +version = "0.6.5" + +[[NaNMath]] +git-tree-sha1 = "928b8ca9b2791081dc71a51c55347c27c618760f" +uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +version = "0.3.3" + +[[OpenSpecFun_jll]] +deps = ["CompilerSupportLibraries_jll", "Libdl", "Pkg"] +git-tree-sha1 = "d110040968b9afe95c6bd9c6233570b0fe8abd22" +uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" +version = "0.5.3+2" + +[[OrderedCollections]] +deps = ["Random", "Serialization", "Test"] +git-tree-sha1 = "c4c13474d23c60d20a67b217f1d7f22a40edf8f1" +uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +version = "1.1.0" + +[[Pkg]] +deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] +uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" + +[[Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" + +[[Profile]] +deps = ["Printf"] +uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79" + +[[REPL]] +deps = ["InteractiveUtils", "Markdown", "Sockets"] +uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" + +[[Random]] +deps = ["Serialization"] +uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + +[[Reexport]] +deps = ["Pkg"] +git-tree-sha1 = "7b1d07f411bc8ddb7977ec7f377b97b158514fe0" +uuid = "189a3867-3050-52da-a836-e630ba90ab69" +version = "0.2.0" + +[[Requires]] +deps = ["UUIDs"] +git-tree-sha1 = "d37400976e98018ee840e0ca4f9d20baa231dc6b" +uuid = "ae029012-a4dd-5104-9daa-d747884805df" +version = "1.0.1" + +[[SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" + +[[Serialization]] +uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" + +[[Sockets]] +uuid = "6462fe0b-24de-5631-8697-dd941f90decc" + +[[SortingAlgorithms]] +deps = ["DataStructures", "Random", "Test"] +git-tree-sha1 = "03f5898c9959f8115e30bc7226ada7d0df554ddd" +uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" +version = "0.3.1" + +[[SparseArrays]] +deps = ["LinearAlgebra", "Random"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + +[[SpecialFunctions]] +deps = ["OpenSpecFun_jll"] +git-tree-sha1 = "e19b98acb182567bcb7b75bb5d9eedf3a3b5ec6c" +uuid = "276daf66-3868-5448-9aa4-cd146d93841b" +version = "0.10.0" + +[[StaticArrays]] +deps = ["LinearAlgebra", "Random", "Statistics"] +git-tree-sha1 = "5a3bcb6233adabde68ebc97be66e95dcb787424c" +uuid = "90137ffa-7385-5640-81b9-e52037218182" +version = "0.12.1" + +[[Statistics]] +deps = ["LinearAlgebra", "SparseArrays"] +uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" + +[[StatsBase]] +deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"] +git-tree-sha1 = "be5c7d45daa449d12868f4466dbf5882242cf2d9" +uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +version = "0.32.1" + +[[Test]] +deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] +uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[[TimerOutputs]] +deps = ["Printf"] +git-tree-sha1 = "311765af81bbb48d7bad01fb016d9c328c6ede03" +uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" +version = "0.5.3" + +[[TranscodingStreams]] +deps = ["Random", "Test"] +git-tree-sha1 = "7c53c35547de1c5b9d46a4797cf6d8253807108c" +uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" +version = "0.9.5" + +[[UUIDs]] +deps = ["Random", "SHA"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" + +[[Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" + +[[ZipFile]] +deps = ["Libdl", "Printf", "Zlib_jll"] +git-tree-sha1 = "8748302cfdec02c4ae9c97b112cf10003f7f767f" +uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea" +version = "0.9.1" + +[[Zlib_jll]] +deps = ["Libdl", "Pkg"] +git-tree-sha1 = "fd36a6739e256527287c5444960d0266712cd49e" +uuid = "83775a58-1f1d-513f-b197-d71354ab007a" +version = "1.2.11+8" + +[[Zygote]] +deps = ["DiffRules", "FFTW", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"] +git-tree-sha1 = "f8329b595c465caf3ca87c4f744e6041a4983e43" +uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" +version = "0.4.8" + +[[ZygoteRules]] +deps = ["MacroTools"] +git-tree-sha1 = "b3b4882cc9accf6731a08cc39543fbc6b669dca8" +uuid = "700de1a5-db45-46bc-99cf-38207098b444" +version = "0.2.0" diff --git a/text/char-rnn/Project.toml b/text/char-rnn/Project.toml new file mode 100644 index 000000000..77df42abf --- /dev/null +++ b/text/char-rnn/Project.toml @@ -0,0 +1,2 @@ +[deps] +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" diff --git a/text/char-rnn/char-rnn.jl b/text/char-rnn/char-rnn.jl index 1abc0d2d6..64515e919 100644 --- a/text/char-rnn/char-rnn.jl +++ b/text/char-rnn/char-rnn.jl @@ -31,7 +31,6 @@ m = gpu(m) function loss(xs, ys) l = sum(crossentropy.(m.(gpu.(xs)), gpu.(ys))) - Flux.reset!(m) return l end @@ -51,7 +50,7 @@ function sample(m, alphabet, len) c = rand(alphabet) for i = 1:len write(buf, c) - c = wsample(alphabet, m(onehot(c, alphabet)).data) + c = wsample(alphabet, m(onehot(c, alphabet))) end return String(take!(buf)) end From fb29d7194c9b8acee54e1c80287e9d5fe94b2010 Mon Sep 17 00:00:00 2001 From: Dev Chauhan Date: Mon, 2 Mar 2020 16:47:12 +0530 Subject: [PATCH 3/6] updated Project.toml char-rnn --- text/char-rnn/Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/text/char-rnn/Project.toml b/text/char-rnn/Project.toml index 77df42abf..2b4ae433b 100644 --- a/text/char-rnn/Project.toml +++ b/text/char-rnn/Project.toml @@ -1,2 +1,3 @@ [deps] Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" From 2d3d7454d62f8880fa678dfd42250bec5c7af6c2 Mon Sep 17 00:00:00 2001 From: Dev Chauhan Date: Thu, 5 Mar 2020 04:26:36 +0530 Subject: [PATCH 4/6] format code --- text/char-rnn/Manifest.toml | 12 ++++ text/char-rnn/Project.toml | 2 + text/char-rnn/char-rnn.jl | 121 +++++++++++++++++++++++------------- 3 files changed, 92 insertions(+), 43 deletions(-) diff --git a/text/char-rnn/Manifest.toml b/text/char-rnn/Manifest.toml index edc8b56f9..c5db70be7 100644 --- a/text/char-rnn/Manifest.toml +++ b/text/char-rnn/Manifest.toml @@ -258,6 +258,12 @@ git-tree-sha1 = "c4c13474d23c60d20a67b217f1d7f22a40edf8f1" uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" version = "1.1.0" +[[Parameters]] +deps = ["OrderedCollections"] +git-tree-sha1 = "b62b2558efb1eef1fa44e4be5ff58a515c287e38" +uuid = "d96e819e-fc66-5662-9728-84c9c7592b0a" +version = "0.12.0" + [[Pkg]] deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" @@ -270,6 +276,12 @@ uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" deps = ["Printf"] uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79" +[[ProgressMeter]] +deps = ["Distributed", "Printf"] +git-tree-sha1 = "ea1f4fa0ff5e8b771bf130d87af5b7ef400760bd" +uuid = "92933f4c-e287-5a05-a399-4b506db050ca" +version = "1.2.0" + [[REPL]] deps = ["InteractiveUtils", "Markdown", "Sockets"] uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" diff --git a/text/char-rnn/Project.toml b/text/char-rnn/Project.toml index 2b4ae433b..942a5987e 100644 --- a/text/char-rnn/Project.toml +++ b/text/char-rnn/Project.toml @@ -1,3 +1,5 @@ [deps] Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" +ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" diff --git a/text/char-rnn/char-rnn.jl b/text/char-rnn/char-rnn.jl index 64515e919..4ad8930a7 100644 --- a/text/char-rnn/char-rnn.jl +++ b/text/char-rnn/char-rnn.jl @@ -2,57 +2,92 @@ using Flux using Flux: onehot, chunk, batchseq, throttle, crossentropy using StatsBase: wsample using Base.Iterators: partition +using Parameters: @with_kw +using ProgressMeter cd(@__DIR__) -isfile("input.txt") || - download("https://cs.stanford.edu/people/karpathy/char-rnn/shakespeare_input.txt", - "input.txt") - -text = collect(String(read("input.txt"))) -alphabet = [unique(text)..., '_'] -text = map(ch -> onehot(ch, alphabet), text) -stop = onehot('_', alphabet) - -N = length(alphabet) -seqlen = 50 -nbatch = 50 +@with_kw struct HyperParams + nbatch::Int = 50 + seqlen::Int = 50 + epochs::Int = 1 + lr::Float64 = 0.01 + val_char_len::Int = 1000 + verbose_freq::Int = 1 +end -Xs = collect(partition(batchseq(chunk(text, nbatch), stop), seqlen)) -Ys = collect(partition(batchseq(chunk(text[2:end], nbatch), stop), seqlen)) +function get_data(hparams::HyperParams) + isfile("input.txt") || + download("https://cs.stanford.edu/people/karpathy/char-rnn/shakespeare_input.txt", + "input.txt") + + text = collect(String(read("input.txt"))) + alphabet = [unique(text)..., '_'] + text = map(ch -> onehot(ch, alphabet), text) + stop = onehot('_', alphabet) + N = length(alphabet) -m = Chain( - LSTM(N, 128), - LSTM(128, 128), - Dense(128, N), - softmax) + Xs = collect(partition(batchseq(chunk(text, hparams.nbatch)|>gpu, stop), hparams.seqlen)) + Ys = collect(partition(batchseq(chunk(text[2:end], hparams.nbatch)|>gpu, stop), hparams.seqlen)) -m = gpu(m) + return Xs, Ys, alphabet, N +end -function loss(xs, ys) - l = sum(crossentropy.(m.(gpu.(xs)), gpu.(ys))) - return l +function sample(m, alphabet, hparams::HyperParams) + Flux.reset!(m) + buf = IOBuffer() + c = rand(alphabet) + for i = 1:hparams.val_char_len + write(buf, c) + c = wsample(alphabet, m(onehot(c, alphabet))) + end + return String(take!(buf)) end -opt = ADAM(0.01) -tx, ty = (Xs[5], Ys[5]) -evalcb = () -> @show loss(tx, ty) - -Flux.train!(loss, params(m), zip(Xs, Ys), opt, - cb = throttle(evalcb, 30)) - -# Sampling - -function sample(m, alphabet, len) - m = cpu(m) - Flux.reset!(m) - buf = IOBuffer() - c = rand(alphabet) - for i = 1:len - write(buf, c) - c = wsample(alphabet, m(onehot(c, alphabet))) - end - return String(take!(buf)) +function train(; kws...) + # Parameters for training + hparams = HyperParams() + # data + Xs, Ys, alphabet, N = get_data(hparams) + # model + m = Chain(LSTM(N, 128), + LSTM(128, 128), + Dense(128, N), + softmax) |> gpu + # optimizer + opt = ADAM(hparams.lr) + # validation data + tx, ty = (Xs[5], Ys[5]) + # progress bar + p = Progress(hparams.epochs * length(Xs)) + + function loss(xs, ys) + l = sum(crossentropy.(m.(gpu.(xs)), gpu.(ys))) + return l + end + + for ep in 1:hparams.epochs + iter = 0 + # train + Flux.train!(params(m), zip(Xs, Ys), opt, + cb=function () + if iter % hparams.verbose_freq == 0 + # calc val loss + val_loss = sum(crossentropy.(m.(tx), ty)) + info_val = [(:epoch, ep), (:iter, iter), (:train_loss, loss), (:val_loss, val_loss)] + else + info_val = [(:epoch, ep), (:iter, iter), (:train_loss, loss)] + end + ProgressMeter.next!(p; showvalues=info_val) + iter += 1 + end + ) do x, y + # calc loss + loss = sum(crossentropy.(m.(x), y)) + end + # sample and dump sequence of chars + write("output/sample_$(ep).txt", sample(m, alphabet, hparams)) + end end -sample(m, alphabet, 1000) |> println +train() From f9e9655bbea1378dc730a46ad541714fe12f2292 Mon Sep 17 00:00:00 2001 From: Dev Chauhan Date: Thu, 5 Mar 2020 04:29:15 +0530 Subject: [PATCH 5/6] remove unnecessary --- text/char-rnn/char-rnn.jl | 5 ----- 1 file changed, 5 deletions(-) diff --git a/text/char-rnn/char-rnn.jl b/text/char-rnn/char-rnn.jl index 4ad8930a7..008a9978a 100644 --- a/text/char-rnn/char-rnn.jl +++ b/text/char-rnn/char-rnn.jl @@ -61,11 +61,6 @@ function train(; kws...) # progress bar p = Progress(hparams.epochs * length(Xs)) - function loss(xs, ys) - l = sum(crossentropy.(m.(gpu.(xs)), gpu.(ys))) - return l - end - for ep in 1:hparams.epochs iter = 0 # train From db9412293bb6a85e36e87a9c604d249539e46823 Mon Sep 17 00:00:00 2001 From: Dev Chauhan Date: Thu, 5 Mar 2020 04:55:07 +0530 Subject: [PATCH 6/6] add output --- text/char-rnn/output/sample_1.txt | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 text/char-rnn/output/sample_1.txt diff --git a/text/char-rnn/output/sample_1.txt b/text/char-rnn/output/sample_1.txt new file mode 100644 index 000000000..b3c5e6ad5 --- /dev/null +++ b/text/char-rnn/output/sample_1.txt @@ -0,0 +1,29 @@ +;! , grapl ply th is fredrsper lefno hor behee heartom me the venteave. + +OGARO: +Sughuch shathey and he py t? the preweeeray youritoor mearoirge dres urur akist eod? + +LERIA: +Wow,-brbsto to she my not comy ay ll, llsaot mitiveaI ia fain th idnder? + +DARomy oraNemy do, thip? +IINIANENANAELOPAREENAO: MOREPENANCICIA: +No hild my hald I +EARrING OEQNeavor thes? her, I call'y.'d low mflrow I noto sanby lapl I firteave thats: SENTER: +I guthesoptheare scanoto met, +Woonow, my sato his ne; Bal forNnentnes! +I d. Whim hight ahow; ife-t th, of to non ths mandre: +Nohat th dgfepenveace: I ved, preid try tive. ais bight yout; + +ROIUCUESALZOI Peate, vy nespoit, orohI to riz. + +PAONETR: +Lorththlrhome teveri'st'm +Aving hshoby you lets them it in is teu iserth. + +DUKING ER: +Dracord etlaiderirs, I extt roalet sorrirohING, LO: +Ay ed ure I hitd +esughmy, ied oity ine fonfswe pior hiD lhower, +Coiller. beall ryemeave womsent hay melrairs, yot, wisterears fade or bees ses seakinerren tohthist go, seo fheake tuf nas a, +B \ No newline at end of file