diff --git a/smiles/ckpt/SMILESReDi_base.ckpt b/smiles/ckpt/SMILESReDi_base.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..5661438859ab23698ee873ca4d0a4e662f294377 --- /dev/null +++ b/smiles/ckpt/SMILESReDi_base.ckpt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8b2a6049e7ce6c2ef2ab0da0987ee14a5c7e3df737c5c32202631f5c230a81df +size 1038182356 diff --git a/smiles/ckpt/SMILESReDi_v1.ckpt b/smiles/ckpt/SMILESReDi_v1.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..f698ef340bef10b473403589de8088a6a6e5fdce --- /dev/null +++ b/smiles/ckpt/SMILESReDi_v1.ckpt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c2740ae1b81d5011c8de7109a8c5c9cead11e776feb46660b59697e040d5e595 +size 1038182548 diff --git a/smiles/data/11M_smiles_old_tokenizer_no_limit/dataset_dict.json b/smiles/data/11M_smiles_old_tokenizer_no_limit/dataset_dict.json new file mode 100644 index 0000000000000000000000000000000000000000..14dbc8c17a07765896e4991dacef79ff8ef64df1 --- /dev/null +++ b/smiles/data/11M_smiles_old_tokenizer_no_limit/dataset_dict.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:37b75ca00c6464fb0545b1a34339647a18ac5a94a18dc4906a2744889ec1a20d +size 28 diff --git a/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00000-of-00024.arrow b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00000-of-00024.arrow new file mode 100644 index 0000000000000000000000000000000000000000..c0dddb04a765cee1480e711049c170e40492dbc8 --- /dev/null +++ b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00000-of-00024.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b2871015033b5ee02228dbaa80ad9939a7902b062b9b10bdc6bc27f69fe8cd2e +size 477588328 diff --git a/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00001-of-00024.arrow b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00001-of-00024.arrow new file mode 100644 index 0000000000000000000000000000000000000000..f76e4ede7fd274708839511f5f624720df60919c --- /dev/null +++ b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00001-of-00024.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:286ae4a066882217ff95c91ca5315427ce1a9a8e0a11e72a9646beb9d8d286b0 +size 464191456 diff --git a/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00002-of-00024.arrow b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00002-of-00024.arrow new file mode 100644 index 0000000000000000000000000000000000000000..068ff851c4bed6d5e02c168af1e1a307dc7b15fe --- /dev/null +++ b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00002-of-00024.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a88f73782ad97f90ea04198cdf1d9f1d685d812026f298edaea9ce9beca00570 +size 479308880 diff --git a/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00003-of-00024.arrow b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00003-of-00024.arrow new file mode 100644 index 0000000000000000000000000000000000000000..efcfe3bdea248ff849904edf5fe9118e3fea4c5d --- /dev/null +++ b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00003-of-00024.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cd3615d74a0d36c86df8d6056d7203ba8a6caabf5c09a83192c6712bb4bd75d1 +size 492330784 diff --git a/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00004-of-00024.arrow b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00004-of-00024.arrow new file mode 100644 index 0000000000000000000000000000000000000000..0e7a5f1e2bcaf504f527dc8c05658fc8b5519614 --- /dev/null +++ b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00004-of-00024.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a7bb43f10e018b88dbbc8148cf8575cc31ec87cd05b7c7c66fc3ffa70d23a298 +size 491461384 diff --git a/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00005-of-00024.arrow b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00005-of-00024.arrow new file mode 100644 index 0000000000000000000000000000000000000000..70933333a8148e5bbd8c8898ed467ad1d94882e7 --- /dev/null +++ b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00005-of-00024.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:473af91a61fdbb15570789c55cb85e3425f505e4895448a1babcd68abfeacacf +size 490155080 diff --git a/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00006-of-00024.arrow b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00006-of-00024.arrow new file mode 100644 index 0000000000000000000000000000000000000000..eb9c50baed8c6e9f53914b5f6e5615effda6ece3 --- /dev/null +++ b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00006-of-00024.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:be40dce89b1ecd5b5476d491f252e4852c9a9369e48162b4ef7797c13f8508f5 +size 489402848 diff --git a/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00007-of-00024.arrow b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00007-of-00024.arrow new file mode 100644 index 0000000000000000000000000000000000000000..c5b6fb2f47dd49d027495f113824450e47e7e5de --- /dev/null +++ b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00007-of-00024.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:994f608ee9089d45f43c3e577afc73597855c98e42436621f5b8a9f4d1959fd4 +size 488308088 diff --git a/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00008-of-00024.arrow b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00008-of-00024.arrow new file mode 100644 index 0000000000000000000000000000000000000000..6ce7c210eee37fbbd20b1d1eb92dcbbd79bbcaec --- /dev/null +++ b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00008-of-00024.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:13f1a9ef28679b02cdf49b5fb26401b6a6fd819ce26b2d6ef4c9cc8486b5e5a8 +size 488565816 diff --git a/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00009-of-00024.arrow b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00009-of-00024.arrow new file mode 100644 index 0000000000000000000000000000000000000000..334011b6311e91d5d28035b44b9c07d910d0799c --- /dev/null +++ b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00009-of-00024.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:42fab66a9a83a5b7ab14f4d01949e5591168d3bdc37ab76f3abcda043b63c0b8 +size 487136800 diff --git a/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00010-of-00024.arrow b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00010-of-00024.arrow new file mode 100644 index 0000000000000000000000000000000000000000..33b6cfd44b5659690ff68a820d34c88c0fcf3136 --- /dev/null +++ b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00010-of-00024.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:09b5a103a3d305405890db95b82de8624cf57944a12bb7fe3b24644af3ef96bb +size 487375280 diff --git a/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00011-of-00024.arrow b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00011-of-00024.arrow new file mode 100644 index 0000000000000000000000000000000000000000..2cfdcc1b4e4a37bce1ff64cc7e6c2559cdf24179 --- /dev/null +++ b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00011-of-00024.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:95cec0ffca845fb12bcd511f876f98dfd8ef8420a9e0ec030547ed6bde47210a +size 486001120 diff --git a/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00012-of-00024.arrow b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00012-of-00024.arrow new file mode 100644 index 0000000000000000000000000000000000000000..54be2adf9d14c97ffe09ff86d2c30eb4003487fb --- /dev/null +++ b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00012-of-00024.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:37149287fd5f4e70df2d00dc54b0db0c5f4e590b449b69b3553ebb00ff8ae208 +size 485408104 diff --git a/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00013-of-00024.arrow b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00013-of-00024.arrow new file mode 100644 index 0000000000000000000000000000000000000000..252add2234e3fae38b1508e20881ccd3ac4ec963 --- /dev/null +++ b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00013-of-00024.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a16996dc09ccca1e1c3e9063e4dadb3b6329ad4081f35454bd261c0de774a911 +size 485633824 diff --git a/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00014-of-00024.arrow b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00014-of-00024.arrow new file mode 100644 index 0000000000000000000000000000000000000000..88af390a8297e648f4e1042654ef622e2605328f --- /dev/null +++ b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00014-of-00024.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:77f4beaefc1a08ae233e2ab5644e103af3f0c11284b0b49553d1c11c884ef69f +size 485841176 diff --git a/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00015-of-00024.arrow b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00015-of-00024.arrow new file mode 100644 index 0000000000000000000000000000000000000000..abfcb7f8d3424bcbde6ea58f585c776eca200a54 --- /dev/null +++ b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00015-of-00024.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:27d38dbf2c523723e5f3e29416ebca7a70b0002124372e2d40796bbaec966456 +size 484020296 diff --git a/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00016-of-00024.arrow b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00016-of-00024.arrow new file mode 100644 index 0000000000000000000000000000000000000000..5e7618a2a8400e1400b448f8f9927cfc35cfe81a --- /dev/null +++ b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00016-of-00024.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d0931dd1c981afabcc52e6ace64b986b7683e244be585cf5598c7cbe3165308b +size 484883888 diff --git a/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00017-of-00024.arrow b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00017-of-00024.arrow new file mode 100644 index 0000000000000000000000000000000000000000..0997e399e32efd31c9ca62bb6799de625f2d819e --- /dev/null +++ b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00017-of-00024.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2e0002265365f32697aa9b970443b6a4b4bdbf245fdafc30f4cececb6a7d40ee +size 483414096 diff --git a/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00018-of-00024.arrow b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00018-of-00024.arrow new file mode 100644 index 0000000000000000000000000000000000000000..9b69d097da9041eecb46440f9ba2cc99006eacdd --- /dev/null +++ b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00018-of-00024.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c9ed92e0b1c0b107eb1386dd6c498ee8df15b42a079bb80650b26f2de3b75475 +size 481238056 diff --git a/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00019-of-00024.arrow b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00019-of-00024.arrow new file mode 100644 index 0000000000000000000000000000000000000000..a40e1fd2dea6bb836587aec9c06c8623034cab3a --- /dev/null +++ b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00019-of-00024.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9d61618770e6ebad69b32e7d3b0b93a94cf2ea9cbc467cedf1fcccb1dbad6c2c +size 481634360 diff --git a/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00020-of-00024.arrow b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00020-of-00024.arrow new file mode 100644 index 0000000000000000000000000000000000000000..23d69d876df897d5d8b42251d59a3bf865ca5677 --- /dev/null +++ b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00020-of-00024.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:996aeae99180d8d4c3e3c1db18bfc12f37707d47d584e68255edb71e9ebc99da +size 483933744 diff --git a/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00021-of-00024.arrow b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00021-of-00024.arrow new file mode 100644 index 0000000000000000000000000000000000000000..1e0228a0123ee8c1d669d2a714d4af465204b518 --- /dev/null +++ b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00021-of-00024.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:37561a782f79cf1cce57a63cd6a429cf0cebba33f6c4339d2e75f4e63da00217 +size 481501560 diff --git a/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00022-of-00024.arrow b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00022-of-00024.arrow new file mode 100644 index 0000000000000000000000000000000000000000..6a39278d3c812858a70c4ec7f8b425c32139954f --- /dev/null +++ b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00022-of-00024.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fc4e8f69b678826007870368b3bd9b4a46f4f055aaeabfecb6dc4ca946a9a0e6 +size 478094088 diff --git a/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00023-of-00024.arrow b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00023-of-00024.arrow new file mode 100644 index 0000000000000000000000000000000000000000..836c11c5d31755c0dcc4d80957eb350b86fb7319 --- /dev/null +++ b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/data-00023-of-00024.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9e8ce766667969f716e2145501823d3d2ec22e3135297b24d0fa33589acf778d +size 478855336 diff --git a/smiles/data/11M_smiles_old_tokenizer_no_limit/train/dataset_info.json b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/dataset_info.json new file mode 100644 index 0000000000000000000000000000000000000000..36f39bd9e92a42956fb710ae1b8f7f745cb4f784 --- /dev/null +++ b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/dataset_info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bb71c7d0f7d8760fb31c407e18c8fe4917a8eb9ad1a5e356b767d94e0e000ff2 +size 619 diff --git a/smiles/data/11M_smiles_old_tokenizer_no_limit/train/state.json b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/state.json new file mode 100644 index 0000000000000000000000000000000000000000..44600c0fb830b8318075372e64d513231fa00cd2 --- /dev/null +++ b/smiles/data/11M_smiles_old_tokenizer_no_limit/train/state.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aeda4310bf416dcbba857e167834b0e3ccdc8f8f9be0ede33b1d0924dc5f755c +size 1604 diff --git a/smiles/data/11M_smiles_old_tokenizer_no_limit/val/data-00000-of-00012.arrow b/smiles/data/11M_smiles_old_tokenizer_no_limit/val/data-00000-of-00012.arrow new file mode 100644 index 0000000000000000000000000000000000000000..fa53e76e5f04503d3d4844e3f8b7e0d520f31ff4 --- /dev/null +++ b/smiles/data/11M_smiles_old_tokenizer_no_limit/val/data-00000-of-00012.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b81ab15dbcc73e25ca02140b3c101a3fd2407660a59fbf81ccbc3af96fe38e0f +size 479009832 diff --git a/smiles/data/11M_smiles_old_tokenizer_no_limit/val/data-00001-of-00012.arrow b/smiles/data/11M_smiles_old_tokenizer_no_limit/val/data-00001-of-00012.arrow new file mode 100644 index 0000000000000000000000000000000000000000..f48f5691126008cf8a1d44a3a2e975718c7498fd --- /dev/null +++ b/smiles/data/11M_smiles_old_tokenizer_no_limit/val/data-00001-of-00012.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:43ac35ede1a6adf2ca8159a568a8302fad4f6ea04c0a1ecda428b6e1a2e849de +size 483483872 diff --git a/smiles/data/11M_smiles_old_tokenizer_no_limit/val/data-00002-of-00012.arrow b/smiles/data/11M_smiles_old_tokenizer_no_limit/val/data-00002-of-00012.arrow new file mode 100644 index 0000000000000000000000000000000000000000..e8bd96b8159ef7b8d7436b45c2a226ece7e4c2ba --- /dev/null +++ b/smiles/data/11M_smiles_old_tokenizer_no_limit/val/data-00002-of-00012.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c075e124ba0c213d76552b930f7dae7a8823202e3ac2939196b730fe105b7090 +size 481710208 diff --git a/smiles/data/11M_smiles_old_tokenizer_no_limit/val/data-00003-of-00012.arrow b/smiles/data/11M_smiles_old_tokenizer_no_limit/val/data-00003-of-00012.arrow new file mode 100644 index 0000000000000000000000000000000000000000..8e36aa31f1b2a3e1cf0f20b7a2c0a4c7cc1189d1 --- /dev/null +++ b/smiles/data/11M_smiles_old_tokenizer_no_limit/val/data-00003-of-00012.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:576facb6744a5930942432cef811962a92e8f9d3274dfe2fdb40c72273e14ea0 +size 480840896 diff --git a/smiles/data/11M_smiles_old_tokenizer_no_limit/val/data-00004-of-00012.arrow b/smiles/data/11M_smiles_old_tokenizer_no_limit/val/data-00004-of-00012.arrow new file mode 100644 index 0000000000000000000000000000000000000000..9082bbdcc428eab05a40dd1473c29aa24ed3e7bb --- /dev/null +++ b/smiles/data/11M_smiles_old_tokenizer_no_limit/val/data-00004-of-00012.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d35117920343110ba621ee376db7eccdb96d1f6084ed773142aa1948ef1fe14d +size 480085056 diff --git a/smiles/data/11M_smiles_old_tokenizer_no_limit/val/data-00005-of-00012.arrow b/smiles/data/11M_smiles_old_tokenizer_no_limit/val/data-00005-of-00012.arrow new file mode 100644 index 0000000000000000000000000000000000000000..082a18e447428d1853d96861750273ccc7614e3c --- /dev/null +++ b/smiles/data/11M_smiles_old_tokenizer_no_limit/val/data-00005-of-00012.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:65023f39b5a8856a3e85307b7a36b33fa31591a0d0d3ae1fc6cf86775d0e6f57 +size 479035216 diff --git a/smiles/data/11M_smiles_old_tokenizer_no_limit/val/data-00006-of-00012.arrow b/smiles/data/11M_smiles_old_tokenizer_no_limit/val/data-00006-of-00012.arrow new file mode 100644 index 0000000000000000000000000000000000000000..8da370cb9af4cdc5f181b6f917ea610aab570c47 --- /dev/null +++ b/smiles/data/11M_smiles_old_tokenizer_no_limit/val/data-00006-of-00012.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:76246331cfb5db44e2ef1d6e2b42636cc270a5bb8f457ccca09598aa909b2818 +size 478170416 diff --git a/smiles/data/11M_smiles_old_tokenizer_no_limit/val/data-00007-of-00012.arrow b/smiles/data/11M_smiles_old_tokenizer_no_limit/val/data-00007-of-00012.arrow new file mode 100644 index 0000000000000000000000000000000000000000..bc89a0caa3d2271586a85d4a9833a13a1a2e56dc --- /dev/null +++ b/smiles/data/11M_smiles_old_tokenizer_no_limit/val/data-00007-of-00012.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:60e0c9660f347e5bc71c5065349fbe4026f495321e81975996066b59361d50b3 +size 477289400 diff --git a/smiles/data/11M_smiles_old_tokenizer_no_limit/val/data-00008-of-00012.arrow b/smiles/data/11M_smiles_old_tokenizer_no_limit/val/data-00008-of-00012.arrow new file mode 100644 index 0000000000000000000000000000000000000000..518867f88fd2b1a081bfaea96becbf3df2693124 --- /dev/null +++ b/smiles/data/11M_smiles_old_tokenizer_no_limit/val/data-00008-of-00012.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cb17deb75bc25626877746d2f39a029a036d6a8ee180a9a6bfa1dd6c99b4b793 +size 478124120 diff --git a/smiles/data/11M_smiles_old_tokenizer_no_limit/val/data-00009-of-00012.arrow b/smiles/data/11M_smiles_old_tokenizer_no_limit/val/data-00009-of-00012.arrow new file mode 100644 index 0000000000000000000000000000000000000000..a1f83c277090ca0464da9b4d2fda79bfa35eec3a --- /dev/null +++ b/smiles/data/11M_smiles_old_tokenizer_no_limit/val/data-00009-of-00012.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8a1721819aec27ca4c06b12ee9f3cd4aa5b5ea3ed1ac5d956398ab2790e8c921 +size 474046656 diff --git a/smiles/data/11M_smiles_old_tokenizer_no_limit/val/data-00010-of-00012.arrow b/smiles/data/11M_smiles_old_tokenizer_no_limit/val/data-00010-of-00012.arrow new file mode 100644 index 0000000000000000000000000000000000000000..32c44ec554a6e8b9f27b096bb83f8f1c0507a30e --- /dev/null +++ b/smiles/data/11M_smiles_old_tokenizer_no_limit/val/data-00010-of-00012.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:59d5ae6b9357e5cc2909b9f9b219a2989cc64ab686eb770eb4b4d3836b9563d8 +size 475710864 diff --git a/smiles/data/11M_smiles_old_tokenizer_no_limit/val/data-00011-of-00012.arrow b/smiles/data/11M_smiles_old_tokenizer_no_limit/val/data-00011-of-00012.arrow new file mode 100644 index 0000000000000000000000000000000000000000..e10dc1eef5c5f977d16c43cc02f42fa8e77ebfa6 --- /dev/null +++ b/smiles/data/11M_smiles_old_tokenizer_no_limit/val/data-00011-of-00012.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d18bc2c8060542c65d5928816e744ecb8050ca0cac22e6d0ddd0fc9d556d8df9 +size 471405008 diff --git a/smiles/data/11M_smiles_old_tokenizer_no_limit/val/dataset_info.json b/smiles/data/11M_smiles_old_tokenizer_no_limit/val/dataset_info.json new file mode 100644 index 0000000000000000000000000000000000000000..36f39bd9e92a42956fb710ae1b8f7f745cb4f784 --- /dev/null +++ b/smiles/data/11M_smiles_old_tokenizer_no_limit/val/dataset_info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bb71c7d0f7d8760fb31c407e18c8fe4917a8eb9ad1a5e356b767d94e0e000ff2 +size 619 diff --git a/smiles/data/11M_smiles_old_tokenizer_no_limit/val/state.json b/smiles/data/11M_smiles_old_tokenizer_no_limit/val/state.json new file mode 100644 index 0000000000000000000000000000000000000000..b9858bca17e4006de821b97b5cde2b739a57b076 --- /dev/null +++ b/smiles/data/11M_smiles_old_tokenizer_no_limit/val/state.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5de26b88b1538c191ccbce345f636ffe16659c6ffe178327981d0b0f81f6ce07 +size 896 diff --git a/smiles/data/v1/test/data-00000-of-00001.arrow b/smiles/data/v1/test/data-00000-of-00001.arrow new file mode 100644 index 0000000000000000000000000000000000000000..ea4ed3db3d48a60de0bd7d34cdbf2ad8c821197a --- /dev/null +++ b/smiles/data/v1/test/data-00000-of-00001.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:76a591da7810044a102a07adcb9959bf1e70cf3352e491d8a09da56f54f1b31e +size 125377832 diff --git a/smiles/data/v1/test/dataset_info.json b/smiles/data/v1/test/dataset_info.json new file mode 100644 index 0000000000000000000000000000000000000000..325bfcb0a70462335a16af9dd89ab015d969f4cc --- /dev/null +++ b/smiles/data/v1/test/dataset_info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5cf092add18034dcc835e91443ed9ac9e6ac2107c0215c59b0c0cf866bd655e6 +size 659 diff --git a/smiles/data/v1/test/state.json b/smiles/data/v1/test/state.json new file mode 100644 index 0000000000000000000000000000000000000000..3824f60096075c5eb65364f29be56fbdc65daf64 --- /dev/null +++ b/smiles/data/v1/test/state.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e13e0bc082bfb5b099336a1cf0f1629b124ed4d8798c07715abda0c3277eb001 +size 300 diff --git a/smiles/data/v1/train/data-00000-of-00003.arrow b/smiles/data/v1/train/data-00000-of-00003.arrow new file mode 100644 index 0000000000000000000000000000000000000000..830e93a10e87530af6670c6233d26cb14de0ba53 --- /dev/null +++ b/smiles/data/v1/train/data-00000-of-00003.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0dc3ce657e78f4125aa3c7d075da1b69dc46e25367bff8e988303a92192427eb +size 374779640 diff --git a/smiles/data/v1/train/data-00001-of-00003.arrow b/smiles/data/v1/train/data-00001-of-00003.arrow new file mode 100644 index 0000000000000000000000000000000000000000..532706925a890f6e8aa0a2feccc1eb08183881a8 --- /dev/null +++ b/smiles/data/v1/train/data-00001-of-00003.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:766a5b47948aacad2347a1b3d1c197455bab80a7be7d6eb520d52f3b2be41880 +size 374916056 diff --git a/smiles/data/v1/train/data-00002-of-00003.arrow b/smiles/data/v1/train/data-00002-of-00003.arrow new file mode 100644 index 0000000000000000000000000000000000000000..dbccab260e30038611a021bfc137f6824e49fb8f --- /dev/null +++ b/smiles/data/v1/train/data-00002-of-00003.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:db993610385a6ac26e8b2b8f8c003a43f8000055893e1dd15a7541505906591e +size 374201336 diff --git a/smiles/data/v1/train/dataset_info.json b/smiles/data/v1/train/dataset_info.json new file mode 100644 index 0000000000000000000000000000000000000000..325bfcb0a70462335a16af9dd89ab015d969f4cc --- /dev/null +++ b/smiles/data/v1/train/dataset_info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5cf092add18034dcc835e91443ed9ac9e6ac2107c0215c59b0c0cf866bd655e6 +size 659 diff --git a/smiles/data/v1/train/state.json b/smiles/data/v1/train/state.json new file mode 100644 index 0000000000000000000000000000000000000000..254529932561438409fedc25d7acd71528884824 --- /dev/null +++ b/smiles/data/v1/train/state.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:70b8df8bfdd139e4ed390de43a604fc5759824c3d1ba21305799cd232b539b7a +size 418 diff --git a/smiles/data/v1/validation/data-00000-of-00001.arrow b/smiles/data/v1/validation/data-00000-of-00001.arrow new file mode 100644 index 0000000000000000000000000000000000000000..2b05c9af50a2e08a0ca5ea22aabc5b652b653668 --- /dev/null +++ b/smiles/data/v1/validation/data-00000-of-00001.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:98b295cc9fd12b718cb57d44d35a20c03263589e017d1da38ab3f733ea16ad20 +size 125247968 diff --git a/smiles/data/v1/validation/dataset_info.json b/smiles/data/v1/validation/dataset_info.json new file mode 100644 index 0000000000000000000000000000000000000000..325bfcb0a70462335a16af9dd89ab015d969f4cc --- /dev/null +++ b/smiles/data/v1/validation/dataset_info.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5cf092add18034dcc835e91443ed9ac9e6ac2107c0215c59b0c0cf866bd655e6 +size 659 diff --git a/smiles/data/v1/validation/state.json b/smiles/data/v1/validation/state.json new file mode 100644 index 0000000000000000000000000000000000000000..827ee09ea2a703ada962e7d5c6605a1487090038 --- /dev/null +++ b/smiles/data/v1/validation/state.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eff9b32fae3303ceb21768b02edbdef6443a1897c3cdae44e55d6bcf18a7449f +size 300 diff --git a/smiles/dataloading_for_dynamic_batching.py b/smiles/dataloading_for_dynamic_batching.py new file mode 100644 index 0000000000000000000000000000000000000000..b7fbce1846897913a211c7ac6db33443243a88f1 --- /dev/null +++ b/smiles/dataloading_for_dynamic_batching.py @@ -0,0 +1,210 @@ +#!/usr/bin/env +import torch +from torch.utils.data import Dataset, DataLoader +from datasets import Dataset,load_from_disk +import sys +import pytorch_lightning as pl +from smiles_tokenizer.my_tokenizers import SMILES_SPE_Tokenizer +from functools import partial +import re +from tqdm import tqdm +import os +import pdb + + +class DynamicBatchingDataset(Dataset): + def __init__(self, dataset_dict, tokenizer): + print('Initializing dataset...') + self.dataset_dict = { + 'attention_mask': [torch.tensor(item) for item in tqdm(dataset_dict['attention_mask'])], + 'input_ids': [torch.tensor(item) for item in dataset_dict['input_ids']], + 'labels': dataset_dict['labels'] + } + self.tokenizer = tokenizer + + def __len__(self): + return len(self.dataset_dict['attention_mask']) + + def __getitem__(self, idx): + if isinstance(idx, int): + return { + 'input_ids': self.dataset_dict['input_ids'][idx], + 'attention_mask': self.dataset_dict['attention_mask'][idx], + 'labels': self.dataset_dict['labels'][idx] + } + elif isinstance(idx, list): + return { + 'input_ids': [self.dataset_dict['input_ids'][i] for i in idx], + 'attention_mask': [self.dataset_dict['attention_mask'][i] for i in idx], + 'labels': [self.dataset_dict['labels'][i] for i in idx] + } + else: + raise ValueError(f"Expected idx to be int or list, but got {type(idx)}") + +class CustomDataModule(pl.LightningDataModule): + def __init__(self, dataset_path, tokenizer): + super().__init__() + self.dataset = load_from_disk(dataset_path) + self.tokenizer = tokenizer + self.dataset_path = dataset_path + + def peptide_bond_mask(self, smiles_list): + """ + Returns a mask with shape (batch_size, seq_length) that has 1 at the locations + of recognized bonds in the positions dictionary and 0 elsewhere. + + Args: + smiles_list: List of peptide SMILES strings (batch of SMILES strings). + + Returns: + np.ndarray: A mask of shape (batch_size, seq_length) with 1s at bond positions. + """ + # Initialize the batch mask + batch_size = len(smiles_list) + max_seq_length = 1035 #max(len(smiles) for smiles in smiles_list) # Find the longest SMILES + mask = torch.zeros((batch_size, max_seq_length), dtype=torch.int) # Mask filled with zeros + + bond_patterns = [ + (r'OC\(=O\)', 'ester'), + (r'N\(C\)C\(=O\)', 'n_methyl'), + (r'N[12]C\(=O\)', 'peptide'), # Pro peptide bonds + (r'NC\(=O\)', 'peptide'), # Regular peptide bonds + (r'C\(=O\)N\(C\)', 'n_methyl'), + (r'C\(=O\)N[12]?', 'peptide') + ] + + for batch_idx, smiles in enumerate(smiles_list): + positions = [] + used = set() + + # Identify bonds + for pattern, bond_type in bond_patterns: + for match in re.finditer(pattern, smiles): + if not any(p in range(match.start(), match.end()) for p in used): + positions.append({ + 'start': match.start(), + 'end': match.end(), + 'type': bond_type, + 'pattern': match.group() + }) + used.update(range(match.start(), match.end())) + + # Update the mask for the current SMILES + for pos in positions: + mask[batch_idx, pos['start']:pos['end']] = 1 + + return mask + + def peptide_token_mask(self, smiles_list, token_lists): + """ + Returns a mask with shape (batch_size, num_tokens) that has 1 for tokens + where any part of the token overlaps with a peptide bond, and 0 elsewhere. + + Args: + smiles_list: List of peptide SMILES strings (batch of SMILES strings). + token_lists: List of tokenized SMILES strings (split into tokens). + + Returns: + np.ndarray: A mask of shape (batch_size, num_tokens) with 1s for peptide bond tokens. + """ + # Initialize the batch mask + batch_size = len(smiles_list) + token_seq_length = max(len(tokens) for tokens in token_lists) # Find the longest tokenized sequence + tokenized_masks = torch.zeros((batch_size, token_seq_length), dtype=torch.int) # Mask filled with zeros + atomwise_masks = self.peptide_bond_mask(smiles_list) + + + for batch_idx, atomwise_mask in enumerate(atomwise_masks): + token_seq = token_lists[batch_idx] + atom_idx = 0 + + for token_idx, token in enumerate(token_seq): + if token_idx != 0 and token_idx != len(token_seq) - 1: + if torch.sum(atomwise_mask[atom_idx:atom_idx+len(token)]) >= 1: + tokenized_masks[batch_idx][token_idx] = 1 + atom_idx += len(token) + + return tokenized_masks + + def collate_fn(self, batch): + item = batch[0] + # print(item) + # pdb.set_trace() + + token_array = self.tokenizer.get_token_split(item['input_ids']) + bond_mask = self.peptide_token_mask(item['labels'], token_array) + + return { + 'input_ids': item['input_ids'], + 'attention_mask': item['attention_mask'], + 'bond_mask': bond_mask + } + + def _train_dataset(self): + train_dataset = DynamicBatchingDataset(self.dataset['train'], tokenizer=self.tokenizer) + return train_dataset + + def _val_dataset(self): + val_dataset = DynamicBatchingDataset(self.dataset['val'], tokenizer=self.tokenizer) + return val_dataset + + def train_dataloader(self): + train_dataset = self._train_dataset() + # if train_dataset is None: + # # train_dataset = self._train_dataset() + # train_dataset = self.dataset_path + 'train' + + return DataLoader( + train_dataset, + batch_size=1, + collate_fn=self.collate_fn, # Use the instance method + shuffle=True, + num_workers=12, + pin_memory=True + ) + + def val_dataloader(self): + val_dataset = self._val_dataset() + # if val_dataset is None: + # # val_dataset = self._val_dataset() + # val_dataset = self.dataset_path + 'val' + + return DataLoader( + val_dataset, + batch_size=1, + collate_fn=self.collate_fn, # Use the instance method + num_workers=8, + pin_memory=True + ) + +class RectifyDataModule(pl.LightningDataModule): + def __init__(self, dataset_path): + super().__init__() + self.dataset_path = dataset_path + + def collate_fn(self, batch): + return { + 'source_ids': torch.tensor(batch[0]['source_ids']), + 'target_ids': torch.tensor(batch[0]['target_ids']), + 'bond_mask': torch.tensor(batch[0]['bond_mask']), + } + + def train_dataloader(self): + train_dataset = load_from_disk(os.path.join(self.dataset_path, 'train')) + return DataLoader( + train_dataset, + batch_size=1, + collate_fn=self.collate_fn, + num_workers=12, + pin_memory=True + ) + + def val_dataloader(self): + val_dataset = load_from_disk(os.path.join(self.dataset_path, 'validation')) + return DataLoader( + val_dataset, + batch_size=1, + collate_fn=self.collate_fn, + num_workers=8, + pin_memory=True + ) \ No newline at end of file diff --git a/smiles/generation.py b/smiles/generation.py new file mode 100644 index 0000000000000000000000000000000000000000..93171ec5a4078cc7b0fb9d39a861b8f2b43cc1d5 --- /dev/null +++ b/smiles/generation.py @@ -0,0 +1,153 @@ +import argparse +from pathlib import Path + +import torch +import torch.nn.functional as F +from tqdm import tqdm + +# Import necessary classes from your training script +from smiles_train import MDLMLightningModule, PeptideAnalyzer +from smiles_tokenizer.my_tokenizers import SMILES_SPE_Tokenizer + +import pdb + + +def generate_smiles(model, tokenizer, args): + """ + Generates peptide SMILES strings using the trained MDLM model + with a forward (t=0 to t=1) flow matching process. + + Args: + model (MDLMLightningModule): The trained PyTorch Lightning model. + tokenizer (SMILES_SPE_Tokenizer): The tokenizer used for training. + args (argparse.Namespace): Command-line arguments containing sampling parameters. + + Returns: + list[str]: A list of generated SMILES strings. + float: The validity rate of the generated SMILES. + """ + print("Starting SMILES generation with forward flow matching (t=0 to t=1)...") + model.eval() + device = args.device + + # 1. Start with a tensor of random tokens (pure noise at t=0) + x = torch.randint( + 0, + model.model.vocab_size, + (args.n_samples, args.seq_len), + device=device + ) + + # 2. Define the time schedule for the forward process (0.0 to 1.0) + time_steps = torch.linspace(0.0, 1.0, args.n_steps + 1, device=device) + + # 3. Iteratively follow the flow from noise to data + with torch.no_grad(): + for i in tqdm(range(args.n_steps), desc="Flow Matching Steps"): + t_curr = time_steps[i] + t_next = time_steps[i+1] + + # Prepare the current timestep tensor for the model + t_tensor = torch.full((args.n_samples,), t_curr, device=device) + + # Get the model's prediction for the final clean sequence (at t=1) + logits = model(x, t_tensor) + logits = logits / args.temperature + + pred_x1 = torch.argmax(logits, dim=-1) + + # On the last step, the result is the final prediction + if i == args.n_steps - 1: + x = pred_x1 + break + + # --- Construct the next state x_{t_next} --- + # The probability of a token being noise at time t_next is (1 - t_next). + noise_prob = 1.0 - t_next + mask = torch.rand(x.shape, device=device) < noise_prob + + # Generate new random tokens for the noise positions + noise = torch.randint( + 0, + model.model.vocab_size, + x.shape, + device=device + ) + + # Combine the final prediction with noise to form the next intermediate state + x = torch.where(mask, noise, pred_x1) + + # 4. Decode the final token IDs into SMILES strings + generated_sequences = tokenizer.batch_decode(x) + + # 5. Analyze the validity of the generated sequences + peptide_analyzer = PeptideAnalyzer() + valid_count = 0 + valid_smiles = [] + for seq in generated_sequences: + if peptide_analyzer.is_peptide(seq): + valid_count += 1 + valid_smiles.append(seq) + + validity_rate = valid_count / len(generated_sequences) + + print(f"\nGeneration complete. Validity rate: {validity_rate:.2%}") + return valid_smiles, validity_rate + + +def main(): + parser = argparse.ArgumentParser(description="Sample from a trained ReDi model.") + + # --- Required Arguments --- + parser.add_argument("--checkpoint_path", type=str, required=True, help="Path to the model checkpoint (.ckpt file).") + + # --- Sampling Arguments --- + parser.add_argument("--n_samples", type=int, default=16, help="Number of SMILES strings to generate.") + parser.add_argument("--seq_len", type=int, default=256, help="Maximum sequence length for generated SMILES.") + parser.add_argument("--n_steps", type=int, default=100, help="Number of denoising steps for sampling.") + parser.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature. Higher values increase diversity.") + + # --- Environment Arguments --- + parser.add_argument("--vocab_path", type=str, default='/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/new_vocab.txt', help="Path to tokenizer vocabulary file.") + parser.add_argument("--splits_path", type=str, default='/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/new_splits.txt', help="Path to tokenizer splits file.") + parser.add_argument("--output_file", type=str, default="generated_smiles.txt", help="File to save the valid generated SMILES.") + + args = parser.parse_args() + + # Set up device + device = "cuda" if torch.cuda.is_available() else "cpu" + args.device = device + print(f"Using device: {device}") + + # --- Load Model and Tokenizer --- + print("Loading tokenizer...") + tokenizer = SMILES_SPE_Tokenizer(args.vocab_path, args.splits_path) + + print(f"Loading model from checkpoint: {args.checkpoint_path}") + # Load hyperparameters from the checkpoint to ensure model architecture matches + checkpoint = torch.load(args.checkpoint_path, map_location=device, weights_only=False) + model_hparams = checkpoint["hyper_parameters"]["args"] + + # Instantiate the model with the loaded hyperparameters + model = MDLMLightningModule.load_from_checkpoint( + args.checkpoint_path, + args=model_hparams, + tokenizer=tokenizer, + map_location=device, + strict=False # Recommended if you have updated the code since training + ) + model.to(device) + + # --- Generate SMILES --- + valid_smiles, validity_rate = generate_smiles(model, tokenizer, args) + + # pdb.set_trace() + + with open('./v0_samples_200.csv', 'a') as f: + for smiles in valid_smiles: + # print(smiles) + f.write(smiles + '\n') + print(validity_rate) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/smiles/metrics.py b/smiles/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..b6192e607a2264a2920479e181355aad8e471178 --- /dev/null +++ b/smiles/metrics.py @@ -0,0 +1,266 @@ +import torch +import warnings +import numpy as np +import pandas as pd +from typing import List +from rdkit import Chem, rdBase, DataStructs +import pickle +import gzip +from rdkit.Chem import AllChem, Descriptors +# from utils.utils import mapper +import math +import os.path as op +from rdkit.Chem import rdMolDescriptors + +rdBase.DisableLog('rdApp.error') +warnings.filterwarnings("ignore", category=DeprecationWarning) +warnings.filterwarnings("ignore", category=UserWarning) +warnings.filterwarnings("ignore", category=FutureWarning) + +def fingerprints_from_mol(molecule, radius=3, size=2048, hashed=False): + """ Create ECFP fingerprint of a molecule """ + if hashed: + fp_bits = AllChem.GetHashedMorganFingerprint(molecule, radius, nBits=size) + else: + fp_bits = AllChem.GetMorganFingerprintAsBitVect(molecule, radius, nBits=size) + fp_np = np.zeros((1,)) + DataStructs.ConvertToNumpyArray(fp_bits, fp_np) + return fp_np.reshape(1, -1) + +def average_agg_tanimoto(stock_vecs, gen_vecs, batch_size=5000, agg='max', device='cuda', p=1): + """ + For each molecule in gen_vecs finds closest molecule in stock_vecs. + Returns average tanimoto score for between these molecules + + Parameters: + stock_vecs: numpy array + gen_vecs: numpy array + agg: max or mean + p: power for averaging: (mean x^p)^(1/p) + """ + assert agg in ['max', 'mean'], "Can aggregate only max or mean" + agg_tanimoto = np.zeros(len(gen_vecs)) + total = np.zeros(len(gen_vecs)) + for j in range(0, stock_vecs.shape[0], batch_size): + x_stock = torch.tensor(stock_vecs[j:j + batch_size]).to(device).float() + for i in range(0, gen_vecs.shape[0], batch_size): + y_gen = torch.tensor(gen_vecs[i:i + batch_size]).to(device).float() + y_gen = y_gen.transpose(0, 1) + tp = torch.mm(x_stock, y_gen) + jac = (tp / (x_stock.sum(1, keepdim=True) + y_gen.sum(0, keepdim=True) - tp)).cpu().numpy() + jac[np.isnan(jac)] = 1 + if p != 1: + jac = jac**p + if agg == 'max': + agg_tanimoto[i:i + y_gen.shape[1]] = np.maximum(agg_tanimoto[i:i + y_gen.shape[1]], jac.max(0)) + elif agg == 'mean': + agg_tanimoto[i:i + y_gen.shape[1]] += jac.sum(0) + total[i:i + y_gen.shape[1]] += jac.shape[0] + if agg == 'mean': + agg_tanimoto /= total + if p != 1: + agg_tanimoto = (agg_tanimoto)**(1/p) + return np.mean(agg_tanimoto) + +def get_mol(smiles_or_mol): + ''' + Loads SMILES/molecule into RDKit's object + ''' + if isinstance(smiles_or_mol, str): + if len(smiles_or_mol) == 0: + return None + mol = Chem.MolFromSmiles(smiles_or_mol) + if mol is None: + return None + try: + Chem.SanitizeMol(mol) + except ValueError: + return None + return mol + return smiles_or_mol + +def canonic_smiles(smiles_or_mol): + mol = get_mol(smiles_or_mol) + if mol is None: + return None + return Chem.MolToSmiles(mol) + +class SAScorer: + def __init__(self, model_path='/home/st512/peptune/scripts/peptide-mdlm-mcts/utils/sascore/SA_score_prediction.pkl.gz', input_type='smiles'): + self.clf = pickle.load(gzip.open(model_path, "rb")) + self.input_type = 'smiles' + + def __call__(self, smiles_file): + df = pd.read_csv(smiles_file) + smiles = df["SMILES"].tolist() + scores = self.get_scores(smiles) + return scores, scores + + @staticmethod + def _get_descriptors_from_smiles(smiles: List, radius=3, size=4096): # + """ + Add fingerprints together with SAscore and molecular weights + """ + fps = [] + valid_mask = [] + for i, smile in enumerate(smiles): + mol = Chem.MolFromSmiles(smile) if smile is not None else None + valid_mask.append(int(mol is not None)) + fp = fingerprints_from_mol(mol, radius, size=size) if mol else np.zeros((1, size)) + others = np.array([calculateScore(mol), Descriptors.ExactMolWt(mol)]) if mol else np.zeros(2) + prop_np = np.concatenate([others.T, fp.T[:, 0]]) + fps.append(prop_np) + + return fps, valid_mask + + def get_scores(self, smiles: List, valid_only=False): + descriptors, valid_mask = self._get_descriptors_from_smiles(smiles) + scores = self.clf.predict_proba(descriptors)[:, 1] + if valid_only: # filter by valid mask + return np.float32([scores[i] for i in range(len(scores)) if valid_mask[i]]) + return np.float32(scores * np.array(valid_mask)) + + +class Metrics: + + def __init__(self, prior_path='/scratch/pranamlab/tong/data/smiles/30K_all.csv', n_jobs=100, input_type='smiles'): + train_set_cano_smi = pd.read_csv(prior_path)['SMILES'].astype(str).tolist() + #print(train_set_cano_smi[:5]) + + for smi in train_set_cano_smi: + mol = Chem.MolFromSmiles(smi) + + if mol is None: + print(f"Invalid SMILES: {smi}") + + self.train_set_cano_smi = train_set_cano_smi + self.n_jobs = n_jobs + #self.input_type = 'helm' if input_type != 'smiles' else 'smiles' + self.ref_fps = np.vstack([ + fingerprints_from_mol(Chem.MolFromSmiles(smi)) + for smi in train_set_cano_smi + if Chem.MolFromSmiles(smi) is not None + ]) + #self.ref_fps = np.vstack([fingerprints_from_mol(Chem.MolFromSmiles(smi)) for smi in train_set_cano_smi]) + + def get_metrics(self, generated_path): + generated_smi = pd.read_csv(generated_path)['SMILES'].astype(str).tolist() + #generated_smi = pd.read_csv(generated_path, usecols=['Generated SMILES'])['Generated SMILES'].tolist() + #mols = [Chem.MolFromSmiles(smi) if smi else None for smi in generated_smi] + mols = [Chem.MolFromSmiles(smi) if smi else None for smi in generated_smi] + is_valid = [1 if mol else 0 for mol in mols] + validity = sum(is_valid) / len(is_valid) + + valid_canon_smiles = [Chem.MolToSmiles(mol) for mol in mols if mol] + uniqueness = len(set(valid_canon_smiles)) / len(valid_canon_smiles) + + uniq_smis = list(set(valid_canon_smiles)) + uniq_mols = [Chem.MolFromSmiles(smi) for smi in uniq_smis] + + fps = np.vstack([fingerprints_from_mol(mol) for mol in uniq_mols]) + diversity = 1 - (average_agg_tanimoto(fps, fps, agg='mean', p=1)).mean() + + snn = average_agg_tanimoto(self.ref_fps, fps, agg='max', p=1) + + # gen_smiles = mapper(self.n_jobs)(canonic_smiles, valid_canon_smiles) + # gen_smiles_set = set(gen_smiles) - {None} + # train_set = set(self.train_set_cano_smi) + # novelty = len(gen_smiles_set - train_set) / len(gen_smiles_set) + + # print(f"validity\tuniqueness\tdiversity\tsnn\tnovelty") + # print(f"{validity:.3f}\t{uniqueness:.3f}\t{diversity:.3f}\t{snn:.3f}\t{novelty:.3f}") + print(f"validity\tuniqueness\tdiversity\tsnn") + print(f"{validity:.3f}\t{uniqueness:.3f}\t{diversity:.3f}\t{snn:.3f}") + return { + "validity": validity, + "uniqueness": uniqueness, + "diversity": diversity, + "snn": snn, # "structural novelty" + # "novelty": novelty, + } + + +def readFragmentScores(name='fpscores'): + import gzip + global _fscores + # generate the full path filename: + if name == "fpscores": + name = op.join(op.dirname(__file__), name) + data = pickle.load(gzip.open('%s.pkl.gz' % name)) + outDict = {} + for i in data: + for j in range(1, len(i)): + outDict[i[j]] = float(i[0]) + _fscores = outDict + + +def numBridgeheadsAndSpiro(mol, ri=None): + nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol) + nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol) + return nBridgehead, nSpiro + + +def calculateScore(m): + if _fscores is None: + readFragmentScores() + + # fragment score + fp = rdMolDescriptors.GetMorganFingerprint(m, 2) # <- 2 is the *radius* of the circular fingerprint + fps = fp.GetNonzeroElements() + score1 = 0. + nf = 0 + for bitId, v in fps.items(): + nf += v + sfp = bitId + score1 += _fscores.get(sfp, -4) * v + try: + score1 /= nf + except ZeroDivisionError: # where nf is 0 + score1 = 1 + + # features score + nAtoms = m.GetNumAtoms() + nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True)) + ri = m.GetRingInfo() + nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri) + nMacrocycles = 0 + for x in ri.AtomRings(): + if len(x) > 8: + nMacrocycles += 1 + + sizePenalty = nAtoms**1.005 - nAtoms + stereoPenalty = math.log10(nChiralCenters + 1) + spiroPenalty = math.log10(nSpiro + 1) + bridgePenalty = math.log10(nBridgeheads + 1) + macrocyclePenalty = 0. + # --------------------------------------- + # This differs from the paper, which defines: + # macrocyclePenalty = math.log10(nMacrocycles+1) + # This form generates better results when 2 or more macrocycles are present + if nMacrocycles > 0: + macrocyclePenalty = math.log10(2) + + score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty + + # correction for the fingerprint density + # not in the original publication, added in version 1.1 + # to make highly symmetrical molecules easier to synthetise + score3 = 0. + if nAtoms > len(fps): + score3 = math.log(float(nAtoms) / len(fps)) * .5 + + sascore = score1 + score2 + score3 + + # need to transform "raw" value into scale between 1 and 10 + min = -4.0 + max = 2.5 + sascore = 11. - (sascore - min + 1) / (max - min) * 9. + # smooth the 10-end + if sascore > 8.: + sascore = 8. + math.log(sascore + 1. - 9.) + if sascore > 10.: + sascore = 10.0 + elif sascore < 1.: + sascore = 1.0 + + return sascore diff --git a/smiles/moo.py b/smiles/moo.py new file mode 100644 index 0000000000000000000000000000000000000000..328bf7796b47bd4e576fac16080dbf3b32870585 --- /dev/null +++ b/smiles/moo.py @@ -0,0 +1,333 @@ +import argparse +import re +import random +from collections import Counter +import csv +import torch +import torch.nn as nn +import torch.nn.functional as F +from tqdm import tqdm +from train import MDLMLightningModule, PeptideAnalyzer +from smiles_tokenizer.my_tokenizers import SMILES_SPE_Tokenizer +import numpy as np + +from smiles_classifiers import Analyzer, Hemolysis, Nonfouling, Solubility, BindingAffinity + +def peptide_bond_mask(smiles_list): + # Initialize the batch mask + batch_size = len(smiles_list) + max_seq_length = 1035 #max(len(smiles) for smiles in smiles_list) # Find the longest SMILES + mask = torch.zeros((batch_size, max_seq_length), dtype=torch.int) # Mask filled with zeros + + bond_patterns = [ + (r'OC\(=O\)', 'ester'), + (r'N\(C\)C\(=O\)', 'n_methyl'), + (r'N[12]C\(=O\)', 'peptide'), # Pro peptide bonds + (r'NC\(=O\)', 'peptide'), # Regular peptide bonds + (r'C\(=O\)N\(C\)', 'n_methyl'), + (r'C\(=O\)N[12]?', 'peptide') + ] + + for batch_idx, smiles in enumerate(smiles_list): + positions = [] + used = set() + + # Identify bonds + for pattern, bond_type in bond_patterns: + for match in re.finditer(pattern, smiles): + if not any(p in range(match.start(), match.end()) for p in used): + positions.append({ + 'start': match.start(), + 'end': match.end(), + 'type': bond_type, + 'pattern': match.group() + }) + used.update(range(match.start(), match.end())) + + # Update the mask for the current SMILES + for pos in positions: + mask[batch_idx, pos['start']:pos['end']] = 1 + + return mask + +def peptide_token_mask(smiles_list, token_lists): + # Initialize the batch mask + batch_size = len(smiles_list) + token_seq_length = max(len(tokens) for tokens in token_lists) # Find the longest tokenized sequence + tokenized_masks = torch.zeros((batch_size, token_seq_length), dtype=torch.int) # Mask filled with zeros + atomwise_masks = peptide_bond_mask(smiles_list) + + + for batch_idx, atomwise_mask in enumerate(atomwise_masks): + token_seq = token_lists[batch_idx] + atom_idx = 0 + + for token_idx, token in enumerate(token_seq): + if token_idx != 0 and token_idx != len(token_seq) - 1: + if torch.sum(atomwise_mask[atom_idx:atom_idx+len(token)]) >= 1: + tokenized_masks[batch_idx][token_idx] = 1 + atom_idx += len(token) + + return tokenized_masks + +class MOGGenerator: + def __init__(self, model, device, objectives, args): + self.model = model + self.device = device + self.objectives = objectives + self.args = args + self.num_objectives = len(objectives) + self.peptide_analyzer = PeptideAnalyzer() + self.invalid = [0, 1, 2, 3, 4, 585, 586] + + def generate_x0(self, n_steps=16, temperature=1.0): + print("Starting initial SMILES generation...") + self.model.eval() + + # 1. Start with a tensor of random tokens (pure noise at t=0) + x = torch.randint( + 0, + self.model.tokenizer.vocab_size, + (args.num_samples, args.gen_len), + device=self.device + ) + + # 2. Define the time schedule for the forward process (0.0 to 1.0) + time_steps = torch.linspace(0.0, 1.0, n_steps + 1, device=self.device) + + # 3. Iteratively follow the flow from noise to data + with torch.no_grad(): + for i in tqdm(range(n_steps), desc="Flow Matching Steps"): + t_curr = time_steps[i] + t_next = time_steps[i+1] + + # Prepare the current timestep tensor for the model + t_tensor = torch.full((args.num_samples,), t_curr, device=self.device) + + # Get the model's prediction for the final clean sequence (at t=1) + logits = self.model(x, t_tensor) + logits = logits / temperature + logits[:, :, 586] = -1000 + + pred_x1 = torch.argmax(logits, dim=-1) + + # On the last step, the result is the final prediction + if i == n_steps - 1: + x = pred_x1 + break + + # --- Construct the next state x_{t_next} --- + # The probability of a token being noise at time t_next is (1 - t_next). + noise_prob = 1.0 - t_next + mask = torch.rand(x.shape, device=self.device) < noise_prob + + # Generate new random tokens for the noise positions + noise = torch.randint( + 0, + self.model.tokenizer.vocab_size, + x.shape, + device=self.device + ) + + # Combine the final prediction with noise to form the next intermediate state + x = torch.where(mask, noise, pred_x1) + + # 4. Decode the final token IDs into SMILES strings + generated_sequences = self.model.tokenizer.batch_decode(x) + + # 5. Analyze the validity of the generated sequences + validities = [] + for seq in generated_sequences: + validities.append(self.peptide_analyzer.is_peptide(seq)) + print(seq) + print(f"Initial Sequence Validity: {validities}") + return x + + def _get_scores(self, x_batch): + """Calculates the normalized scores for a batch of sequences.""" + scores = [] + for obj_func in self.objectives: + scores.append(obj_func(x_batch.to(self.device))) + return torch.stack(scores, dim=0).to(self.device) + + def _barker_g(self, u): + """Barker balancing function.""" + return u / (1 + u) + + def validity(self, x): + sampled_sequences = self.model.tokenizer.batch_decode(x) + + return [1.0 if self.peptide_analyzer.is_peptide(seq) else 0.0 for seq in sampled_sequences] + + def generate(self): + """Main generation loop.""" + shape = (self.args.num_samples, self.args.gen_len) + x = self.generate_x0() + print(x) + + if args.weights is None: + # The first weight is for peptide analyzer + # We need to ensure the SMILES sequences are valid peptides + weights = torch.tensor([1] + [1/(self.num_objectives-1)] * (self.num_objectives-1), device=self.device).view(-1,1) + else: + weights = torch.tensor(self.args.weights, device=self.device).view(-1, 1) + if len(weights) != self.num_objectives: + raise ValueError("Number of weights must match number of objectives.") + print(f"Weights: {weights}") + + with torch.no_grad(): + for t in tqdm(range(self.args.optimization_steps), desc="MOG Generation"): + improved = False + # Anneal guidance strength + eta_t = self.args.eta_min + (self.args.eta_max - self.args.eta_min) * (t / (self.args.optimization_steps - 1)) + # Choose a random position to mutate + mut_idx = random.randint(1, self.args.gen_len-2) + + # Determine the generation timestep + # We cycle through the timesteps to ensure all are visited + generation_step = t % self.args.optimization_steps + time_t = torch.full((self.args.num_samples,), (generation_step / self.args.optimization_steps), device=self.device) + + # Get proposal distribution from ReDi model for the chosen position + logits = self.model(x, time_t) + probs = F.softmax(logits, dim=-1) + pos_probs = probs[:, mut_idx, :] + pos_probs[:, x[:, mut_idx]] = 0 # We don't evalute the same token + pos_probs[:, self.invalid] = 0 + + # Prune candidate vocabulary using top-p sampling + sorted_probs, sorted_indices = torch.sort(pos_probs, descending=True) + cumulative_probs = torch.cumsum(sorted_probs, dim=-1) + remove_mask = cumulative_probs > self.args.top_p + remove_mask[..., 1:] = remove_mask[..., :-1].clone() + remove_mask[..., 0] = 0 + + # Get the set of candidate tokens for each sample in the batch + candidate_tokens_list = [] + for i in range(self.args.num_samples): + sample_mask = remove_mask[i] + candidates = sorted_indices[i, ~sample_mask] + candidate_tokens_list.append(candidates) + + # Get current scores + current_scores = self._get_scores(x) + w_current = torch.exp(eta_t * torch.min(weights * current_scores, dim=0).values) + + if t == 0: + print(f"Initial Scores: {current_scores}") + + # Evaluate all candidate tokens for each sample + final_proposal_tokens = [] + for i in range(self.args.num_samples): + candidates = candidate_tokens_list[i] + candidates = torch.tensor([token for token in candidates if token not in self.invalid], device=candidates.device) + if len(candidates) >= 200: + candidates = candidates[:200] + num_candidates = len(candidates) + + # Create a batch of proposed sequences for the current sample + x_prop_batch = x[i].repeat(num_candidates, 1) + x_prop_batch[:, mut_idx] = candidates + + # Evaluate all proposals + proposal_scores = self._get_scores(x_prop_batch) + proposal_s_omega = torch.min(weights * proposal_scores, dim=0).values + w_proposal = torch.exp(eta_t * proposal_s_omega) + + # Get ReDi probabilities for the candidates + redi_probs = pos_probs[i, candidates] + + # Calculate unnormalized guided probabilities + tilde_q = redi_probs * self._barker_g(w_proposal / w_current[i]) + + # Normalize and sample the final token + final_probs = tilde_q / (torch.sum(tilde_q) + 1e-9) + + index = torch.multinomial(final_probs, 1).item() + + if (random.uniform(0,1) > 1 and proposal_scores[:,index][0] == 1) or torch.sum(weights.squeeze(1) * proposal_scores[:, index]) >= torch.sum(weights.squeeze(1) * current_scores[:,i]): + final_token = candidates[index] + print(f"Previous Weighted Sum: {torch.sum(weights.squeeze(1) * current_scores[:,i])}") + print(f"Previous Scores: {current_scores[:,i]}") + + print(f"New Weighted Sum: {torch.sum(weights.squeeze(1) * proposal_scores[:, index])}") + print(f"New Scores: {proposal_scores[:,index]}") + improved = True + + else: + final_token = x[i][mut_idx] + + final_proposal_tokens.append(final_token) + + # Update the sequences with the chosen tokens + x[torch.arange(self.args.num_samples), mut_idx] = torch.stack(final_proposal_tokens) + if improved: + print(self.model.tokenizer.batch_decode(x)) + + return x + + +# --- Main Execution --- +def main(args): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + target = args.target + tokenizer = SMILES_SPE_Tokenizer('/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/new_vocab.txt', '/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/new_splits.txt') + + analyzer_model = Analyzer(tokenizer) + hemolysis_model = Hemolysis(device) + nonfouling_model = Nonfouling(device) + solubility_model = Solubility(device) + permeability_model = Permeability(device) + affinity_model = BindingAffinity(target, device) + + # List of all objective functions + OBJECTIVE_FUNCTIONS = [analyzer_model, hemolysis_model, nonfouling_model, solubility_model, affinity_model] + + # --- Load Model --- + print(f"Loading model from checkpoint: {args.checkpoint}") + checkpoint = torch.load(args.checkpoint, map_location=device, weights_only=False) + model_hparams = checkpoint["hyper_parameters"]["args"] + + model = MDLMLightningModule.load_from_checkpoint( + args.checkpoint, + args=model_hparams, + tokenizer=tokenizer, + map_location=device, + strict=False + ) + model.to(device) + print("Model loaded successfully.") + + mog_generator = MOGGenerator(model, device, OBJECTIVE_FUNCTIONS, args) + + for _ in range(args.num_batches): + generated_tokens = mog_generator.generate() + final_scores = mog_generator._get_scores(generated_tokens.detach()).detach().cpu().numpy() + sequence_str = tokenizer.batch_decode(generated_tokens) + + print(sequence_str) + print(final_scores) + + print("Generation complete.") + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Multi-Objective Generation with LBP-MOG-ReDi (Single Mutation).") + + parser.add_argument("--checkpoint", type=str, required=True, help="Path to the trained ReDi model checkpoint.") + parser.add_argument("--num_samples", type=int, default=10, help="Number of samples to generate.") + parser.add_argument("--num_batches", type=int, default=10, help="Number of samples to generate.") + parser.add_argument("--output_file", type=str, default="./smiles.txt", help="File to save the generated sequences.") + parser.add_argument("--gen_len", type=int, default=50, help="Length of the sequences to generate.") + parser.add_argument("--optimization_steps", type=int, default=16, help="Number of passes over the sequence.") + parser.add_argument("--weights", type=float, nargs='+', required=False, help="Weights for the objectives (e.g., 0.5 0.5).") + parser.add_argument("--eta_min", type=float, default=1.0, help="Minimum guidance strength for annealing.") + parser.add_argument("--eta_max", type=float, default=20.0, help="Maximum guidance strength for annealing.") + parser.add_argument("--top_p", type=float, default=0.9, help="Top-p for pruning candidate tokens.") + + parser.add_argument("--target", type=str, required=True) + args = parser.parse_args() + main(args) diff --git a/smiles/new_coupling.py b/smiles/new_coupling.py new file mode 100644 index 0000000000000000000000000000000000000000..b27140a33091fa27a13b8fa082fb54224e495a4f --- /dev/null +++ b/smiles/new_coupling.py @@ -0,0 +1,270 @@ +import argparse +from pathlib import Path +import os +import re +import torch +import torch.nn.functional as F +from tqdm import tqdm +from datasets import Dataset, concatenate_datasets +import pdb + +# Import necessary classes from your provided scripts +# Ensure that smiles_train.py and the tokenizer are in the Python path +from train import MDLMLightningModule +from peptide_analyzer import PeptideAnalyzer +from smiles_tokenizer.my_tokenizers import SMILES_SPE_Tokenizer + +def peptide_bond_mask(smiles_list): + """ + Returns a mask with shape (batch_size, seq_length) that has 1 at the locations + of recognized bonds in the positions dictionary and 0 elsewhere. + + Args: + smiles_list: List of peptide SMILES strings (batch of SMILES strings). + + Returns: + np.ndarray: A mask of shape (batch_size, seq_length) with 1s at bond positions. + """ + # Initialize the batch mask + batch_size = len(smiles_list) + max_seq_length = 1035 #max(len(smiles) for smiles in smiles_list) # Find the longest SMILES + mask = torch.zeros((batch_size, max_seq_length), dtype=torch.int) # Mask filled with zeros + + bond_patterns = [ + (r'OC\(=O\)', 'ester'), + (r'N\(C\)C\(=O\)', 'n_methyl'), + (r'N[12]C\(=O\)', 'peptide'), # Pro peptide bonds + (r'NC\(=O\)', 'peptide'), # Regular peptide bonds + (r'C\(=O\)N\(C\)', 'n_methyl'), + (r'C\(=O\)N[12]?', 'peptide') + ] + + for batch_idx, smiles in enumerate(smiles_list): + positions = [] + used = set() + + # Identify bonds + for pattern, bond_type in bond_patterns: + for match in re.finditer(pattern, smiles): + if not any(p in range(match.start(), match.end()) for p in used): + positions.append({ + 'start': match.start(), + 'end': match.end(), + 'type': bond_type, + 'pattern': match.group() + }) + used.update(range(match.start(), match.end())) + + # Update the mask for the current SMILES + for pos in positions: + mask[batch_idx, pos['start']:pos['end']] = 1 + + return mask + +def peptide_token_mask(smiles_list, token_lists): + """ + Returns a mask with shape (batch_size, num_tokens) that has 1 for tokens + where any part of the token overlaps with a peptide bond, and 0 elsewhere. + + Args: + smiles_list: List of peptide SMILES strings (batch of SMILES strings). + token_lists: List of tokenized SMILES strings (split into tokens). + + Returns: + np.ndarray: A mask of shape (batch_size, num_tokens) with 1s for peptide bond tokens. + """ + # Initialize the batch mask + batch_size = len(smiles_list) + token_seq_length = max(len(tokens) for tokens in token_lists) # Find the longest tokenized sequence + tokenized_masks = torch.zeros((batch_size, token_seq_length), dtype=torch.int) # Mask filled with zeros + atomwise_masks = peptide_bond_mask(smiles_list) + + + for batch_idx, atomwise_mask in enumerate(atomwise_masks): + token_seq = token_lists[batch_idx] + atom_idx = 0 + + for token_idx, token in enumerate(token_seq): + if token_idx != 0 and token_idx != len(token_seq) - 1: + if torch.sum(atomwise_mask[atom_idx:atom_idx+len(token)]) >= 1: + tokenized_masks[batch_idx][token_idx] = 1 + atom_idx += len(token) + + return tokenized_masks + +def generate_and_filter_batch(model, tokenizer, peptide_analyzer, seq_len, batch_size, n_steps, temperature, device): + """ + Generates a single batch of SMILES, filters them for validity, and returns the valid ones + along with their original corresponding noise tensors (x0) and final token tensors (x1). + + Args: + model (MDLMLightningModule): The trained PyTorch Lightning model. + tokenizer (SMILES_SPE_Tokenizer): The tokenizer used for training. + peptide_analyzer (PeptideAnalyzer): The analyzer to validate peptides. + seq_len (int): The sequence length for this batch. + batch_size (int): The number of samples to generate in this batch. + n_steps (int): The number of steps for the flow matching process. + temperature (float): The sampling temperature. + device (str): The device to run generation on ('cuda' or 'cpu'). + + Returns: + tuple[list[str], list[torch.Tensor], list[torch.Tensor]]: A tuple containing: + - A list of valid, generated peptide SMILES strings. + - A list of the corresponding x0 tensors (noise). + - A list of the corresponding x1 tensors (final generated tokens). + """ + # 1. Start with a tensor of random tokens (pure noise at t=0) + x0 = torch.randint( + 0, + model.model.vocab_size, + (batch_size, seq_len), + device=device + ) + x = x0.clone() + + # 2. Define the time schedule for the forward process (0.0 to 1.0) + time_steps = torch.linspace(0.0, 1.0, n_steps + 1, device=device) + + # 3. Iteratively follow the flow from noise to data + with torch.no_grad(): + for i in range(n_steps): + t_curr = time_steps[i] + # Prepare the current timestep tensor for the model + t_tensor = torch.full((batch_size,), t_curr, device=device) + + # Get the model's prediction for the final clean sequence (at t=1) + logits = model(x, t_tensor) + if temperature > 0: + logits = logits / temperature + + pred_x1 = torch.argmax(logits, dim=-1) + + if i == n_steps - 1: + x = pred_x1 + break + + # --- Construct the next state x_{t_next} --- + t_next = time_steps[i+1] + noise_prob = 1.0 - t_next + mask = torch.rand(x.shape, device=device) < noise_prob + noise = torch.randint(0, model.model.vocab_size, x.shape, device=device) + x = torch.where(mask, noise, pred_x1) + + generated_sequences = tokenizer.batch_decode(x) + + # 5. Analyze the validity and collect valid (SMILES, x0, x1) triplets + valid_smiles = [] + valid_x0s = [] + valid_x1s = [] + for i, seq in enumerate(generated_sequences): + if peptide_analyzer.is_peptide(seq): + valid_smiles.append(seq) + valid_x0s.append(x0[i]) + valid_x1s.append(x[i]) # Store the final token tensor + + return valid_smiles, valid_x0s, valid_x1s + + +def main(args): + device = "cuda" if torch.cuda.is_available() else "cpu" + tokenizer = SMILES_SPE_Tokenizer(args.vocab_path, args.splits_path) + checkpoint = torch.load(args.checkpoint_path, map_location=device, weights_only=False) + model = MDLMLightningModule.load_from_checkpoint( + args.checkpoint_path, args=checkpoint["hyper_parameters"]["args"], + tokenizer=tokenizer, strict=False + ).to(device).eval() + pa = PeptideAnalyzer() + + all_sources = [] + all_targets = [] + all_bonds = [] + + for length in range(args.max_length, args.min_length - 1, -1): + print(f"\n--- Generating for length {length} ---") + + collected_for_len = 0 + pbar = tqdm(total=args.num_sequences_per_length, desc=f"Length {length}") + + # Accumulators for the current "save batch" + chunk_source, chunk_target, chunk_bond = [], [], [] + max_batch_size = args.max_tokens_in_batch // length + + while collected_for_len < args.num_sequences_per_length: + num_needed = args.num_sequences_per_length - collected_for_len + + gen_bsz = max_batch_size - len(chunk_target) if max_batch_size > len(chunk_target) else max_batch_size + if gen_bsz == 0: + print(f"Warning: Length {length} too long for token limit. Skipping.") + break + + actual_bsz = min(num_needed, gen_bsz) + + smiles, x0s, x1s = generate_and_filter_batch( + model, tokenizer, pa, length, actual_bsz, + args.n_steps, args.temperature, device + ) + + if smiles: + tokens = tokenizer.get_token_split(x1s) + b_masks = peptide_token_mask(smiles, tokens) + + chunk_source.extend([x.tolist() for x in x0s]) + chunk_target.extend([x.tolist() for x in x1s]) + chunk_bond.extend(b_masks.tolist()) + + collected_for_len += len(smiles) + pbar.update(len(smiles)) + + + # Check if current chunk hits the token limit, and if so, save it + if len(chunk_target) == min(max_batch_size, args.num_sequences_per_length): + all_sources.append(chunk_source) + all_targets.append(chunk_target) + all_bonds.append(chunk_bond) + chunk_source, chunk_target, chunk_bond = [], [], [] + + pbar.close() + + all_data = Dataset.from_dict({ + 'source_ids': all_sources, + 'target_ids': all_targets, + 'bond_mask': all_bonds + }) + print("\n--- Combining all generated data chunks ---") + print(f"Total valid sequences collected: {len(all_data)}") + + print(f"Saving new rectified dataset to {args.output_dir}...") + train_val = all_data.train_test_split(test_size=0.1, seed=42) + final_split = train_val['train'].train_test_split(test_size=(1/9), seed=42) + + train_val['train'].save_to_disk(os.path.join(args.output_dir, 'train')) + final_split['test'].save_to_disk(os.path.join(args.output_dir, 'validation')) + train_val['test'].save_to_disk(os.path.join(args.output_dir, 'test')) + + print("\nDataset combination and saving complete.") + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Generate rectified data couplings using a trained ReDi model for a range of lengths.") + + # --- Required Arguments --- + parser.add_argument("--checkpoint_path", type=str, required=True, help="Path to the model checkpoint (.ckpt file).") + parser.add_argument("--output_dir", type=str, required=True, help="Directory to save the new rectified dataset.") + + # --- Generation Arguments --- + parser.add_argument("--num_sequences_per_length", type=int, default=100, help="Number of valid sequences to generate for each length.") + parser.add_argument("--min_length", type=int, default=4, help="Minimum sequence length to generate.") + parser.add_argument("--max_length", type=int, default=1035, help="Maximum sequence length to generate (and padding length).") + parser.add_argument("--max_tokens_in_batch", type=int, default=5200, help="Maximum number of tokens in a single generation batch (batch_size * seq_len).") + parser.add_argument("--n_steps", type=int, default=100, help="Number of steps for the flow matching process.") + parser.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature. Higher values increase diversity. Set to 0 for pure argmax.") + + # --- Environment Arguments --- + parser.add_argument("--vocab_path", type=str, default='/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/new_vocab.txt', help="Path to tokenizer vocabulary file.") + parser.add_argument("--splits_path", type=str, default='/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/new_splits.txt', help="Path to tokenizer splits file.") + + args = parser.parse_args() + + main(args) + diff --git a/smiles/peptide_analyzer.py b/smiles/peptide_analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..fc6d0046661f6ad0608169747e43514a726c442c --- /dev/null +++ b/smiles/peptide_analyzer.py @@ -0,0 +1,1259 @@ +import os +import re +import pandas as pd +from io import StringIO +import rdkit +from rdkit import Chem +from rdkit.Chem import AllChem, Draw +import numpy as np +from PIL import Image, ImageDraw, ImageFont +import matplotlib.pyplot as plt +import matplotlib.patches as patches +from io import BytesIO +import tempfile +from rdkit import Chem, RDLogger +RDLogger.DisableLog('rdApp.*') + +import pdb + +class PeptideAnalyzer: + def __init__(self): + self.bond_patterns = [ + (r'OC\(=O\)', 'ester'), # Ester bond + (r'N\(C\)C\(=O\)', 'n_methyl'), # N-methylated peptide bond + (r'N[0-9]C\(=O\)', 'proline'), # Proline peptide bond + (r'NC\(=O\)', 'peptide'), # Standard peptide bond + (r'C\(=O\)N\(C\)', 'n_methyl_reverse'), # Reverse N-methylated + (r'C\(=O\)N[12]?', 'peptide_reverse') # Reverse peptide bond + ] + # Three to one letter code mapping + self.three_to_one = { + 'Ala': 'A', 'Cys': 'C', 'Asp': 'D', 'Glu': 'E', + 'Phe': 'F', 'Gly': 'G', 'His': 'H', 'Ile': 'I', + 'Lys': 'K', 'Leu': 'L', 'Met': 'M', 'Asn': 'N', + 'Pro': 'P', 'Gln': 'Q', 'Arg': 'R', 'Ser': 'S', + 'Thr': 'T', 'Val': 'V', 'Trp': 'W', 'Tyr': 'Y' + } + + def is_peptide(self, smiles): + """Check if the SMILES represents a peptide structure""" + # pdb.set_trace() + mol = Chem.MolFromSmiles(smiles) + if mol is None: + return False + + # Look for peptide bonds: NC(=O) pattern + peptide_bond_pattern = Chem.MolFromSmarts('[NH][C](=O)') + if mol.HasSubstructMatch(peptide_bond_pattern): + return True + + # Look for N-methylated peptide bonds: N(C)C(=O) pattern + n_methyl_pattern = Chem.MolFromSmarts('[N;H0;$(NC)](C)[C](=O)') + if mol.HasSubstructMatch(n_methyl_pattern): + return True + + return False + + def is_cyclic(self, smiles): + """Improved cyclic peptide detection""" + # Check for C-terminal carboxyl + if smiles.endswith('C(=O)O'): + return False, [], [] + + # Find all numbers used in ring closures + ring_numbers = re.findall(r'(?:^|[^c])[0-9](?=[A-Z@\(\)])', smiles) + + # Find aromatic ring numbers + aromatic_matches = re.findall(r'c[0-9](?:ccccc|c\[nH\]c)[0-9]', smiles) + aromatic_cycles = [] + for match in aromatic_matches: + numbers = re.findall(r'[0-9]', match) + aromatic_cycles.extend(numbers) + + # Numbers that aren't part of aromatic rings are peptide cycles + peptide_cycles = [n for n in ring_numbers if n not in aromatic_cycles] + + is_cyclic = len(peptide_cycles) > 0 and not smiles.endswith('C(=O)O') + return is_cyclic, peptide_cycles, aromatic_cycles + + def split_on_bonds(self, smiles): + """Split SMILES into segments with simplified Pro handling""" + positions = [] + used = set() + + # Find Gly pattern first + gly_pattern = r'NCC\(=O\)' + for match in re.finditer(gly_pattern, smiles): + if not any(p in range(match.start(), match.end()) for p in used): + positions.append({ + 'start': match.start(), + 'end': match.end(), + 'type': 'gly', + 'pattern': match.group() + }) + used.update(range(match.start(), match.end())) + + for pattern, bond_type in self.bond_patterns: + for match in re.finditer(pattern, smiles): + if not any(p in range(match.start(), match.end()) for p in used): + positions.append({ + 'start': match.start(), + 'end': match.end(), + 'type': bond_type, + 'pattern': match.group() + }) + used.update(range(match.start(), match.end())) + + # Sort by position + positions.sort(key=lambda x: x['start']) + + # Create segments + segments = [] + + if positions: + # First segment + if positions[0]['start'] > 0: + segments.append({ + 'content': smiles[0:positions[0]['start']], + 'bond_after': positions[0]['pattern'] + }) + + # Process segments + for i in range(len(positions)-1): + current = positions[i] + next_pos = positions[i+1] + + if current['type'] == 'gly': + segments.append({ + 'content': 'NCC(=O)', + 'bond_before': positions[i-1]['pattern'] if i > 0 else None, + 'bond_after': next_pos['pattern'] + }) + else: + content = smiles[current['end']:next_pos['start']] + if content: + segments.append({ + 'content': content, + 'bond_before': current['pattern'], + 'bond_after': next_pos['pattern'] + }) + + # Last segment + if positions[-1]['end'] < len(smiles): + segments.append({ + 'content': smiles[positions[-1]['end']:], + 'bond_before': positions[-1]['pattern'] + }) + + return segments + + def clean_terminal_carboxyl(self, segment): + """Remove C-terminal carboxyl only if it's the true terminus""" + content = segment['content'] + + # Only clean if: + # 1. Contains C(=O)O + # 2. No bond_after exists (meaning it's the last segment) + # 3. C(=O)O is at the end of the content + if 'C(=O)O' in content and not segment.get('bond_after'): + print('recognized?') + # Remove C(=O)O pattern regardless of position + cleaned = re.sub(r'\(C\(=O\)O\)', '', content) + # Remove any leftover empty parentheses + cleaned = re.sub(r'\(\)', '', cleaned) + print(cleaned) + return cleaned + return content + + def identify_residue(self, segment): + """Identify residue with Pro reconstruction""" + # Only clean terminal carboxyl if this is the last segment + content = self.clean_terminal_carboxyl(segment) + mods = self.get_modifications(segment) + + # UAA pattern matching section - before regular residues + # Phenylglycine and derivatives + if 'c1ccccc1' in content: + if '[C@@H](c1ccccc1)' in content or '[C@H](c1ccccc1)' in content: + return '4', mods # Base phenylglycine + + # 4-substituted phenylalanines + if 'Cc1ccc' in content: + if 'OMe' in content or 'OCc1ccc' in content: + return '0A1', mods # 4-methoxy-Phenylalanine + elif 'Clc1ccc' in content: + return '200', mods # 4-chloro-Phenylalanine + elif 'Brc1ccc' in content: + return '4BF', mods # 4-Bromo-phenylalanine + elif 'C#Nc1ccc' in content: + return '4CF', mods # 4-cyano-phenylalanine + elif 'Ic1ccc' in content: + return 'PHI', mods # 4-Iodo-phenylalanine + elif 'Fc1ccc' in content: + return 'PFF', mods # 4-Fluoro-phenylalanine + + # Modified tryptophans + if 'c[nH]c2' in content: + if 'Oc2cccc2' in content: + return '0AF', mods # 7-hydroxy-tryptophan + elif 'Fc2cccc2' in content: + return '4FW', mods # 4-fluoro-tryptophan + elif 'Clc2cccc2' in content: + return '6CW', mods # 6-chloro-tryptophan + elif 'Brc2cccc2' in content: + return 'BTR', mods # 6-bromo-tryptophan + elif 'COc2cccc2' in content: + return 'MOT5', mods # 5-Methoxy-tryptophan + elif 'Cc2cccc2' in content: + return 'MTR5', mods # 5-Methyl-tryptophan + + # Special amino acids + if 'CC(C)(C)[C@@H]' in content or 'CC(C)(C)[C@H]' in content: + return 'BUG', mods # Tertleucine + + if 'CCCNC(=N)N' in content: + return 'CIR', mods # Citrulline + + if '[SeH]' in content: + return 'CSE', mods # Selenocysteine + + if '[NH3]CC[C@@H]' in content or '[NH3]CC[C@H]' in content: + return 'DAB', mods # Diaminobutyric acid + + if 'C1CCCCC1' in content: + if 'C1CCCCC1[C@@H]' in content or 'C1CCCCC1[C@H]' in content: + return 'CHG', mods # Cyclohexylglycine + elif 'C1CCCCC1C[C@@H]' in content or 'C1CCCCC1C[C@H]' in content: + return 'ALC', mods # 3-cyclohexyl-alanine + + # Naphthalene derivatives + if 'c1cccc2c1cccc2' in content: + if 'c1cccc2c1cccc2[C@@H]' in content or 'c1cccc2c1cccc2[C@H]' in content: + return 'NAL', mods # 2-Naphthyl-alanine + + # Heteroaromatic derivatives + if 'c1cncc' in content: + return 'PYR4', mods # 3-(4-Pyridyl)-alanine + if 'c1cscc' in content: + return 'THA3', mods # 3-(3-thienyl)-alanine + if 'c1nnc' in content: + return 'TRZ4', mods # 3-(1,2,4-Triazol-1-yl)-alanine + + # Modified serines and threonines + if 'OP(O)(O)O' in content: + if '[C@@H](COP' in content or '[C@H](COP' in content: + return 'SEP', mods # phosphoserine + elif '[C@@H](OP' in content or '[C@H](OP' in content: + return 'TPO', mods # phosphothreonine + + # Specialized ring systems + if 'c1c2ccccc2cc2c1cccc2' in content: + return 'ANTH', mods # 3-(9-anthryl)-alanine + if 'c1csc2c1cccc2' in content: + return 'BTH3', mods # 3-(3-benzothienyl)-alanine + if '[C@]12C[C@H]3C[C@@H](C2)C[C@@H](C1)C3' in content: + return 'ADAM', mods # Adamanthane + + # Fluorinated derivatives + if 'FC(F)(F)' in content: + if 'CC(F)(F)F' in content: + return 'FLA', mods # Trifluoro-alanine + if 'C(F)(F)F)c1' in content: + if 'c1ccccc1C(F)(F)F' in content: + return 'TFG2', mods # 2-(Trifluoromethyl)-phenylglycine + if 'c1cccc(c1)C(F)(F)F' in content: + return 'TFG3', mods # 3-(Trifluoromethyl)-phenylglycine + if 'c1ccc(cc1)C(F)(F)F' in content: + return 'TFG4', mods # 4-(Trifluoromethyl)-phenylglycine + + # Multiple halogen patterns + if 'F' in content and 'c1' in content: + if 'c1ccc(c(c1)F)F' in content: + return 'F2F', mods # 3,4-Difluoro-phenylalanine + if 'cc(F)cc(c1)F' in content: + return 'WFP', mods # 3,5-Difluoro-phenylalanine + if 'Cl' in content and 'c1' in content: + if 'c1ccc(cc1Cl)Cl' in content: + return 'CP24', mods # 2,4-dichloro-phenylalanine + if 'c1ccc(c(c1)Cl)Cl' in content: + return 'CP34', mods # 3,4-dichloro-phenylalanine + + # Hydroxy and amino derivatives + if 'O' in content and 'c1' in content: + if 'c1cc(O)cc(c1)O' in content: + return '3FG', mods # (2s)-amino(3,5-dihydroxyphenyl)-ethanoic acid + if 'c1ccc(c(c1)O)O' in content: + return 'DAH', mods # 3,4-Dihydroxy-phenylalanine + + # Cyclic amino acids + if 'C1CCCC1' in content: + return 'CPA3', mods # 3-Cyclopentyl-alanine + if 'C1CCCCC1' in content: + if 'CC1CCCCC1' in content: + return 'ALC', mods # 3-cyclohexyl-alanine + else: + return 'CHG', mods # Cyclohexylglycine + + # Chain-length variants + if 'CCC[C@@H]' in content or 'CCC[C@H]' in content: + return 'NLE', mods # Norleucine + if 'CC[C@@H]' in content or 'CC[C@H]' in content: + if not any(x in content for x in ['CC(C)', 'COC', 'CN(']): + return 'ABA', mods # 2-Aminobutyric acid + + # Modified histidines + if 'c1cnc' in content: + if '[C@@H]1CN[C@@H](N1)F' in content: + return '2HF', mods # 2-fluoro-l-histidine + if 'c1cnc([nH]1)F' in content: + return '2HF1', mods # 2-fluoro-l-histidine variant + if 'c1c[nH]c(n1)F' in content: + return '2HF2', mods # 2-fluoro-l-histidine variant + + # Sulfur and selenium containing + if '[SeH]' in content: + return 'CSE', mods # Selenocysteine + if 'S' in content: + if 'CSCc1ccccc1' in content: + return 'BCS', mods # benzylcysteine + if 'CCSC' in content: + return 'ESC', mods # Ethionine + if 'CCS' in content: + return 'HCS', mods # homocysteine + + # Additional modifications + if 'CN=[N]=N' in content: + return 'AZDA', mods # azido-alanine + if '[NH]=[C](=[NH2])=[NH2]' in content: + if 'CCC[NH]=' in content: + return 'AGM', mods # 5-methyl-arginine + if 'CC[NH]=' in content: + return 'GDPR', mods # 2-Amino-3-guanidinopropionic acid + + if 'CCON' in content: + return 'CAN', mods # canaline + if '[C@@H]1C=C[C@@H](C=C1)' in content: + return 'ACZ', mods # cis-amiclenomycin + if 'CCC(=O)[NH3]' in content: + return 'ONL', mods # 5-oxo-l-norleucine + if 'c1ccncc1' in content: + return 'PYR4', mods # 3-(4-Pyridyl)-alanine + if 'c1ccco1' in content: + return 'FUA2', mods # (2-furyl)-alanine + + if 'c1ccc' in content: + if 'c1ccc(cc1)c1ccccc1' in content: + return 'BIF', mods # 4,4-biphenylalanine + if 'c1ccc(cc1)C(=O)c1ccccc1' in content: + return 'PBF', mods # 4-benzoyl-phenylalanine + if 'c1ccc(cc1)C(C)(C)C' in content: + return 'TBP4', mods # 4-tert-butyl-phenylalanine + if 'c1ccc(cc1)[C](=[NH2])=[NH2]' in content: + return '0BN', mods # 4-carbamimidoyl-l-phenylalanine + if 'c1cccc(c1)[C](=[NH2])=[NH2]' in content: + return 'APM', mods # m-amidinophenyl-3-alanine + + # Multiple hydroxy patterns + if 'O' in content: + if '[C@H]([C@H](C)O)O' in content: + return 'ILX', mods # 4,5-dihydroxy-isoleucine + if '[C@H]([C@@H](C)O)O' in content: + return 'ALO', mods # Allo-threonine + if '[C@H](COP(O)(O)O)' in content: + return 'SEP', mods # phosphoserine + if '[C@H]([C@@H](C)OP(O)(O)O)' in content: + return 'TPO', mods # phosphothreonine + if '[C@H](c1ccc(O)cc1)O' in content: + return 'OMX', mods # (betar)-beta-hydroxy-l-tyrosine + if '[C@H](c1ccc(c(Cl)c1)O)O' in content: + return 'OMY', mods # (betar)-3-chloro-beta-hydroxy-l-tyrosine + + # Heterocyclic patterns + if 'n1' in content: + if 'n1cccn1' in content: + return 'PYZ1', mods # 3-(1-Pyrazolyl)-alanine + if 'n1nncn1' in content: + return 'TEZA', mods # 3-(2-Tetrazolyl)-alanine + if 'c2c(n1)cccc2' in content: + return 'QU32', mods # 3-(2-Quinolyl)-alanine + if 'c1cnc2c(c1)cccc2' in content: + return 'QU33', mods # 3-(3-quinolyl)-alanine + if 'c1ccnc2c1cccc2' in content: + return 'QU34', mods # 3-(4-quinolyl)-alanine + if 'c1ccc2c(c1)nccc2' in content: + return 'QU35', mods # 3-(5-Quinolyl)-alanine + if 'c1ccc2c(c1)cncc2' in content: + return 'QU36', mods # 3-(6-Quinolyl)-alanine + if 'c1cnc2c(n1)cccc2' in content: + return 'QX32', mods # 3-(2-quinoxalyl)-alanine + + # Multiple nitrogen patterns + if 'N' in content: + if '[NH3]CC[C@@H]' in content: + return 'DAB', mods # Diaminobutyric acid + if '[NH3]C[C@@H]' in content: + return 'DPP', mods # 2,3-Diaminopropanoic acid + if '[NH3]CCCCCC[C@@H]' in content: + return 'HHK', mods # (2s)-2,8-diaminooctanoic acid + if 'CCC[NH]=[C](=[NH2])=[NH2]' in content: + return 'GBUT', mods # 2-Amino-4-guanidinobutryric acid + if '[NH]=[C](=S)=[NH2]' in content: + return 'THIC', mods # Thio-citrulline + + # Chain modified amino acids + if 'CC' in content: + if 'CCCC[C@@H]' in content: + return 'AHP', mods # 2-Aminoheptanoic acid + if 'CCC([C@@H])(C)C' in content: + return 'I2M', mods # 3-methyl-l-alloisoleucine + if 'CC[C@H]([C@@H])C' in content: + return 'IIL', mods # Allo-Isoleucine + if '[C@H](CCC(C)C)' in content: + return 'HLEU', mods # Homoleucine + if '[C@@H]([C@@H](C)O)C' in content: + return 'HLU', mods # beta-hydroxyleucine + + # Modified glutamate/aspartate patterns + if '[C@@H]' in content: + if '[C@@H](C[C@@H](F))' in content: + return 'FGA4', mods # 4-Fluoro-glutamic acid + if '[C@@H](C[C@@H](O))' in content: + return '3GL', mods # 4-hydroxy-glutamic-acid + if '[C@@H](C[C@H](C))' in content: + return 'LME', mods # (3r)-3-methyl-l-glutamic acid + if '[C@@H](CC[C@H](C))' in content: + return 'MEG', mods # (3s)-3-methyl-l-glutamic acid + + # Sulfur and selenium modifications + if 'S' in content: + if 'SCC[C@@H]' in content: + return 'HSER', mods # homoserine + if 'SCCN' in content: + return 'SLZ', mods # thialysine + if 'SC(=O)' in content: + return 'CSA', mods # s-acetonylcysteine + if '[S@@](=O)' in content: + return 'SME', mods # Methionine sulfoxide + if 'S(=O)(=O)' in content: + return 'OMT', mods # Methionine sulfone + + # Double bond containing + if 'C=' in content: + if 'C=C[C@@H]' in content: + return '2AG', mods # 2-Allyl-glycine + if 'C=C[C@@H]' in content: + return 'LVG', mods # vinylglycine + if 'C=Cc1ccccc1' in content: + return 'STYA', mods # Styrylalanine + + # Special cases + if '[C@@H]1Cc2c(C1)cccc2' in content: + return 'IGL', mods # alpha-amino-2-indanacetic acid + if '[C](=[C](=O)=O)=O' in content: + return '26P', mods # 2-amino-6-oxopimelic acid + if '[C](=[C](=O)=O)=C' in content: + return '2NP', mods # l-2-amino-6-methylene-pimelic acid + if 'c2cnc[nH]2' in content: + return 'HIS', mods # histidine core + if 'c1cccc2c1cc(O)cc2' in content: + return 'NAO1', mods # 5-hydroxy-1-naphthalene + if 'c1ccc2c(c1)cc(O)cc2' in content: + return 'NAO2', mods # 6-hydroxy-2-naphthalene + + # Proline (P) - flexible ring numbers + if any([ + # Check for any ring number in bond patterns + (segment.get('bond_after', '').startswith(f'N{n}C(=O)') and 'CCC' in content and + any(f'[C@@H]{n}' in content or f'[C@H]{n}' in content for n in '123456789')) + for n in '123456789' + ]) or any([ + # Check ending patterns with any ring number + (f'CCCN{n}' in content and content.endswith('=O') and + any(f'[C@@H]{n}' in content or f'[C@H]{n}' in content for n in '123456789')) + for n in '123456789' + ]) or any([ + # Handle CCC[C@H]n patterns + (content == f'CCC[C@H]{n}' and segment.get('bond_before', '').startswith(f'C(=O)N{n}')) or + (content == f'CCC[C@@H]{n}' and segment.get('bond_before', '').startswith(f'C(=O)N{n}')) or + # N-terminal Pro with any ring number + (f'N{n}CCC[C@H]{n}' in content) or + (f'N{n}CCC[C@@H]{n}' in content) + for n in '123456789' + ]): + return 'Pro', mods + + # Tryptophan (W) - more specific indole pattern + if re.search(r'c[0-9]c\[nH\]c[0-9]ccccc[0-9][0-9]', content) and \ + 'c[nH]c' in content.replace(' ', ''): + return 'Trp', mods + + # Lysine (K) - both patterns + if '[C@@H](CCCCN)' in content or '[C@H](CCCCN)' in content: + return 'Lys', mods + + # Arginine (R) - both patterns + if '[C@@H](CCCNC(=N)N)' in content or '[C@H](CCCNC(=N)N)' in content: + return 'Arg', mods + + if ('C[C@H](CCCC)' in content or 'C[C@@H](CCCC)' in content) and 'CC(C)' not in content: + return 'Nle', mods + + # Ornithine (Orn) - 3-carbon chain with NH2 + if ('C[C@H](CCCN)' in content or 'C[C@@H](CCCN)' in content) and 'CC(C)' not in content: + return 'Orn', mods + + # 2-Naphthylalanine (2Nal) - distinct from Phe pattern + if ('Cc3cc2ccccc2c3' in content) and ('C[C@H]' in content or 'C[C@@H]' in content): + return '2Nal', mods + + # Cyclohexylalanine (Cha) - already in your code but moved here for clarity + if 'N2CCCCC2' in content or 'CCCCC2' in content: + return 'Cha', mods + + # Aminobutyric acid (Abu) - 2-carbon chain + if ('C[C@H](CC)' in content or 'C[C@@H](CC)' in content) and not any(p in content for p in ['CC(C)', 'CCCC', 'CCC(C)']): + return 'Abu', mods + + # Pipecolic acid (Pip) - 6-membered ring like Pro + if ('N3CCCCC3' in content or 'CCCCC3' in content) and ('C[C@H]' in content or 'C[C@@H]' in content): + return 'Pip', mods + + # Cyclohexylglycine (Chg) - direct cyclohexyl without CH2 + if ('C[C@H](C1CCCCC1)' in content or 'C[C@@H](C1CCCCC1)' in content): + return 'Chg', mods + + # 4-Fluorophenylalanine (4F-Phe) + if ('Cc2ccc(F)cc2' in content) and ('C[C@H]' in content or 'C[C@@H]' in content): + return '4F-Phe', mods + + # Regular residue identification + if ('NCC(=O)' in content) or (content == 'C'): + # Middle case - between bonds + if segment.get('bond_before') and segment.get('bond_after'): + if ('C(=O)N' in segment['bond_before'] or 'C(=O)N(C)' in segment['bond_before']): + return 'Gly', mods + # Terminal case - at the end + elif segment.get('bond_before') and segment.get('bond_before').startswith('C(=O)N'): + return 'Gly', mods + + if 'CC(C)C[C@H]' in content or 'CC(C)C[C@@H]' in content: + return 'Leu', mods + if '[C@@H](CC(C)C)' in content or '[C@H](CC(C)C)' in content: + return 'Leu', mods + + if '[C@@H]([C@@H](C)O)' in content or '[C@H]([C@H](C)O)' in content: + return 'Thr', mods + + if '[C@H](Cc2ccccc2)' in content or '[C@@H](Cc2ccccc2)' in content: + return 'Phe', mods + + if ('[C@H](C(C)C)' in content or # With outer parentheses + '[C@@H](C(C)C)' in content or # With outer parentheses + '[C@H]C(C)C' in content or # Without outer parentheses + '[C@@H]C(C)C' in content): # Without outer parentheses + if not any(p in content for p in ['CC(C)C[C@H]', 'CC(C)C[C@@H]']): # Still check not Leu + return 'Val', mods + + if '[C@H](COC(C)(C)C)' in content or '[C@@H](COC(C)(C)C)' in content: + return 'O-tBu', mods + + if any([ + 'CC[C@H](C)' in content, + 'CC[C@@H](C)' in content, + 'C(C)C[C@H]' in content and 'CC(C)C' not in content, + 'C(C)C[C@@H]' in content and 'CC(C)C' not in content + ]): + return 'Ile', mods + + if ('[C@H](C)' in content or '[C@@H](C)' in content): + if not any(p in content for p in ['C(C)C', 'COC', 'CN(', 'C(C)O', 'CC[C@H]', 'CC[C@@H]']): + return 'Ala', mods + + # Tyrosine (Tyr) - 4-hydroxybenzyl side chain + if re.search(r'Cc[0-9]ccc\(O\)cc[0-9]', content): + return 'Tyr', mods + + + # Serine (Ser) - Hydroxymethyl side chain + if '[C@H](CO)' in content or '[C@@H](CO)' in content: + if not ('C(C)O' in content or 'COC' in content): + return 'Ser', mods + + # Threonine (Thr) - 1-hydroxyethyl side chain + if '[C@@H]([C@@H](C)O)' in content or '[C@H]([C@H](C)O)' in content or '[C@@H](C)O' in content or '[C@H](C)O' in content: + return 'Thr', mods + + # Cysteine (Cys) - Thiol side chain + if '[C@H](CS)' in content or '[C@@H](CS)' in content: + return 'Cys', mods + + # Methionine (Met) - Methylthioethyl side chain + if ('C[C@H](CCSC)' in content or 'C[C@@H](CCSC)' in content): + return 'Met', mods + + # Asparagine (Asn) - Carbamoylmethyl side chain + if ('CC(=O)N' in content) and ('C[C@H]' in content or 'C[C@@H]' in content): + return 'Asn', mods + + # Glutamine (Gln) - Carbamoylethyl side chain + if ('CCC(=O)N' in content) and ('C[C@H]' in content or 'C[C@@H]' in content): + return 'Gln', mods + + # Aspartic acid (Asp) - Carboxymethyl side chain + if ('CC(=O)O' in content) and ('C[C@H]' in content or 'C[C@@H]' in content): + return 'Asp', mods + + # Glutamic acid (Glu) - Carboxyethyl side chainCC(C(=O)O)N + if ('CCC(=O)O' in content) and ('C[C@H]' in content or 'C[C@@H]' in content): + return 'Glu', mods + + # Arginine (Arg) - 3-guanidinopropyl side chain + if ('CCCNC(=N)N' in content) and ('C[C@H]' in content or 'C[C@@H]' in content): + return 'Arg', mods + + # Histidine (His) - Imidazole side chain + if ('Cc2cnc[nH]2' in content) and ('C[C@H]' in content or 'C[C@@H]' in content): + return 'His', mods + + return None, mods + + def get_modifications(self, segment): + """Get modifications based on bond types""" + mods = [] + if segment.get('bond_after'): + if 'N(C)' in segment['bond_after'] or segment['bond_after'].startswith('C(=O)N(C)'): + mods.append('N-Me') + if 'OC(=O)' in segment['bond_after']: + mods.append('O-linked') + return mods + + def analyze_structure(self, smiles): + """Main analysis function with debug output""" + print("\nAnalyzing structure:", smiles) + + # Split into segments + segments = self.split_on_bonds(smiles) + + print("\nSegment Analysis:") + sequence = [] + for i, segment in enumerate(segments): + print(f"\nSegment {i}:") + print(f"Content: {segment['content']}") + print(f"Bond before: {segment.get('bond_before', 'None')}") + print(f"Bond after: {segment.get('bond_after', 'None')}") + + residue, mods = self.identify_residue(segment) + if residue: + if mods: + sequence.append(f"{residue}({','.join(mods)})") + else: + sequence.append(residue) + print(f"Identified as: {residue}") + print(f"Modifications: {mods}") + else: + print(f"Warning: Could not identify residue in segment: {segment['content']}") + + # Check if cyclic + is_cyclic, peptide_cycles, aromatic_cycles = self.is_cyclic(smiles) + three_letter = '-'.join(sequence) + one_letter = ''.join(self.three_to_one.get(aa.split('(')[0], 'X') for aa in sequence) + + if is_cyclic: + three_letter = f"cyclo({three_letter})" + one_letter = f"cyclo({one_letter})" + + print(f"\nFinal sequence: {three_letter}") + print(f"One-letter code: {one_letter}") + print(f"Is cyclic: {is_cyclic}") + #print(f"Peptide cycles: {peptide_cycles}") + #print(f"Aromatic cycles: {aromatic_cycles}") + + return three_letter, len(segments) + """return { + 'three_letter': three_letter, + #'one_letter': one_letter, + 'is_cyclic': is_cyclic + }""" + + def return_sequence(self, smiles): + """Main analysis function with debug output""" + print("\nAnalyzing structure:", smiles) + + # Split into segments + segments = self.split_on_bonds(smiles) + + print("\nSegment Analysis:") + sequence = [] + for i, segment in enumerate(segments): + print(f"\nSegment {i}:") + print(f"Content: {segment['content']}") + print(f"Bond before: {segment.get('bond_before', 'None')}") + print(f"Bond after: {segment.get('bond_after', 'None')}") + + residue, mods = self.identify_residue(segment) + if residue: + if mods: + sequence.append(f"{residue}({','.join(mods)})") + else: + sequence.append(residue) + print(f"Identified as: {residue}") + print(f"Modifications: {mods}") + else: + print(f"Warning: Could not identify residue in segment: {segment['content']}") + + return sequence + +""" +def annotate_cyclic_structure(mol, sequence): + '''Create annotated 2D structure with clear, non-overlapping residue labels''' + # Generate 2D coordinates + # Generate 2D coordinates + AllChem.Compute2DCoords(mol) + + # Create drawer with larger size for annotations + drawer = Draw.rdMolDraw2D.MolDraw2DCairo(2000, 2000) # Even larger size + + # Get residue list and reverse it to match structural representation + if sequence.startswith('cyclo('): + residues = sequence[6:-1].split('-') + else: + residues = sequence.split('-') + residues = list(reversed(residues)) # Reverse the sequence + + # Draw molecule first to get its bounds + drawer.drawOptions().addAtomIndices = False + drawer.DrawMolecule(mol) + drawer.FinishDrawing() + + # Convert to PIL Image + img = Image.open(BytesIO(drawer.GetDrawingText())) + draw = ImageDraw.Draw(img) + + try: + # Try to use DejaVuSans as it's commonly available on Linux systems + font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 60) + small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 60) + except OSError: + try: + # Fallback to Arial if available (common on Windows) + font = ImageFont.truetype("arial.ttf", 60) + small_font = ImageFont.truetype("arial.ttf", 60) + except OSError: + # If no TrueType fonts are available, fall back to default + print("Warning: TrueType fonts not available, using default font") + font = ImageFont.load_default() + small_font = ImageFont.load_default() + # Get molecule bounds + conf = mol.GetConformer() + positions = [] + for i in range(mol.GetNumAtoms()): + pos = conf.GetAtomPosition(i) + positions.append((pos.x, pos.y)) + + x_coords = [p[0] for p in positions] + y_coords = [p[1] for p in positions] + min_x, max_x = min(x_coords), max(x_coords) + min_y, max_y = min(y_coords), max(y_coords) + + # Calculate scaling factors + scale = 150 # Increased scale factor + center_x = 1000 # Image center + center_y = 1000 + + # Add residue labels in a circular arrangement around the structure + n_residues = len(residues) + radius = 700 # Distance of labels from center + + # Start from the rightmost point (3 o'clock position) and go counterclockwise + # Offset by -3 positions to align with structure + offset = 0 # Adjust this value to match the structure alignment + for i, residue in enumerate(residues): + # Calculate position in a circle around the structure + # Start from 0 (3 o'clock) and go counterclockwise + angle = -(2 * np.pi * ((i + offset) % n_residues) / n_residues) + + # Calculate label position + label_x = center_x + radius * np.cos(angle) + label_y = center_y + radius * np.sin(angle) + + # Draw residue label + text = f"{i+1}. {residue}" + bbox = draw.textbbox((label_x, label_y), text, font=font) + padding = 10 + draw.rectangle([bbox[0]-padding, bbox[1]-padding, + bbox[2]+padding, bbox[3]+padding], + fill='white', outline='white') + draw.text((label_x, label_y), text, + font=font, fill='black', anchor="mm") + + # Add sequence at the top with white background + seq_text = f"Sequence: {sequence}" + bbox = draw.textbbox((center_x, 100), seq_text, font=small_font) + padding = 10 + draw.rectangle([bbox[0]-padding, bbox[1]-padding, + bbox[2]+padding, bbox[3]+padding], + fill='white', outline='white') + draw.text((center_x, 100), seq_text, + font=small_font, fill='black', anchor="mm") + + return img +""" +def annotate_cyclic_structure(mol, sequence): + """Create structure visualization with just the sequence header""" + # Generate 2D coordinates + AllChem.Compute2DCoords(mol) + + # Create drawer with larger size for annotations + drawer = Draw.rdMolDraw2D.MolDraw2DCairo(2000, 2000) + + # Draw molecule first + drawer.drawOptions().addAtomIndices = False + drawer.DrawMolecule(mol) + drawer.FinishDrawing() + + # Convert to PIL Image + img = Image.open(BytesIO(drawer.GetDrawingText())) + draw = ImageDraw.Draw(img) + try: + small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 60) + except OSError: + try: + small_font = ImageFont.truetype("arial.ttf", 60) + except OSError: + print("Warning: TrueType fonts not available, using default font") + small_font = ImageFont.load_default() + + # Add just the sequence header at the top + seq_text = f"Sequence: {sequence}" + bbox = draw.textbbox((1000, 100), seq_text, font=small_font) + padding = 10 + draw.rectangle([bbox[0]-padding, bbox[1]-padding, + bbox[2]+padding, bbox[3]+padding], + fill='white', outline='white') + draw.text((1000, 100), seq_text, + font=small_font, fill='black', anchor="mm") + + return img + +def create_enhanced_linear_viz(sequence, smiles): + """Create an enhanced linear representation using PeptideAnalyzer""" + analyzer = PeptideAnalyzer() # Create analyzer instance + + # Create figure with two subplots + fig = plt.figure(figsize=(15, 10)) + gs = fig.add_gridspec(2, 1, height_ratios=[1, 2]) + ax_struct = fig.add_subplot(gs[0]) + ax_detail = fig.add_subplot(gs[1]) + + # Parse sequence and get residues + if sequence.startswith('cyclo('): + residues = sequence[6:-1].split('-') + else: + residues = sequence.split('-') + + # Get segments using analyzer + segments = analyzer.split_on_bonds(smiles) + + # Debug print + print(f"Number of residues: {len(residues)}") + print(f"Number of segments: {len(segments)}") + + # Top subplot - Basic structure + ax_struct.set_xlim(0, 10) + ax_struct.set_ylim(0, 2) + + num_residues = len(residues) + spacing = 9.0 / (num_residues - 1) if num_residues > 1 else 9.0 + + # Draw basic structure + y_pos = 1.5 + for i in range(num_residues): + x_pos = 0.5 + i * spacing + + # Draw amino acid box + rect = patches.Rectangle((x_pos-0.3, y_pos-0.2), 0.6, 0.4, + facecolor='lightblue', edgecolor='black') + ax_struct.add_patch(rect) + + # Draw connecting bonds if not the last residue + if i < num_residues - 1: + segment = segments[i] if i < len(segments) else None + if segment: + # Determine bond type from segment info + bond_type = 'ester' if 'O-linked' in segment.get('bond_after', '') else 'peptide' + is_n_methylated = 'N-Me' in segment.get('bond_after', '') + + bond_color = 'red' if bond_type == 'ester' else 'black' + linestyle = '--' if bond_type == 'ester' else '-' + + # Draw bond line + ax_struct.plot([x_pos+0.3, x_pos+spacing-0.3], [y_pos, y_pos], + color=bond_color, linestyle=linestyle, linewidth=2) + + # Add bond type label + mid_x = x_pos + spacing/2 + bond_label = f"{bond_type}" + if is_n_methylated: + bond_label += "\n(N-Me)" + ax_struct.text(mid_x, y_pos+0.1, bond_label, + ha='center', va='bottom', fontsize=10, + color=bond_color) + + # Add residue label + ax_struct.text(x_pos, y_pos-0.5, residues[i], + ha='center', va='top', fontsize=14) + + # Bottom subplot - Detailed breakdown + ax_detail.set_ylim(0, len(segments)+1) + ax_detail.set_xlim(0, 1) + + # Create detailed breakdown + segment_y = len(segments) # Start from top + for i, segment in enumerate(segments): + y = segment_y - i + + # Check if this is a bond or residue + residue, mods = analyzer.identify_residue(segment) + if residue: + text = f"Residue {i+1}: {residue}" + if mods: + text += f" ({', '.join(mods)})" + color = 'blue' + else: + # Must be a bond + text = f"Bond {i}: " + if 'O-linked' in segment.get('bond_after', ''): + text += "ester" + elif 'N-Me' in segment.get('bond_after', ''): + text += "peptide (N-methylated)" + else: + text += "peptide" + color = 'red' + + # Add segment analysis + ax_detail.text(0.05, y, text, fontsize=12, color=color) + ax_detail.text(0.5, y, f"SMILES: {segment.get('content', '')}", fontsize=10, color='gray') + + # If cyclic, add connection indicator + if sequence.startswith('cyclo('): + ax_struct.annotate('', xy=(9.5, y_pos), xytext=(0.5, y_pos), + arrowprops=dict(arrowstyle='<->', color='red', lw=2)) + ax_struct.text(5, y_pos+0.3, 'Cyclic Connection', + ha='center', color='red', fontsize=14) + + # Add titles and adjust layout + ax_struct.set_title("Peptide Structure Overview", pad=20) + ax_detail.set_title("Segment Analysis Breakdown", pad=20) + + # Remove axes + for ax in [ax_struct, ax_detail]: + ax.set_xticks([]) + ax.set_yticks([]) + ax.axis('off') + + plt.tight_layout() + return fig + +class PeptideStructureGenerator: + """A class to generate 3D structures of peptides using different embedding methods""" + + @staticmethod + def prepare_molecule(smiles): + """Prepare molecule with proper hydrogen handling""" + mol = Chem.MolFromSmiles(smiles, sanitize=False) + if mol is None: + raise ValueError("Failed to create molecule from SMILES") + + # Calculate valence for each atom + for atom in mol.GetAtoms(): + atom.UpdatePropertyCache(strict=False) + + # Sanitize with reduced requirements + Chem.SanitizeMol(mol, + sanitizeOps=Chem.SANITIZE_FINDRADICALS| + Chem.SANITIZE_KEKULIZE| + Chem.SANITIZE_SETAROMATICITY| + Chem.SANITIZE_SETCONJUGATION| + Chem.SANITIZE_SETHYBRIDIZATION| + Chem.SANITIZE_CLEANUPCHIRALITY) + + mol = Chem.AddHs(mol) + return mol + + @staticmethod + def get_etkdg_params(attempt=0): + """Get ETKDG parameters with optional modifications based on attempt number""" + params = AllChem.ETKDGv3() + params.randomSeed = -1 + params.maxIterations = 200 + params.numThreads = 4 # Reduced for web interface + params.useBasicKnowledge = True + params.enforceChirality = True + params.useExpTorsionAnglePrefs = True + params.useSmallRingTorsions = True + params.useMacrocycleTorsions = True + params.ETversion = 2 + params.pruneRmsThresh = -1 + params.embedRmsThresh = 0.5 + + if attempt > 10: + params.bondLength = 1.5 + (attempt - 10) * 0.02 + params.useExpTorsionAnglePrefs = False + + return params + + def generate_structure_etkdg(self, smiles, max_attempts=20): + """Generate 3D structure using ETKDG without UFF optimization""" + success = False + mol = None + + for attempt in range(max_attempts): + try: + mol = self.prepare_molecule(smiles) + params = self.get_etkdg_params(attempt) + + if AllChem.EmbedMolecule(mol, params) == 0: + success = True + break + except Exception as e: + continue + + if not success: + raise ValueError("Failed to generate structure with ETKDG") + + return mol + + def generate_structure_uff(self, smiles, max_attempts=20): + """Generate 3D structure using ETKDG followed by UFF optimization""" + best_mol = None + lowest_energy = float('inf') + + for attempt in range(max_attempts): + try: + test_mol = self.prepare_molecule(smiles) + params = self.get_etkdg_params(attempt) + + if AllChem.EmbedMolecule(test_mol, params) == 0: + res = AllChem.UFFOptimizeMolecule(test_mol, maxIters=2000, + vdwThresh=10.0, confId=0, + ignoreInterfragInteractions=True) + + if res == 0: + ff = AllChem.UFFGetMoleculeForceField(test_mol) + if ff: + current_energy = ff.CalcEnergy() + if current_energy < lowest_energy: + lowest_energy = current_energy + best_mol = Chem.Mol(test_mol) + except Exception: + continue + + if best_mol is None: + raise ValueError("Failed to generate optimized structure") + + return best_mol + + @staticmethod + def mol_to_sdf_bytes(mol): + """Convert RDKit molecule to SDF file bytes""" + # First write to StringIO in text mode + sio = StringIO() + writer = Chem.SDWriter(sio) + writer.write(mol) + writer.close() + + # Convert the string to bytes + return sio.getvalue().encode('utf-8') + +def process_input(smiles_input=None, file_obj=None, show_linear=False, + show_segment_details=False, generate_3d=False, use_uff=False): + """Process input and create visualizations using PeptideAnalyzer""" + analyzer = PeptideAnalyzer() + temp_dir = tempfile.mkdtemp() if generate_3d else None + structure_files = [] + + # Handle direct SMILES input + if smiles_input: + smiles = smiles_input.strip() + + # First check if it's a peptide using analyzer's method + if not analyzer.is_peptide(smiles): + return "Error: Input SMILES does not appear to be a peptide structure.", None, None + + try: + # Create molecule + mol = Chem.MolFromSmiles(smiles) + if mol is None: + return "Error: Invalid SMILES notation.", None, None + + # Generate 3D structures if requested + if generate_3d: + generator = PeptideStructureGenerator() + + try: + # Generate ETKDG structure + mol_etkdg = generator.generate_structure_etkdg(smiles) + etkdg_path = os.path.join(temp_dir, "structure_etkdg.sdf") + writer = Chem.SDWriter(etkdg_path) + writer.write(mol_etkdg) + writer.close() + structure_files.append(etkdg_path) + + # Generate UFF structure if requested + if use_uff: + mol_uff = generator.generate_structure_uff(smiles) + uff_path = os.path.join(temp_dir, "structure_uff.sdf") + writer = Chem.SDWriter(uff_path) + writer.write(mol_uff) + writer.close() + structure_files.append(uff_path) + + except Exception as e: + return f"Error generating 3D structures: {str(e)}", None, None, None + + # Use analyzer to get sequence + segments = analyzer.split_on_bonds(smiles) + + # Process segments and build sequence + sequence_parts = [] + output_text = "" + + # Only include segment analysis in output if requested + if show_segment_details: + output_text += "Segment Analysis:\n" + for i, segment in enumerate(segments): + output_text += f"\nSegment {i}:\n" + output_text += f"Content: {segment['content']}\n" + output_text += f"Bond before: {segment.get('bond_before', 'None')}\n" + output_text += f"Bond after: {segment.get('bond_after', 'None')}\n" + + residue, mods = analyzer.identify_residue(segment) + if residue: + if mods: + sequence_parts.append(f"{residue}({','.join(mods)})") + else: + sequence_parts.append(residue) + output_text += f"Identified as: {residue}\n" + output_text += f"Modifications: {mods}\n" + else: + output_text += f"Warning: Could not identify residue in segment: {segment['content']}\n" + output_text += "\n" + else: + # Just build sequence without detailed analysis in output + for segment in segments: + residue, mods = analyzer.identify_residue(segment) + if residue: + if mods: + sequence_parts.append(f"{residue}({','.join(mods)})") + else: + sequence_parts.append(residue) + + # Check if cyclic using analyzer's method + is_cyclic, peptide_cycles, aromatic_cycles = analyzer.is_cyclic(smiles) + three_letter = '-'.join(sequence_parts) + one_letter = ''.join(analyzer.three_to_one.get(aa.split('(')[0], 'X') for aa in sequence_parts) + + if is_cyclic: + three_letter = f"cyclo({three_letter})" + one_letter = f"cyclo({one_letter})" + + # Create cyclic structure visualization + img_cyclic = annotate_cyclic_structure(mol, three_letter) + + # Create linear representation if requested + img_linear = None + if show_linear: + fig_linear = create_enhanced_linear_viz(three_letter, smiles) + buf = BytesIO() + fig_linear.savefig(buf, format='png', bbox_inches='tight', dpi=300) + buf.seek(0) + img_linear = Image.open(buf) + plt.close(fig_linear) + + # Add summary to output + summary = "Summary:\n" + summary += f"Sequence: {three_letter}\n" + summary += f"One-letter code: {one_letter}\n" + summary += f"Is Cyclic: {'Yes' if is_cyclic else 'No'}\n" + #if is_cyclic: + #summary += f"Peptide Cycles: {', '.join(peptide_cycles)}\n" + #summary += f"Aromatic Cycles: {', '.join(aromatic_cycles)}\n" + + if structure_files: + summary += "\n3D Structures Generated:\n" + for filepath in structure_files: + summary += f"- {os.path.basename(filepath)}\n" + + return summary + output_text, img_cyclic, img_linear, structure_files if structure_files else None + + except Exception as e: + return f"Error processing SMILES: {str(e)}", None, None, None + + # Handle file input + if file_obj is not None: + try: + # Handle file content + if hasattr(file_obj, 'name'): + with open(file_obj.name, 'r') as f: + content = f.read() + else: + content = file_obj.decode('utf-8') if isinstance(file_obj, bytes) else str(file_obj) + + output_text = "" + for line in content.splitlines(): + smiles = line.strip() + if smiles: + # Check if it's a peptide + if not analyzer.is_peptide(smiles): + output_text += f"Skipping non-peptide SMILES: {smiles}\n" + continue + + # Process this SMILES + segments = analyzer.split_on_bonds(smiles) + sequence_parts = [] + + # Add segment details if requested + if show_segment_details: + output_text += f"\nSegment Analysis for SMILES: {smiles}\n" + for i, segment in enumerate(segments): + output_text += f"\nSegment {i}:\n" + output_text += f"Content: {segment['content']}\n" + output_text += f"Bond before: {segment.get('bond_before', 'None')}\n" + output_text += f"Bond after: {segment.get('bond_after', 'None')}\n" + residue, mods = analyzer.identify_residue(segment) + if residue: + if mods: + sequence_parts.append(f"{residue}({','.join(mods)})") + else: + sequence_parts.append(residue) + output_text += f"Identified as: {residue}\n" + output_text += f"Modifications: {mods}\n" + else: + for segment in segments: + residue, mods = analyzer.identify_residue(segment) + if residue: + if mods: + sequence_parts.append(f"{residue}({','.join(mods)})") + else: + sequence_parts.append(residue) + + # Get cyclicity and create sequence + is_cyclic, peptide_cycles, aromatic_cycles = analyzer.is_cyclic(smiles) + sequence = f"cyclo({'-'.join(sequence_parts)})" if is_cyclic else '-'.join(sequence_parts) + + output_text += f"\nSummary for SMILES: {smiles}\n" + output_text += f"Sequence: {sequence}\n" + output_text += f"Is Cyclic: {'Yes' if is_cyclic else 'No'}\n" + if is_cyclic: + output_text += f"Peptide Cycles: {', '.join(peptide_cycles)}\n" + #output_text += f"Aromatic Cycles: {', '.join(aromatic_cycles)}\n" + output_text += "-" * 50 + "\n" + + return output_text, None, None + + except Exception as e: + return f"Error processing file: {str(e)}", None, None + + return "No input provided.", None, None + + diff --git a/smiles/rectify_train.py b/smiles/rectify_train.py new file mode 100644 index 0000000000000000000000000000000000000000..8fc1edc4b6311042663abe3e8f8258b5fc8e9c4a --- /dev/null +++ b/smiles/rectify_train.py @@ -0,0 +1,553 @@ +import argparse +import math +import os +from functools import partial +from collections import Counter + +import torch +import torch.nn as nn +import torch.nn.functional as F +from datasets import load_from_disk +from torch.optim import AdamW +from torch.optim.lr_scheduler import LambdaLR +from torch.utils.data import DataLoader +import pytorch_lightning as pl +from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor +from pytorch_lightning.loggers import WandbLogger +from pytorch_lightning.strategies import DDPStrategy +from rdkit import Chem + +from smiles_tokenizer.my_tokenizers import SMILES_SPE_Tokenizer +from peptide_analyzer import PeptideAnalyzer +import dataloading_for_dynamic_batching as dynamic_dataloader + + +class RotaryPositionalEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000): + super().__init__() + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + + def forward(self, x, seq_len=None): + if seq_len is None: + seq_len = x.shape[1] + + t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + + cos_emb = emb.cos()[None, :, :] + sin_emb = emb.sin()[None, :, :] + + return cos_emb, sin_emb + +def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + +def apply_rotary_pos_emb(q, k, cos, sin): + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + +# --- Model Architecture with RoPE --- +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + +class TimestepEmbedder(nn.Module): + def __init__(self, hidden_size): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(1, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + + def forward(self, t): + return self.mlp(t.unsqueeze(-1)) + +class MultiHeadAttentionWithRoPE(nn.Module): + def __init__(self, hidden_size, n_heads): + super().__init__() + self.hidden_size = hidden_size + self.n_heads = n_heads + self.head_dim = hidden_size // n_heads + + assert self.head_dim * n_heads == hidden_size, "hidden_size must be divisible by n_heads" + + self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False) + self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False) + self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False) + self.out_proj = nn.Linear(hidden_size, hidden_size) + + self.rope = RotaryPositionalEmbedding(self.head_dim) + + def forward(self, x): + batch_size, seq_len, hidden_size = x.shape + + q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) + k = self.k_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) + v = self.v_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rope(q, seq_len) + q, k = apply_rotary_pos_emb(q, k, cos, sin) + + scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) + attn_weights = F.softmax(scores, dim=-1) + attn_output = torch.matmul(attn_weights, v) + + attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, hidden_size) + output = self.out_proj(attn_output) + + return output + +class DiTBlock(nn.Module): + def __init__(self, hidden_size, n_heads): + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = MultiHeadAttentionWithRoPE(hidden_size, n_heads) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.mlp = nn.Sequential( + nn.Linear(hidden_size, 4 * hidden_size), + nn.GELU(), + nn.Linear(4 * hidden_size, hidden_size) + ) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) + x_norm1 = modulate(self.norm1(x), shift_msa, scale_msa) + attn_output = self.attn(x_norm1) + x = x + gate_msa.unsqueeze(1) * attn_output + x_norm2 = modulate(self.norm2(x), shift_mlp, scale_mlp) + mlp_output = self.mlp(x_norm2) + x = x + gate_mlp.unsqueeze(1) * mlp_output + return x + +class MDLM(nn.Module): + def __init__(self, vocab_size, model_dim, n_heads, n_layers): + super().__init__() + self.vocab_size = vocab_size + self.model_dim = model_dim + self.mask_token_id = vocab_size + + self.token_embedder = nn.Embedding(vocab_size, model_dim) + self.time_embedder = TimestepEmbedder(model_dim) + + self.transformer_blocks = nn.ModuleList([ + DiTBlock(model_dim, n_heads) for _ in range(n_layers) + ]) + + self.final_norm = nn.LayerNorm(model_dim) + self.lm_head = nn.Linear(model_dim, vocab_size) + + self.apply(self._init_weights) + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.02) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + if module.bias is not None: + module.bias.data.zero_() + if module.weight is not None: + module.weight.data.fill_(1.0) + + def forward(self, x, t): + x_embed = self.token_embedder(x) + t_embed = self.time_embedder(t) + for block in self.transformer_blocks: + x_embed = block(x_embed, t_embed) + x_embed = self.final_norm(x_embed) + logits = self.lm_head(x_embed) + return logits + +# --- PyTorch Lightning Module --- +class MDLMLightningModule(pl.LightningModule): + def __init__(self, args, tokenizer): + super().__init__() + self.save_hyperparameters(ignore=['tokenizer']) + self.args = args + self.tokenizer = tokenizer + self.peptide_analyzer = PeptideAnalyzer() + + # Initialize model + self.model = MDLM( + vocab_size=tokenizer.vocab_size, + model_dim=args.model_dim, + n_heads=args.n_heads, + n_layers=args.n_layers + ) + + self.automatic_optimization = True + self.validation_step_outputs = [] + + # Track training progress + self.register_buffer('epoch_progress', torch.tensor(0.0)) + + def forward(self, x, t): + return self.model(x, t) + + def _compute_invalid_loss(self, logits, t_continuous=None): + """ + Original invalid loss computation from PepTune + with optional time-dependent weighting + """ + batch_token_ids = torch.argmax(logits, dim=-1) # (batch_size, seq_length) + sampled_sequences = self.tokenizer.batch_decode(batch_token_ids) + + # Check validity using peptide analyzer + penalties = torch.tensor( + [1.0 if not self.peptide_analyzer.is_peptide(seq) else 0.0 for seq in sampled_sequences], + dtype=torch.float32, + device=self.device + ) # (batch_size,) + + # Optional: Apply time-dependent scaling + if t_continuous is not None and self.args.time_dependent_validity: + # Less penalty at early timesteps (when t is close to 0) + time_weight = t_continuous ** self.args.validity_time_power # Default power = 0.5 + penalties = penalties * time_weight + + # Get softmax probabilities for selected tokens + sampled_probs = torch.softmax(logits, dim=-1).gather( + dim=-1, index=batch_token_ids.unsqueeze(-1) + ).squeeze(-1).to(self.device) # (batch_size, seq_length) + + # Scale penalty by token probabilities (makes it differentiable) + scaled_penalty = penalties[:, None] * sampled_probs # (batch_size, seq_length) + + return scaled_penalty + + def get_validity_weight(self): + """ + Compute annealed validity weight based on training progress + """ + current_epoch = self.current_epoch + + # Stage 1: No validity loss for first N epochs + if current_epoch < self.args.validity_start_epoch: + return 0.0 + + # Stage 2: Gradually increase validity weight + epochs_with_validity = current_epoch - self.args.validity_start_epoch + max_epochs_with_validity = self.args.epochs - self.args.validity_start_epoch + + if self.args.validity_schedule == 'linear': + # Linear increase from min to max weight + progress = epochs_with_validity / max_epochs_with_validity + weight = (self.args.validity_weight_min + + (self.args.validity_weight_max - self.args.validity_weight_min) * progress) + + elif self.args.validity_schedule == 'exponential': + # Exponential increase (starts slow, accelerates) + progress = epochs_with_validity / max_epochs_with_validity + weight = (self.args.validity_weight_min * + (self.args.validity_weight_max / self.args.validity_weight_min) ** progress) + + elif self.args.validity_schedule == 'cosine': + # Cosine schedule (smooth increase) + progress = epochs_with_validity / max_epochs_with_validity + cosine_factor = 0.5 * (1 - math.cos(math.pi * progress)) + weight = (self.args.validity_weight_min + + (self.args.validity_weight_max - self.args.validity_weight_min) * cosine_factor) + + elif self.args.validity_schedule == 'step': + # Step-wise increase + steps = [0.25, 0.5, 0.75, 1.0] + weights = [self.args.validity_weight_min, + self.args.validity_weight_min * 2, + self.args.validity_weight_min * 5, + self.args.validity_weight_max] + progress = epochs_with_validity / max_epochs_with_validity + for i, step in enumerate(steps): + if progress <= step: + weight = weights[i] + break + else: + # Constant weight + weight = self.args.validity_weight_max + + return weight + + def _loss(self, logits, x_1, attn_mask, t_continuous=None): + """ + Combined loss with staged validity loss + """ + # Standard cross-entropy loss + ce_loss = F.cross_entropy( + logits.view(-1, self.model.vocab_size), + x_1.view(-1), + reduction='none' + ).view(x_1.shape[0], -1) + + # Get current validity weight + validity_weight = self.get_validity_weight() + + # Compute invalid loss only if weight > 0 + if validity_weight > 0: + invalid_loss = self._compute_invalid_loss(logits, t_continuous) + else: + invalid_loss = torch.zeros_like(ce_loss) + + # Combine losses + total_loss = ce_loss + validity_weight * invalid_loss + + # Apply attention mask + masked_loss = total_loss * attn_mask + num_tokens = attn_mask.sum() + token_nll = masked_loss.sum() / num_tokens + + # Individual components for logging + ce_token_loss = (ce_loss * attn_mask).sum() / num_tokens + invalid_token_loss = (invalid_loss * attn_mask).sum() / num_tokens + + return token_nll, ce_token_loss, invalid_token_loss, validity_weight + + def training_step(self, batch, batch_idx): + x_0 = batch['source_ids'].to(self.device) + x_1 = batch['target_ids'].to(self.device) + attn_mask = torch.ones_like(x_1).to(self.device) + bond_mask = batch['bond_mask'].to(self.device).bool() + batch_size, _ = x_1.shape + + # ReDi approach: random start -> target + t_continuous = torch.rand(batch_size, device=self.device) + + # Bond-aware masking + peptide_bond_prob = t_continuous.view(-1, 1) ** self.args.gamma + non_peptide_prob = t_continuous.view(-1, 1) + + masking_prob = torch.where(bond_mask, peptide_bond_prob, non_peptide_prob) + mask = torch.rand(x_1.shape, device=self.device) < masking_prob + x_t = torch.where(mask, x_1, x_0) + + # Forward pass + logits = self.model(x_t, t_continuous) + + # Compute loss with staged validity + token_nll, ce_loss, invalid_loss, validity_weight = self._loss( + logits, x_1, attn_mask, t_continuous + ) + + # Extensive logging + self.log('train/token_nll', token_nll.item(), on_step=True, on_epoch=True, prog_bar=True, batch_size=batch_size, sync_dist=True) + self.log('train/ce_loss', ce_loss.item(), on_step=True, on_epoch=True, batch_size=batch_size, sync_dist=True) + self.log('train/invalid_loss', invalid_loss.item(), on_step=True, on_epoch=True, batch_size=batch_size, sync_dist=True) + self.log('train/validity_weight', validity_weight, on_step=False, on_epoch=True, batch_size=batch_size, sync_dist=True) + + # Log gradient norm for debugging + if batch_idx % 1000 == 0: + total_norm = 0 + for p in self.model.parameters(): + if p.grad is not None: + param_norm = p.grad.data.norm(2) + total_norm += param_norm.item() ** 2 + total_norm = total_norm ** 0.5 + self.log('train/grad_norm', total_norm, batch_size=batch_size, sync_dist=True) + + return token_nll + + def validation_step(self, batch, batch_idx): + x_0 = batch['source_ids'].to(self.device) + x_1 = batch['target_ids'].to(self.device) + attn_mask = torch.ones_like(x_1).to(self.device) + bond_mask = batch['bond_mask'].to(self.device).bool() + batch_size, _ = x_1.shape + + # Same masking as training + t_continuous = torch.rand(batch_size, device=self.device) + + peptide_bond_prob = t_continuous.view(-1, 1) ** self.args.gamma + non_peptide_prob = t_continuous.view(-1, 1) + + masking_prob = torch.where(bond_mask, peptide_bond_prob, non_peptide_prob) + mask = torch.rand(x_1.shape, device=self.device) < masking_prob + x_t = torch.where(mask, x_1, x_0) + + logits = self.model(x_t, t_continuous) + + token_nll, ce_loss, invalid_loss, validity_weight = self._loss( + logits, x_1, attn_mask, t_continuous + ) + + self.log('val/token_nll', token_nll.item(), on_step=True, on_epoch=True, prog_bar=True, batch_size=batch_size, sync_dist=True) + self.log('val/ce_loss', ce_loss.item(), on_step=True, on_epoch=True, batch_size=batch_size, sync_dist=True) + self.log('val/invalid_loss', invalid_loss.item(), on_step=True, on_epoch=True, batch_size=batch_size, sync_dist=True) + + # Sample and check validity at different timesteps + if batch_idx == 0: + with torch.no_grad(): + validity_results = {} + for t_val in [0.9, 0.5, 0.1]: # Different timesteps + t_test = torch.full((batch_size,), t_val, device=self.device) + test_mask = torch.rand(x_1.shape, device=self.device) < t_val + x_test = torch.where(test_mask, x_1, x_0) + + test_logits = self.model(x_test, t_test) + test_preds = torch.argmax(test_logits, dim=-1) + + sequences = self.tokenizer.batch_decode(test_preds) + valid_count = sum(1 for seq in sequences if self.peptide_analyzer.is_peptide(seq)) + validity_rate = valid_count / len(sequences) + + self.log(f'val/validity_rate_t{t_val}', validity_rate, batch_size=batch_size, sync_dist=True) + + def configure_optimizers(self): + optimizer = AdamW( + self.parameters(), + lr=self.args.learning_rate, + weight_decay=self.args.weight_decay + ) + + # Calculate total steps + if hasattr(self.trainer, 'estimated_stepping_batches'): + num_training_steps = self.trainer.estimated_stepping_batches + else: + num_training_steps = len(self.trainer.datamodule.train_dataloader()) * self.trainer.max_epochs + + warmup_steps = int(num_training_steps * 0.1) + + def lr_lambda(current_step): + if current_step < warmup_steps: + # Linear warmup + lr_factor = current_step / warmup_steps + return lr_factor + else: + # Cosine decay with min LR + progress = (current_step - warmup_steps) / (num_training_steps - warmup_steps) + cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress)) + min_lr_ratio = 0.1 + return min_lr_ratio + (1 - min_lr_ratio) * cosine_decay + + scheduler = LambdaLR(optimizer, lr_lambda) + + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "interval": "step", + "frequency": 1, + }, + } + +def main(args): + # Set up checkpoint directory + checkpoint_dir = (args.checkpoint_dir + + f"new_lr{args.learning_rate}_layer{args.n_layers}_" + f"head{args.n_heads}_{args.validity_schedule}") + print(f"Saving to {checkpoint_dir}") + os.makedirs(checkpoint_dir, exist_ok=True) + + print("Loading tokenizer...") + tokenizer = SMILES_SPE_Tokenizer('/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/new_vocab.txt', + '/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/new_splits.txt') + print(f"Tokenizer loaded. Vocab size: {tokenizer.vocab_size}") + + # Initialize data module + data_module = dynamic_dataloader.RectifyDataModule('/scratch/pranamlab/tong/data/smiles/v1') + + model = MDLMLightningModule(args, tokenizer) + model = MDLMLightningModule.load_from_checkpoint( + checkpoint_path=args.checkpoint, + args=args, + tokenizer=tokenizer + ) + # Set up logger + logger = WandbLogger( + project="smiles-redi-staged-training", + entity="programmablebio", + name=f"v1_lr{args.learning_rate}_epochs{args.validity_start_epoch}_{args.validity_schedule}", + save_dir=checkpoint_dir + ) + + # Set up callbacks + callbacks = [ + ModelCheckpoint( + dirpath=checkpoint_dir, + filename='best', + monitor='val/token_nll', + mode='min', + save_top_k=1, + save_last=True, + # every_n_train_steps=5000 + ), + # Save every epoch + ModelCheckpoint( + dirpath=checkpoint_dir, + filename='{epoch:02d}', + save_top_k=-1, + every_n_epochs=1, + save_on_train_epoch_end=True + ), + LearningRateMonitor(logging_interval='step') + ] + + # Initialize trainer + trainer = pl.Trainer( + max_epochs=args.epochs, + devices=torch.cuda.device_count(), + accelerator='gpu', + strategy=DDPStrategy(find_unused_parameters=False), + num_nodes=int(os.environ.get("SLURM_NNODES", 1)), + precision="bf16", + gradient_clip_val=args.grad_clip if args.grad_clip > 0 else None, + callbacks=callbacks, + logger=logger, + log_every_n_steps=100, + check_val_every_n_epoch=None, + # val_check_interval=5000, + accumulate_grad_batches=1, + enable_progress_bar=True, + enable_model_summary=True + ) + + print(f"Model initialized with {sum(p.numel() for p in model.parameters()):,} parameters.") + print(f"Training strategy: CE-only for {args.validity_start_epoch} epochs, then staged validity loss") + print("Starting training...") + + # Train the model + trainer.fit(model, data_module) + + print("Training complete.") + print(f"Best checkpoint saved at: {trainer.checkpoint_callback.best_model_path}") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Train ReDi model with staged validity loss") + + # Model arguments + parser.add_argument("--model_dim", type=int, default=1024) + parser.add_argument("--n_heads", type=int, default=8) + parser.add_argument("--n_layers", type=int, default=6) + + # Training arguments + parser.add_argument("--epochs", type=int, default=5) + parser.add_argument("--learning_rate", type=float, default=1e-4) + parser.add_argument("--weight_decay", type=float, default=1e-5) + parser.add_argument("--label_smoothing", type=float, default=0) + parser.add_argument("--grad_clip", type=float, default=1.0) + parser.add_argument("--gamma", type=float, default=2.0) + + # Staged validity arguments + parser.add_argument("--validity_start_epoch", type=int, default=2, help="Epoch to start adding validity loss (0-indexed)") + parser.add_argument("--validity_weight_min", type=float, default=10.0, help="Initial validity weight when starting") + parser.add_argument("--validity_weight_max", type=float, default=200.0, help="Maximum validity weight") + parser.add_argument("--validity_schedule", type=str, default="linear", choices=['linear', 'exponential', 'cosine', 'step', 'constant'], help="Schedule for increasing validity weight") + parser.add_argument("--time_dependent_validity", type=bool, default=False, help="Whether to apply time-dependent scaling to validity loss") + parser.add_argument("--validity_time_power", type=float, default=0.5, help="Power for time-dependent validity scaling") + + # Other arguments + parser.add_argument("--checkpoint_dir", type=str, default="./checkpoints_smiles") + parser.add_argument("--checkpoint", type=str, required=True) + + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/smiles/scoring/__init__.py b/smiles/scoring/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/smiles/scoring/binary_xg.py b/smiles/scoring/binary_xg.py new file mode 100644 index 0000000000000000000000000000000000000000..5a4122879d4180ba94e053b2fb5a819e649052d1 --- /dev/null +++ b/smiles/scoring/binary_xg.py @@ -0,0 +1,280 @@ +import pandas as pd +import numpy as np +import torch +from sklearn.model_selection import train_test_split +from sklearn.metrics import precision_recall_curve, f1_score +import optuna +from optuna.trial import TrialState +import xgboost as xgb +import os +from datasets import load_from_disk +from lightning.pytorch import seed_everything +from rdkit import Chem, rdBase, DataStructs +from typing import List +from rdkit.Chem import AllChem +import matplotlib.pyplot as plt +from sklearn.metrics import accuracy_score, roc_auc_score +import seaborn as sns + +def save_and_plot_binary_predictions(y_true_train, y_pred_train, y_true_val, y_pred_val, threshold, output_path): + """ + Saves the true and predicted values for training and validation sets, and generates binary classification plots. + + Parameters: + y_true_train (array): True labels for the training set. + y_pred_train (array): Predicted probabilities for the training set. + y_true_val (array): True labels for the validation set. + y_pred_val (array): Predicted probabilities for the validation set. + threshold (float): Classification threshold for predictions. + output_path (str): Directory to save the CSV files and plots. + """ + os.makedirs(output_path, exist_ok=True) + + # Convert probabilities to binary predictions + y_pred_train_binary = (y_pred_train >= threshold).astype(int) + y_pred_val_binary = (y_pred_val >= threshold).astype(int) + + # Save training predictions + train_df = pd.DataFrame({ + 'True Label': y_true_train, + 'Predicted Probability': y_pred_train, + 'Predicted Label': y_pred_train_binary + }) + train_df.to_csv(os.path.join(output_path, 'train_predictions_binary.csv'), index=False) + + # Save validation predictions + val_df = pd.DataFrame({ + 'True Label': y_true_val, + 'Predicted Probability': y_pred_val, + 'Predicted Label': y_pred_val_binary + }) + val_df.to_csv(os.path.join(output_path, 'val_predictions_binary.csv'), index=False) + + # Plot training predictions + plot_boxplot_with_threshold( + y_true_train, + y_pred_train, + threshold, + title="Training Set Binary Classification Plot", + output_file=os.path.join(output_path, 'train_classification_plot.png') + ) + + # Plot validation predictions + plot_boxplot_with_threshold( + y_true_val, + y_pred_val, + threshold, + title="Validation Set Binary Classification Plot", + output_file=os.path.join(output_path, 'val_classification_plot.png') + ) + +def plot_binary_correlation(y_true, y_pred, threshold, title, output_file): + # Scatter plot + plt.figure(figsize=(10, 8)) + plt.scatter(y_true, y_pred, alpha=0.5, label='Data points', color='#BC80FF') + + # Add threshold line + plt.axhline(y=threshold, color='red', linestyle='--', label=f'Threshold = {threshold}') + + # Add annotations + plt.title(title) + plt.xlabel("True Labels") + plt.ylabel("Predicted Probability") + plt.legend() + + # Save and show the plot + plt.tight_layout() + plt.savefig(output_file) + plt.show() + +def plot_boxplot_with_threshold(y_true, y_pred, threshold, title, output_file): + """ + Generates a boxplot for binary classification and includes a threshold line. + + Parameters: + y_true (array): True labels. + y_pred (array): Predicted probabilities. + threshold (float): Classification threshold for predictions. + title (str): Title of the plot. + output_file (str): Path to save the plot. + """ + plt.figure(figsize=(10, 8)) + + # Combine data into a DataFrame for seaborn + df = pd.DataFrame({'True Label': y_true, 'Predicted Probability': y_pred}) + + # Boxplot + sns.boxplot(x='True Label', y='Predicted Probability', data=df) + + # Add threshold line + plt.axhline(y=threshold, color='red', linestyle='--', label=f'Threshold = {threshold}') + plt.text( + x=0.5, y=threshold + 0.05, s=f"Threshold = {threshold}", color="red", fontsize=10 + ) + + # Add annotations + plt.title(title) + plt.xlabel("True Label") + plt.ylabel("Predicted Probability") + plt.legend() + + # Save and show the plot + plt.tight_layout() + plt.savefig(output_file) + plt.show() + +def plot_boxplot(y_true, y_pred, title, output_file): + plt.figure(figsize=(10, 8)) + + # Combine data into a single DataFrame for seaborn + df = pd.DataFrame({'True Label': y_true, 'Predicted Probability': y_pred}) + sns.boxplot(x='True Label', y='Predicted Probability', data=df) + + # Add annotations + plt.title(title) + plt.xlabel("True Label") + plt.ylabel("Predicted Probability") + + # Save and show the plot + plt.tight_layout() + plt.savefig(output_file) + plt.show() + +def plot_binary_correlation_with_density(y_true, y_pred, threshold, title, output_file): + """ + Generates a scatter plot with a density plot for binary classification and saves it to a file. + """ + plt.figure(figsize=(10, 8)) + + # Scatter plot + plt.scatter(range(len(y_true)), y_pred, alpha=0.5, label='Predicted Probabilities', color='#BC80FF') + + # Add density plot + sns.kdeplot(y_pred, color='green', fill=True, alpha=0.3, label='Probability Density') + + # Add threshold line + plt.axhline(y=threshold, color='red', linestyle='--', label=f'Threshold = {threshold}') + + # Add annotations + plt.title(title) + plt.xlabel("Index") + plt.ylabel("Predicted Probability") + plt.legend() + + # Save and show the plot + plt.tight_layout() + plt.savefig(output_file) + plt.show() + +seed_everything(42) + +dataset = load_from_disk('/scratch/pranamlab/tong/ReDi_discrete/smiles/scoring/functions/solubility/new_data') + +sequences = np.stack(dataset['sequence']) # Ensure sequences are SMILES strings +labels = np.stack(dataset['labels']) +embeddings = np.stack(dataset['embedding']) + +# Initialize best F1 score and model path +best_f1 = -np.inf +best_model_path = "/scratch/pranamlab/tong/ReDi_discrete/smiles/scoring/functions/solubility/new_train/" + +# Trial callback +def trial_info_callback(study, trial): + if study.best_trial == trial: + print(f"Trial {trial.number}:") + print(f" Weighted F1 Score: {trial.value}") + + +def objective(trial): + params = { + 'objective': 'binary:logistic', + 'lambda': trial.suggest_float('lambda', 1e-8, 10.0, log=True), + 'alpha': trial.suggest_float('alpha', 1e-8, 10.0, log=True), + 'colsample_bytree': trial.suggest_float('colsample_bytree', 0.1, 1.0), + 'subsample': trial.suggest_float('subsample', 0.1, 1.0), + 'learning_rate': trial.suggest_float('learning_rate', 0.01, 0.3), + 'max_depth': trial.suggest_int('max_depth', 2, 30), + 'min_child_weight': trial.suggest_int('min_child_weight', 1, 20), + 'tree_method': 'hist', + 'device': 'cuda:0', + } + num_boost_round = trial.suggest_int('num_boost_round', 10, 1000) + + # Split the data + train_idx, val_idx = train_test_split( + np.arange(len(sequences)), test_size=0.2, stratify=labels, random_state=42 + ) + train_subset = dataset.select(train_idx).with_format("torch") + val_subset = dataset.select(val_idx).with_format("torch") + + # Extract embeddings and labels for train/validation + train_embeddings = train_subset['embedding'] + valid_embeddings = val_subset['embedding'] + train_labels = train_subset['labels'] + valid_labels = val_subset['labels'] + + # Prepare training and validation sets + dtrain = xgb.DMatrix(train_embeddings, label=train_labels) + dvalid = xgb.DMatrix(valid_embeddings, label=valid_labels) + + # Train the model + model = xgb.train( + params=params, + dtrain=dtrain, + num_boost_round=num_boost_round, + evals=[(dvalid, "validation")], + early_stopping_rounds=50, + verbose_eval=False, + ) + + # Predict probabilities + preds_train = model.predict(dtrain) + preds_val = model.predict(dvalid) + + # Perform dynamic thresholding on validation predictions + best_f1_val = -np.inf + best_threshold = 0.5 + + for threshold in np.arange(0.1, 1.0, 0.05): # Try thresholds from 0.1 to 1.0 + preds_val_binary = (preds_val >= threshold).astype(int) + f1_temp = f1_score(valid_labels, preds_val_binary, average="weighted") + if f1_temp > best_f1_val: + best_f1_val = f1_temp + best_threshold = threshold + + print(f"Best F1 Score: {best_f1_val:.3f} at Threshold: {best_threshold:.3f}") + + # Calculate AUC for additional insight + auc_val = roc_auc_score(valid_labels, preds_val) + print(f"AUC: {auc_val:.3f}") + + # Save the best model if the F1 score is improved + if trial.study.user_attrs.get("best_f1", -np.inf) < best_f1_val: + trial.study.set_user_attr("best_f1", best_f1_val) + trial.study.set_user_attr("best_threshold", best_threshold) # Save the best threshold + os.makedirs(best_model_path, exist_ok=True) + + model.save_model(os.path.join(best_model_path, "best_model.json")) + print(f"Best model saved to {os.path.join(best_model_path, 'best_model.json')}") + + # Save and plot binary predictions with the best threshold + save_and_plot_binary_predictions( + train_labels, + preds_train, + valid_labels, + preds_val, + best_threshold, + best_model_path + ) + + return best_f1_val + +if __name__ == "__main__": + study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner()) + study.optimize(objective, n_trials=200) + + print("Study statistics: ") + print(f" Number of finished trials: {len(study.trials)}") + print(f" Best AUC: {study.user_attrs.get('best_auc', None)}") + for key, value in study.best_trial.params.items(): + print(f" {key}: {value}") \ No newline at end of file diff --git a/smiles/scoring/checkpoints/binding-affinity.pt b/smiles/scoring/checkpoints/binding-affinity.pt new file mode 100644 index 0000000000000000000000000000000000000000..e9a1c343d195a49c63a8555d1715095c13fb3971 --- /dev/null +++ b/smiles/scoring/checkpoints/binding-affinity.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8daf2ce5ec84045e94bbe41b0a57669287ac937a2a51e97da1af36ae8dc5e16f +size 132489414 diff --git a/smiles/scoring/checkpoints/hemolysis-xgboost.json b/smiles/scoring/checkpoints/hemolysis-xgboost.json new file mode 100644 index 0000000000000000000000000000000000000000..aa0e9161d5024beda00804b0731d6e7deb43580c --- /dev/null +++ b/smiles/scoring/checkpoints/hemolysis-xgboost.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:69552199fa7722369211aa0e33adc9850816fb816f555ce9e207c73e0543b0ee +size 1932561 diff --git a/smiles/scoring/checkpoints/nonfouling-xgboost.json b/smiles/scoring/checkpoints/nonfouling-xgboost.json new file mode 100644 index 0000000000000000000000000000000000000000..073146b25527be74e43948f9ca0ab75c5f0f4656 --- /dev/null +++ b/smiles/scoring/checkpoints/nonfouling-xgboost.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:08364879c255e236cf0a337505e99e36824f1a72090f5829d3f7870d4ca11644 +size 645478 diff --git a/smiles/scoring/checkpoints/solubility-xgboost.json b/smiles/scoring/checkpoints/solubility-xgboost.json new file mode 100644 index 0000000000000000000000000000000000000000..b24d6573faf8594c42e198856d0f347e58bbe878 --- /dev/null +++ b/smiles/scoring/checkpoints/solubility-xgboost.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3c6a534db85550c947d864aa5cf0b1a2862333bc50eb4b6d5053074cc7d66f02 +size 196887 diff --git a/smiles/scoring/functions/__pycache__/analyzer.cpython-39.pyc b/smiles/scoring/functions/__pycache__/analyzer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91bebfc3f9a6f82f2818b884c38a1b2bee57e404 Binary files /dev/null and b/smiles/scoring/functions/__pycache__/analyzer.cpython-39.pyc differ diff --git a/smiles/scoring/functions/__pycache__/binding.cpython-39.pyc b/smiles/scoring/functions/__pycache__/binding.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e6301e0bd725a80ceae441cd10aca44a920f6f1 Binary files /dev/null and b/smiles/scoring/functions/__pycache__/binding.cpython-39.pyc differ diff --git a/smiles/scoring/functions/__pycache__/hemolysis.cpython-39.pyc b/smiles/scoring/functions/__pycache__/hemolysis.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2285e4aafe0beaaa0f2057bdf3733bd38aeed316 Binary files /dev/null and b/smiles/scoring/functions/__pycache__/hemolysis.cpython-39.pyc differ diff --git a/smiles/scoring/functions/__pycache__/nonfouling.cpython-39.pyc b/smiles/scoring/functions/__pycache__/nonfouling.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f6de4b0e7c68a744aef21d5e0e526e6a969324e Binary files /dev/null and b/smiles/scoring/functions/__pycache__/nonfouling.cpython-39.pyc differ diff --git a/smiles/scoring/functions/__pycache__/permeability.cpython-39.pyc b/smiles/scoring/functions/__pycache__/permeability.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6caf342df56fe904a0ee17dd7eff760fb367f13d Binary files /dev/null and b/smiles/scoring/functions/__pycache__/permeability.cpython-39.pyc differ diff --git a/smiles/scoring/functions/__pycache__/solubility.cpython-39.pyc b/smiles/scoring/functions/__pycache__/solubility.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f22e456118002138611aa5eaeba32b859ef4d42 Binary files /dev/null and b/smiles/scoring/functions/__pycache__/solubility.cpython-39.pyc differ diff --git a/smiles/scoring/functions/analyzer.py b/smiles/scoring/functions/analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..b69fdddf9e95c61e2eca7dea06bdd3fd2880e372 --- /dev/null +++ b/smiles/scoring/functions/analyzer.py @@ -0,0 +1,42 @@ +import sys +import os +sys.path.append('/scratch/pranamlab/tong/ReDi_discrete/smiles') +import xgboost as xgb +import torch +import numpy as np +from transformers import AutoModelForMaskedLM +from smiles_tokenizer.my_tokenizers import SMILES_SPE_Tokenizer +import numpy as np +from rdkit import Chem, rdBase, DataStructs + +class Analyzer: + + def __init__(self, device): + self.device = device + + def get_scores(self, input_seqs): + """Check if the SMILES represents a peptide structure""" + results = [] + + for smiles in input_seqs: + mol = Chem.MolFromSmiles(smiles) + if mol is None: + results.append(0) + continue + + # Look for peptide bonds: NC(=O) pattern + peptide_bond_pattern = Chem.MolFromSmarts('[NH][C](=O)') + + # Look for N-methylated peptide bonds: N(C)C(=O) pattern + n_methyl_pattern = Chem.MolFromSmarts('[N;H0;$(NC)](C)[C](=O)') + + if mol.HasSubstructMatch(peptide_bond_pattern) or mol.HasSubstructMatch(n_methyl_pattern): + results.append(1) + else: + results.append(0) + + return results + + def __call__(self, input_seqs): + scores = self.get_scores(input_seqs) + return torch.tensor(scores) \ No newline at end of file diff --git a/smiles/scoring/functions/binding.py b/smiles/scoring/functions/binding.py new file mode 100644 index 0000000000000000000000000000000000000000..ef2d5cdcd968fa55c29fe3ce8568b8bb482b9e54 --- /dev/null +++ b/smiles/scoring/functions/binding.py @@ -0,0 +1,205 @@ +import sys +sys.path.append('/scratch/pranamlab/tong/ReDi_discrete/smiles') +import numpy as np +from torch.utils.data import Dataset, DataLoader +from sklearn.model_selection import train_test_split +from collections import defaultdict +import torch +import pandas as pd +import torch.nn as nn +import esm +from smiles_tokenizer.my_tokenizers import SMILES_SPE_Tokenizer +from transformers import AutoModelForMaskedLM, AutoModelForCausalLM, AutoTokenizer, AutoModel + + +class ImprovedBindingPredictor(nn.Module): + def __init__(self, + esm_dim=1280, + smiles_dim=768, + hidden_dim=512, + n_heads=8, + n_layers=3, + dropout=0.1): + super().__init__() + + # Define binding thresholds + self.tight_threshold = 7.5 # Kd/Ki/IC50 ≤ ~30nM + self.weak_threshold = 6.0 # Kd/Ki/IC50 > 1μM + + # Project to same dimension + self.smiles_projection = nn.Linear(smiles_dim, hidden_dim) + self.protein_projection = nn.Linear(esm_dim, hidden_dim) + self.protein_norm = nn.LayerNorm(hidden_dim) + self.smiles_norm = nn.LayerNorm(hidden_dim) + + # Cross attention blocks with layer norm + self.cross_attention_layers = nn.ModuleList([ + nn.ModuleDict({ + 'attention': nn.MultiheadAttention(hidden_dim, n_heads, dropout=dropout), + 'norm1': nn.LayerNorm(hidden_dim), + 'ffn': nn.Sequential( + nn.Linear(hidden_dim, hidden_dim * 4), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim * 4, hidden_dim) + ), + 'norm2': nn.LayerNorm(hidden_dim) + }) for _ in range(n_layers) + ]) + + # Prediction heads + self.shared_head = nn.Sequential( + nn.Linear(hidden_dim * 2, hidden_dim), + nn.ReLU(), + nn.Dropout(dropout), + ) + + # Regression head + self.regression_head = nn.Linear(hidden_dim, 1) + + # Classification head (3 classes: tight, medium, loose binding) + self.classification_head = nn.Linear(hidden_dim, 3) + + def get_binding_class(self, affinity): + """Convert affinity values to class indices + 0: tight binding (>= 7.5) + 1: medium binding (6.0-7.5) + 2: weak binding (< 6.0) + """ + if isinstance(affinity, torch.Tensor): + tight_mask = affinity >= self.tight_threshold + weak_mask = affinity < self.weak_threshold + medium_mask = ~(tight_mask | weak_mask) + + classes = torch.zeros_like(affinity, dtype=torch.long) + classes[medium_mask] = 1 + classes[weak_mask] = 2 + return classes + else: + if affinity >= self.tight_threshold: + return 0 # tight binding + elif affinity < self.weak_threshold: + return 2 # weak binding + else: + return 1 # medium binding + + def forward(self, protein_emb, smiles_emb): + protein = self.protein_norm(self.protein_projection(protein_emb)) + smiles = self.smiles_norm(self.smiles_projection(smiles_emb)) + + #protein = protein.transpose(0, 1) + #smiles = smiles.transpose(0, 1) + + # Cross attention layers + for layer in self.cross_attention_layers: + # Protein attending to SMILES + attended_protein = layer['attention']( + protein, smiles, smiles + )[0] + protein = layer['norm1'](protein + attended_protein) + protein = layer['norm2'](protein + layer['ffn'](protein)) + + # SMILES attending to protein + attended_smiles = layer['attention']( + smiles, protein, protein + )[0] + smiles = layer['norm1'](smiles + attended_smiles) + smiles = layer['norm2'](smiles + layer['ffn'](smiles)) + + # Get sequence-level representations + protein_pool = torch.mean(protein, dim=0) + smiles_pool = torch.mean(smiles, dim=0) + + # Concatenate both representations + combined = torch.cat([protein_pool, smiles_pool], dim=-1) + + # Shared features + shared_features = self.shared_head(combined) + + regression_output = self.regression_head(shared_features) + classification_logits = self.classification_head(shared_features) + + return regression_output, classification_logits + +class BindingAffinity: + def __init__(self, prot_seq, device, model_type='PeptideCLM'): + super().__init__() + + if model_type == 'PepDoRA': + # peptide embeddings + model_name = "ChatterjeeLab/PepDoRA" + self.pep_tokenizer = AutoTokenizer.from_pretrained(model_name) + self.pep_model = AutoModel.from_pretrained(model_name) + + self.model = ImprovedBindingPredictor(smiles_dim=384) + checkpoint = torch.load('/scratch/pranamlab/tong/ReDi_discrete/smiles/scoring/functions/binding/best_model_optuna1.pt') + self.model.load_state_dict(checkpoint['model_state_dict']) + else: + # peptide embeddings + self.pep_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer + self.pep_tokenizer = SMILES_SPE_Tokenizer('/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/new_vocab.txt', + '/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/new_splits.txt') + + + self.model = ImprovedBindingPredictor(smiles_dim=768).to(device) + checkpoint = torch.load('/scratch/pranamlab/tong/ReDi_discrete/smiles/scoring/checkpoints/binding-affinity.pt', weights_only=False) + self.model.load_state_dict(checkpoint['model_state_dict']) + + self.model.eval() + + self.esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D() # load ESM-2 model + self.esm_model.to(device) + self.prot_tokenizer = alphabet.get_batch_converter() # load esm tokenizer + + data = [("target", prot_seq)] + # get tokenized protein + _, _, prot_tokens = self.prot_tokenizer(data) + prot_tokens = prot_tokens.to(device) + with torch.no_grad(): + results = self.esm_model.forward(prot_tokens, repr_layers=[33]) # Example with ESM-2 + prot_emb = results["representations"][33] + + self.prot_emb = prot_emb[0] + self.prot_emb = torch.mean(self.prot_emb, dim=0, keepdim=True).to(device) + + self.device = device + + def forward(self, input_seqs): + with torch.no_grad(): + scores = [] + for seq in input_seqs: + pep_tokens = self.pep_tokenizer(seq, return_tensors='pt', padding=True) + + with torch.no_grad(): + emb = self.pep_model(input_ids=pep_tokens['input_ids'], + attention_mask=pep_tokens['attention_mask'], + output_hidden_states=True) + + #emb = self.pep_model(input_ids=pep_tokens['input_ids'], attention_mask=pep_tokens['attention_mask']) + pep_emb = emb.last_hidden_state.squeeze(0).to(self.device) + pep_emb = torch.mean(pep_emb, dim=0, keepdim=True) + + score, logits = self.model.forward(self.prot_emb, pep_emb) + scores.append(score.item()) + return torch.tensor(scores) + + def __call__(self, input_seqs: list): + return self.forward(input_seqs) + +def unittest(): + amhr = 'MLGSLGLWALLPTAVEAPPNRRTCVFFEAPGVRGSTKTLGELLDTGTELPRAIRCLYSRCCFGIWNLTQDRAQVEMQGCRDSDEPGCESLHCDPSPRAHPSPGSTLFTCSCGTDFCNANYSHLPPPGSPGTPGSQGPQAAPGESIWMALVLLGLFLLLLLLLGSIILALLQRKNYRVRGEPVPEPRPDSGRDWSVELQELPELCFSQVIREGGHAVVWAGQLQGKLVAIKAFPPRSVAQFQAERALYELPGLQHDHIVRFITASRGGPGRLLSGPLLVLELHPKGSLCHYLTQYTSDWGSSLRMALSLAQGLAFLHEERWQNGQYKPGIAHRDLSSQNVLIREDGSCAIGDLGLALVLPGLTQPPAWTPTQPQGPAAIMEAGTQRYMAPELLDKTLDLQDWGMALRRADIYSLALLLWEILSRCPDLRPDSSPPPFQLAYEAELGNTPTSDELWALAVQERRRPYIPSTWRCFATDPDGLRELLEDCWDADPEARLTAECVQQRLAALAHPQESHPFPESCPRGCPPLCPEDCTSIPAPTILPCRPQRSACHFSVQQGPCSRNPQPACTLSPV' + tfr = 'MMDQARSAFSNLFGGEPLSYTRFSLARQVDGDNSHVEMKLAVDEEENADNNTKANVTKPKRCSGSICYGTIAVIVFFLIGFMIGYLGYCKGVEPKTECERLAGTESPVREEPGEDFPAARRLYWDDLKRKLSEKLDSTDFTGTIKLLNENSYVPREAGSQKDENLALYVENQFREFKLSKVWRDQHFVKIQVKDSAQNSVIIVDKNGRLVYLVENPGGYVAYSKAATVTGKLVHANFGTKKDFEDLYTPVNGSIVIVRAGKITFAEKVANAESLNAIGVLIYMDQTKFPIVNAELSFFGHAHLGTGDPYTPGFPSFNHTQFPPSRSSGLPNIPVQTISRAAAEKLFGNMEGDCPSDWKTDSTCRMVTSESKNVKLTVSNVLKEIKILNIFGVIKGFVEPDHYVVVGAQRDAWGPGAAKSGVGTALLLKLAQMFSDMVLKDGFQPSRSIIFASWSAGDFGSVGATEWLEGYLSSLHLKAFTYINLDKAVLGTSNFKVSASPLLYTLIEKTMQNVKHPVTGQFLYQDSNWASKVEKLTLDNAAFPFLAYSGIPAVSFCFCEDTDYPYLGTTMDTYKELIERIPELNKVARAAAEVAGQFVIKLTHDVELNLDYERYNSQLLSFVRDLNQYRADIKEMGLSLQWLYSARGDFFRATSRLTTDFGNAEKTDRFVMKKLNDRVMRVEYHFLSPYVSPKESPFRHVFWGSGSHTLPALLENLKLRKQNNGAFNETLFRNQLALATWTIQGAANALSGDVWDIDNEF' + gfap = 'MERRRITSAARRSYVSSGEMMVGGLAPGRRLGPGTRLSLARMPPPLPTRVDFSLAGALNAGFKETRASERAEMMELNDRFASYIEKVRFLEQQNKALAAELNQLRAKEPTKLADVYQAELRELRLRLDQLTANSARLEVERDNLAQDLATVRQKLQDETNLRLEAENNLAAYRQEADEATLARLDLERKIESLEEEIRFLRKIHEEEVRELQEQLARQQVHVELDVAKPDLTAALKEIRTQYEAMASSNMHEAEEWYRSKFADLTDAAARNAELLRQAKHEANDYRRQLQSLTCDLESLRGTNESLERQMREQEERHVREAASYQEALARLEEEGQSLKDEMARHLQEYQDLLNVKLALDIEIATYRKLLEGEENRITIPVQTFSNLQIRETSLDTKSVSEGHLKRNIVVKTVEMRDGEVIKESKQEHKDVM' + glp1 = 'MAGAPGPLRLALLLLGMVGRAGPRPQGATVSLWETVQKWREYRRQCQRSLTEDPPPATDLFCNRTFDEYACWPDGEPGSFVNVSCPWYLPWASSVPQGHVYRFCTAEGLWLQKDNSSLPWRDLSECEESKRGERSSPEEQLLFLYIIYTVGYALSFSALVIASAILLGFRHLHCTRNYIHLNLFASFILRALSVFIKDAALKWMYSTAAQQHQWDGLLSYQDSLSCRLVFLLMQYCVAANYYWLLVEGVYLYTLLAFSVLSEQWIFRLYVSIGWGVPLLFVVPWGIVKYLYEDEGCWTRNSNMNYWLIIRLPILFAIGVNFLIFVRVICIVVSKLKANLMCKTDIKCRLAKSTLTLIPLLGTHEVIFAFVMDEHARGTLRFIKLFTELSFTSFQGLMVAILYCFVNNEVQLEFRKSWERWRLEHLHIQRDSSMKPLKCPTSSLSSGATAGSSMYTATCQASCS' + glast = 'MTKSNGEEPKMGGRMERFQQGVRKRTLLAKKKVQNITKEDVKSYLFRNAFVLLTVTAVIVGTILGFTLRPYRMSYREVKYFSFPGELLMRMLQMLVLPLIISSLVTGMAALDSKASGKMGMRAVVYYMTTTIIAVVIGIIIVIIIHPGKGTKENMHREGKIVRVTAADAFLDLIRNMFPPNLVEACFKQFKTNYEKRSFKVPIQANETLVGAVINNVSEAMETLTRITEELVPVPGSVNGVNALGLVVFSMCFGFVIGNMKEQGQALREFFDSLNEAIMRLVAVIMWYAPVGILFLIAGKIVEMEDMGVIGGQLAMYTVTVIVGLLIHAVIVLPLLYFLVTRKNPWVFIGGLLQALITALGTSSSSATLPITFKCLEENNGVDKRVTRFVLPVGATINMDGTALYEALAAIFIAQVNNFELNFGQIITISITATAASIGAAGIPQAGLVTMVIVLTSVGLPTDDITLIIAVDWFLDRLRTTTNVLGDSLGAGIVEHLSRHELKNRDVEMGNSVIEENEMKKPYQLIAQDNETEKPIDSETKM' + ncam = 'LQTKDLIWTLFFLGTAVSLQVDIVPSQGEISVGESKFFLCQVAGDAKDKDISWFSPNGEKLTPNQQRISVVWNDDSSSTLTIYNANIDDAGIYKCVVTGEDGSESEATVNVKIFQKLMFKNAPTPQEFREGEDAVIVCDVVSSLPPTIIWKHKGRDVILKKDVRFIVLSNNYLQIRGIKKTDEGTYRCEGRILARGEINFKDIQVIVNVPPTIQARQNIVNATANLGQSVTLVCDAEGFPEPTMSWTKDGEQIEQEEDDEKYIFSDDSSQLTIKKVDKNDEAEYICIAENKAGEQDATIHLKVFAKPKITYVENQTAMELEEQVTLTCEASGDPIPSITWRTSTRNISSEEKASWTRPEKQETLDGHMVVRSHARVSSLTLKSIQYTDAGEYICTASNTIGQDSQSMYLEVQYAPKLQGPVAVYTWEGNQVNITCEVFAYPSATISWFRDGQLLPSSNYSNIKIYNTPSASYLEVTPDSENDFGNYNCTAVNRIGQESLEFILVQADTPSSPSIDQVEPYSSTAQVQFDEPEATGGVPILKYKAEWRAVGEEVWHSKWYDAKEASMEGIVTIVGLKPETTYAVRLAALNGKGLGEISAASEF' + + binding = BindingAffinity(amhr, device='cuda:0') + seq = ['N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CCSC(F)F)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](Cc1ccc(O)cc1)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CCC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(C)C)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CCC(=O)N)C(=O)N[C@@H](CCC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)O'] + + scores = binding(seq) + print(scores) + print(len(scores)) + +if __name__ == '__main__': + unittest() \ No newline at end of file diff --git a/smiles/scoring/functions/binding_utils.py b/smiles/scoring/functions/binding_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b66e0b9ada1893dc78f9bcbb8b65a684a43bdfe3 --- /dev/null +++ b/smiles/scoring/functions/binding_utils.py @@ -0,0 +1,291 @@ +from torch import nn +import pdb +import torch +import numpy as np + +def to_var(x): + if torch.cuda.is_available(): + x = x.cuda() + return x + +class MultiHeadAttentionSequence(nn.Module): + + def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): + + super().__init__() + + self.n_head = n_head + self.d_model = d_model + self.d_k = d_k + self.d_v = d_v + + self.W_Q = nn.Linear(d_model, n_head*d_k) + self.W_K = nn.Linear(d_model, n_head*d_k) + self.W_V = nn.Linear(d_model, n_head*d_v) + self.W_O = nn.Linear(n_head*d_v, d_model) + + self.layer_norm = nn.LayerNorm(d_model) + + self.dropout = nn.Dropout(dropout) + + def forward(self, q, k, v): + + batch, len_q, _ = q.size() + batch, len_k, _ = k.size() + batch, len_v, _ = v.size() + + Q = self.W_Q(q).view([batch, len_q, self.n_head, self.d_k]) + K = self.W_K(k).view([batch, len_k, self.n_head, self.d_k]) + V = self.W_V(v).view([batch, len_v, self.n_head, self.d_v]) + + Q = Q.transpose(1, 2) + K = K.transpose(1, 2).transpose(2, 3) + V = V.transpose(1, 2) + + attention = torch.matmul(Q, K) + + attention = attention / np.sqrt(self.d_k) + + attention = F.softmax(attention, dim=-1) + + output = torch.matmul(attention, V) + + output = output.transpose(1, 2).reshape([batch, len_q, self.d_v*self.n_head]) + + output = self.W_O(output) + + output = self.dropout(output) + + output = self.layer_norm(output + q) + + return output, attention + +class MultiHeadAttentionReciprocal(nn.Module): + + + def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): + + super().__init__() + + self.n_head = n_head + self.d_model = d_model + self.d_k = d_k + self.d_v = d_v + + self.W_Q = nn.Linear(d_model, n_head*d_k) + self.W_K = nn.Linear(d_model, n_head*d_k) + self.W_V = nn.Linear(d_model, n_head*d_v) + self.W_O = nn.Linear(n_head*d_v, d_model) + self.W_V_2 = nn.Linear(d_model, n_head*d_v) + self.W_O_2 = nn.Linear(n_head*d_v, d_model) + + self.layer_norm = nn.LayerNorm(d_model) + + self.dropout = nn.Dropout(dropout) + + self.layer_norm_2 = nn.LayerNorm(d_model) + + self.dropout_2 = nn.Dropout(dropout) + + def forward(self, q, k, v, v_2): + + batch, len_q, _ = q.size() + batch, len_k, _ = k.size() + batch, len_v, _ = v.size() + batch, len_v_2, _ = v_2.size() + + Q = self.W_Q(q).view([batch, len_q, self.n_head, self.d_k]) + K = self.W_K(k).view([batch, len_k, self.n_head, self.d_k]) + V = self.W_V(v).view([batch, len_v, self.n_head, self.d_v]) + V_2 = self.W_V_2(v_2).view([batch, len_v_2, self.n_head, self.d_v]) + + Q = Q.transpose(1, 2) + K = K.transpose(1, 2).transpose(2, 3) + V = V.transpose(1, 2) + V_2 = V_2.transpose(1,2) + + attention = torch.matmul(Q, K) + + + attention = attention /np.sqrt(self.d_k) + + attention_2 = attention.transpose(-2, -1) + + + + attention = F.softmax(attention, dim=-1) + + attention_2 = F.softmax(attention_2, dim=-1) + + + output = torch.matmul(attention, V) + + output_2 = torch.matmul(attention_2, V_2) + + output = output.transpose(1, 2).reshape([batch, len_q, self.d_v*self.n_head]) + + output_2 = output_2.transpose(1, 2).reshape([batch, len_k, self.d_v*self.n_head]) + + output = self.W_O(output) + + output_2 = self.W_O_2(output_2) + + output = self.dropout(output) + + output = self.layer_norm(output + q) + + output_2 = self.dropout(output_2) + + output_2 = self.layer_norm(output_2 + k) + + + return output, output_2, attention, attention_2 + + +class FFN(nn.Module): + + def __init__(self, d_in, d_hid, dropout=0.1): + super().__init__() + + self.layer_1 = nn.Conv1d(d_in, d_hid,1) + self.layer_2 = nn.Conv1d(d_hid, d_in,1) + self.relu = nn.ReLU() + self.layer_norm = nn.LayerNorm(d_in) + + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + + residual = x + output = self.layer_1(x.transpose(1, 2)) + + output = self.relu(output) + + output = self.layer_2(output) + + output = self.dropout(output) + + output = self.layer_norm(output.transpose(1, 2)+residual) + + return output + +class ConvLayer(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, padding, dilation): + super(ConvLayer, self).__init__() + self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation) + self.relu = nn.ReLU() + + def forward(self, x): + out = self.conv(x) + out = self.relu(out) + return out + + +class DilatedCNN(nn.Module): + def __init__(self, d_model, d_hidden): + super(DilatedCNN, self).__init__() + self.first_ = nn.ModuleList() + self.second_ = nn.ModuleList() + self.third_ = nn.ModuleList() + + dilation_tuple = (1, 2, 3) + dim_in_tuple = (d_model, d_hidden, d_hidden) + dim_out_tuple = (d_hidden, d_hidden, d_hidden) + + for i, dilation_rate in enumerate(dilation_tuple): + self.first_.append(ConvLayer(dim_in_tuple[i], dim_out_tuple[i], kernel_size=3, padding=dilation_rate, + dilation=dilation_rate)) + + for i, dilation_rate in enumerate(dilation_tuple): + self.second_.append(ConvLayer(dim_in_tuple[i], dim_out_tuple[i], kernel_size=5, padding=2*dilation_rate, + dilation=dilation_rate)) + + for i, dilation_rate in enumerate(dilation_tuple): + self.third_.append(ConvLayer(dim_in_tuple[i], dim_out_tuple[i], kernel_size=7, padding=3*dilation_rate, + dilation=dilation_rate)) + + def forward(self, protein_seq_enc): + # pdb.set_trace() + protein_seq_enc = protein_seq_enc.transpose(1, 2) # protein_seq_enc's shape: B*L*d_model -> B*d_model*L + + first_embedding = protein_seq_enc + second_embedding = protein_seq_enc + third_embedding = protein_seq_enc + + for i in range(len(self.first_)): + first_embedding = self.first_[i](first_embedding) + + for i in range(len(self.second_)): + second_embedding = self.second_[i](second_embedding) + + for i in range(len(self.third_)): + third_embedding = self.third_[i](third_embedding) + + # pdb.set_trace() + + protein_seq_enc = first_embedding + second_embedding + third_embedding + + return protein_seq_enc.transpose(1, 2) + + +class ReciprocalLayerwithCNN(nn.Module): + + def __init__(self, d_model, d_inner, d_hidden, n_head, d_k, d_v): + super().__init__() + + self.cnn = DilatedCNN(d_model, d_hidden) + + self.sequence_attention_layer = MultiHeadAttentionSequence(n_head, d_hidden, d_k, d_v) + + self.protein_attention_layer = MultiHeadAttentionSequence(n_head, d_hidden, d_k, d_v) + + self.reciprocal_attention_layer = MultiHeadAttentionReciprocal(n_head, d_hidden, d_k, d_v) + + self.ffn_seq = FFN(d_hidden, d_inner) + + self.ffn_protein = FFN(d_hidden, d_inner) + + def forward(self, sequence_enc, protein_seq_enc): + # pdb.set_trace() # protein_seq_enc.shape = B * L * d_model + protein_seq_enc = self.cnn(protein_seq_enc) + prot_enc, prot_attention = self.protein_attention_layer(protein_seq_enc, protein_seq_enc, protein_seq_enc) + + seq_enc, sequence_attention = self.sequence_attention_layer(sequence_enc, sequence_enc, sequence_enc) + + prot_enc, seq_enc, prot_seq_attention, seq_prot_attention = self.reciprocal_attention_layer(prot_enc, seq_enc, seq_enc, prot_enc) + + prot_enc = self.ffn_protein(prot_enc) + + seq_enc = self.ffn_seq(seq_enc) + + return prot_enc, seq_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention + + +class ReciprocalLayer(nn.Module): + + def __init__(self, d_model, d_inner, n_head, d_k, d_v): + + super().__init__() + + self.sequence_attention_layer = MultiHeadAttentionSequence(n_head, d_model, d_k, d_v) + + self.protein_attention_layer = MultiHeadAttentionSequence(n_head, d_model, d_k, d_v) + + self.reciprocal_attention_layer = MultiHeadAttentionReciprocal(n_head, d_model, d_k, d_v) + + self.ffn_seq = FFN(d_model, d_inner) + + self.ffn_protein = FFN(d_model, d_inner) + + def forward(self, sequence_enc, protein_seq_enc): + prot_enc, prot_attention = self.protein_attention_layer(protein_seq_enc, protein_seq_enc, protein_seq_enc) + + seq_enc, sequence_attention = self.sequence_attention_layer(sequence_enc, sequence_enc, sequence_enc) + + + prot_enc, seq_enc, prot_seq_attention, seq_prot_attention = self.reciprocal_attention_layer(prot_enc, seq_enc, seq_enc, prot_enc) + prot_enc = self.ffn_protein(prot_enc) + + seq_enc = self.ffn_seq(seq_enc) + + return prot_enc, seq_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention \ No newline at end of file diff --git a/smiles/scoring/functions/hemolysis.py b/smiles/scoring/functions/hemolysis.py new file mode 100644 index 0000000000000000000000000000000000000000..77dca9cb2484f191b1ad9dc357c669e5d1901524 --- /dev/null +++ b/smiles/scoring/functions/hemolysis.py @@ -0,0 +1,71 @@ +import sys +import os +sys.path.append('/scratch/pranamlab/tong/ReDi_discrete/smiles') +import xgboost as xgb +import torch +import numpy as np +from transformers import AutoModelForMaskedLM +from smiles_tokenizer.my_tokenizers import SMILES_SPE_Tokenizer +import warnings +import numpy as np +from rdkit.Chem import Descriptors, rdMolDescriptors +from rdkit import Chem, rdBase, DataStructs +from transformers import AutoModelForMaskedLM + + +rdBase.DisableLog('rdApp.error') +warnings.filterwarnings("ignore", category=DeprecationWarning) +warnings.filterwarnings("ignore", category=UserWarning) +warnings.filterwarnings("ignore", category=FutureWarning) + +class Hemolysis: + + def __init__(self, device): + self.predictor = xgb.Booster(model_file='/scratch/pranamlab/tong/ReDi_discrete/smiles/scoring/checkpoints/hemolysis-xgboost.json') + self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device) + self.tokenizer = SMILES_SPE_Tokenizer('/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/new_vocab.txt', + '/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/new_splits.txt') + + self.device = device + + def generate_embeddings(self, sequences): + embeddings = [] + for sequence in sequences: + tokenized = self.tokenizer(sequence, return_tensors='pt').to(self.device) + with torch.no_grad(): + output = self.emb_model(**tokenized) + # Mean pooling across sequence length + embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy() + embeddings.append(embedding) + return np.array(embeddings) + + def get_scores(self, input_seqs: list): + scores = np.ones(len(input_seqs)) + features = self.generate_embeddings(input_seqs) + + if len(features) == 0: + return scores + + features = np.nan_to_num(features, nan=0.) + features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max) + + features = xgb.DMatrix(features) + + probs = self.predictor.predict(features) + # return the probability of it being not hemolytic + return scores - probs + + def __call__(self, input_seqs: list): + scores = self.get_scores(input_seqs) + return torch.tensor(scores) + +def unittest(): + hemo = Hemolysis(device='cuda:6') + seq = ["N[C@@H](CC(=O)O)-N[C@@H](C[C@@H](C))C(=O)N[C@@H](CCC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H]([C@H](CC)C)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CCC(=O)N)C(=O)N[C@H](CO)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CCC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CCC(=O)N)C(=O)N[C@@H](CO)C(=O)N[C@H](CO)C(=O)N[C@@H](CCC(=O)N)C(=O)N[C@@H](CCN)C(=O)N[C@@H](CCCC(=O)N)C(=O)N[C@@H](CCC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](Cc1ccc(O)cc1)C(=O)N[C@@H](CCC(=O)N)C(=O)N[C@@H](CCC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CCC(=O)N)C(=O)O"] + + scores = hemo(input_seqs=seq) + print(scores) + + +if __name__ == '__main__': + unittest() \ No newline at end of file diff --git a/smiles/scoring/functions/nonfouling.py b/smiles/scoring/functions/nonfouling.py new file mode 100644 index 0000000000000000000000000000000000000000..fe94d39a9fcae5d639de20f41dd49eebae063587 --- /dev/null +++ b/smiles/scoring/functions/nonfouling.py @@ -0,0 +1,69 @@ +import sys +import os +sys.path.append('/scratch/pranamlab/tong/ReDi_discrete/smiles') +import xgboost as xgb +import torch +import numpy as np +from transformers import AutoModelForMaskedLM +from smiles_tokenizer.my_tokenizers import SMILES_SPE_Tokenizer +import warnings +import numpy as np +from rdkit import Chem, rdBase, DataStructs +from transformers import AutoModelForMaskedLM + + +rdBase.DisableLog('rdApp.error') +warnings.filterwarnings("ignore", category=DeprecationWarning) +warnings.filterwarnings("ignore", category=UserWarning) +warnings.filterwarnings("ignore", category=FutureWarning) + +class Nonfouling: + + def __init__(self, device): + self.predictor = xgb.Booster(model_file='/scratch/pranamlab/tong/ReDi_discrete/smiles/scoring/checkpoints/nonfouling-xgboost.json') + self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device) + self.tokenizer = SMILES_SPE_Tokenizer('/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/old_vocab.txt', + '/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/old_splits.txt') + self.device = device + + def generate_embeddings(self, sequences): + embeddings = [] + for sequence in sequences: + tokenized = self.tokenizer(sequence, return_tensors='pt').to(self.device) + with torch.no_grad(): + output = self.emb_model(**tokenized) + # Mean pooling across sequence length + embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy() + embeddings.append(embedding) + return np.array(embeddings) + + def get_scores(self, input_seqs: list): + scores = np.zeros(len(input_seqs)) + features = self.generate_embeddings(input_seqs) + + if len(features) == 0: + return scores + + features = np.nan_to_num(features, nan=0.) + features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max) + + features = xgb.DMatrix(features) + + scores = self.predictor.predict(features) + # return the probability of it being not hemolytic + return scores + + def __call__(self, input_seqs: list): + scores = self.get_scores(input_seqs) + return torch.tensor(scores) + +def unittest(): + nf = Nonfouling(device='cuda:6') + seq = ["N[C@@H](CC(=O)O)-N[C@@H](C[C@@H](C))C(=O)N[C@@H](CCC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H]([C@H](CC)C)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CCC(=O)N)C(=O)N[C@H](CO)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CCC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CCC(=O)N)C(=O)N[C@@H](CO)C(=O)N[C@H](CO)C(=O)N[C@@H](CCC(=O)N)C(=O)N[C@@H](CCN)C(=O)N[C@@H](CCCC(=O)N)C(=O)N[C@@H](CCC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](Cc1ccc(O)cc1)C(=O)N[C@@H](CCC(=O)N)C(=O)N[C@@H](CCC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CCC(=O)N)C(=O)O"] + + scores = nf(input_seqs=seq) + print(scores) + + +if __name__ == '__main__': + unittest() \ No newline at end of file diff --git a/smiles/scoring/functions/scoring_utils.py b/smiles/scoring/functions/scoring_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..642b6e9d49d652a14a2f39bdea01410eba30a911 --- /dev/null +++ b/smiles/scoring/functions/scoring_utils.py @@ -0,0 +1,111 @@ +import warnings +import numpy as np +from loguru import logger +from sklearn.ensemble import RandomForestRegressor +from rdkit.Chem import Descriptors, rdMolDescriptors +import joblib +from transformation import TransformFunction +from rdkit import Chem, rdBase, DataStructs +from rdkit.Chem import AllChem +from typing import List + + +def fingerprints_from_mol(molecule, radius=3, size=2048, hashed=False): + """ + Create ECFP fingerprint of a molecule + """ + if hashed: + fp_bits = AllChem.GetHashedMorganFingerprint(molecule, radius, nBits=size) + else: + fp_bits = AllChem.GetMorganFingerprintAsBitVect(molecule, radius, nBits=size) + fp_np = np.zeros((1,)) + DataStructs.ConvertToNumpyArray(fp_bits, fp_np) + return fp_np.reshape(1, -1) + + +def fingerprints_from_smiles(smiles: List, size=2048): + """ Create ECFP fingerprints of smiles, with validity check """ + fps = [] + valid_mask = [] + for i, smile in enumerate(smiles): + mol = Chem.MolFromSmiles(smile) + valid_mask.append(int(mol is not None)) + fp = fingerprints_from_mol(mol, size=size) if mol else np.zeros((1, size)) + fps.append(fp) + + fps = np.concatenate(fps, axis=0) if len(fps) > 0 else np.zeros((0, size)) + return fps, valid_mask + + +def getMolDescriptors(mol, missingVal=0): + """ calculate the full list of descriptors for a molecule """ + + values, names = [], [] + for nm, fn in Descriptors._descList: + try: + val = fn(mol) + except: + val = missingVal + values.append(val) + names.append(nm) + + custom_descriptors = {'hydrogen-bond donors': rdMolDescriptors.CalcNumLipinskiHBD, + 'hydrogen-bond acceptors': rdMolDescriptors.CalcNumLipinskiHBA, + 'rotatable bonds': rdMolDescriptors.CalcNumRotatableBonds,} + + for nm, fn in custom_descriptors.items(): + try: + val = fn(mol) + except: + val = missingVal + values.append(val) + names.append(nm) + return values, names + + +def get_pep_dps_from_smi(smi): + try: + mol = Chem.MolFromSmiles(smi) + except: + print(f"convert smi {smi} to molecule failed!") + mol = None + + dps, _ = getMolDescriptors(mol) + return np.array(dps) + + +def get_pep_dps(smi_list): + if len(smi_list) == 0: + return np.zeros((0, 211)) + return np.array([get_pep_dps_from_smi(smi) for smi in smi_list]) + + +"""def get_smi_from_helms(helm_seqs: list): + valid_idxes = [] + valid_smiles = [] + + for idx, helm in enumerate(helm_seqs): + # Ignore helm which cannot converted into molecules + try: + smi = get_cycpep_smi_from_helm(helm) + if smi: + valid_idxes.append(idx) + valid_smiles.append(smi) + except Exception as e: + # logger.debug(f'Error: {e} in helm {helm}') + pass + return valid_smiles, valid_idxes""" + + +def check_smi_validity(smiles: list): + valid_smi, valid_idx = [], [] + for idx, smi in enumerate(smiles): + try: + mol = Chem.MolFromSmiles(smi) if smi else None + if mol: + valid_smi.append(smi) + valid_idx.append(idx) + except Exception as e: + # logger.debug(f'Error: {e} in smiles {smi}') + pass + return valid_smi, valid_idx \ No newline at end of file diff --git a/smiles/scoring/functions/solubility.py b/smiles/scoring/functions/solubility.py new file mode 100644 index 0000000000000000000000000000000000000000..996c68903d139c51af4703bf33b2ea68ef681a30 --- /dev/null +++ b/smiles/scoring/functions/solubility.py @@ -0,0 +1,64 @@ +import sys +import os +sys.path.append('/scratch/pranamlab/tong/ReDi_discrete/smiles') +import xgboost as xgb +import torch +import numpy as np +from transformers import AutoModelForMaskedLM +from smiles_tokenizer.my_tokenizers import SMILES_SPE_Tokenizer +import warnings +import numpy as np +from rdkit import Chem, rdBase, DataStructs + + +rdBase.DisableLog('rdApp.error') +warnings.filterwarnings("ignore", category=DeprecationWarning) +warnings.filterwarnings("ignore", category=UserWarning) +warnings.filterwarnings("ignore", category=FutureWarning) + +class Solubility: + def __init__(self, device): + self.predictor = xgb.Booster(model_file='/scratch/pranamlab/tong/ReDi_discrete/smiles/scoring/checkpoints/solubility-xgboost.json') + self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device) + self.tokenizer = SMILES_SPE_Tokenizer('/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/old_vocab.txt', + '/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/old_splits.txt') + self.device = device + + def generate_embeddings(self, sequences): + embeddings = [] + for sequence in sequences: + tokenized = self.tokenizer(sequence, return_tensors='pt').to(self.device) + with torch.no_grad(): + output = self.emb_model(**tokenized) + # Mean pooling across sequence length + embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy() + embeddings.append(embedding) + return np.array(embeddings) + + def get_scores(self, input_seqs: list): + scores = np.zeros(len(input_seqs)) + features = self.generate_embeddings(input_seqs) + + if len(features) == 0: + return scores + + features = np.nan_to_num(features, nan=0.) + features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max) + + features = xgb.DMatrix(features) + + scores = self.predictor.predict(features) + return scores + + def __call__(self, input_seqs: list): + scores = self.get_scores(input_seqs) + return torch.tensor(scores) + +def unittest(): + solubility = Solubility(device='cuda:6') + seq = ["N[C@@H](CC(=O)N)C(=O)N[C@@H](CCC(=O)N)C(=O)N[C@@H](CCC(=O)N)C(=O)N[C@@H]([C@H](O)C)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CCC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CCC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](C)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CCC(=O)N)C(=O)O"] + scores = solubility(input_seqs=seq) + print(scores) + +if __name__ == '__main__': + unittest() \ No newline at end of file diff --git a/smiles/scoring/scoring_functions.py b/smiles/scoring/scoring_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..4839445d0573db7aa58ea7ca7685a3dd5aa61fb1 --- /dev/null +++ b/smiles/scoring/scoring_functions.py @@ -0,0 +1,104 @@ +import sys +sys.path.append('/scratch/pranamlab/tong/ReDi_discrete/smiles') +import io +import subprocess +import warnings +import numpy as np +import pandas as pd +from typing import List +from loguru import logger +from tqdm import tqdm +from rdkit import Chem, rdBase, DataStructs +from rdkit.Chem import AllChem +import torch +from scoring.functions.binding.binding import BindingAffinity +from scoring.functions.permeability.permeability import Permeability +from scoring.functions.solubility.solubility import Solubility +from scoring.functions.hemolysis.hemolysis import Hemolysis +from scoring.functions.nonfouling.nonfouling import Nonfouling + +class ScoringFunctions: + def __init__(self, score_func_names=None, prot_seqs=[]): + """ + Class for generating score vectors given generated sequence + + Args: + score_func_names: list of scoring function names to be evaluated + score_weights: weights to scale scores (default: 1) + target_protein: sequence of target protein binder + """ + if score_func_names is None: + # just do unmasking based on validity of peptide bonds + self.score_func_names = [] + else: + self.score_func_names = score_func_names + + # self.weights = np.array([1] * len(self.score_func_names) if score_weights is None else score_weights) + + # binding affinities + self.target_protein = prot_seqs + print(len(prot_seqs)) + + if ('binding_affinity1' in score_func_names) and (len(prot_seqs) == 1): + binding_affinity1 = BindingAffinity(prot_seqs[0]) + binding_affinity2 = None + elif ('binding_affinity1' in score_func_names) and ('binding_affinity2' in score_func_names) and (len(prot_seqs) == 2): + binding_affinity1 = BindingAffinity(prot_seqs[0]) + binding_affinity2 = BindingAffinity(prot_seqs[1]) + else: + print("here") + binding_affinity1 = None + binding_affinity2 = None + + permeability = Permeability() + sol = Solubility() + nonfouling = Nonfouling() + hemo = Hemolysis() + + self.all_funcs = {'binding_affinity1': binding_affinity1, + 'binding_affinity2': binding_affinity2, + 'permeability': permeability, + 'nonfouling': nonfouling, + 'solubility': sol, + 'hemolysis': hemo + } + + def forward(self, input_seqs): + scores = [] + + for i, score_func in enumerate(self.score_func_names): + score = self.all_funcs[score_func](input_seqs = input_seqs) + + scores.append(score) + + # convert to numpy arrays with shape (num_sequences, num_functions) + scores = np.float32(scores).T + + return scores + + def __call__(self, input_seqs: list): + return self.forward(input_seqs) + + +def unittest(): + amhr = 'MLGSLGLWALLPTAVEAPPNRRTCVFFEAPGVRGSTKTLGELLDTGTELPRAIRCLYSRCCFGIWNLTQDRAQVEMQGCRDSDEPGCESLHCDPSPRAHPSPGSTLFTCSCGTDFCNANYSHLPPPGSPGTPGSQGPQAAPGESIWMALVLLGLFLLLLLLLGSIILALLQRKNYRVRGEPVPEPRPDSGRDWSVELQELPELCFSQVIREGGHAVVWAGQLQGKLVAIKAFPPRSVAQFQAERALYELPGLQHDHIVRFITASRGGPGRLLSGPLLVLELHPKGSLCHYLTQYTSDWGSSLRMALSLAQGLAFLHEERWQNGQYKPGIAHRDLSSQNVLIREDGSCAIGDLGLALVLPGLTQPPAWTPTQPQGPAAIMEAGTQRYMAPELLDKTLDLQDWGMALRRADIYSLALLLWEILSRCPDLRPDSSPPPFQLAYEAELGNTPTSDELWALAVQERRRPYIPSTWRCFATDPDGLRELLEDCWDADPEARLTAECVQQRLAALAHPQESHPFPESCPRGCPPLCPEDCTSIPAPTILPCRPQRSACHFSVQQGPCSRNPQPACTLSPV' + tfr = 'MMDQARSAFSNLFGGEPLSYTRFSLARQVDGDNSHVEMKLAVDEEENADNNTKANVTKPKRCSGSICYGTIAVIVFFLIGFMIGYLGYCKGVEPKTECERLAGTESPVREEPGEDFPAARRLYWDDLKRKLSEKLDSTDFTGTIKLLNENSYVPREAGSQKDENLALYVENQFREFKLSKVWRDQHFVKIQVKDSAQNSVIIVDKNGRLVYLVENPGGYVAYSKAATVTGKLVHANFGTKKDFEDLYTPVNGSIVIVRAGKITFAEKVANAESLNAIGVLIYMDQTKFPIVNAELSFFGHAHLGTGDPYTPGFPSFNHTQFPPSRSSGLPNIPVQTISRAAAEKLFGNMEGDCPSDWKTDSTCRMVTSESKNVKLTVSNVLKEIKILNIFGVIKGFVEPDHYVVVGAQRDAWGPGAAKSGVGTALLLKLAQMFSDMVLKDGFQPSRSIIFASWSAGDFGSVGATEWLEGYLSSLHLKAFTYINLDKAVLGTSNFKVSASPLLYTLIEKTMQNVKHPVTGQFLYQDSNWASKVEKLTLDNAAFPFLAYSGIPAVSFCFCEDTDYPYLGTTMDTYKELIERIPELNKVARAAAEVAGQFVIKLTHDVELNLDYERYNSQLLSFVRDLNQYRADIKEMGLSLQWLYSARGDFFRATSRLTTDFGNAEKTDRFVMKKLNDRVMRVEYHFLSPYVSPKESPFRHVFWGSGSHTLPALLENLKLRKQNNGAFNETLFRNQLALATWTIQGAANALSGDVWDIDNEF' + gfap = 'MERRRITSAARRSYVSSGEMMVGGLAPGRRLGPGTRLSLARMPPPLPTRVDFSLAGALNAGFKETRASERAEMMELNDRFASYIEKVRFLEQQNKALAAELNQLRAKEPTKLADVYQAELRELRLRLDQLTANSARLEVERDNLAQDLATVRQKLQDETNLRLEAENNLAAYRQEADEATLARLDLERKIESLEEEIRFLRKIHEEEVRELQEQLARQQVHVELDVAKPDLTAALKEIRTQYEAMASSNMHEAEEWYRSKFADLTDAAARNAELLRQAKHEANDYRRQLQSLTCDLESLRGTNESLERQMREQEERHVREAASYQEALARLEEEGQSLKDEMARHLQEYQDLLNVKLALDIEIATYRKLLEGEENRITIPVQTFSNLQIRETSLDTKSVSEGHLKRNIVVKTVEMRDGEVIKESKQEHKDVM' + glp1 = 'MAGAPGPLRLALLLLGMVGRAGPRPQGATVSLWETVQKWREYRRQCQRSLTEDPPPATDLFCNRTFDEYACWPDGEPGSFVNVSCPWYLPWASSVPQGHVYRFCTAEGLWLQKDNSSLPWRDLSECEESKRGERSSPEEQLLFLYIIYTVGYALSFSALVIASAILLGFRHLHCTRNYIHLNLFASFILRALSVFIKDAALKWMYSTAAQQHQWDGLLSYQDSLSCRLVFLLMQYCVAANYYWLLVEGVYLYTLLAFSVLSEQWIFRLYVSIGWGVPLLFVVPWGIVKYLYEDEGCWTRNSNMNYWLIIRLPILFAIGVNFLIFVRVICIVVSKLKANLMCKTDIKCRLAKSTLTLIPLLGTHEVIFAFVMDEHARGTLRFIKLFTELSFTSFQGLMVAILYCFVNNEVQLEFRKSWERWRLEHLHIQRDSSMKPLKCPTSSLSSGATAGSSMYTATCQASCS' + glast = 'MTKSNGEEPKMGGRMERFQQGVRKRTLLAKKKVQNITKEDVKSYLFRNAFVLLTVTAVIVGTILGFTLRPYRMSYREVKYFSFPGELLMRMLQMLVLPLIISSLVTGMAALDSKASGKMGMRAVVYYMTTTIIAVVIGIIIVIIIHPGKGTKENMHREGKIVRVTAADAFLDLIRNMFPPNLVEACFKQFKTNYEKRSFKVPIQANETLVGAVINNVSEAMETLTRITEELVPVPGSVNGVNALGLVVFSMCFGFVIGNMKEQGQALREFFDSLNEAIMRLVAVIMWYAPVGILFLIAGKIVEMEDMGVIGGQLAMYTVTVIVGLLIHAVIVLPLLYFLVTRKNPWVFIGGLLQALITALGTSSSSATLPITFKCLEENNGVDKRVTRFVLPVGATINMDGTALYEALAAIFIAQVNNFELNFGQIITISITATAASIGAAGIPQAGLVTMVIVLTSVGLPTDDITLIIAVDWFLDRLRTTTNVLGDSLGAGIVEHLSRHELKNRDVEMGNSVIEENEMKKPYQLIAQDNETEKPIDSETKM' + ncam = 'LQTKDLIWTLFFLGTAVSLQVDIVPSQGEISVGESKFFLCQVAGDAKDKDISWFSPNGEKLTPNQQRISVVWNDDSSSTLTIYNANIDDAGIYKCVVTGEDGSESEATVNVKIFQKLMFKNAPTPQEFREGEDAVIVCDVVSSLPPTIIWKHKGRDVILKKDVRFIVLSNNYLQIRGIKKTDEGTYRCEGRILARGEINFKDIQVIVNVPPTIQARQNIVNATANLGQSVTLVCDAEGFPEPTMSWTKDGEQIEQEEDDEKYIFSDDSSQLTIKKVDKNDEAEYICIAENKAGEQDATIHLKVFAKPKITYVENQTAMELEEQVTLTCEASGDPIPSITWRTSTRNISSEEKASWTRPEKQETLDGHMVVRSHARVSSLTLKSIQYTDAGEYICTASNTIGQDSQSMYLEVQYAPKLQGPVAVYTWEGNQVNITCEVFAYPSATISWFRDGQLLPSSNYSNIKIYNTPSASYLEVTPDSENDFGNYNCTAVNRIGQESLEFILVQADTPSSPSIDQVEPYSSTAQVQFDEPEATGGVPILKYKAEWRAVGEEVWHSKWYDAKEASMEGIVTIVGLKPETTYAVRLAALNGKGLGEISAASEF' + cereblon = 'MAGEGDQQDAAHNMGNHLPLLPAESEEEDEMEVEDQDSKEAKKPNIINFDTSLPTSHTYLGADMEEFHGRTLHDDDSCQVIPVLPQVMMILIPGQTLPLQLFHPQEVSMVRNLIQKDRTFAVLAYSNVQEREAQFGTTAEIYAYREEQDFGIEIVKVKAIGRQRFKVLELRTQSDGIQQAKVQILPECVLPSTMSAVQLESLNKCQIFPSKPVSREDQCSYKWWQKYQKRKFHCANLTSWPRWLYSLYDAETLMDRIKKQLREWDENLKDDSLPSNPIDFSYRVAACLPIDDVLRIQLLKIGSAIQRLRCELDIMNKCTSLCCKQCQETEITTKNEIFSLSLCGPMAAYVNPHGYVHETLTVYKACNLNLIGRPSTEHSWFPGYAWTVAQCKICASHIGWKFTATKKDMSPQKFWGLTRSALLPTIPDTEDEISPDKVILCL' + + num_iter = 0 + score_func_times = [0, 1, 2, 3, 4, 5] + + scoring = ScoringFunctions(score_func_names=['binding_affinity1', 'solubility', 'hemolysis', 'nonfouling', 'permeability'], prot_seqs=[tfr]) + + smiles = ['N2[C@H](CC(C)C)C(=O)N1[C@@H](CCC1)C(=O)N1[C@@H](CCC1)C(=O)N1[C@@H](CCC1)C(=O)N[C@@H](Cc1ccccc1C(F)(F)F)C(=O)N1[C@@H](CCC1)C(=O)N[C@@H](CCSC)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CCCNC(=N)N)C(=O)N[C@@H](CC(=O)N)C2(=O)'] + + scores = scoring(input_seqs=smiles) + print(scores) + print(len(scores)) + +if __name__ == '__main__': + unittest() \ No newline at end of file diff --git a/smiles/smiles_classifiers.py b/smiles/smiles_classifiers.py new file mode 100644 index 0000000000000000000000000000000000000000..806f2337bb15e4133d11bcb574e183d53be0a4a7 --- /dev/null +++ b/smiles/smiles_classifiers.py @@ -0,0 +1,281 @@ +import sys +import os +sys.path.append('/scratch/pranamlab/tong/ReDi_discrete/smiles') +import xgboost as xgb +import torch +import numpy as np +from transformers import AutoModelForMaskedLM +from smiles_tokenizer.my_tokenizers import SMILES_SPE_Tokenizer +import warnings +import numpy as np +import esm +import torch.nn as nn +from rdkit import Chem, rdBase, DataStructs + +rdBase.DisableLog('rdApp.error') +warnings.filterwarnings("ignore", category=DeprecationWarning) +warnings.filterwarnings("ignore", category=UserWarning) +warnings.filterwarnings("ignore", category=FutureWarning) + +class Analyzer: + + def __init__(self, tokenizer): + self.tokenizer = tokenizer + + def get_scores(self, x): + """Check if the SMILES represents a peptide structure""" + results = [] + + smiles_list = self.tokenizer.batch_decode(x) + for smiles in smiles_list: + mol = Chem.MolFromSmiles(smiles) + if mol is None: + results.append(0) + continue + + # Look for peptide bonds: NC(=O) pattern + peptide_bond_pattern = Chem.MolFromSmarts('[NH][C](=O)') + + # Look for N-methylated peptide bonds: N(C)C(=O) pattern + n_methyl_pattern = Chem.MolFromSmarts('[N;H0;$(NC)](C)[C](=O)') + + if mol.HasSubstructMatch(peptide_bond_pattern) or mol.HasSubstructMatch(n_methyl_pattern): + results.append(1) + else: + results.append(0) + + return torch.tensor(results) + + def __call__(self, x): + scores = self.get_scores(x) + return torch.tensor(scores) + + +class Hemolysis: + + def __init__(self, device): + self.predictor = xgb.Booster(model_file='/scratch/pranamlab/tong/ReDi_discrete/smiles/scoring/checkpoints/hemolysis-xgboost.json') + self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device) + + def get_scores(self, x): + scores = np.ones(len(x)) + features = np.array(self.emb_model(input_ids=x).last_hidden_state.mean(dim=1).detach().cpu()) + + if len(features) == 0: + return scores + + features = np.nan_to_num(features, nan=0.) + features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max) + + features = xgb.DMatrix(features) + + probs = self.predictor.predict(features) + # return the probability of it being not hemolytic + return scores - probs + + def __call__(self, x): + scores = self.get_scores(x) + return torch.tensor(scores) + + +class Nonfouling: + + def __init__(self, device): + self.predictor = xgb.Booster(model_file='/scratch/pranamlab/tong/ReDi_discrete/smiles/scoring/checkpoints/nonfouling-xgboost.json') + self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device) + + def get_scores(self, x): + scores = np.zeros(len(x)) + features = np.array(self.emb_model(input_ids=x).last_hidden_state.mean(dim=1).detach().cpu()) + + if len(features) == 0: + return scores + + features = np.nan_to_num(features, nan=0.) + features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max) + + features = xgb.DMatrix(features) + + scores = self.predictor.predict(features) + # return the probability of it being not hemolytic + return scores + + def __call__(self, x): + scores = self.get_scores(x) + return torch.tensor(scores) + + +class Solubility: + def __init__(self, device): + self.predictor = xgb.Booster(model_file='/scratch/pranamlab/tong/ReDi_discrete/smiles/scoring/checkpoints/solubility-xgboost.json') + self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device) + + def get_scores(self, x): + scores = np.zeros(len(x)) + features = np.array(self.emb_model(input_ids=x).last_hidden_state.mean(dim=1).detach().cpu()) + + if len(features) == 0: + return scores + + features = np.nan_to_num(features, nan=0.) + features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max) + + features = xgb.DMatrix(features) + + scores = self.predictor.predict(features) + return scores + + def __call__(self, x): + scores = self.get_scores(x) + return torch.tensor(scores) + + +class ImprovedBindingPredictor(nn.Module): + def __init__(self, + esm_dim=1280, + smiles_dim=768, + hidden_dim=512, + n_heads=8, + n_layers=3, + dropout=0.1): + super().__init__() + + # Define binding thresholds + self.tight_threshold = 7.5 # Kd/Ki/IC50 ≤ ~30nM + self.weak_threshold = 6.0 # Kd/Ki/IC50 > 1μM + + # Project to same dimension + self.smiles_projection = nn.Linear(smiles_dim, hidden_dim) + self.protein_projection = nn.Linear(esm_dim, hidden_dim) + self.protein_norm = nn.LayerNorm(hidden_dim) + self.smiles_norm = nn.LayerNorm(hidden_dim) + + # Cross attention blocks with layer norm + self.cross_attention_layers = nn.ModuleList([ + nn.ModuleDict({ + 'attention': nn.MultiheadAttention(hidden_dim, n_heads, dropout=dropout), + 'norm1': nn.LayerNorm(hidden_dim), + 'ffn': nn.Sequential( + nn.Linear(hidden_dim, hidden_dim * 4), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim * 4, hidden_dim) + ), + 'norm2': nn.LayerNorm(hidden_dim) + }) for _ in range(n_layers) + ]) + + # Prediction heads + self.shared_head = nn.Sequential( + nn.Linear(hidden_dim * 2, hidden_dim), + nn.ReLU(), + nn.Dropout(dropout), + ) + + # Regression head + self.regression_head = nn.Linear(hidden_dim, 1) + + # Classification head (3 classes: tight, medium, loose binding) + self.classification_head = nn.Linear(hidden_dim, 3) + + + def get_binding_class(self, affinity): + """Convert affinity values to class indices + 0: tight binding (>= 7.5) + 1: medium binding (6.0-7.5) + 2: weak binding (< 6.0) + """ + if isinstance(affinity, torch.Tensor): + tight_mask = affinity >= self.tight_threshold + weak_mask = affinity < self.weak_threshold + medium_mask = ~(tight_mask | weak_mask) + + classes = torch.zeros_like(affinity, dtype=torch.long) + classes[medium_mask] = 1 + classes[weak_mask] = 2 + return classes + else: + if affinity >= self.tight_threshold: + return 0 # tight binding + elif affinity < self.weak_threshold: + return 2 # weak binding + else: + return 1 # medium binding + + def forward(self, protein_emb, smiles_emb): + protein = self.protein_norm(self.protein_projection(protein_emb)) + smiles = self.smiles_norm(self.smiles_projection(smiles_emb)) + + #protein = protein.transpose(0, 1) + #smiles = smiles.transpose(0, 1) + + # Cross attention layers + for layer in self.cross_attention_layers: + # Protein attending to SMILES + attended_protein = layer['attention']( + protein, smiles, smiles + )[0] + protein = layer['norm1'](protein + attended_protein) + protein = layer['norm2'](protein + layer['ffn'](protein)) + + # SMILES attending to protein + attended_smiles = layer['attention']( + smiles, protein, protein + )[0] + smiles = layer['norm1'](smiles + attended_smiles) + smiles = layer['norm2'](smiles + layer['ffn'](smiles)) + + # Get sequence-level representations + protein_pool = torch.mean(protein, dim=0) + smiles_pool = torch.mean(smiles, dim=0) + + # Concatenate both representations + combined = torch.cat([protein_pool, smiles_pool], dim=-1) + + # Shared features + shared_features = self.shared_head(combined) + + regression_output = self.regression_head(shared_features) + classification_logits = self.classification_head(shared_features) + + return regression_output, classification_logits + +class BindingAffinity: + def __init__(self, prot_seq, device, model_type='PeptideCLM'): + super().__init__() + + self.pep_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device) + + self.model = ImprovedBindingPredictor(smiles_dim=768).to(device) + checkpoint = torch.load('/scratch/pranamlab/tong/ReDi_discrete/smiles/scoring/checkpoints/binding-affinity.pt', weights_only=False) + self.model.load_state_dict(checkpoint['model_state_dict']) + + self.model.eval() + + self.esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D() # load ESM-2 model + self.esm_model.to(device) + self.prot_tokenizer = alphabet.get_batch_converter() # load esm tokenizer + + data = [("target", prot_seq)] + # get tokenized protein + _, _, prot_tokens = self.prot_tokenizer(data) + prot_tokens = prot_tokens.to(device) + with torch.no_grad(): + results = self.esm_model.forward(prot_tokens, repr_layers=[33]) # Example with ESM-2 + prot_emb = results["representations"][33] + + self.prot_emb = prot_emb[0] + self.prot_emb = torch.mean(self.prot_emb, dim=0, keepdim=True).to(device) + + def forward(self, x): + with torch.no_grad(): + scores = [] + pep_emb = self.pep_model(input_ids=x, output_hidden_states=True).last_hidden_state.mean(dim=1, keepdim=True) + for pep in pep_emb: + score, logits = self.model.forward(self.prot_emb, pep) + scores.append(score.item() / 10) + + return torch.tensor(scores) + + def __call__(self, x): + return self.forward(x) diff --git a/smiles/smiles_tokenizer/my_tokenizers.py b/smiles/smiles_tokenizer/my_tokenizers.py new file mode 100644 index 0000000000000000000000000000000000000000..4bcd57c8b958b54461553df2e0a6f60cee470e93 --- /dev/null +++ b/smiles/smiles_tokenizer/my_tokenizers.py @@ -0,0 +1,445 @@ +import collections +import logging +import os +import re +import codecs +import unicodedata +from typing import List, Optional +from transformers import PreTrainedTokenizer +from SmilesPE.tokenizer import SPE_Tokenizer +import torch + +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + with open(vocab_file, "r", encoding="utf-8") as reader: + tokens = reader.readlines() + for index, token in enumerate(tokens): + token = token.rstrip("\n") + vocab[token] = index + return vocab + +class Atomwise_Tokenizer(object): + """Run atom-level SMILES tokenization""" + + def __init__(self): + """ Constructs a atom-level Tokenizer. + """ + # self.regex_pattern = r"(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])" + self.regex_pattern = r"(\([^\(\)]{0,4}\)|\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/\/?|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])" + + self.regex = re.compile(self.regex_pattern) + + def tokenize(self, text): + """ Basic Tokenization of a SMILES. + """ + tokens = [token for token in self.regex.findall(text)] + return tokens + +class SMILES_SPE_Tokenizer(PreTrainedTokenizer): + r""" + Constructs a SMILES tokenizer. Based on SMILES Pair Encoding (https://github.com/XinhaoLi74/SmilesPE). + This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users + should refer to the superclass for more information regarding methods. + Args: + vocab_file (:obj:`string`): + File containing the vocabulary. + spe_file (:obj:`string`): + File containing the trained SMILES Pair Encoding vocabulary. + unk_token (:obj:`string`, `optional`, defaults to "[UNK]"): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + sep_token (:obj:`string`, `optional`, defaults to "[SEP]"): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences + for sequence classification or for a text and a question for question answering. + It is also used as the last token of a sequence built with special tokens. + pad_token (:obj:`string`, `optional`, defaults to "[PAD]"): + The token used for padding, for example when batching sequences of different lengths. + cls_token (:obj:`string`, `optional`, defaults to "[CLS]"): + The classifier token which is used when doing sequence classification (classification of the whole + sequence instead of per-token classification). It is the first token of the sequence when built with + special tokens. + mask_token (:obj:`string`, `optional`, defaults to "[MASK]"): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + """ + + def __init__(self, vocab_file, spe_file, + unk_token="[UNK]", + sep_token="[SEP]", + pad_token="[PAD]", + cls_token="[CLS]", + mask_token="[MASK]", + **kwargs): + if not os.path.isfile(vocab_file): + raise ValueError("Can't find a vocabulary file at path '{}'.".format(vocab_file)) + if not os.path.isfile(spe_file): + raise ValueError("Can't find a SPE vocabulary file at path '{}'.".format(spe_file)) + + self.vocab = load_vocab(vocab_file) + self.spe_vocab = open(spe_file, 'r', encoding='utf-8') + self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()]) + self.spe_tokenizer = SPE_Tokenizer(self.spe_vocab) + + super().__init__( + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + **kwargs) + + @property + def vocab_size(self): + return len(self.vocab) + + def get_vocab(self): + return dict(self.vocab, **self.added_tokens_encoder) + + def _tokenize(self, text): + return self.spe_tokenizer.tokenize(text).split(' ') + + def _convert_token_to_id(self, token): + """ Converts a token (str) in an id using the vocab. """ + return self.vocab.get(token, self.vocab.get(self.unk_token)) + + # changed encode and decode functions + def encode(self, token_array): + token_ids = [] + token_ids.append(2) + for token in token_array: + id = self._convert_token_to_id(token) + token_ids.append(id) + token_ids.append(3) + token_ids = torch.tensor([token_ids]) + attn_mask = torch.ones_like(token_ids) + return {'input_ids': token_ids, 'attention_mask': attn_mask} + + def decode(self, token_ids, skip_special_tokens=True): + token_ids = token_ids.squeeze(0).cpu().tolist() + token_array = [] + for idx in token_ids: + if idx == 3: # Stop decoding when token ID 3 is encountered + break + if skip_special_tokens and idx in self.all_special_ids: + continue + token = self._convert_id_to_token(idx) + token_array.append(token) + sequence = "".join(token_array) + return sequence + + def batch_decode(self, batch_token_ids, skip_special_tokens=True): + sequences = [] + for token_ids in batch_token_ids: + sequences.append(self.decode(token_ids)) + return sequences + + def get_token_split(self, token_ids): + if isinstance(token_ids, torch.Tensor): + token_ids = token_ids.cpu().tolist() + + # import pdb + # pdb.set_trace() + token_array = [] + for seq_ids in token_ids: + seq_array = [] + if isinstance(seq_ids, torch.Tensor): + seq_ids = seq_ids.cpu().tolist() + for id in seq_ids: + token = self._convert_id_to_token(id) + seq_array.append(token) + token_array.append(seq_array) + + return token_array + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.ids_to_tokens.get(index, self.unk_token) + + def convert_tokens_to_string(self, tokens): + """ Converts a sequence of tokens (string) in a single string. """ + out_string = " ".join(tokens).replace(" ##", "").strip() + return out_string + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks + by concatenating and adding special tokens. + A BERT sequence has the following format: + - single sequence: ``[CLS] X [SEP]`` + - pair of sequences: ``[CLS] A [SEP] B [SEP]`` + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs to which the special tokens will be added + token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`): + Optional second list of IDs for sequence pairs. + Returns: + :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. + """ + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + token_ids_1 + sep + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer ``prepare_for_model`` method. + Args: + token_ids_0 (:obj:`List[int]`): + List of ids. + token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): + Set to True if the token list is already formatted with special tokens for the model + Returns: + :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + if token_ids_1 is not None: + raise ValueError( + "You should not supply a second sequence if the provided sequence of " + "ids is already formated with special tokens for the model." + ) + return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) + + if token_ids_1 is not None: + return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1] + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Creates a mask from the two sequences passed to be used in a sequence-pair classification task. + A BERT sequence pair mask has the following format: + :: + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + if token_ids_1 is None, only returns the first portion of the mask (0's). + Args: + token_ids_0 (:obj:`List[int]`): + List of ids. + token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`): + Optional second list of IDs for sequence pairs. + Returns: + :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given + sequence(s). + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] + + def save_vocabulary(self, vocab_path): + """ + Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory. + Args: + vocab_path (:obj:`str`): + The directory in which to save the vocabulary. + Returns: + :obj:`Tuple(str)`: Paths to the files saved. + """ + index = 0 + if os.path.isdir(vocab_path): + vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"]) + else: + vocab_file = vocab_path + with open(vocab_file, "w", encoding="utf-8") as writer: + for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + "Saving vocabulary to {}: vocabulary indices are not consecutive." + " Please check that the vocabulary is not corrupted!".format(vocab_file) + ) + index = token_index + writer.write(token + "\n") + index += 1 + return (vocab_file,) + +class SMILES_Atomwise_Tokenizer(PreTrainedTokenizer): + r""" + Constructs a SMILES tokenizer. Based on SMILES Pair Encoding (https://github.com/XinhaoLi74/SmilesPE). + This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users + should refer to the superclass for more information regarding methods. + Args: + vocab_file (:obj:`string`): + File containing the vocabulary. + unk_token (:obj:`string`, `optional`, defaults to "[UNK]"): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + sep_token (:obj:`string`, `optional`, defaults to "[SEP]"): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences + for sequence classification or for a text and a question for question answering. + It is also used as the last token of a sequence built with special tokens. + pad_token (:obj:`string`, `optional`, defaults to "[PAD]"): + The token used for padding, for example when batching sequences of different lengths. + cls_token (:obj:`string`, `optional`, defaults to "[CLS]"): + The classifier token which is used when doing sequence classification (classification of the whole + sequence instead of per-token classification). It is the first token of the sequence when built with + special tokens. + mask_token (:obj:`string`, `optional`, defaults to "[MASK]"): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + """ + + def __init__( + self, + vocab_file, + unk_token="[UNK]", + sep_token="[SEP]", + pad_token="[PAD]", + cls_token="[CLS]", + mask_token="[MASK]", + **kwargs + ): + super().__init__( + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + **kwargs, + ) + + if not os.path.isfile(vocab_file): + raise ValueError( + "Can't find a vocabulary file at path '{}'.".format(vocab_file) + ) + self.vocab = load_vocab(vocab_file) + self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()]) + self.tokenizer = Atomwise_Tokenizer() + + @property + def vocab_size(self): + return len(self.vocab) + + def get_vocab(self): + return dict(self.vocab, **self.added_tokens_encoder) + + + def _tokenize(self, text): + return self.tokenizer.tokenize(text) + + def _convert_token_to_id(self, token): + """ Converts a token (str) in an id using the vocab. """ + return self.vocab.get(token, self.vocab.get(self.unk_token)) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.ids_to_tokens.get(index, self.unk_token) + + def convert_tokens_to_string(self, tokens): + """ Converts a sequence of tokens (string) in a single string. """ + out_string = " ".join(tokens).replace(" ##", "").strip() + return out_string + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks + by concatenating and adding special tokens. + A BERT sequence has the following format: + - single sequence: ``[CLS] X [SEP]`` + - pair of sequences: ``[CLS] A [SEP] B [SEP]`` + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs to which the special tokens will be added + token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`): + Optional second list of IDs for sequence pairs. + Returns: + :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. + """ + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + token_ids_1 + sep + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer ``prepare_for_model`` method. + Args: + token_ids_0 (:obj:`List[int]`): + List of ids. + token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): + Set to True if the token list is already formatted with special tokens for the model + Returns: + :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + if token_ids_1 is not None: + raise ValueError( + "You should not supply a second sequence if the provided sequence of " + "ids is already formated with special tokens for the model." + ) + return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) + + if token_ids_1 is not None: + return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1] + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Creates a mask from the two sequences passed to be used in a sequence-pair classification task. + A BERT sequence pair mask has the following format: + :: + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + if token_ids_1 is None, only returns the first portion of the mask (0's). + Args: + token_ids_0 (:obj:`List[int]`): + List of ids. + token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`): + Optional second list of IDs for sequence pairs. + Returns: + :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given + sequence(s). + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] + + def save_vocabulary(self, vocab_path): + """ + Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory. + Args: + vocab_path (:obj:`str`): + The directory in which to save the vocabulary. + Returns: + :obj:`Tuple(str)`: Paths to the files saved. + """ + index = 0 + if os.path.isdir(vocab_path): + vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"]) + else: + vocab_file = vocab_path + with open(vocab_file, "w", encoding="utf-8") as writer: + for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + "Saving vocabulary to {}: vocabulary indices are not consecutive." + " Please check that the vocabulary is not corrupted!".format(vocab_file) + ) + index = token_index + writer.write(token + "\n") + index += 1 + return (vocab_file,) diff --git a/smiles/smiles_tokenizer/new_splits.txt b/smiles/smiles_tokenizer/new_splits.txt new file mode 100644 index 0000000000000000000000000000000000000000..e48784ed46b8515935fe27feaeff177b63e5da5b --- /dev/null +++ b/smiles/smiles_tokenizer/new_splits.txt @@ -0,0 +1,159 @@ +c 1 +c 2 +c 3 +c 4 +c 5 +c 6 +c 7 +c 8 +c 9 +( c1 +( c2 +c1 ) +c2 ) +n 1 +n 2 +n 3 +n 4 +n 5 +n 6 +n 7 +n 8 +n 9 +( n1 +( n2 +n1 ) +n2 ) +O 1 +O 2 +O 3 +O 4 +O 5 +O 6 +O 7 +O 8 +O 9 +( O1 +( O2 +O2 ) +O2 ) += O += C += c += N += n +=C C +=C N +=C c +=c c +=N C +=N c +=n C +=n c +# N +# C +#N C +#C C +#C N +#N N +( C +C ) +( O +O ) +( N +N ) +Br c +( =O +(=O ) +C (=O) +C =O +C =N +C #N +C #C +C C +CC C +CC N +CC O +CC S +CC c +CC n +C N +CN C +CN c +C O +CO C +CO N +CO c +C S +CS C +CS S +CS c +C c +Cl c +C n +F c +N C +NC C +NC c +N N +N O +N c +N n +O C +OC C +OC O +OC c +O N +O O +O c +S C +SC C +SC c +S S +S c +c c +cc c +cc n +cc o +cc s +cc cc +c n +cn c +cn n +c o +co c +c s +cs c +cs n +n c +nc c +nc n +nc o +nc s +n n +nn c +nn n +n o +no c +no n +n s +ns c +ns n +o c +oc c +o n +s c +sc c +sc n +s n +N P +P N +C P +P C +N S +S N +C S +S C +S P +P S +C I \ No newline at end of file diff --git a/smiles/smiles_tokenizer/new_vocab.txt b/smiles/smiles_tokenizer/new_vocab.txt new file mode 100644 index 0000000000000000000000000000000000000000..1127d24d4cea8480d413e57949b4c6742dfd5fe9 --- /dev/null +++ b/smiles/smiles_tokenizer/new_vocab.txt @@ -0,0 +1,587 @@ +[PAD] +[UNK] +[CLS] +[SEP] +[MASK] +# +% +( +) ++ +- +/ +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 += +@ +A +B +Br +Brc +C +CC +CCC +CCN +CCO +CCS +CCc +CCn +CN +CNC +CNc +CO +COC +CON +COc +CS +CSC +CSS +CSc +Cc +Cl +Clc +Cn +F +Fc +H +I +K +L +M +N +NC +NCC +NCc +NN +NO +Nc +Nn +O +OC +OCC +OCO +OCc +ON +OO +Oc +P +R +S +SC +SCC +SCc +SS +Sc +T +X +Z +[ +\\ +(/ +] +a +b +c +cc +ccc +cccc +ccn +cco +ccs +cn +cnc +cnn +co +coc +cs +csc +csn +e +g +i +l +n +nc +ncc +ncn +nco +ncs +nn +nnc +nnn +no +noc +non +ns +nsc +nsn +o +oc +occ +on +p +r +s +sc +scc +scn +sn +t +c1 +c2 +c3 +c4 +c5 +c6 +c7 +c8 +c9 +n1 +n2 +n3 +n4 +n5 +n6 +n7 +n8 +n9 +O1 +O2 +O3 +O4 +O5 +O6 +O7 +O8 +O9 +(c1 +(c2 +c1) +c2) +(n1 +(n2 +n1) +n2) +(O1 +(O2 +O2) +=O +=C +=c +=N +=n +=CC +=CN +=Cc +=cc +=NC +=Nc +=nC +=nc +#C +#CC +#CN +#N +#NC +#NN +(C +C) +(O +O) +(N +N) +NP +PN +CP +PC +NS +SN +SP +PS +C(=O) +(/Br) +(/C#N) +(/C) +(/C=N) +(/C=O) +(/CBr) +(/CC) +(/CCC) +(/CCF) +(/CCN) +(/CCO) +(/CCl) +(/CI) +(/CN) +(/CO) +(/CS) +(/Cl) +(/F) +(/I) +(/N) +(/NC) +(/NCC) +(/NO) +(/O) +(/OC) +(/OCC) +(/S) +(/SC) +(=C) +(=C/C) +(=C/F) +(=C/I) +(=C/N) +(=C/O) +(=CBr) +(=CC) +(=CCF) +(=CCN) +(=CCO) +(=CCl) +(=CF) +(=CI) +(=CN) +(=CO) +(=C\\C) +(=C\\F) +(=C\\I) +(=C\\N) +(=C\\O) +(=N) +(=N/C) +(=N/N) +(=N/O) +(=NBr) +(=NC) +(=NCC) +(=NCl) +(=NN) +(=NO) +(=NOC) +(=N\\C) +(=N\\N) +(=N\\O) +(=O) +(=S) +(B) +(Br) +(C#C) +(C#CC) +(C#CI) +(C#CO) +(C#N) +(C#SN) +(C) +(C=C) +(C=CF) +(C=CI) +(C=N) +(C=NN) +(C=NO) +(C=O) +(C=S) +(CBr) +(CC#C) +(CC#N) +(CC) +(CC=C) +(CC=O) +(CCBr) +(CCC) +(CCCC) +(CCCF) +(CCCI) +(CCCN) +(CCCO) +(CCCS) +(CCCl) +(CCF) +(CCI) +(CCN) +(CCNC) +(CCNN) +(CCNO) +(CCO) +(CCOC) +(CCON) +(CCS) +(CCSC) +(CCl) +(CF) +(CI) +(CN) +(CN=O) +(CNC) +(CNCC) +(CNCO) +(CNN) +(CNNC) +(CNO) +(CNOC) +(CO) +(COC) +(COCC) +(COCI) +(COCN) +(COCO) +(COF) +(CON) +(COO) +(CS) +(CSC) +(CSCC) +(CSCF) +(CSO) +(Cl) +(F) +(I) +(N) +(N=N) +(N=NO) +(N=O) +(N=S) +(NBr) +(NC#N) +(NC) +(NC=N) +(NC=O) +(NC=S) +(NCBr) +(NCC) +(NCCC) +(NCCF) +(NCCN) +(NCCO) +(NCCS) +(NCCl) +(NCNC) +(NCO) +(NCS) +(NCl) +(NN) +(NN=O) +(NNC) +(NO) +(NOC) +(O) +(OC#N) +(OC) +(OC=C) +(OC=O) +(OC=S) +(OCBr) +(OCC) +(OCCC) +(OCCF) +(OCCI) +(OCCN) +(OCCO) +(OCCS) +(OCCl) +(OCF) +(OCI) +(OCO) +(OCOC) +(OCON) +(OCSC) +(OCl) +(OI) +(ON) +(OO) +(OOC) +(OOCC) +(OOSN) +(OSC) +(P) +(S) +(SC#N) +(SC) +(SCC) +(SCCC) +(SCCF) +(SCCN) +(SCCO) +(SCCS) +(SCCl) +(SCF) +(SCN) +(SCOC) +(SCSC) +(SCl) +(SI) +(SN) +(SN=O) +(SO) +(SOC) +(SOOO) +(SS) +(SSC) +(SSCC) +([At]) +([O-]) +([O]) +([S-]) +(\\Br) +(\\C#N) +(\\C) +(\\C=N) +(\\C=O) +(\\CBr) +(\\CC) +(\\CCC) +(\\CCO) +(\\CCl) +(\\CF) +(\\CN) +(\\CNC) +(\\CO) +(\\COC) +(\\Cl) +(\\F) +(\\I) +(\\N) +(\\NC) +(\\NCC) +(\\NN) +(\\NO) +(\\NOC) +(\\O) +(\\OC) +(\\OCC) +(\\ON) +(\\S) +(\\SC) +(\\SCC) +[Ag+] +[Ag-4] +[Ag] +[Al-3] +[Al] +[As+] +[AsH3] +[AsH] +[As] +[At] +[B-] +[B@-] +[B@@-] +[BH-] +[BH2-] +[BH3-] +[B] +[Ba] +[Br+2] +[BrH] +[Br] +[C+] +[C-] +[C@@H] +[C@@] +[C@H] +[C@] +[CH-] +[CH2] +[CH3] +[CH] +[C] +[CaH2] +[Ca] +[Cl+2] +[Cl+3] +[Cl+] +[Cs] +[FH] +[F] +[H] +[He] +[I+2] +[I+3] +[I+] +[IH] +[I] +[K] +[Kr] +[Li+] +[LiH] +[MgH2] +[Mg] +[N+] +[N-] +[N@+] +[N@@+] +[N@@] +[N@] +[NH+] +[NH-] +[NH2+] +[NH3] +[NH] +[N] +[Na] +[O+] +[O-] +[OH+] +[OH2] +[OH] +[O] +[P+] +[P@+] +[P@@+] +[P@@] +[P@] +[PH2] +[PH] +[P] +[Ra] +[Rb] +[S+] +[S-] +[S@+] +[S@@+] +[S@@] +[S@] +[SH+] +[SH2] +[SH] +[S] +[Se+] +[Se-2] +[SeH2] +[SeH] +[Se] +[Si@] +[SiH2] +[SiH] +[Si] +[SrH2] +[TeH] +[Te] +[Xe] +[Zn+2] +[Zn-2] +[Zn] +[b-] +[c+] +[c-] +[cH-] +[cH] +[c] +[n+] +[n-] +[nH] +[n] +[o+] +[s+] +[se+] +[se] +[te+] +[te] \ No newline at end of file diff --git a/smiles/train.py b/smiles/train.py new file mode 100644 index 0000000000000000000000000000000000000000..7a4b3134032c2013cf90658fa0669f15fdf5f6bf --- /dev/null +++ b/smiles/train.py @@ -0,0 +1,427 @@ +import argparse +import math +import os +from functools import partial +from collections import Counter + +import torch +import torch.nn as nn +import torch.nn.functional as F +from datasets import load_from_disk +from torch.optim import AdamW +from torch.optim.lr_scheduler import LambdaLR +from torch.utils.data import DataLoader +import pytorch_lightning as pl +from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor +from pytorch_lightning.loggers import WandbLogger +from pytorch_lightning.strategies import DDPStrategy +from rdkit import Chem + +from smiles_tokenizer.my_tokenizers import SMILES_SPE_Tokenizer +from peptide_analyzer import PeptideAnalyzer +import dataloading_for_dynamic_batching as dynamic_dataloader + + +class RotaryPositionalEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000): + super().__init__() + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + + def forward(self, x, seq_len=None): + if seq_len is None: + seq_len = x.shape[1] + + t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + + cos_emb = emb.cos()[None, :, :] + sin_emb = emb.sin()[None, :, :] + + return cos_emb, sin_emb + +def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + +def apply_rotary_pos_emb(q, k, cos, sin): + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + +# --- Model Architecture with RoPE --- +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + +class TimestepEmbedder(nn.Module): + def __init__(self, hidden_size): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(1, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + + def forward(self, t): + return self.mlp(t.unsqueeze(-1)) + +class MultiHeadAttentionWithRoPE(nn.Module): + def __init__(self, hidden_size, n_heads): + super().__init__() + self.hidden_size = hidden_size + self.n_heads = n_heads + self.head_dim = hidden_size // n_heads + + assert self.head_dim * n_heads == hidden_size, "hidden_size must be divisible by n_heads" + + self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False) + self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False) + self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False) + self.out_proj = nn.Linear(hidden_size, hidden_size) + + self.rope = RotaryPositionalEmbedding(self.head_dim) + + def forward(self, x): + batch_size, seq_len, hidden_size = x.shape + + q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) + k = self.k_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) + v = self.v_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rope(q, seq_len) + q, k = apply_rotary_pos_emb(q, k, cos, sin) + + scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) + attn_weights = F.softmax(scores, dim=-1) + attn_output = torch.matmul(attn_weights, v) + + attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, hidden_size) + output = self.out_proj(attn_output) + + return output + +class DiTBlock(nn.Module): + def __init__(self, hidden_size, n_heads): + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = MultiHeadAttentionWithRoPE(hidden_size, n_heads) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.mlp = nn.Sequential( + nn.Linear(hidden_size, 4 * hidden_size), + nn.GELU(), + nn.Linear(4 * hidden_size, hidden_size) + ) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) + x_norm1 = modulate(self.norm1(x), shift_msa, scale_msa) + attn_output = self.attn(x_norm1) + x = x + gate_msa.unsqueeze(1) * attn_output + x_norm2 = modulate(self.norm2(x), shift_mlp, scale_mlp) + mlp_output = self.mlp(x_norm2) + x = x + gate_mlp.unsqueeze(1) * mlp_output + return x + +class MDLM(nn.Module): + def __init__(self, vocab_size, model_dim, n_heads, n_layers): + super().__init__() + self.vocab_size = vocab_size + self.model_dim = model_dim + self.mask_token_id = vocab_size + + self.token_embedder = nn.Embedding(vocab_size, model_dim) + self.time_embedder = TimestepEmbedder(model_dim) + + self.transformer_blocks = nn.ModuleList([ + DiTBlock(model_dim, n_heads) for _ in range(n_layers) + ]) + + self.final_norm = nn.LayerNorm(model_dim) + self.lm_head = nn.Linear(model_dim, vocab_size) + + self.apply(self._init_weights) + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.02) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + if module.bias is not None: + module.bias.data.zero_() + if module.weight is not None: + module.weight.data.fill_(1.0) + + def forward(self, x, t): + x_embed = self.token_embedder(x) + t_embed = self.time_embedder(t) + for block in self.transformer_blocks: + x_embed = block(x_embed, t_embed) + x_embed = self.final_norm(x_embed) + logits = self.lm_head(x_embed) + return logits + +# --- PyTorch Lightning Module --- +class MDLMLightningModule(pl.LightningModule): + def __init__(self, args, tokenizer): + super().__init__() + self.save_hyperparameters(ignore=['tokenizer']) + self.args = args + self.tokenizer = tokenizer + self.peptide_analyzer = PeptideAnalyzer() + + # Initialize model + self.model = MDLM( + vocab_size=tokenizer.vocab_size, + model_dim=args.model_dim, + n_heads=args.n_heads, + n_layers=args.n_layers + ) + + # For tracking steps + self.automatic_optimization = True + + self.validation_step_outputs = [] + + def forward(self, x, t): + return self.model(x, t) + + def _compute_invalid_loss(self, logits): + batch_token_ids = torch.argmax(logits, dim=-1) # (batch_size, seq_length) + + sampled_sequences = self.tokenizer.batch_decode(batch_token_ids) + + penalties = torch.tensor( + [1 if not self.peptide_analyzer.is_peptide(seq) else 0 for seq in sampled_sequences], + dtype=torch.float32, + device=self.device + ) + sampled_probs = torch.softmax(logits, dim=-1).gather(dim=-1, index=batch_token_ids.unsqueeze(-1)).squeeze(-1).to(self.device) + + scaled_penalty = penalties[:, None] * sampled_probs # (batch_size, seq_length) + + return scaled_penalty + + def _loss(self, logits, x_1, attn_mask): + # Standard cross-entropy loss + ce_loss = F.cross_entropy( + logits.view(-1, self.model.vocab_size), + x_1.view(-1), + reduction='none' + ).view(x_1.shape[0], -1) + # ce_loss = (ce_loss * attn_mask).sum() / attn_mask.sum() + + # validity_weight = self.args.validity_weight * min(1.0, (self.current_epoch + 1) / self.trainer.max_epochs) + + invalid_loss = self._compute_invalid_loss(logits) # (batch_size, seq_length) + + loss = ce_loss + self.args.validity_weight * invalid_loss + nlls = loss * attn_mask + + num_tokens = attn_mask.sum() + batch_nll = nlls.sum() + token_nll = batch_nll / num_tokens + + return token_nll, (ce_loss*attn_mask).sum() / num_tokens, (invalid_loss*attn_mask).sum() / num_tokens + + def training_step(self, batch, batch_idx): + x_1 = batch['input_ids'].clone().detach().to(self.device) + attn_mask = batch['attention_mask'].clone().detach().to(self.device) + bond_mask = batch['bond_mask'].clone().detach().to(self.device).bool() + batch_size, _ = x_1.shape + + # ReDi approach: random start -> target + x_0 = torch.randint(0, self.model.vocab_size, x_1.shape, device=self.device) + t_continuous = torch.rand(batch_size, device=self.device) + # mask = torch.rand(x_1.shape, device=self.device) < t_continuous.view(-1, 1) + # x_t = torch.where(mask, x_1, x_0) + + peptide_bond_prob = t_continuous.view(-1, 1) ** self.args.gamma # slower increase + non_peptide_prob = t_continuous.view(-1, 1) # linear increase + + masking_prob = torch.where(bond_mask, peptide_bond_prob, non_peptide_prob) + mask = torch.rand(x_1.shape, device=self.device) < masking_prob + x_t = torch.where(mask, x_1, x_0) + + logits = self.model(x_t, t_continuous) + + token_nll, ce_loss, invalid_loss = self._loss(logits, x_1, attn_mask) + + # Logging + self.log('train/token_nll', token_nll.item(), on_step=True, on_epoch=True, prog_bar=True, batch_size=x_1.size(0), sync_dist=True) + self.log('train/ce_loss', ce_loss.item(), on_step=True, on_epoch=True, batch_size=x_1.size(0), sync_dist=True) + self.log('train/invalid_loss', invalid_loss.item(), on_step=True, on_epoch=True, batch_size=x_1.size(0), sync_dist=True) + # self.log('train/validity_weight', validity_weight, on_step=False, on_epoch=True, batch_size=x_1.size(0)) + + return token_nll + + def validation_step(self, batch, batch_idx): + x_1 = batch['input_ids'].clone().detach().to(self.device) + attn_mask = batch['attention_mask'].clone().detach().to(self.device) + bond_mask = batch['bond_mask'].clone().detach().to(self.device).bool() + batch_size, _ = x_1.shape + + # ReDi approach: random start -> target + x_0 = torch.randint(0, self.model.vocab_size, x_1.shape, device=self.device) + t_continuous = torch.rand(batch_size, device=self.device) + + peptide_bond_prob = t_continuous.view(-1, 1) ** self.args.gamma # slower increase + non_peptide_prob = t_continuous.view(-1, 1) # linear increase + + masking_prob = torch.where(bond_mask, peptide_bond_prob, non_peptide_prob) + mask = torch.rand(x_1.shape, device=self.device) < masking_prob + x_t = torch.where(mask, x_1, x_0) + + logits = self.model(x_t, t_continuous) + + token_nll, ce_loss, invalid_loss = self._loss(logits, x_1, attn_mask) + self.log('val/token_nll', token_nll.item(), on_step=True, on_epoch=True, prog_bar=True, batch_size=x_1.size(0), sync_dist=True) + self.log('val/ce_loss', ce_loss.item(), on_step=True, on_epoch=True, batch_size=x_1.size(0), sync_dist=True) + self.log('val/invalid_loss', invalid_loss.item(), on_step=True, on_epoch=True, batch_size=x_1.size(0), sync_dist=True) + + def configure_optimizers(self): + optimizer = AdamW( + self.parameters(), + lr=self.args.learning_rate, + weight_decay=self.args.weight_decay + ) + + # Calculate total steps + if hasattr(self.trainer, 'estimated_stepping_batches'): + num_training_steps = self.trainer.estimated_stepping_batches + else: + # Fallback calculation + num_training_steps = len(self.trainer.datamodule.train_dataloader()) * self.trainer.max_epochs + + warmup_steps = int(num_training_steps * 0.1) + + def lr_lambda(current_step): + if current_step < warmup_steps: + lr_range = self.args.learning_rate - (self.args.learning_rate * 0.1) + lr = (self.args.learning_rate * 0.1) + lr_range * (current_step / warmup_steps) + return lr / self.args.learning_rate + else: + progress = (current_step - warmup_steps) / (num_training_steps - warmup_steps) + cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress)) + lr_range = self.args.learning_rate - (self.args.learning_rate * 0.1) + lr = (self.args.learning_rate * 0.1) + lr_range * cosine_decay + return lr / self.args.learning_rate + + scheduler = LambdaLR(optimizer, lr_lambda) + + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "interval": "step", + "frequency": 1, + }, + } + + +# --- Main Execution --- +def main(args): + # Set up checkpoint directory + checkpoint_dir = (args.checkpoint_dir + + f"correct_lr{args.learning_rate}_wd{args.weight_decay}_layer{args.n_layers}_" + f"head{args.n_heads}_valweight{args.validity_weight}") + print(f"Saving to {checkpoint_dir}") + os.makedirs(checkpoint_dir, exist_ok=True) + + print("Loading tokenizer...") + tokenizer = SMILES_SPE_Tokenizer('/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/old_vocab.txt', + '/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/old_splits.txt') + print(f"Tokenizer loaded. Vocab size: {tokenizer.vocab_size}") + + # Initialize data module + data_module = dynamic_dataloader.CustomDataModule('./data/11M_smiles_old_tokenizer_no_limit/', tokenizer) + + # Initialize model + model = MDLMLightningModule(args, tokenizer) + + # Set up logger + logger = WandbLogger( + project="smiles-redi-training", # or your preferred project name + entity="programmablebio", + name=f"lr{args.learning_rate}_dim{args.model_dim}_head{args.n_heads}_layer{args.n_layers}", + save_dir=checkpoint_dir + ) + + # Set up callbacks + callbacks = [ + ModelCheckpoint( + dirpath=checkpoint_dir, + filename='best', + monitor='val/token_nll', + mode='min', + save_top_k=1, + save_last=True, + # every_n_train_steps=10000 # This will save every 1000 steps AND when val/nll improves + ), + LearningRateMonitor(logging_interval='step') + ] + + # Initialize trainer + trainer = pl.Trainer( + max_epochs=args.epochs, + devices=torch.cuda.device_count(), + accelerator='gpu', + strategy=DDPStrategy(find_unused_parameters=False), + num_nodes=int(os.environ.get("SLURM_NNODES", 1)), + precision="bf16", + gradient_clip_val=args.grad_clip if args.grad_clip > 0 else None, + callbacks=callbacks, + logger=logger, + log_every_n_steps=100, + check_val_every_n_epoch=True, + # val_check_interval=10000, + accumulate_grad_batches=8, + enable_progress_bar=True, + enable_model_summary=True + ) + + print(f"Model initialized with {sum(p.numel() for p in model.parameters()):,} parameters.") + print("Starting training...") + + # Train the model + trainer.fit(model, data_module) + + print("Training complete.") + print(f"Best checkpoint saved at: {trainer.checkpoint_callback.best_model_path}") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Train ReDi model for SMILES generation with RoPE using PyTorch Lightning") + + # Model arguments + parser.add_argument("--model_dim", type=int, default=1024) + parser.add_argument("--n_heads", type=int, default=8) + parser.add_argument("--n_layers", type=int, default=6) + + # Training arguments + parser.add_argument("--epochs", type=int, default=50) + parser.add_argument("--learning_rate", type=float, default=1e-4) + parser.add_argument("--weight_decay", type=float, default=1e-5) + parser.add_argument("--label_smoothing", type=float, default=0) + parser.add_argument("--grad_clip", type=float, default=1.0) + parser.add_argument("--gamma", type=float, default=2.0) + + # Validity arguments + parser.add_argument("--validity_weight", type=float, default=0.1) + parser.add_argument("--validity_check_freq", type=int, default=10) + parser.add_argument("--validity_eval_batches", type=int, default=20) + + # Logging arguments + parser.add_argument("--checkpoint_dir", type=str, default="./checkpoints_smiles") + + args = parser.parse_args() + main(args) \ No newline at end of file