diff --git a/text/word-rnn/Manifest.toml b/text/word-rnn/Manifest.toml new file mode 100644 index 00000000..c425622b --- /dev/null +++ b/text/word-rnn/Manifest.toml @@ -0,0 +1,651 @@ +# This file is machine-generated - editing it directly is not advised + +julia_version = "1.7.2" +manifest_format = "2.0" + +[[deps.AbstractFFTs]] +deps = ["ChainRulesCore", "LinearAlgebra"] +git-tree-sha1 = "6f1d9bc1c08f9f4a8fa92e3ea3cb50153a1b40d4" +uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" +version = "1.1.0" + +[[deps.Accessors]] +deps = ["Compat", "CompositionsBase", "ConstructionBase", "Future", "LinearAlgebra", "MacroTools", "Requires", "Test"] +git-tree-sha1 = "0264a938934447408c7f0be8985afec2a2237af4" +uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" +version = "0.1.11" + +[[deps.Adapt]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "af92965fb30777147966f58acb05da51c5616b5f" +uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +version = "3.3.3" + +[[deps.ArgCheck]] +git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4" +uuid = "dce04be8-c92d-5529-be00-80e4d2c0e197" +version = "2.3.0" + +[[deps.ArgTools]] +uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" + +[[deps.ArrayInterface]] +deps = ["Compat", "IfElse", "LinearAlgebra", "Requires", "SparseArrays", "Static"] +git-tree-sha1 = "81f0cb60dc994ca17f68d9fb7c942a5ae70d9ee4" +uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +version = "5.0.8" + +[[deps.Artifacts]] +uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" + +[[deps.BFloat16s]] +deps = ["LinearAlgebra", "Printf", "Random", "Test"] +git-tree-sha1 = "a598ecb0d717092b5539dbbe890c98bac842b072" +uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" +version = "0.2.0" + +[[deps.BangBang]] +deps = ["Compat", "ConstructionBase", "Future", "InitialValues", "LinearAlgebra", "Requires", "Setfield", "Tables", "ZygoteRules"] +git-tree-sha1 = "b15a6bc52594f5e4a3b825858d1089618871bf9d" +uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" +version = "0.3.36" + +[[deps.Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" + +[[deps.Baselet]] +git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e" +uuid = "9718e550-a3fa-408a-8086-8db961cd8217" +version = "0.1.1" + +[[deps.CEnum]] +git-tree-sha1 = "eb4cb44a499229b3b8426dcfb5dd85333951ff90" +uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" +version = "0.4.2" + +[[deps.CUDA]] +deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "TimerOutputs"] +git-tree-sha1 = "bc6de7d0852de77a036a8648823b7edaf5a82852" +uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" +version = "3.9.1" + +[[deps.ChainRules]] +deps = ["ChainRulesCore", "Compat", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "Statistics"] +git-tree-sha1 = "ab656fb36197083c5817667e76cccd10d11f5c30" +uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" +version = "1.32.0" + +[[deps.ChainRulesCore]] +deps = ["Compat", "LinearAlgebra", "SparseArrays"] +git-tree-sha1 = "9950387274246d08af38f6eef8cb5480862a435f" +uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +version = "1.14.0" + +[[deps.ChangesOfVariables]] +deps = ["ChainRulesCore", "LinearAlgebra", "Test"] +git-tree-sha1 = "1e315e3f4b0b7ce40feded39c73049692126cf53" +uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" +version = "0.1.3" + +[[deps.CommonSubexpressions]] +deps = ["MacroTools", "Test"] +git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" +uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" +version = "0.3.0" + +[[deps.Compat]] +deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] +git-tree-sha1 = "b153278a25dd42c65abbf4e62344f9d22e59191b" +uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" +version = "3.43.0" + +[[deps.CompilerSupportLibraries_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" + +[[deps.CompositionsBase]] +git-tree-sha1 = "455419f7e328a1a2493cabc6428d79e951349769" +uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b" +version = "0.1.1" + +[[deps.ConstructionBase]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "f74e9d5388b8620b4cee35d4c5a618dd4dc547f4" +uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" +version = "1.3.0" + +[[deps.ContextVariablesX]] +deps = ["Compat", "Logging", "UUIDs"] +git-tree-sha1 = "8ccaa8c655bc1b83d2da4d569c9b28254ababd6e" +uuid = "6add18c4-b38d-439d-96f6-d6bc489c04c5" +version = "0.1.2" + +[[deps.DataAPI]] +git-tree-sha1 = "fb5f5316dd3fd4c5e7c30a24d50643b73e37cd40" +uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" +version = "1.10.0" + +[[deps.DataStructures]] +deps = ["Compat", "InteractiveUtils", "OrderedCollections"] +git-tree-sha1 = "cc1a8e22627f33c789ab60b36a9132ac050bbf75" +uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +version = "0.18.12" + +[[deps.DataValueInterfaces]] +git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" +uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" +version = "1.0.0" + +[[deps.Dates]] +deps = ["Printf"] +uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" + +[[deps.DefineSingletons]] +git-tree-sha1 = "0fba8b706d0178b4dc7fd44a96a92382c9065c2c" +uuid = "244e2a9f-e319-4986-a169-4d1fe445cd52" +version = "0.1.2" + +[[deps.DelimitedFiles]] +deps = ["Mmap"] +uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" + +[[deps.DiffResults]] +deps = ["StaticArrays"] +git-tree-sha1 = "c18e98cba888c6c25d1c3b048e4b3380ca956805" +uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +version = "1.0.3" + +[[deps.DiffRules]] +deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"] +git-tree-sha1 = "28d605d9a0ac17118fe2c5e9ce0fbb76c3ceb120" +uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" +version = "1.11.0" + +[[deps.Distributed]] +deps = ["Random", "Serialization", "Sockets"] +uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" + +[[deps.DocStringExtensions]] +deps = ["LibGit2"] +git-tree-sha1 = "b19534d1895d702889b219c382a6e18010797f0b" +uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +version = "0.8.6" + +[[deps.Downloads]] +deps = ["ArgTools", "LibCURL", "NetworkOptions"] +uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" + +[[deps.ExprTools]] +git-tree-sha1 = "56559bbef6ca5ea0c0818fa5c90320398a6fbf8d" +uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" +version = "0.1.8" + +[[deps.FLoops]] +deps = ["BangBang", "Compat", "FLoopsBase", "InitialValues", "JuliaVariables", "MLStyle", "Serialization", "Setfield", "Transducers"] +git-tree-sha1 = "4391d3ed58db9dc5a9883b23a0578316b4798b1f" +uuid = "cc61a311-1640-44b5-9fba-1b764f453329" +version = "0.2.0" + +[[deps.FLoopsBase]] +deps = ["ContextVariablesX"] +git-tree-sha1 = "656f7a6859be8673bf1f35da5670246b923964f7" +uuid = "b9860ae5-e623-471e-878b-f6a53c775ea6" +version = "0.1.1" + +[[deps.FillArrays]] +deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"] +git-tree-sha1 = "246621d23d1f43e3b9c368bf3b72b2331a27c286" +uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" +version = "0.13.2" + +[[deps.Flux]] +deps = ["Adapt", "ArrayInterface", "CUDA", "ChainRulesCore", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "NNlibCUDA", "Optimisers", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "Test", "Zygote"] +git-tree-sha1 = "f84e50845ab88702c721dc7c6129a85cbc1de332" +uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" +version = "0.13.1" + +[[deps.FoldsThreads]] +deps = ["Accessors", "FunctionWrappers", "InitialValues", "SplittablesBase", "Transducers"] +git-tree-sha1 = "eb8e1989b9028f7e0985b4268dabe94682249025" +uuid = "9c68100b-dfe1-47cf-94c8-95104e173443" +version = "0.1.1" + +[[deps.ForwardDiff]] +deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions", "StaticArrays"] +git-tree-sha1 = "89cc49bf5819f0a10a7a3c38885e7c7ee048de57" +uuid = "f6369f11-7733-5829-9624-2563aa707210" +version = "0.10.29" + +[[deps.FunctionWrappers]] +git-tree-sha1 = "241552bc2209f0fa068b6415b1942cc0aa486bcc" +uuid = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" +version = "1.1.2" + +[[deps.Functors]] +git-tree-sha1 = "223fffa49ca0ff9ce4f875be001ffe173b2b7de4" +uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +version = "0.2.8" + +[[deps.Future]] +deps = ["Random"] +uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" + +[[deps.GPUArrays]] +deps = ["Adapt", "LLVM", "LinearAlgebra", "Printf", "Random", "Serialization", "Statistics"] +git-tree-sha1 = "c783e8883028bf26fb05ed4022c450ef44edd875" +uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" +version = "8.3.2" + +[[deps.GPUCompiler]] +deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "TimerOutputs", "UUIDs"] +git-tree-sha1 = "556190e1e0ea3e37d83059fc9aa576f1e2104375" +uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" +version = "0.14.1" + +[[deps.IRTools]] +deps = ["InteractiveUtils", "MacroTools", "Test"] +git-tree-sha1 = "af14a478780ca78d5eb9908b263023096c2b9d64" +uuid = "7869d1d1-7146-5819-86e3-90919afe41df" +version = "0.4.6" + +[[deps.IfElse]] +git-tree-sha1 = "debdd00ffef04665ccbb3e150747a77560e8fad1" +uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" +version = "0.1.1" + +[[deps.InitialValues]] +git-tree-sha1 = "4da0f88e9a39111c2fa3add390ab15f3a44f3ca3" +uuid = "22cec73e-a1b8-11e9-2c92-598750a2cf9c" +version = "0.3.1" + +[[deps.InteractiveUtils]] +deps = ["Markdown"] +uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" + +[[deps.InverseFunctions]] +deps = ["Test"] +git-tree-sha1 = "336cc738f03e069ef2cac55a104eb823455dca75" +uuid = "3587e190-3f89-42d0-90ee-14403ec27112" +version = "0.1.4" + +[[deps.IrrationalConstants]] +git-tree-sha1 = "7fd44fd4ff43fc60815f8e764c0f352b83c49151" +uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" +version = "0.1.1" + +[[deps.IteratorInterfaceExtensions]] +git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" +uuid = "82899510-4779-5014-852e-03e436cf321d" +version = "1.0.0" + +[[deps.JLLWrappers]] +deps = ["Preferences"] +git-tree-sha1 = "abc9885a7ca2052a736a600f7fa66209f96506e1" +uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" +version = "1.4.1" + +[[deps.JuliaVariables]] +deps = ["MLStyle", "NameResolution"] +git-tree-sha1 = "49fb3cb53362ddadb4415e9b73926d6b40709e70" +uuid = "b14d175d-62b4-44ba-8fb7-3064adc8c3ec" +version = "0.2.4" + +[[deps.LLVM]] +deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"] +git-tree-sha1 = "c8d47589611803a0f3b4813d9e267cd4e3dbcefb" +uuid = "929cbde3-209d-540e-8aea-75f648917ca0" +version = "4.11.1" + +[[deps.LLVMExtra_jll]] +deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg", "TOML"] +git-tree-sha1 = "771bfe376249626d3ca12bcd58ba243d3f961576" +uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" +version = "0.0.16+0" + +[[deps.LazyArtifacts]] +deps = ["Artifacts", "Pkg"] +uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" + +[[deps.LibCURL]] +deps = ["LibCURL_jll", "MozillaCACerts_jll"] +uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" + +[[deps.LibCURL_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] +uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" + +[[deps.LibGit2]] +deps = ["Base64", "NetworkOptions", "Printf", "SHA"] +uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" + +[[deps.LibSSH2_jll]] +deps = ["Artifacts", "Libdl", "MbedTLS_jll"] +uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" + +[[deps.Libdl]] +uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" + +[[deps.LinearAlgebra]] +deps = ["Libdl", "libblastrampoline_jll"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[[deps.LogExpFunctions]] +deps = ["ChainRulesCore", "ChangesOfVariables", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"] +git-tree-sha1 = "09e4b894ce6a976c354a69041a04748180d43637" +uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +version = "0.3.15" + +[[deps.Logging]] +uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" + +[[deps.MLStyle]] +git-tree-sha1 = "e49789e5eb7b2d5577aaea395bfcac769df64bb8" +uuid = "d8e11817-5142-5d16-987a-aa16d5891078" +version = "0.4.11" + +[[deps.MLUtils]] +deps = ["ChainRulesCore", "DelimitedFiles", "FLoops", "FoldsThreads", "Random", "ShowCases", "Statistics", "StatsBase"] +git-tree-sha1 = "202617a5a49a8b5f3b4abf96621f2519b1592c74" +uuid = "f1d291b0-491e-4a28-83b9-f70985020b54" +version = "0.2.4" + +[[deps.MacroTools]] +deps = ["Markdown", "Random"] +git-tree-sha1 = "3d3e902b31198a27340d0bf00d6ac452866021cf" +uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +version = "0.5.9" + +[[deps.Markdown]] +deps = ["Base64"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" + +[[deps.MbedTLS_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" + +[[deps.MicroCollections]] +deps = ["BangBang", "InitialValues", "Setfield"] +git-tree-sha1 = "6bb7786e4f24d44b4e29df03c69add1b63d88f01" +uuid = "128add7d-3638-4c79-886c-908ea0c25c34" +version = "0.1.2" + +[[deps.Missings]] +deps = ["DataAPI"] +git-tree-sha1 = "bf210ce90b6c9eed32d25dbcae1ebc565df2687f" +uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" +version = "1.0.2" + +[[deps.Mmap]] +uuid = "a63ad114-7e13-5084-954f-fe012c677804" + +[[deps.MozillaCACerts_jll]] +uuid = "14a3606d-f60d-562e-9121-12d972cd8159" + +[[deps.NNlib]] +deps = ["Adapt", "ChainRulesCore", "Compat", "LinearAlgebra", "Pkg", "Requires", "Statistics"] +git-tree-sha1 = "f89de462a7bc3243f95834e75751d70b3a33e59d" +uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +version = "0.8.5" + +[[deps.NNlibCUDA]] +deps = ["CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics"] +git-tree-sha1 = "0d18b4c80a92a00d3d96e8f9677511a7422a946e" +uuid = "a00861dc-f156-4864-bf3c-e6376f28a68d" +version = "0.2.2" + +[[deps.NaNMath]] +git-tree-sha1 = "737a5957f387b17e74d4ad2f440eb330b39a62c5" +uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +version = "1.0.0" + +[[deps.NameResolution]] +deps = ["PrettyPrint"] +git-tree-sha1 = "1a0fa0e9613f46c9b8c11eee38ebb4f590013c5e" +uuid = "71a1bf82-56d0-4bbc-8a3c-48b961074391" +version = "0.1.5" + +[[deps.NetworkOptions]] +uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" + +[[deps.OpenBLAS_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] +uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" + +[[deps.OpenLibm_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "05823500-19ac-5b8b-9628-191a04bc5112" + +[[deps.OpenSpecFun_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" +uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" +version = "0.5.5+0" + +[[deps.Optimisers]] +deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"] +git-tree-sha1 = "2442c3ddbda547c80e8b6451a103719d6a3593dd" +uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" +version = "0.2.4" + +[[deps.OrderedCollections]] +git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c" +uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +version = "1.4.1" + +[[deps.Parameters]] +deps = ["OrderedCollections", "UnPack"] +git-tree-sha1 = "34c0e9ad262e5f7fc75b10a9952ca7692cfc5fbe" +uuid = "d96e819e-fc66-5662-9728-84c9c7592b0a" +version = "0.12.3" + +[[deps.Pkg]] +deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] +uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" + +[[deps.Preferences]] +deps = ["TOML"] +git-tree-sha1 = "47e5f437cc0e7ef2ce8406ce1e7e24d44915f88d" +uuid = "21216c6a-2e73-6563-6e65-726566657250" +version = "1.3.0" + +[[deps.PrettyPrint]] +git-tree-sha1 = "632eb4abab3449ab30c5e1afaa874f0b98b586e4" +uuid = "8162dcfd-2161-5ef2-ae6c-7681170c5f98" +version = "0.2.0" + +[[deps.Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" + +[[deps.ProgressLogging]] +deps = ["Logging", "SHA", "UUIDs"] +git-tree-sha1 = "80d919dee55b9c50e8d9e2da5eeafff3fe58b539" +uuid = "33c8b6b6-d38a-422a-b730-caa89a2f386c" +version = "0.1.4" + +[[deps.REPL]] +deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] +uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" + +[[deps.Random]] +deps = ["SHA", "Serialization"] +uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + +[[deps.Random123]] +deps = ["Random", "RandomNumbers"] +git-tree-sha1 = "afeacaecf4ed1649555a19cb2cad3c141bbc9474" +uuid = "74087812-796a-5b5d-8853-05524746bad3" +version = "1.5.0" + +[[deps.RandomNumbers]] +deps = ["Random", "Requires"] +git-tree-sha1 = "043da614cc7e95c703498a491e2c21f58a2b8111" +uuid = "e6cf234a-135c-5ec9-84dd-332b85af5143" +version = "1.5.3" + +[[deps.RealDot]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "9f0a1b71baaf7650f4fa8a1d168c7fb6ee41f0c9" +uuid = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" +version = "0.1.0" + +[[deps.Reexport]] +git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" +uuid = "189a3867-3050-52da-a836-e630ba90ab69" +version = "1.2.2" + +[[deps.Requires]] +deps = ["UUIDs"] +git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" +uuid = "ae029012-a4dd-5104-9daa-d747884805df" +version = "1.3.0" + +[[deps.SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" + +[[deps.Serialization]] +uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" + +[[deps.Setfield]] +deps = ["ConstructionBase", "Future", "MacroTools", "Requires"] +git-tree-sha1 = "38d88503f695eb0301479bc9b0d4320b378bafe5" +uuid = "efcf1570-3423-57d1-acb7-fd33fddbac46" +version = "0.8.2" + +[[deps.SharedArrays]] +deps = ["Distributed", "Mmap", "Random", "Serialization"] +uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" + +[[deps.ShowCases]] +git-tree-sha1 = "7f534ad62ab2bd48591bdeac81994ea8c445e4a5" +uuid = "605ecd9f-84a6-4c9e-81e2-4798472b76a3" +version = "0.1.0" + +[[deps.Sockets]] +uuid = "6462fe0b-24de-5631-8697-dd941f90decc" + +[[deps.SortingAlgorithms]] +deps = ["DataStructures"] +git-tree-sha1 = "b3363d7460f7d098ca0912c69b082f75625d7508" +uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" +version = "1.0.1" + +[[deps.SparseArrays]] +deps = ["LinearAlgebra", "Random"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + +[[deps.SpecialFunctions]] +deps = ["ChainRulesCore", "IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] +git-tree-sha1 = "5ba658aeecaaf96923dce0da9e703bd1fe7666f9" +uuid = "276daf66-3868-5448-9aa4-cd146d93841b" +version = "2.1.4" + +[[deps.SplittablesBase]] +deps = ["Setfield", "Test"] +git-tree-sha1 = "39c9f91521de844bad65049efd4f9223e7ed43f9" +uuid = "171d559e-b47b-412a-8079-5efa626c420e" +version = "0.1.14" + +[[deps.Static]] +deps = ["IfElse"] +git-tree-sha1 = "5309da1cdef03e95b73cd3251ac3a39f887da53e" +uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" +version = "0.6.4" + +[[deps.StaticArrays]] +deps = ["LinearAlgebra", "Random", "Statistics"] +git-tree-sha1 = "cd56bf18ed715e8b09f06ef8c6b781e6cdc49911" +uuid = "90137ffa-7385-5640-81b9-e52037218182" +version = "1.4.4" + +[[deps.Statistics]] +deps = ["LinearAlgebra", "SparseArrays"] +uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" + +[[deps.StatsAPI]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "c82aaa13b44ea00134f8c9c89819477bd3986ecd" +uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" +version = "1.3.0" + +[[deps.StatsBase]] +deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] +git-tree-sha1 = "8977b17906b0a1cc74ab2e3a05faa16cf08a8291" +uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +version = "0.33.16" + +[[deps.TOML]] +deps = ["Dates"] +uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" + +[[deps.TableTraits]] +deps = ["IteratorInterfaceExtensions"] +git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" +uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" +version = "1.0.1" + +[[deps.Tables]] +deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "OrderedCollections", "TableTraits", "Test"] +git-tree-sha1 = "5ce79ce186cc678bbb5c5681ca3379d1ddae11a1" +uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" +version = "1.7.0" + +[[deps.Tar]] +deps = ["ArgTools", "SHA"] +uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" + +[[deps.Test]] +deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] +uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[[deps.TimerOutputs]] +deps = ["ExprTools", "Printf"] +git-tree-sha1 = "7638550aaea1c9a1e86817a231ef0faa9aca79bd" +uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" +version = "0.5.19" + +[[deps.Transducers]] +deps = ["Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "Setfield", "SplittablesBase", "Tables"] +git-tree-sha1 = "c76399a3bbe6f5a88faa33c8f8a65aa631d95013" +uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999" +version = "0.4.73" + +[[deps.UUIDs]] +deps = ["Random", "SHA"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" + +[[deps.UnPack]] +git-tree-sha1 = "387c1f73762231e86e0c9c5443ce3b4a0a9a0c2b" +uuid = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" +version = "1.0.2" + +[[deps.Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" + +[[deps.Zlib_jll]] +deps = ["Libdl"] +uuid = "83775a58-1f1d-513f-b197-d71354ab007a" + +[[deps.Zygote]] +deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NaNMath", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] +git-tree-sha1 = "a49267a2e5f113c7afe93843deea7461c0f6b206" +uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" +version = "0.6.40" + +[[deps.ZygoteRules]] +deps = ["MacroTools"] +git-tree-sha1 = "8c1a8e4dfacb1fd631745552c8db35d0deb09ea0" +uuid = "700de1a5-db45-46bc-99cf-38207098b444" +version = "0.2.2" + +[[deps.libblastrampoline_jll]] +deps = ["Artifacts", "Libdl", "OpenBLAS_jll"] +uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" + +[[deps.nghttp2_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" + +[[deps.p7zip_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" diff --git a/text/word-rnn/Project.toml b/text/word-rnn/Project.toml new file mode 100644 index 00000000..233c8b53 --- /dev/null +++ b/text/word-rnn/Project.toml @@ -0,0 +1,4 @@ +[deps] +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" diff --git a/text/word-rnn/README.md b/text/word-rnn/README.md new file mode 100644 index 00000000..7e106c1f --- /dev/null +++ b/text/word-rnn/README.md @@ -0,0 +1,84 @@ +# Word-level generative language model + +Created using julia 1.7.2 + +## Introduction +This model and training sequence is based on the model in https://arxiv.org/pdf/1409.2329.pdf, and uses LSTMs with +dropout layers to avoid large RNNs tendency to over fit. + +## Model Architecture & Parameters + +Customizable Parameters +```julia +# Model Params +em_size::Int = 200 +recur_size::Int = 400 +clip::Float64 = 0.1 +dropout::Float64 = 0.25 + +# Training Params +γ::Float64 = 2.5 +epochs::Int = 8 + +# Data Params +nbatch::Int = 250 +seqlen::Int = 20 +``` + +Model Architecture +```julia +## word-rnn.jl +Flux.Embedding(vocab_size, args.em_size), +Flux.Dropout(args.dropout), +Flux.LSTM(args.em_size, args.recur_size), +Flux.Dropout(args.dropout), +Flux.LSTM(args.recur_size, args.recur_size), +Flux.Dropout(args.dropout), +Flux.Dense(args.recur_size, vocab_size) +``` + +## Corpus +The model is trained on Plato's republic, freely available from MIT +[here](http://classics.mit.edu/Plato/republic.mb.txt). The corpus is parsed in [corpus.jl](corpus.jl) + +## Training and Evaluation + +Set up the word-level generative language model project and install +dependencies from the `text/word-rnn` directory: + +```julia +using Pkg +Pkg.activate(".") +Pkg.instantiate() +``` + +Then begin training the model. This will take 20-25 minutes on a Macbook M1 chip, and 2-4 minutes +on an NVIDIA GPU with CUDA installed. +```julia +include("word-rnn.jl") +``` + +The model's final validation perplexity should be approximately `180 bits`, which is comparable to the model performance +in the paper accounting for the differences in the corpus and training time. + +Once the model's training is complete, the `word-rnn.jl` script will automatically sample from the model to +produce three random text sequences of length 20 using the following seeds, `socrates` and `liberty`, as well as +using a randomly sampled seed from the corpus. Example output: + +```shell +socrates should have been such said as reason they were them selected and take his UNK will binding not reduced me + +liberty and fain them in knowledge there were guilty many subjects which a tempers the introduction numerous like to alas yes + +vulgarity received and behold amusements and really and they UNK their qualities which were commanded at their own own UNK mind +``` + +To manually produce text sequences from the model use the `sample` method provided in `word-rnn.jl`. For example, +```julia +# 50 word sequence with random seed +sample(model, vocab, word2ind, 50) |> println + +# 25 word sequence with "freedom" as the seed +sample(model, vocab, word2ind, 25; seed="freedom") |> println +``` + diff --git a/text/word-rnn/corpus.jl b/text/word-rnn/corpus.jl new file mode 100644 index 00000000..b7c4dcad --- /dev/null +++ b/text/word-rnn/corpus.jl @@ -0,0 +1,86 @@ +using Flux: chunk, batchseq +using Base.Iterators: partition +using Downloads: download + +function load_and_clean_data() + isfile("train.txt") || + download("http://classics.mit.edu/Plato/republic.mb.txt", "train.txt") + + # read string + doc = lowercase(String(read("train.txt"))) + + # prepare corpus for tokenizer + doc = replace( + doc, + # replace common contractions + "n't" => " not", + "'s" => 's', + # replace chapter dividers with white space + "--" => ' ', + # replace quotes with white space + '\'' => ' ', + '"' => ' ' + ) + + # remove end of sentence punctuation + tokens = replace.(split(doc), r"(? "") + + # keep only alphabetic tokens + filter!(w -> all(isletter(c) for c in w), tokens) + + return tokens +end + +function get_tokens_and_vocabulary() + tokens = load_and_clean_data() + + ##### borrowed from model-zoo/text/treebank/data.jl ##### + # Count how many times each token appears. + freqs = Dict{String,Int}() + for t in tokens + freqs[t] = get(freqs, t, 0) + 1 + end + + # Replace singleton tokens with an "unknown" marker. + # This roughly cuts our "vocabulary" of tokens in half. + tokens = replace(t -> get(freqs, t, 0) == 1 ? "UNK" : t, tokens) + ########## + + # create vocabulary + vocabulary = unique(tokens) + + return tokens, vocabulary +end + +function onehot_data(batch, labels) + # create targets with dimension (sequence_length x vocab_size x samples) + return [[Flux.onehotbatch(b_i, labels) for b_i in b] for b in batch] +end + +function batchify_data(tokens, unk_token, args) + # restructure data into batches of dimension sequence_length x (features x samples) + return batchseq.(partition.(chunk(tokens, args.nbatch), args.seqlen), unk_token) +end + +function get_data(args) + # load the raw data + tokens, vocabulary = get_tokens_and_vocabulary() + + # vocab_size calculated from corpus + vocab_size = length(vocabulary) + + # map words to their indices in vocabulary array + word2ind = Dict(vocabulary .=> 1:vocab_size) + + # unknown token in vocabulary + unk = word2ind["UNK"] + + # convert string tokens to integers + tokens = map(x -> get(word2ind, x, nothing), tokens) + + # final data format + x_train = batchify_data(tokens[1:end-1], unk, args) + y_train = onehot_data(batchify_data(tokens[2:end], unk, args), 1:vocab_size) + + return x_train, y_train, word2ind, vocabulary +end diff --git a/text/word-rnn/word-rnn.jl b/text/word-rnn/word-rnn.jl new file mode 100644 index 00000000..a7195c86 --- /dev/null +++ b/text/word-rnn/word-rnn.jl @@ -0,0 +1,132 @@ +# architecture based on https://arxiv.org/pdf/1409.2329.pdf +include("corpus.jl") +using Flux +using Parameters: @with_kw +using Statistics: mean +using StatsBase: wsample + +@with_kw mutable struct Args + # Model Params + em_size::Int = 200 + recur_size::Int = 400 + clip::Float64 = 0.1 + dropout::Float64 = 0.25 + + # Training Params + γ::Float64 = 2.5 + epochs::Int = 8 + + # Data Params + nbatch::Int = 250 + seqlen::Int = 20 +end + +function create_model(vocab_size, args) + return Chain( + Flux.Embedding(vocab_size, args.em_size), + Flux.Dropout(args.dropout), + Flux.LSTM(args.em_size, args.recur_size), + Flux.Dropout(args.dropout), + Flux.LSTM(args.recur_size, args.recur_size), + Flux.Dropout(args.dropout), + Flux.Dense(args.recur_size, vocab_size) + ) +end + +function train(; kws...) + # initialize parameter struct + args = Args() + + # load train data and vocabulary + x_train, y_train, word2ind, vocab = get_data(args) + vocab_size = length(vocab) + + # create model + model = create_model(vocab_size, args; kws...) + + # logit cross entropy loss function + function loss(x, y) + Flux.reset!(model) + return mean(Flux.logitcrossentropy(model(x_i), y_i) for (x_i, y_i) in zip(x,y)) + end + + # reference to model params, and optimizer + ps = Flux.params(model) + + # create batch iterators for data and validation + data_loader = zip(x_train[1:end-5], y_train[1:end-5]) + hold_out = zip(x_train[end-5:end], y_train[end-5:end]) + + # used for updating hyperparameters + best_val_loss = nothing + lr = args.γ + + # begin training loop + @info "Start Training, total $(args.epochs) epochs" + for epoch = 1:args.epochs + + @info "Epoch $(epoch) / $(args.epochs)" + + for batch in data_loader + + gradient = Flux.gradient(ps) do + # compute loss for this batch + training_loss = loss(batch...) + return training_loss + end + + for x in ps + # apply clip to handle exploding gradients + grad_x = clamp!(gradient[x], -args.clip, args.clip) + # backprop + x .-= lr .* grad_x + end + end + + # compute and show the loss for the hold out set + validation_loss = mean([loss(x_v, y_v) for (x_v, y_v) in hold_out]) + @show(validation_loss, lr) + + if best_val_loss == nothing || validation_loss < best_val_loss + best_val_loss = validation_loss + else + # Anneal the learning rate if hold out set loss did not improve + lr /= 4.0 + end + end + + + # show final lr, and hold out set perplexity + valid_perplex = exp(mean([loss(x_v, y_v) for (x_v, y_v) in hold_out])) + @info "Training finished, final validation perplexity: $(valid_perplex) bits, final lr: $(lr)" + + return model, vocab, word2ind +end + +function sample(model, vocab, word2ind, len; seed="") + # load the model, and generate a sentence of length `len` + model = cpu(model) + Flux.reset!(model) + buf = IOBuffer() + if seed == "" + seed = string(rand(vocab)) + end + write(buf, seed) + c = wsample(vocab, Flux.softmax(model(word2ind[seed]))) + for i = 1:len + write(buf, ' ') + write(buf, c) + c = wsample(vocab, softmax(model(word2ind[c]))) + end + write(buf, '\n') + return String(take!(buf)) +end + +cd(@__DIR__) +@time begin + model, vocab, word2ind = train() +end +@info "Word language model generation examples:" +sample(model, vocab, word2ind, 20; seed="socrates") |> println +sample(model, vocab, word2ind, 20; seed="liberty") |> println +sample(model, vocab, word2ind, 20) |> println