Open Media Library Platform
This commit is contained in:
commit
411ad5b16f
5849 changed files with 1778641 additions and 0 deletions
|
|
@ -0,0 +1 @@
|
|||
'conch tests'
|
||||
208
Linux/lib/python2.7/site-packages/twisted/conch/test/keydata.py
Normal file
208
Linux/lib/python2.7/site-packages/twisted/conch/test/keydata.py
Normal file
|
|
@ -0,0 +1,208 @@
|
|||
# -*- test-case-name: twisted.conch.test.test_keys -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Data used by test_keys as well as others.
|
||||
"""
|
||||
RSAData = {
|
||||
'n':long('1062486685755247411169438309495398947372127791189432809481'
|
||||
'382072971106157632182084539383569281493520117634129557550415277'
|
||||
'516685881326038852354459895734875625093273594925884531272867425'
|
||||
'864910490065695876046999646807138717162833156501L'),
|
||||
'e':35L,
|
||||
'd':long('6678487739032983727350755088256793383481946116047863373882'
|
||||
'973030104095847973715959961839578340816412167985957218887914482'
|
||||
'713602371850869127033494910375212470664166001439410214474266799'
|
||||
'85974425203903884190893469297150446322896587555L'),
|
||||
'q':long('3395694744258061291019136154000709371890447462086362702627'
|
||||
'9704149412726577280741108645721676968699696898960891593323L'),
|
||||
'p':long('3128922844292337321766351031842562691837301298995834258844'
|
||||
'4720539204069737532863831050930719431498338835415515173887L')}
|
||||
|
||||
DSAData = {
|
||||
'y':long('2300663509295750360093768159135720439490120577534296730713'
|
||||
'348508834878775464483169644934425336771277908527130096489120714'
|
||||
'610188630979820723924744291603865L'),
|
||||
'g':long('4451569990409370769930903934104221766858515498655655091803'
|
||||
'866645719060300558655677517139568505649468378587802312867198352'
|
||||
'1161998270001677664063945776405L'),
|
||||
'p':long('7067311773048598659694590252855127633397024017439939353776'
|
||||
'608320410518694001356789646664502838652272205440894335303988504'
|
||||
'978724817717069039110940675621677L'),
|
||||
'q':1184501645189849666738820838619601267690550087703L,
|
||||
'x':863951293559205482820041244219051653999559962819L}
|
||||
|
||||
publicRSA_openssh = ("ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAGEArzJx8OYOnJmzf4tfBE"
|
||||
"vLi8DVPrJ3/c9k2I/Az64fxjHf9imyRJbixtQhlH9lfNjUIx+4LmrJH5QNRsFporcHDKOTwTTYL"
|
||||
"h5KmRpslkYHRivcJSkbh/C+BR3utDS555mV comment")
|
||||
|
||||
privateRSA_openssh = """-----BEGIN RSA PRIVATE KEY-----
|
||||
MIIByAIBAAJhAK8ycfDmDpyZs3+LXwRLy4vA1T6yd/3PZNiPwM+uH8Yx3/YpskSW
|
||||
4sbUIZR/ZXzY1CMfuC5qyR+UDUbBaaK3Bwyjk8E02C4eSpkabJZGB0Yr3CUpG4fw
|
||||
vgUd7rQ0ueeZlQIBIwJgbh+1VZfr7WftK5lu7MHtqE1S1vPWZQYE3+VUn8yJADyb
|
||||
Z4fsZaCrzW9lkIqXkE3GIY+ojdhZhkO1gbG0118sIgphwSWKRxK0mvh6ERxKqIt1
|
||||
xJEJO74EykXZV4oNJ8sjAjEA3J9r2ZghVhGN6V8DnQrTk24Td0E8hU8AcP0FVP+8
|
||||
PQm/g/aXf2QQkQT+omdHVEJrAjEAy0pL0EBH6EVS98evDCBtQw22OZT52qXlAwZ2
|
||||
gyTriKFVoqjeEjt3SZKKqXHSApP/AjBLpF99zcJJZRq2abgYlf9lv1chkrWqDHUu
|
||||
DZttmYJeEfiFBBavVYIF1dOlZT0G8jMCMBc7sOSZodFnAiryP+Qg9otSBjJ3bQML
|
||||
pSTqy7c3a2AScC/YyOwkDaICHnnD3XyjMwIxALRzl0tQEKMXs6hH8ToUdlLROCrP
|
||||
EhQ0wahUTCk1gKA4uPD6TMTChavbh4K63OvbKg==
|
||||
-----END RSA PRIVATE KEY-----"""
|
||||
|
||||
# some versions of OpenSSH generate these (slightly different keys)
|
||||
privateRSA_openssh_alternate = """-----BEGIN RSA PRIVATE KEY-----
|
||||
MIIBzjCCAcgCAQACYQCvMnHw5g6cmbN/i18ES8uLwNU+snf9z2TYj8DPrh/GMd/2
|
||||
KbJEluLG1CGUf2V82NQjH7guaskflA1GwWmitwcMo5PBNNguHkqZGmyWRgdGK9wl
|
||||
KRuH8L4FHe60NLnnmZUCASMCYG4ftVWX6+1n7SuZbuzB7ahNUtbz1mUGBN/lVJ/M
|
||||
iQA8m2eH7GWgq81vZZCKl5BNxiGPqI3YWYZDtYGxtNdfLCIKYcElikcStJr4ehEc
|
||||
SqiLdcSRCTu+BMpF2VeKDSfLIwIxANyfa9mYIVYRjelfA50K05NuE3dBPIVPAHD9
|
||||
BVT/vD0Jv4P2l39kEJEE/qJnR1RCawIxAMtKS9BAR+hFUvfHrwwgbUMNtjmU+dql
|
||||
5QMGdoMk64ihVaKo3hI7d0mSiqlx0gKT/wIwS6Rffc3CSWUatmm4GJX/Zb9XIZK1
|
||||
qgx1Lg2bbZmCXhH4hQQWr1WCBdXTpWU9BvIzAjAXO7DkmaHRZwIq8j/kIPaLUgYy
|
||||
d20DC6Uk6su3N2tgEnAv2MjsJA2iAh55w918ozMCMQC0c5dLUBCjF7OoR/E6FHZS
|
||||
0TgqzxIUNMGoVEwpNYCgOLjw+kzEwoWr24eCutzr2yowAA==
|
||||
------END RSA PRIVATE KEY------"""
|
||||
|
||||
# encrypted with the passphrase 'encrypted'
|
||||
privateRSA_openssh_encrypted = """-----BEGIN RSA PRIVATE KEY-----
|
||||
Proc-Type: 4,ENCRYPTED
|
||||
DEK-Info: DES-EDE3-CBC,FFFFFFFFFFFFFFFF
|
||||
|
||||
30qUR7DYY/rpVJu159paRM1mUqt/IMibfEMTKWSjNhCVD21hskftZCJROw/WgIFt
|
||||
ncusHpJMkjgwEpho0KyKilcC7zxjpunTex24Meb5pCdXCrYft8AyUkRdq3dugMqT
|
||||
4nuWuWxziluBhKQ2M9tPGcEOeulU4vVjceZt2pZhZQVBf08o3XUv5/7RYd24M9md
|
||||
WIo+5zdj2YQkI6xMFTP954O/X32ME1KQt98wgNEy6mxhItbvf00mH3woALwEKP3v
|
||||
PSMxxtx3VKeDKd9YTOm1giKkXZUf91vZWs0378tUBrU4U5qJxgryTjvvVKOtofj6
|
||||
4qQy6+r6M6wtwVlXBgeRm2gBPvL3nv6MsROp3E6ztBd/e7A8fSec+UTq3ko/EbGP
|
||||
0QG+IG5tg8FsdITxQ9WAIITZL3Rc6hA5Ymx1VNhySp3iSiso8Jof27lku4pyuvRV
|
||||
ko/B3N2H7LnQrGV0GyrjeYocW/qZh/PCsY48JBFhlNQexn2mn44AJW3y5xgbhvKA
|
||||
3mrmMD1hD17ZvZxi4fPHjbuAyM1vFqhQx63eT9ijbwJ91svKJl5O5MIv41mCRonm
|
||||
hxvOXw8S0mjSasyofptzzQCtXxFLQigXbpQBltII+Ys=
|
||||
-----END RSA PRIVATE KEY-----"""
|
||||
|
||||
# encrypted with the passphrase 'testxp'. NB: this key was generated by
|
||||
# OpenSSH, so it doesn't use the same key data as the other keys here.
|
||||
privateRSA_openssh_encrypted_aes = """-----BEGIN RSA PRIVATE KEY-----
|
||||
Proc-Type: 4,ENCRYPTED
|
||||
DEK-Info: AES-128-CBC,0673309A6ACCAB4B77DEE1C1E536AC26
|
||||
|
||||
4Ed/a9OgJWHJsne7yOGWeWMzHYKsxuP9w1v0aYcp+puS75wvhHLiUnNwxz0KDi6n
|
||||
T3YkKLBsoCWS68ApR2J9yeQ6R+EyS+UQDrO9nwqo3DB5BT3Ggt8S1wE7vjNLQD0H
|
||||
g/SJnlqwsECNhh8aAx+Ag0m3ZKOZiRD5mCkcDQsZET7URSmFytDKOjhFn3u6ZFVB
|
||||
sXrfpYc6TJtOQlHd/52JB6aAbjt6afSv955Z7enIi+5yEJ5y7oYQTaE5zrFMP7N5
|
||||
9LbfJFlKXxEddy/DErRLxEjmC+t4svHesoJKc2jjjyNPiOoGGF3kJXea62vsjdNV
|
||||
gMK5Eged3TBVIk2dv8rtJUvyFeCUtjQ1UJZIebScRR47KrbsIpCmU8I4/uHWm5hW
|
||||
0mOwvdx1L/mqx/BHqVU9Dw2COhOdLbFxlFI92chkovkmNk4P48ziyVnpm7ME22sE
|
||||
vfCMsyirdqB1mrL4CSM7FXONv+CgfBfeYVkYW8RfJac9U1L/O+JNn7yee414O/rS
|
||||
hRYw4UdWnH6Gg6niklVKWNY0ZwUZC8zgm2iqy8YCYuneS37jC+OEKP+/s6HSKuqk
|
||||
2bzcl3/TcZXNSM815hnFRpz0anuyAsvwPNRyvxG2/DacJHL1f6luV4B0o6W410yf
|
||||
qXQx01DLo7nuyhJqoH3UGCyyXB+/QUs0mbG2PAEn3f5dVs31JMdbt+PrxURXXjKk
|
||||
4cexpUcIpqqlfpIRe3RD0sDVbH4OXsGhi2kiTfPZu7mgyFxKopRbn1KwU1qKinfY
|
||||
EU9O4PoTak/tPT+5jFNhaP+HrURoi/pU8EAUNSktl7xAkHYwkN/9Cm7DeBghgf3n
|
||||
8+tyCGYDsB5utPD0/Xe9yx0Qhc/kMm4xIyQDyA937dk3mUvLC9vulnAP8I+Izim0
|
||||
fZ182+D1bWwykoD0997mUHG/AUChWR01V1OLwRyPv2wUtiS8VNG76Y2aqKlgqP1P
|
||||
V+IvIEqR4ERvSBVFzXNF8Y6j/sVxo8+aZw+d0L1Ns/R55deErGg3B8i/2EqGd3r+
|
||||
0jps9BqFHHWW87n3VyEB3jWCMj8Vi2EJIfa/7pSaViFIQn8LiBLf+zxG5LTOToK5
|
||||
xkN42fReDcqi3UNfKNGnv4dsplyTR2hyx65lsj4bRKDGLKOuB1y7iB0AGb0LtcAI
|
||||
dcsVlcCeUquDXtqKvRnwfIMg+ZunyjqHBhj3qgRgbXbT6zjaSdNnih569aTg0Vup
|
||||
VykzZ7+n/KVcGLmvX0NesdoI7TKbq4TnEIOynuG5Sf+2GpARO5bjcWKSZeN/Ybgk
|
||||
gccf8Cqf6XWqiwlWd0B7BR3SymeHIaSymC45wmbgdstrbk7Ppa2Tp9AZku8M2Y7c
|
||||
8mY9b+onK075/ypiwBm4L4GRNTFLnoNQJXx0OSl4FNRWsn6ztbD+jZhu8Seu10Jw
|
||||
SEJVJ+gmTKdRLYORJKyqhDet6g7kAxs4EoJ25WsOnX5nNr00rit+NkMPA7xbJT+7
|
||||
CfI51GQLw7pUPeO2WNt6yZO/YkzZrqvTj5FEwybkUyBv7L0gkqu9wjfDdUw0fVHE
|
||||
xEm4DxjEoaIp8dW/JOzXQ2EF+WaSOgdYsw3Ac+rnnjnNptCdOEDGP6QBkt+oXj4P
|
||||
-----END RSA PRIVATE KEY-----"""
|
||||
|
||||
publicRSA_lsh = ("{KDEwOnB1YmxpYy1rZXkoMTQ6cnNhLXBrY3MxLXNoYTEoMTpuOTc6AK8yc"
|
||||
"fDmDpyZs3+LXwRLy4vA1T6yd/3PZNiPwM+uH8Yx3/YpskSW4sbUIZR/ZXzY1CMfuC5qyR+UDUbB"
|
||||
"aaK3Bwyjk8E02C4eSpkabJZGB0Yr3CUpG4fwvgUd7rQ0ueeZlSkoMTplMTojKSkp}")
|
||||
|
||||
privateRSA_lsh = ("(11:private-key(9:rsa-pkcs1(1:n97:\x00\xaf2q\xf0\xe6\x0e"
|
||||
"\x9c\x99\xb3\x7f\x8b_\x04K\xcb\x8b\xc0\xd5>\xb2w\xfd\xcfd\xd8\x8f\xc0\xcf"
|
||||
"\xae\x1f\xc61\xdf\xf6)\xb2D\x96\xe2\xc6\xd4!\x94\x7fe|\xd8\xd4#\x1f\xb8.j"
|
||||
"\xc9\x1f\x94\rF\xc1i\xa2\xb7\x07\x0c\xa3\x93\xc14\xd8.\x1eJ\x99\x1al\x96F"
|
||||
"\x07F+\xdc%)\x1b\x87\xf0\xbe\x05\x1d\xee\xb44\xb9\xe7\x99\x95)(1:e1:#)(1:d9"
|
||||
"6:n\x1f\xb5U\x97\xeb\xedg\xed+\x99n\xec\xc1\xed\xa8MR\xd6\xf3\xd6e\x06\x04"
|
||||
"\xdf\xe5T\x9f\xcc\x89\x00<\x9bg\x87\xece\xa0\xab\xcdoe\x90\x8a\x97\x90M\xc6"
|
||||
'!\x8f\xa8\x8d\xd8Y\x86C\xb5\x81\xb1\xb4\xd7_,"\na\xc1%\x8aG\x12\xb4\x9a\xf8'
|
||||
"z\x11\x1cJ\xa8\x8bu\xc4\x91\t;\xbe\x04\xcaE\xd9W\x8a\r\'\xcb#)(1:p49:\x00"
|
||||
"\xdc\x9fk\xd9\x98!V\x11\x8d\xe9_\x03\x9d\n\xd3\x93n\x13wA<\x85O\x00p\xfd"
|
||||
"\x05T\xff\xbc=\t\xbf\x83\xf6\x97\x7fd\x10\x91\x04\xfe\xa2gGTBk)(1:q49:\x00"
|
||||
"\xcbJK\xd0@G\xe8ER\xf7\xc7\xaf\x0c mC\r\xb69\x94\xf9\xda\xa5\xe5\x03\x06v"
|
||||
"\x83$\xeb\x88\xa1U\xa2\xa8\xde\x12;wI\x92\x8a\xa9q\xd2\x02\x93\xff)(1:a48:K"
|
||||
"\xa4_}\xcd\xc2Ie\x1a\xb6i\xb8\x18\x95\xffe\xbfW!\x92\xb5\xaa\x0cu.\r\x9bm"
|
||||
"\x99\x82^\x11\xf8\x85\x04\x16\xafU\x82\x05\xd5\xd3\xa5e=\x06\xf23)(1:b48:"
|
||||
"\x17;\xb0\xe4\x99\xa1\xd1g\x02*\xf2?\xe4 \xf6\x8bR\x062wm\x03\x0b\xa5$\xea"
|
||||
"\xcb\xb77k`\x12p/\xd8\xc8\xec$\r\xa2\x02\x1ey\xc3\xdd|\xa33)(1:c49:\x00\xb4"
|
||||
"s\x97KP\x10\xa3\x17\xb3\xa8G\xf1:\x14vR\xd18*\xcf\x12\x144\xc1\xa8TL)5\x80"
|
||||
"\xa08\xb8\xf0\xfaL\xc4\xc2\x85\xab\xdb\x87\x82\xba\xdc\xeb\xdb*)))")
|
||||
|
||||
privateRSA_agentv3 = ("\x00\x00\x00\x07ssh-rsa\x00\x00\x00\x01#\x00\x00\x00`"
|
||||
"n\x1f\xb5U\x97\xeb\xedg\xed+\x99n\xec\xc1\xed\xa8MR\xd6\xf3\xd6e\x06\x04"
|
||||
"\xdf\xe5T\x9f\xcc\x89\x00<\x9bg\x87\xece\xa0\xab\xcdoe\x90\x8a\x97\x90M\xc6"
|
||||
'!\x8f\xa8\x8d\xd8Y\x86C\xb5\x81\xb1\xb4\xd7_,"\na\xc1%\x8aG\x12\xb4\x9a\xf8'
|
||||
"z\x11\x1cJ\xa8\x8bu\xc4\x91\t;\xbe\x04\xcaE\xd9W\x8a\r\'\xcb#\x00\x00\x00a"
|
||||
"\x00\xaf2q\xf0\xe6\x0e\x9c\x99\xb3\x7f\x8b_\x04K\xcb\x8b\xc0\xd5>\xb2w\xfd"
|
||||
"\xcfd\xd8\x8f\xc0\xcf\xae\x1f\xc61\xdf\xf6)\xb2D\x96\xe2\xc6\xd4!\x94\x7fe|"
|
||||
"\xd8\xd4#\x1f\xb8.j\xc9\x1f\x94\rF\xc1i\xa2\xb7\x07\x0c\xa3\x93\xc14\xd8."
|
||||
"\x1eJ\x99\x1al\x96F\x07F+\xdc%)\x1b\x87\xf0\xbe\x05\x1d\xee\xb44\xb9\xe7"
|
||||
"\x99\x95\x00\x00\x001\x00\xb4s\x97KP\x10\xa3\x17\xb3\xa8G\xf1:\x14vR\xd18*"
|
||||
"\xcf\x12\x144\xc1\xa8TL)5\x80\xa08\xb8\xf0\xfaL\xc4\xc2\x85\xab\xdb\x87\x82"
|
||||
"\xba\xdc\xeb\xdb*\x00\x00\x001\x00\xcbJK\xd0@G\xe8ER\xf7\xc7\xaf\x0c mC\r"
|
||||
"\xb69\x94\xf9\xda\xa5\xe5\x03\x06v\x83$\xeb\x88\xa1U\xa2\xa8\xde\x12;wI\x92"
|
||||
"\x8a\xa9q\xd2\x02\x93\xff\x00\x00\x001\x00\xdc\x9fk\xd9\x98!V\x11\x8d\xe9_"
|
||||
"\x03\x9d\n\xd3\x93n\x13wA<\x85O\x00p\xfd\x05T\xff\xbc=\t\xbf\x83\xf6\x97"
|
||||
"\x7fd\x10\x91\x04\xfe\xa2gGTBk")
|
||||
|
||||
publicDSA_openssh = ("ssh-dss AAAAB3NzaC1kc3MAAABBAIbwTOSsZ7Bl7U1KyMNqV13Tu7"
|
||||
"yRAtTr70PVI3QnfrPumf2UzCgpL1ljbKxSfAi05XvrE/1vfCFAsFYXRZLhQy0AAAAVAM965Akmo"
|
||||
"6eAi7K+k9qDR4TotFAXAAAAQADZlpTW964haQWS4vC063NGdldT6xpUGDcDRqbm90CoPEa2RmNO"
|
||||
"uOqi8lnbhYraEzypYH3K4Gzv/bxCBnKtHRUAAABAK+1osyWBS0+P90u/rAuko6chZ98thUSY2kL"
|
||||
"SHp6hLKyy2bjnT29h7haELE+XHfq2bM9fckDx2FLOSIJzy83VmQ== comment")
|
||||
|
||||
privateDSA_openssh = """-----BEGIN DSA PRIVATE KEY-----
|
||||
MIH4AgEAAkEAhvBM5KxnsGXtTUrIw2pXXdO7vJEC1OvvQ9UjdCd+s+6Z/ZTMKCkv
|
||||
WWNsrFJ8CLTle+sT/W98IUCwVhdFkuFDLQIVAM965Akmo6eAi7K+k9qDR4TotFAX
|
||||
AkAA2ZaU1veuIWkFkuLwtOtzRnZXU+saVBg3A0am5vdAqDxGtkZjTrjqovJZ24WK
|
||||
2hM8qWB9yuBs7/28QgZyrR0VAkAr7WizJYFLT4/3S7+sC6SjpyFn3y2FRJjaQtIe
|
||||
nqEsrLLZuOdPb2HuFoQsT5cd+rZsz19yQPHYUs5IgnPLzdWZAhUAl1TqdmlAG/b4
|
||||
nnVchGiO9sML8MM=
|
||||
-----END DSA PRIVATE KEY-----"""
|
||||
|
||||
publicDSA_lsh = ("{KDEwOnB1YmxpYy1rZXkoMzpkc2EoMTpwNjU6AIbwTOSsZ7Bl7U1KyMNqV"
|
||||
"13Tu7yRAtTr70PVI3QnfrPumf2UzCgpL1ljbKxSfAi05XvrE/1vfCFAsFYXRZLhQy0pKDE6cTIx"
|
||||
"OgDPeuQJJqOngIuyvpPag0eE6LRQFykoMTpnNjQ6ANmWlNb3riFpBZLi8LTrc0Z2V1PrGlQYNwN"
|
||||
"Gpub3QKg8RrZGY0646qLyWduFitoTPKlgfcrgbO/9vEIGcq0dFSkoMTp5NjQ6K+1osyWBS0+P90"
|
||||
"u/rAuko6chZ98thUSY2kLSHp6hLKyy2bjnT29h7haELE+XHfq2bM9fckDx2FLOSIJzy83VmSkpK"
|
||||
"Q==}")
|
||||
|
||||
privateDSA_lsh = ("(11:private-key(3:dsa(1:p65:\x00\x86\xf0L\xe4\xacg\xb0e"
|
||||
"\xedMJ\xc8\xc3jW]\xd3\xbb\xbc\x91\x02\xd4\xeb\xefC\xd5#t'~\xb3\xee\x99\xfd"
|
||||
"\x94\xcc()/Ycl\xacR|\x08\xb4\xe5{\xeb\x13\xfdo|!@\xb0V\x17E\x92\xe1C-)(1:q2"
|
||||
"1:\x00\xcfz\xe4\t&\xa3\xa7\x80\x8b\xb2\xbe\x93\xda\x83G\x84\xe8\xb4P\x17)(1"
|
||||
":g64:\x00\xd9\x96\x94\xd6\xf7\xae!i\x05\x92\xe2\xf0\xb4\xebsFvWS\xeb\x1aT"
|
||||
"\x187\x03F\xa6\xe6\xf7@\xa8<F\xb6FcN\xb8\xea\xa2\xf2Y\xdb\x85\x8a\xda\x13<"
|
||||
"\xa9`}\xca\xe0l\xef\xfd\xbcB\x06r\xad\x1d\x15)(1:y64:+\xedh\xb3%\x81KO\x8f"
|
||||
"\xf7K\xbf\xac\x0b\xa4\xa3\xa7!g\xdf-\x85D\x98\xdaB\xd2\x1e\x9e\xa1,\xac\xb2"
|
||||
"\xd9\xb8\xe7Ooa\xee\x16\x84,O\x97\x1d\xfa\xb6l\xcf_r@\xf1\xd8R\xceH\x82s"
|
||||
"\xcb\xcd\xd5\x99)(1:x21:\x00\x97T\xeavi@\x1b\xf6\xf8\x9eu\\\x84h\x8e\xf6"
|
||||
"\xc3\x0b\xf0\xc3)))")
|
||||
|
||||
privateDSA_agentv3 = ("\x00\x00\x00\x07ssh-dss\x00\x00\x00A\x00\x86\xf0L\xe4"
|
||||
"\xacg\xb0e\xedMJ\xc8\xc3jW]\xd3\xbb\xbc\x91\x02\xd4\xeb\xefC\xd5#t'~\xb3"
|
||||
"\xee\x99\xfd\x94\xcc()/Ycl\xacR|\x08\xb4\xe5{\xeb\x13\xfdo|!@\xb0V\x17E\x92"
|
||||
"\xe1C-\x00\x00\x00\x15\x00\xcfz\xe4\t&\xa3\xa7\x80\x8b\xb2\xbe\x93\xda\x83G"
|
||||
"\x84\xe8\xb4P\x17\x00\x00\x00@\x00\xd9\x96\x94\xd6\xf7\xae!i\x05\x92\xe2"
|
||||
"\xf0\xb4\xebsFvWS\xeb\x1aT\x187\x03F\xa6\xe6\xf7@\xa8<F\xb6FcN\xb8\xea\xa2"
|
||||
"\xf2Y\xdb\x85\x8a\xda\x13<\xa9`}\xca\xe0l\xef\xfd\xbcB\x06r\xad\x1d\x15\x00"
|
||||
"\x00\x00@+\xedh\xb3%\x81KO\x8f\xf7K\xbf\xac\x0b\xa4\xa3\xa7!g\xdf-\x85D\x98"
|
||||
"\xdaB\xd2\x1e\x9e\xa1,\xac\xb2\xd9\xb8\xe7Ooa\xee\x16\x84,O\x97\x1d\xfa\xb6"
|
||||
"l\xcf_r@\xf1\xd8R\xceH\x82s\xcb\xcd\xd5\x99\x00\x00\x00\x15\x00\x97T\xeavi@"
|
||||
"\x1b\xf6\xf8\x9eu\\\x84h\x8e\xf6\xc3\x0b\xf0\xc3")
|
||||
|
||||
__all__ = ['DSAData', 'RSAData', 'privateDSA_agentv3', 'privateDSA_lsh',
|
||||
'privateDSA_openssh', 'privateRSA_agentv3', 'privateRSA_lsh',
|
||||
'privateRSA_openssh', 'publicDSA_lsh', 'publicDSA_openssh',
|
||||
'publicRSA_lsh', 'publicRSA_openssh', 'privateRSA_openssh_alternate']
|
||||
|
||||
|
|
@ -0,0 +1,49 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{SSHTransportAddrress} in ssh/address.py
|
||||
"""
|
||||
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet.address import IPv4Address
|
||||
from twisted.internet.test.test_address import AddressTestCaseMixin
|
||||
|
||||
from twisted.conch.ssh.address import SSHTransportAddress
|
||||
|
||||
|
||||
|
||||
class SSHTransportAddressTestCase(unittest.TestCase, AddressTestCaseMixin):
|
||||
"""
|
||||
L{twisted.conch.ssh.address.SSHTransportAddress} is what Conch transports
|
||||
use to represent the other side of the SSH connection. This tests the
|
||||
basic functionality of that class (string representation, comparison, &c).
|
||||
"""
|
||||
|
||||
|
||||
def _stringRepresentation(self, stringFunction):
|
||||
"""
|
||||
The string representation of C{SSHTransportAddress} should be
|
||||
"SSHTransportAddress(<stringFunction on address>)".
|
||||
"""
|
||||
addr = self.buildAddress()
|
||||
stringValue = stringFunction(addr)
|
||||
addressValue = stringFunction(addr.address)
|
||||
self.assertEqual(stringValue,
|
||||
"SSHTransportAddress(%s)" % addressValue)
|
||||
|
||||
|
||||
def buildAddress(self):
|
||||
"""
|
||||
Create an arbitrary new C{SSHTransportAddress}. A new instance is
|
||||
created for each call, but always for the same address.
|
||||
"""
|
||||
return SSHTransportAddress(IPv4Address("TCP", "127.0.0.1", 22))
|
||||
|
||||
|
||||
def buildDifferentAddress(self):
|
||||
"""
|
||||
Like C{buildAddress}, but with a different fixed address.
|
||||
"""
|
||||
return SSHTransportAddress(IPv4Address("TCP", "127.0.0.2", 22))
|
||||
|
||||
|
|
@ -0,0 +1,399 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.conch.ssh.agent}.
|
||||
"""
|
||||
|
||||
import struct
|
||||
|
||||
from twisted.trial import unittest
|
||||
|
||||
try:
|
||||
import OpenSSL
|
||||
except ImportError:
|
||||
iosim = None
|
||||
else:
|
||||
from twisted.test import iosim
|
||||
|
||||
try:
|
||||
import Crypto.Cipher.DES3
|
||||
except ImportError:
|
||||
Crypto = None
|
||||
|
||||
try:
|
||||
import pyasn1
|
||||
except ImportError:
|
||||
pyasn1 = None
|
||||
|
||||
if Crypto and pyasn1:
|
||||
from twisted.conch.ssh import keys, agent
|
||||
else:
|
||||
keys = agent = None
|
||||
|
||||
from twisted.conch.test import keydata
|
||||
from twisted.conch.error import ConchError, MissingKeyStoreError
|
||||
|
||||
|
||||
class StubFactory(object):
|
||||
"""
|
||||
Mock factory that provides the keys attribute required by the
|
||||
SSHAgentServerProtocol
|
||||
"""
|
||||
def __init__(self):
|
||||
self.keys = {}
|
||||
|
||||
|
||||
|
||||
class AgentTestBase(unittest.TestCase):
|
||||
"""
|
||||
Tests for SSHAgentServer/Client.
|
||||
"""
|
||||
if iosim is None:
|
||||
skip = "iosim requires SSL, but SSL is not available"
|
||||
elif agent is None or keys is None:
|
||||
skip = "Cannot run without PyCrypto or PyASN1"
|
||||
|
||||
def setUp(self):
|
||||
# wire up our client <-> server
|
||||
self.client, self.server, self.pump = iosim.connectedServerAndClient(
|
||||
agent.SSHAgentServer, agent.SSHAgentClient)
|
||||
|
||||
# the server's end of the protocol is stateful and we store it on the
|
||||
# factory, for which we only need a mock
|
||||
self.server.factory = StubFactory()
|
||||
|
||||
# pub/priv keys of each kind
|
||||
self.rsaPrivate = keys.Key.fromString(keydata.privateRSA_openssh)
|
||||
self.dsaPrivate = keys.Key.fromString(keydata.privateDSA_openssh)
|
||||
|
||||
self.rsaPublic = keys.Key.fromString(keydata.publicRSA_openssh)
|
||||
self.dsaPublic = keys.Key.fromString(keydata.publicDSA_openssh)
|
||||
|
||||
|
||||
|
||||
class TestServerProtocolContractWithFactory(AgentTestBase):
|
||||
"""
|
||||
The server protocol is stateful and so uses its factory to track state
|
||||
across requests. This test asserts that the protocol raises if its factory
|
||||
doesn't provide the necessary storage for that state.
|
||||
"""
|
||||
def test_factorySuppliesKeyStorageForServerProtocol(self):
|
||||
# need a message to send into the server
|
||||
msg = struct.pack('!LB',1, agent.AGENTC_REQUEST_IDENTITIES)
|
||||
del self.server.factory.__dict__['keys']
|
||||
self.assertRaises(MissingKeyStoreError,
|
||||
self.server.dataReceived, msg)
|
||||
|
||||
|
||||
|
||||
class TestUnimplementedVersionOneServer(AgentTestBase):
|
||||
"""
|
||||
Tests for methods with no-op implementations on the server. We need these
|
||||
for clients, such as openssh, that try v1 methods before going to v2.
|
||||
|
||||
Because the client doesn't expose these operations with nice method names,
|
||||
we invoke sendRequest directly with an op code.
|
||||
"""
|
||||
|
||||
def test_agentc_REQUEST_RSA_IDENTITIES(self):
|
||||
"""
|
||||
assert that we get the correct op code for an RSA identities request
|
||||
"""
|
||||
d = self.client.sendRequest(agent.AGENTC_REQUEST_RSA_IDENTITIES, '')
|
||||
self.pump.flush()
|
||||
def _cb(packet):
|
||||
self.assertEqual(
|
||||
agent.AGENT_RSA_IDENTITIES_ANSWER, ord(packet[0]))
|
||||
return d.addCallback(_cb)
|
||||
|
||||
|
||||
def test_agentc_REMOVE_RSA_IDENTITY(self):
|
||||
"""
|
||||
assert that we get the correct op code for an RSA remove identity request
|
||||
"""
|
||||
d = self.client.sendRequest(agent.AGENTC_REMOVE_RSA_IDENTITY, '')
|
||||
self.pump.flush()
|
||||
return d.addCallback(self.assertEqual, '')
|
||||
|
||||
|
||||
def test_agentc_REMOVE_ALL_RSA_IDENTITIES(self):
|
||||
"""
|
||||
assert that we get the correct op code for an RSA remove all identities
|
||||
request.
|
||||
"""
|
||||
d = self.client.sendRequest(agent.AGENTC_REMOVE_ALL_RSA_IDENTITIES, '')
|
||||
self.pump.flush()
|
||||
return d.addCallback(self.assertEqual, '')
|
||||
|
||||
|
||||
|
||||
if agent is not None:
|
||||
class CorruptServer(agent.SSHAgentServer):
|
||||
"""
|
||||
A misbehaving server that returns bogus response op codes so that we can
|
||||
verify that our callbacks that deal with these op codes handle such
|
||||
miscreants.
|
||||
"""
|
||||
def agentc_REQUEST_IDENTITIES(self, data):
|
||||
self.sendResponse(254, '')
|
||||
|
||||
|
||||
def agentc_SIGN_REQUEST(self, data):
|
||||
self.sendResponse(254, '')
|
||||
|
||||
|
||||
|
||||
class TestClientWithBrokenServer(AgentTestBase):
|
||||
"""
|
||||
verify error handling code in the client using a misbehaving server
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
AgentTestBase.setUp(self)
|
||||
self.client, self.server, self.pump = iosim.connectedServerAndClient(
|
||||
CorruptServer, agent.SSHAgentClient)
|
||||
# the server's end of the protocol is stateful and we store it on the
|
||||
# factory, for which we only need a mock
|
||||
self.server.factory = StubFactory()
|
||||
|
||||
|
||||
def test_signDataCallbackErrorHandling(self):
|
||||
"""
|
||||
Assert that L{SSHAgentClient.signData} raises a ConchError
|
||||
if we get a response from the server whose opcode doesn't match
|
||||
the protocol for data signing requests.
|
||||
"""
|
||||
d = self.client.signData(self.rsaPublic.blob(), "John Hancock")
|
||||
self.pump.flush()
|
||||
return self.assertFailure(d, ConchError)
|
||||
|
||||
|
||||
def test_requestIdentitiesCallbackErrorHandling(self):
|
||||
"""
|
||||
Assert that L{SSHAgentClient.requestIdentities} raises a ConchError
|
||||
if we get a response from the server whose opcode doesn't match
|
||||
the protocol for identity requests.
|
||||
"""
|
||||
d = self.client.requestIdentities()
|
||||
self.pump.flush()
|
||||
return self.assertFailure(d, ConchError)
|
||||
|
||||
|
||||
|
||||
class TestAgentKeyAddition(AgentTestBase):
|
||||
"""
|
||||
Test adding different flavors of keys to an agent.
|
||||
"""
|
||||
|
||||
def test_addRSAIdentityNoComment(self):
|
||||
"""
|
||||
L{SSHAgentClient.addIdentity} adds the private key it is called
|
||||
with to the SSH agent server to which it is connected, associating
|
||||
it with the comment it is called with.
|
||||
|
||||
This test asserts that ommitting the comment produces an
|
||||
empty string for the comment on the server.
|
||||
"""
|
||||
d = self.client.addIdentity(self.rsaPrivate.privateBlob())
|
||||
self.pump.flush()
|
||||
def _check(ignored):
|
||||
serverKey = self.server.factory.keys[self.rsaPrivate.blob()]
|
||||
self.assertEqual(self.rsaPrivate, serverKey[0])
|
||||
self.assertEqual('', serverKey[1])
|
||||
return d.addCallback(_check)
|
||||
|
||||
|
||||
def test_addDSAIdentityNoComment(self):
|
||||
"""
|
||||
L{SSHAgentClient.addIdentity} adds the private key it is called
|
||||
with to the SSH agent server to which it is connected, associating
|
||||
it with the comment it is called with.
|
||||
|
||||
This test asserts that ommitting the comment produces an
|
||||
empty string for the comment on the server.
|
||||
"""
|
||||
d = self.client.addIdentity(self.dsaPrivate.privateBlob())
|
||||
self.pump.flush()
|
||||
def _check(ignored):
|
||||
serverKey = self.server.factory.keys[self.dsaPrivate.blob()]
|
||||
self.assertEqual(self.dsaPrivate, serverKey[0])
|
||||
self.assertEqual('', serverKey[1])
|
||||
return d.addCallback(_check)
|
||||
|
||||
|
||||
def test_addRSAIdentityWithComment(self):
|
||||
"""
|
||||
L{SSHAgentClient.addIdentity} adds the private key it is called
|
||||
with to the SSH agent server to which it is connected, associating
|
||||
it with the comment it is called with.
|
||||
|
||||
This test asserts that the server receives/stores the comment
|
||||
as sent by the client.
|
||||
"""
|
||||
d = self.client.addIdentity(
|
||||
self.rsaPrivate.privateBlob(), comment='My special key')
|
||||
self.pump.flush()
|
||||
def _check(ignored):
|
||||
serverKey = self.server.factory.keys[self.rsaPrivate.blob()]
|
||||
self.assertEqual(self.rsaPrivate, serverKey[0])
|
||||
self.assertEqual('My special key', serverKey[1])
|
||||
return d.addCallback(_check)
|
||||
|
||||
|
||||
def test_addDSAIdentityWithComment(self):
|
||||
"""
|
||||
L{SSHAgentClient.addIdentity} adds the private key it is called
|
||||
with to the SSH agent server to which it is connected, associating
|
||||
it with the comment it is called with.
|
||||
|
||||
This test asserts that the server receives/stores the comment
|
||||
as sent by the client.
|
||||
"""
|
||||
d = self.client.addIdentity(
|
||||
self.dsaPrivate.privateBlob(), comment='My special key')
|
||||
self.pump.flush()
|
||||
def _check(ignored):
|
||||
serverKey = self.server.factory.keys[self.dsaPrivate.blob()]
|
||||
self.assertEqual(self.dsaPrivate, serverKey[0])
|
||||
self.assertEqual('My special key', serverKey[1])
|
||||
return d.addCallback(_check)
|
||||
|
||||
|
||||
|
||||
class TestAgentClientFailure(AgentTestBase):
|
||||
def test_agentFailure(self):
|
||||
"""
|
||||
verify that the client raises ConchError on AGENT_FAILURE
|
||||
"""
|
||||
d = self.client.sendRequest(254, '')
|
||||
self.pump.flush()
|
||||
return self.assertFailure(d, ConchError)
|
||||
|
||||
|
||||
|
||||
class TestAgentIdentityRequests(AgentTestBase):
|
||||
"""
|
||||
Test operations against a server with identities already loaded.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
AgentTestBase.setUp(self)
|
||||
self.server.factory.keys[self.dsaPrivate.blob()] = (
|
||||
self.dsaPrivate, 'a comment')
|
||||
self.server.factory.keys[self.rsaPrivate.blob()] = (
|
||||
self.rsaPrivate, 'another comment')
|
||||
|
||||
|
||||
def test_signDataRSA(self):
|
||||
"""
|
||||
Sign data with an RSA private key and then verify it with the public
|
||||
key.
|
||||
"""
|
||||
d = self.client.signData(self.rsaPublic.blob(), "John Hancock")
|
||||
self.pump.flush()
|
||||
def _check(sig):
|
||||
expected = self.rsaPrivate.sign("John Hancock")
|
||||
self.assertEqual(expected, sig)
|
||||
self.assertTrue(self.rsaPublic.verify(sig, "John Hancock"))
|
||||
return d.addCallback(_check)
|
||||
|
||||
|
||||
def test_signDataDSA(self):
|
||||
"""
|
||||
Sign data with a DSA private key and then verify it with the public
|
||||
key.
|
||||
"""
|
||||
d = self.client.signData(self.dsaPublic.blob(), "John Hancock")
|
||||
self.pump.flush()
|
||||
def _check(sig):
|
||||
# Cannot do this b/c DSA uses random numbers when signing
|
||||
# expected = self.dsaPrivate.sign("John Hancock")
|
||||
# self.assertEqual(expected, sig)
|
||||
self.assertTrue(self.dsaPublic.verify(sig, "John Hancock"))
|
||||
return d.addCallback(_check)
|
||||
|
||||
|
||||
def test_signDataRSAErrbackOnUnknownBlob(self):
|
||||
"""
|
||||
Assert that we get an errback if we try to sign data using a key that
|
||||
wasn't added.
|
||||
"""
|
||||
del self.server.factory.keys[self.rsaPublic.blob()]
|
||||
d = self.client.signData(self.rsaPublic.blob(), "John Hancock")
|
||||
self.pump.flush()
|
||||
return self.assertFailure(d, ConchError)
|
||||
|
||||
|
||||
def test_requestIdentities(self):
|
||||
"""
|
||||
Assert that we get all of the keys/comments that we add when we issue a
|
||||
request for all identities.
|
||||
"""
|
||||
d = self.client.requestIdentities()
|
||||
self.pump.flush()
|
||||
def _check(keyt):
|
||||
expected = {}
|
||||
expected[self.dsaPublic.blob()] = 'a comment'
|
||||
expected[self.rsaPublic.blob()] = 'another comment'
|
||||
|
||||
received = {}
|
||||
for k in keyt:
|
||||
received[keys.Key.fromString(k[0], type='blob').blob()] = k[1]
|
||||
self.assertEqual(expected, received)
|
||||
return d.addCallback(_check)
|
||||
|
||||
|
||||
|
||||
class TestAgentKeyRemoval(AgentTestBase):
|
||||
"""
|
||||
Test support for removing keys in a remote server.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
AgentTestBase.setUp(self)
|
||||
self.server.factory.keys[self.dsaPrivate.blob()] = (
|
||||
self.dsaPrivate, 'a comment')
|
||||
self.server.factory.keys[self.rsaPrivate.blob()] = (
|
||||
self.rsaPrivate, 'another comment')
|
||||
|
||||
|
||||
def test_removeRSAIdentity(self):
|
||||
"""
|
||||
Assert that we can remove an RSA identity.
|
||||
"""
|
||||
# only need public key for this
|
||||
d = self.client.removeIdentity(self.rsaPrivate.blob())
|
||||
self.pump.flush()
|
||||
|
||||
def _check(ignored):
|
||||
self.assertEqual(1, len(self.server.factory.keys))
|
||||
self.assertIn(self.dsaPrivate.blob(), self.server.factory.keys)
|
||||
self.assertNotIn(self.rsaPrivate.blob(), self.server.factory.keys)
|
||||
return d.addCallback(_check)
|
||||
|
||||
|
||||
def test_removeDSAIdentity(self):
|
||||
"""
|
||||
Assert that we can remove a DSA identity.
|
||||
"""
|
||||
# only need public key for this
|
||||
d = self.client.removeIdentity(self.dsaPrivate.blob())
|
||||
self.pump.flush()
|
||||
|
||||
def _check(ignored):
|
||||
self.assertEqual(1, len(self.server.factory.keys))
|
||||
self.assertIn(self.rsaPrivate.blob(), self.server.factory.keys)
|
||||
return d.addCallback(_check)
|
||||
|
||||
|
||||
def test_removeAllIdentities(self):
|
||||
"""
|
||||
Assert that we can remove all identities.
|
||||
"""
|
||||
d = self.client.removeAllIdentities()
|
||||
self.pump.flush()
|
||||
|
||||
def _check(ignored):
|
||||
self.assertEqual(0, len(self.server.factory.keys))
|
||||
return d.addCallback(_check)
|
||||
|
|
@ -0,0 +1,992 @@
|
|||
# -*- test-case-name: twisted.conch.test.test_cftp -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE file for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.conch.scripts.cftp}.
|
||||
"""
|
||||
|
||||
import locale
|
||||
import time, sys, os, operator, getpass, struct
|
||||
from StringIO import StringIO
|
||||
|
||||
from twisted.conch.test.test_ssh import Crypto, pyasn1
|
||||
|
||||
_reason = None
|
||||
if Crypto and pyasn1:
|
||||
try:
|
||||
from twisted.conch import unix
|
||||
from twisted.conch.scripts import cftp
|
||||
from twisted.conch.scripts.cftp import SSHSession
|
||||
from twisted.conch.test.test_filetransfer import FileTransferForTestAvatar
|
||||
except ImportError as e:
|
||||
unix = None
|
||||
_reason = str(e)
|
||||
del e
|
||||
else:
|
||||
unix = None
|
||||
|
||||
from twisted.python.fakepwd import UserDatabase
|
||||
from twisted.trial.unittest import TestCase
|
||||
from twisted.cred import portal
|
||||
from twisted.internet import reactor, protocol, interfaces, defer, error
|
||||
from twisted.internet.utils import getProcessOutputAndValue
|
||||
from twisted.python import log
|
||||
from twisted.conch import ls
|
||||
from twisted.test.proto_helpers import StringTransport
|
||||
from twisted.internet.task import Clock
|
||||
|
||||
from twisted.conch.test import test_ssh, test_conch
|
||||
from twisted.conch.test.test_filetransfer import SFTPTestBase
|
||||
from twisted.conch.test.test_filetransfer import FileTransferTestAvatar
|
||||
from twisted.conch.test.test_conch import FakeStdio
|
||||
|
||||
|
||||
|
||||
class SSHSessionTests(TestCase):
|
||||
"""
|
||||
Tests for L{twisted.conch.scripts.cftp.SSHSession}.
|
||||
"""
|
||||
def test_eofReceived(self):
|
||||
"""
|
||||
L{twisted.conch.scripts.cftp.SSHSession.eofReceived} loses the write
|
||||
half of its stdio connection.
|
||||
"""
|
||||
stdio = FakeStdio()
|
||||
channel = SSHSession()
|
||||
channel.stdio = stdio
|
||||
channel.eofReceived()
|
||||
self.assertTrue(stdio.writeConnLost)
|
||||
|
||||
|
||||
|
||||
class ListingTests(TestCase):
|
||||
"""
|
||||
Tests for L{lsLine}, the function which generates an entry for a file or
|
||||
directory in an SFTP I{ls} command's output.
|
||||
"""
|
||||
if getattr(time, 'tzset', None) is None:
|
||||
skip = "Cannot test timestamp formatting code without time.tzset"
|
||||
|
||||
def setUp(self):
|
||||
"""
|
||||
Patch the L{ls} module's time function so the results of L{lsLine} are
|
||||
deterministic.
|
||||
"""
|
||||
self.now = 123456789
|
||||
def fakeTime():
|
||||
return self.now
|
||||
self.patch(ls, 'time', fakeTime)
|
||||
|
||||
# Make sure that the timezone ends up the same after these tests as
|
||||
# it was before.
|
||||
if 'TZ' in os.environ:
|
||||
self.addCleanup(operator.setitem, os.environ, 'TZ', os.environ['TZ'])
|
||||
self.addCleanup(time.tzset)
|
||||
else:
|
||||
def cleanup():
|
||||
# os.environ.pop is broken! Don't use it! Ever! Or die!
|
||||
try:
|
||||
del os.environ['TZ']
|
||||
except KeyError:
|
||||
pass
|
||||
time.tzset()
|
||||
self.addCleanup(cleanup)
|
||||
|
||||
|
||||
def _lsInTimezone(self, timezone, stat):
|
||||
"""
|
||||
Call L{ls.lsLine} after setting the timezone to C{timezone} and return
|
||||
the result.
|
||||
"""
|
||||
# Set the timezone to a well-known value so the timestamps are
|
||||
# predictable.
|
||||
os.environ['TZ'] = timezone
|
||||
time.tzset()
|
||||
return ls.lsLine('foo', stat)
|
||||
|
||||
|
||||
def test_oldFile(self):
|
||||
"""
|
||||
A file with an mtime six months (approximately) or more in the past has
|
||||
a listing including a low-resolution timestamp.
|
||||
"""
|
||||
# Go with 7 months. That's more than 6 months.
|
||||
then = self.now - (60 * 60 * 24 * 31 * 7)
|
||||
stat = os.stat_result((0, 0, 0, 0, 0, 0, 0, 0, then, 0))
|
||||
|
||||
self.assertEqual(
|
||||
self._lsInTimezone('America/New_York', stat),
|
||||
'!--------- 0 0 0 0 Apr 26 1973 foo')
|
||||
self.assertEqual(
|
||||
self._lsInTimezone('Pacific/Auckland', stat),
|
||||
'!--------- 0 0 0 0 Apr 27 1973 foo')
|
||||
|
||||
|
||||
def test_oldSingleDigitDayOfMonth(self):
|
||||
"""
|
||||
A file with a high-resolution timestamp which falls on a day of the
|
||||
month which can be represented by one decimal digit is formatted with
|
||||
one padding 0 to preserve the columns which come after it.
|
||||
"""
|
||||
# A point about 7 months in the past, tweaked to fall on the first of a
|
||||
# month so we test the case we want to test.
|
||||
then = self.now - (60 * 60 * 24 * 31 * 7) + (60 * 60 * 24 * 5)
|
||||
stat = os.stat_result((0, 0, 0, 0, 0, 0, 0, 0, then, 0))
|
||||
|
||||
self.assertEqual(
|
||||
self._lsInTimezone('America/New_York', stat),
|
||||
'!--------- 0 0 0 0 May 01 1973 foo')
|
||||
self.assertEqual(
|
||||
self._lsInTimezone('Pacific/Auckland', stat),
|
||||
'!--------- 0 0 0 0 May 02 1973 foo')
|
||||
|
||||
|
||||
def test_newFile(self):
|
||||
"""
|
||||
A file with an mtime fewer than six months (approximately) in the past
|
||||
has a listing including a high-resolution timestamp excluding the year.
|
||||
"""
|
||||
# A point about three months in the past.
|
||||
then = self.now - (60 * 60 * 24 * 31 * 3)
|
||||
stat = os.stat_result((0, 0, 0, 0, 0, 0, 0, 0, then, 0))
|
||||
|
||||
self.assertEqual(
|
||||
self._lsInTimezone('America/New_York', stat),
|
||||
'!--------- 0 0 0 0 Aug 28 17:33 foo')
|
||||
self.assertEqual(
|
||||
self._lsInTimezone('Pacific/Auckland', stat),
|
||||
'!--------- 0 0 0 0 Aug 29 09:33 foo')
|
||||
|
||||
|
||||
def test_localeIndependent(self):
|
||||
"""
|
||||
The month name in the date is locale independent.
|
||||
"""
|
||||
# A point about three months in the past.
|
||||
then = self.now - (60 * 60 * 24 * 31 * 3)
|
||||
stat = os.stat_result((0, 0, 0, 0, 0, 0, 0, 0, then, 0))
|
||||
|
||||
# Fake that we're in a language where August is not Aug (e.g.: Spanish)
|
||||
currentLocale = locale.getlocale()
|
||||
locale.setlocale(locale.LC_ALL, "es_AR.UTF8")
|
||||
self.addCleanup(locale.setlocale, locale.LC_ALL, currentLocale)
|
||||
|
||||
self.assertEqual(
|
||||
self._lsInTimezone('America/New_York', stat),
|
||||
'!--------- 0 0 0 0 Aug 28 17:33 foo')
|
||||
self.assertEqual(
|
||||
self._lsInTimezone('Pacific/Auckland', stat),
|
||||
'!--------- 0 0 0 0 Aug 29 09:33 foo')
|
||||
|
||||
# if alternate locale is not available, the previous test will be
|
||||
# skipped, please install this locale for it to run
|
||||
currentLocale = locale.getlocale()
|
||||
try:
|
||||
try:
|
||||
locale.setlocale(locale.LC_ALL, "es_AR.UTF8")
|
||||
except locale.Error:
|
||||
test_localeIndependent.skip = "The es_AR.UTF8 locale is not installed."
|
||||
finally:
|
||||
locale.setlocale(locale.LC_ALL, currentLocale)
|
||||
|
||||
|
||||
def test_newSingleDigitDayOfMonth(self):
|
||||
"""
|
||||
A file with a high-resolution timestamp which falls on a day of the
|
||||
month which can be represented by one decimal digit is formatted with
|
||||
one padding 0 to preserve the columns which come after it.
|
||||
"""
|
||||
# A point about three months in the past, tweaked to fall on the first
|
||||
# of a month so we test the case we want to test.
|
||||
then = self.now - (60 * 60 * 24 * 31 * 3) + (60 * 60 * 24 * 4)
|
||||
stat = os.stat_result((0, 0, 0, 0, 0, 0, 0, 0, then, 0))
|
||||
|
||||
self.assertEqual(
|
||||
self._lsInTimezone('America/New_York', stat),
|
||||
'!--------- 0 0 0 0 Sep 01 17:33 foo')
|
||||
self.assertEqual(
|
||||
self._lsInTimezone('Pacific/Auckland', stat),
|
||||
'!--------- 0 0 0 0 Sep 02 09:33 foo')
|
||||
|
||||
|
||||
|
||||
class StdioClientTests(TestCase):
|
||||
"""
|
||||
Tests for L{cftp.StdioClient}.
|
||||
"""
|
||||
def setUp(self):
|
||||
"""
|
||||
Create a L{cftp.StdioClient} hooked up to dummy transport and a fake
|
||||
user database.
|
||||
"""
|
||||
class Connection:
|
||||
pass
|
||||
|
||||
conn = Connection()
|
||||
conn.transport = StringTransport()
|
||||
conn.transport.localClosed = False
|
||||
|
||||
self.client = cftp.StdioClient(conn)
|
||||
self.database = self.client._pwd = UserDatabase()
|
||||
|
||||
# Intentionally bypassing makeConnection - that triggers some code
|
||||
# which uses features not provided by our dumb Connection fake.
|
||||
self.client.transport = StringTransport()
|
||||
|
||||
|
||||
def test_exec(self):
|
||||
"""
|
||||
The I{exec} command runs its arguments locally in a child process
|
||||
using the user's shell.
|
||||
"""
|
||||
self.database.addUser(
|
||||
getpass.getuser(), 'secret', os.getuid(), 1234, 'foo', 'bar',
|
||||
sys.executable)
|
||||
|
||||
d = self.client._dispatchCommand("exec print 1 + 2")
|
||||
d.addCallback(self.assertEqual, "3\n")
|
||||
return d
|
||||
|
||||
|
||||
def test_execWithoutShell(self):
|
||||
"""
|
||||
If the local user has no shell, the I{exec} command runs its arguments
|
||||
using I{/bin/sh}.
|
||||
"""
|
||||
self.database.addUser(
|
||||
getpass.getuser(), 'secret', os.getuid(), 1234, 'foo', 'bar', '')
|
||||
|
||||
d = self.client._dispatchCommand("exec echo hello")
|
||||
d.addCallback(self.assertEqual, "hello\n")
|
||||
return d
|
||||
|
||||
|
||||
def test_bang(self):
|
||||
"""
|
||||
The I{exec} command is run for lines which start with C{"!"}.
|
||||
"""
|
||||
self.database.addUser(
|
||||
getpass.getuser(), 'secret', os.getuid(), 1234, 'foo', 'bar',
|
||||
'/bin/sh')
|
||||
|
||||
d = self.client._dispatchCommand("!echo hello")
|
||||
d.addCallback(self.assertEqual, "hello\n")
|
||||
return d
|
||||
|
||||
|
||||
def setKnownConsoleSize(self, width, height):
|
||||
"""
|
||||
For the duration of this test, patch C{cftp}'s C{fcntl} module to return
|
||||
a fixed width and height.
|
||||
|
||||
@param width: the width in characters
|
||||
@type width: C{int}
|
||||
@param height: the height in characters
|
||||
@type height: C{int}
|
||||
"""
|
||||
import tty # local import to avoid win32 issues
|
||||
class FakeFcntl(object):
|
||||
def ioctl(self, fd, opt, mutate):
|
||||
if opt != tty.TIOCGWINSZ:
|
||||
self.fail("Only window-size queries supported.")
|
||||
return struct.pack("4H", height, width, 0, 0)
|
||||
self.patch(cftp, "fcntl", FakeFcntl())
|
||||
|
||||
|
||||
def test_progressReporting(self):
|
||||
"""
|
||||
L{StdioClient._printProgressBar} prints a progress description,
|
||||
including percent done, amount transferred, transfer rate, and time
|
||||
remaining, all based the given start time, the given L{FileWrapper}'s
|
||||
progress information and the reactor's current time.
|
||||
"""
|
||||
# Use a short, known console width because this simple test doesn't need
|
||||
# to test the console padding.
|
||||
self.setKnownConsoleSize(10, 34)
|
||||
clock = self.client.reactor = Clock()
|
||||
wrapped = StringIO("x")
|
||||
wrapped.name = "sample"
|
||||
wrapper = cftp.FileWrapper(wrapped)
|
||||
wrapper.size = 1024 * 10
|
||||
startTime = clock.seconds()
|
||||
clock.advance(2.0)
|
||||
wrapper.total += 4096
|
||||
self.client._printProgressBar(wrapper, startTime)
|
||||
self.assertEqual(self.client.transport.value(),
|
||||
"\rsample 40% 4.0kB 2.0kBps 00:03 ")
|
||||
|
||||
|
||||
def test_reportNoProgress(self):
|
||||
"""
|
||||
L{StdioClient._printProgressBar} prints a progress description that
|
||||
indicates 0 bytes transferred if no bytes have been transferred and no
|
||||
time has passed.
|
||||
"""
|
||||
self.setKnownConsoleSize(10, 34)
|
||||
clock = self.client.reactor = Clock()
|
||||
wrapped = StringIO("x")
|
||||
wrapped.name = "sample"
|
||||
wrapper = cftp.FileWrapper(wrapped)
|
||||
startTime = clock.seconds()
|
||||
self.client._printProgressBar(wrapper, startTime)
|
||||
self.assertEqual(self.client.transport.value(),
|
||||
"\rsample 0% 0.0B 0.0Bps 00:00 ")
|
||||
|
||||
|
||||
|
||||
class FileTransferTestRealm:
|
||||
def __init__(self, testDir):
|
||||
self.testDir = testDir
|
||||
|
||||
def requestAvatar(self, avatarID, mind, *interfaces):
|
||||
a = FileTransferTestAvatar(self.testDir)
|
||||
return interfaces[0], a, lambda: None
|
||||
|
||||
|
||||
class SFTPTestProcess(protocol.ProcessProtocol):
|
||||
"""
|
||||
Protocol for testing cftp. Provides an interface between Python (where all
|
||||
the tests are) and the cftp client process (which does the work that is
|
||||
being tested).
|
||||
"""
|
||||
|
||||
def __init__(self, onOutReceived):
|
||||
"""
|
||||
@param onOutReceived: A L{Deferred} to be fired as soon as data is
|
||||
received from stdout.
|
||||
"""
|
||||
self.clearBuffer()
|
||||
self.onOutReceived = onOutReceived
|
||||
self.onProcessEnd = None
|
||||
self._expectingCommand = None
|
||||
self._processEnded = False
|
||||
|
||||
def clearBuffer(self):
|
||||
"""
|
||||
Clear any buffered data received from stdout. Should be private.
|
||||
"""
|
||||
self.buffer = ''
|
||||
self._linesReceived = []
|
||||
self._lineBuffer = ''
|
||||
|
||||
def outReceived(self, data):
|
||||
"""
|
||||
Called by Twisted when the cftp client prints data to stdout.
|
||||
"""
|
||||
log.msg('got %s' % data)
|
||||
lines = (self._lineBuffer + data).split('\n')
|
||||
self._lineBuffer = lines.pop(-1)
|
||||
self._linesReceived.extend(lines)
|
||||
# XXX - not strictly correct.
|
||||
# We really want onOutReceived to fire after the first 'cftp>' prompt
|
||||
# has been received. (See use in TestOurServerCmdLineClient.setUp)
|
||||
if self.onOutReceived is not None:
|
||||
d, self.onOutReceived = self.onOutReceived, None
|
||||
d.callback(data)
|
||||
self.buffer += data
|
||||
self._checkForCommand()
|
||||
|
||||
def _checkForCommand(self):
|
||||
prompt = 'cftp> '
|
||||
if self._expectingCommand and self._lineBuffer == prompt:
|
||||
buf = '\n'.join(self._linesReceived)
|
||||
if buf.startswith(prompt):
|
||||
buf = buf[len(prompt):]
|
||||
self.clearBuffer()
|
||||
d, self._expectingCommand = self._expectingCommand, None
|
||||
d.callback(buf)
|
||||
|
||||
def errReceived(self, data):
|
||||
"""
|
||||
Called by Twisted when the cftp client prints data to stderr.
|
||||
"""
|
||||
log.msg('err: %s' % data)
|
||||
|
||||
def getBuffer(self):
|
||||
"""
|
||||
Return the contents of the buffer of data received from stdout.
|
||||
"""
|
||||
return self.buffer
|
||||
|
||||
def runCommand(self, command):
|
||||
"""
|
||||
Issue the given command via the cftp client. Return a C{Deferred} that
|
||||
fires when the server returns a result. Note that the C{Deferred} will
|
||||
callback even if the server returns some kind of error.
|
||||
|
||||
@param command: A string containing an sftp command.
|
||||
|
||||
@return: A C{Deferred} that fires when the sftp server returns a
|
||||
result. The payload is the server's response string.
|
||||
"""
|
||||
self._expectingCommand = defer.Deferred()
|
||||
self.clearBuffer()
|
||||
self.transport.write(command + '\n')
|
||||
return self._expectingCommand
|
||||
|
||||
def runScript(self, commands):
|
||||
"""
|
||||
Run each command in sequence and return a Deferred that fires when all
|
||||
commands are completed.
|
||||
|
||||
@param commands: A list of strings containing sftp commands.
|
||||
|
||||
@return: A C{Deferred} that fires when all commands are completed. The
|
||||
payload is a list of response strings from the server, in the same
|
||||
order as the commands.
|
||||
"""
|
||||
sem = defer.DeferredSemaphore(1)
|
||||
dl = [sem.run(self.runCommand, command) for command in commands]
|
||||
return defer.gatherResults(dl)
|
||||
|
||||
def killProcess(self):
|
||||
"""
|
||||
Kill the process if it is still running.
|
||||
|
||||
If the process is still running, sends a KILL signal to the transport
|
||||
and returns a C{Deferred} which fires when L{processEnded} is called.
|
||||
|
||||
@return: a C{Deferred}.
|
||||
"""
|
||||
if self._processEnded:
|
||||
return defer.succeed(None)
|
||||
self.onProcessEnd = defer.Deferred()
|
||||
self.transport.signalProcess('KILL')
|
||||
return self.onProcessEnd
|
||||
|
||||
def processEnded(self, reason):
|
||||
"""
|
||||
Called by Twisted when the cftp client process ends.
|
||||
"""
|
||||
self._processEnded = True
|
||||
if self.onProcessEnd:
|
||||
d, self.onProcessEnd = self.onProcessEnd, None
|
||||
d.callback(None)
|
||||
|
||||
|
||||
class CFTPClientTestBase(SFTPTestBase):
|
||||
def setUp(self):
|
||||
f = open('dsa_test.pub','w')
|
||||
f.write(test_ssh.publicDSA_openssh)
|
||||
f.close()
|
||||
f = open('dsa_test','w')
|
||||
f.write(test_ssh.privateDSA_openssh)
|
||||
f.close()
|
||||
os.chmod('dsa_test', 33152)
|
||||
f = open('kh_test','w')
|
||||
f.write('127.0.0.1 ' + test_ssh.publicRSA_openssh)
|
||||
f.close()
|
||||
return SFTPTestBase.setUp(self)
|
||||
|
||||
def startServer(self):
|
||||
realm = FileTransferTestRealm(self.testDir)
|
||||
p = portal.Portal(realm)
|
||||
p.registerChecker(test_ssh.ConchTestPublicKeyChecker())
|
||||
fac = test_ssh.ConchTestServerFactory()
|
||||
fac.portal = p
|
||||
self.server = reactor.listenTCP(0, fac, interface="127.0.0.1")
|
||||
|
||||
def stopServer(self):
|
||||
if not hasattr(self.server.factory, 'proto'):
|
||||
return self._cbStopServer(None)
|
||||
self.server.factory.proto.expectedLoseConnection = 1
|
||||
d = defer.maybeDeferred(
|
||||
self.server.factory.proto.transport.loseConnection)
|
||||
d.addCallback(self._cbStopServer)
|
||||
return d
|
||||
|
||||
def _cbStopServer(self, ignored):
|
||||
return defer.maybeDeferred(self.server.stopListening)
|
||||
|
||||
def tearDown(self):
|
||||
for f in ['dsa_test.pub', 'dsa_test', 'kh_test']:
|
||||
try:
|
||||
os.remove(f)
|
||||
except:
|
||||
pass
|
||||
return SFTPTestBase.tearDown(self)
|
||||
|
||||
|
||||
|
||||
class TestOurServerCmdLineClient(CFTPClientTestBase):
|
||||
|
||||
def setUp(self):
|
||||
CFTPClientTestBase.setUp(self)
|
||||
|
||||
self.startServer()
|
||||
cmds = ('-p %i -l testuser '
|
||||
'--known-hosts kh_test '
|
||||
'--user-authentications publickey '
|
||||
'--host-key-algorithms ssh-rsa '
|
||||
'-i dsa_test '
|
||||
'-a '
|
||||
'-v '
|
||||
'127.0.0.1')
|
||||
port = self.server.getHost().port
|
||||
cmds = test_conch._makeArgs((cmds % port).split(), mod='cftp')
|
||||
log.msg('running %s %s' % (sys.executable, cmds))
|
||||
d = defer.Deferred()
|
||||
self.processProtocol = SFTPTestProcess(d)
|
||||
d.addCallback(lambda _: self.processProtocol.clearBuffer())
|
||||
env = os.environ.copy()
|
||||
env['PYTHONPATH'] = os.pathsep.join(sys.path)
|
||||
reactor.spawnProcess(self.processProtocol, sys.executable, cmds,
|
||||
env=env)
|
||||
return d
|
||||
|
||||
def tearDown(self):
|
||||
d = self.stopServer()
|
||||
d.addCallback(lambda _: self.processProtocol.killProcess())
|
||||
return d
|
||||
|
||||
def _killProcess(self, ignored):
|
||||
try:
|
||||
self.processProtocol.transport.signalProcess('KILL')
|
||||
except error.ProcessExitedAlready:
|
||||
pass
|
||||
|
||||
def runCommand(self, command):
|
||||
"""
|
||||
Run the given command with the cftp client. Return a C{Deferred} that
|
||||
fires when the command is complete. Payload is the server's output for
|
||||
that command.
|
||||
"""
|
||||
return self.processProtocol.runCommand(command)
|
||||
|
||||
def runScript(self, *commands):
|
||||
"""
|
||||
Run the given commands with the cftp client. Returns a C{Deferred}
|
||||
that fires when the commands are all complete. The C{Deferred}'s
|
||||
payload is a list of output for each command.
|
||||
"""
|
||||
return self.processProtocol.runScript(commands)
|
||||
|
||||
def testCdPwd(self):
|
||||
"""
|
||||
Test that 'pwd' reports the current remote directory, that 'lpwd'
|
||||
reports the current local directory, and that changing to a
|
||||
subdirectory then changing to its parent leaves you in the original
|
||||
remote directory.
|
||||
"""
|
||||
# XXX - not actually a unit test, see docstring.
|
||||
homeDir = os.path.join(os.getcwd(), self.testDir)
|
||||
d = self.runScript('pwd', 'lpwd', 'cd testDirectory', 'cd ..', 'pwd')
|
||||
d.addCallback(lambda xs: xs[:3] + xs[4:])
|
||||
d.addCallback(self.assertEqual,
|
||||
[homeDir, os.getcwd(), '', homeDir])
|
||||
return d
|
||||
|
||||
def testChAttrs(self):
|
||||
"""
|
||||
Check that 'ls -l' output includes the access permissions and that
|
||||
this output changes appropriately with 'chmod'.
|
||||
"""
|
||||
def _check(results):
|
||||
self.flushLoggedErrors()
|
||||
self.assertTrue(results[0].startswith('-rw-r--r--'))
|
||||
self.assertEqual(results[1], '')
|
||||
self.assertTrue(results[2].startswith('----------'), results[2])
|
||||
self.assertEqual(results[3], '')
|
||||
|
||||
d = self.runScript('ls -l testfile1', 'chmod 0 testfile1',
|
||||
'ls -l testfile1', 'chmod 644 testfile1')
|
||||
return d.addCallback(_check)
|
||||
# XXX test chgrp/own
|
||||
|
||||
|
||||
def testList(self):
|
||||
"""
|
||||
Check 'ls' works as expected. Checks for wildcards, hidden files,
|
||||
listing directories and listing empty directories.
|
||||
"""
|
||||
def _check(results):
|
||||
self.assertEqual(results[0], ['testDirectory', 'testRemoveFile',
|
||||
'testRenameFile', 'testfile1'])
|
||||
self.assertEqual(results[1], ['testDirectory', 'testRemoveFile',
|
||||
'testRenameFile', 'testfile1'])
|
||||
self.assertEqual(results[2], ['testRemoveFile', 'testRenameFile'])
|
||||
self.assertEqual(results[3], ['.testHiddenFile', 'testRemoveFile',
|
||||
'testRenameFile'])
|
||||
self.assertEqual(results[4], [''])
|
||||
d = self.runScript('ls', 'ls ../' + os.path.basename(self.testDir),
|
||||
'ls *File', 'ls -a *File', 'ls -l testDirectory')
|
||||
d.addCallback(lambda xs: [x.split('\n') for x in xs])
|
||||
return d.addCallback(_check)
|
||||
|
||||
|
||||
def testHelp(self):
|
||||
"""
|
||||
Check that running the '?' command returns help.
|
||||
"""
|
||||
d = self.runCommand('?')
|
||||
d.addCallback(self.assertEqual,
|
||||
cftp.StdioClient(None).cmd_HELP('').strip())
|
||||
return d
|
||||
|
||||
def assertFilesEqual(self, name1, name2, msg=None):
|
||||
"""
|
||||
Assert that the files at C{name1} and C{name2} contain exactly the
|
||||
same data.
|
||||
"""
|
||||
f1 = file(name1).read()
|
||||
f2 = file(name2).read()
|
||||
self.assertEqual(f1, f2, msg)
|
||||
|
||||
|
||||
def testGet(self):
|
||||
"""
|
||||
Test that 'get' saves the remote file to the correct local location,
|
||||
that the output of 'get' is correct and that 'rm' actually removes
|
||||
the file.
|
||||
"""
|
||||
# XXX - not actually a unit test
|
||||
expectedOutput = ("Transferred %s/%s/testfile1 to %s/test file2"
|
||||
% (os.getcwd(), self.testDir, self.testDir))
|
||||
def _checkGet(result):
|
||||
self.assertTrue(result.endswith(expectedOutput))
|
||||
self.assertFilesEqual(self.testDir + '/testfile1',
|
||||
self.testDir + '/test file2',
|
||||
"get failed")
|
||||
return self.runCommand('rm "test file2"')
|
||||
|
||||
d = self.runCommand('get testfile1 "%s/test file2"' % (self.testDir,))
|
||||
d.addCallback(_checkGet)
|
||||
d.addCallback(lambda _: self.assertFalse(
|
||||
os.path.exists(self.testDir + '/test file2')))
|
||||
return d
|
||||
|
||||
|
||||
def testWildcardGet(self):
|
||||
"""
|
||||
Test that 'get' works correctly when given wildcard parameters.
|
||||
"""
|
||||
def _check(ignored):
|
||||
self.assertFilesEqual(self.testDir + '/testRemoveFile',
|
||||
'testRemoveFile',
|
||||
'testRemoveFile get failed')
|
||||
self.assertFilesEqual(self.testDir + '/testRenameFile',
|
||||
'testRenameFile',
|
||||
'testRenameFile get failed')
|
||||
|
||||
d = self.runCommand('get testR*')
|
||||
return d.addCallback(_check)
|
||||
|
||||
|
||||
def testPut(self):
|
||||
"""
|
||||
Check that 'put' uploads files correctly and that they can be
|
||||
successfully removed. Also check the output of the put command.
|
||||
"""
|
||||
# XXX - not actually a unit test
|
||||
expectedOutput = ('Transferred %s/testfile1 to %s/%s/test"file2'
|
||||
% (self.testDir, os.getcwd(), self.testDir))
|
||||
def _checkPut(result):
|
||||
self.assertFilesEqual(self.testDir + '/testfile1',
|
||||
self.testDir + '/test"file2')
|
||||
self.assertTrue(result.endswith(expectedOutput))
|
||||
return self.runCommand('rm "test\\"file2"')
|
||||
|
||||
d = self.runCommand('put %s/testfile1 "test\\"file2"'
|
||||
% (self.testDir,))
|
||||
d.addCallback(_checkPut)
|
||||
d.addCallback(lambda _: self.assertFalse(
|
||||
os.path.exists(self.testDir + '/test"file2')))
|
||||
return d
|
||||
|
||||
|
||||
def test_putOverLongerFile(self):
|
||||
"""
|
||||
Check that 'put' uploads files correctly when overwriting a longer
|
||||
file.
|
||||
"""
|
||||
# XXX - not actually a unit test
|
||||
f = file(os.path.join(self.testDir, 'shorterFile'), 'w')
|
||||
f.write("a")
|
||||
f.close()
|
||||
f = file(os.path.join(self.testDir, 'longerFile'), 'w')
|
||||
f.write("bb")
|
||||
f.close()
|
||||
def _checkPut(result):
|
||||
self.assertFilesEqual(self.testDir + '/shorterFile',
|
||||
self.testDir + '/longerFile')
|
||||
|
||||
d = self.runCommand('put %s/shorterFile longerFile'
|
||||
% (self.testDir,))
|
||||
d.addCallback(_checkPut)
|
||||
return d
|
||||
|
||||
|
||||
def test_putMultipleOverLongerFile(self):
|
||||
"""
|
||||
Check that 'put' uploads files correctly when overwriting a longer
|
||||
file and you use a wildcard to specify the files to upload.
|
||||
"""
|
||||
# XXX - not actually a unit test
|
||||
os.mkdir(os.path.join(self.testDir, 'dir'))
|
||||
f = file(os.path.join(self.testDir, 'dir', 'file'), 'w')
|
||||
f.write("a")
|
||||
f.close()
|
||||
f = file(os.path.join(self.testDir, 'file'), 'w')
|
||||
f.write("bb")
|
||||
f.close()
|
||||
def _checkPut(result):
|
||||
self.assertFilesEqual(self.testDir + '/dir/file',
|
||||
self.testDir + '/file')
|
||||
|
||||
d = self.runCommand('put %s/dir/*'
|
||||
% (self.testDir,))
|
||||
d.addCallback(_checkPut)
|
||||
return d
|
||||
|
||||
|
||||
def testWildcardPut(self):
|
||||
"""
|
||||
What happens if you issue a 'put' command and include a wildcard (i.e.
|
||||
'*') in parameter? Check that all files matching the wildcard are
|
||||
uploaded to the correct directory.
|
||||
"""
|
||||
def check(results):
|
||||
self.assertEqual(results[0], '')
|
||||
self.assertEqual(results[2], '')
|
||||
self.assertFilesEqual(self.testDir + '/testRemoveFile',
|
||||
self.testDir + '/../testRemoveFile',
|
||||
'testRemoveFile get failed')
|
||||
self.assertFilesEqual(self.testDir + '/testRenameFile',
|
||||
self.testDir + '/../testRenameFile',
|
||||
'testRenameFile get failed')
|
||||
|
||||
d = self.runScript('cd ..',
|
||||
'put %s/testR*' % (self.testDir,),
|
||||
'cd %s' % os.path.basename(self.testDir))
|
||||
d.addCallback(check)
|
||||
return d
|
||||
|
||||
|
||||
def testLink(self):
|
||||
"""
|
||||
Test that 'ln' creates a file which appears as a link in the output of
|
||||
'ls'. Check that removing the new file succeeds without output.
|
||||
"""
|
||||
def _check(results):
|
||||
self.flushLoggedErrors()
|
||||
self.assertEqual(results[0], '')
|
||||
self.assertTrue(results[1].startswith('l'), 'link failed')
|
||||
return self.runCommand('rm testLink')
|
||||
|
||||
d = self.runScript('ln testLink testfile1', 'ls -l testLink')
|
||||
d.addCallback(_check)
|
||||
d.addCallback(self.assertEqual, '')
|
||||
return d
|
||||
|
||||
|
||||
def testRemoteDirectory(self):
|
||||
"""
|
||||
Test that we can create and remove directories with the cftp client.
|
||||
"""
|
||||
def _check(results):
|
||||
self.assertEqual(results[0], '')
|
||||
self.assertTrue(results[1].startswith('d'))
|
||||
return self.runCommand('rmdir testMakeDirectory')
|
||||
|
||||
d = self.runScript('mkdir testMakeDirectory',
|
||||
'ls -l testMakeDirector?')
|
||||
d.addCallback(_check)
|
||||
d.addCallback(self.assertEqual, '')
|
||||
return d
|
||||
|
||||
|
||||
def test_existingRemoteDirectory(self):
|
||||
"""
|
||||
Test that a C{mkdir} on an existing directory fails with the
|
||||
appropriate error, and doesn't log an useless error server side.
|
||||
"""
|
||||
def _check(results):
|
||||
self.assertEqual(results[0], '')
|
||||
self.assertEqual(results[1],
|
||||
'remote error 11: mkdir failed')
|
||||
|
||||
d = self.runScript('mkdir testMakeDirectory',
|
||||
'mkdir testMakeDirectory')
|
||||
d.addCallback(_check)
|
||||
return d
|
||||
|
||||
|
||||
def testLocalDirectory(self):
|
||||
"""
|
||||
Test that we can create a directory locally and remove it with the
|
||||
cftp client. This test works because the 'remote' server is running
|
||||
out of a local directory.
|
||||
"""
|
||||
d = self.runCommand('lmkdir %s/testLocalDirectory' % (self.testDir,))
|
||||
d.addCallback(self.assertEqual, '')
|
||||
d.addCallback(lambda _: self.runCommand('rmdir testLocalDirectory'))
|
||||
d.addCallback(self.assertEqual, '')
|
||||
return d
|
||||
|
||||
|
||||
def testRename(self):
|
||||
"""
|
||||
Test that we can rename a file.
|
||||
"""
|
||||
def _check(results):
|
||||
self.assertEqual(results[0], '')
|
||||
self.assertEqual(results[1], 'testfile2')
|
||||
return self.runCommand('rename testfile2 testfile1')
|
||||
|
||||
d = self.runScript('rename testfile1 testfile2', 'ls testfile?')
|
||||
d.addCallback(_check)
|
||||
d.addCallback(self.assertEqual, '')
|
||||
return d
|
||||
|
||||
|
||||
|
||||
class TestOurServerBatchFile(CFTPClientTestBase):
|
||||
def setUp(self):
|
||||
CFTPClientTestBase.setUp(self)
|
||||
self.startServer()
|
||||
|
||||
def tearDown(self):
|
||||
CFTPClientTestBase.tearDown(self)
|
||||
return self.stopServer()
|
||||
|
||||
def _getBatchOutput(self, f):
|
||||
fn = self.mktemp()
|
||||
open(fn, 'w').write(f)
|
||||
port = self.server.getHost().port
|
||||
cmds = ('-p %i -l testuser '
|
||||
'--known-hosts kh_test '
|
||||
'--user-authentications publickey '
|
||||
'--host-key-algorithms ssh-rsa '
|
||||
'-i dsa_test '
|
||||
'-a '
|
||||
'-v -b %s 127.0.0.1') % (port, fn)
|
||||
cmds = test_conch._makeArgs(cmds.split(), mod='cftp')[1:]
|
||||
log.msg('running %s %s' % (sys.executable, cmds))
|
||||
env = os.environ.copy()
|
||||
env['PYTHONPATH'] = os.pathsep.join(sys.path)
|
||||
|
||||
self.server.factory.expectedLoseConnection = 1
|
||||
|
||||
d = getProcessOutputAndValue(sys.executable, cmds, env=env)
|
||||
|
||||
def _cleanup(res):
|
||||
os.remove(fn)
|
||||
return res
|
||||
|
||||
d.addCallback(lambda res: res[0])
|
||||
d.addBoth(_cleanup)
|
||||
|
||||
return d
|
||||
|
||||
def testBatchFile(self):
|
||||
"""Test whether batch file function of cftp ('cftp -b batchfile').
|
||||
This works by treating the file as a list of commands to be run.
|
||||
"""
|
||||
cmds = """pwd
|
||||
ls
|
||||
exit
|
||||
"""
|
||||
def _cbCheckResult(res):
|
||||
res = res.split('\n')
|
||||
log.msg('RES %s' % str(res))
|
||||
self.assertIn(self.testDir, res[1])
|
||||
self.assertEqual(res[3:-2], ['testDirectory', 'testRemoveFile',
|
||||
'testRenameFile', 'testfile1'])
|
||||
|
||||
d = self._getBatchOutput(cmds)
|
||||
d.addCallback(_cbCheckResult)
|
||||
return d
|
||||
|
||||
def testError(self):
|
||||
"""Test that an error in the batch file stops running the batch.
|
||||
"""
|
||||
cmds = """chown 0 missingFile
|
||||
pwd
|
||||
exit
|
||||
"""
|
||||
def _cbCheckResult(res):
|
||||
self.assertNotIn(self.testDir, res)
|
||||
|
||||
d = self._getBatchOutput(cmds)
|
||||
d.addCallback(_cbCheckResult)
|
||||
return d
|
||||
|
||||
def testIgnoredError(self):
|
||||
"""Test that a minus sign '-' at the front of a line ignores
|
||||
any errors.
|
||||
"""
|
||||
cmds = """-chown 0 missingFile
|
||||
pwd
|
||||
exit
|
||||
"""
|
||||
def _cbCheckResult(res):
|
||||
self.assertIn(self.testDir, res)
|
||||
|
||||
d = self._getBatchOutput(cmds)
|
||||
d.addCallback(_cbCheckResult)
|
||||
return d
|
||||
|
||||
|
||||
|
||||
class TestOurServerSftpClient(CFTPClientTestBase):
|
||||
"""
|
||||
Test the sftp server against sftp command line client.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
CFTPClientTestBase.setUp(self)
|
||||
return self.startServer()
|
||||
|
||||
|
||||
def tearDown(self):
|
||||
return self.stopServer()
|
||||
|
||||
|
||||
def test_extendedAttributes(self):
|
||||
"""
|
||||
Test the return of extended attributes by the server: the sftp client
|
||||
should ignore them, but still be able to parse the response correctly.
|
||||
|
||||
This test is mainly here to check that
|
||||
L{filetransfer.FILEXFER_ATTR_EXTENDED} has the correct value.
|
||||
"""
|
||||
fn = self.mktemp()
|
||||
open(fn, 'w').write("ls .\nexit")
|
||||
port = self.server.getHost().port
|
||||
|
||||
oldGetAttr = FileTransferForTestAvatar._getAttrs
|
||||
def _getAttrs(self, s):
|
||||
attrs = oldGetAttr(self, s)
|
||||
attrs["ext_foo"] = "bar"
|
||||
return attrs
|
||||
|
||||
self.patch(FileTransferForTestAvatar, "_getAttrs", _getAttrs)
|
||||
|
||||
self.server.factory.expectedLoseConnection = True
|
||||
cmds = ('-o', 'IdentityFile=dsa_test',
|
||||
'-o', 'UserKnownHostsFile=kh_test',
|
||||
'-o', 'HostKeyAlgorithms=ssh-rsa',
|
||||
'-o', 'Port=%i' % (port,), '-b', fn, 'testuser@127.0.0.1')
|
||||
d = getProcessOutputAndValue("sftp", cmds)
|
||||
def check(result):
|
||||
self.assertEqual(result[2], 0)
|
||||
for i in ['testDirectory', 'testRemoveFile',
|
||||
'testRenameFile', 'testfile1']:
|
||||
self.assertIn(i, result[0])
|
||||
return d.addCallback(check)
|
||||
|
||||
|
||||
|
||||
if unix is None or Crypto is None or pyasn1 is None or interfaces.IReactorProcess(reactor, None) is None:
|
||||
if _reason is None:
|
||||
_reason = "don't run w/o spawnProcess or PyCrypto or pyasn1"
|
||||
TestOurServerCmdLineClient.skip = _reason
|
||||
TestOurServerBatchFile.skip = _reason
|
||||
TestOurServerSftpClient.skip = _reason
|
||||
StdioClientTests.skip = _reason
|
||||
SSHSessionTests.skip = _reason
|
||||
else:
|
||||
from twisted.python.procutils import which
|
||||
if not which('sftp'):
|
||||
TestOurServerSftpClient.skip = "no sftp command-line client available"
|
||||
|
|
@ -0,0 +1,279 @@
|
|||
# Copyright (C) 2007-2008 Twisted Matrix Laboratories
|
||||
# See LICENSE for details
|
||||
|
||||
"""
|
||||
Test ssh/channel.py.
|
||||
"""
|
||||
from twisted.conch.ssh import channel
|
||||
from twisted.trial import unittest
|
||||
|
||||
|
||||
class MockTransport(object):
|
||||
"""
|
||||
A mock Transport. All we use is the getPeer() and getHost() methods.
|
||||
Channels implement the ITransport interface, and their getPeer() and
|
||||
getHost() methods return ('SSH', <transport's getPeer/Host value>) so
|
||||
we need to implement these methods so they have something to draw
|
||||
from.
|
||||
"""
|
||||
def getPeer(self):
|
||||
return ('MockPeer',)
|
||||
|
||||
def getHost(self):
|
||||
return ('MockHost',)
|
||||
|
||||
|
||||
class MockConnection(object):
|
||||
"""
|
||||
A mock for twisted.conch.ssh.connection.SSHConnection. Record the data
|
||||
that channels send, and when they try to close the connection.
|
||||
|
||||
@ivar data: a C{dict} mapping channel id #s to lists of data sent by that
|
||||
channel.
|
||||
@ivar extData: a C{dict} mapping channel id #s to lists of 2-tuples
|
||||
(extended data type, data) sent by that channel.
|
||||
@ivar closes: a C{dict} mapping channel id #s to True if that channel sent
|
||||
a close message.
|
||||
"""
|
||||
transport = MockTransport()
|
||||
|
||||
def __init__(self):
|
||||
self.data = {}
|
||||
self.extData = {}
|
||||
self.closes = {}
|
||||
|
||||
def logPrefix(self):
|
||||
"""
|
||||
Return our logging prefix.
|
||||
"""
|
||||
return "MockConnection"
|
||||
|
||||
def sendData(self, channel, data):
|
||||
"""
|
||||
Record the sent data.
|
||||
"""
|
||||
self.data.setdefault(channel, []).append(data)
|
||||
|
||||
def sendExtendedData(self, channel, type, data):
|
||||
"""
|
||||
Record the sent extended data.
|
||||
"""
|
||||
self.extData.setdefault(channel, []).append((type, data))
|
||||
|
||||
def sendClose(self, channel):
|
||||
"""
|
||||
Record that the channel sent a close message.
|
||||
"""
|
||||
self.closes[channel] = True
|
||||
|
||||
|
||||
class ChannelTestCase(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
"""
|
||||
Initialize the channel. remoteMaxPacket is 10 so that data is able
|
||||
to be sent (the default of 0 means no data is sent because no packets
|
||||
are made).
|
||||
"""
|
||||
self.conn = MockConnection()
|
||||
self.channel = channel.SSHChannel(conn=self.conn,
|
||||
remoteMaxPacket=10)
|
||||
self.channel.name = 'channel'
|
||||
|
||||
def test_init(self):
|
||||
"""
|
||||
Test that SSHChannel initializes correctly. localWindowSize defaults
|
||||
to 131072 (2**17) and localMaxPacket to 32768 (2**15) as reasonable
|
||||
defaults (what OpenSSH uses for those variables).
|
||||
|
||||
The values in the second set of assertions are meaningless; they serve
|
||||
only to verify that the instance variables are assigned in the correct
|
||||
order.
|
||||
"""
|
||||
c = channel.SSHChannel(conn=self.conn)
|
||||
self.assertEqual(c.localWindowSize, 131072)
|
||||
self.assertEqual(c.localWindowLeft, 131072)
|
||||
self.assertEqual(c.localMaxPacket, 32768)
|
||||
self.assertEqual(c.remoteWindowLeft, 0)
|
||||
self.assertEqual(c.remoteMaxPacket, 0)
|
||||
self.assertEqual(c.conn, self.conn)
|
||||
self.assertEqual(c.data, None)
|
||||
self.assertEqual(c.avatar, None)
|
||||
|
||||
c2 = channel.SSHChannel(1, 2, 3, 4, 5, 6, 7)
|
||||
self.assertEqual(c2.localWindowSize, 1)
|
||||
self.assertEqual(c2.localWindowLeft, 1)
|
||||
self.assertEqual(c2.localMaxPacket, 2)
|
||||
self.assertEqual(c2.remoteWindowLeft, 3)
|
||||
self.assertEqual(c2.remoteMaxPacket, 4)
|
||||
self.assertEqual(c2.conn, 5)
|
||||
self.assertEqual(c2.data, 6)
|
||||
self.assertEqual(c2.avatar, 7)
|
||||
|
||||
def test_str(self):
|
||||
"""
|
||||
Test that str(SSHChannel) works gives the channel name and local and
|
||||
remote windows at a glance..
|
||||
"""
|
||||
self.assertEqual(str(self.channel), '<SSHChannel channel (lw 131072 '
|
||||
'rw 0)>')
|
||||
|
||||
def test_logPrefix(self):
|
||||
"""
|
||||
Test that SSHChannel.logPrefix gives the name of the channel, the
|
||||
local channel ID and the underlying connection.
|
||||
"""
|
||||
self.assertEqual(self.channel.logPrefix(), 'SSHChannel channel '
|
||||
'(unknown) on MockConnection')
|
||||
|
||||
def test_addWindowBytes(self):
|
||||
"""
|
||||
Test that addWindowBytes adds bytes to the window and resumes writing
|
||||
if it was paused.
|
||||
"""
|
||||
cb = [False]
|
||||
def stubStartWriting():
|
||||
cb[0] = True
|
||||
self.channel.startWriting = stubStartWriting
|
||||
self.channel.write('test')
|
||||
self.channel.writeExtended(1, 'test')
|
||||
self.channel.addWindowBytes(50)
|
||||
self.assertEqual(self.channel.remoteWindowLeft, 50 - 4 - 4)
|
||||
self.assertTrue(self.channel.areWriting)
|
||||
self.assertTrue(cb[0])
|
||||
self.assertEqual(self.channel.buf, '')
|
||||
self.assertEqual(self.conn.data[self.channel], ['test'])
|
||||
self.assertEqual(self.channel.extBuf, [])
|
||||
self.assertEqual(self.conn.extData[self.channel], [(1, 'test')])
|
||||
|
||||
cb[0] = False
|
||||
self.channel.addWindowBytes(20)
|
||||
self.assertFalse(cb[0])
|
||||
|
||||
self.channel.write('a'*80)
|
||||
self.channel.loseConnection()
|
||||
self.channel.addWindowBytes(20)
|
||||
self.assertFalse(cb[0])
|
||||
|
||||
def test_requestReceived(self):
|
||||
"""
|
||||
Test that requestReceived handles requests by dispatching them to
|
||||
request_* methods.
|
||||
"""
|
||||
self.channel.request_test_method = lambda data: data == ''
|
||||
self.assertTrue(self.channel.requestReceived('test-method', ''))
|
||||
self.assertFalse(self.channel.requestReceived('test-method', 'a'))
|
||||
self.assertFalse(self.channel.requestReceived('bad-method', ''))
|
||||
|
||||
def test_closeReceieved(self):
|
||||
"""
|
||||
Test that the default closeReceieved closes the connection.
|
||||
"""
|
||||
self.assertFalse(self.channel.closing)
|
||||
self.channel.closeReceived()
|
||||
self.assertTrue(self.channel.closing)
|
||||
|
||||
def test_write(self):
|
||||
"""
|
||||
Test that write handles data correctly. Send data up to the size
|
||||
of the remote window, splitting the data into packets of length
|
||||
remoteMaxPacket.
|
||||
"""
|
||||
cb = [False]
|
||||
def stubStopWriting():
|
||||
cb[0] = True
|
||||
# no window to start with
|
||||
self.channel.stopWriting = stubStopWriting
|
||||
self.channel.write('d')
|
||||
self.channel.write('a')
|
||||
self.assertFalse(self.channel.areWriting)
|
||||
self.assertTrue(cb[0])
|
||||
# regular write
|
||||
self.channel.addWindowBytes(20)
|
||||
self.channel.write('ta')
|
||||
data = self.conn.data[self.channel]
|
||||
self.assertEqual(data, ['da', 'ta'])
|
||||
self.assertEqual(self.channel.remoteWindowLeft, 16)
|
||||
# larger than max packet
|
||||
self.channel.write('12345678901')
|
||||
self.assertEqual(data, ['da', 'ta', '1234567890', '1'])
|
||||
self.assertEqual(self.channel.remoteWindowLeft, 5)
|
||||
# running out of window
|
||||
cb[0] = False
|
||||
self.channel.write('123456')
|
||||
self.assertFalse(self.channel.areWriting)
|
||||
self.assertTrue(cb[0])
|
||||
self.assertEqual(data, ['da', 'ta', '1234567890', '1', '12345'])
|
||||
self.assertEqual(self.channel.buf, '6')
|
||||
self.assertEqual(self.channel.remoteWindowLeft, 0)
|
||||
|
||||
def test_writeExtended(self):
|
||||
"""
|
||||
Test that writeExtended handles data correctly. Send extended data
|
||||
up to the size of the window, splitting the extended data into packets
|
||||
of length remoteMaxPacket.
|
||||
"""
|
||||
cb = [False]
|
||||
def stubStopWriting():
|
||||
cb[0] = True
|
||||
# no window to start with
|
||||
self.channel.stopWriting = stubStopWriting
|
||||
self.channel.writeExtended(1, 'd')
|
||||
self.channel.writeExtended(1, 'a')
|
||||
self.channel.writeExtended(2, 't')
|
||||
self.assertFalse(self.channel.areWriting)
|
||||
self.assertTrue(cb[0])
|
||||
# regular write
|
||||
self.channel.addWindowBytes(20)
|
||||
self.channel.writeExtended(2, 'a')
|
||||
data = self.conn.extData[self.channel]
|
||||
self.assertEqual(data, [(1, 'da'), (2, 't'), (2, 'a')])
|
||||
self.assertEqual(self.channel.remoteWindowLeft, 16)
|
||||
# larger than max packet
|
||||
self.channel.writeExtended(3, '12345678901')
|
||||
self.assertEqual(data, [(1, 'da'), (2, 't'), (2, 'a'),
|
||||
(3, '1234567890'), (3, '1')])
|
||||
self.assertEqual(self.channel.remoteWindowLeft, 5)
|
||||
# running out of window
|
||||
cb[0] = False
|
||||
self.channel.writeExtended(4, '123456')
|
||||
self.assertFalse(self.channel.areWriting)
|
||||
self.assertTrue(cb[0])
|
||||
self.assertEqual(data, [(1, 'da'), (2, 't'), (2, 'a'),
|
||||
(3, '1234567890'), (3, '1'), (4, '12345')])
|
||||
self.assertEqual(self.channel.extBuf, [[4, '6']])
|
||||
self.assertEqual(self.channel.remoteWindowLeft, 0)
|
||||
|
||||
def test_writeSequence(self):
|
||||
"""
|
||||
Test that writeSequence is equivalent to write(''.join(sequece)).
|
||||
"""
|
||||
self.channel.addWindowBytes(20)
|
||||
self.channel.writeSequence(map(str, range(10)))
|
||||
self.assertEqual(self.conn.data[self.channel], ['0123456789'])
|
||||
|
||||
def test_loseConnection(self):
|
||||
"""
|
||||
Tesyt that loseConnection() doesn't close the channel until all
|
||||
the data is sent.
|
||||
"""
|
||||
self.channel.write('data')
|
||||
self.channel.writeExtended(1, 'datadata')
|
||||
self.channel.loseConnection()
|
||||
self.assertEqual(self.conn.closes.get(self.channel), None)
|
||||
self.channel.addWindowBytes(4) # send regular data
|
||||
self.assertEqual(self.conn.closes.get(self.channel), None)
|
||||
self.channel.addWindowBytes(8) # send extended data
|
||||
self.assertTrue(self.conn.closes.get(self.channel))
|
||||
|
||||
def test_getPeer(self):
|
||||
"""
|
||||
Test that getPeer() returns ('SSH', <connection transport peer>).
|
||||
"""
|
||||
self.assertEqual(self.channel.getPeer(), ('SSH', 'MockPeer'))
|
||||
|
||||
def test_getHost(self):
|
||||
"""
|
||||
Test that getHost() returns ('SSH', <connection transport host>).
|
||||
"""
|
||||
self.assertEqual(self.channel.getHost(), ('SSH', 'MockHost'))
|
||||
|
|
@ -0,0 +1,603 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.conch.checkers}.
|
||||
"""
|
||||
|
||||
try:
|
||||
import crypt
|
||||
except ImportError:
|
||||
cryptSkip = 'cannot run without crypt module'
|
||||
else:
|
||||
cryptSkip = None
|
||||
|
||||
import os, base64
|
||||
|
||||
from twisted.python import util
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.trial.unittest import TestCase
|
||||
from twisted.python.filepath import FilePath
|
||||
from twisted.cred.checkers import InMemoryUsernamePasswordDatabaseDontUse
|
||||
from twisted.cred.credentials import UsernamePassword, IUsernamePassword, \
|
||||
SSHPrivateKey, ISSHPrivateKey
|
||||
from twisted.cred.error import UnhandledCredentials, UnauthorizedLogin
|
||||
from twisted.python.fakepwd import UserDatabase, ShadowDatabase
|
||||
from twisted.test.test_process import MockOS
|
||||
|
||||
try:
|
||||
import Crypto.Cipher.DES3
|
||||
import pyasn1
|
||||
except ImportError:
|
||||
dependencySkip = "can't run without Crypto and PyASN1"
|
||||
else:
|
||||
dependencySkip = None
|
||||
from twisted.conch.ssh import keys
|
||||
from twisted.conch import checkers
|
||||
from twisted.conch.error import NotEnoughAuthentication, ValidPublicKey
|
||||
from twisted.conch.test import keydata
|
||||
|
||||
if getattr(os, 'geteuid', None) is None:
|
||||
euidSkip = "Cannot run without effective UIDs (questionable)"
|
||||
else:
|
||||
euidSkip = None
|
||||
|
||||
|
||||
class HelperTests(TestCase):
|
||||
"""
|
||||
Tests for helper functions L{verifyCryptedPassword}, L{_pwdGetByName} and
|
||||
L{_shadowGetByName}.
|
||||
"""
|
||||
skip = cryptSkip or dependencySkip
|
||||
|
||||
def setUp(self):
|
||||
self.mockos = MockOS()
|
||||
|
||||
|
||||
def test_verifyCryptedPassword(self):
|
||||
"""
|
||||
L{verifyCryptedPassword} returns C{True} if the plaintext password
|
||||
passed to it matches the encrypted password passed to it.
|
||||
"""
|
||||
password = 'secret string'
|
||||
salt = 'salty'
|
||||
crypted = crypt.crypt(password, salt)
|
||||
self.assertTrue(
|
||||
checkers.verifyCryptedPassword(crypted, password),
|
||||
'%r supposed to be valid encrypted password for %r' % (
|
||||
crypted, password))
|
||||
|
||||
|
||||
def test_verifyCryptedPasswordMD5(self):
|
||||
"""
|
||||
L{verifyCryptedPassword} returns True if the provided cleartext password
|
||||
matches the provided MD5 password hash.
|
||||
"""
|
||||
password = 'password'
|
||||
salt = '$1$salt'
|
||||
crypted = crypt.crypt(password, salt)
|
||||
self.assertTrue(
|
||||
checkers.verifyCryptedPassword(crypted, password),
|
||||
'%r supposed to be valid encrypted password for %s' % (
|
||||
crypted, password))
|
||||
|
||||
|
||||
def test_refuteCryptedPassword(self):
|
||||
"""
|
||||
L{verifyCryptedPassword} returns C{False} if the plaintext password
|
||||
passed to it does not match the encrypted password passed to it.
|
||||
"""
|
||||
password = 'string secret'
|
||||
wrong = 'secret string'
|
||||
crypted = crypt.crypt(password, password)
|
||||
self.assertFalse(
|
||||
checkers.verifyCryptedPassword(crypted, wrong),
|
||||
'%r not supposed to be valid encrypted password for %s' % (
|
||||
crypted, wrong))
|
||||
|
||||
|
||||
def test_pwdGetByName(self):
|
||||
"""
|
||||
L{_pwdGetByName} returns a tuple of items from the UNIX /etc/passwd
|
||||
database if the L{pwd} module is present.
|
||||
"""
|
||||
userdb = UserDatabase()
|
||||
userdb.addUser(
|
||||
'alice', 'secrit', 1, 2, 'first last', '/foo', '/bin/sh')
|
||||
self.patch(checkers, 'pwd', userdb)
|
||||
self.assertEqual(
|
||||
checkers._pwdGetByName('alice'), userdb.getpwnam('alice'))
|
||||
|
||||
|
||||
def test_pwdGetByNameWithoutPwd(self):
|
||||
"""
|
||||
If the C{pwd} module isn't present, L{_pwdGetByName} returns C{None}.
|
||||
"""
|
||||
self.patch(checkers, 'pwd', None)
|
||||
self.assertIs(checkers._pwdGetByName('alice'), None)
|
||||
|
||||
|
||||
def test_shadowGetByName(self):
|
||||
"""
|
||||
L{_shadowGetByName} returns a tuple of items from the UNIX /etc/shadow
|
||||
database if the L{spwd} is present.
|
||||
"""
|
||||
userdb = ShadowDatabase()
|
||||
userdb.addUser('bob', 'passphrase', 1, 2, 3, 4, 5, 6, 7)
|
||||
self.patch(checkers, 'spwd', userdb)
|
||||
|
||||
self.mockos.euid = 2345
|
||||
self.mockos.egid = 1234
|
||||
self.patch(util, 'os', self.mockos)
|
||||
|
||||
self.assertEqual(
|
||||
checkers._shadowGetByName('bob'), userdb.getspnam('bob'))
|
||||
self.assertEqual(self.mockos.seteuidCalls, [0, 2345])
|
||||
self.assertEqual(self.mockos.setegidCalls, [0, 1234])
|
||||
|
||||
|
||||
def test_shadowGetByNameWithoutSpwd(self):
|
||||
"""
|
||||
L{_shadowGetByName} uses the C{shadow} module to return a tuple of items
|
||||
from the UNIX /etc/shadow database if the C{spwd} module is not present
|
||||
and the C{shadow} module is.
|
||||
"""
|
||||
userdb = ShadowDatabase()
|
||||
userdb.addUser('bob', 'passphrase', 1, 2, 3, 4, 5, 6, 7)
|
||||
self.patch(checkers, 'spwd', None)
|
||||
self.patch(checkers, 'shadow', userdb)
|
||||
self.patch(util, 'os', self.mockos)
|
||||
|
||||
self.mockos.euid = 2345
|
||||
self.mockos.egid = 1234
|
||||
|
||||
self.assertEqual(
|
||||
checkers._shadowGetByName('bob'), userdb.getspnam('bob'))
|
||||
self.assertEqual(self.mockos.seteuidCalls, [0, 2345])
|
||||
self.assertEqual(self.mockos.setegidCalls, [0, 1234])
|
||||
|
||||
|
||||
def test_shadowGetByNameWithoutEither(self):
|
||||
"""
|
||||
L{_shadowGetByName} returns C{None} if neither C{spwd} nor C{shadow} is
|
||||
present.
|
||||
"""
|
||||
self.patch(checkers, 'spwd', None)
|
||||
self.patch(checkers, 'shadow', None)
|
||||
|
||||
self.assertIs(checkers._shadowGetByName('bob'), None)
|
||||
self.assertEqual(self.mockos.seteuidCalls, [])
|
||||
self.assertEqual(self.mockos.setegidCalls, [])
|
||||
|
||||
|
||||
|
||||
class SSHPublicKeyDatabaseTestCase(TestCase):
|
||||
"""
|
||||
Tests for L{SSHPublicKeyDatabase}.
|
||||
"""
|
||||
skip = euidSkip or dependencySkip
|
||||
|
||||
def setUp(self):
|
||||
self.checker = checkers.SSHPublicKeyDatabase()
|
||||
self.key1 = base64.encodestring("foobar")
|
||||
self.key2 = base64.encodestring("eggspam")
|
||||
self.content = "t1 %s foo\nt2 %s egg\n" % (self.key1, self.key2)
|
||||
|
||||
self.mockos = MockOS()
|
||||
self.mockos.path = FilePath(self.mktemp())
|
||||
self.mockos.path.makedirs()
|
||||
self.patch(util, 'os', self.mockos)
|
||||
self.sshDir = self.mockos.path.child('.ssh')
|
||||
self.sshDir.makedirs()
|
||||
|
||||
userdb = UserDatabase()
|
||||
userdb.addUser(
|
||||
'user', 'password', 1, 2, 'first last',
|
||||
self.mockos.path.path, '/bin/shell')
|
||||
self.checker._userdb = userdb
|
||||
|
||||
|
||||
def _testCheckKey(self, filename):
|
||||
self.sshDir.child(filename).setContent(self.content)
|
||||
user = UsernamePassword("user", "password")
|
||||
user.blob = "foobar"
|
||||
self.assertTrue(self.checker.checkKey(user))
|
||||
user.blob = "eggspam"
|
||||
self.assertTrue(self.checker.checkKey(user))
|
||||
user.blob = "notallowed"
|
||||
self.assertFalse(self.checker.checkKey(user))
|
||||
|
||||
|
||||
def test_checkKey(self):
|
||||
"""
|
||||
L{SSHPublicKeyDatabase.checkKey} should retrieve the content of the
|
||||
authorized_keys file and check the keys against that file.
|
||||
"""
|
||||
self._testCheckKey("authorized_keys")
|
||||
self.assertEqual(self.mockos.seteuidCalls, [])
|
||||
self.assertEqual(self.mockos.setegidCalls, [])
|
||||
|
||||
|
||||
def test_checkKey2(self):
|
||||
"""
|
||||
L{SSHPublicKeyDatabase.checkKey} should retrieve the content of the
|
||||
authorized_keys2 file and check the keys against that file.
|
||||
"""
|
||||
self._testCheckKey("authorized_keys2")
|
||||
self.assertEqual(self.mockos.seteuidCalls, [])
|
||||
self.assertEqual(self.mockos.setegidCalls, [])
|
||||
|
||||
|
||||
def test_checkKeyAsRoot(self):
|
||||
"""
|
||||
If the key file is readable, L{SSHPublicKeyDatabase.checkKey} should
|
||||
switch its uid/gid to the ones of the authenticated user.
|
||||
"""
|
||||
keyFile = self.sshDir.child("authorized_keys")
|
||||
keyFile.setContent(self.content)
|
||||
# Fake permission error by changing the mode
|
||||
keyFile.chmod(0000)
|
||||
self.addCleanup(keyFile.chmod, 0777)
|
||||
# And restore the right mode when seteuid is called
|
||||
savedSeteuid = self.mockos.seteuid
|
||||
def seteuid(euid):
|
||||
keyFile.chmod(0777)
|
||||
return savedSeteuid(euid)
|
||||
self.mockos.euid = 2345
|
||||
self.mockos.egid = 1234
|
||||
self.patch(self.mockos, "seteuid", seteuid)
|
||||
self.patch(util, 'os', self.mockos)
|
||||
user = UsernamePassword("user", "password")
|
||||
user.blob = "foobar"
|
||||
self.assertTrue(self.checker.checkKey(user))
|
||||
self.assertEqual(self.mockos.seteuidCalls, [0, 1, 0, 2345])
|
||||
self.assertEqual(self.mockos.setegidCalls, [2, 1234])
|
||||
|
||||
|
||||
def test_requestAvatarId(self):
|
||||
"""
|
||||
L{SSHPublicKeyDatabase.requestAvatarId} should return the avatar id
|
||||
passed in if its C{_checkKey} method returns True.
|
||||
"""
|
||||
def _checkKey(ignored):
|
||||
return True
|
||||
self.patch(self.checker, 'checkKey', _checkKey)
|
||||
credentials = SSHPrivateKey(
|
||||
'test', 'ssh-rsa', keydata.publicRSA_openssh, 'foo',
|
||||
keys.Key.fromString(keydata.privateRSA_openssh).sign('foo'))
|
||||
d = self.checker.requestAvatarId(credentials)
|
||||
def _verify(avatarId):
|
||||
self.assertEqual(avatarId, 'test')
|
||||
return d.addCallback(_verify)
|
||||
|
||||
|
||||
def test_requestAvatarIdWithoutSignature(self):
|
||||
"""
|
||||
L{SSHPublicKeyDatabase.requestAvatarId} should raise L{ValidPublicKey}
|
||||
if the credentials represent a valid key without a signature. This
|
||||
tells the user that the key is valid for login, but does not actually
|
||||
allow that user to do so without a signature.
|
||||
"""
|
||||
def _checkKey(ignored):
|
||||
return True
|
||||
self.patch(self.checker, 'checkKey', _checkKey)
|
||||
credentials = SSHPrivateKey(
|
||||
'test', 'ssh-rsa', keydata.publicRSA_openssh, None, None)
|
||||
d = self.checker.requestAvatarId(credentials)
|
||||
return self.assertFailure(d, ValidPublicKey)
|
||||
|
||||
|
||||
def test_requestAvatarIdInvalidKey(self):
|
||||
"""
|
||||
If L{SSHPublicKeyDatabase.checkKey} returns False,
|
||||
C{_cbRequestAvatarId} should raise L{UnauthorizedLogin}.
|
||||
"""
|
||||
def _checkKey(ignored):
|
||||
return False
|
||||
self.patch(self.checker, 'checkKey', _checkKey)
|
||||
d = self.checker.requestAvatarId(None);
|
||||
return self.assertFailure(d, UnauthorizedLogin)
|
||||
|
||||
|
||||
def test_requestAvatarIdInvalidSignature(self):
|
||||
"""
|
||||
Valid keys with invalid signatures should cause
|
||||
L{SSHPublicKeyDatabase.requestAvatarId} to return a {UnauthorizedLogin}
|
||||
failure
|
||||
"""
|
||||
def _checkKey(ignored):
|
||||
return True
|
||||
self.patch(self.checker, 'checkKey', _checkKey)
|
||||
credentials = SSHPrivateKey(
|
||||
'test', 'ssh-rsa', keydata.publicRSA_openssh, 'foo',
|
||||
keys.Key.fromString(keydata.privateDSA_openssh).sign('foo'))
|
||||
d = self.checker.requestAvatarId(credentials)
|
||||
return self.assertFailure(d, UnauthorizedLogin)
|
||||
|
||||
|
||||
def test_requestAvatarIdNormalizeException(self):
|
||||
"""
|
||||
Exceptions raised while verifying the key should be normalized into an
|
||||
C{UnauthorizedLogin} failure.
|
||||
"""
|
||||
def _checkKey(ignored):
|
||||
return True
|
||||
self.patch(self.checker, 'checkKey', _checkKey)
|
||||
credentials = SSHPrivateKey('test', None, 'blob', 'sigData', 'sig')
|
||||
d = self.checker.requestAvatarId(credentials)
|
||||
def _verifyLoggedException(failure):
|
||||
errors = self.flushLoggedErrors(keys.BadKeyError)
|
||||
self.assertEqual(len(errors), 1)
|
||||
return failure
|
||||
d.addErrback(_verifyLoggedException)
|
||||
return self.assertFailure(d, UnauthorizedLogin)
|
||||
|
||||
|
||||
|
||||
class SSHProtocolCheckerTestCase(TestCase):
|
||||
"""
|
||||
Tests for L{SSHProtocolChecker}.
|
||||
"""
|
||||
|
||||
skip = dependencySkip
|
||||
|
||||
def test_registerChecker(self):
|
||||
"""
|
||||
L{SSHProcotolChecker.registerChecker} should add the given checker to
|
||||
the list of registered checkers.
|
||||
"""
|
||||
checker = checkers.SSHProtocolChecker()
|
||||
self.assertEqual(checker.credentialInterfaces, [])
|
||||
checker.registerChecker(checkers.SSHPublicKeyDatabase(), )
|
||||
self.assertEqual(checker.credentialInterfaces, [ISSHPrivateKey])
|
||||
self.assertIsInstance(checker.checkers[ISSHPrivateKey],
|
||||
checkers.SSHPublicKeyDatabase)
|
||||
|
||||
|
||||
def test_registerCheckerWithInterface(self):
|
||||
"""
|
||||
If a apecific interface is passed into
|
||||
L{SSHProtocolChecker.registerChecker}, that interface should be
|
||||
registered instead of what the checker specifies in
|
||||
credentialIntefaces.
|
||||
"""
|
||||
checker = checkers.SSHProtocolChecker()
|
||||
self.assertEqual(checker.credentialInterfaces, [])
|
||||
checker.registerChecker(checkers.SSHPublicKeyDatabase(),
|
||||
IUsernamePassword)
|
||||
self.assertEqual(checker.credentialInterfaces, [IUsernamePassword])
|
||||
self.assertIsInstance(checker.checkers[IUsernamePassword],
|
||||
checkers.SSHPublicKeyDatabase)
|
||||
|
||||
|
||||
def test_requestAvatarId(self):
|
||||
"""
|
||||
L{SSHProtocolChecker.requestAvatarId} should defer to one if its
|
||||
registered checkers to authenticate a user.
|
||||
"""
|
||||
checker = checkers.SSHProtocolChecker()
|
||||
passwordDatabase = InMemoryUsernamePasswordDatabaseDontUse()
|
||||
passwordDatabase.addUser('test', 'test')
|
||||
checker.registerChecker(passwordDatabase)
|
||||
d = checker.requestAvatarId(UsernamePassword('test', 'test'))
|
||||
def _callback(avatarId):
|
||||
self.assertEqual(avatarId, 'test')
|
||||
return d.addCallback(_callback)
|
||||
|
||||
|
||||
def test_requestAvatarIdWithNotEnoughAuthentication(self):
|
||||
"""
|
||||
If the client indicates that it is never satisfied, by always returning
|
||||
False from _areDone, then L{SSHProtocolChecker} should raise
|
||||
L{NotEnoughAuthentication}.
|
||||
"""
|
||||
checker = checkers.SSHProtocolChecker()
|
||||
def _areDone(avatarId):
|
||||
return False
|
||||
self.patch(checker, 'areDone', _areDone)
|
||||
|
||||
passwordDatabase = InMemoryUsernamePasswordDatabaseDontUse()
|
||||
passwordDatabase.addUser('test', 'test')
|
||||
checker.registerChecker(passwordDatabase)
|
||||
d = checker.requestAvatarId(UsernamePassword('test', 'test'))
|
||||
return self.assertFailure(d, NotEnoughAuthentication)
|
||||
|
||||
|
||||
def test_requestAvatarIdInvalidCredential(self):
|
||||
"""
|
||||
If the passed credentials aren't handled by any registered checker,
|
||||
L{SSHProtocolChecker} should raise L{UnhandledCredentials}.
|
||||
"""
|
||||
checker = checkers.SSHProtocolChecker()
|
||||
d = checker.requestAvatarId(UsernamePassword('test', 'test'))
|
||||
return self.assertFailure(d, UnhandledCredentials)
|
||||
|
||||
|
||||
def test_areDone(self):
|
||||
"""
|
||||
The default L{SSHProcotolChecker.areDone} should simply return True.
|
||||
"""
|
||||
self.assertEqual(checkers.SSHProtocolChecker().areDone(None), True)
|
||||
|
||||
|
||||
|
||||
class UNIXPasswordDatabaseTests(TestCase):
|
||||
"""
|
||||
Tests for L{UNIXPasswordDatabase}.
|
||||
"""
|
||||
skip = cryptSkip or dependencySkip
|
||||
|
||||
def assertLoggedIn(self, d, username):
|
||||
"""
|
||||
Assert that the L{Deferred} passed in is called back with the value
|
||||
'username'. This represents a valid login for this TestCase.
|
||||
|
||||
NOTE: To work, this method's return value must be returned from the
|
||||
test method, or otherwise hooked up to the test machinery.
|
||||
|
||||
@param d: a L{Deferred} from an L{IChecker.requestAvatarId} method.
|
||||
@type d: L{Deferred}
|
||||
@rtype: L{Deferred}
|
||||
"""
|
||||
result = []
|
||||
d.addBoth(result.append)
|
||||
self.assertEqual(len(result), 1, "login incomplete")
|
||||
if isinstance(result[0], Failure):
|
||||
result[0].raiseException()
|
||||
self.assertEqual(result[0], username)
|
||||
|
||||
|
||||
def test_defaultCheckers(self):
|
||||
"""
|
||||
L{UNIXPasswordDatabase} with no arguments has checks the C{pwd} database
|
||||
and then the C{spwd} database.
|
||||
"""
|
||||
checker = checkers.UNIXPasswordDatabase()
|
||||
|
||||
def crypted(username, password):
|
||||
salt = crypt.crypt(password, username)
|
||||
crypted = crypt.crypt(password, '$1$' + salt)
|
||||
return crypted
|
||||
|
||||
pwd = UserDatabase()
|
||||
pwd.addUser('alice', crypted('alice', 'password'),
|
||||
1, 2, 'foo', '/foo', '/bin/sh')
|
||||
# x and * are convention for "look elsewhere for the password"
|
||||
pwd.addUser('bob', 'x', 1, 2, 'bar', '/bar', '/bin/sh')
|
||||
spwd = ShadowDatabase()
|
||||
spwd.addUser('alice', 'wrong', 1, 2, 3, 4, 5, 6, 7)
|
||||
spwd.addUser('bob', crypted('bob', 'password'),
|
||||
8, 9, 10, 11, 12, 13, 14)
|
||||
|
||||
self.patch(checkers, 'pwd', pwd)
|
||||
self.patch(checkers, 'spwd', spwd)
|
||||
|
||||
mockos = MockOS()
|
||||
self.patch(util, 'os', mockos)
|
||||
|
||||
mockos.euid = 2345
|
||||
mockos.egid = 1234
|
||||
|
||||
cred = UsernamePassword("alice", "password")
|
||||
self.assertLoggedIn(checker.requestAvatarId(cred), 'alice')
|
||||
self.assertEqual(mockos.seteuidCalls, [])
|
||||
self.assertEqual(mockos.setegidCalls, [])
|
||||
cred.username = "bob"
|
||||
self.assertLoggedIn(checker.requestAvatarId(cred), 'bob')
|
||||
self.assertEqual(mockos.seteuidCalls, [0, 2345])
|
||||
self.assertEqual(mockos.setegidCalls, [0, 1234])
|
||||
|
||||
|
||||
def assertUnauthorizedLogin(self, d):
|
||||
"""
|
||||
Asserts that the L{Deferred} passed in is erred back with an
|
||||
L{UnauthorizedLogin} L{Failure}. This reprsents an invalid login for
|
||||
this TestCase.
|
||||
|
||||
NOTE: To work, this method's return value must be returned from the
|
||||
test method, or otherwise hooked up to the test machinery.
|
||||
|
||||
@param d: a L{Deferred} from an L{IChecker.requestAvatarId} method.
|
||||
@type d: L{Deferred}
|
||||
@rtype: L{None}
|
||||
"""
|
||||
self.assertRaises(
|
||||
checkers.UnauthorizedLogin, self.assertLoggedIn, d, 'bogus value')
|
||||
|
||||
|
||||
def test_passInCheckers(self):
|
||||
"""
|
||||
L{UNIXPasswordDatabase} takes a list of functions to check for UNIX
|
||||
user information.
|
||||
"""
|
||||
password = crypt.crypt('secret', 'secret')
|
||||
userdb = UserDatabase()
|
||||
userdb.addUser('anybody', password, 1, 2, 'foo', '/bar', '/bin/sh')
|
||||
checker = checkers.UNIXPasswordDatabase([userdb.getpwnam])
|
||||
self.assertLoggedIn(
|
||||
checker.requestAvatarId(UsernamePassword('anybody', 'secret')),
|
||||
'anybody')
|
||||
|
||||
|
||||
def test_verifyPassword(self):
|
||||
"""
|
||||
If the encrypted password provided by the getpwnam function is valid
|
||||
(verified by the L{verifyCryptedPassword} function), we callback the
|
||||
C{requestAvatarId} L{Deferred} with the username.
|
||||
"""
|
||||
def verifyCryptedPassword(crypted, pw):
|
||||
return crypted == pw
|
||||
def getpwnam(username):
|
||||
return [username, username]
|
||||
self.patch(checkers, 'verifyCryptedPassword', verifyCryptedPassword)
|
||||
checker = checkers.UNIXPasswordDatabase([getpwnam])
|
||||
credential = UsernamePassword('username', 'username')
|
||||
self.assertLoggedIn(checker.requestAvatarId(credential), 'username')
|
||||
|
||||
|
||||
def test_failOnKeyError(self):
|
||||
"""
|
||||
If the getpwnam function raises a KeyError, the login fails with an
|
||||
L{UnauthorizedLogin} exception.
|
||||
"""
|
||||
def getpwnam(username):
|
||||
raise KeyError(username)
|
||||
checker = checkers.UNIXPasswordDatabase([getpwnam])
|
||||
credential = UsernamePassword('username', 'username')
|
||||
self.assertUnauthorizedLogin(checker.requestAvatarId(credential))
|
||||
|
||||
|
||||
def test_failOnBadPassword(self):
|
||||
"""
|
||||
If the verifyCryptedPassword function doesn't verify the password, the
|
||||
login fails with an L{UnauthorizedLogin} exception.
|
||||
"""
|
||||
def verifyCryptedPassword(crypted, pw):
|
||||
return False
|
||||
def getpwnam(username):
|
||||
return [username, username]
|
||||
self.patch(checkers, 'verifyCryptedPassword', verifyCryptedPassword)
|
||||
checker = checkers.UNIXPasswordDatabase([getpwnam])
|
||||
credential = UsernamePassword('username', 'username')
|
||||
self.assertUnauthorizedLogin(checker.requestAvatarId(credential))
|
||||
|
||||
|
||||
def test_loopThroughFunctions(self):
|
||||
"""
|
||||
UNIXPasswordDatabase.requestAvatarId loops through each getpwnam
|
||||
function associated with it and returns a L{Deferred} which fires with
|
||||
the result of the first one which returns a value other than None.
|
||||
ones do not verify the password.
|
||||
"""
|
||||
def verifyCryptedPassword(crypted, pw):
|
||||
return crypted == pw
|
||||
def getpwnam1(username):
|
||||
return [username, 'not the password']
|
||||
def getpwnam2(username):
|
||||
return [username, username]
|
||||
self.patch(checkers, 'verifyCryptedPassword', verifyCryptedPassword)
|
||||
checker = checkers.UNIXPasswordDatabase([getpwnam1, getpwnam2])
|
||||
credential = UsernamePassword('username', 'username')
|
||||
self.assertLoggedIn(checker.requestAvatarId(credential), 'username')
|
||||
|
||||
|
||||
def test_failOnSpecial(self):
|
||||
"""
|
||||
If the password returned by any function is C{""}, C{"x"}, or C{"*"} it
|
||||
is not compared against the supplied password. Instead it is skipped.
|
||||
"""
|
||||
pwd = UserDatabase()
|
||||
pwd.addUser('alice', '', 1, 2, '', 'foo', 'bar')
|
||||
pwd.addUser('bob', 'x', 1, 2, '', 'foo', 'bar')
|
||||
pwd.addUser('carol', '*', 1, 2, '', 'foo', 'bar')
|
||||
self.patch(checkers, 'pwd', pwd)
|
||||
|
||||
checker = checkers.UNIXPasswordDatabase([checkers._pwdGetByName])
|
||||
cred = UsernamePassword('alice', '')
|
||||
self.assertUnauthorizedLogin(checker.requestAvatarId(cred))
|
||||
|
||||
cred = UsernamePassword('bob', 'x')
|
||||
self.assertUnauthorizedLogin(checker.requestAvatarId(cred))
|
||||
|
||||
cred = UsernamePassword('carol', '*')
|
||||
self.assertUnauthorizedLogin(checker.requestAvatarId(cred))
|
||||
|
|
@ -0,0 +1,340 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.conch.scripts.ckeygen}.
|
||||
"""
|
||||
|
||||
import getpass
|
||||
import sys
|
||||
from StringIO import StringIO
|
||||
|
||||
try:
|
||||
import Crypto
|
||||
import pyasn1
|
||||
except ImportError:
|
||||
skip = "PyCrypto and pyasn1 required for twisted.conch.scripts.ckeygen."
|
||||
else:
|
||||
from twisted.conch.ssh.keys import Key, BadKeyError
|
||||
from twisted.conch.scripts.ckeygen import (
|
||||
changePassPhrase, displayPublicKey, printFingerprint, _saveKey)
|
||||
|
||||
from twisted.python.filepath import FilePath
|
||||
from twisted.trial.unittest import TestCase
|
||||
from twisted.conch.test.keydata import (
|
||||
publicRSA_openssh, privateRSA_openssh, privateRSA_openssh_encrypted)
|
||||
|
||||
|
||||
|
||||
def makeGetpass(*passphrases):
|
||||
"""
|
||||
Return a callable to patch C{getpass.getpass}. Yields a passphrase each
|
||||
time called. Use case is to provide an old, then new passphrase(s) as if
|
||||
requested interactively.
|
||||
|
||||
@param passphrases: The list of passphrases returned, one per each call.
|
||||
"""
|
||||
passphrases = iter(passphrases)
|
||||
|
||||
def fakeGetpass(_):
|
||||
return passphrases.next()
|
||||
|
||||
return fakeGetpass
|
||||
|
||||
|
||||
|
||||
class KeyGenTests(TestCase):
|
||||
"""
|
||||
Tests for various functions used to implement the I{ckeygen} script.
|
||||
"""
|
||||
def setUp(self):
|
||||
"""
|
||||
Patch C{sys.stdout} with a L{StringIO} instance to tests can make
|
||||
assertions about what's printed.
|
||||
"""
|
||||
self.stdout = StringIO()
|
||||
self.patch(sys, 'stdout', self.stdout)
|
||||
|
||||
|
||||
def test_printFingerprint(self):
|
||||
"""
|
||||
L{printFingerprint} writes a line to standard out giving the number of
|
||||
bits of the key, its fingerprint, and the basename of the file from it
|
||||
was read.
|
||||
"""
|
||||
filename = self.mktemp()
|
||||
FilePath(filename).setContent(publicRSA_openssh)
|
||||
printFingerprint({'filename': filename})
|
||||
self.assertEqual(
|
||||
self.stdout.getvalue(),
|
||||
'768 3d:13:5f:cb:c9:79:8a:93:06:27:65:bc:3d:0b:8f:af temp\n')
|
||||
|
||||
|
||||
def test_saveKey(self):
|
||||
"""
|
||||
L{_saveKey} writes the private and public parts of a key to two
|
||||
different files and writes a report of this to standard out.
|
||||
"""
|
||||
base = FilePath(self.mktemp())
|
||||
base.makedirs()
|
||||
filename = base.child('id_rsa').path
|
||||
key = Key.fromString(privateRSA_openssh)
|
||||
_saveKey(
|
||||
key.keyObject,
|
||||
{'filename': filename, 'pass': 'passphrase'})
|
||||
self.assertEqual(
|
||||
self.stdout.getvalue(),
|
||||
"Your identification has been saved in %s\n"
|
||||
"Your public key has been saved in %s.pub\n"
|
||||
"The key fingerprint is:\n"
|
||||
"3d:13:5f:cb:c9:79:8a:93:06:27:65:bc:3d:0b:8f:af\n" % (
|
||||
filename,
|
||||
filename))
|
||||
self.assertEqual(
|
||||
key.fromString(
|
||||
base.child('id_rsa').getContent(), None, 'passphrase'),
|
||||
key)
|
||||
self.assertEqual(
|
||||
Key.fromString(base.child('id_rsa.pub').getContent()),
|
||||
key.public())
|
||||
|
||||
|
||||
def test_saveKeyEmptyPassphrase(self):
|
||||
"""
|
||||
L{_saveKey} will choose an empty string for the passphrase if
|
||||
no-passphrase is C{True}.
|
||||
"""
|
||||
base = FilePath(self.mktemp())
|
||||
base.makedirs()
|
||||
filename = base.child('id_rsa').path
|
||||
key = Key.fromString(privateRSA_openssh)
|
||||
_saveKey(
|
||||
key.keyObject,
|
||||
{'filename': filename, 'no-passphrase': True})
|
||||
self.assertEqual(
|
||||
key.fromString(
|
||||
base.child('id_rsa').getContent(), None, b''),
|
||||
key)
|
||||
|
||||
|
||||
def test_displayPublicKey(self):
|
||||
"""
|
||||
L{displayPublicKey} prints out the public key associated with a given
|
||||
private key.
|
||||
"""
|
||||
filename = self.mktemp()
|
||||
pubKey = Key.fromString(publicRSA_openssh)
|
||||
FilePath(filename).setContent(privateRSA_openssh)
|
||||
displayPublicKey({'filename': filename})
|
||||
self.assertEqual(
|
||||
self.stdout.getvalue().strip('\n'),
|
||||
pubKey.toString('openssh'))
|
||||
|
||||
|
||||
def test_displayPublicKeyEncrypted(self):
|
||||
"""
|
||||
L{displayPublicKey} prints out the public key associated with a given
|
||||
private key using the given passphrase when it's encrypted.
|
||||
"""
|
||||
filename = self.mktemp()
|
||||
pubKey = Key.fromString(publicRSA_openssh)
|
||||
FilePath(filename).setContent(privateRSA_openssh_encrypted)
|
||||
displayPublicKey({'filename': filename, 'pass': 'encrypted'})
|
||||
self.assertEqual(
|
||||
self.stdout.getvalue().strip('\n'),
|
||||
pubKey.toString('openssh'))
|
||||
|
||||
|
||||
def test_displayPublicKeyEncryptedPassphrasePrompt(self):
|
||||
"""
|
||||
L{displayPublicKey} prints out the public key associated with a given
|
||||
private key, asking for the passphrase when it's encrypted.
|
||||
"""
|
||||
filename = self.mktemp()
|
||||
pubKey = Key.fromString(publicRSA_openssh)
|
||||
FilePath(filename).setContent(privateRSA_openssh_encrypted)
|
||||
self.patch(getpass, 'getpass', lambda x: 'encrypted')
|
||||
displayPublicKey({'filename': filename})
|
||||
self.assertEqual(
|
||||
self.stdout.getvalue().strip('\n'),
|
||||
pubKey.toString('openssh'))
|
||||
|
||||
|
||||
def test_displayPublicKeyWrongPassphrase(self):
|
||||
"""
|
||||
L{displayPublicKey} fails with a L{BadKeyError} when trying to decrypt
|
||||
an encrypted key with the wrong password.
|
||||
"""
|
||||
filename = self.mktemp()
|
||||
FilePath(filename).setContent(privateRSA_openssh_encrypted)
|
||||
self.assertRaises(
|
||||
BadKeyError, displayPublicKey,
|
||||
{'filename': filename, 'pass': 'wrong'})
|
||||
|
||||
|
||||
def test_changePassphrase(self):
|
||||
"""
|
||||
L{changePassPhrase} allows a user to change the passphrase of a
|
||||
private key interactively.
|
||||
"""
|
||||
oldNewConfirm = makeGetpass('encrypted', 'newpass', 'newpass')
|
||||
self.patch(getpass, 'getpass', oldNewConfirm)
|
||||
|
||||
filename = self.mktemp()
|
||||
FilePath(filename).setContent(privateRSA_openssh_encrypted)
|
||||
|
||||
changePassPhrase({'filename': filename})
|
||||
self.assertEqual(
|
||||
self.stdout.getvalue().strip('\n'),
|
||||
'Your identification has been saved with the new passphrase.')
|
||||
self.assertNotEqual(privateRSA_openssh_encrypted,
|
||||
FilePath(filename).getContent())
|
||||
|
||||
|
||||
def test_changePassphraseWithOld(self):
|
||||
"""
|
||||
L{changePassPhrase} allows a user to change the passphrase of a
|
||||
private key, providing the old passphrase and prompting for new one.
|
||||
"""
|
||||
newConfirm = makeGetpass('newpass', 'newpass')
|
||||
self.patch(getpass, 'getpass', newConfirm)
|
||||
|
||||
filename = self.mktemp()
|
||||
FilePath(filename).setContent(privateRSA_openssh_encrypted)
|
||||
|
||||
changePassPhrase({'filename': filename, 'pass': 'encrypted'})
|
||||
self.assertEqual(
|
||||
self.stdout.getvalue().strip('\n'),
|
||||
'Your identification has been saved with the new passphrase.')
|
||||
self.assertNotEqual(privateRSA_openssh_encrypted,
|
||||
FilePath(filename).getContent())
|
||||
|
||||
|
||||
def test_changePassphraseWithBoth(self):
|
||||
"""
|
||||
L{changePassPhrase} allows a user to change the passphrase of a private
|
||||
key by providing both old and new passphrases without prompting.
|
||||
"""
|
||||
filename = self.mktemp()
|
||||
FilePath(filename).setContent(privateRSA_openssh_encrypted)
|
||||
|
||||
changePassPhrase(
|
||||
{'filename': filename, 'pass': 'encrypted',
|
||||
'newpass': 'newencrypt'})
|
||||
self.assertEqual(
|
||||
self.stdout.getvalue().strip('\n'),
|
||||
'Your identification has been saved with the new passphrase.')
|
||||
self.assertNotEqual(privateRSA_openssh_encrypted,
|
||||
FilePath(filename).getContent())
|
||||
|
||||
|
||||
def test_changePassphraseWrongPassphrase(self):
|
||||
"""
|
||||
L{changePassPhrase} exits if passed an invalid old passphrase when
|
||||
trying to change the passphrase of a private key.
|
||||
"""
|
||||
filename = self.mktemp()
|
||||
FilePath(filename).setContent(privateRSA_openssh_encrypted)
|
||||
error = self.assertRaises(
|
||||
SystemExit, changePassPhrase,
|
||||
{'filename': filename, 'pass': 'wrong'})
|
||||
self.assertEqual('Could not change passphrase: old passphrase error',
|
||||
str(error))
|
||||
self.assertEqual(privateRSA_openssh_encrypted,
|
||||
FilePath(filename).getContent())
|
||||
|
||||
|
||||
def test_changePassphraseEmptyGetPass(self):
|
||||
"""
|
||||
L{changePassPhrase} exits if no passphrase is specified for the
|
||||
C{getpass} call and the key is encrypted.
|
||||
"""
|
||||
self.patch(getpass, 'getpass', makeGetpass(''))
|
||||
filename = self.mktemp()
|
||||
FilePath(filename).setContent(privateRSA_openssh_encrypted)
|
||||
error = self.assertRaises(
|
||||
SystemExit, changePassPhrase, {'filename': filename})
|
||||
self.assertEqual(
|
||||
'Could not change passphrase: Passphrase must be provided '
|
||||
'for an encrypted key',
|
||||
str(error))
|
||||
self.assertEqual(privateRSA_openssh_encrypted,
|
||||
FilePath(filename).getContent())
|
||||
|
||||
|
||||
def test_changePassphraseBadKey(self):
|
||||
"""
|
||||
L{changePassPhrase} exits if the file specified points to an invalid
|
||||
key.
|
||||
"""
|
||||
filename = self.mktemp()
|
||||
FilePath(filename).setContent('foobar')
|
||||
error = self.assertRaises(
|
||||
SystemExit, changePassPhrase, {'filename': filename})
|
||||
self.assertEqual(
|
||||
"Could not change passphrase: cannot guess the type of 'foobar'",
|
||||
str(error))
|
||||
self.assertEqual('foobar', FilePath(filename).getContent())
|
||||
|
||||
|
||||
def test_changePassphraseCreateError(self):
|
||||
"""
|
||||
L{changePassPhrase} doesn't modify the key file if an unexpected error
|
||||
happens when trying to create the key with the new passphrase.
|
||||
"""
|
||||
filename = self.mktemp()
|
||||
FilePath(filename).setContent(privateRSA_openssh)
|
||||
|
||||
def toString(*args, **kwargs):
|
||||
raise RuntimeError('oops')
|
||||
|
||||
self.patch(Key, 'toString', toString)
|
||||
|
||||
error = self.assertRaises(
|
||||
SystemExit, changePassPhrase,
|
||||
{'filename': filename,
|
||||
'newpass': 'newencrypt'})
|
||||
|
||||
self.assertEqual(
|
||||
'Could not change passphrase: oops', str(error))
|
||||
|
||||
self.assertEqual(privateRSA_openssh, FilePath(filename).getContent())
|
||||
|
||||
|
||||
def test_changePassphraseEmptyStringError(self):
|
||||
"""
|
||||
L{changePassPhrase} doesn't modify the key file if C{toString} returns
|
||||
an empty string.
|
||||
"""
|
||||
filename = self.mktemp()
|
||||
FilePath(filename).setContent(privateRSA_openssh)
|
||||
|
||||
def toString(*args, **kwargs):
|
||||
return ''
|
||||
|
||||
self.patch(Key, 'toString', toString)
|
||||
|
||||
error = self.assertRaises(
|
||||
SystemExit, changePassPhrase,
|
||||
{'filename': filename, 'newpass': 'newencrypt'})
|
||||
|
||||
self.assertEqual(
|
||||
"Could not change passphrase: "
|
||||
"cannot guess the type of ''", str(error))
|
||||
|
||||
self.assertEqual(privateRSA_openssh, FilePath(filename).getContent())
|
||||
|
||||
|
||||
def test_changePassphrasePublicKey(self):
|
||||
"""
|
||||
L{changePassPhrase} exits when trying to change the passphrase on a
|
||||
public key, and doesn't change the file.
|
||||
"""
|
||||
filename = self.mktemp()
|
||||
FilePath(filename).setContent(publicRSA_openssh)
|
||||
error = self.assertRaises(
|
||||
SystemExit, changePassPhrase,
|
||||
{'filename': filename, 'newpass': 'pass'})
|
||||
self.assertEqual(
|
||||
'Could not change passphrase: key not encrypted', str(error))
|
||||
self.assertEqual(publicRSA_openssh, FilePath(filename).getContent())
|
||||
|
|
@ -0,0 +1,564 @@
|
|||
# -*- test-case-name: twisted.conch.test.test_conch -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
import os, sys, socket
|
||||
from itertools import count
|
||||
|
||||
from zope.interface import implements
|
||||
|
||||
from twisted.cred import portal
|
||||
from twisted.internet import reactor, defer, protocol
|
||||
from twisted.internet.error import ProcessExitedAlready
|
||||
from twisted.internet.task import LoopingCall
|
||||
from twisted.python import log, runtime
|
||||
from twisted.trial import unittest
|
||||
from twisted.conch.error import ConchError
|
||||
from twisted.conch.avatar import ConchUser
|
||||
from twisted.conch.ssh.session import ISession, SSHSession, wrapProtocol
|
||||
|
||||
try:
|
||||
from twisted.conch.scripts.conch import SSHSession as StdioInteractingSession
|
||||
except ImportError, e:
|
||||
StdioInteractingSession = None
|
||||
_reason = str(e)
|
||||
del e
|
||||
|
||||
from twisted.conch.test.test_ssh import ConchTestRealm
|
||||
from twisted.python.procutils import which
|
||||
|
||||
from twisted.conch.test.keydata import publicRSA_openssh, privateRSA_openssh
|
||||
from twisted.conch.test.keydata import publicDSA_openssh, privateDSA_openssh
|
||||
|
||||
from twisted.conch.test.test_ssh import Crypto, pyasn1
|
||||
try:
|
||||
from twisted.conch.test.test_ssh import ConchTestServerFactory, \
|
||||
ConchTestPublicKeyChecker
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
class FakeStdio(object):
|
||||
"""
|
||||
A fake for testing L{twisted.conch.scripts.conch.SSHSession.eofReceived} and
|
||||
L{twisted.conch.scripts.cftp.SSHSession.eofReceived}.
|
||||
|
||||
@ivar writeConnLost: A flag which records whether L{loserWriteConnection}
|
||||
has been called.
|
||||
"""
|
||||
writeConnLost = False
|
||||
|
||||
def loseWriteConnection(self):
|
||||
"""
|
||||
Record the call to loseWriteConnection.
|
||||
"""
|
||||
self.writeConnLost = True
|
||||
|
||||
|
||||
|
||||
class StdioInteractingSessionTests(unittest.TestCase):
|
||||
"""
|
||||
Tests for L{twisted.conch.scripts.conch.SSHSession}.
|
||||
"""
|
||||
if StdioInteractingSession is None:
|
||||
skip = _reason
|
||||
|
||||
def test_eofReceived(self):
|
||||
"""
|
||||
L{twisted.conch.scripts.conch.SSHSession.eofReceived} loses the
|
||||
write half of its stdio connection.
|
||||
"""
|
||||
stdio = FakeStdio()
|
||||
channel = StdioInteractingSession()
|
||||
channel.stdio = stdio
|
||||
channel.eofReceived()
|
||||
self.assertTrue(stdio.writeConnLost)
|
||||
|
||||
|
||||
|
||||
class Echo(protocol.Protocol):
|
||||
def connectionMade(self):
|
||||
log.msg('ECHO CONNECTION MADE')
|
||||
|
||||
|
||||
def connectionLost(self, reason):
|
||||
log.msg('ECHO CONNECTION DONE')
|
||||
|
||||
|
||||
def dataReceived(self, data):
|
||||
self.transport.write(data)
|
||||
if '\n' in data:
|
||||
self.transport.loseConnection()
|
||||
|
||||
|
||||
|
||||
class EchoFactory(protocol.Factory):
|
||||
protocol = Echo
|
||||
|
||||
|
||||
|
||||
class ConchTestOpenSSHProcess(protocol.ProcessProtocol):
|
||||
"""
|
||||
Test protocol for launching an OpenSSH client process.
|
||||
|
||||
@ivar deferred: Set by whatever uses this object. Accessed using
|
||||
L{_getDeferred}, which destroys the value so the Deferred is not
|
||||
fired twice. Fires when the process is terminated.
|
||||
"""
|
||||
|
||||
deferred = None
|
||||
buf = ''
|
||||
|
||||
def _getDeferred(self):
|
||||
d, self.deferred = self.deferred, None
|
||||
return d
|
||||
|
||||
|
||||
def outReceived(self, data):
|
||||
self.buf += data
|
||||
|
||||
|
||||
def processEnded(self, reason):
|
||||
"""
|
||||
Called when the process has ended.
|
||||
|
||||
@param reason: a Failure giving the reason for the process' end.
|
||||
"""
|
||||
if reason.value.exitCode != 0:
|
||||
self._getDeferred().errback(
|
||||
ConchError("exit code was not 0: %s" %
|
||||
reason.value.exitCode))
|
||||
else:
|
||||
buf = self.buf.replace('\r\n', '\n')
|
||||
self._getDeferred().callback(buf)
|
||||
|
||||
|
||||
|
||||
class ConchTestForwardingProcess(protocol.ProcessProtocol):
|
||||
"""
|
||||
Manages a third-party process which launches a server.
|
||||
|
||||
Uses L{ConchTestForwardingPort} to connect to the third-party server.
|
||||
Once L{ConchTestForwardingPort} has disconnected, kill the process and fire
|
||||
a Deferred with the data received by the L{ConchTestForwardingPort}.
|
||||
|
||||
@ivar deferred: Set by whatever uses this object. Accessed using
|
||||
L{_getDeferred}, which destroys the value so the Deferred is not
|
||||
fired twice. Fires when the process is terminated.
|
||||
"""
|
||||
|
||||
deferred = None
|
||||
|
||||
def __init__(self, port, data):
|
||||
"""
|
||||
@type port: C{int}
|
||||
@param port: The port on which the third-party server is listening.
|
||||
(it is assumed that the server is running on localhost).
|
||||
|
||||
@type data: C{str}
|
||||
@param data: This is sent to the third-party server. Must end with '\n'
|
||||
in order to trigger a disconnect.
|
||||
"""
|
||||
self.port = port
|
||||
self.buffer = None
|
||||
self.data = data
|
||||
|
||||
|
||||
def _getDeferred(self):
|
||||
d, self.deferred = self.deferred, None
|
||||
return d
|
||||
|
||||
|
||||
def connectionMade(self):
|
||||
self._connect()
|
||||
|
||||
|
||||
def _connect(self):
|
||||
"""
|
||||
Connect to the server, which is often a third-party process.
|
||||
Tries to reconnect if it fails because we have no way of determining
|
||||
exactly when the port becomes available for listening -- we can only
|
||||
know when the process starts.
|
||||
"""
|
||||
cc = protocol.ClientCreator(reactor, ConchTestForwardingPort, self,
|
||||
self.data)
|
||||
d = cc.connectTCP('127.0.0.1', self.port)
|
||||
d.addErrback(self._ebConnect)
|
||||
return d
|
||||
|
||||
|
||||
def _ebConnect(self, f):
|
||||
reactor.callLater(.1, self._connect)
|
||||
|
||||
|
||||
def forwardingPortDisconnected(self, buffer):
|
||||
"""
|
||||
The network connection has died; save the buffer of output
|
||||
from the network and attempt to quit the process gracefully,
|
||||
and then (after the reactor has spun) send it a KILL signal.
|
||||
"""
|
||||
self.buffer = buffer
|
||||
self.transport.write('\x03')
|
||||
self.transport.loseConnection()
|
||||
reactor.callLater(0, self._reallyDie)
|
||||
|
||||
|
||||
def _reallyDie(self):
|
||||
try:
|
||||
self.transport.signalProcess('KILL')
|
||||
except ProcessExitedAlready:
|
||||
pass
|
||||
|
||||
|
||||
def processEnded(self, reason):
|
||||
"""
|
||||
Fire the Deferred at self.deferred with the data collected
|
||||
from the L{ConchTestForwardingPort} connection, if any.
|
||||
"""
|
||||
self._getDeferred().callback(self.buffer)
|
||||
|
||||
|
||||
|
||||
class ConchTestForwardingPort(protocol.Protocol):
|
||||
"""
|
||||
Connects to server launched by a third-party process (managed by
|
||||
L{ConchTestForwardingProcess}) sends data, then reports whatever it
|
||||
received back to the L{ConchTestForwardingProcess} once the connection
|
||||
is ended.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, protocol, data):
|
||||
"""
|
||||
@type protocol: L{ConchTestForwardingProcess}
|
||||
@param protocol: The L{ProcessProtocol} which made this connection.
|
||||
|
||||
@type data: str
|
||||
@param data: The data to be sent to the third-party server.
|
||||
"""
|
||||
self.protocol = protocol
|
||||
self.data = data
|
||||
|
||||
|
||||
def connectionMade(self):
|
||||
self.buffer = ''
|
||||
self.transport.write(self.data)
|
||||
|
||||
|
||||
def dataReceived(self, data):
|
||||
self.buffer += data
|
||||
|
||||
|
||||
def connectionLost(self, reason):
|
||||
self.protocol.forwardingPortDisconnected(self.buffer)
|
||||
|
||||
|
||||
|
||||
def _makeArgs(args, mod="conch"):
|
||||
start = [sys.executable, '-c'
|
||||
"""
|
||||
### Twisted Preamble
|
||||
import sys, os
|
||||
path = os.path.abspath(sys.argv[0])
|
||||
while os.path.dirname(path) != path:
|
||||
if os.path.basename(path).startswith('Twisted'):
|
||||
sys.path.insert(0, path)
|
||||
break
|
||||
path = os.path.dirname(path)
|
||||
|
||||
from twisted.conch.scripts.%s import run
|
||||
run()""" % mod]
|
||||
return start + list(args)
|
||||
|
||||
|
||||
|
||||
class ConchServerSetupMixin:
|
||||
if not Crypto:
|
||||
skip = "can't run w/o PyCrypto"
|
||||
|
||||
if not pyasn1:
|
||||
skip = "Cannot run without PyASN1"
|
||||
|
||||
realmFactory = staticmethod(lambda: ConchTestRealm('testuser'))
|
||||
|
||||
def _createFiles(self):
|
||||
for f in ['rsa_test','rsa_test.pub','dsa_test','dsa_test.pub',
|
||||
'kh_test']:
|
||||
if os.path.exists(f):
|
||||
os.remove(f)
|
||||
open('rsa_test','w').write(privateRSA_openssh)
|
||||
open('rsa_test.pub','w').write(publicRSA_openssh)
|
||||
open('dsa_test.pub','w').write(publicDSA_openssh)
|
||||
open('dsa_test','w').write(privateDSA_openssh)
|
||||
os.chmod('dsa_test', 33152)
|
||||
os.chmod('rsa_test', 33152)
|
||||
open('kh_test','w').write('127.0.0.1 '+publicRSA_openssh)
|
||||
|
||||
|
||||
def _getFreePort(self):
|
||||
s = socket.socket()
|
||||
s.bind(('', 0))
|
||||
port = s.getsockname()[1]
|
||||
s.close()
|
||||
return port
|
||||
|
||||
|
||||
def _makeConchFactory(self):
|
||||
"""
|
||||
Make a L{ConchTestServerFactory}, which allows us to start a
|
||||
L{ConchTestServer} -- i.e. an actually listening conch.
|
||||
"""
|
||||
realm = self.realmFactory()
|
||||
p = portal.Portal(realm)
|
||||
p.registerChecker(ConchTestPublicKeyChecker())
|
||||
factory = ConchTestServerFactory()
|
||||
factory.portal = p
|
||||
return factory
|
||||
|
||||
|
||||
def setUp(self):
|
||||
self._createFiles()
|
||||
self.conchFactory = self._makeConchFactory()
|
||||
self.conchFactory.expectedLoseConnection = 1
|
||||
self.conchServer = reactor.listenTCP(0, self.conchFactory,
|
||||
interface="127.0.0.1")
|
||||
self.echoServer = reactor.listenTCP(0, EchoFactory())
|
||||
self.echoPort = self.echoServer.getHost().port
|
||||
|
||||
|
||||
def tearDown(self):
|
||||
try:
|
||||
self.conchFactory.proto.done = 1
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
self.conchFactory.proto.transport.loseConnection()
|
||||
return defer.gatherResults([
|
||||
defer.maybeDeferred(self.conchServer.stopListening),
|
||||
defer.maybeDeferred(self.echoServer.stopListening)])
|
||||
|
||||
|
||||
|
||||
class ForwardingMixin(ConchServerSetupMixin):
|
||||
"""
|
||||
Template class for tests of the Conch server's ability to forward arbitrary
|
||||
protocols over SSH.
|
||||
|
||||
These tests are integration tests, not unit tests. They launch a Conch
|
||||
server, a custom TCP server (just an L{EchoProtocol}) and then call
|
||||
L{execute}.
|
||||
|
||||
L{execute} is implemented by subclasses of L{ForwardingMixin}. It should
|
||||
cause an SSH client to connect to the Conch server, asking it to forward
|
||||
data to the custom TCP server.
|
||||
"""
|
||||
|
||||
def test_exec(self):
|
||||
"""
|
||||
Test that we can use whatever client to send the command "echo goodbye"
|
||||
to the Conch server. Make sure we receive "goodbye" back from the
|
||||
server.
|
||||
"""
|
||||
d = self.execute('echo goodbye', ConchTestOpenSSHProcess())
|
||||
return d.addCallback(self.assertEqual, 'goodbye\n')
|
||||
|
||||
|
||||
def test_localToRemoteForwarding(self):
|
||||
"""
|
||||
Test that we can use whatever client to forward a local port to a
|
||||
specified port on the server.
|
||||
"""
|
||||
localPort = self._getFreePort()
|
||||
process = ConchTestForwardingProcess(localPort, 'test\n')
|
||||
d = self.execute('', process,
|
||||
sshArgs='-N -L%i:127.0.0.1:%i'
|
||||
% (localPort, self.echoPort))
|
||||
d.addCallback(self.assertEqual, 'test\n')
|
||||
return d
|
||||
|
||||
|
||||
def test_remoteToLocalForwarding(self):
|
||||
"""
|
||||
Test that we can use whatever client to forward a port from the server
|
||||
to a port locally.
|
||||
"""
|
||||
localPort = self._getFreePort()
|
||||
process = ConchTestForwardingProcess(localPort, 'test\n')
|
||||
d = self.execute('', process,
|
||||
sshArgs='-N -R %i:127.0.0.1:%i'
|
||||
% (localPort, self.echoPort))
|
||||
d.addCallback(self.assertEqual, 'test\n')
|
||||
return d
|
||||
|
||||
|
||||
|
||||
class RekeyAvatar(ConchUser):
|
||||
"""
|
||||
This avatar implements a shell which sends 60 numbered lines to whatever
|
||||
connects to it, then closes the session with a 0 exit status.
|
||||
|
||||
60 lines is selected as being enough to send more than 2kB of traffic, the
|
||||
amount the client is configured to initiate a rekey after.
|
||||
"""
|
||||
# Conventionally there is a separate adapter object which provides ISession
|
||||
# for the user, but making the user provide ISession directly works too.
|
||||
# This isn't a full implementation of ISession though, just enough to make
|
||||
# these tests pass.
|
||||
implements(ISession)
|
||||
|
||||
def __init__(self):
|
||||
ConchUser.__init__(self)
|
||||
self.channelLookup['session'] = SSHSession
|
||||
|
||||
|
||||
def openShell(self, transport):
|
||||
"""
|
||||
Write 60 lines of data to the transport, then exit.
|
||||
"""
|
||||
proto = protocol.Protocol()
|
||||
proto.makeConnection(transport)
|
||||
transport.makeConnection(wrapProtocol(proto))
|
||||
|
||||
# Send enough bytes to the connection so that a rekey is triggered in
|
||||
# the client.
|
||||
def write(counter):
|
||||
i = counter()
|
||||
if i == 60:
|
||||
call.stop()
|
||||
transport.session.conn.sendRequest(
|
||||
transport.session, 'exit-status', '\x00\x00\x00\x00')
|
||||
transport.loseConnection()
|
||||
else:
|
||||
transport.write("line #%02d\n" % (i,))
|
||||
|
||||
# The timing for this loop is an educated guess (and/or the result of
|
||||
# experimentation) to exercise the case where a packet is generated
|
||||
# mid-rekey. Since the other side of the connection is (so far) the
|
||||
# OpenSSH command line client, there's no easy way to determine when the
|
||||
# rekey has been initiated. If there were, then generating a packet
|
||||
# immediately at that time would be a better way to test the
|
||||
# functionality being tested here.
|
||||
call = LoopingCall(write, count().next)
|
||||
call.start(0.01)
|
||||
|
||||
|
||||
def closed(self):
|
||||
"""
|
||||
Ignore the close of the session.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
class RekeyRealm:
|
||||
"""
|
||||
This realm gives out new L{RekeyAvatar} instances for any avatar request.
|
||||
"""
|
||||
def requestAvatar(self, avatarID, mind, *interfaces):
|
||||
return interfaces[0], RekeyAvatar(), lambda: None
|
||||
|
||||
|
||||
|
||||
class RekeyTestsMixin(ConchServerSetupMixin):
|
||||
"""
|
||||
TestCase mixin which defines tests exercising L{SSHTransportBase}'s handling
|
||||
of rekeying messages.
|
||||
"""
|
||||
realmFactory = RekeyRealm
|
||||
|
||||
def test_clientRekey(self):
|
||||
"""
|
||||
After a client-initiated rekey is completed, application data continues
|
||||
to be passed over the SSH connection.
|
||||
"""
|
||||
process = ConchTestOpenSSHProcess()
|
||||
d = self.execute("", process, '-o RekeyLimit=2K')
|
||||
def finished(result):
|
||||
self.assertEqual(
|
||||
result,
|
||||
'\n'.join(['line #%02d' % (i,) for i in range(60)]) + '\n')
|
||||
d.addCallback(finished)
|
||||
return d
|
||||
|
||||
|
||||
|
||||
class OpenSSHClientMixin:
|
||||
if not which('ssh'):
|
||||
skip = "no ssh command-line client available"
|
||||
|
||||
def execute(self, remoteCommand, process, sshArgs=''):
|
||||
"""
|
||||
Connects to the SSH server started in L{ConchServerSetupMixin.setUp} by
|
||||
running the 'ssh' command line tool.
|
||||
|
||||
@type remoteCommand: str
|
||||
@param remoteCommand: The command (with arguments) to run on the
|
||||
remote end.
|
||||
|
||||
@type process: L{ConchTestOpenSSHProcess}
|
||||
|
||||
@type sshArgs: str
|
||||
@param sshArgs: Arguments to pass to the 'ssh' process.
|
||||
|
||||
@return: L{defer.Deferred}
|
||||
"""
|
||||
process.deferred = defer.Deferred()
|
||||
cmdline = ('ssh -2 -l testuser -p %i '
|
||||
'-oUserKnownHostsFile=kh_test '
|
||||
'-oPasswordAuthentication=no '
|
||||
# Always use the RSA key, since that's the one in kh_test.
|
||||
'-oHostKeyAlgorithms=ssh-rsa '
|
||||
'-a '
|
||||
'-i dsa_test ') + sshArgs + \
|
||||
' 127.0.0.1 ' + remoteCommand
|
||||
port = self.conchServer.getHost().port
|
||||
cmds = (cmdline % port).split()
|
||||
reactor.spawnProcess(process, "ssh", cmds)
|
||||
return process.deferred
|
||||
|
||||
|
||||
|
||||
class OpenSSHClientForwardingTestCase(ForwardingMixin, OpenSSHClientMixin,
|
||||
unittest.TestCase):
|
||||
"""
|
||||
Connection forwarding tests run against the OpenSSL command line client.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
class OpenSSHClientRekeyTestCase(RekeyTestsMixin, OpenSSHClientMixin,
|
||||
unittest.TestCase):
|
||||
"""
|
||||
Rekeying tests run against the OpenSSL command line client.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
class CmdLineClientTestCase(ForwardingMixin, unittest.TestCase):
|
||||
"""
|
||||
Connection forwarding tests run against the Conch command line client.
|
||||
"""
|
||||
if runtime.platformType == 'win32':
|
||||
skip = "can't run cmdline client on win32"
|
||||
|
||||
def execute(self, remoteCommand, process, sshArgs=''):
|
||||
"""
|
||||
As for L{OpenSSHClientTestCase.execute}, except it runs the 'conch'
|
||||
command line tool, not 'ssh'.
|
||||
"""
|
||||
process.deferred = defer.Deferred()
|
||||
port = self.conchServer.getHost().port
|
||||
cmd = ('-p %i -l testuser '
|
||||
'--known-hosts kh_test '
|
||||
'--user-authentications publickey '
|
||||
'--host-key-algorithms ssh-rsa '
|
||||
'-a '
|
||||
'-i dsa_test '
|
||||
'-v ') % port + sshArgs + \
|
||||
' 127.0.0.1 ' + remoteCommand
|
||||
cmds = _makeArgs(cmd.split())
|
||||
log.msg(str(cmds))
|
||||
env = os.environ.copy()
|
||||
env['PYTHONPATH'] = os.pathsep.join(sys.path)
|
||||
reactor.spawnProcess(process, sys.executable, cmds, env=env)
|
||||
return process.deferred
|
||||
|
|
@ -0,0 +1,730 @@
|
|||
# Copyright (c) 2007-2010 Twisted Matrix Laboratories.
|
||||
# See LICENSE for details
|
||||
|
||||
"""
|
||||
This module tests twisted.conch.ssh.connection.
|
||||
"""
|
||||
|
||||
import struct
|
||||
|
||||
from twisted.conch import error
|
||||
from twisted.conch.ssh import channel, common, connection
|
||||
from twisted.trial import unittest
|
||||
from twisted.conch.test import test_userauth
|
||||
|
||||
|
||||
class TestChannel(channel.SSHChannel):
|
||||
"""
|
||||
A mocked-up version of twisted.conch.ssh.channel.SSHChannel.
|
||||
|
||||
@ivar gotOpen: True if channelOpen has been called.
|
||||
@type gotOpen: C{bool}
|
||||
@ivar specificData: the specific channel open data passed to channelOpen.
|
||||
@type specificData: C{str}
|
||||
@ivar openFailureReason: the reason passed to openFailed.
|
||||
@type openFailed: C{error.ConchError}
|
||||
@ivar inBuffer: a C{list} of strings received by the channel.
|
||||
@type inBuffer: C{list}
|
||||
@ivar extBuffer: a C{list} of 2-tuples (type, extended data) of received by
|
||||
the channel.
|
||||
@type extBuffer: C{list}
|
||||
@ivar numberRequests: the number of requests that have been made to this
|
||||
channel.
|
||||
@type numberRequests: C{int}
|
||||
@ivar gotEOF: True if the other side sent EOF.
|
||||
@type gotEOF: C{bool}
|
||||
@ivar gotOneClose: True if the other side closed the connection.
|
||||
@type gotOneClose: C{bool}
|
||||
@ivar gotClosed: True if the channel is closed.
|
||||
@type gotClosed: C{bool}
|
||||
"""
|
||||
name = "TestChannel"
|
||||
gotOpen = False
|
||||
|
||||
def logPrefix(self):
|
||||
return "TestChannel %i" % self.id
|
||||
|
||||
def channelOpen(self, specificData):
|
||||
"""
|
||||
The channel is open. Set up the instance variables.
|
||||
"""
|
||||
self.gotOpen = True
|
||||
self.specificData = specificData
|
||||
self.inBuffer = []
|
||||
self.extBuffer = []
|
||||
self.numberRequests = 0
|
||||
self.gotEOF = False
|
||||
self.gotOneClose = False
|
||||
self.gotClosed = False
|
||||
|
||||
def openFailed(self, reason):
|
||||
"""
|
||||
Opening the channel failed. Store the reason why.
|
||||
"""
|
||||
self.openFailureReason = reason
|
||||
|
||||
def request_test(self, data):
|
||||
"""
|
||||
A test request. Return True if data is 'data'.
|
||||
|
||||
@type data: C{str}
|
||||
"""
|
||||
self.numberRequests += 1
|
||||
return data == 'data'
|
||||
|
||||
def dataReceived(self, data):
|
||||
"""
|
||||
Data was received. Store it in the buffer.
|
||||
"""
|
||||
self.inBuffer.append(data)
|
||||
|
||||
def extReceived(self, code, data):
|
||||
"""
|
||||
Extended data was received. Store it in the buffer.
|
||||
"""
|
||||
self.extBuffer.append((code, data))
|
||||
|
||||
def eofReceived(self):
|
||||
"""
|
||||
EOF was received. Remember it.
|
||||
"""
|
||||
self.gotEOF = True
|
||||
|
||||
def closeReceived(self):
|
||||
"""
|
||||
Close was received. Remember it.
|
||||
"""
|
||||
self.gotOneClose = True
|
||||
|
||||
def closed(self):
|
||||
"""
|
||||
The channel is closed. Rembember it.
|
||||
"""
|
||||
self.gotClosed = True
|
||||
|
||||
class TestAvatar:
|
||||
"""
|
||||
A mocked-up version of twisted.conch.avatar.ConchUser
|
||||
"""
|
||||
_ARGS_ERROR_CODE = 123
|
||||
|
||||
def lookupChannel(self, channelType, windowSize, maxPacket, data):
|
||||
"""
|
||||
The server wants us to return a channel. If the requested channel is
|
||||
our TestChannel, return it, otherwise return None.
|
||||
"""
|
||||
if channelType == TestChannel.name:
|
||||
return TestChannel(remoteWindow=windowSize,
|
||||
remoteMaxPacket=maxPacket,
|
||||
data=data, avatar=self)
|
||||
elif channelType == "conch-error-args":
|
||||
# Raise a ConchError with backwards arguments to make sure the
|
||||
# connection fixes it for us. This case should be deprecated and
|
||||
# deleted eventually, but only after all of Conch gets the argument
|
||||
# order right.
|
||||
raise error.ConchError(
|
||||
self._ARGS_ERROR_CODE, "error args in wrong order")
|
||||
|
||||
|
||||
def gotGlobalRequest(self, requestType, data):
|
||||
"""
|
||||
The client has made a global request. If the global request is
|
||||
'TestGlobal', return True. If the global request is 'TestData',
|
||||
return True and the request-specific data we received. Otherwise,
|
||||
return False.
|
||||
"""
|
||||
if requestType == 'TestGlobal':
|
||||
return True
|
||||
elif requestType == 'TestData':
|
||||
return True, data
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
|
||||
class TestConnection(connection.SSHConnection):
|
||||
"""
|
||||
A subclass of SSHConnection for testing.
|
||||
|
||||
@ivar channel: the current channel.
|
||||
@type channel. C{TestChannel}
|
||||
"""
|
||||
|
||||
def logPrefix(self):
|
||||
return "TestConnection"
|
||||
|
||||
def global_TestGlobal(self, data):
|
||||
"""
|
||||
The other side made the 'TestGlobal' global request. Return True.
|
||||
"""
|
||||
return True
|
||||
|
||||
def global_Test_Data(self, data):
|
||||
"""
|
||||
The other side made the 'Test-Data' global request. Return True and
|
||||
the data we received.
|
||||
"""
|
||||
return True, data
|
||||
|
||||
def channel_TestChannel(self, windowSize, maxPacket, data):
|
||||
"""
|
||||
The other side is requesting the TestChannel. Create a C{TestChannel}
|
||||
instance, store it, and return it.
|
||||
"""
|
||||
self.channel = TestChannel(remoteWindow=windowSize,
|
||||
remoteMaxPacket=maxPacket, data=data)
|
||||
return self.channel
|
||||
|
||||
def channel_ErrorChannel(self, windowSize, maxPacket, data):
|
||||
"""
|
||||
The other side is requesting the ErrorChannel. Raise an exception.
|
||||
"""
|
||||
raise AssertionError('no such thing')
|
||||
|
||||
|
||||
|
||||
class ConnectionTestCase(unittest.TestCase):
|
||||
|
||||
if test_userauth.transport is None:
|
||||
skip = "Cannot run without both PyCrypto and pyasn1"
|
||||
|
||||
def setUp(self):
|
||||
self.transport = test_userauth.FakeTransport(None)
|
||||
self.transport.avatar = TestAvatar()
|
||||
self.conn = TestConnection()
|
||||
self.conn.transport = self.transport
|
||||
self.conn.serviceStarted()
|
||||
|
||||
def _openChannel(self, channel):
|
||||
"""
|
||||
Open the channel with the default connection.
|
||||
"""
|
||||
self.conn.openChannel(channel)
|
||||
self.transport.packets = self.transport.packets[:-1]
|
||||
self.conn.ssh_CHANNEL_OPEN_CONFIRMATION(struct.pack('>2L',
|
||||
channel.id, 255) + '\x00\x02\x00\x00\x00\x00\x80\x00')
|
||||
|
||||
def tearDown(self):
|
||||
self.conn.serviceStopped()
|
||||
|
||||
def test_linkAvatar(self):
|
||||
"""
|
||||
Test that the connection links itself to the avatar in the
|
||||
transport.
|
||||
"""
|
||||
self.assertIs(self.transport.avatar.conn, self.conn)
|
||||
|
||||
def test_serviceStopped(self):
|
||||
"""
|
||||
Test that serviceStopped() closes any open channels.
|
||||
"""
|
||||
channel1 = TestChannel()
|
||||
channel2 = TestChannel()
|
||||
self.conn.openChannel(channel1)
|
||||
self.conn.openChannel(channel2)
|
||||
self.conn.ssh_CHANNEL_OPEN_CONFIRMATION('\x00\x00\x00\x00' * 4)
|
||||
self.assertTrue(channel1.gotOpen)
|
||||
self.assertFalse(channel2.gotOpen)
|
||||
self.conn.serviceStopped()
|
||||
self.assertTrue(channel1.gotClosed)
|
||||
|
||||
def test_GLOBAL_REQUEST(self):
|
||||
"""
|
||||
Test that global request packets are dispatched to the global_*
|
||||
methods and the return values are translated into success or failure
|
||||
messages.
|
||||
"""
|
||||
self.conn.ssh_GLOBAL_REQUEST(common.NS('TestGlobal') + '\xff')
|
||||
self.assertEqual(self.transport.packets,
|
||||
[(connection.MSG_REQUEST_SUCCESS, '')])
|
||||
self.transport.packets = []
|
||||
self.conn.ssh_GLOBAL_REQUEST(common.NS('TestData') + '\xff' +
|
||||
'test data')
|
||||
self.assertEqual(self.transport.packets,
|
||||
[(connection.MSG_REQUEST_SUCCESS, 'test data')])
|
||||
self.transport.packets = []
|
||||
self.conn.ssh_GLOBAL_REQUEST(common.NS('TestBad') + '\xff')
|
||||
self.assertEqual(self.transport.packets,
|
||||
[(connection.MSG_REQUEST_FAILURE, '')])
|
||||
self.transport.packets = []
|
||||
self.conn.ssh_GLOBAL_REQUEST(common.NS('TestGlobal') + '\x00')
|
||||
self.assertEqual(self.transport.packets, [])
|
||||
|
||||
def test_REQUEST_SUCCESS(self):
|
||||
"""
|
||||
Test that global request success packets cause the Deferred to be
|
||||
called back.
|
||||
"""
|
||||
d = self.conn.sendGlobalRequest('request', 'data', True)
|
||||
self.conn.ssh_REQUEST_SUCCESS('data')
|
||||
def check(data):
|
||||
self.assertEqual(data, 'data')
|
||||
d.addCallback(check)
|
||||
d.addErrback(self.fail)
|
||||
return d
|
||||
|
||||
def test_REQUEST_FAILURE(self):
|
||||
"""
|
||||
Test that global request failure packets cause the Deferred to be
|
||||
erred back.
|
||||
"""
|
||||
d = self.conn.sendGlobalRequest('request', 'data', True)
|
||||
self.conn.ssh_REQUEST_FAILURE('data')
|
||||
def check(f):
|
||||
self.assertEqual(f.value.data, 'data')
|
||||
d.addCallback(self.fail)
|
||||
d.addErrback(check)
|
||||
return d
|
||||
|
||||
def test_CHANNEL_OPEN(self):
|
||||
"""
|
||||
Test that open channel packets cause a channel to be created and
|
||||
opened or a failure message to be returned.
|
||||
"""
|
||||
del self.transport.avatar
|
||||
self.conn.ssh_CHANNEL_OPEN(common.NS('TestChannel') +
|
||||
'\x00\x00\x00\x01' * 4)
|
||||
self.assertTrue(self.conn.channel.gotOpen)
|
||||
self.assertEqual(self.conn.channel.conn, self.conn)
|
||||
self.assertEqual(self.conn.channel.data, '\x00\x00\x00\x01')
|
||||
self.assertEqual(self.conn.channel.specificData, '\x00\x00\x00\x01')
|
||||
self.assertEqual(self.conn.channel.remoteWindowLeft, 1)
|
||||
self.assertEqual(self.conn.channel.remoteMaxPacket, 1)
|
||||
self.assertEqual(self.transport.packets,
|
||||
[(connection.MSG_CHANNEL_OPEN_CONFIRMATION,
|
||||
'\x00\x00\x00\x01\x00\x00\x00\x00\x00\x02\x00\x00'
|
||||
'\x00\x00\x80\x00')])
|
||||
self.transport.packets = []
|
||||
self.conn.ssh_CHANNEL_OPEN(common.NS('BadChannel') +
|
||||
'\x00\x00\x00\x02' * 4)
|
||||
self.flushLoggedErrors()
|
||||
self.assertEqual(self.transport.packets,
|
||||
[(connection.MSG_CHANNEL_OPEN_FAILURE,
|
||||
'\x00\x00\x00\x02\x00\x00\x00\x03' + common.NS(
|
||||
'unknown channel') + common.NS(''))])
|
||||
self.transport.packets = []
|
||||
self.conn.ssh_CHANNEL_OPEN(common.NS('ErrorChannel') +
|
||||
'\x00\x00\x00\x02' * 4)
|
||||
self.flushLoggedErrors()
|
||||
self.assertEqual(self.transport.packets,
|
||||
[(connection.MSG_CHANNEL_OPEN_FAILURE,
|
||||
'\x00\x00\x00\x02\x00\x00\x00\x02' + common.NS(
|
||||
'unknown failure') + common.NS(''))])
|
||||
|
||||
|
||||
def _lookupChannelErrorTest(self, code):
|
||||
"""
|
||||
Deliver a request for a channel open which will result in an exception
|
||||
being raised during channel lookup. Assert that an error response is
|
||||
delivered as a result.
|
||||
"""
|
||||
self.transport.avatar._ARGS_ERROR_CODE = code
|
||||
self.conn.ssh_CHANNEL_OPEN(
|
||||
common.NS('conch-error-args') + '\x00\x00\x00\x01' * 4)
|
||||
errors = self.flushLoggedErrors(error.ConchError)
|
||||
self.assertEqual(
|
||||
len(errors), 1, "Expected one error, got: %r" % (errors,))
|
||||
self.assertEqual(errors[0].value.args, (123, "error args in wrong order"))
|
||||
self.assertEqual(
|
||||
self.transport.packets,
|
||||
[(connection.MSG_CHANNEL_OPEN_FAILURE,
|
||||
# The response includes some bytes which identifying the
|
||||
# associated request, as well as the error code (7b in hex) and
|
||||
# the error message.
|
||||
'\x00\x00\x00\x01\x00\x00\x00\x7b' + common.NS(
|
||||
'error args in wrong order') + common.NS(''))])
|
||||
|
||||
|
||||
def test_lookupChannelError(self):
|
||||
"""
|
||||
If a C{lookupChannel} implementation raises L{error.ConchError} with the
|
||||
arguments in the wrong order, a C{MSG_CHANNEL_OPEN} failure is still
|
||||
sent in response to the message.
|
||||
|
||||
This is a temporary work-around until L{error.ConchError} is given
|
||||
better attributes and all of the Conch code starts constructing
|
||||
instances of it properly. Eventually this functionality should be
|
||||
deprecated and then removed.
|
||||
"""
|
||||
self._lookupChannelErrorTest(123)
|
||||
|
||||
|
||||
def test_lookupChannelErrorLongCode(self):
|
||||
"""
|
||||
Like L{test_lookupChannelError}, but for the case where the failure code
|
||||
is represented as a C{long} instead of a C{int}.
|
||||
"""
|
||||
self._lookupChannelErrorTest(123L)
|
||||
|
||||
|
||||
def test_CHANNEL_OPEN_CONFIRMATION(self):
|
||||
"""
|
||||
Test that channel open confirmation packets cause the channel to be
|
||||
notified that it's open.
|
||||
"""
|
||||
channel = TestChannel()
|
||||
self.conn.openChannel(channel)
|
||||
self.conn.ssh_CHANNEL_OPEN_CONFIRMATION('\x00\x00\x00\x00'*5)
|
||||
self.assertEqual(channel.remoteWindowLeft, 0)
|
||||
self.assertEqual(channel.remoteMaxPacket, 0)
|
||||
self.assertEqual(channel.specificData, '\x00\x00\x00\x00')
|
||||
self.assertEqual(self.conn.channelsToRemoteChannel[channel],
|
||||
0)
|
||||
self.assertEqual(self.conn.localToRemoteChannel[0], 0)
|
||||
|
||||
def test_CHANNEL_OPEN_FAILURE(self):
|
||||
"""
|
||||
Test that channel open failure packets cause the channel to be
|
||||
notified that its opening failed.
|
||||
"""
|
||||
channel = TestChannel()
|
||||
self.conn.openChannel(channel)
|
||||
self.conn.ssh_CHANNEL_OPEN_FAILURE('\x00\x00\x00\x00\x00\x00\x00'
|
||||
'\x01' + common.NS('failure!'))
|
||||
self.assertEqual(channel.openFailureReason.args, ('failure!', 1))
|
||||
self.assertEqual(self.conn.channels.get(channel), None)
|
||||
|
||||
|
||||
def test_CHANNEL_WINDOW_ADJUST(self):
|
||||
"""
|
||||
Test that channel window adjust messages add bytes to the channel
|
||||
window.
|
||||
"""
|
||||
channel = TestChannel()
|
||||
self._openChannel(channel)
|
||||
oldWindowSize = channel.remoteWindowLeft
|
||||
self.conn.ssh_CHANNEL_WINDOW_ADJUST('\x00\x00\x00\x00\x00\x00\x00'
|
||||
'\x01')
|
||||
self.assertEqual(channel.remoteWindowLeft, oldWindowSize + 1)
|
||||
|
||||
def test_CHANNEL_DATA(self):
|
||||
"""
|
||||
Test that channel data messages are passed up to the channel, or
|
||||
cause the channel to be closed if the data is too large.
|
||||
"""
|
||||
channel = TestChannel(localWindow=6, localMaxPacket=5)
|
||||
self._openChannel(channel)
|
||||
self.conn.ssh_CHANNEL_DATA('\x00\x00\x00\x00' + common.NS('data'))
|
||||
self.assertEqual(channel.inBuffer, ['data'])
|
||||
self.assertEqual(self.transport.packets,
|
||||
[(connection.MSG_CHANNEL_WINDOW_ADJUST, '\x00\x00\x00\xff'
|
||||
'\x00\x00\x00\x04')])
|
||||
self.transport.packets = []
|
||||
longData = 'a' * (channel.localWindowLeft + 1)
|
||||
self.conn.ssh_CHANNEL_DATA('\x00\x00\x00\x00' + common.NS(longData))
|
||||
self.assertEqual(channel.inBuffer, ['data'])
|
||||
self.assertEqual(self.transport.packets,
|
||||
[(connection.MSG_CHANNEL_CLOSE, '\x00\x00\x00\xff')])
|
||||
channel = TestChannel()
|
||||
self._openChannel(channel)
|
||||
bigData = 'a' * (channel.localMaxPacket + 1)
|
||||
self.transport.packets = []
|
||||
self.conn.ssh_CHANNEL_DATA('\x00\x00\x00\x01' + common.NS(bigData))
|
||||
self.assertEqual(channel.inBuffer, [])
|
||||
self.assertEqual(self.transport.packets,
|
||||
[(connection.MSG_CHANNEL_CLOSE, '\x00\x00\x00\xff')])
|
||||
|
||||
def test_CHANNEL_EXTENDED_DATA(self):
|
||||
"""
|
||||
Test that channel extended data messages are passed up to the channel,
|
||||
or cause the channel to be closed if they're too big.
|
||||
"""
|
||||
channel = TestChannel(localWindow=6, localMaxPacket=5)
|
||||
self._openChannel(channel)
|
||||
self.conn.ssh_CHANNEL_EXTENDED_DATA('\x00\x00\x00\x00\x00\x00\x00'
|
||||
'\x00' + common.NS('data'))
|
||||
self.assertEqual(channel.extBuffer, [(0, 'data')])
|
||||
self.assertEqual(self.transport.packets,
|
||||
[(connection.MSG_CHANNEL_WINDOW_ADJUST, '\x00\x00\x00\xff'
|
||||
'\x00\x00\x00\x04')])
|
||||
self.transport.packets = []
|
||||
longData = 'a' * (channel.localWindowLeft + 1)
|
||||
self.conn.ssh_CHANNEL_EXTENDED_DATA('\x00\x00\x00\x00\x00\x00\x00'
|
||||
'\x00' + common.NS(longData))
|
||||
self.assertEqual(channel.extBuffer, [(0, 'data')])
|
||||
self.assertEqual(self.transport.packets,
|
||||
[(connection.MSG_CHANNEL_CLOSE, '\x00\x00\x00\xff')])
|
||||
channel = TestChannel()
|
||||
self._openChannel(channel)
|
||||
bigData = 'a' * (channel.localMaxPacket + 1)
|
||||
self.transport.packets = []
|
||||
self.conn.ssh_CHANNEL_EXTENDED_DATA('\x00\x00\x00\x01\x00\x00\x00'
|
||||
'\x00' + common.NS(bigData))
|
||||
self.assertEqual(channel.extBuffer, [])
|
||||
self.assertEqual(self.transport.packets,
|
||||
[(connection.MSG_CHANNEL_CLOSE, '\x00\x00\x00\xff')])
|
||||
|
||||
def test_CHANNEL_EOF(self):
|
||||
"""
|
||||
Test that channel eof messages are passed up to the channel.
|
||||
"""
|
||||
channel = TestChannel()
|
||||
self._openChannel(channel)
|
||||
self.conn.ssh_CHANNEL_EOF('\x00\x00\x00\x00')
|
||||
self.assertTrue(channel.gotEOF)
|
||||
|
||||
def test_CHANNEL_CLOSE(self):
|
||||
"""
|
||||
Test that channel close messages are passed up to the channel. Also,
|
||||
test that channel.close() is called if both sides are closed when this
|
||||
message is received.
|
||||
"""
|
||||
channel = TestChannel()
|
||||
self._openChannel(channel)
|
||||
self.conn.sendClose(channel)
|
||||
self.conn.ssh_CHANNEL_CLOSE('\x00\x00\x00\x00')
|
||||
self.assertTrue(channel.gotOneClose)
|
||||
self.assertTrue(channel.gotClosed)
|
||||
|
||||
def test_CHANNEL_REQUEST_success(self):
|
||||
"""
|
||||
Test that channel requests that succeed send MSG_CHANNEL_SUCCESS.
|
||||
"""
|
||||
channel = TestChannel()
|
||||
self._openChannel(channel)
|
||||
self.conn.ssh_CHANNEL_REQUEST('\x00\x00\x00\x00' + common.NS('test')
|
||||
+ '\x00')
|
||||
self.assertEqual(channel.numberRequests, 1)
|
||||
d = self.conn.ssh_CHANNEL_REQUEST('\x00\x00\x00\x00' + common.NS(
|
||||
'test') + '\xff' + 'data')
|
||||
def check(result):
|
||||
self.assertEqual(self.transport.packets,
|
||||
[(connection.MSG_CHANNEL_SUCCESS, '\x00\x00\x00\xff')])
|
||||
d.addCallback(check)
|
||||
return d
|
||||
|
||||
def test_CHANNEL_REQUEST_failure(self):
|
||||
"""
|
||||
Test that channel requests that fail send MSG_CHANNEL_FAILURE.
|
||||
"""
|
||||
channel = TestChannel()
|
||||
self._openChannel(channel)
|
||||
d = self.conn.ssh_CHANNEL_REQUEST('\x00\x00\x00\x00' + common.NS(
|
||||
'test') + '\xff')
|
||||
def check(result):
|
||||
self.assertEqual(self.transport.packets,
|
||||
[(connection.MSG_CHANNEL_FAILURE, '\x00\x00\x00\xff'
|
||||
)])
|
||||
d.addCallback(self.fail)
|
||||
d.addErrback(check)
|
||||
return d
|
||||
|
||||
def test_CHANNEL_REQUEST_SUCCESS(self):
|
||||
"""
|
||||
Test that channel request success messages cause the Deferred to be
|
||||
called back.
|
||||
"""
|
||||
channel = TestChannel()
|
||||
self._openChannel(channel)
|
||||
d = self.conn.sendRequest(channel, 'test', 'data', True)
|
||||
self.conn.ssh_CHANNEL_SUCCESS('\x00\x00\x00\x00')
|
||||
def check(result):
|
||||
self.assertTrue(result)
|
||||
return d
|
||||
|
||||
def test_CHANNEL_REQUEST_FAILURE(self):
|
||||
"""
|
||||
Test that channel request failure messages cause the Deferred to be
|
||||
erred back.
|
||||
"""
|
||||
channel = TestChannel()
|
||||
self._openChannel(channel)
|
||||
d = self.conn.sendRequest(channel, 'test', '', True)
|
||||
self.conn.ssh_CHANNEL_FAILURE('\x00\x00\x00\x00')
|
||||
def check(result):
|
||||
self.assertEqual(result.value.value, 'channel request failed')
|
||||
d.addCallback(self.fail)
|
||||
d.addErrback(check)
|
||||
return d
|
||||
|
||||
def test_sendGlobalRequest(self):
|
||||
"""
|
||||
Test that global request messages are sent in the right format.
|
||||
"""
|
||||
d = self.conn.sendGlobalRequest('wantReply', 'data', True)
|
||||
# must be added to prevent errbacking during teardown
|
||||
d.addErrback(lambda failure: None)
|
||||
self.conn.sendGlobalRequest('noReply', '', False)
|
||||
self.assertEqual(self.transport.packets,
|
||||
[(connection.MSG_GLOBAL_REQUEST, common.NS('wantReply') +
|
||||
'\xffdata'),
|
||||
(connection.MSG_GLOBAL_REQUEST, common.NS('noReply') +
|
||||
'\x00')])
|
||||
self.assertEqual(self.conn.deferreds, {'global':[d]})
|
||||
|
||||
def test_openChannel(self):
|
||||
"""
|
||||
Test that open channel messages are sent in the right format.
|
||||
"""
|
||||
channel = TestChannel()
|
||||
self.conn.openChannel(channel, 'aaaa')
|
||||
self.assertEqual(self.transport.packets,
|
||||
[(connection.MSG_CHANNEL_OPEN, common.NS('TestChannel') +
|
||||
'\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x80\x00aaaa')])
|
||||
self.assertEqual(channel.id, 0)
|
||||
self.assertEqual(self.conn.localChannelID, 1)
|
||||
|
||||
def test_sendRequest(self):
|
||||
"""
|
||||
Test that channel request messages are sent in the right format.
|
||||
"""
|
||||
channel = TestChannel()
|
||||
self._openChannel(channel)
|
||||
d = self.conn.sendRequest(channel, 'test', 'test', True)
|
||||
# needed to prevent errbacks during teardown.
|
||||
d.addErrback(lambda failure: None)
|
||||
self.conn.sendRequest(channel, 'test2', '', False)
|
||||
channel.localClosed = True # emulate sending a close message
|
||||
self.conn.sendRequest(channel, 'test3', '', True)
|
||||
self.assertEqual(self.transport.packets,
|
||||
[(connection.MSG_CHANNEL_REQUEST, '\x00\x00\x00\xff' +
|
||||
common.NS('test') + '\x01test'),
|
||||
(connection.MSG_CHANNEL_REQUEST, '\x00\x00\x00\xff' +
|
||||
common.NS('test2') + '\x00')])
|
||||
self.assertEqual(self.conn.deferreds[0], [d])
|
||||
|
||||
def test_adjustWindow(self):
|
||||
"""
|
||||
Test that channel window adjust messages cause bytes to be added
|
||||
to the window.
|
||||
"""
|
||||
channel = TestChannel(localWindow=5)
|
||||
self._openChannel(channel)
|
||||
channel.localWindowLeft = 0
|
||||
self.conn.adjustWindow(channel, 1)
|
||||
self.assertEqual(channel.localWindowLeft, 1)
|
||||
channel.localClosed = True
|
||||
self.conn.adjustWindow(channel, 2)
|
||||
self.assertEqual(channel.localWindowLeft, 1)
|
||||
self.assertEqual(self.transport.packets,
|
||||
[(connection.MSG_CHANNEL_WINDOW_ADJUST, '\x00\x00\x00\xff'
|
||||
'\x00\x00\x00\x01')])
|
||||
|
||||
def test_sendData(self):
|
||||
"""
|
||||
Test that channel data messages are sent in the right format.
|
||||
"""
|
||||
channel = TestChannel()
|
||||
self._openChannel(channel)
|
||||
self.conn.sendData(channel, 'a')
|
||||
channel.localClosed = True
|
||||
self.conn.sendData(channel, 'b')
|
||||
self.assertEqual(self.transport.packets,
|
||||
[(connection.MSG_CHANNEL_DATA, '\x00\x00\x00\xff' +
|
||||
common.NS('a'))])
|
||||
|
||||
def test_sendExtendedData(self):
|
||||
"""
|
||||
Test that channel extended data messages are sent in the right format.
|
||||
"""
|
||||
channel = TestChannel()
|
||||
self._openChannel(channel)
|
||||
self.conn.sendExtendedData(channel, 1, 'test')
|
||||
channel.localClosed = True
|
||||
self.conn.sendExtendedData(channel, 2, 'test2')
|
||||
self.assertEqual(self.transport.packets,
|
||||
[(connection.MSG_CHANNEL_EXTENDED_DATA, '\x00\x00\x00\xff' +
|
||||
'\x00\x00\x00\x01' + common.NS('test'))])
|
||||
|
||||
def test_sendEOF(self):
|
||||
"""
|
||||
Test that channel EOF messages are sent in the right format.
|
||||
"""
|
||||
channel = TestChannel()
|
||||
self._openChannel(channel)
|
||||
self.conn.sendEOF(channel)
|
||||
self.assertEqual(self.transport.packets,
|
||||
[(connection.MSG_CHANNEL_EOF, '\x00\x00\x00\xff')])
|
||||
channel.localClosed = True
|
||||
self.conn.sendEOF(channel)
|
||||
self.assertEqual(self.transport.packets,
|
||||
[(connection.MSG_CHANNEL_EOF, '\x00\x00\x00\xff')])
|
||||
|
||||
def test_sendClose(self):
|
||||
"""
|
||||
Test that channel close messages are sent in the right format.
|
||||
"""
|
||||
channel = TestChannel()
|
||||
self._openChannel(channel)
|
||||
self.conn.sendClose(channel)
|
||||
self.assertTrue(channel.localClosed)
|
||||
self.assertEqual(self.transport.packets,
|
||||
[(connection.MSG_CHANNEL_CLOSE, '\x00\x00\x00\xff')])
|
||||
self.conn.sendClose(channel)
|
||||
self.assertEqual(self.transport.packets,
|
||||
[(connection.MSG_CHANNEL_CLOSE, '\x00\x00\x00\xff')])
|
||||
|
||||
channel2 = TestChannel()
|
||||
self._openChannel(channel2)
|
||||
channel2.remoteClosed = True
|
||||
self.conn.sendClose(channel2)
|
||||
self.assertTrue(channel2.gotClosed)
|
||||
|
||||
def test_getChannelWithAvatar(self):
|
||||
"""
|
||||
Test that getChannel dispatches to the avatar when an avatar is
|
||||
present. Correct functioning without the avatar is verified in
|
||||
test_CHANNEL_OPEN.
|
||||
"""
|
||||
channel = self.conn.getChannel('TestChannel', 50, 30, 'data')
|
||||
self.assertEqual(channel.data, 'data')
|
||||
self.assertEqual(channel.remoteWindowLeft, 50)
|
||||
self.assertEqual(channel.remoteMaxPacket, 30)
|
||||
self.assertRaises(error.ConchError, self.conn.getChannel,
|
||||
'BadChannel', 50, 30, 'data')
|
||||
|
||||
def test_gotGlobalRequestWithoutAvatar(self):
|
||||
"""
|
||||
Test that gotGlobalRequests dispatches to global_* without an avatar.
|
||||
"""
|
||||
del self.transport.avatar
|
||||
self.assertTrue(self.conn.gotGlobalRequest('TestGlobal', 'data'))
|
||||
self.assertEqual(self.conn.gotGlobalRequest('Test-Data', 'data'),
|
||||
(True, 'data'))
|
||||
self.assertFalse(self.conn.gotGlobalRequest('BadGlobal', 'data'))
|
||||
|
||||
|
||||
def test_channelClosedCausesLeftoverChannelDeferredsToErrback(self):
|
||||
"""
|
||||
Whenever an SSH channel gets closed any Deferred that was returned by a
|
||||
sendRequest() on its parent connection must be errbacked.
|
||||
"""
|
||||
channel = TestChannel()
|
||||
self._openChannel(channel)
|
||||
|
||||
d = self.conn.sendRequest(
|
||||
channel, "dummyrequest", "dummydata", wantReply=1)
|
||||
d = self.assertFailure(d, error.ConchError)
|
||||
self.conn.channelClosed(channel)
|
||||
return d
|
||||
|
||||
|
||||
|
||||
class TestCleanConnectionShutdown(unittest.TestCase):
|
||||
"""
|
||||
Check whether correct cleanup is performed on connection shutdown.
|
||||
"""
|
||||
if test_userauth.transport is None:
|
||||
skip = "Cannot run without both PyCrypto and pyasn1"
|
||||
|
||||
def setUp(self):
|
||||
self.transport = test_userauth.FakeTransport(None)
|
||||
self.transport.avatar = TestAvatar()
|
||||
self.conn = TestConnection()
|
||||
self.conn.transport = self.transport
|
||||
|
||||
|
||||
def test_serviceStoppedCausesLeftoverGlobalDeferredsToErrback(self):
|
||||
"""
|
||||
Once the service is stopped any leftover global deferred returned by
|
||||
a sendGlobalRequest() call must be errbacked.
|
||||
"""
|
||||
self.conn.serviceStarted()
|
||||
|
||||
d = self.conn.sendGlobalRequest(
|
||||
"dummyrequest", "dummydata", wantReply=1)
|
||||
d = self.assertFailure(d, error.ConchError)
|
||||
self.conn.serviceStopped()
|
||||
return d
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,171 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.conch.client.default}.
|
||||
"""
|
||||
try:
|
||||
import Crypto.Cipher.DES3
|
||||
import pyasn1
|
||||
except ImportError:
|
||||
skip = "PyCrypto and PyASN1 required for twisted.conch.client.default."
|
||||
else:
|
||||
from twisted.conch.client.agent import SSHAgentClient
|
||||
from twisted.conch.client.default import SSHUserAuthClient
|
||||
from twisted.conch.client.options import ConchOptions
|
||||
from twisted.conch.ssh.keys import Key
|
||||
|
||||
|
||||
from twisted.trial.unittest import TestCase
|
||||
from twisted.python.filepath import FilePath
|
||||
from twisted.conch.test import keydata
|
||||
from twisted.test.proto_helpers import StringTransport
|
||||
|
||||
|
||||
|
||||
class SSHUserAuthClientTest(TestCase):
|
||||
"""
|
||||
Tests for L{SSHUserAuthClient}.
|
||||
|
||||
@type rsaPublic: L{Key}
|
||||
@ivar rsaPublic: A public RSA key.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self.rsaPublic = Key.fromString(keydata.publicRSA_openssh)
|
||||
self.tmpdir = FilePath(self.mktemp())
|
||||
self.tmpdir.makedirs()
|
||||
self.rsaFile = self.tmpdir.child('id_rsa')
|
||||
self.rsaFile.setContent(keydata.privateRSA_openssh)
|
||||
self.tmpdir.child('id_rsa.pub').setContent(keydata.publicRSA_openssh)
|
||||
|
||||
|
||||
def test_signDataWithAgent(self):
|
||||
"""
|
||||
When connected to an agent, L{SSHUserAuthClient} can use it to
|
||||
request signatures of particular data with a particular L{Key}.
|
||||
"""
|
||||
client = SSHUserAuthClient("user", ConchOptions(), None)
|
||||
agent = SSHAgentClient()
|
||||
transport = StringTransport()
|
||||
agent.makeConnection(transport)
|
||||
client.keyAgent = agent
|
||||
cleartext = "Sign here"
|
||||
client.signData(self.rsaPublic, cleartext)
|
||||
self.assertEqual(
|
||||
transport.value(),
|
||||
"\x00\x00\x00\x8b\r\x00\x00\x00u" + self.rsaPublic.blob() +
|
||||
"\x00\x00\x00\t" + cleartext +
|
||||
"\x00\x00\x00\x00")
|
||||
|
||||
|
||||
def test_agentGetPublicKey(self):
|
||||
"""
|
||||
L{SSHUserAuthClient} looks up public keys from the agent using the
|
||||
L{SSHAgentClient} class. That L{SSHAgentClient.getPublicKey} returns a
|
||||
L{Key} object with one of the public keys in the agent. If no more
|
||||
keys are present, it returns C{None}.
|
||||
"""
|
||||
agent = SSHAgentClient()
|
||||
agent.blobs = [self.rsaPublic.blob()]
|
||||
key = agent.getPublicKey()
|
||||
self.assertEqual(key.isPublic(), True)
|
||||
self.assertEqual(key, self.rsaPublic)
|
||||
self.assertEqual(agent.getPublicKey(), None)
|
||||
|
||||
|
||||
def test_getPublicKeyFromFile(self):
|
||||
"""
|
||||
L{SSHUserAuthClient.getPublicKey()} is able to get a public key from
|
||||
the first file described by its options' C{identitys} list, and return
|
||||
the corresponding public L{Key} object.
|
||||
"""
|
||||
options = ConchOptions()
|
||||
options.identitys = [self.rsaFile.path]
|
||||
client = SSHUserAuthClient("user", options, None)
|
||||
key = client.getPublicKey()
|
||||
self.assertEqual(key.isPublic(), True)
|
||||
self.assertEqual(key, self.rsaPublic)
|
||||
|
||||
|
||||
def test_getPublicKeyAgentFallback(self):
|
||||
"""
|
||||
If an agent is present, but doesn't return a key,
|
||||
L{SSHUserAuthClient.getPublicKey} continue with the normal key lookup.
|
||||
"""
|
||||
options = ConchOptions()
|
||||
options.identitys = [self.rsaFile.path]
|
||||
agent = SSHAgentClient()
|
||||
client = SSHUserAuthClient("user", options, None)
|
||||
client.keyAgent = agent
|
||||
key = client.getPublicKey()
|
||||
self.assertEqual(key.isPublic(), True)
|
||||
self.assertEqual(key, self.rsaPublic)
|
||||
|
||||
|
||||
def test_getPublicKeyBadKeyError(self):
|
||||
"""
|
||||
If L{keys.Key.fromFile} raises a L{keys.BadKeyError}, the
|
||||
L{SSHUserAuthClient.getPublicKey} tries again to get a public key by
|
||||
calling itself recursively.
|
||||
"""
|
||||
options = ConchOptions()
|
||||
self.tmpdir.child('id_dsa.pub').setContent(keydata.publicDSA_openssh)
|
||||
dsaFile = self.tmpdir.child('id_dsa')
|
||||
dsaFile.setContent(keydata.privateDSA_openssh)
|
||||
options.identitys = [self.rsaFile.path, dsaFile.path]
|
||||
self.tmpdir.child('id_rsa.pub').setContent('not a key!')
|
||||
client = SSHUserAuthClient("user", options, None)
|
||||
key = client.getPublicKey()
|
||||
self.assertEqual(key.isPublic(), True)
|
||||
self.assertEqual(key, Key.fromString(keydata.publicDSA_openssh))
|
||||
self.assertEqual(client.usedFiles, [self.rsaFile.path, dsaFile.path])
|
||||
|
||||
|
||||
def test_getPrivateKey(self):
|
||||
"""
|
||||
L{SSHUserAuthClient.getPrivateKey} will load a private key from the
|
||||
last used file populated by L{SSHUserAuthClient.getPublicKey}, and
|
||||
return a L{Deferred} which fires with the corresponding private L{Key}.
|
||||
"""
|
||||
rsaPrivate = Key.fromString(keydata.privateRSA_openssh)
|
||||
options = ConchOptions()
|
||||
options.identitys = [self.rsaFile.path]
|
||||
client = SSHUserAuthClient("user", options, None)
|
||||
# Populate the list of used files
|
||||
client.getPublicKey()
|
||||
|
||||
def _cbGetPrivateKey(key):
|
||||
self.assertEqual(key.isPublic(), False)
|
||||
self.assertEqual(key, rsaPrivate)
|
||||
|
||||
return client.getPrivateKey().addCallback(_cbGetPrivateKey)
|
||||
|
||||
|
||||
def test_getPrivateKeyPassphrase(self):
|
||||
"""
|
||||
L{SSHUserAuthClient} can get a private key from a file, and return a
|
||||
Deferred called back with a private L{Key} object, even if the key is
|
||||
encrypted.
|
||||
"""
|
||||
rsaPrivate = Key.fromString(keydata.privateRSA_openssh)
|
||||
passphrase = 'this is the passphrase'
|
||||
self.rsaFile.setContent(rsaPrivate.toString('openssh', passphrase))
|
||||
options = ConchOptions()
|
||||
options.identitys = [self.rsaFile.path]
|
||||
client = SSHUserAuthClient("user", options, None)
|
||||
# Populate the list of used files
|
||||
client.getPublicKey()
|
||||
|
||||
def _getPassword(prompt):
|
||||
self.assertEqual(prompt,
|
||||
"Enter passphrase for key '%s': " % (
|
||||
self.rsaFile.path,))
|
||||
return passphrase
|
||||
|
||||
def _cbGetPrivateKey(key):
|
||||
self.assertEqual(key.isPublic(), False)
|
||||
self.assertEqual(key, rsaPrivate)
|
||||
|
||||
self.patch(client, '_getPassword', _getPassword)
|
||||
return client.getPrivateKey().addCallback(_cbGetPrivateKey)
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -0,0 +1,771 @@
|
|||
# -*- test-case-name: twisted.conch.test.test_filetransfer -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE file for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.conch.ssh.filetransfer}.
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import struct
|
||||
import sys
|
||||
|
||||
from twisted.trial import unittest
|
||||
try:
|
||||
from twisted.conch import unix
|
||||
unix # shut up pyflakes
|
||||
except ImportError:
|
||||
unix = None
|
||||
|
||||
from twisted.conch import avatar
|
||||
from twisted.conch.ssh import common, connection, filetransfer, session
|
||||
from twisted.internet import defer
|
||||
from twisted.protocols import loopback
|
||||
from twisted.python import components
|
||||
|
||||
|
||||
class TestAvatar(avatar.ConchUser):
|
||||
def __init__(self):
|
||||
avatar.ConchUser.__init__(self)
|
||||
self.channelLookup['session'] = session.SSHSession
|
||||
self.subsystemLookup['sftp'] = filetransfer.FileTransferServer
|
||||
|
||||
def _runAsUser(self, f, *args, **kw):
|
||||
try:
|
||||
f = iter(f)
|
||||
except TypeError:
|
||||
f = [(f, args, kw)]
|
||||
for i in f:
|
||||
func = i[0]
|
||||
args = len(i)>1 and i[1] or ()
|
||||
kw = len(i)>2 and i[2] or {}
|
||||
r = func(*args, **kw)
|
||||
return r
|
||||
|
||||
|
||||
class FileTransferTestAvatar(TestAvatar):
|
||||
|
||||
def __init__(self, homeDir):
|
||||
TestAvatar.__init__(self)
|
||||
self.homeDir = homeDir
|
||||
|
||||
def getHomeDir(self):
|
||||
return os.path.join(os.getcwd(), self.homeDir)
|
||||
|
||||
|
||||
class ConchSessionForTestAvatar:
|
||||
|
||||
def __init__(self, avatar):
|
||||
self.avatar = avatar
|
||||
|
||||
if unix:
|
||||
if not hasattr(unix, 'SFTPServerForUnixConchUser'):
|
||||
# unix should either be a fully working module, or None. I'm not sure
|
||||
# how this happens, but on win32 it does. Try to cope. --spiv.
|
||||
import warnings
|
||||
warnings.warn(("twisted.conch.unix imported %r, "
|
||||
"but doesn't define SFTPServerForUnixConchUser'")
|
||||
% (unix,))
|
||||
unix = None
|
||||
else:
|
||||
class FileTransferForTestAvatar(unix.SFTPServerForUnixConchUser):
|
||||
|
||||
def gotVersion(self, version, otherExt):
|
||||
return {'conchTest' : 'ext data'}
|
||||
|
||||
def extendedRequest(self, extName, extData):
|
||||
if extName == 'testExtendedRequest':
|
||||
return 'bar'
|
||||
raise NotImplementedError
|
||||
|
||||
components.registerAdapter(FileTransferForTestAvatar,
|
||||
TestAvatar,
|
||||
filetransfer.ISFTPServer)
|
||||
|
||||
class SFTPTestBase(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.testDir = self.mktemp()
|
||||
# Give the testDir another level so we can safely "cd .." from it in
|
||||
# tests.
|
||||
self.testDir = os.path.join(self.testDir, 'extra')
|
||||
os.makedirs(os.path.join(self.testDir, 'testDirectory'))
|
||||
|
||||
f = file(os.path.join(self.testDir, 'testfile1'),'w')
|
||||
f.write('a'*10+'b'*10)
|
||||
f.write(file('/dev/urandom').read(1024*64)) # random data
|
||||
os.chmod(os.path.join(self.testDir, 'testfile1'), 0644)
|
||||
file(os.path.join(self.testDir, 'testRemoveFile'), 'w').write('a')
|
||||
file(os.path.join(self.testDir, 'testRenameFile'), 'w').write('a')
|
||||
file(os.path.join(self.testDir, '.testHiddenFile'), 'w').write('a')
|
||||
|
||||
|
||||
class TestOurServerOurClient(SFTPTestBase):
|
||||
|
||||
if not unix:
|
||||
skip = "can't run on non-posix computers"
|
||||
|
||||
def setUp(self):
|
||||
SFTPTestBase.setUp(self)
|
||||
|
||||
self.avatar = FileTransferTestAvatar(self.testDir)
|
||||
self.server = filetransfer.FileTransferServer(avatar=self.avatar)
|
||||
clientTransport = loopback.LoopbackRelay(self.server)
|
||||
|
||||
self.client = filetransfer.FileTransferClient()
|
||||
self._serverVersion = None
|
||||
self._extData = None
|
||||
def _(serverVersion, extData):
|
||||
self._serverVersion = serverVersion
|
||||
self._extData = extData
|
||||
self.client.gotServerVersion = _
|
||||
serverTransport = loopback.LoopbackRelay(self.client)
|
||||
self.client.makeConnection(clientTransport)
|
||||
self.server.makeConnection(serverTransport)
|
||||
|
||||
self.clientTransport = clientTransport
|
||||
self.serverTransport = serverTransport
|
||||
|
||||
self._emptyBuffers()
|
||||
|
||||
|
||||
def _emptyBuffers(self):
|
||||
while self.serverTransport.buffer or self.clientTransport.buffer:
|
||||
self.serverTransport.clearBuffer()
|
||||
self.clientTransport.clearBuffer()
|
||||
|
||||
|
||||
def tearDown(self):
|
||||
self.serverTransport.loseConnection()
|
||||
self.clientTransport.loseConnection()
|
||||
self.serverTransport.clearBuffer()
|
||||
self.clientTransport.clearBuffer()
|
||||
|
||||
|
||||
def testServerVersion(self):
|
||||
self.assertEqual(self._serverVersion, 3)
|
||||
self.assertEqual(self._extData, {'conchTest' : 'ext data'})
|
||||
|
||||
|
||||
def test_interface_implementation(self):
|
||||
"""
|
||||
It implements the ISFTPServer interface.
|
||||
"""
|
||||
self.assertTrue(
|
||||
filetransfer.ISFTPServer.providedBy(self.server.client),
|
||||
"ISFTPServer not provided by %r" % (self.server.client,))
|
||||
|
||||
|
||||
def test_openedFileClosedWithConnection(self):
|
||||
"""
|
||||
A file opened with C{openFile} is close when the connection is lost.
|
||||
"""
|
||||
d = self.client.openFile("testfile1", filetransfer.FXF_READ |
|
||||
filetransfer.FXF_WRITE, {})
|
||||
self._emptyBuffers()
|
||||
|
||||
oldClose = os.close
|
||||
closed = []
|
||||
def close(fd):
|
||||
closed.append(fd)
|
||||
oldClose(fd)
|
||||
|
||||
self.patch(os, "close", close)
|
||||
|
||||
def _fileOpened(openFile):
|
||||
fd = self.server.openFiles[openFile.handle[4:]].fd
|
||||
self.serverTransport.loseConnection()
|
||||
self.clientTransport.loseConnection()
|
||||
self.serverTransport.clearBuffer()
|
||||
self.clientTransport.clearBuffer()
|
||||
self.assertEqual(self.server.openFiles, {})
|
||||
self.assertIn(fd, closed)
|
||||
|
||||
d.addCallback(_fileOpened)
|
||||
return d
|
||||
|
||||
|
||||
def test_openedDirectoryClosedWithConnection(self):
|
||||
"""
|
||||
A directory opened with C{openDirectory} is close when the connection
|
||||
is lost.
|
||||
"""
|
||||
d = self.client.openDirectory('')
|
||||
self._emptyBuffers()
|
||||
|
||||
def _getFiles(openDir):
|
||||
self.serverTransport.loseConnection()
|
||||
self.clientTransport.loseConnection()
|
||||
self.serverTransport.clearBuffer()
|
||||
self.clientTransport.clearBuffer()
|
||||
self.assertEqual(self.server.openDirs, {})
|
||||
|
||||
d.addCallback(_getFiles)
|
||||
return d
|
||||
|
||||
|
||||
def testOpenFileIO(self):
|
||||
d = self.client.openFile("testfile1", filetransfer.FXF_READ |
|
||||
filetransfer.FXF_WRITE, {})
|
||||
self._emptyBuffers()
|
||||
|
||||
def _fileOpened(openFile):
|
||||
self.assertEqual(openFile, filetransfer.ISFTPFile(openFile))
|
||||
d = _readChunk(openFile)
|
||||
d.addCallback(_writeChunk, openFile)
|
||||
return d
|
||||
|
||||
def _readChunk(openFile):
|
||||
d = openFile.readChunk(0, 20)
|
||||
self._emptyBuffers()
|
||||
d.addCallback(self.assertEqual, 'a'*10 + 'b'*10)
|
||||
return d
|
||||
|
||||
def _writeChunk(_, openFile):
|
||||
d = openFile.writeChunk(20, 'c'*10)
|
||||
self._emptyBuffers()
|
||||
d.addCallback(_readChunk2, openFile)
|
||||
return d
|
||||
|
||||
def _readChunk2(_, openFile):
|
||||
d = openFile.readChunk(0, 30)
|
||||
self._emptyBuffers()
|
||||
d.addCallback(self.assertEqual, 'a'*10 + 'b'*10 + 'c'*10)
|
||||
return d
|
||||
|
||||
d.addCallback(_fileOpened)
|
||||
return d
|
||||
|
||||
def testClosedFileGetAttrs(self):
|
||||
d = self.client.openFile("testfile1", filetransfer.FXF_READ |
|
||||
filetransfer.FXF_WRITE, {})
|
||||
self._emptyBuffers()
|
||||
|
||||
def _getAttrs(_, openFile):
|
||||
d = openFile.getAttrs()
|
||||
self._emptyBuffers()
|
||||
return d
|
||||
|
||||
def _err(f):
|
||||
self.flushLoggedErrors()
|
||||
return f
|
||||
|
||||
def _close(openFile):
|
||||
d = openFile.close()
|
||||
self._emptyBuffers()
|
||||
d.addCallback(_getAttrs, openFile)
|
||||
d.addErrback(_err)
|
||||
return self.assertFailure(d, filetransfer.SFTPError)
|
||||
|
||||
d.addCallback(_close)
|
||||
return d
|
||||
|
||||
def testOpenFileAttributes(self):
|
||||
d = self.client.openFile("testfile1", filetransfer.FXF_READ |
|
||||
filetransfer.FXF_WRITE, {})
|
||||
self._emptyBuffers()
|
||||
|
||||
def _getAttrs(openFile):
|
||||
d = openFile.getAttrs()
|
||||
self._emptyBuffers()
|
||||
d.addCallback(_getAttrs2)
|
||||
return d
|
||||
|
||||
def _getAttrs2(attrs1):
|
||||
d = self.client.getAttrs('testfile1')
|
||||
self._emptyBuffers()
|
||||
d.addCallback(self.assertEqual, attrs1)
|
||||
return d
|
||||
|
||||
return d.addCallback(_getAttrs)
|
||||
|
||||
|
||||
def testOpenFileSetAttrs(self):
|
||||
# XXX test setAttrs
|
||||
# Ok, how about this for a start? It caught a bug :) -- spiv.
|
||||
d = self.client.openFile("testfile1", filetransfer.FXF_READ |
|
||||
filetransfer.FXF_WRITE, {})
|
||||
self._emptyBuffers()
|
||||
|
||||
def _getAttrs(openFile):
|
||||
d = openFile.getAttrs()
|
||||
self._emptyBuffers()
|
||||
d.addCallback(_setAttrs)
|
||||
return d
|
||||
|
||||
def _setAttrs(attrs):
|
||||
attrs['atime'] = 0
|
||||
d = self.client.setAttrs('testfile1', attrs)
|
||||
self._emptyBuffers()
|
||||
d.addCallback(_getAttrs2)
|
||||
d.addCallback(self.assertEqual, attrs)
|
||||
return d
|
||||
|
||||
def _getAttrs2(_):
|
||||
d = self.client.getAttrs('testfile1')
|
||||
self._emptyBuffers()
|
||||
return d
|
||||
|
||||
d.addCallback(_getAttrs)
|
||||
return d
|
||||
|
||||
|
||||
def test_openFileExtendedAttributes(self):
|
||||
"""
|
||||
Check that L{filetransfer.FileTransferClient.openFile} can send
|
||||
extended attributes, that should be extracted server side. By default,
|
||||
they are ignored, so we just verify they are correctly parsed.
|
||||
"""
|
||||
savedAttributes = {}
|
||||
oldOpenFile = self.server.client.openFile
|
||||
def openFile(filename, flags, attrs):
|
||||
savedAttributes.update(attrs)
|
||||
return oldOpenFile(filename, flags, attrs)
|
||||
self.server.client.openFile = openFile
|
||||
|
||||
d = self.client.openFile("testfile1", filetransfer.FXF_READ |
|
||||
filetransfer.FXF_WRITE, {"ext_foo": "bar"})
|
||||
self._emptyBuffers()
|
||||
|
||||
def check(ign):
|
||||
self.assertEqual(savedAttributes, {"ext_foo": "bar"})
|
||||
|
||||
return d.addCallback(check)
|
||||
|
||||
|
||||
def testRemoveFile(self):
|
||||
d = self.client.getAttrs("testRemoveFile")
|
||||
self._emptyBuffers()
|
||||
def _removeFile(ignored):
|
||||
d = self.client.removeFile("testRemoveFile")
|
||||
self._emptyBuffers()
|
||||
return d
|
||||
d.addCallback(_removeFile)
|
||||
d.addCallback(_removeFile)
|
||||
return self.assertFailure(d, filetransfer.SFTPError)
|
||||
|
||||
def testRenameFile(self):
|
||||
d = self.client.getAttrs("testRenameFile")
|
||||
self._emptyBuffers()
|
||||
def _rename(attrs):
|
||||
d = self.client.renameFile("testRenameFile", "testRenamedFile")
|
||||
self._emptyBuffers()
|
||||
d.addCallback(_testRenamed, attrs)
|
||||
return d
|
||||
def _testRenamed(_, attrs):
|
||||
d = self.client.getAttrs("testRenamedFile")
|
||||
self._emptyBuffers()
|
||||
d.addCallback(self.assertEqual, attrs)
|
||||
return d.addCallback(_rename)
|
||||
|
||||
def testDirectoryBad(self):
|
||||
d = self.client.getAttrs("testMakeDirectory")
|
||||
self._emptyBuffers()
|
||||
return self.assertFailure(d, filetransfer.SFTPError)
|
||||
|
||||
def testDirectoryCreation(self):
|
||||
d = self.client.makeDirectory("testMakeDirectory", {})
|
||||
self._emptyBuffers()
|
||||
|
||||
def _getAttrs(_):
|
||||
d = self.client.getAttrs("testMakeDirectory")
|
||||
self._emptyBuffers()
|
||||
return d
|
||||
|
||||
# XXX not until version 4/5
|
||||
# self.assertEqual(filetransfer.FILEXFER_TYPE_DIRECTORY&attrs['type'],
|
||||
# filetransfer.FILEXFER_TYPE_DIRECTORY)
|
||||
|
||||
def _removeDirectory(_):
|
||||
d = self.client.removeDirectory("testMakeDirectory")
|
||||
self._emptyBuffers()
|
||||
return d
|
||||
|
||||
d.addCallback(_getAttrs)
|
||||
d.addCallback(_removeDirectory)
|
||||
d.addCallback(_getAttrs)
|
||||
return self.assertFailure(d, filetransfer.SFTPError)
|
||||
|
||||
def testOpenDirectory(self):
|
||||
d = self.client.openDirectory('')
|
||||
self._emptyBuffers()
|
||||
files = []
|
||||
|
||||
def _getFiles(openDir):
|
||||
def append(f):
|
||||
files.append(f)
|
||||
return openDir
|
||||
d = defer.maybeDeferred(openDir.next)
|
||||
self._emptyBuffers()
|
||||
d.addCallback(append)
|
||||
d.addCallback(_getFiles)
|
||||
d.addErrback(_close, openDir)
|
||||
return d
|
||||
|
||||
def _checkFiles(ignored):
|
||||
fs = list(zip(*files)[0])
|
||||
fs.sort()
|
||||
self.assertEqual(fs,
|
||||
['.testHiddenFile', 'testDirectory',
|
||||
'testRemoveFile', 'testRenameFile',
|
||||
'testfile1'])
|
||||
|
||||
def _close(_, openDir):
|
||||
d = openDir.close()
|
||||
self._emptyBuffers()
|
||||
return d
|
||||
|
||||
d.addCallback(_getFiles)
|
||||
d.addCallback(_checkFiles)
|
||||
return d
|
||||
|
||||
def testLinkDoesntExist(self):
|
||||
d = self.client.getAttrs('testLink')
|
||||
self._emptyBuffers()
|
||||
return self.assertFailure(d, filetransfer.SFTPError)
|
||||
|
||||
def testLinkSharesAttrs(self):
|
||||
d = self.client.makeLink('testLink', 'testfile1')
|
||||
self._emptyBuffers()
|
||||
def _getFirstAttrs(_):
|
||||
d = self.client.getAttrs('testLink', 1)
|
||||
self._emptyBuffers()
|
||||
return d
|
||||
def _getSecondAttrs(firstAttrs):
|
||||
d = self.client.getAttrs('testfile1')
|
||||
self._emptyBuffers()
|
||||
d.addCallback(self.assertEqual, firstAttrs)
|
||||
return d
|
||||
d.addCallback(_getFirstAttrs)
|
||||
return d.addCallback(_getSecondAttrs)
|
||||
|
||||
def testLinkPath(self):
|
||||
d = self.client.makeLink('testLink', 'testfile1')
|
||||
self._emptyBuffers()
|
||||
def _readLink(_):
|
||||
d = self.client.readLink('testLink')
|
||||
self._emptyBuffers()
|
||||
d.addCallback(self.assertEqual,
|
||||
os.path.join(os.getcwd(), self.testDir, 'testfile1'))
|
||||
return d
|
||||
def _realPath(_):
|
||||
d = self.client.realPath('testLink')
|
||||
self._emptyBuffers()
|
||||
d.addCallback(self.assertEqual,
|
||||
os.path.join(os.getcwd(), self.testDir, 'testfile1'))
|
||||
return d
|
||||
d.addCallback(_readLink)
|
||||
d.addCallback(_realPath)
|
||||
return d
|
||||
|
||||
def testExtendedRequest(self):
|
||||
d = self.client.extendedRequest('testExtendedRequest', 'foo')
|
||||
self._emptyBuffers()
|
||||
d.addCallback(self.assertEqual, 'bar')
|
||||
d.addCallback(self._cbTestExtendedRequest)
|
||||
return d
|
||||
|
||||
def _cbTestExtendedRequest(self, ignored):
|
||||
d = self.client.extendedRequest('testBadRequest', '')
|
||||
self._emptyBuffers()
|
||||
return self.assertFailure(d, NotImplementedError)
|
||||
|
||||
|
||||
class FakeConn:
|
||||
def sendClose(self, channel):
|
||||
pass
|
||||
|
||||
|
||||
class TestFileTransferClose(unittest.TestCase):
|
||||
|
||||
if not unix:
|
||||
skip = "can't run on non-posix computers"
|
||||
|
||||
def setUp(self):
|
||||
self.avatar = TestAvatar()
|
||||
|
||||
def buildServerConnection(self):
|
||||
# make a server connection
|
||||
conn = connection.SSHConnection()
|
||||
# server connections have a 'self.transport.avatar'.
|
||||
class DummyTransport:
|
||||
def __init__(self):
|
||||
self.transport = self
|
||||
def sendPacket(self, kind, data):
|
||||
pass
|
||||
def logPrefix(self):
|
||||
return 'dummy transport'
|
||||
conn.transport = DummyTransport()
|
||||
conn.transport.avatar = self.avatar
|
||||
return conn
|
||||
|
||||
def interceptConnectionLost(self, sftpServer):
|
||||
self.connectionLostFired = False
|
||||
origConnectionLost = sftpServer.connectionLost
|
||||
def connectionLost(reason):
|
||||
self.connectionLostFired = True
|
||||
origConnectionLost(reason)
|
||||
sftpServer.connectionLost = connectionLost
|
||||
|
||||
def assertSFTPConnectionLost(self):
|
||||
self.assertTrue(self.connectionLostFired,
|
||||
"sftpServer's connectionLost was not called")
|
||||
|
||||
def test_sessionClose(self):
|
||||
"""
|
||||
Closing a session should notify an SFTP subsystem launched by that
|
||||
session.
|
||||
"""
|
||||
# make a session
|
||||
testSession = session.SSHSession(conn=FakeConn(), avatar=self.avatar)
|
||||
|
||||
# start an SFTP subsystem on the session
|
||||
testSession.request_subsystem(common.NS('sftp'))
|
||||
sftpServer = testSession.client.transport.proto
|
||||
|
||||
# intercept connectionLost so we can check that it's called
|
||||
self.interceptConnectionLost(sftpServer)
|
||||
|
||||
# close session
|
||||
testSession.closeReceived()
|
||||
|
||||
self.assertSFTPConnectionLost()
|
||||
|
||||
def test_clientClosesChannelOnConnnection(self):
|
||||
"""
|
||||
A client sending CHANNEL_CLOSE should trigger closeReceived on the
|
||||
associated channel instance.
|
||||
"""
|
||||
conn = self.buildServerConnection()
|
||||
|
||||
# somehow get a session
|
||||
packet = common.NS('session') + struct.pack('>L', 0) * 3
|
||||
conn.ssh_CHANNEL_OPEN(packet)
|
||||
sessionChannel = conn.channels[0]
|
||||
|
||||
sessionChannel.request_subsystem(common.NS('sftp'))
|
||||
sftpServer = sessionChannel.client.transport.proto
|
||||
self.interceptConnectionLost(sftpServer)
|
||||
|
||||
# intercept closeReceived
|
||||
self.interceptConnectionLost(sftpServer)
|
||||
|
||||
# close the connection
|
||||
conn.ssh_CHANNEL_CLOSE(struct.pack('>L', 0))
|
||||
|
||||
self.assertSFTPConnectionLost()
|
||||
|
||||
|
||||
def test_stopConnectionServiceClosesChannel(self):
|
||||
"""
|
||||
Closing an SSH connection should close all sessions within it.
|
||||
"""
|
||||
conn = self.buildServerConnection()
|
||||
|
||||
# somehow get a session
|
||||
packet = common.NS('session') + struct.pack('>L', 0) * 3
|
||||
conn.ssh_CHANNEL_OPEN(packet)
|
||||
sessionChannel = conn.channels[0]
|
||||
|
||||
sessionChannel.request_subsystem(common.NS('sftp'))
|
||||
sftpServer = sessionChannel.client.transport.proto
|
||||
self.interceptConnectionLost(sftpServer)
|
||||
|
||||
# close the connection
|
||||
conn.serviceStopped()
|
||||
|
||||
self.assertSFTPConnectionLost()
|
||||
|
||||
|
||||
|
||||
class TestConstants(unittest.TestCase):
|
||||
"""
|
||||
Tests for the constants used by the SFTP protocol implementation.
|
||||
|
||||
@ivar filexferSpecExcerpts: Excerpts from the
|
||||
draft-ietf-secsh-filexfer-02.txt (draft) specification of the SFTP
|
||||
protocol. There are more recent drafts of the specification, but this
|
||||
one describes version 3, which is what conch (and OpenSSH) implements.
|
||||
"""
|
||||
|
||||
|
||||
filexferSpecExcerpts = [
|
||||
"""
|
||||
The following values are defined for packet types.
|
||||
|
||||
#define SSH_FXP_INIT 1
|
||||
#define SSH_FXP_VERSION 2
|
||||
#define SSH_FXP_OPEN 3
|
||||
#define SSH_FXP_CLOSE 4
|
||||
#define SSH_FXP_READ 5
|
||||
#define SSH_FXP_WRITE 6
|
||||
#define SSH_FXP_LSTAT 7
|
||||
#define SSH_FXP_FSTAT 8
|
||||
#define SSH_FXP_SETSTAT 9
|
||||
#define SSH_FXP_FSETSTAT 10
|
||||
#define SSH_FXP_OPENDIR 11
|
||||
#define SSH_FXP_READDIR 12
|
||||
#define SSH_FXP_REMOVE 13
|
||||
#define SSH_FXP_MKDIR 14
|
||||
#define SSH_FXP_RMDIR 15
|
||||
#define SSH_FXP_REALPATH 16
|
||||
#define SSH_FXP_STAT 17
|
||||
#define SSH_FXP_RENAME 18
|
||||
#define SSH_FXP_READLINK 19
|
||||
#define SSH_FXP_SYMLINK 20
|
||||
#define SSH_FXP_STATUS 101
|
||||
#define SSH_FXP_HANDLE 102
|
||||
#define SSH_FXP_DATA 103
|
||||
#define SSH_FXP_NAME 104
|
||||
#define SSH_FXP_ATTRS 105
|
||||
#define SSH_FXP_EXTENDED 200
|
||||
#define SSH_FXP_EXTENDED_REPLY 201
|
||||
|
||||
Additional packet types should only be defined if the protocol
|
||||
version number (see Section ``Protocol Initialization'') is
|
||||
incremented, and their use MUST be negotiated using the version
|
||||
number. However, the SSH_FXP_EXTENDED and SSH_FXP_EXTENDED_REPLY
|
||||
packets can be used to implement vendor-specific extensions. See
|
||||
Section ``Vendor-Specific-Extensions'' for more details.
|
||||
""",
|
||||
"""
|
||||
The flags bits are defined to have the following values:
|
||||
|
||||
#define SSH_FILEXFER_ATTR_SIZE 0x00000001
|
||||
#define SSH_FILEXFER_ATTR_UIDGID 0x00000002
|
||||
#define SSH_FILEXFER_ATTR_PERMISSIONS 0x00000004
|
||||
#define SSH_FILEXFER_ATTR_ACMODTIME 0x00000008
|
||||
#define SSH_FILEXFER_ATTR_EXTENDED 0x80000000
|
||||
|
||||
""",
|
||||
"""
|
||||
The `pflags' field is a bitmask. The following bits have been
|
||||
defined.
|
||||
|
||||
#define SSH_FXF_READ 0x00000001
|
||||
#define SSH_FXF_WRITE 0x00000002
|
||||
#define SSH_FXF_APPEND 0x00000004
|
||||
#define SSH_FXF_CREAT 0x00000008
|
||||
#define SSH_FXF_TRUNC 0x00000010
|
||||
#define SSH_FXF_EXCL 0x00000020
|
||||
""",
|
||||
"""
|
||||
Currently, the following values are defined (other values may be
|
||||
defined by future versions of this protocol):
|
||||
|
||||
#define SSH_FX_OK 0
|
||||
#define SSH_FX_EOF 1
|
||||
#define SSH_FX_NO_SUCH_FILE 2
|
||||
#define SSH_FX_PERMISSION_DENIED 3
|
||||
#define SSH_FX_FAILURE 4
|
||||
#define SSH_FX_BAD_MESSAGE 5
|
||||
#define SSH_FX_NO_CONNECTION 6
|
||||
#define SSH_FX_CONNECTION_LOST 7
|
||||
#define SSH_FX_OP_UNSUPPORTED 8
|
||||
"""]
|
||||
|
||||
|
||||
def test_constantsAgainstSpec(self):
|
||||
"""
|
||||
The constants used by the SFTP protocol implementation match those
|
||||
found by searching through the spec.
|
||||
"""
|
||||
constants = {}
|
||||
for excerpt in self.filexferSpecExcerpts:
|
||||
for line in excerpt.splitlines():
|
||||
m = re.match('^\s*#define SSH_([A-Z_]+)\s+([0-9x]*)\s*$', line)
|
||||
if m:
|
||||
constants[m.group(1)] = long(m.group(2), 0)
|
||||
self.assertTrue(
|
||||
len(constants) > 0, "No constants found (the test must be buggy).")
|
||||
for k, v in constants.items():
|
||||
self.assertEqual(v, getattr(filetransfer, k))
|
||||
|
||||
|
||||
|
||||
class TestRawPacketData(unittest.TestCase):
|
||||
"""
|
||||
Tests for L{filetransfer.FileTransferClient} which explicitly craft certain
|
||||
less common protocol messages to exercise their handling.
|
||||
"""
|
||||
def setUp(self):
|
||||
self.ftc = filetransfer.FileTransferClient()
|
||||
|
||||
|
||||
def test_packetSTATUS(self):
|
||||
"""
|
||||
A STATUS packet containing a result code, a message, and a language is
|
||||
parsed to produce the result of an outstanding request L{Deferred}.
|
||||
|
||||
@see: U{section 9.1<http://tools.ietf.org/html/draft-ietf-secsh-filexfer-13#section-9.1>}
|
||||
of the SFTP Internet-Draft.
|
||||
"""
|
||||
d = defer.Deferred()
|
||||
d.addCallback(self._cbTestPacketSTATUS)
|
||||
self.ftc.openRequests[1] = d
|
||||
data = struct.pack('!LL', 1, filetransfer.FX_OK) + common.NS('msg') + common.NS('lang')
|
||||
self.ftc.packet_STATUS(data)
|
||||
return d
|
||||
|
||||
|
||||
def _cbTestPacketSTATUS(self, result):
|
||||
"""
|
||||
Assert that the result is a two-tuple containing the message and
|
||||
language from the STATUS packet.
|
||||
"""
|
||||
self.assertEqual(result[0], 'msg')
|
||||
self.assertEqual(result[1], 'lang')
|
||||
|
||||
|
||||
def test_packetSTATUSShort(self):
|
||||
"""
|
||||
A STATUS packet containing only a result code can also be parsed to
|
||||
produce the result of an outstanding request L{Deferred}. Such packets
|
||||
are sent by some SFTP implementations, though not strictly legal.
|
||||
|
||||
@see: U{section 9.1<http://tools.ietf.org/html/draft-ietf-secsh-filexfer-13#section-9.1>}
|
||||
of the SFTP Internet-Draft.
|
||||
"""
|
||||
d = defer.Deferred()
|
||||
d.addCallback(self._cbTestPacketSTATUSShort)
|
||||
self.ftc.openRequests[1] = d
|
||||
data = struct.pack('!LL', 1, filetransfer.FX_OK)
|
||||
self.ftc.packet_STATUS(data)
|
||||
return d
|
||||
|
||||
|
||||
def _cbTestPacketSTATUSShort(self, result):
|
||||
"""
|
||||
Assert that the result is a two-tuple containing empty strings, since
|
||||
the STATUS packet had neither a message nor a language.
|
||||
"""
|
||||
self.assertEqual(result[0], '')
|
||||
self.assertEqual(result[1], '')
|
||||
|
||||
|
||||
def test_packetSTATUSWithoutLang(self):
|
||||
"""
|
||||
A STATUS packet containing a result code and a message but no language
|
||||
can also be parsed to produce the result of an outstanding request
|
||||
L{Deferred}. Such packets are sent by some SFTP implementations, though
|
||||
not strictly legal.
|
||||
|
||||
@see: U{section 9.1<http://tools.ietf.org/html/draft-ietf-secsh-filexfer-13#section-9.1>}
|
||||
of the SFTP Internet-Draft.
|
||||
"""
|
||||
d = defer.Deferred()
|
||||
d.addCallback(self._cbTestPacketSTATUSWithoutLang)
|
||||
self.ftc.openRequests[1] = d
|
||||
data = struct.pack('!LL', 1, filetransfer.FX_OK) + common.NS('msg')
|
||||
self.ftc.packet_STATUS(data)
|
||||
return d
|
||||
|
||||
|
||||
def _cbTestPacketSTATUSWithoutLang(self, result):
|
||||
"""
|
||||
Assert that the result is a two-tuple containing the message from the
|
||||
STATUS packet and an empty string, since the language was missing.
|
||||
"""
|
||||
self.assertEqual(result[0], 'msg')
|
||||
self.assertEqual(result[1], '')
|
||||
|
|
@ -0,0 +1,614 @@
|
|||
# -*- test-case-name: twisted.conch.test.test_helper -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
from twisted.conch.insults import helper
|
||||
from twisted.conch.insults.insults import G0, G1, G2, G3
|
||||
from twisted.conch.insults.insults import modes, privateModes
|
||||
from twisted.conch.insults.insults import (
|
||||
NORMAL, BOLD, UNDERLINE, BLINK, REVERSE_VIDEO)
|
||||
|
||||
from twisted.trial import unittest
|
||||
|
||||
WIDTH = 80
|
||||
HEIGHT = 24
|
||||
|
||||
class BufferTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.term = helper.TerminalBuffer()
|
||||
self.term.connectionMade()
|
||||
|
||||
def testInitialState(self):
|
||||
self.assertEqual(self.term.width, WIDTH)
|
||||
self.assertEqual(self.term.height, HEIGHT)
|
||||
self.assertEqual(str(self.term),
|
||||
'\n' * (HEIGHT - 1))
|
||||
self.assertEqual(self.term.reportCursorPosition(), (0, 0))
|
||||
|
||||
|
||||
def test_initialPrivateModes(self):
|
||||
"""
|
||||
Verify that only DEC Auto Wrap Mode (DECAWM) and DEC Text Cursor Enable
|
||||
Mode (DECTCEM) are initially in the Set Mode (SM) state.
|
||||
"""
|
||||
self.assertEqual(
|
||||
{privateModes.AUTO_WRAP: True,
|
||||
privateModes.CURSOR_MODE: True},
|
||||
self.term.privateModes)
|
||||
|
||||
|
||||
def test_carriageReturn(self):
|
||||
"""
|
||||
C{"\r"} moves the cursor to the first column in the current row.
|
||||
"""
|
||||
self.term.cursorForward(5)
|
||||
self.term.cursorDown(3)
|
||||
self.assertEqual(self.term.reportCursorPosition(), (5, 3))
|
||||
self.term.insertAtCursor("\r")
|
||||
self.assertEqual(self.term.reportCursorPosition(), (0, 3))
|
||||
|
||||
|
||||
def test_linefeed(self):
|
||||
"""
|
||||
C{"\n"} moves the cursor to the next row without changing the column.
|
||||
"""
|
||||
self.term.cursorForward(5)
|
||||
self.assertEqual(self.term.reportCursorPosition(), (5, 0))
|
||||
self.term.insertAtCursor("\n")
|
||||
self.assertEqual(self.term.reportCursorPosition(), (5, 1))
|
||||
|
||||
|
||||
def test_newline(self):
|
||||
"""
|
||||
C{write} transforms C{"\n"} into C{"\r\n"}.
|
||||
"""
|
||||
self.term.cursorForward(5)
|
||||
self.term.cursorDown(3)
|
||||
self.assertEqual(self.term.reportCursorPosition(), (5, 3))
|
||||
self.term.write("\n")
|
||||
self.assertEqual(self.term.reportCursorPosition(), (0, 4))
|
||||
|
||||
|
||||
def test_setPrivateModes(self):
|
||||
"""
|
||||
Verify that L{helper.TerminalBuffer.setPrivateModes} changes the Set
|
||||
Mode (SM) state to "set" for the private modes it is passed.
|
||||
"""
|
||||
expected = self.term.privateModes.copy()
|
||||
self.term.setPrivateModes([privateModes.SCROLL, privateModes.SCREEN])
|
||||
expected[privateModes.SCROLL] = True
|
||||
expected[privateModes.SCREEN] = True
|
||||
self.assertEqual(expected, self.term.privateModes)
|
||||
|
||||
|
||||
def test_resetPrivateModes(self):
|
||||
"""
|
||||
Verify that L{helper.TerminalBuffer.resetPrivateModes} changes the Set
|
||||
Mode (SM) state to "reset" for the private modes it is passed.
|
||||
"""
|
||||
expected = self.term.privateModes.copy()
|
||||
self.term.resetPrivateModes([privateModes.AUTO_WRAP, privateModes.CURSOR_MODE])
|
||||
del expected[privateModes.AUTO_WRAP]
|
||||
del expected[privateModes.CURSOR_MODE]
|
||||
self.assertEqual(expected, self.term.privateModes)
|
||||
|
||||
|
||||
def testCursorDown(self):
|
||||
self.term.cursorDown(3)
|
||||
self.assertEqual(self.term.reportCursorPosition(), (0, 3))
|
||||
self.term.cursorDown()
|
||||
self.assertEqual(self.term.reportCursorPosition(), (0, 4))
|
||||
self.term.cursorDown(HEIGHT)
|
||||
self.assertEqual(self.term.reportCursorPosition(), (0, HEIGHT - 1))
|
||||
|
||||
def testCursorUp(self):
|
||||
self.term.cursorUp(5)
|
||||
self.assertEqual(self.term.reportCursorPosition(), (0, 0))
|
||||
|
||||
self.term.cursorDown(20)
|
||||
self.term.cursorUp(1)
|
||||
self.assertEqual(self.term.reportCursorPosition(), (0, 19))
|
||||
|
||||
self.term.cursorUp(19)
|
||||
self.assertEqual(self.term.reportCursorPosition(), (0, 0))
|
||||
|
||||
def testCursorForward(self):
|
||||
self.term.cursorForward(2)
|
||||
self.assertEqual(self.term.reportCursorPosition(), (2, 0))
|
||||
self.term.cursorForward(2)
|
||||
self.assertEqual(self.term.reportCursorPosition(), (4, 0))
|
||||
self.term.cursorForward(WIDTH)
|
||||
self.assertEqual(self.term.reportCursorPosition(), (WIDTH, 0))
|
||||
|
||||
def testCursorBackward(self):
|
||||
self.term.cursorForward(10)
|
||||
self.term.cursorBackward(2)
|
||||
self.assertEqual(self.term.reportCursorPosition(), (8, 0))
|
||||
self.term.cursorBackward(7)
|
||||
self.assertEqual(self.term.reportCursorPosition(), (1, 0))
|
||||
self.term.cursorBackward(1)
|
||||
self.assertEqual(self.term.reportCursorPosition(), (0, 0))
|
||||
self.term.cursorBackward(1)
|
||||
self.assertEqual(self.term.reportCursorPosition(), (0, 0))
|
||||
|
||||
def testCursorPositioning(self):
|
||||
self.term.cursorPosition(3, 9)
|
||||
self.assertEqual(self.term.reportCursorPosition(), (3, 9))
|
||||
|
||||
def testSimpleWriting(self):
|
||||
s = "Hello, world."
|
||||
self.term.write(s)
|
||||
self.assertEqual(
|
||||
str(self.term),
|
||||
s + '\n' +
|
||||
'\n' * (HEIGHT - 2))
|
||||
|
||||
def testOvertype(self):
|
||||
s = "hello, world."
|
||||
self.term.write(s)
|
||||
self.term.cursorBackward(len(s))
|
||||
self.term.resetModes([modes.IRM])
|
||||
self.term.write("H")
|
||||
self.assertEqual(
|
||||
str(self.term),
|
||||
("H" + s[1:]) + '\n' +
|
||||
'\n' * (HEIGHT - 2))
|
||||
|
||||
def testInsert(self):
|
||||
s = "ello, world."
|
||||
self.term.write(s)
|
||||
self.term.cursorBackward(len(s))
|
||||
self.term.setModes([modes.IRM])
|
||||
self.term.write("H")
|
||||
self.assertEqual(
|
||||
str(self.term),
|
||||
("H" + s) + '\n' +
|
||||
'\n' * (HEIGHT - 2))
|
||||
|
||||
def testWritingInTheMiddle(self):
|
||||
s = "Hello, world."
|
||||
self.term.cursorDown(5)
|
||||
self.term.cursorForward(5)
|
||||
self.term.write(s)
|
||||
self.assertEqual(
|
||||
str(self.term),
|
||||
'\n' * 5 +
|
||||
(self.term.fill * 5) + s + '\n' +
|
||||
'\n' * (HEIGHT - 7))
|
||||
|
||||
def testWritingWrappedAtEndOfLine(self):
|
||||
s = "Hello, world."
|
||||
self.term.cursorForward(WIDTH - 5)
|
||||
self.term.write(s)
|
||||
self.assertEqual(
|
||||
str(self.term),
|
||||
s[:5].rjust(WIDTH) + '\n' +
|
||||
s[5:] + '\n' +
|
||||
'\n' * (HEIGHT - 3))
|
||||
|
||||
def testIndex(self):
|
||||
self.term.index()
|
||||
self.assertEqual(self.term.reportCursorPosition(), (0, 1))
|
||||
self.term.cursorDown(HEIGHT)
|
||||
self.assertEqual(self.term.reportCursorPosition(), (0, HEIGHT - 1))
|
||||
self.term.index()
|
||||
self.assertEqual(self.term.reportCursorPosition(), (0, HEIGHT - 1))
|
||||
|
||||
def testReverseIndex(self):
|
||||
self.term.reverseIndex()
|
||||
self.assertEqual(self.term.reportCursorPosition(), (0, 0))
|
||||
self.term.cursorDown(2)
|
||||
self.assertEqual(self.term.reportCursorPosition(), (0, 2))
|
||||
self.term.reverseIndex()
|
||||
self.assertEqual(self.term.reportCursorPosition(), (0, 1))
|
||||
|
||||
def test_nextLine(self):
|
||||
"""
|
||||
C{nextLine} positions the cursor at the beginning of the row below the
|
||||
current row.
|
||||
"""
|
||||
self.term.nextLine()
|
||||
self.assertEqual(self.term.reportCursorPosition(), (0, 1))
|
||||
self.term.cursorForward(5)
|
||||
self.assertEqual(self.term.reportCursorPosition(), (5, 1))
|
||||
self.term.nextLine()
|
||||
self.assertEqual(self.term.reportCursorPosition(), (0, 2))
|
||||
|
||||
def testSaveCursor(self):
|
||||
self.term.cursorDown(5)
|
||||
self.term.cursorForward(7)
|
||||
self.assertEqual(self.term.reportCursorPosition(), (7, 5))
|
||||
self.term.saveCursor()
|
||||
self.term.cursorDown(7)
|
||||
self.term.cursorBackward(3)
|
||||
self.assertEqual(self.term.reportCursorPosition(), (4, 12))
|
||||
self.term.restoreCursor()
|
||||
self.assertEqual(self.term.reportCursorPosition(), (7, 5))
|
||||
|
||||
def testSingleShifts(self):
|
||||
self.term.singleShift2()
|
||||
self.term.write('Hi')
|
||||
|
||||
ch = self.term.getCharacter(0, 0)
|
||||
self.assertEqual(ch[0], 'H')
|
||||
self.assertEqual(ch[1].charset, G2)
|
||||
|
||||
ch = self.term.getCharacter(1, 0)
|
||||
self.assertEqual(ch[0], 'i')
|
||||
self.assertEqual(ch[1].charset, G0)
|
||||
|
||||
self.term.singleShift3()
|
||||
self.term.write('!!')
|
||||
|
||||
ch = self.term.getCharacter(2, 0)
|
||||
self.assertEqual(ch[0], '!')
|
||||
self.assertEqual(ch[1].charset, G3)
|
||||
|
||||
ch = self.term.getCharacter(3, 0)
|
||||
self.assertEqual(ch[0], '!')
|
||||
self.assertEqual(ch[1].charset, G0)
|
||||
|
||||
def testShifting(self):
|
||||
s1 = "Hello"
|
||||
s2 = "World"
|
||||
s3 = "Bye!"
|
||||
self.term.write("Hello\n")
|
||||
self.term.shiftOut()
|
||||
self.term.write("World\n")
|
||||
self.term.shiftIn()
|
||||
self.term.write("Bye!\n")
|
||||
|
||||
g = G0
|
||||
h = 0
|
||||
for s in (s1, s2, s3):
|
||||
for i in range(len(s)):
|
||||
ch = self.term.getCharacter(i, h)
|
||||
self.assertEqual(ch[0], s[i])
|
||||
self.assertEqual(ch[1].charset, g)
|
||||
g = g == G0 and G1 or G0
|
||||
h += 1
|
||||
|
||||
def testGraphicRendition(self):
|
||||
self.term.selectGraphicRendition(BOLD, UNDERLINE, BLINK, REVERSE_VIDEO)
|
||||
self.term.write('W')
|
||||
self.term.selectGraphicRendition(NORMAL)
|
||||
self.term.write('X')
|
||||
self.term.selectGraphicRendition(BLINK)
|
||||
self.term.write('Y')
|
||||
self.term.selectGraphicRendition(BOLD)
|
||||
self.term.write('Z')
|
||||
|
||||
ch = self.term.getCharacter(0, 0)
|
||||
self.assertEqual(ch[0], 'W')
|
||||
self.assertTrue(ch[1].bold)
|
||||
self.assertTrue(ch[1].underline)
|
||||
self.assertTrue(ch[1].blink)
|
||||
self.assertTrue(ch[1].reverseVideo)
|
||||
|
||||
ch = self.term.getCharacter(1, 0)
|
||||
self.assertEqual(ch[0], 'X')
|
||||
self.assertFalse(ch[1].bold)
|
||||
self.assertFalse(ch[1].underline)
|
||||
self.assertFalse(ch[1].blink)
|
||||
self.assertFalse(ch[1].reverseVideo)
|
||||
|
||||
ch = self.term.getCharacter(2, 0)
|
||||
self.assertEqual(ch[0], 'Y')
|
||||
self.assertTrue(ch[1].blink)
|
||||
self.assertFalse(ch[1].bold)
|
||||
self.assertFalse(ch[1].underline)
|
||||
self.assertFalse(ch[1].reverseVideo)
|
||||
|
||||
ch = self.term.getCharacter(3, 0)
|
||||
self.assertEqual(ch[0], 'Z')
|
||||
self.assertTrue(ch[1].blink)
|
||||
self.assertTrue(ch[1].bold)
|
||||
self.assertFalse(ch[1].underline)
|
||||
self.assertFalse(ch[1].reverseVideo)
|
||||
|
||||
def testColorAttributes(self):
|
||||
s1 = "Merry xmas"
|
||||
s2 = "Just kidding"
|
||||
self.term.selectGraphicRendition(helper.FOREGROUND + helper.RED,
|
||||
helper.BACKGROUND + helper.GREEN)
|
||||
self.term.write(s1 + "\n")
|
||||
self.term.selectGraphicRendition(NORMAL)
|
||||
self.term.write(s2 + "\n")
|
||||
|
||||
for i in range(len(s1)):
|
||||
ch = self.term.getCharacter(i, 0)
|
||||
self.assertEqual(ch[0], s1[i])
|
||||
self.assertEqual(ch[1].charset, G0)
|
||||
self.assertEqual(ch[1].bold, False)
|
||||
self.assertEqual(ch[1].underline, False)
|
||||
self.assertEqual(ch[1].blink, False)
|
||||
self.assertEqual(ch[1].reverseVideo, False)
|
||||
self.assertEqual(ch[1].foreground, helper.RED)
|
||||
self.assertEqual(ch[1].background, helper.GREEN)
|
||||
|
||||
for i in range(len(s2)):
|
||||
ch = self.term.getCharacter(i, 1)
|
||||
self.assertEqual(ch[0], s2[i])
|
||||
self.assertEqual(ch[1].charset, G0)
|
||||
self.assertEqual(ch[1].bold, False)
|
||||
self.assertEqual(ch[1].underline, False)
|
||||
self.assertEqual(ch[1].blink, False)
|
||||
self.assertEqual(ch[1].reverseVideo, False)
|
||||
self.assertEqual(ch[1].foreground, helper.WHITE)
|
||||
self.assertEqual(ch[1].background, helper.BLACK)
|
||||
|
||||
def testEraseLine(self):
|
||||
s1 = 'line 1'
|
||||
s2 = 'line 2'
|
||||
s3 = 'line 3'
|
||||
self.term.write('\n'.join((s1, s2, s3)) + '\n')
|
||||
self.term.cursorPosition(1, 1)
|
||||
self.term.eraseLine()
|
||||
|
||||
self.assertEqual(
|
||||
str(self.term),
|
||||
s1 + '\n' +
|
||||
'\n' +
|
||||
s3 + '\n' +
|
||||
'\n' * (HEIGHT - 4))
|
||||
|
||||
def testEraseToLineEnd(self):
|
||||
s = 'Hello, world.'
|
||||
self.term.write(s)
|
||||
self.term.cursorBackward(5)
|
||||
self.term.eraseToLineEnd()
|
||||
self.assertEqual(
|
||||
str(self.term),
|
||||
s[:-5] + '\n' +
|
||||
'\n' * (HEIGHT - 2))
|
||||
|
||||
def testEraseToLineBeginning(self):
|
||||
s = 'Hello, world.'
|
||||
self.term.write(s)
|
||||
self.term.cursorBackward(5)
|
||||
self.term.eraseToLineBeginning()
|
||||
self.assertEqual(
|
||||
str(self.term),
|
||||
s[-4:].rjust(len(s)) + '\n' +
|
||||
'\n' * (HEIGHT - 2))
|
||||
|
||||
def testEraseDisplay(self):
|
||||
self.term.write('Hello world\n')
|
||||
self.term.write('Goodbye world\n')
|
||||
self.term.eraseDisplay()
|
||||
|
||||
self.assertEqual(
|
||||
str(self.term),
|
||||
'\n' * (HEIGHT - 1))
|
||||
|
||||
def testEraseToDisplayEnd(self):
|
||||
s1 = "Hello world"
|
||||
s2 = "Goodbye world"
|
||||
self.term.write('\n'.join((s1, s2, '')))
|
||||
self.term.cursorPosition(5, 1)
|
||||
self.term.eraseToDisplayEnd()
|
||||
|
||||
self.assertEqual(
|
||||
str(self.term),
|
||||
s1 + '\n' +
|
||||
s2[:5] + '\n' +
|
||||
'\n' * (HEIGHT - 3))
|
||||
|
||||
def testEraseToDisplayBeginning(self):
|
||||
s1 = "Hello world"
|
||||
s2 = "Goodbye world"
|
||||
self.term.write('\n'.join((s1, s2)))
|
||||
self.term.cursorPosition(5, 1)
|
||||
self.term.eraseToDisplayBeginning()
|
||||
|
||||
self.assertEqual(
|
||||
str(self.term),
|
||||
'\n' +
|
||||
s2[6:].rjust(len(s2)) + '\n' +
|
||||
'\n' * (HEIGHT - 3))
|
||||
|
||||
def testLineInsertion(self):
|
||||
s1 = "Hello world"
|
||||
s2 = "Goodbye world"
|
||||
self.term.write('\n'.join((s1, s2)))
|
||||
self.term.cursorPosition(7, 1)
|
||||
self.term.insertLine()
|
||||
|
||||
self.assertEqual(
|
||||
str(self.term),
|
||||
s1 + '\n' +
|
||||
'\n' +
|
||||
s2 + '\n' +
|
||||
'\n' * (HEIGHT - 4))
|
||||
|
||||
def testLineDeletion(self):
|
||||
s1 = "Hello world"
|
||||
s2 = "Middle words"
|
||||
s3 = "Goodbye world"
|
||||
self.term.write('\n'.join((s1, s2, s3)))
|
||||
self.term.cursorPosition(9, 1)
|
||||
self.term.deleteLine()
|
||||
|
||||
self.assertEqual(
|
||||
str(self.term),
|
||||
s1 + '\n' +
|
||||
s3 + '\n' +
|
||||
'\n' * (HEIGHT - 3))
|
||||
|
||||
class FakeDelayedCall:
|
||||
called = False
|
||||
cancelled = False
|
||||
def __init__(self, fs, timeout, f, a, kw):
|
||||
self.fs = fs
|
||||
self.timeout = timeout
|
||||
self.f = f
|
||||
self.a = a
|
||||
self.kw = kw
|
||||
|
||||
def active(self):
|
||||
return not (self.cancelled or self.called)
|
||||
|
||||
def cancel(self):
|
||||
self.cancelled = True
|
||||
# self.fs.calls.remove(self)
|
||||
|
||||
def call(self):
|
||||
self.called = True
|
||||
self.f(*self.a, **self.kw)
|
||||
|
||||
class FakeScheduler:
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
def callLater(self, timeout, f, *a, **kw):
|
||||
self.calls.append(FakeDelayedCall(self, timeout, f, a, kw))
|
||||
return self.calls[-1]
|
||||
|
||||
class ExpectTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.term = helper.ExpectableBuffer()
|
||||
self.term.connectionMade()
|
||||
self.fs = FakeScheduler()
|
||||
|
||||
def testSimpleString(self):
|
||||
result = []
|
||||
d = self.term.expect("hello world", timeout=1, scheduler=self.fs)
|
||||
d.addCallback(result.append)
|
||||
|
||||
self.term.write("greeting puny earthlings\n")
|
||||
self.assertFalse(result)
|
||||
self.term.write("hello world\n")
|
||||
self.assertTrue(result)
|
||||
self.assertEqual(result[0].group(), "hello world")
|
||||
self.assertEqual(len(self.fs.calls), 1)
|
||||
self.assertFalse(self.fs.calls[0].active())
|
||||
|
||||
def testBrokenUpString(self):
|
||||
result = []
|
||||
d = self.term.expect("hello world")
|
||||
d.addCallback(result.append)
|
||||
|
||||
self.assertFalse(result)
|
||||
self.term.write("hello ")
|
||||
self.assertFalse(result)
|
||||
self.term.write("worl")
|
||||
self.assertFalse(result)
|
||||
self.term.write("d")
|
||||
self.assertTrue(result)
|
||||
self.assertEqual(result[0].group(), "hello world")
|
||||
|
||||
|
||||
def testMultiple(self):
|
||||
result = []
|
||||
d1 = self.term.expect("hello ")
|
||||
d1.addCallback(result.append)
|
||||
d2 = self.term.expect("world")
|
||||
d2.addCallback(result.append)
|
||||
|
||||
self.assertFalse(result)
|
||||
self.term.write("hello")
|
||||
self.assertFalse(result)
|
||||
self.term.write(" ")
|
||||
self.assertEqual(len(result), 1)
|
||||
self.term.write("world")
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertEqual(result[0].group(), "hello ")
|
||||
self.assertEqual(result[1].group(), "world")
|
||||
|
||||
def testSynchronous(self):
|
||||
self.term.write("hello world")
|
||||
|
||||
result = []
|
||||
d = self.term.expect("hello world")
|
||||
d.addCallback(result.append)
|
||||
self.assertTrue(result)
|
||||
self.assertEqual(result[0].group(), "hello world")
|
||||
|
||||
def testMultipleSynchronous(self):
|
||||
self.term.write("goodbye world")
|
||||
|
||||
result = []
|
||||
d1 = self.term.expect("bye")
|
||||
d1.addCallback(result.append)
|
||||
d2 = self.term.expect("world")
|
||||
d2.addCallback(result.append)
|
||||
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertEqual(result[0].group(), "bye")
|
||||
self.assertEqual(result[1].group(), "world")
|
||||
|
||||
def _cbTestTimeoutFailure(self, res):
|
||||
self.assertTrue(hasattr(res, 'type'))
|
||||
self.assertEqual(res.type, helper.ExpectationTimeout)
|
||||
|
||||
def testTimeoutFailure(self):
|
||||
d = self.term.expect("hello world", timeout=1, scheduler=self.fs)
|
||||
d.addBoth(self._cbTestTimeoutFailure)
|
||||
self.fs.calls[0].call()
|
||||
|
||||
def testOverlappingTimeout(self):
|
||||
self.term.write("not zoomtastic")
|
||||
|
||||
result = []
|
||||
d1 = self.term.expect("hello world", timeout=1, scheduler=self.fs)
|
||||
d1.addBoth(self._cbTestTimeoutFailure)
|
||||
d2 = self.term.expect("zoom")
|
||||
d2.addCallback(result.append)
|
||||
|
||||
self.fs.calls[0].call()
|
||||
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result[0].group(), "zoom")
|
||||
|
||||
|
||||
|
||||
class CharacterAttributeTests(unittest.TestCase):
|
||||
"""
|
||||
Tests for L{twisted.conch.insults.helper.CharacterAttribute}.
|
||||
"""
|
||||
def test_equality(self):
|
||||
"""
|
||||
L{CharacterAttribute}s must have matching character attribute values
|
||||
(bold, blink, underline, etc) with the same values to be considered
|
||||
equal.
|
||||
"""
|
||||
self.assertEqual(
|
||||
helper.CharacterAttribute(),
|
||||
helper.CharacterAttribute())
|
||||
|
||||
self.assertEqual(
|
||||
helper.CharacterAttribute(),
|
||||
helper.CharacterAttribute(charset=G0))
|
||||
|
||||
self.assertEqual(
|
||||
helper.CharacterAttribute(
|
||||
bold=True, underline=True, blink=False, reverseVideo=True,
|
||||
foreground=helper.BLUE),
|
||||
helper.CharacterAttribute(
|
||||
bold=True, underline=True, blink=False, reverseVideo=True,
|
||||
foreground=helper.BLUE))
|
||||
|
||||
self.assertNotEqual(
|
||||
helper.CharacterAttribute(),
|
||||
helper.CharacterAttribute(charset=G1))
|
||||
|
||||
self.assertNotEqual(
|
||||
helper.CharacterAttribute(bold=True),
|
||||
helper.CharacterAttribute(bold=False))
|
||||
|
||||
|
||||
def test_wantOneDeprecated(self):
|
||||
"""
|
||||
L{twisted.conch.insults.helper.CharacterAttribute.wantOne} emits
|
||||
a deprecation warning when invoked.
|
||||
"""
|
||||
# Trigger the deprecation warning.
|
||||
helper._FormattingState().wantOne(bold=True)
|
||||
|
||||
warningsShown = self.flushWarnings([self.test_wantOneDeprecated])
|
||||
self.assertEqual(len(warningsShown), 1)
|
||||
self.assertEqual(warningsShown[0]['category'], DeprecationWarning)
|
||||
self.assertEqual(
|
||||
warningsShown[0]['message'],
|
||||
'twisted.conch.insults.helper.wantOne was deprecated in '
|
||||
'Twisted 13.1.0')
|
||||
|
|
@ -0,0 +1,496 @@
|
|||
# -*- test-case-name: twisted.conch.test.test_insults -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
from twisted.trial import unittest
|
||||
from twisted.test.proto_helpers import StringTransport
|
||||
|
||||
from twisted.conch.insults.insults import ServerProtocol, ClientProtocol
|
||||
from twisted.conch.insults.insults import CS_UK, CS_US, CS_DRAWING, CS_ALTERNATE, CS_ALTERNATE_SPECIAL
|
||||
from twisted.conch.insults.insults import G0, G1
|
||||
from twisted.conch.insults.insults import modes
|
||||
|
||||
def _getattr(mock, name):
|
||||
return super(Mock, mock).__getattribute__(name)
|
||||
|
||||
def occurrences(mock):
|
||||
return _getattr(mock, 'occurrences')
|
||||
|
||||
def methods(mock):
|
||||
return _getattr(mock, 'methods')
|
||||
|
||||
def _append(mock, obj):
|
||||
occurrences(mock).append(obj)
|
||||
|
||||
default = object()
|
||||
|
||||
class Mock(object):
|
||||
callReturnValue = default
|
||||
|
||||
def __init__(self, methods=None, callReturnValue=default):
|
||||
"""
|
||||
@param methods: Mapping of names to return values
|
||||
@param callReturnValue: object __call__ should return
|
||||
"""
|
||||
self.occurrences = []
|
||||
if methods is None:
|
||||
methods = {}
|
||||
self.methods = methods
|
||||
if callReturnValue is not default:
|
||||
self.callReturnValue = callReturnValue
|
||||
|
||||
def __call__(self, *a, **kw):
|
||||
returnValue = _getattr(self, 'callReturnValue')
|
||||
if returnValue is default:
|
||||
returnValue = Mock()
|
||||
# _getattr(self, 'occurrences').append(('__call__', returnValue, a, kw))
|
||||
_append(self, ('__call__', returnValue, a, kw))
|
||||
return returnValue
|
||||
|
||||
def __getattribute__(self, name):
|
||||
methods = _getattr(self, 'methods')
|
||||
if name in methods:
|
||||
attrValue = Mock(callReturnValue=methods[name])
|
||||
else:
|
||||
attrValue = Mock()
|
||||
# _getattr(self, 'occurrences').append((name, attrValue))
|
||||
_append(self, (name, attrValue))
|
||||
return attrValue
|
||||
|
||||
class MockMixin:
|
||||
def assertCall(self, occurrence, methodName, expectedPositionalArgs=(),
|
||||
expectedKeywordArgs={}):
|
||||
attr, mock = occurrence
|
||||
self.assertEqual(attr, methodName)
|
||||
self.assertEqual(len(occurrences(mock)), 1)
|
||||
[(call, result, args, kw)] = occurrences(mock)
|
||||
self.assertEqual(call, "__call__")
|
||||
self.assertEqual(args, expectedPositionalArgs)
|
||||
self.assertEqual(kw, expectedKeywordArgs)
|
||||
return result
|
||||
|
||||
|
||||
_byteGroupingTestTemplate = """\
|
||||
def testByte%(groupName)s(self):
|
||||
transport = StringTransport()
|
||||
proto = Mock()
|
||||
parser = self.protocolFactory(lambda: proto)
|
||||
parser.factory = self
|
||||
parser.makeConnection(transport)
|
||||
|
||||
bytes = self.TEST_BYTES
|
||||
while bytes:
|
||||
chunk = bytes[:%(bytesPer)d]
|
||||
bytes = bytes[%(bytesPer)d:]
|
||||
parser.dataReceived(chunk)
|
||||
|
||||
self.verifyResults(transport, proto, parser)
|
||||
"""
|
||||
class ByteGroupingsMixin(MockMixin):
|
||||
protocolFactory = None
|
||||
|
||||
for word, n in [('Pairs', 2), ('Triples', 3), ('Quads', 4), ('Quints', 5), ('Sexes', 6)]:
|
||||
exec _byteGroupingTestTemplate % {'groupName': word, 'bytesPer': n}
|
||||
del word, n
|
||||
|
||||
def verifyResults(self, transport, proto, parser):
|
||||
result = self.assertCall(occurrences(proto).pop(0), "makeConnection", (parser,))
|
||||
self.assertEqual(occurrences(result), [])
|
||||
|
||||
del _byteGroupingTestTemplate
|
||||
|
||||
class ServerArrowKeys(ByteGroupingsMixin, unittest.TestCase):
|
||||
protocolFactory = ServerProtocol
|
||||
|
||||
# All the arrow keys once
|
||||
TEST_BYTES = '\x1b[A\x1b[B\x1b[C\x1b[D'
|
||||
|
||||
def verifyResults(self, transport, proto, parser):
|
||||
ByteGroupingsMixin.verifyResults(self, transport, proto, parser)
|
||||
|
||||
for arrow in (parser.UP_ARROW, parser.DOWN_ARROW,
|
||||
parser.RIGHT_ARROW, parser.LEFT_ARROW):
|
||||
result = self.assertCall(occurrences(proto).pop(0), "keystrokeReceived", (arrow, None))
|
||||
self.assertEqual(occurrences(result), [])
|
||||
self.assertFalse(occurrences(proto))
|
||||
|
||||
|
||||
class PrintableCharacters(ByteGroupingsMixin, unittest.TestCase):
|
||||
protocolFactory = ServerProtocol
|
||||
|
||||
# Some letters and digits, first on their own, then capitalized,
|
||||
# then modified with alt
|
||||
|
||||
TEST_BYTES = 'abc123ABC!@#\x1ba\x1bb\x1bc\x1b1\x1b2\x1b3'
|
||||
|
||||
def verifyResults(self, transport, proto, parser):
|
||||
ByteGroupingsMixin.verifyResults(self, transport, proto, parser)
|
||||
|
||||
for char in 'abc123ABC!@#':
|
||||
result = self.assertCall(occurrences(proto).pop(0), "keystrokeReceived", (char, None))
|
||||
self.assertEqual(occurrences(result), [])
|
||||
|
||||
for char in 'abc123':
|
||||
result = self.assertCall(occurrences(proto).pop(0), "keystrokeReceived", (char, parser.ALT))
|
||||
self.assertEqual(occurrences(result), [])
|
||||
|
||||
occs = occurrences(proto)
|
||||
self.assertFalse(occs, "%r should have been []" % (occs,))
|
||||
|
||||
class ServerFunctionKeys(ByteGroupingsMixin, unittest.TestCase):
|
||||
"""Test for parsing and dispatching function keys (F1 - F12)
|
||||
"""
|
||||
protocolFactory = ServerProtocol
|
||||
|
||||
byteList = []
|
||||
for bytes in ('OP', 'OQ', 'OR', 'OS', # F1 - F4
|
||||
'15~', '17~', '18~', '19~', # F5 - F8
|
||||
'20~', '21~', '23~', '24~'): # F9 - F12
|
||||
byteList.append('\x1b[' + bytes)
|
||||
TEST_BYTES = ''.join(byteList)
|
||||
del byteList, bytes
|
||||
|
||||
def verifyResults(self, transport, proto, parser):
|
||||
ByteGroupingsMixin.verifyResults(self, transport, proto, parser)
|
||||
for funcNum in range(1, 13):
|
||||
funcArg = getattr(parser, 'F%d' % (funcNum,))
|
||||
result = self.assertCall(occurrences(proto).pop(0), "keystrokeReceived", (funcArg, None))
|
||||
self.assertEqual(occurrences(result), [])
|
||||
self.assertFalse(occurrences(proto))
|
||||
|
||||
class ClientCursorMovement(ByteGroupingsMixin, unittest.TestCase):
|
||||
protocolFactory = ClientProtocol
|
||||
|
||||
d2 = "\x1b[2B"
|
||||
r4 = "\x1b[4C"
|
||||
u1 = "\x1b[A"
|
||||
l2 = "\x1b[2D"
|
||||
# Move the cursor down two, right four, up one, left two, up one, left two
|
||||
TEST_BYTES = d2 + r4 + u1 + l2 + u1 + l2
|
||||
del d2, r4, u1, l2
|
||||
|
||||
def verifyResults(self, transport, proto, parser):
|
||||
ByteGroupingsMixin.verifyResults(self, transport, proto, parser)
|
||||
|
||||
for (method, count) in [('Down', 2), ('Forward', 4), ('Up', 1),
|
||||
('Backward', 2), ('Up', 1), ('Backward', 2)]:
|
||||
result = self.assertCall(occurrences(proto).pop(0), "cursor" + method, (count,))
|
||||
self.assertEqual(occurrences(result), [])
|
||||
self.assertFalse(occurrences(proto))
|
||||
|
||||
class ClientControlSequences(unittest.TestCase, MockMixin):
|
||||
def setUp(self):
|
||||
self.transport = StringTransport()
|
||||
self.proto = Mock()
|
||||
self.parser = ClientProtocol(lambda: self.proto)
|
||||
self.parser.factory = self
|
||||
self.parser.makeConnection(self.transport)
|
||||
result = self.assertCall(occurrences(self.proto).pop(0), "makeConnection", (self.parser,))
|
||||
self.assertFalse(occurrences(result))
|
||||
|
||||
def testSimpleCardinals(self):
|
||||
self.parser.dataReceived(
|
||||
''.join([''.join(['\x1b[' + str(n) + ch for n in ('', 2, 20, 200)]) for ch in 'BACD']))
|
||||
occs = occurrences(self.proto)
|
||||
|
||||
for meth in ("Down", "Up", "Forward", "Backward"):
|
||||
for count in (1, 2, 20, 200):
|
||||
result = self.assertCall(occs.pop(0), "cursor" + meth, (count,))
|
||||
self.assertFalse(occurrences(result))
|
||||
self.assertFalse(occs)
|
||||
|
||||
def testScrollRegion(self):
|
||||
self.parser.dataReceived('\x1b[5;22r\x1b[r')
|
||||
occs = occurrences(self.proto)
|
||||
|
||||
result = self.assertCall(occs.pop(0), "setScrollRegion", (5, 22))
|
||||
self.assertFalse(occurrences(result))
|
||||
|
||||
result = self.assertCall(occs.pop(0), "setScrollRegion", (None, None))
|
||||
self.assertFalse(occurrences(result))
|
||||
self.assertFalse(occs)
|
||||
|
||||
def testHeightAndWidth(self):
|
||||
self.parser.dataReceived("\x1b#3\x1b#4\x1b#5\x1b#6")
|
||||
occs = occurrences(self.proto)
|
||||
|
||||
result = self.assertCall(occs.pop(0), "doubleHeightLine", (True,))
|
||||
self.assertFalse(occurrences(result))
|
||||
|
||||
result = self.assertCall(occs.pop(0), "doubleHeightLine", (False,))
|
||||
self.assertFalse(occurrences(result))
|
||||
|
||||
result = self.assertCall(occs.pop(0), "singleWidthLine")
|
||||
self.assertFalse(occurrences(result))
|
||||
|
||||
result = self.assertCall(occs.pop(0), "doubleWidthLine")
|
||||
self.assertFalse(occurrences(result))
|
||||
self.assertFalse(occs)
|
||||
|
||||
def testCharacterSet(self):
|
||||
self.parser.dataReceived(
|
||||
''.join([''.join(['\x1b' + g + n for n in 'AB012']) for g in '()']))
|
||||
occs = occurrences(self.proto)
|
||||
|
||||
for which in (G0, G1):
|
||||
for charset in (CS_UK, CS_US, CS_DRAWING, CS_ALTERNATE, CS_ALTERNATE_SPECIAL):
|
||||
result = self.assertCall(occs.pop(0), "selectCharacterSet", (charset, which))
|
||||
self.assertFalse(occurrences(result))
|
||||
self.assertFalse(occs)
|
||||
|
||||
def testShifting(self):
|
||||
self.parser.dataReceived("\x15\x14")
|
||||
occs = occurrences(self.proto)
|
||||
|
||||
result = self.assertCall(occs.pop(0), "shiftIn")
|
||||
self.assertFalse(occurrences(result))
|
||||
|
||||
result = self.assertCall(occs.pop(0), "shiftOut")
|
||||
self.assertFalse(occurrences(result))
|
||||
self.assertFalse(occs)
|
||||
|
||||
def testSingleShifts(self):
|
||||
self.parser.dataReceived("\x1bN\x1bO")
|
||||
occs = occurrences(self.proto)
|
||||
|
||||
result = self.assertCall(occs.pop(0), "singleShift2")
|
||||
self.assertFalse(occurrences(result))
|
||||
|
||||
result = self.assertCall(occs.pop(0), "singleShift3")
|
||||
self.assertFalse(occurrences(result))
|
||||
self.assertFalse(occs)
|
||||
|
||||
def testKeypadMode(self):
|
||||
self.parser.dataReceived("\x1b=\x1b>")
|
||||
occs = occurrences(self.proto)
|
||||
|
||||
result = self.assertCall(occs.pop(0), "applicationKeypadMode")
|
||||
self.assertFalse(occurrences(result))
|
||||
|
||||
result = self.assertCall(occs.pop(0), "numericKeypadMode")
|
||||
self.assertFalse(occurrences(result))
|
||||
self.assertFalse(occs)
|
||||
|
||||
def testCursor(self):
|
||||
self.parser.dataReceived("\x1b7\x1b8")
|
||||
occs = occurrences(self.proto)
|
||||
|
||||
result = self.assertCall(occs.pop(0), "saveCursor")
|
||||
self.assertFalse(occurrences(result))
|
||||
|
||||
result = self.assertCall(occs.pop(0), "restoreCursor")
|
||||
self.assertFalse(occurrences(result))
|
||||
self.assertFalse(occs)
|
||||
|
||||
def testReset(self):
|
||||
self.parser.dataReceived("\x1bc")
|
||||
occs = occurrences(self.proto)
|
||||
|
||||
result = self.assertCall(occs.pop(0), "reset")
|
||||
self.assertFalse(occurrences(result))
|
||||
self.assertFalse(occs)
|
||||
|
||||
def testIndex(self):
|
||||
self.parser.dataReceived("\x1bD\x1bM\x1bE")
|
||||
occs = occurrences(self.proto)
|
||||
|
||||
result = self.assertCall(occs.pop(0), "index")
|
||||
self.assertFalse(occurrences(result))
|
||||
|
||||
result = self.assertCall(occs.pop(0), "reverseIndex")
|
||||
self.assertFalse(occurrences(result))
|
||||
|
||||
result = self.assertCall(occs.pop(0), "nextLine")
|
||||
self.assertFalse(occurrences(result))
|
||||
self.assertFalse(occs)
|
||||
|
||||
def testModes(self):
|
||||
self.parser.dataReceived(
|
||||
"\x1b[" + ';'.join(map(str, [modes.KAM, modes.IRM, modes.LNM])) + "h")
|
||||
self.parser.dataReceived(
|
||||
"\x1b[" + ';'.join(map(str, [modes.KAM, modes.IRM, modes.LNM])) + "l")
|
||||
occs = occurrences(self.proto)
|
||||
|
||||
result = self.assertCall(occs.pop(0), "setModes", ([modes.KAM, modes.IRM, modes.LNM],))
|
||||
self.assertFalse(occurrences(result))
|
||||
|
||||
result = self.assertCall(occs.pop(0), "resetModes", ([modes.KAM, modes.IRM, modes.LNM],))
|
||||
self.assertFalse(occurrences(result))
|
||||
self.assertFalse(occs)
|
||||
|
||||
def testErasure(self):
|
||||
self.parser.dataReceived(
|
||||
"\x1b[K\x1b[1K\x1b[2K\x1b[J\x1b[1J\x1b[2J\x1b[3P")
|
||||
occs = occurrences(self.proto)
|
||||
|
||||
for meth in ("eraseToLineEnd", "eraseToLineBeginning", "eraseLine",
|
||||
"eraseToDisplayEnd", "eraseToDisplayBeginning",
|
||||
"eraseDisplay"):
|
||||
result = self.assertCall(occs.pop(0), meth)
|
||||
self.assertFalse(occurrences(result))
|
||||
|
||||
result = self.assertCall(occs.pop(0), "deleteCharacter", (3,))
|
||||
self.assertFalse(occurrences(result))
|
||||
self.assertFalse(occs)
|
||||
|
||||
def testLineDeletion(self):
|
||||
self.parser.dataReceived("\x1b[M\x1b[3M")
|
||||
occs = occurrences(self.proto)
|
||||
|
||||
for arg in (1, 3):
|
||||
result = self.assertCall(occs.pop(0), "deleteLine", (arg,))
|
||||
self.assertFalse(occurrences(result))
|
||||
self.assertFalse(occs)
|
||||
|
||||
def testLineInsertion(self):
|
||||
self.parser.dataReceived("\x1b[L\x1b[3L")
|
||||
occs = occurrences(self.proto)
|
||||
|
||||
for arg in (1, 3):
|
||||
result = self.assertCall(occs.pop(0), "insertLine", (arg,))
|
||||
self.assertFalse(occurrences(result))
|
||||
self.assertFalse(occs)
|
||||
|
||||
def testCursorPosition(self):
|
||||
methods(self.proto)['reportCursorPosition'] = (6, 7)
|
||||
self.parser.dataReceived("\x1b[6n")
|
||||
self.assertEqual(self.transport.value(), "\x1b[7;8R")
|
||||
occs = occurrences(self.proto)
|
||||
|
||||
result = self.assertCall(occs.pop(0), "reportCursorPosition")
|
||||
# This isn't really an interesting assert, since it only tests that
|
||||
# our mock setup is working right, but I'll include it anyway.
|
||||
self.assertEqual(result, (6, 7))
|
||||
|
||||
|
||||
def test_applicationDataBytes(self):
|
||||
"""
|
||||
Contiguous non-control bytes are passed to a single call to the
|
||||
C{write} method of the terminal to which the L{ClientProtocol} is
|
||||
connected.
|
||||
"""
|
||||
occs = occurrences(self.proto)
|
||||
self.parser.dataReceived('a')
|
||||
self.assertCall(occs.pop(0), "write", ("a",))
|
||||
self.parser.dataReceived('bc')
|
||||
self.assertCall(occs.pop(0), "write", ("bc",))
|
||||
|
||||
|
||||
def _applicationDataTest(self, data, calls):
|
||||
occs = occurrences(self.proto)
|
||||
self.parser.dataReceived(data)
|
||||
while calls:
|
||||
self.assertCall(occs.pop(0), *calls.pop(0))
|
||||
self.assertFalse(occs, "No other calls should happen: %r" % (occs,))
|
||||
|
||||
|
||||
def test_shiftInAfterApplicationData(self):
|
||||
"""
|
||||
Application data bytes followed by a shift-in command are passed to a
|
||||
call to C{write} before the terminal's C{shiftIn} method is called.
|
||||
"""
|
||||
self._applicationDataTest(
|
||||
'ab\x15', [
|
||||
("write", ("ab",)),
|
||||
("shiftIn",)])
|
||||
|
||||
|
||||
def test_shiftOutAfterApplicationData(self):
|
||||
"""
|
||||
Application data bytes followed by a shift-out command are passed to a
|
||||
call to C{write} before the terminal's C{shiftOut} method is called.
|
||||
"""
|
||||
self._applicationDataTest(
|
||||
'ab\x14', [
|
||||
("write", ("ab",)),
|
||||
("shiftOut",)])
|
||||
|
||||
|
||||
def test_cursorBackwardAfterApplicationData(self):
|
||||
"""
|
||||
Application data bytes followed by a cursor-backward command are passed
|
||||
to a call to C{write} before the terminal's C{cursorBackward} method is
|
||||
called.
|
||||
"""
|
||||
self._applicationDataTest(
|
||||
'ab\x08', [
|
||||
("write", ("ab",)),
|
||||
("cursorBackward",)])
|
||||
|
||||
|
||||
def test_escapeAfterApplicationData(self):
|
||||
"""
|
||||
Application data bytes followed by an escape character are passed to a
|
||||
call to C{write} before the terminal's handler method for the escape is
|
||||
called.
|
||||
"""
|
||||
# Test a short escape
|
||||
self._applicationDataTest(
|
||||
'ab\x1bD', [
|
||||
("write", ("ab",)),
|
||||
("index",)])
|
||||
|
||||
# And a long escape
|
||||
self._applicationDataTest(
|
||||
'ab\x1b[4h', [
|
||||
("write", ("ab",)),
|
||||
("setModes", ([4],))])
|
||||
|
||||
# There's some other cases too, but they're all handled by the same
|
||||
# codepaths as above.
|
||||
|
||||
|
||||
|
||||
class ServerProtocolOutputTests(unittest.TestCase):
|
||||
"""
|
||||
Tests for the bytes L{ServerProtocol} writes to its transport when its
|
||||
methods are called.
|
||||
"""
|
||||
def test_nextLine(self):
|
||||
"""
|
||||
L{ServerProtocol.nextLine} writes C{"\r\n"} to its transport.
|
||||
"""
|
||||
# Why doesn't it write ESC E? Because ESC E is poorly supported. For
|
||||
# example, gnome-terminal (many different versions) fails to scroll if
|
||||
# it receives ESC E and the cursor is already on the last row.
|
||||
protocol = ServerProtocol()
|
||||
transport = StringTransport()
|
||||
protocol.makeConnection(transport)
|
||||
protocol.nextLine()
|
||||
self.assertEqual(transport.value(), "\r\n")
|
||||
|
||||
|
||||
|
||||
class Deprecations(unittest.TestCase):
|
||||
"""
|
||||
Tests to ensure deprecation of L{insults.colors} and L{insults.client}
|
||||
"""
|
||||
|
||||
def ensureDeprecated(self, message):
|
||||
"""
|
||||
Ensures that the correct deprecation warning was issued.
|
||||
"""
|
||||
warnings = self.flushWarnings()
|
||||
self.assertIs(warnings[0]['category'], DeprecationWarning)
|
||||
self.assertEqual(warnings[0]['message'], message)
|
||||
self.assertEqual(len(warnings), 1)
|
||||
|
||||
|
||||
def test_colors(self):
|
||||
"""
|
||||
The L{insults.colors} module is deprecated
|
||||
"""
|
||||
from twisted.conch.insults import colors
|
||||
self.ensureDeprecated("twisted.conch.insults.colors was deprecated "
|
||||
"in Twisted 10.1.0: Please use "
|
||||
"twisted.conch.insults.helper instead.")
|
||||
|
||||
|
||||
def test_client(self):
|
||||
"""
|
||||
The L{insults.client} module is deprecated
|
||||
"""
|
||||
from twisted.conch.insults import client
|
||||
self.ensureDeprecated("twisted.conch.insults.client was deprecated "
|
||||
"in Twisted 10.1.0: Please use "
|
||||
"twisted.conch.insults.insults instead.")
|
||||
|
|
@ -0,0 +1,644 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.conch.ssh.keys}.
|
||||
"""
|
||||
|
||||
try:
|
||||
import Crypto.Cipher.DES3
|
||||
except ImportError:
|
||||
# we'll have to skip these tests without PyCypto and pyasn1
|
||||
Crypto = None
|
||||
|
||||
try:
|
||||
import pyasn1
|
||||
except ImportError:
|
||||
pyasn1 = None
|
||||
|
||||
if Crypto and pyasn1:
|
||||
from twisted.conch.ssh import keys, common, sexpy
|
||||
|
||||
import os, base64
|
||||
from hashlib import sha1
|
||||
from twisted.conch.test import keydata
|
||||
from twisted.python import randbytes
|
||||
from twisted.trial import unittest
|
||||
|
||||
|
||||
class HelpersTestCase(unittest.TestCase):
|
||||
|
||||
if Crypto is None:
|
||||
skip = "cannot run w/o PyCrypto"
|
||||
if pyasn1 is None:
|
||||
skip = "Cannot run without PyASN1"
|
||||
|
||||
def setUp(self):
|
||||
self._secureRandom = randbytes.secureRandom
|
||||
randbytes.secureRandom = lambda x: '\x55' * x
|
||||
|
||||
def tearDown(self):
|
||||
randbytes.secureRandom = self._secureRandom
|
||||
self._secureRandom = None
|
||||
|
||||
def test_pkcs1(self):
|
||||
"""
|
||||
Test Public Key Cryptographic Standard #1 functions.
|
||||
"""
|
||||
data = 'ABC'
|
||||
messageSize = 6
|
||||
self.assertEqual(keys.pkcs1Pad(data, messageSize),
|
||||
'\x01\xff\x00ABC')
|
||||
hash = sha1().digest()
|
||||
messageSize = 40
|
||||
self.assertEqual(keys.pkcs1Digest('', messageSize),
|
||||
'\x01\xff\xff\xff\x00' + keys.ID_SHA1 + hash)
|
||||
|
||||
def _signRSA(self, data):
|
||||
key = keys.Key.fromString(keydata.privateRSA_openssh)
|
||||
sig = key.sign(data)
|
||||
return key.keyObject, sig
|
||||
|
||||
def _signDSA(self, data):
|
||||
key = keys.Key.fromString(keydata.privateDSA_openssh)
|
||||
sig = key.sign(data)
|
||||
return key.keyObject, sig
|
||||
|
||||
def test_signRSA(self):
|
||||
"""
|
||||
Test that RSA keys return appropriate signatures.
|
||||
"""
|
||||
data = 'data'
|
||||
key, sig = self._signRSA(data)
|
||||
sigData = keys.pkcs1Digest(data, keys.lenSig(key))
|
||||
v = key.sign(sigData, '')[0]
|
||||
self.assertEqual(sig, common.NS('ssh-rsa') + common.MP(v))
|
||||
return key, sig
|
||||
|
||||
def test_signDSA(self):
|
||||
"""
|
||||
Test that DSA keys return appropriate signatures.
|
||||
"""
|
||||
data = 'data'
|
||||
key, sig = self._signDSA(data)
|
||||
sigData = sha1(data).digest()
|
||||
v = key.sign(sigData, '\x55' * 19)
|
||||
self.assertEqual(sig, common.NS('ssh-dss') + common.NS(
|
||||
Crypto.Util.number.long_to_bytes(v[0], 20) +
|
||||
Crypto.Util.number.long_to_bytes(v[1], 20)))
|
||||
return key, sig
|
||||
|
||||
|
||||
def test_objectType(self):
|
||||
"""
|
||||
Test that objectType, returns the correct type for objects.
|
||||
"""
|
||||
self.assertEqual(keys.objectType(keys.Key.fromString(
|
||||
keydata.privateRSA_openssh).keyObject), 'ssh-rsa')
|
||||
self.assertEqual(keys.objectType(keys.Key.fromString(
|
||||
keydata.privateDSA_openssh).keyObject), 'ssh-dss')
|
||||
self.assertRaises(keys.BadKeyError, keys.objectType, None)
|
||||
|
||||
|
||||
class KeyTestCase(unittest.TestCase):
|
||||
|
||||
if Crypto is None:
|
||||
skip = "cannot run w/o PyCrypto"
|
||||
if pyasn1 is None:
|
||||
skip = "Cannot run without PyASN1"
|
||||
|
||||
def setUp(self):
|
||||
self.rsaObj = Crypto.PublicKey.RSA.construct((1L, 2L, 3L, 4L, 5L))
|
||||
self.dsaObj = Crypto.PublicKey.DSA.construct((1L, 2L, 3L, 4L, 5L))
|
||||
self.rsaSignature = ('\x00\x00\x00\x07ssh-rsa\x00'
|
||||
'\x00\x00`N\xac\xb4@qK\xa0(\xc3\xf2h \xd3\xdd\xee6Np\x9d_'
|
||||
'\xb0>\xe3\x0c(L\x9d{\txUd|!\xf6m\x9c\xd3\x93\x842\x7fU'
|
||||
'\x05\xf4\xf7\xfaD\xda\xce\x81\x8ea\x7f=Y\xed*\xb7\xba\x81'
|
||||
'\xf2\xad\xda\xeb(\x97\x03S\x08\x81\xc7\xb1\xb7\xe6\xe3'
|
||||
'\xcd*\xd4\xbd\xc0wt\xf7y\xcd\xf0\xb7\x7f\xfb\x1e>\xf9r'
|
||||
'\x8c\xba')
|
||||
self.dsaSignature = ('\x00\x00\x00\x07ssh-dss\x00\x00'
|
||||
'\x00(\x18z)H\x8a\x1b\xc6\r\xbbq\xa2\xd7f\x7f$\xa7\xbf'
|
||||
'\xe8\x87\x8c\x88\xef\xd9k\x1a\x98\xdd{=\xdec\x18\t\xe3'
|
||||
'\x87\xa9\xc72h\x95')
|
||||
self.oldSecureRandom = randbytes.secureRandom
|
||||
randbytes.secureRandom = lambda x: '\xff' * x
|
||||
self.keyFile = self.mktemp()
|
||||
file(self.keyFile, 'wb').write(keydata.privateRSA_lsh)
|
||||
|
||||
def tearDown(self):
|
||||
randbytes.secureRandom = self.oldSecureRandom
|
||||
del self.oldSecureRandom
|
||||
os.unlink(self.keyFile)
|
||||
|
||||
def test__guessStringType(self):
|
||||
"""
|
||||
Test that the _guessStringType method guesses string types
|
||||
correctly.
|
||||
"""
|
||||
self.assertEqual(keys.Key._guessStringType(keydata.publicRSA_openssh),
|
||||
'public_openssh')
|
||||
self.assertEqual(keys.Key._guessStringType(keydata.publicDSA_openssh),
|
||||
'public_openssh')
|
||||
self.assertEqual(keys.Key._guessStringType(
|
||||
keydata.privateRSA_openssh), 'private_openssh')
|
||||
self.assertEqual(keys.Key._guessStringType(
|
||||
keydata.privateDSA_openssh), 'private_openssh')
|
||||
self.assertEqual(keys.Key._guessStringType(keydata.publicRSA_lsh),
|
||||
'public_lsh')
|
||||
self.assertEqual(keys.Key._guessStringType(keydata.publicDSA_lsh),
|
||||
'public_lsh')
|
||||
self.assertEqual(keys.Key._guessStringType(keydata.privateRSA_lsh),
|
||||
'private_lsh')
|
||||
self.assertEqual(keys.Key._guessStringType(keydata.privateDSA_lsh),
|
||||
'private_lsh')
|
||||
self.assertEqual(keys.Key._guessStringType(
|
||||
keydata.privateRSA_agentv3), 'agentv3')
|
||||
self.assertEqual(keys.Key._guessStringType(
|
||||
keydata.privateDSA_agentv3), 'agentv3')
|
||||
self.assertEqual(keys.Key._guessStringType(
|
||||
'\x00\x00\x00\x07ssh-rsa\x00\x00\x00\x01\x01'),
|
||||
'blob')
|
||||
self.assertEqual(keys.Key._guessStringType(
|
||||
'\x00\x00\x00\x07ssh-dss\x00\x00\x00\x01\x01'),
|
||||
'blob')
|
||||
self.assertEqual(keys.Key._guessStringType('not a key'),
|
||||
None)
|
||||
|
||||
def _testPublicPrivateFromString(self, public, private, type, data):
|
||||
self._testPublicFromString(public, type, data)
|
||||
self._testPrivateFromString(private, type, data)
|
||||
|
||||
def _testPublicFromString(self, public, type, data):
|
||||
publicKey = keys.Key.fromString(public)
|
||||
self.assertTrue(publicKey.isPublic())
|
||||
self.assertEqual(publicKey.type(), type)
|
||||
for k, v in publicKey.data().items():
|
||||
self.assertEqual(data[k], v)
|
||||
|
||||
def _testPrivateFromString(self, private, type, data):
|
||||
privateKey = keys.Key.fromString(private)
|
||||
self.assertFalse(privateKey.isPublic())
|
||||
self.assertEqual(privateKey.type(), type)
|
||||
for k, v in data.items():
|
||||
self.assertEqual(privateKey.data()[k], v)
|
||||
|
||||
def test_fromOpenSSH(self):
|
||||
"""
|
||||
Test that keys are correctly generated from OpenSSH strings.
|
||||
"""
|
||||
self._testPublicPrivateFromString(keydata.publicRSA_openssh,
|
||||
keydata.privateRSA_openssh, 'RSA', keydata.RSAData)
|
||||
self.assertEqual(keys.Key.fromString(
|
||||
keydata.privateRSA_openssh_encrypted,
|
||||
passphrase='encrypted'),
|
||||
keys.Key.fromString(keydata.privateRSA_openssh))
|
||||
self.assertEqual(keys.Key.fromString(
|
||||
keydata.privateRSA_openssh_alternate),
|
||||
keys.Key.fromString(keydata.privateRSA_openssh))
|
||||
self._testPublicPrivateFromString(keydata.publicDSA_openssh,
|
||||
keydata.privateDSA_openssh, 'DSA', keydata.DSAData)
|
||||
|
||||
def test_fromOpenSSH_with_whitespace(self):
|
||||
"""
|
||||
If key strings have trailing whitespace, it should be ignored.
|
||||
"""
|
||||
# from bug #3391, since our test key data doesn't have
|
||||
# an issue with appended newlines
|
||||
privateDSAData = """-----BEGIN DSA PRIVATE KEY-----
|
||||
MIIBuwIBAAKBgQDylESNuc61jq2yatCzZbenlr9llG+p9LhIpOLUbXhhHcwC6hrh
|
||||
EZIdCKqTO0USLrGoP5uS9UHAUoeN62Z0KXXWTwOWGEQn/syyPzNJtnBorHpNUT9D
|
||||
Qzwl1yUa53NNgEctpo4NoEFOx8PuU6iFLyvgHCjNn2MsuGuzkZm7sI9ZpQIVAJiR
|
||||
9dPc08KLdpJyRxz8T74b4FQRAoGAGBc4Z5Y6R/HZi7AYM/iNOM8su6hrk8ypkBwR
|
||||
a3Dbhzk97fuV3SF1SDrcQu4zF7c4CtH609N5nfZs2SUjLLGPWln83Ysb8qhh55Em
|
||||
AcHXuROrHS/sDsnqu8FQp86MaudrqMExCOYyVPE7jaBWW+/JWFbKCxmgOCSdViUJ
|
||||
esJpBFsCgYEA7+jtVvSt9yrwsS/YU1QGP5wRAiDYB+T5cK4HytzAqJKRdC5qS4zf
|
||||
C7R0eKcDHHLMYO39aPnCwXjscisnInEhYGNblTDyPyiyNxAOXuC8x7luTmwzMbNJ
|
||||
/ow0IqSj0VF72VJN9uSoPpFd4lLT0zN8v42RWja0M8ohWNf+YNJluPgCFE0PT4Vm
|
||||
SUrCyZXsNh6VXwjs3gKQ
|
||||
-----END DSA PRIVATE KEY-----"""
|
||||
self.assertEqual(keys.Key.fromString(privateDSAData),
|
||||
keys.Key.fromString(privateDSAData + '\n'))
|
||||
|
||||
def test_fromNewerOpenSSH(self):
|
||||
"""
|
||||
Newer versions of OpenSSH generate encrypted keys which have a longer
|
||||
IV than the older versions. These newer keys are also loaded.
|
||||
"""
|
||||
key = keys.Key.fromString(keydata.privateRSA_openssh_encrypted_aes,
|
||||
passphrase='testxp')
|
||||
self.assertEqual(key.type(), 'RSA')
|
||||
key2 = keys.Key.fromString(
|
||||
keydata.privateRSA_openssh_encrypted_aes + '\n',
|
||||
passphrase='testxp')
|
||||
self.assertEqual(key, key2)
|
||||
|
||||
|
||||
def test_fromLSH(self):
|
||||
"""
|
||||
Test that keys are correctly generated from LSH strings.
|
||||
"""
|
||||
self._testPublicPrivateFromString(keydata.publicRSA_lsh,
|
||||
keydata.privateRSA_lsh, 'RSA', keydata.RSAData)
|
||||
self._testPublicPrivateFromString(keydata.publicDSA_lsh,
|
||||
keydata.privateDSA_lsh, 'DSA', keydata.DSAData)
|
||||
sexp = sexpy.pack([['public-key', ['bad-key', ['p', '2']]]])
|
||||
self.assertRaises(keys.BadKeyError, keys.Key.fromString,
|
||||
data='{'+base64.encodestring(sexp)+'}')
|
||||
sexp = sexpy.pack([['private-key', ['bad-key', ['p', '2']]]])
|
||||
self.assertRaises(keys.BadKeyError, keys.Key.fromString,
|
||||
sexp)
|
||||
|
||||
def test_fromAgentv3(self):
|
||||
"""
|
||||
Test that keys are correctly generated from Agent v3 strings.
|
||||
"""
|
||||
self._testPrivateFromString(keydata.privateRSA_agentv3, 'RSA',
|
||||
keydata.RSAData)
|
||||
self._testPrivateFromString(keydata.privateDSA_agentv3, 'DSA',
|
||||
keydata.DSAData)
|
||||
self.assertRaises(keys.BadKeyError, keys.Key.fromString,
|
||||
'\x00\x00\x00\x07ssh-foo'+'\x00\x00\x00\x01\x01'*5)
|
||||
|
||||
def test_fromStringErrors(self):
|
||||
"""
|
||||
keys.Key.fromString should raise BadKeyError when the key is invalid.
|
||||
"""
|
||||
self.assertRaises(keys.BadKeyError, keys.Key.fromString, '')
|
||||
# no key data with a bad key type
|
||||
self.assertRaises(keys.BadKeyError, keys.Key.fromString, '',
|
||||
'bad_type')
|
||||
# trying to decrypt a key which doesn't support encryption
|
||||
self.assertRaises(keys.BadKeyError, keys.Key.fromString,
|
||||
keydata.publicRSA_lsh, passphrase = 'unencrypted')
|
||||
# trying to decrypt a key with the wrong passphrase
|
||||
self.assertRaises(keys.EncryptedKeyError, keys.Key.fromString,
|
||||
keys.Key(self.rsaObj).toString('openssh', 'encrypted'))
|
||||
# key with no key data
|
||||
self.assertRaises(keys.BadKeyError, keys.Key.fromString,
|
||||
'-----BEGIN RSA KEY-----\nwA==\n')
|
||||
# key with invalid DEK Info
|
||||
self.assertRaises(
|
||||
keys.BadKeyError, keys.Key.fromString,
|
||||
"""-----BEGIN ENCRYPTED RSA KEY-----
|
||||
Proc-Type: 4,ENCRYPTED
|
||||
DEK-Info: weird type
|
||||
|
||||
4Ed/a9OgJWHJsne7yOGWeWMzHYKsxuP9w1v0aYcp+puS75wvhHLiUnNwxz0KDi6n
|
||||
T3YkKLBsoCWS68ApR2J9yeQ6R+EyS+UQDrO9nwqo3DB5BT3Ggt8S1wE7vjNLQD0H
|
||||
g/SJnlqwsECNhh8aAx+Ag0m3ZKOZiRD5mCkcDQsZET7URSmFytDKOjhFn3u6ZFVB
|
||||
sXrfpYc6TJtOQlHd/52JB6aAbjt6afSv955Z7enIi+5yEJ5y7oYQTaE5zrFMP7N5
|
||||
9LbfJFlKXxEddy/DErRLxEjmC+t4svHesoJKc2jjjyNPiOoGGF3kJXea62vsjdNV
|
||||
gMK5Eged3TBVIk2dv8rtJUvyFeCUtjQ1UJZIebScRR47KrbsIpCmU8I4/uHWm5hW
|
||||
0mOwvdx1L/mqx/BHqVU9Dw2COhOdLbFxlFI92chkovkmNk4P48ziyVnpm7ME22sE
|
||||
vfCMsyirdqB1mrL4CSM7FXONv+CgfBfeYVkYW8RfJac9U1L/O+JNn7yee414O/rS
|
||||
hRYw4UdWnH6Gg6niklVKWNY0ZwUZC8zgm2iqy8YCYuneS37jC+OEKP+/s6HSKuqk
|
||||
2bzcl3/TcZXNSM815hnFRpz0anuyAsvwPNRyvxG2/DacJHL1f6luV4B0o6W410yf
|
||||
qXQx01DLo7nuyhJqoH3UGCyyXB+/QUs0mbG2PAEn3f5dVs31JMdbt+PrxURXXjKk
|
||||
4cexpUcIpqqlfpIRe3RD0sDVbH4OXsGhi2kiTfPZu7mgyFxKopRbn1KwU1qKinfY
|
||||
EU9O4PoTak/tPT+5jFNhaP+HrURoi/pU8EAUNSktl7xAkHYwkN/9Cm7DeBghgf3n
|
||||
8+tyCGYDsB5utPD0/Xe9yx0Qhc/kMm4xIyQDyA937dk3mUvLC9vulnAP8I+Izim0
|
||||
fZ182+D1bWwykoD0997mUHG/AUChWR01V1OLwRyPv2wUtiS8VNG76Y2aqKlgqP1P
|
||||
V+IvIEqR4ERvSBVFzXNF8Y6j/sVxo8+aZw+d0L1Ns/R55deErGg3B8i/2EqGd3r+
|
||||
0jps9BqFHHWW87n3VyEB3jWCMj8Vi2EJIfa/7pSaViFIQn8LiBLf+zxG5LTOToK5
|
||||
xkN42fReDcqi3UNfKNGnv4dsplyTR2hyx65lsj4bRKDGLKOuB1y7iB0AGb0LtcAI
|
||||
dcsVlcCeUquDXtqKvRnwfIMg+ZunyjqHBhj3qgRgbXbT6zjaSdNnih569aTg0Vup
|
||||
VykzZ7+n/KVcGLmvX0NesdoI7TKbq4TnEIOynuG5Sf+2GpARO5bjcWKSZeN/Ybgk
|
||||
gccf8Cqf6XWqiwlWd0B7BR3SymeHIaSymC45wmbgdstrbk7Ppa2Tp9AZku8M2Y7c
|
||||
8mY9b+onK075/ypiwBm4L4GRNTFLnoNQJXx0OSl4FNRWsn6ztbD+jZhu8Seu10Jw
|
||||
SEJVJ+gmTKdRLYORJKyqhDet6g7kAxs4EoJ25WsOnX5nNr00rit+NkMPA7xbJT+7
|
||||
CfI51GQLw7pUPeO2WNt6yZO/YkzZrqvTj5FEwybkUyBv7L0gkqu9wjfDdUw0fVHE
|
||||
xEm4DxjEoaIp8dW/JOzXQ2EF+WaSOgdYsw3Ac+rnnjnNptCdOEDGP6QBkt+oXj4P
|
||||
-----END RSA PRIVATE KEY-----""", passphrase='encrypted')
|
||||
# key with invalid encryption type
|
||||
self.assertRaises(
|
||||
keys.BadKeyError, keys.Key.fromString,
|
||||
"""-----BEGIN ENCRYPTED RSA KEY-----
|
||||
Proc-Type: 4,ENCRYPTED
|
||||
DEK-Info: FOO-123-BAR,01234567
|
||||
|
||||
4Ed/a9OgJWHJsne7yOGWeWMzHYKsxuP9w1v0aYcp+puS75wvhHLiUnNwxz0KDi6n
|
||||
T3YkKLBsoCWS68ApR2J9yeQ6R+EyS+UQDrO9nwqo3DB5BT3Ggt8S1wE7vjNLQD0H
|
||||
g/SJnlqwsECNhh8aAx+Ag0m3ZKOZiRD5mCkcDQsZET7URSmFytDKOjhFn3u6ZFVB
|
||||
sXrfpYc6TJtOQlHd/52JB6aAbjt6afSv955Z7enIi+5yEJ5y7oYQTaE5zrFMP7N5
|
||||
9LbfJFlKXxEddy/DErRLxEjmC+t4svHesoJKc2jjjyNPiOoGGF3kJXea62vsjdNV
|
||||
gMK5Eged3TBVIk2dv8rtJUvyFeCUtjQ1UJZIebScRR47KrbsIpCmU8I4/uHWm5hW
|
||||
0mOwvdx1L/mqx/BHqVU9Dw2COhOdLbFxlFI92chkovkmNk4P48ziyVnpm7ME22sE
|
||||
vfCMsyirdqB1mrL4CSM7FXONv+CgfBfeYVkYW8RfJac9U1L/O+JNn7yee414O/rS
|
||||
hRYw4UdWnH6Gg6niklVKWNY0ZwUZC8zgm2iqy8YCYuneS37jC+OEKP+/s6HSKuqk
|
||||
2bzcl3/TcZXNSM815hnFRpz0anuyAsvwPNRyvxG2/DacJHL1f6luV4B0o6W410yf
|
||||
qXQx01DLo7nuyhJqoH3UGCyyXB+/QUs0mbG2PAEn3f5dVs31JMdbt+PrxURXXjKk
|
||||
4cexpUcIpqqlfpIRe3RD0sDVbH4OXsGhi2kiTfPZu7mgyFxKopRbn1KwU1qKinfY
|
||||
EU9O4PoTak/tPT+5jFNhaP+HrURoi/pU8EAUNSktl7xAkHYwkN/9Cm7DeBghgf3n
|
||||
8+tyCGYDsB5utPD0/Xe9yx0Qhc/kMm4xIyQDyA937dk3mUvLC9vulnAP8I+Izim0
|
||||
fZ182+D1bWwykoD0997mUHG/AUChWR01V1OLwRyPv2wUtiS8VNG76Y2aqKlgqP1P
|
||||
V+IvIEqR4ERvSBVFzXNF8Y6j/sVxo8+aZw+d0L1Ns/R55deErGg3B8i/2EqGd3r+
|
||||
0jps9BqFHHWW87n3VyEB3jWCMj8Vi2EJIfa/7pSaViFIQn8LiBLf+zxG5LTOToK5
|
||||
xkN42fReDcqi3UNfKNGnv4dsplyTR2hyx65lsj4bRKDGLKOuB1y7iB0AGb0LtcAI
|
||||
dcsVlcCeUquDXtqKvRnwfIMg+ZunyjqHBhj3qgRgbXbT6zjaSdNnih569aTg0Vup
|
||||
VykzZ7+n/KVcGLmvX0NesdoI7TKbq4TnEIOynuG5Sf+2GpARO5bjcWKSZeN/Ybgk
|
||||
gccf8Cqf6XWqiwlWd0B7BR3SymeHIaSymC45wmbgdstrbk7Ppa2Tp9AZku8M2Y7c
|
||||
8mY9b+onK075/ypiwBm4L4GRNTFLnoNQJXx0OSl4FNRWsn6ztbD+jZhu8Seu10Jw
|
||||
SEJVJ+gmTKdRLYORJKyqhDet6g7kAxs4EoJ25WsOnX5nNr00rit+NkMPA7xbJT+7
|
||||
CfI51GQLw7pUPeO2WNt6yZO/YkzZrqvTj5FEwybkUyBv7L0gkqu9wjfDdUw0fVHE
|
||||
xEm4DxjEoaIp8dW/JOzXQ2EF+WaSOgdYsw3Ac+rnnjnNptCdOEDGP6QBkt+oXj4P
|
||||
-----END RSA PRIVATE KEY-----""", passphrase='encrypted')
|
||||
# key with bad IV (AES)
|
||||
self.assertRaises(
|
||||
keys.BadKeyError, keys.Key.fromString,
|
||||
"""-----BEGIN ENCRYPTED RSA KEY-----
|
||||
Proc-Type: 4,ENCRYPTED
|
||||
DEK-Info: AES-128-CBC,01234
|
||||
|
||||
4Ed/a9OgJWHJsne7yOGWeWMzHYKsxuP9w1v0aYcp+puS75wvhHLiUnNwxz0KDi6n
|
||||
T3YkKLBsoCWS68ApR2J9yeQ6R+EyS+UQDrO9nwqo3DB5BT3Ggt8S1wE7vjNLQD0H
|
||||
g/SJnlqwsECNhh8aAx+Ag0m3ZKOZiRD5mCkcDQsZET7URSmFytDKOjhFn3u6ZFVB
|
||||
sXrfpYc6TJtOQlHd/52JB6aAbjt6afSv955Z7enIi+5yEJ5y7oYQTaE5zrFMP7N5
|
||||
9LbfJFlKXxEddy/DErRLxEjmC+t4svHesoJKc2jjjyNPiOoGGF3kJXea62vsjdNV
|
||||
gMK5Eged3TBVIk2dv8rtJUvyFeCUtjQ1UJZIebScRR47KrbsIpCmU8I4/uHWm5hW
|
||||
0mOwvdx1L/mqx/BHqVU9Dw2COhOdLbFxlFI92chkovkmNk4P48ziyVnpm7ME22sE
|
||||
vfCMsyirdqB1mrL4CSM7FXONv+CgfBfeYVkYW8RfJac9U1L/O+JNn7yee414O/rS
|
||||
hRYw4UdWnH6Gg6niklVKWNY0ZwUZC8zgm2iqy8YCYuneS37jC+OEKP+/s6HSKuqk
|
||||
2bzcl3/TcZXNSM815hnFRpz0anuyAsvwPNRyvxG2/DacJHL1f6luV4B0o6W410yf
|
||||
qXQx01DLo7nuyhJqoH3UGCyyXB+/QUs0mbG2PAEn3f5dVs31JMdbt+PrxURXXjKk
|
||||
4cexpUcIpqqlfpIRe3RD0sDVbH4OXsGhi2kiTfPZu7mgyFxKopRbn1KwU1qKinfY
|
||||
EU9O4PoTak/tPT+5jFNhaP+HrURoi/pU8EAUNSktl7xAkHYwkN/9Cm7DeBghgf3n
|
||||
8+tyCGYDsB5utPD0/Xe9yx0Qhc/kMm4xIyQDyA937dk3mUvLC9vulnAP8I+Izim0
|
||||
fZ182+D1bWwykoD0997mUHG/AUChWR01V1OLwRyPv2wUtiS8VNG76Y2aqKlgqP1P
|
||||
V+IvIEqR4ERvSBVFzXNF8Y6j/sVxo8+aZw+d0L1Ns/R55deErGg3B8i/2EqGd3r+
|
||||
0jps9BqFHHWW87n3VyEB3jWCMj8Vi2EJIfa/7pSaViFIQn8LiBLf+zxG5LTOToK5
|
||||
xkN42fReDcqi3UNfKNGnv4dsplyTR2hyx65lsj4bRKDGLKOuB1y7iB0AGb0LtcAI
|
||||
dcsVlcCeUquDXtqKvRnwfIMg+ZunyjqHBhj3qgRgbXbT6zjaSdNnih569aTg0Vup
|
||||
VykzZ7+n/KVcGLmvX0NesdoI7TKbq4TnEIOynuG5Sf+2GpARO5bjcWKSZeN/Ybgk
|
||||
gccf8Cqf6XWqiwlWd0B7BR3SymeHIaSymC45wmbgdstrbk7Ppa2Tp9AZku8M2Y7c
|
||||
8mY9b+onK075/ypiwBm4L4GRNTFLnoNQJXx0OSl4FNRWsn6ztbD+jZhu8Seu10Jw
|
||||
SEJVJ+gmTKdRLYORJKyqhDet6g7kAxs4EoJ25WsOnX5nNr00rit+NkMPA7xbJT+7
|
||||
CfI51GQLw7pUPeO2WNt6yZO/YkzZrqvTj5FEwybkUyBv7L0gkqu9wjfDdUw0fVHE
|
||||
xEm4DxjEoaIp8dW/JOzXQ2EF+WaSOgdYsw3Ac+rnnjnNptCdOEDGP6QBkt+oXj4P
|
||||
-----END RSA PRIVATE KEY-----""", passphrase='encrypted')
|
||||
# key with bad IV (DES3)
|
||||
self.assertRaises(
|
||||
keys.BadKeyError, keys.Key.fromString,
|
||||
"""-----BEGIN ENCRYPTED RSA KEY-----
|
||||
Proc-Type: 4,ENCRYPTED
|
||||
DEK-Info: DES-EDE3-CBC,01234
|
||||
|
||||
4Ed/a9OgJWHJsne7yOGWeWMzHYKsxuP9w1v0aYcp+puS75wvhHLiUnNwxz0KDi6n
|
||||
T3YkKLBsoCWS68ApR2J9yeQ6R+EyS+UQDrO9nwqo3DB5BT3Ggt8S1wE7vjNLQD0H
|
||||
g/SJnlqwsECNhh8aAx+Ag0m3ZKOZiRD5mCkcDQsZET7URSmFytDKOjhFn3u6ZFVB
|
||||
sXrfpYc6TJtOQlHd/52JB6aAbjt6afSv955Z7enIi+5yEJ5y7oYQTaE5zrFMP7N5
|
||||
9LbfJFlKXxEddy/DErRLxEjmC+t4svHesoJKc2jjjyNPiOoGGF3kJXea62vsjdNV
|
||||
gMK5Eged3TBVIk2dv8rtJUvyFeCUtjQ1UJZIebScRR47KrbsIpCmU8I4/uHWm5hW
|
||||
0mOwvdx1L/mqx/BHqVU9Dw2COhOdLbFxlFI92chkovkmNk4P48ziyVnpm7ME22sE
|
||||
vfCMsyirdqB1mrL4CSM7FXONv+CgfBfeYVkYW8RfJac9U1L/O+JNn7yee414O/rS
|
||||
hRYw4UdWnH6Gg6niklVKWNY0ZwUZC8zgm2iqy8YCYuneS37jC+OEKP+/s6HSKuqk
|
||||
2bzcl3/TcZXNSM815hnFRpz0anuyAsvwPNRyvxG2/DacJHL1f6luV4B0o6W410yf
|
||||
qXQx01DLo7nuyhJqoH3UGCyyXB+/QUs0mbG2PAEn3f5dVs31JMdbt+PrxURXXjKk
|
||||
4cexpUcIpqqlfpIRe3RD0sDVbH4OXsGhi2kiTfPZu7mgyFxKopRbn1KwU1qKinfY
|
||||
EU9O4PoTak/tPT+5jFNhaP+HrURoi/pU8EAUNSktl7xAkHYwkN/9Cm7DeBghgf3n
|
||||
8+tyCGYDsB5utPD0/Xe9yx0Qhc/kMm4xIyQDyA937dk3mUvLC9vulnAP8I+Izim0
|
||||
fZ182+D1bWwykoD0997mUHG/AUChWR01V1OLwRyPv2wUtiS8VNG76Y2aqKlgqP1P
|
||||
V+IvIEqR4ERvSBVFzXNF8Y6j/sVxo8+aZw+d0L1Ns/R55deErGg3B8i/2EqGd3r+
|
||||
0jps9BqFHHWW87n3VyEB3jWCMj8Vi2EJIfa/7pSaViFIQn8LiBLf+zxG5LTOToK5
|
||||
xkN42fReDcqi3UNfKNGnv4dsplyTR2hyx65lsj4bRKDGLKOuB1y7iB0AGb0LtcAI
|
||||
dcsVlcCeUquDXtqKvRnwfIMg+ZunyjqHBhj3qgRgbXbT6zjaSdNnih569aTg0Vup
|
||||
VykzZ7+n/KVcGLmvX0NesdoI7TKbq4TnEIOynuG5Sf+2GpARO5bjcWKSZeN/Ybgk
|
||||
gccf8Cqf6XWqiwlWd0B7BR3SymeHIaSymC45wmbgdstrbk7Ppa2Tp9AZku8M2Y7c
|
||||
8mY9b+onK075/ypiwBm4L4GRNTFLnoNQJXx0OSl4FNRWsn6ztbD+jZhu8Seu10Jw
|
||||
SEJVJ+gmTKdRLYORJKyqhDet6g7kAxs4EoJ25WsOnX5nNr00rit+NkMPA7xbJT+7
|
||||
CfI51GQLw7pUPeO2WNt6yZO/YkzZrqvTj5FEwybkUyBv7L0gkqu9wjfDdUw0fVHE
|
||||
xEm4DxjEoaIp8dW/JOzXQ2EF+WaSOgdYsw3Ac+rnnjnNptCdOEDGP6QBkt+oXj4P
|
||||
-----END RSA PRIVATE KEY-----""", passphrase='encrypted')
|
||||
|
||||
def test_fromFile(self):
|
||||
"""
|
||||
Test that fromFile works correctly.
|
||||
"""
|
||||
self.assertEqual(keys.Key.fromFile(self.keyFile),
|
||||
keys.Key.fromString(keydata.privateRSA_lsh))
|
||||
self.assertRaises(keys.BadKeyError, keys.Key.fromFile,
|
||||
self.keyFile, 'bad_type')
|
||||
self.assertRaises(keys.BadKeyError, keys.Key.fromFile,
|
||||
self.keyFile, passphrase='unencrypted')
|
||||
|
||||
def test_init(self):
|
||||
"""
|
||||
Test that the PublicKey object is initialized correctly.
|
||||
"""
|
||||
obj = Crypto.PublicKey.RSA.construct((1L, 2L))
|
||||
key = keys.Key(obj)
|
||||
self.assertEqual(key.keyObject, obj)
|
||||
|
||||
def test_equal(self):
|
||||
"""
|
||||
Test that Key objects are compared correctly.
|
||||
"""
|
||||
rsa1 = keys.Key(self.rsaObj)
|
||||
rsa2 = keys.Key(self.rsaObj)
|
||||
rsa3 = keys.Key(Crypto.PublicKey.RSA.construct((1L, 2L)))
|
||||
dsa = keys.Key(self.dsaObj)
|
||||
self.assertTrue(rsa1 == rsa2)
|
||||
self.assertFalse(rsa1 == rsa3)
|
||||
self.assertFalse(rsa1 == dsa)
|
||||
self.assertFalse(rsa1 == object)
|
||||
self.assertFalse(rsa1 == None)
|
||||
|
||||
def test_notEqual(self):
|
||||
"""
|
||||
Test that Key objects are not-compared correctly.
|
||||
"""
|
||||
rsa1 = keys.Key(self.rsaObj)
|
||||
rsa2 = keys.Key(self.rsaObj)
|
||||
rsa3 = keys.Key(Crypto.PublicKey.RSA.construct((1L, 2L)))
|
||||
dsa = keys.Key(self.dsaObj)
|
||||
self.assertFalse(rsa1 != rsa2)
|
||||
self.assertTrue(rsa1 != rsa3)
|
||||
self.assertTrue(rsa1 != dsa)
|
||||
self.assertTrue(rsa1 != object)
|
||||
self.assertTrue(rsa1 != None)
|
||||
|
||||
def test_type(self):
|
||||
"""
|
||||
Test that the type method returns the correct type for an object.
|
||||
"""
|
||||
self.assertEqual(keys.Key(self.rsaObj).type(), 'RSA')
|
||||
self.assertEqual(keys.Key(self.rsaObj).sshType(), 'ssh-rsa')
|
||||
self.assertEqual(keys.Key(self.dsaObj).type(), 'DSA')
|
||||
self.assertEqual(keys.Key(self.dsaObj).sshType(), 'ssh-dss')
|
||||
self.assertRaises(RuntimeError, keys.Key(None).type)
|
||||
self.assertRaises(RuntimeError, keys.Key(None).sshType)
|
||||
self.assertRaises(RuntimeError, keys.Key(self).type)
|
||||
self.assertRaises(RuntimeError, keys.Key(self).sshType)
|
||||
|
||||
def test_fromBlob(self):
|
||||
"""
|
||||
Test that a public key is correctly generated from a public key blob.
|
||||
"""
|
||||
rsaBlob = common.NS('ssh-rsa') + common.MP(2) + common.MP(3)
|
||||
rsaKey = keys.Key.fromString(rsaBlob)
|
||||
dsaBlob = (common.NS('ssh-dss') + common.MP(2) + common.MP(3) +
|
||||
common.MP(4) + common.MP(5))
|
||||
dsaKey = keys.Key.fromString(dsaBlob)
|
||||
badBlob = common.NS('ssh-bad')
|
||||
self.assertTrue(rsaKey.isPublic())
|
||||
self.assertEqual(rsaKey.data(), {'e':2L, 'n':3L})
|
||||
self.assertTrue(dsaKey.isPublic())
|
||||
self.assertEqual(dsaKey.data(), {'p':2L, 'q':3L, 'g':4L, 'y':5L})
|
||||
self.assertRaises(keys.BadKeyError,
|
||||
keys.Key.fromString, badBlob)
|
||||
|
||||
|
||||
def test_fromPrivateBlob(self):
|
||||
"""
|
||||
Test that a private key is correctly generated from a private key blob.
|
||||
"""
|
||||
rsaBlob = (common.NS('ssh-rsa') + common.MP(2) + common.MP(3) +
|
||||
common.MP(4) + common.MP(5) + common.MP(6) + common.MP(7))
|
||||
rsaKey = keys.Key._fromString_PRIVATE_BLOB(rsaBlob)
|
||||
dsaBlob = (common.NS('ssh-dss') + common.MP(2) + common.MP(3) +
|
||||
common.MP(4) + common.MP(5) + common.MP(6))
|
||||
dsaKey = keys.Key._fromString_PRIVATE_BLOB(dsaBlob)
|
||||
badBlob = common.NS('ssh-bad')
|
||||
self.assertFalse(rsaKey.isPublic())
|
||||
self.assertEqual(
|
||||
rsaKey.data(), {'n':2L, 'e':3L, 'd':4L, 'u':5L, 'p':6L, 'q':7L})
|
||||
self.assertFalse(dsaKey.isPublic())
|
||||
self.assertEqual(dsaKey.data(), {'p':2L, 'q':3L, 'g':4L, 'y':5L, 'x':6L})
|
||||
self.assertRaises(
|
||||
keys.BadKeyError, keys.Key._fromString_PRIVATE_BLOB, badBlob)
|
||||
|
||||
|
||||
def test_blob(self):
|
||||
"""
|
||||
Test that the Key object generates blobs correctly.
|
||||
"""
|
||||
self.assertEqual(keys.Key(self.rsaObj).blob(),
|
||||
'\x00\x00\x00\x07ssh-rsa\x00\x00\x00\x01\x02'
|
||||
'\x00\x00\x00\x01\x01')
|
||||
self.assertEqual(keys.Key(self.dsaObj).blob(),
|
||||
'\x00\x00\x00\x07ssh-dss\x00\x00\x00\x01\x03'
|
||||
'\x00\x00\x00\x01\x04\x00\x00\x00\x01\x02'
|
||||
'\x00\x00\x00\x01\x01')
|
||||
|
||||
badKey = keys.Key(None)
|
||||
self.assertRaises(RuntimeError, badKey.blob)
|
||||
|
||||
|
||||
def test_privateBlob(self):
|
||||
"""
|
||||
L{Key.privateBlob} returns the SSH protocol-level format of the private
|
||||
key and raises L{RuntimeError} if the underlying key object is invalid.
|
||||
"""
|
||||
self.assertEqual(keys.Key(self.rsaObj).privateBlob(),
|
||||
'\x00\x00\x00\x07ssh-rsa\x00\x00\x00\x01\x01'
|
||||
'\x00\x00\x00\x01\x02\x00\x00\x00\x01\x03\x00'
|
||||
'\x00\x00\x01\x04\x00\x00\x00\x01\x04\x00\x00'
|
||||
'\x00\x01\x05')
|
||||
self.assertEqual(keys.Key(self.dsaObj).privateBlob(),
|
||||
'\x00\x00\x00\x07ssh-dss\x00\x00\x00\x01\x03'
|
||||
'\x00\x00\x00\x01\x04\x00\x00\x00\x01\x02\x00'
|
||||
'\x00\x00\x01\x01\x00\x00\x00\x01\x05')
|
||||
|
||||
badKey = keys.Key(None)
|
||||
self.assertRaises(RuntimeError, badKey.privateBlob)
|
||||
|
||||
|
||||
def test_toOpenSSH(self):
|
||||
"""
|
||||
Test that the Key object generates OpenSSH keys correctly.
|
||||
"""
|
||||
key = keys.Key.fromString(keydata.privateRSA_lsh)
|
||||
self.assertEqual(key.toString('openssh'), keydata.privateRSA_openssh)
|
||||
self.assertEqual(key.toString('openssh', 'encrypted'),
|
||||
keydata.privateRSA_openssh_encrypted)
|
||||
self.assertEqual(key.public().toString('openssh'),
|
||||
keydata.publicRSA_openssh[:-8]) # no comment
|
||||
self.assertEqual(key.public().toString('openssh', 'comment'),
|
||||
keydata.publicRSA_openssh)
|
||||
key = keys.Key.fromString(keydata.privateDSA_lsh)
|
||||
self.assertEqual(key.toString('openssh'), keydata.privateDSA_openssh)
|
||||
self.assertEqual(key.public().toString('openssh', 'comment'),
|
||||
keydata.publicDSA_openssh)
|
||||
self.assertEqual(key.public().toString('openssh'),
|
||||
keydata.publicDSA_openssh[:-8]) # no comment
|
||||
|
||||
def test_toLSH(self):
|
||||
"""
|
||||
Test that the Key object generates LSH keys correctly.
|
||||
"""
|
||||
key = keys.Key.fromString(keydata.privateRSA_openssh)
|
||||
self.assertEqual(key.toString('lsh'), keydata.privateRSA_lsh)
|
||||
self.assertEqual(key.public().toString('lsh'),
|
||||
keydata.publicRSA_lsh)
|
||||
key = keys.Key.fromString(keydata.privateDSA_openssh)
|
||||
self.assertEqual(key.toString('lsh'), keydata.privateDSA_lsh)
|
||||
self.assertEqual(key.public().toString('lsh'),
|
||||
keydata.publicDSA_lsh)
|
||||
|
||||
def test_toAgentv3(self):
|
||||
"""
|
||||
Test that the Key object generates Agent v3 keys correctly.
|
||||
"""
|
||||
key = keys.Key.fromString(keydata.privateRSA_openssh)
|
||||
self.assertEqual(key.toString('agentv3'), keydata.privateRSA_agentv3)
|
||||
key = keys.Key.fromString(keydata.privateDSA_openssh)
|
||||
self.assertEqual(key.toString('agentv3'), keydata.privateDSA_agentv3)
|
||||
|
||||
def test_toStringErrors(self):
|
||||
"""
|
||||
Test that toString raises errors appropriately.
|
||||
"""
|
||||
self.assertRaises(keys.BadKeyError, keys.Key(self.rsaObj).toString,
|
||||
'bad_type')
|
||||
|
||||
def test_sign(self):
|
||||
"""
|
||||
Test that the Key object generates correct signatures.
|
||||
"""
|
||||
key = keys.Key.fromString(keydata.privateRSA_openssh)
|
||||
self.assertEqual(key.sign(''), self.rsaSignature)
|
||||
key = keys.Key.fromString(keydata.privateDSA_openssh)
|
||||
self.assertEqual(key.sign(''), self.dsaSignature)
|
||||
|
||||
|
||||
def test_verify(self):
|
||||
"""
|
||||
Test that the Key object correctly verifies signatures.
|
||||
"""
|
||||
key = keys.Key.fromString(keydata.publicRSA_openssh)
|
||||
self.assertTrue(key.verify(self.rsaSignature, ''))
|
||||
self.assertFalse(key.verify(self.rsaSignature, 'a'))
|
||||
self.assertFalse(key.verify(self.dsaSignature, ''))
|
||||
key = keys.Key.fromString(keydata.publicDSA_openssh)
|
||||
self.assertTrue(key.verify(self.dsaSignature, ''))
|
||||
self.assertFalse(key.verify(self.dsaSignature, 'a'))
|
||||
self.assertFalse(key.verify(self.rsaSignature, ''))
|
||||
|
||||
|
||||
def test_verifyDSANoPrefix(self):
|
||||
"""
|
||||
Some commercial SSH servers send DSA keys as 2 20-byte numbers;
|
||||
they are still verified as valid keys.
|
||||
"""
|
||||
key = keys.Key.fromString(keydata.publicDSA_openssh)
|
||||
self.assertTrue(key.verify(self.dsaSignature[-40:], ''))
|
||||
|
||||
|
||||
def test_repr(self):
|
||||
"""
|
||||
Test the pretty representation of Key.
|
||||
"""
|
||||
self.assertEqual(repr(keys.Key(self.rsaObj)),
|
||||
"""<RSA Private Key (0 bits)
|
||||
attr d:
|
||||
\t03
|
||||
attr e:
|
||||
\t02
|
||||
attr n:
|
||||
\t01
|
||||
attr p:
|
||||
\t04
|
||||
attr q:
|
||||
\t05
|
||||
attr u:
|
||||
\t04>""")
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -0,0 +1,372 @@
|
|||
# -*- test-case-name: twisted.conch.test.test_manhole -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.conch.manhole}.
|
||||
"""
|
||||
|
||||
import traceback
|
||||
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet import error, defer
|
||||
from twisted.test.proto_helpers import StringTransport
|
||||
from twisted.conch.test.test_recvline import _TelnetMixin, _SSHMixin, _StdioMixin, stdio, ssh
|
||||
from twisted.conch import manhole
|
||||
from twisted.conch.insults import insults
|
||||
|
||||
|
||||
def determineDefaultFunctionName():
|
||||
"""
|
||||
Return the string used by Python as the name for code objects which are
|
||||
compiled from interactive input or at the top-level of modules.
|
||||
"""
|
||||
try:
|
||||
1 // 0
|
||||
except:
|
||||
# The last frame is this function. The second to last frame is this
|
||||
# function's caller, which is module-scope, which is what we want,
|
||||
# so -2.
|
||||
return traceback.extract_stack()[-2][2]
|
||||
defaultFunctionName = determineDefaultFunctionName()
|
||||
|
||||
|
||||
|
||||
class ManholeInterpreterTests(unittest.TestCase):
|
||||
"""
|
||||
Tests for L{manhole.ManholeInterpreter}.
|
||||
"""
|
||||
def test_resetBuffer(self):
|
||||
"""
|
||||
L{ManholeInterpreter.resetBuffer} should empty the input buffer.
|
||||
"""
|
||||
interpreter = manhole.ManholeInterpreter(None)
|
||||
interpreter.buffer.extend(["1", "2"])
|
||||
interpreter.resetBuffer()
|
||||
self.assertFalse(interpreter.buffer)
|
||||
|
||||
|
||||
|
||||
class ManholeProtocolTests(unittest.TestCase):
|
||||
"""
|
||||
Tests for L{manhole.Manhole}.
|
||||
"""
|
||||
def test_interruptResetsInterpreterBuffer(self):
|
||||
"""
|
||||
L{manhole.Manhole.handle_INT} should cause the interpreter input buffer
|
||||
to be reset.
|
||||
"""
|
||||
transport = StringTransport()
|
||||
terminal = insults.ServerProtocol(manhole.Manhole)
|
||||
terminal.makeConnection(transport)
|
||||
protocol = terminal.terminalProtocol
|
||||
interpreter = protocol.interpreter
|
||||
interpreter.buffer.extend(["1", "2"])
|
||||
protocol.handle_INT()
|
||||
self.assertFalse(interpreter.buffer)
|
||||
|
||||
|
||||
|
||||
class WriterTestCase(unittest.TestCase):
|
||||
def testInteger(self):
|
||||
manhole.lastColorizedLine("1")
|
||||
|
||||
|
||||
def testDoubleQuoteString(self):
|
||||
manhole.lastColorizedLine('"1"')
|
||||
|
||||
|
||||
def testSingleQuoteString(self):
|
||||
manhole.lastColorizedLine("'1'")
|
||||
|
||||
|
||||
def testTripleSingleQuotedString(self):
|
||||
manhole.lastColorizedLine("'''1'''")
|
||||
|
||||
|
||||
def testTripleDoubleQuotedString(self):
|
||||
manhole.lastColorizedLine('"""1"""')
|
||||
|
||||
|
||||
def testFunctionDefinition(self):
|
||||
manhole.lastColorizedLine("def foo():")
|
||||
|
||||
|
||||
def testClassDefinition(self):
|
||||
manhole.lastColorizedLine("class foo:")
|
||||
|
||||
|
||||
class ManholeLoopbackMixin:
|
||||
serverProtocol = manhole.ColoredManhole
|
||||
|
||||
def wfd(self, d):
|
||||
return defer.waitForDeferred(d)
|
||||
|
||||
def testSimpleExpression(self):
|
||||
done = self.recvlineClient.expect("done")
|
||||
|
||||
self._testwrite(
|
||||
"1 + 1\n"
|
||||
"done")
|
||||
|
||||
def finished(ign):
|
||||
self._assertBuffer(
|
||||
[">>> 1 + 1",
|
||||
"2",
|
||||
">>> done"])
|
||||
|
||||
return done.addCallback(finished)
|
||||
|
||||
def testTripleQuoteLineContinuation(self):
|
||||
done = self.recvlineClient.expect("done")
|
||||
|
||||
self._testwrite(
|
||||
"'''\n'''\n"
|
||||
"done")
|
||||
|
||||
def finished(ign):
|
||||
self._assertBuffer(
|
||||
[">>> '''",
|
||||
"... '''",
|
||||
"'\\n'",
|
||||
">>> done"])
|
||||
|
||||
return done.addCallback(finished)
|
||||
|
||||
def testFunctionDefinition(self):
|
||||
done = self.recvlineClient.expect("done")
|
||||
|
||||
self._testwrite(
|
||||
"def foo(bar):\n"
|
||||
"\tprint bar\n\n"
|
||||
"foo(42)\n"
|
||||
"done")
|
||||
|
||||
def finished(ign):
|
||||
self._assertBuffer(
|
||||
[">>> def foo(bar):",
|
||||
"... print bar",
|
||||
"... ",
|
||||
">>> foo(42)",
|
||||
"42",
|
||||
">>> done"])
|
||||
|
||||
return done.addCallback(finished)
|
||||
|
||||
def testClassDefinition(self):
|
||||
done = self.recvlineClient.expect("done")
|
||||
|
||||
self._testwrite(
|
||||
"class Foo:\n"
|
||||
"\tdef bar(self):\n"
|
||||
"\t\tprint 'Hello, world!'\n\n"
|
||||
"Foo().bar()\n"
|
||||
"done")
|
||||
|
||||
def finished(ign):
|
||||
self._assertBuffer(
|
||||
[">>> class Foo:",
|
||||
"... def bar(self):",
|
||||
"... print 'Hello, world!'",
|
||||
"... ",
|
||||
">>> Foo().bar()",
|
||||
"Hello, world!",
|
||||
">>> done"])
|
||||
|
||||
return done.addCallback(finished)
|
||||
|
||||
def testException(self):
|
||||
done = self.recvlineClient.expect("done")
|
||||
|
||||
self._testwrite(
|
||||
"raise Exception('foo bar baz')\n"
|
||||
"done")
|
||||
|
||||
def finished(ign):
|
||||
self._assertBuffer(
|
||||
[">>> raise Exception('foo bar baz')",
|
||||
"Traceback (most recent call last):",
|
||||
' File "<console>", line 1, in ' + defaultFunctionName,
|
||||
"Exception: foo bar baz",
|
||||
">>> done"])
|
||||
|
||||
return done.addCallback(finished)
|
||||
|
||||
def testControlC(self):
|
||||
done = self.recvlineClient.expect("done")
|
||||
|
||||
self._testwrite(
|
||||
"cancelled line" + manhole.CTRL_C +
|
||||
"done")
|
||||
|
||||
def finished(ign):
|
||||
self._assertBuffer(
|
||||
[">>> cancelled line",
|
||||
"KeyboardInterrupt",
|
||||
">>> done"])
|
||||
|
||||
return done.addCallback(finished)
|
||||
|
||||
|
||||
def test_interruptDuringContinuation(self):
|
||||
"""
|
||||
Sending ^C to Manhole while in a state where more input is required to
|
||||
complete a statement should discard the entire ongoing statement and
|
||||
reset the input prompt to the non-continuation prompt.
|
||||
"""
|
||||
continuing = self.recvlineClient.expect("things")
|
||||
|
||||
self._testwrite("(\nthings")
|
||||
|
||||
def gotContinuation(ignored):
|
||||
self._assertBuffer(
|
||||
[">>> (",
|
||||
"... things"])
|
||||
interrupted = self.recvlineClient.expect(">>> ")
|
||||
self._testwrite(manhole.CTRL_C)
|
||||
return interrupted
|
||||
continuing.addCallback(gotContinuation)
|
||||
|
||||
def gotInterruption(ignored):
|
||||
self._assertBuffer(
|
||||
[">>> (",
|
||||
"... things",
|
||||
"KeyboardInterrupt",
|
||||
">>> "])
|
||||
continuing.addCallback(gotInterruption)
|
||||
return continuing
|
||||
|
||||
|
||||
def testControlBackslash(self):
|
||||
self._testwrite("cancelled line")
|
||||
partialLine = self.recvlineClient.expect("cancelled line")
|
||||
|
||||
def gotPartialLine(ign):
|
||||
self._assertBuffer(
|
||||
[">>> cancelled line"])
|
||||
self._testwrite(manhole.CTRL_BACKSLASH)
|
||||
|
||||
d = self.recvlineClient.onDisconnection
|
||||
return self.assertFailure(d, error.ConnectionDone)
|
||||
|
||||
def gotClearedLine(ign):
|
||||
self._assertBuffer(
|
||||
[""])
|
||||
|
||||
return partialLine.addCallback(gotPartialLine).addCallback(gotClearedLine)
|
||||
|
||||
def testControlD(self):
|
||||
self._testwrite("1 + 1")
|
||||
helloWorld = self.wfd(self.recvlineClient.expect(r"\+ 1"))
|
||||
yield helloWorld
|
||||
helloWorld.getResult()
|
||||
self._assertBuffer([">>> 1 + 1"])
|
||||
|
||||
self._testwrite(manhole.CTRL_D + " + 1")
|
||||
cleared = self.wfd(self.recvlineClient.expect(r"\+ 1"))
|
||||
yield cleared
|
||||
cleared.getResult()
|
||||
self._assertBuffer([">>> 1 + 1 + 1"])
|
||||
|
||||
self._testwrite("\n")
|
||||
printed = self.wfd(self.recvlineClient.expect("3\n>>> "))
|
||||
yield printed
|
||||
printed.getResult()
|
||||
|
||||
self._testwrite(manhole.CTRL_D)
|
||||
d = self.recvlineClient.onDisconnection
|
||||
disconnected = self.wfd(self.assertFailure(d, error.ConnectionDone))
|
||||
yield disconnected
|
||||
disconnected.getResult()
|
||||
testControlD = defer.deferredGenerator(testControlD)
|
||||
|
||||
|
||||
def testControlL(self):
|
||||
"""
|
||||
CTRL+L is generally used as a redraw-screen command in terminal
|
||||
applications. Manhole doesn't currently respect this usage of it,
|
||||
but it should at least do something reasonable in response to this
|
||||
event (rather than, say, eating your face).
|
||||
"""
|
||||
# Start off with a newline so that when we clear the display we can
|
||||
# tell by looking for the missing first empty prompt line.
|
||||
self._testwrite("\n1 + 1")
|
||||
helloWorld = self.wfd(self.recvlineClient.expect(r"\+ 1"))
|
||||
yield helloWorld
|
||||
helloWorld.getResult()
|
||||
self._assertBuffer([">>> ", ">>> 1 + 1"])
|
||||
|
||||
self._testwrite(manhole.CTRL_L + " + 1")
|
||||
redrew = self.wfd(self.recvlineClient.expect(r"1 \+ 1 \+ 1"))
|
||||
yield redrew
|
||||
redrew.getResult()
|
||||
self._assertBuffer([">>> 1 + 1 + 1"])
|
||||
testControlL = defer.deferredGenerator(testControlL)
|
||||
|
||||
|
||||
def test_controlA(self):
|
||||
"""
|
||||
CTRL-A can be used as HOME - returning cursor to beginning of
|
||||
current line buffer.
|
||||
"""
|
||||
self._testwrite('rint "hello"' + '\x01' + 'p')
|
||||
d = self.recvlineClient.expect('print "hello"')
|
||||
def cb(ignore):
|
||||
self._assertBuffer(['>>> print "hello"'])
|
||||
return d.addCallback(cb)
|
||||
|
||||
|
||||
def test_controlE(self):
|
||||
"""
|
||||
CTRL-E can be used as END - setting cursor to end of current
|
||||
line buffer.
|
||||
"""
|
||||
self._testwrite('rint "hello' + '\x01' + 'p' + '\x05' + '"')
|
||||
d = self.recvlineClient.expect('print "hello"')
|
||||
def cb(ignore):
|
||||
self._assertBuffer(['>>> print "hello"'])
|
||||
return d.addCallback(cb)
|
||||
|
||||
|
||||
def testDeferred(self):
|
||||
self._testwrite(
|
||||
"from twisted.internet import defer, reactor\n"
|
||||
"d = defer.Deferred()\n"
|
||||
"d\n")
|
||||
|
||||
deferred = self.wfd(self.recvlineClient.expect("<Deferred #0>"))
|
||||
yield deferred
|
||||
deferred.getResult()
|
||||
|
||||
self._testwrite(
|
||||
"c = reactor.callLater(0.1, d.callback, 'Hi!')\n")
|
||||
delayed = self.wfd(self.recvlineClient.expect(">>> "))
|
||||
yield delayed
|
||||
delayed.getResult()
|
||||
|
||||
called = self.wfd(self.recvlineClient.expect("Deferred #0 called back: 'Hi!'\n>>> "))
|
||||
yield called
|
||||
called.getResult()
|
||||
self._assertBuffer(
|
||||
[">>> from twisted.internet import defer, reactor",
|
||||
">>> d = defer.Deferred()",
|
||||
">>> d",
|
||||
"<Deferred #0>",
|
||||
">>> c = reactor.callLater(0.1, d.callback, 'Hi!')",
|
||||
"Deferred #0 called back: 'Hi!'",
|
||||
">>> "])
|
||||
|
||||
testDeferred = defer.deferredGenerator(testDeferred)
|
||||
|
||||
class ManholeLoopbackTelnet(_TelnetMixin, unittest.TestCase, ManholeLoopbackMixin):
|
||||
pass
|
||||
|
||||
class ManholeLoopbackSSH(_SSHMixin, unittest.TestCase, ManholeLoopbackMixin):
|
||||
if ssh is None:
|
||||
skip = "Crypto requirements missing, can't run manhole tests over ssh"
|
||||
|
||||
class ManholeLoopbackStdio(_StdioMixin, unittest.TestCase, ManholeLoopbackMixin):
|
||||
if stdio is None:
|
||||
skip = "Terminal requirements missing, can't run manhole tests over stdio"
|
||||
else:
|
||||
serverProtocol = stdio.ConsoleManhole
|
||||
|
|
@ -0,0 +1,47 @@
|
|||
# -*- twisted.conch.test.test_mixin -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
import time
|
||||
|
||||
from twisted.internet import reactor, protocol
|
||||
|
||||
from twisted.trial import unittest
|
||||
from twisted.test.proto_helpers import StringTransport
|
||||
|
||||
from twisted.conch import mixin
|
||||
|
||||
|
||||
class TestBufferingProto(mixin.BufferingMixin):
|
||||
scheduled = False
|
||||
rescheduled = 0
|
||||
def schedule(self):
|
||||
self.scheduled = True
|
||||
return object()
|
||||
|
||||
def reschedule(self, token):
|
||||
self.rescheduled += 1
|
||||
|
||||
|
||||
|
||||
class BufferingTest(unittest.TestCase):
|
||||
def testBuffering(self):
|
||||
p = TestBufferingProto()
|
||||
t = p.transport = StringTransport()
|
||||
|
||||
self.assertFalse(p.scheduled)
|
||||
|
||||
L = ['foo', 'bar', 'baz', 'quux']
|
||||
|
||||
p.write('foo')
|
||||
self.assertTrue(p.scheduled)
|
||||
self.assertFalse(p.rescheduled)
|
||||
|
||||
for s in L:
|
||||
n = p.rescheduled
|
||||
p.write(s)
|
||||
self.assertEqual(p.rescheduled, n + 1)
|
||||
self.assertEqual(t.value(), '')
|
||||
|
||||
p.flush()
|
||||
self.assertEqual(t.value(), 'foo' + ''.join(L))
|
||||
|
|
@ -0,0 +1,101 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.conch.openssh_compat}.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from twisted.trial.unittest import TestCase
|
||||
from twisted.python.filepath import FilePath
|
||||
|
||||
try:
|
||||
import Crypto.Cipher.DES3
|
||||
import pyasn1
|
||||
except ImportError:
|
||||
OpenSSHFactory = None
|
||||
else:
|
||||
from twisted.conch.openssh_compat.factory import OpenSSHFactory
|
||||
|
||||
from twisted.conch.test import keydata
|
||||
from twisted.test.test_process import MockOS
|
||||
|
||||
|
||||
class OpenSSHFactoryTests(TestCase):
|
||||
"""
|
||||
Tests for L{OpenSSHFactory}.
|
||||
"""
|
||||
if getattr(os, "geteuid", None) is None:
|
||||
skip = "geteuid/seteuid not available"
|
||||
elif OpenSSHFactory is None:
|
||||
skip = "Cannot run without PyCrypto or PyASN1"
|
||||
|
||||
def setUp(self):
|
||||
self.factory = OpenSSHFactory()
|
||||
self.keysDir = FilePath(self.mktemp())
|
||||
self.keysDir.makedirs()
|
||||
self.factory.dataRoot = self.keysDir.path
|
||||
|
||||
self.keysDir.child("ssh_host_foo").setContent("foo")
|
||||
self.keysDir.child("bar_key").setContent("foo")
|
||||
self.keysDir.child("ssh_host_one_key").setContent(
|
||||
keydata.privateRSA_openssh)
|
||||
self.keysDir.child("ssh_host_two_key").setContent(
|
||||
keydata.privateDSA_openssh)
|
||||
self.keysDir.child("ssh_host_three_key").setContent(
|
||||
"not a key content")
|
||||
|
||||
self.keysDir.child("ssh_host_one_key.pub").setContent(
|
||||
keydata.publicRSA_openssh)
|
||||
|
||||
self.mockos = MockOS()
|
||||
self.patch(os, "seteuid", self.mockos.seteuid)
|
||||
self.patch(os, "setegid", self.mockos.setegid)
|
||||
|
||||
|
||||
def test_getPublicKeys(self):
|
||||
"""
|
||||
L{OpenSSHFactory.getPublicKeys} should return the available public keys
|
||||
in the data directory
|
||||
"""
|
||||
keys = self.factory.getPublicKeys()
|
||||
self.assertEqual(len(keys), 1)
|
||||
keyTypes = keys.keys()
|
||||
self.assertEqual(keyTypes, ['ssh-rsa'])
|
||||
|
||||
|
||||
def test_getPrivateKeys(self):
|
||||
"""
|
||||
L{OpenSSHFactory.getPrivateKeys} should return the available private
|
||||
keys in the data directory.
|
||||
"""
|
||||
keys = self.factory.getPrivateKeys()
|
||||
self.assertEqual(len(keys), 2)
|
||||
keyTypes = keys.keys()
|
||||
self.assertEqual(set(keyTypes), set(['ssh-rsa', 'ssh-dss']))
|
||||
self.assertEqual(self.mockos.seteuidCalls, [])
|
||||
self.assertEqual(self.mockos.setegidCalls, [])
|
||||
|
||||
|
||||
def test_getPrivateKeysAsRoot(self):
|
||||
"""
|
||||
L{OpenSSHFactory.getPrivateKeys} should switch to root if the keys
|
||||
aren't readable by the current user.
|
||||
"""
|
||||
keyFile = self.keysDir.child("ssh_host_two_key")
|
||||
# Fake permission error by changing the mode
|
||||
keyFile.chmod(0000)
|
||||
self.addCleanup(keyFile.chmod, 0777)
|
||||
# And restore the right mode when seteuid is called
|
||||
savedSeteuid = os.seteuid
|
||||
def seteuid(euid):
|
||||
keyFile.chmod(0777)
|
||||
return savedSeteuid(euid)
|
||||
self.patch(os, "seteuid", seteuid)
|
||||
keys = self.factory.getPrivateKeys()
|
||||
self.assertEqual(len(keys), 2)
|
||||
keyTypes = keys.keys()
|
||||
self.assertEqual(set(keyTypes), set(['ssh-rsa', 'ssh-dss']))
|
||||
self.assertEqual(self.mockos.seteuidCalls, [0, os.geteuid()])
|
||||
self.assertEqual(self.mockos.setegidCalls, [0, os.getegid()])
|
||||
|
|
@ -0,0 +1,706 @@
|
|||
# -*- test-case-name: twisted.conch.test.test_recvline -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.conch.recvline} and fixtures for testing related
|
||||
functionality.
|
||||
"""
|
||||
|
||||
import sys, os
|
||||
|
||||
from twisted.conch.insults import insults
|
||||
from twisted.conch import recvline
|
||||
|
||||
from twisted.python import reflect, components
|
||||
from twisted.internet import defer, error
|
||||
from twisted.trial import unittest
|
||||
from twisted.cred import portal
|
||||
from twisted.test.proto_helpers import StringTransport
|
||||
|
||||
class Arrows(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.underlyingTransport = StringTransport()
|
||||
self.pt = insults.ServerProtocol()
|
||||
self.p = recvline.HistoricRecvLine()
|
||||
self.pt.protocolFactory = lambda: self.p
|
||||
self.pt.factory = self
|
||||
self.pt.makeConnection(self.underlyingTransport)
|
||||
# self.p.makeConnection(self.pt)
|
||||
|
||||
def test_printableCharacters(self):
|
||||
"""
|
||||
When L{HistoricRecvLine} receives a printable character,
|
||||
it adds it to the current line buffer.
|
||||
"""
|
||||
self.p.keystrokeReceived('x', None)
|
||||
self.p.keystrokeReceived('y', None)
|
||||
self.p.keystrokeReceived('z', None)
|
||||
|
||||
self.assertEqual(self.p.currentLineBuffer(), ('xyz', ''))
|
||||
|
||||
def test_horizontalArrows(self):
|
||||
"""
|
||||
When L{HistoricRecvLine} receives an LEFT_ARROW or
|
||||
RIGHT_ARROW keystroke it moves the cursor left or right
|
||||
in the current line buffer, respectively.
|
||||
"""
|
||||
kR = lambda ch: self.p.keystrokeReceived(ch, None)
|
||||
for ch in 'xyz':
|
||||
kR(ch)
|
||||
|
||||
self.assertEqual(self.p.currentLineBuffer(), ('xyz', ''))
|
||||
|
||||
kR(self.pt.RIGHT_ARROW)
|
||||
self.assertEqual(self.p.currentLineBuffer(), ('xyz', ''))
|
||||
|
||||
kR(self.pt.LEFT_ARROW)
|
||||
self.assertEqual(self.p.currentLineBuffer(), ('xy', 'z'))
|
||||
|
||||
kR(self.pt.LEFT_ARROW)
|
||||
self.assertEqual(self.p.currentLineBuffer(), ('x', 'yz'))
|
||||
|
||||
kR(self.pt.LEFT_ARROW)
|
||||
self.assertEqual(self.p.currentLineBuffer(), ('', 'xyz'))
|
||||
|
||||
kR(self.pt.LEFT_ARROW)
|
||||
self.assertEqual(self.p.currentLineBuffer(), ('', 'xyz'))
|
||||
|
||||
kR(self.pt.RIGHT_ARROW)
|
||||
self.assertEqual(self.p.currentLineBuffer(), ('x', 'yz'))
|
||||
|
||||
kR(self.pt.RIGHT_ARROW)
|
||||
self.assertEqual(self.p.currentLineBuffer(), ('xy', 'z'))
|
||||
|
||||
kR(self.pt.RIGHT_ARROW)
|
||||
self.assertEqual(self.p.currentLineBuffer(), ('xyz', ''))
|
||||
|
||||
kR(self.pt.RIGHT_ARROW)
|
||||
self.assertEqual(self.p.currentLineBuffer(), ('xyz', ''))
|
||||
|
||||
def test_newline(self):
|
||||
"""
|
||||
When {HistoricRecvLine} receives a newline, it adds the current
|
||||
line buffer to the end of its history buffer.
|
||||
"""
|
||||
kR = lambda ch: self.p.keystrokeReceived(ch, None)
|
||||
|
||||
for ch in 'xyz\nabc\n123\n':
|
||||
kR(ch)
|
||||
|
||||
self.assertEqual(self.p.currentHistoryBuffer(),
|
||||
(('xyz', 'abc', '123'), ()))
|
||||
|
||||
kR('c')
|
||||
kR('b')
|
||||
kR('a')
|
||||
self.assertEqual(self.p.currentHistoryBuffer(),
|
||||
(('xyz', 'abc', '123'), ()))
|
||||
|
||||
kR('\n')
|
||||
self.assertEqual(self.p.currentHistoryBuffer(),
|
||||
(('xyz', 'abc', '123', 'cba'), ()))
|
||||
|
||||
def test_verticalArrows(self):
|
||||
"""
|
||||
When L{HistoricRecvLine} receives UP_ARROW or DOWN_ARROW
|
||||
keystrokes it move the current index in the current history
|
||||
buffer up or down, and resets the current line buffer to the
|
||||
previous or next line in history, respectively for each.
|
||||
"""
|
||||
kR = lambda ch: self.p.keystrokeReceived(ch, None)
|
||||
|
||||
for ch in 'xyz\nabc\n123\n':
|
||||
kR(ch)
|
||||
|
||||
self.assertEqual(self.p.currentHistoryBuffer(),
|
||||
(('xyz', 'abc', '123'), ()))
|
||||
self.assertEqual(self.p.currentLineBuffer(), ('', ''))
|
||||
|
||||
kR(self.pt.UP_ARROW)
|
||||
self.assertEqual(self.p.currentHistoryBuffer(),
|
||||
(('xyz', 'abc'), ('123',)))
|
||||
self.assertEqual(self.p.currentLineBuffer(), ('123', ''))
|
||||
|
||||
kR(self.pt.UP_ARROW)
|
||||
self.assertEqual(self.p.currentHistoryBuffer(),
|
||||
(('xyz',), ('abc', '123')))
|
||||
self.assertEqual(self.p.currentLineBuffer(), ('abc', ''))
|
||||
|
||||
kR(self.pt.UP_ARROW)
|
||||
self.assertEqual(self.p.currentHistoryBuffer(),
|
||||
((), ('xyz', 'abc', '123')))
|
||||
self.assertEqual(self.p.currentLineBuffer(), ('xyz', ''))
|
||||
|
||||
kR(self.pt.UP_ARROW)
|
||||
self.assertEqual(self.p.currentHistoryBuffer(),
|
||||
((), ('xyz', 'abc', '123')))
|
||||
self.assertEqual(self.p.currentLineBuffer(), ('xyz', ''))
|
||||
|
||||
for i in range(4):
|
||||
kR(self.pt.DOWN_ARROW)
|
||||
self.assertEqual(self.p.currentHistoryBuffer(),
|
||||
(('xyz', 'abc', '123'), ()))
|
||||
|
||||
def test_home(self):
|
||||
"""
|
||||
When L{HistoricRecvLine} receives a HOME keystroke it moves the
|
||||
cursor to the beginning of the current line buffer.
|
||||
"""
|
||||
kR = lambda ch: self.p.keystrokeReceived(ch, None)
|
||||
|
||||
for ch in 'hello, world':
|
||||
kR(ch)
|
||||
self.assertEqual(self.p.currentLineBuffer(), ('hello, world', ''))
|
||||
|
||||
kR(self.pt.HOME)
|
||||
self.assertEqual(self.p.currentLineBuffer(), ('', 'hello, world'))
|
||||
|
||||
def test_end(self):
|
||||
"""
|
||||
When L{HistoricRecvLine} receives a END keystroke it moves the cursor
|
||||
to the end of the current line buffer.
|
||||
"""
|
||||
kR = lambda ch: self.p.keystrokeReceived(ch, None)
|
||||
|
||||
for ch in 'hello, world':
|
||||
kR(ch)
|
||||
self.assertEqual(self.p.currentLineBuffer(), ('hello, world', ''))
|
||||
|
||||
kR(self.pt.HOME)
|
||||
kR(self.pt.END)
|
||||
self.assertEqual(self.p.currentLineBuffer(), ('hello, world', ''))
|
||||
|
||||
def test_backspace(self):
|
||||
"""
|
||||
When L{HistoricRecvLine} receives a BACKSPACE keystroke it deletes
|
||||
the character immediately before the cursor.
|
||||
"""
|
||||
kR = lambda ch: self.p.keystrokeReceived(ch, None)
|
||||
|
||||
for ch in 'xyz':
|
||||
kR(ch)
|
||||
self.assertEqual(self.p.currentLineBuffer(), ('xyz', ''))
|
||||
|
||||
kR(self.pt.BACKSPACE)
|
||||
self.assertEqual(self.p.currentLineBuffer(), ('xy', ''))
|
||||
|
||||
kR(self.pt.LEFT_ARROW)
|
||||
kR(self.pt.BACKSPACE)
|
||||
self.assertEqual(self.p.currentLineBuffer(), ('', 'y'))
|
||||
|
||||
kR(self.pt.BACKSPACE)
|
||||
self.assertEqual(self.p.currentLineBuffer(), ('', 'y'))
|
||||
|
||||
def test_delete(self):
|
||||
"""
|
||||
When L{HistoricRecvLine} receives a DELETE keystroke, it
|
||||
delets the character immediately after the cursor.
|
||||
"""
|
||||
kR = lambda ch: self.p.keystrokeReceived(ch, None)
|
||||
|
||||
for ch in 'xyz':
|
||||
kR(ch)
|
||||
self.assertEqual(self.p.currentLineBuffer(), ('xyz', ''))
|
||||
|
||||
kR(self.pt.DELETE)
|
||||
self.assertEqual(self.p.currentLineBuffer(), ('xyz', ''))
|
||||
|
||||
kR(self.pt.LEFT_ARROW)
|
||||
kR(self.pt.DELETE)
|
||||
self.assertEqual(self.p.currentLineBuffer(), ('xy', ''))
|
||||
|
||||
kR(self.pt.LEFT_ARROW)
|
||||
kR(self.pt.DELETE)
|
||||
self.assertEqual(self.p.currentLineBuffer(), ('x', ''))
|
||||
|
||||
kR(self.pt.LEFT_ARROW)
|
||||
kR(self.pt.DELETE)
|
||||
self.assertEqual(self.p.currentLineBuffer(), ('', ''))
|
||||
|
||||
kR(self.pt.DELETE)
|
||||
self.assertEqual(self.p.currentLineBuffer(), ('', ''))
|
||||
|
||||
def test_insert(self):
|
||||
"""
|
||||
When not in INSERT mode, L{HistoricRecvLine} inserts the typed
|
||||
character at the cursor before the next character.
|
||||
"""
|
||||
kR = lambda ch: self.p.keystrokeReceived(ch, None)
|
||||
|
||||
for ch in 'xyz':
|
||||
kR(ch)
|
||||
|
||||
kR(self.pt.LEFT_ARROW)
|
||||
kR('A')
|
||||
self.assertEqual(self.p.currentLineBuffer(), ('xyA', 'z'))
|
||||
|
||||
kR(self.pt.LEFT_ARROW)
|
||||
kR('B')
|
||||
self.assertEqual(self.p.currentLineBuffer(), ('xyB', 'Az'))
|
||||
|
||||
def test_typeover(self):
|
||||
"""
|
||||
When in INSERT mode and upon receiving a keystroke with a printable
|
||||
character, L{HistoricRecvLine} replaces the character at
|
||||
the cursor with the typed character rather than inserting before.
|
||||
Ah, the ironies of INSERT mode.
|
||||
"""
|
||||
kR = lambda ch: self.p.keystrokeReceived(ch, None)
|
||||
|
||||
for ch in 'xyz':
|
||||
kR(ch)
|
||||
|
||||
kR(self.pt.INSERT)
|
||||
|
||||
kR(self.pt.LEFT_ARROW)
|
||||
kR('A')
|
||||
self.assertEqual(self.p.currentLineBuffer(), ('xyA', ''))
|
||||
|
||||
kR(self.pt.LEFT_ARROW)
|
||||
kR('B')
|
||||
self.assertEqual(self.p.currentLineBuffer(), ('xyB', ''))
|
||||
|
||||
|
||||
def test_unprintableCharacters(self):
|
||||
"""
|
||||
When L{HistoricRecvLine} receives a keystroke for an unprintable
|
||||
function key with no assigned behavior, the line buffer is unmodified.
|
||||
"""
|
||||
kR = lambda ch: self.p.keystrokeReceived(ch, None)
|
||||
pt = self.pt
|
||||
|
||||
for ch in (pt.F1, pt.F2, pt.F3, pt.F4, pt.F5, pt.F6, pt.F7, pt.F8,
|
||||
pt.F9, pt.F10, pt.F11, pt.F12, pt.PGUP, pt.PGDN):
|
||||
kR(ch)
|
||||
self.assertEqual(self.p.currentLineBuffer(), ('', ''))
|
||||
|
||||
|
||||
from twisted.conch import telnet
|
||||
from twisted.conch.insults import helper
|
||||
from twisted.protocols import loopback
|
||||
|
||||
class EchoServer(recvline.HistoricRecvLine):
|
||||
def lineReceived(self, line):
|
||||
self.terminal.write(line + '\n' + self.ps[self.pn])
|
||||
|
||||
# An insults API for this would be nice.
|
||||
left = "\x1b[D"
|
||||
right = "\x1b[C"
|
||||
up = "\x1b[A"
|
||||
down = "\x1b[B"
|
||||
insert = "\x1b[2~"
|
||||
home = "\x1b[1~"
|
||||
delete = "\x1b[3~"
|
||||
end = "\x1b[4~"
|
||||
backspace = "\x7f"
|
||||
|
||||
from twisted.cred import checkers
|
||||
|
||||
try:
|
||||
from twisted.conch.ssh import userauth, transport, channel, connection, session
|
||||
from twisted.conch.manhole_ssh import TerminalUser, TerminalSession, TerminalRealm, TerminalSessionTransport, ConchFactory
|
||||
except ImportError:
|
||||
ssh = False
|
||||
else:
|
||||
ssh = True
|
||||
class SessionChannel(channel.SSHChannel):
|
||||
name = 'session'
|
||||
|
||||
def __init__(self, protocolFactory, protocolArgs, protocolKwArgs, width, height, *a, **kw):
|
||||
channel.SSHChannel.__init__(self, *a, **kw)
|
||||
|
||||
self.protocolFactory = protocolFactory
|
||||
self.protocolArgs = protocolArgs
|
||||
self.protocolKwArgs = protocolKwArgs
|
||||
|
||||
self.width = width
|
||||
self.height = height
|
||||
|
||||
def channelOpen(self, data):
|
||||
term = session.packRequest_pty_req("vt102", (self.height, self.width, 0, 0), '')
|
||||
self.conn.sendRequest(self, 'pty-req', term)
|
||||
self.conn.sendRequest(self, 'shell', '')
|
||||
|
||||
self._protocolInstance = self.protocolFactory(*self.protocolArgs, **self.protocolKwArgs)
|
||||
self._protocolInstance.factory = self
|
||||
self._protocolInstance.makeConnection(self)
|
||||
|
||||
def closed(self):
|
||||
self._protocolInstance.connectionLost(error.ConnectionDone())
|
||||
|
||||
def dataReceived(self, data):
|
||||
self._protocolInstance.dataReceived(data)
|
||||
|
||||
class TestConnection(connection.SSHConnection):
|
||||
def __init__(self, protocolFactory, protocolArgs, protocolKwArgs, width, height, *a, **kw):
|
||||
connection.SSHConnection.__init__(self, *a, **kw)
|
||||
|
||||
self.protocolFactory = protocolFactory
|
||||
self.protocolArgs = protocolArgs
|
||||
self.protocolKwArgs = protocolKwArgs
|
||||
|
||||
self.width = width
|
||||
self.height = height
|
||||
|
||||
def serviceStarted(self):
|
||||
self.__channel = SessionChannel(self.protocolFactory, self.protocolArgs, self.protocolKwArgs, self.width, self.height)
|
||||
self.openChannel(self.__channel)
|
||||
|
||||
def write(self, bytes):
|
||||
return self.__channel.write(bytes)
|
||||
|
||||
class TestAuth(userauth.SSHUserAuthClient):
|
||||
def __init__(self, username, password, *a, **kw):
|
||||
userauth.SSHUserAuthClient.__init__(self, username, *a, **kw)
|
||||
self.password = password
|
||||
|
||||
def getPassword(self):
|
||||
return defer.succeed(self.password)
|
||||
|
||||
class TestTransport(transport.SSHClientTransport):
|
||||
def __init__(self, protocolFactory, protocolArgs, protocolKwArgs, username, password, width, height, *a, **kw):
|
||||
# transport.SSHClientTransport.__init__(self, *a, **kw)
|
||||
self.protocolFactory = protocolFactory
|
||||
self.protocolArgs = protocolArgs
|
||||
self.protocolKwArgs = protocolKwArgs
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.width = width
|
||||
self.height = height
|
||||
|
||||
def verifyHostKey(self, hostKey, fingerprint):
|
||||
return defer.succeed(True)
|
||||
|
||||
def connectionSecure(self):
|
||||
self.__connection = TestConnection(self.protocolFactory, self.protocolArgs, self.protocolKwArgs, self.width, self.height)
|
||||
self.requestService(
|
||||
TestAuth(self.username, self.password, self.__connection))
|
||||
|
||||
def write(self, bytes):
|
||||
return self.__connection.write(bytes)
|
||||
|
||||
class TestSessionTransport(TerminalSessionTransport):
|
||||
def protocolFactory(self):
|
||||
return self.avatar.conn.transport.factory.serverProtocol()
|
||||
|
||||
class TestSession(TerminalSession):
|
||||
transportFactory = TestSessionTransport
|
||||
|
||||
class TestUser(TerminalUser):
|
||||
pass
|
||||
|
||||
components.registerAdapter(TestSession, TestUser, session.ISession)
|
||||
|
||||
|
||||
class LoopbackRelay(loopback.LoopbackRelay):
|
||||
clearCall = None
|
||||
|
||||
def logPrefix(self):
|
||||
return "LoopbackRelay(%r)" % (self.target.__class__.__name__,)
|
||||
|
||||
def write(self, bytes):
|
||||
loopback.LoopbackRelay.write(self, bytes)
|
||||
if self.clearCall is not None:
|
||||
self.clearCall.cancel()
|
||||
|
||||
from twisted.internet import reactor
|
||||
self.clearCall = reactor.callLater(0, self._clearBuffer)
|
||||
|
||||
def _clearBuffer(self):
|
||||
self.clearCall = None
|
||||
loopback.LoopbackRelay.clearBuffer(self)
|
||||
|
||||
|
||||
class NotifyingExpectableBuffer(helper.ExpectableBuffer):
|
||||
def __init__(self):
|
||||
self.onConnection = defer.Deferred()
|
||||
self.onDisconnection = defer.Deferred()
|
||||
|
||||
def connectionMade(self):
|
||||
helper.ExpectableBuffer.connectionMade(self)
|
||||
self.onConnection.callback(self)
|
||||
|
||||
def connectionLost(self, reason):
|
||||
self.onDisconnection.errback(reason)
|
||||
|
||||
|
||||
class _BaseMixin:
|
||||
WIDTH = 80
|
||||
HEIGHT = 24
|
||||
|
||||
def _assertBuffer(self, lines):
|
||||
receivedLines = str(self.recvlineClient).splitlines()
|
||||
expectedLines = lines + ([''] * (self.HEIGHT - len(lines) - 1))
|
||||
self.assertEqual(len(receivedLines), len(expectedLines))
|
||||
for i in range(len(receivedLines)):
|
||||
self.assertEqual(
|
||||
receivedLines[i], expectedLines[i],
|
||||
str(receivedLines[max(0, i-1):i+1]) +
|
||||
" != " +
|
||||
str(expectedLines[max(0, i-1):i+1]))
|
||||
|
||||
def _trivialTest(self, input, output):
|
||||
done = self.recvlineClient.expect("done")
|
||||
|
||||
self._testwrite(input)
|
||||
|
||||
def finished(ign):
|
||||
self._assertBuffer(output)
|
||||
|
||||
return done.addCallback(finished)
|
||||
|
||||
|
||||
class _SSHMixin(_BaseMixin):
|
||||
def setUp(self):
|
||||
if not ssh:
|
||||
raise unittest.SkipTest("Crypto requirements missing, can't run historic recvline tests over ssh")
|
||||
|
||||
u, p = 'testuser', 'testpass'
|
||||
rlm = TerminalRealm()
|
||||
rlm.userFactory = TestUser
|
||||
rlm.chainedProtocolFactory = lambda: insultsServer
|
||||
|
||||
ptl = portal.Portal(
|
||||
rlm,
|
||||
[checkers.InMemoryUsernamePasswordDatabaseDontUse(**{u: p})])
|
||||
sshFactory = ConchFactory(ptl)
|
||||
sshFactory.serverProtocol = self.serverProtocol
|
||||
sshFactory.startFactory()
|
||||
|
||||
recvlineServer = self.serverProtocol()
|
||||
insultsServer = insults.ServerProtocol(lambda: recvlineServer)
|
||||
sshServer = sshFactory.buildProtocol(None)
|
||||
clientTransport = LoopbackRelay(sshServer)
|
||||
|
||||
recvlineClient = NotifyingExpectableBuffer()
|
||||
insultsClient = insults.ClientProtocol(lambda: recvlineClient)
|
||||
sshClient = TestTransport(lambda: insultsClient, (), {}, u, p, self.WIDTH, self.HEIGHT)
|
||||
serverTransport = LoopbackRelay(sshClient)
|
||||
|
||||
sshClient.makeConnection(clientTransport)
|
||||
sshServer.makeConnection(serverTransport)
|
||||
|
||||
self.recvlineClient = recvlineClient
|
||||
self.sshClient = sshClient
|
||||
self.sshServer = sshServer
|
||||
self.clientTransport = clientTransport
|
||||
self.serverTransport = serverTransport
|
||||
|
||||
return recvlineClient.onConnection
|
||||
|
||||
def _testwrite(self, bytes):
|
||||
self.sshClient.write(bytes)
|
||||
|
||||
from twisted.conch.test import test_telnet
|
||||
|
||||
class TestInsultsClientProtocol(insults.ClientProtocol,
|
||||
test_telnet.TestProtocol):
|
||||
pass
|
||||
|
||||
|
||||
class TestInsultsServerProtocol(insults.ServerProtocol,
|
||||
test_telnet.TestProtocol):
|
||||
pass
|
||||
|
||||
class _TelnetMixin(_BaseMixin):
|
||||
def setUp(self):
|
||||
recvlineServer = self.serverProtocol()
|
||||
insultsServer = TestInsultsServerProtocol(lambda: recvlineServer)
|
||||
telnetServer = telnet.TelnetTransport(lambda: insultsServer)
|
||||
clientTransport = LoopbackRelay(telnetServer)
|
||||
|
||||
recvlineClient = NotifyingExpectableBuffer()
|
||||
insultsClient = TestInsultsClientProtocol(lambda: recvlineClient)
|
||||
telnetClient = telnet.TelnetTransport(lambda: insultsClient)
|
||||
serverTransport = LoopbackRelay(telnetClient)
|
||||
|
||||
telnetClient.makeConnection(clientTransport)
|
||||
telnetServer.makeConnection(serverTransport)
|
||||
|
||||
serverTransport.clearBuffer()
|
||||
clientTransport.clearBuffer()
|
||||
|
||||
self.recvlineClient = recvlineClient
|
||||
self.telnetClient = telnetClient
|
||||
self.clientTransport = clientTransport
|
||||
self.serverTransport = serverTransport
|
||||
|
||||
return recvlineClient.onConnection
|
||||
|
||||
def _testwrite(self, bytes):
|
||||
self.telnetClient.write(bytes)
|
||||
|
||||
try:
|
||||
from twisted.conch import stdio
|
||||
except ImportError:
|
||||
stdio = None
|
||||
|
||||
class _StdioMixin(_BaseMixin):
|
||||
def setUp(self):
|
||||
# A memory-only terminal emulator, into which the server will
|
||||
# write things and make other state changes. What ends up
|
||||
# here is basically what a user would have seen on their
|
||||
# screen.
|
||||
testTerminal = NotifyingExpectableBuffer()
|
||||
|
||||
# An insults client protocol which will translate bytes
|
||||
# received from the child process into keystroke commands for
|
||||
# an ITerminalProtocol.
|
||||
insultsClient = insults.ClientProtocol(lambda: testTerminal)
|
||||
|
||||
# A process protocol which will translate stdout and stderr
|
||||
# received from the child process to dataReceived calls and
|
||||
# error reporting on an insults client protocol.
|
||||
processClient = stdio.TerminalProcessProtocol(insultsClient)
|
||||
|
||||
# Run twisted/conch/stdio.py with the name of a class
|
||||
# implementing ITerminalProtocol. This class will be used to
|
||||
# handle bytes we send to the child process.
|
||||
exe = sys.executable
|
||||
module = stdio.__file__
|
||||
if module.endswith('.pyc') or module.endswith('.pyo'):
|
||||
module = module[:-1]
|
||||
args = [exe, module, reflect.qual(self.serverProtocol)]
|
||||
env = os.environ.copy()
|
||||
env["PYTHONPATH"] = os.pathsep.join(sys.path)
|
||||
|
||||
from twisted.internet import reactor
|
||||
clientTransport = reactor.spawnProcess(processClient, exe, args,
|
||||
env=env, usePTY=True)
|
||||
|
||||
self.recvlineClient = self.testTerminal = testTerminal
|
||||
self.processClient = processClient
|
||||
self.clientTransport = clientTransport
|
||||
|
||||
# Wait for the process protocol and test terminal to become
|
||||
# connected before proceeding. The former should always
|
||||
# happen first, but it doesn't hurt to be safe.
|
||||
return defer.gatherResults(filter(None, [
|
||||
processClient.onConnection,
|
||||
testTerminal.expect(">>> ")]))
|
||||
|
||||
def tearDown(self):
|
||||
# Kill the child process. We're done with it.
|
||||
try:
|
||||
self.clientTransport.signalProcess("KILL")
|
||||
except (error.ProcessExitedAlready, OSError):
|
||||
pass
|
||||
def trap(failure):
|
||||
failure.trap(error.ProcessTerminated)
|
||||
self.assertEqual(failure.value.exitCode, None)
|
||||
self.assertEqual(failure.value.status, 9)
|
||||
return self.testTerminal.onDisconnection.addErrback(trap)
|
||||
|
||||
def _testwrite(self, bytes):
|
||||
self.clientTransport.write(bytes)
|
||||
|
||||
class RecvlineLoopbackMixin:
|
||||
serverProtocol = EchoServer
|
||||
|
||||
def testSimple(self):
|
||||
return self._trivialTest(
|
||||
"first line\ndone",
|
||||
[">>> first line",
|
||||
"first line",
|
||||
">>> done"])
|
||||
|
||||
def testLeftArrow(self):
|
||||
return self._trivialTest(
|
||||
insert + 'first line' + left * 4 + "xxxx\ndone",
|
||||
[">>> first xxxx",
|
||||
"first xxxx",
|
||||
">>> done"])
|
||||
|
||||
def testRightArrow(self):
|
||||
return self._trivialTest(
|
||||
insert + 'right line' + left * 4 + right * 2 + "xx\ndone",
|
||||
[">>> right lixx",
|
||||
"right lixx",
|
||||
">>> done"])
|
||||
|
||||
def testBackspace(self):
|
||||
return self._trivialTest(
|
||||
"second line" + backspace * 4 + "xxxx\ndone",
|
||||
[">>> second xxxx",
|
||||
"second xxxx",
|
||||
">>> done"])
|
||||
|
||||
def testDelete(self):
|
||||
return self._trivialTest(
|
||||
"delete xxxx" + left * 4 + delete * 4 + "line\ndone",
|
||||
[">>> delete line",
|
||||
"delete line",
|
||||
">>> done"])
|
||||
|
||||
def testInsert(self):
|
||||
return self._trivialTest(
|
||||
"third ine" + left * 3 + "l\ndone",
|
||||
[">>> third line",
|
||||
"third line",
|
||||
">>> done"])
|
||||
|
||||
def testTypeover(self):
|
||||
return self._trivialTest(
|
||||
"fourth xine" + left * 4 + insert + "l\ndone",
|
||||
[">>> fourth line",
|
||||
"fourth line",
|
||||
">>> done"])
|
||||
|
||||
def testHome(self):
|
||||
return self._trivialTest(
|
||||
insert + "blah line" + home + "home\ndone",
|
||||
[">>> home line",
|
||||
"home line",
|
||||
">>> done"])
|
||||
|
||||
def testEnd(self):
|
||||
return self._trivialTest(
|
||||
"end " + left * 4 + end + "line\ndone",
|
||||
[">>> end line",
|
||||
"end line",
|
||||
">>> done"])
|
||||
|
||||
class RecvlineLoopbackTelnet(_TelnetMixin, unittest.TestCase, RecvlineLoopbackMixin):
|
||||
pass
|
||||
|
||||
class RecvlineLoopbackSSH(_SSHMixin, unittest.TestCase, RecvlineLoopbackMixin):
|
||||
pass
|
||||
|
||||
class RecvlineLoopbackStdio(_StdioMixin, unittest.TestCase, RecvlineLoopbackMixin):
|
||||
if stdio is None:
|
||||
skip = "Terminal requirements missing, can't run recvline tests over stdio"
|
||||
|
||||
|
||||
class HistoricRecvlineLoopbackMixin:
|
||||
serverProtocol = EchoServer
|
||||
|
||||
def testUpArrow(self):
|
||||
return self._trivialTest(
|
||||
"first line\n" + up + "\ndone",
|
||||
[">>> first line",
|
||||
"first line",
|
||||
">>> first line",
|
||||
"first line",
|
||||
">>> done"])
|
||||
|
||||
def testDownArrow(self):
|
||||
return self._trivialTest(
|
||||
"first line\nsecond line\n" + up * 2 + down + "\ndone",
|
||||
[">>> first line",
|
||||
"first line",
|
||||
">>> second line",
|
||||
"second line",
|
||||
">>> second line",
|
||||
"second line",
|
||||
">>> done"])
|
||||
|
||||
class HistoricRecvlineLoopbackTelnet(_TelnetMixin, unittest.TestCase, HistoricRecvlineLoopbackMixin):
|
||||
pass
|
||||
|
||||
class HistoricRecvlineLoopbackSSH(_SSHMixin, unittest.TestCase, HistoricRecvlineLoopbackMixin):
|
||||
pass
|
||||
|
||||
class HistoricRecvlineLoopbackStdio(_StdioMixin, unittest.TestCase, HistoricRecvlineLoopbackMixin):
|
||||
if stdio is None:
|
||||
skip = "Terminal requirements missing, can't run historic recvline tests over stdio"
|
||||
|
|
@ -0,0 +1,82 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for the command-line interfaces to conch.
|
||||
"""
|
||||
|
||||
try:
|
||||
import pyasn1
|
||||
except ImportError:
|
||||
pyasn1Skip = "Cannot run without PyASN1"
|
||||
else:
|
||||
pyasn1Skip = None
|
||||
|
||||
try:
|
||||
import Crypto
|
||||
except ImportError:
|
||||
cryptoSkip = "can't run w/o PyCrypto"
|
||||
else:
|
||||
cryptoSkip = None
|
||||
|
||||
try:
|
||||
import tty
|
||||
except ImportError:
|
||||
ttySkip = "can't run w/o tty"
|
||||
else:
|
||||
ttySkip = None
|
||||
|
||||
try:
|
||||
import Tkinter
|
||||
except ImportError:
|
||||
tkskip = "can't run w/o Tkinter"
|
||||
else:
|
||||
try:
|
||||
Tkinter.Tk().destroy()
|
||||
except Tkinter.TclError, e:
|
||||
tkskip = "Can't test Tkinter: " + str(e)
|
||||
else:
|
||||
tkskip = None
|
||||
|
||||
from twisted.trial.unittest import TestCase
|
||||
from twisted.scripts.test.test_scripts import ScriptTestsMixin
|
||||
from twisted.python.test.test_shellcomp import ZshScriptTestMixin
|
||||
|
||||
|
||||
|
||||
class ScriptTests(TestCase, ScriptTestsMixin):
|
||||
"""
|
||||
Tests for the Conch scripts.
|
||||
"""
|
||||
skip = pyasn1Skip or cryptoSkip
|
||||
|
||||
|
||||
def test_conch(self):
|
||||
self.scriptTest("conch/conch")
|
||||
test_conch.skip = ttySkip or skip
|
||||
|
||||
|
||||
def test_cftp(self):
|
||||
self.scriptTest("conch/cftp")
|
||||
test_cftp.skip = ttySkip or skip
|
||||
|
||||
|
||||
def test_ckeygen(self):
|
||||
self.scriptTest("conch/ckeygen")
|
||||
|
||||
|
||||
def test_tkconch(self):
|
||||
self.scriptTest("conch/tkconch")
|
||||
test_tkconch.skip = tkskip or skip
|
||||
|
||||
|
||||
|
||||
class ZshIntegrationTestCase(TestCase, ZshScriptTestMixin):
|
||||
"""
|
||||
Test that zsh completion functions are generated without error
|
||||
"""
|
||||
generateFor = [('conch', 'twisted.conch.scripts.conch.ClientOptions'),
|
||||
('cftp', 'twisted.conch.scripts.cftp.ClientOptions'),
|
||||
('ckeygen', 'twisted.conch.scripts.ckeygen.GeneralOptions'),
|
||||
('tkconch', 'twisted.conch.scripts.tkconch.GeneralOptions'),
|
||||
]
|
||||
1254
Linux/lib/python2.7/site-packages/twisted/conch/test/test_session.py
Normal file
1254
Linux/lib/python2.7/site-packages/twisted/conch/test/test_session.py
Normal file
File diff suppressed because it is too large
Load diff
995
Linux/lib/python2.7/site-packages/twisted/conch/test/test_ssh.py
Normal file
995
Linux/lib/python2.7/site-packages/twisted/conch/test/test_ssh.py
Normal file
|
|
@ -0,0 +1,995 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.conch.ssh}.
|
||||
"""
|
||||
|
||||
import struct
|
||||
|
||||
try:
|
||||
import Crypto.Cipher.DES3
|
||||
except ImportError:
|
||||
Crypto = None
|
||||
|
||||
try:
|
||||
import pyasn1
|
||||
except ImportError:
|
||||
pyasn1 = None
|
||||
|
||||
from twisted.conch.ssh import common, session, forwarding
|
||||
from twisted.conch import avatar, error
|
||||
from twisted.conch.test.keydata import publicRSA_openssh, privateRSA_openssh
|
||||
from twisted.conch.test.keydata import publicDSA_openssh, privateDSA_openssh
|
||||
from twisted.cred import portal
|
||||
from twisted.cred.error import UnauthorizedLogin
|
||||
from twisted.internet import defer, protocol, reactor
|
||||
from twisted.internet.error import ProcessTerminated
|
||||
from twisted.python import failure, log
|
||||
from twisted.trial import unittest
|
||||
|
||||
from twisted.conch.test.test_recvline import LoopbackRelay
|
||||
|
||||
|
||||
|
||||
class ConchTestRealm(object):
|
||||
"""
|
||||
A realm which expects a particular avatarId to log in once and creates a
|
||||
L{ConchTestAvatar} for that request.
|
||||
|
||||
@ivar expectedAvatarID: The only avatarID that this realm will produce an
|
||||
avatar for.
|
||||
|
||||
@ivar avatar: A reference to the avatar after it is requested.
|
||||
"""
|
||||
avatar = None
|
||||
|
||||
def __init__(self, expectedAvatarID):
|
||||
self.expectedAvatarID = expectedAvatarID
|
||||
|
||||
|
||||
def requestAvatar(self, avatarID, mind, *interfaces):
|
||||
"""
|
||||
Return a new L{ConchTestAvatar} if the avatarID matches the expected one
|
||||
and this is the first avatar request.
|
||||
"""
|
||||
if avatarID == self.expectedAvatarID:
|
||||
if self.avatar is not None:
|
||||
raise UnauthorizedLogin("Only one login allowed")
|
||||
self.avatar = ConchTestAvatar()
|
||||
return interfaces[0], self.avatar, self.avatar.logout
|
||||
raise UnauthorizedLogin(
|
||||
"Only %r may log in, not %r" % (self.expectedAvatarID, avatarID))
|
||||
|
||||
|
||||
|
||||
class ConchTestAvatar(avatar.ConchUser):
|
||||
"""
|
||||
An avatar against which various SSH features can be tested.
|
||||
|
||||
@ivar loggedOut: A flag indicating whether the avatar logout method has been
|
||||
called.
|
||||
"""
|
||||
loggedOut = False
|
||||
|
||||
def __init__(self):
|
||||
avatar.ConchUser.__init__(self)
|
||||
self.listeners = {}
|
||||
self.globalRequests = {}
|
||||
self.channelLookup.update({'session': session.SSHSession,
|
||||
'direct-tcpip':forwarding.openConnectForwardingClient})
|
||||
self.subsystemLookup.update({'crazy': CrazySubsystem})
|
||||
|
||||
|
||||
def global_foo(self, data):
|
||||
self.globalRequests['foo'] = data
|
||||
return 1
|
||||
|
||||
|
||||
def global_foo_2(self, data):
|
||||
self.globalRequests['foo_2'] = data
|
||||
return 1, 'data'
|
||||
|
||||
|
||||
def global_tcpip_forward(self, data):
|
||||
host, port = forwarding.unpackGlobal_tcpip_forward(data)
|
||||
try:
|
||||
listener = reactor.listenTCP(
|
||||
port, forwarding.SSHListenForwardingFactory(
|
||||
self.conn, (host, port),
|
||||
forwarding.SSHListenServerForwardingChannel),
|
||||
interface=host)
|
||||
except:
|
||||
log.err(None, "something went wrong with remote->local forwarding")
|
||||
return 0
|
||||
else:
|
||||
self.listeners[(host, port)] = listener
|
||||
return 1
|
||||
|
||||
|
||||
def global_cancel_tcpip_forward(self, data):
|
||||
host, port = forwarding.unpackGlobal_tcpip_forward(data)
|
||||
listener = self.listeners.get((host, port), None)
|
||||
if not listener:
|
||||
return 0
|
||||
del self.listeners[(host, port)]
|
||||
listener.stopListening()
|
||||
return 1
|
||||
|
||||
|
||||
def logout(self):
|
||||
self.loggedOut = True
|
||||
for listener in self.listeners.values():
|
||||
log.msg('stopListening %s' % listener)
|
||||
listener.stopListening()
|
||||
|
||||
|
||||
|
||||
class ConchSessionForTestAvatar(object):
|
||||
"""
|
||||
An ISession adapter for ConchTestAvatar.
|
||||
"""
|
||||
def __init__(self, avatar):
|
||||
"""
|
||||
Initialize the session and create a reference to it on the avatar for
|
||||
later inspection.
|
||||
"""
|
||||
self.avatar = avatar
|
||||
self.avatar._testSession = self
|
||||
self.cmd = None
|
||||
self.proto = None
|
||||
self.ptyReq = False
|
||||
self.eof = 0
|
||||
self.onClose = defer.Deferred()
|
||||
|
||||
|
||||
def getPty(self, term, windowSize, attrs):
|
||||
log.msg('pty req')
|
||||
self._terminalType = term
|
||||
self._windowSize = windowSize
|
||||
self.ptyReq = True
|
||||
|
||||
|
||||
def openShell(self, proto):
|
||||
log.msg('opening shell')
|
||||
self.proto = proto
|
||||
EchoTransport(proto)
|
||||
self.cmd = 'shell'
|
||||
|
||||
|
||||
def execCommand(self, proto, cmd):
|
||||
self.cmd = cmd
|
||||
self.proto = proto
|
||||
f = cmd.split()[0]
|
||||
if f == 'false':
|
||||
t = FalseTransport(proto)
|
||||
# Avoid disconnecting this immediately. If the channel is closed
|
||||
# before execCommand even returns the caller gets confused.
|
||||
reactor.callLater(0, t.loseConnection)
|
||||
elif f == 'echo':
|
||||
t = EchoTransport(proto)
|
||||
t.write(cmd[5:])
|
||||
t.loseConnection()
|
||||
elif f == 'secho':
|
||||
t = SuperEchoTransport(proto)
|
||||
t.write(cmd[6:])
|
||||
t.loseConnection()
|
||||
elif f == 'eecho':
|
||||
t = ErrEchoTransport(proto)
|
||||
t.write(cmd[6:])
|
||||
t.loseConnection()
|
||||
else:
|
||||
raise error.ConchError('bad exec')
|
||||
self.avatar.conn.transport.expectedLoseConnection = 1
|
||||
|
||||
|
||||
def eofReceived(self):
|
||||
self.eof = 1
|
||||
|
||||
|
||||
def closed(self):
|
||||
log.msg('closed cmd "%s"' % self.cmd)
|
||||
self.remoteWindowLeftAtClose = self.proto.session.remoteWindowLeft
|
||||
self.onClose.callback(None)
|
||||
|
||||
from twisted.python import components
|
||||
components.registerAdapter(ConchSessionForTestAvatar, ConchTestAvatar, session.ISession)
|
||||
|
||||
class CrazySubsystem(protocol.Protocol):
|
||||
|
||||
def __init__(self, *args, **kw):
|
||||
pass
|
||||
|
||||
def connectionMade(self):
|
||||
"""
|
||||
good ... good
|
||||
"""
|
||||
|
||||
|
||||
|
||||
class FalseTransport:
|
||||
"""
|
||||
False transport should act like a /bin/false execution, i.e. just exit with
|
||||
nonzero status, writing nothing to the terminal.
|
||||
|
||||
@ivar proto: The protocol associated with this transport.
|
||||
@ivar closed: A flag tracking whether C{loseConnection} has been called yet.
|
||||
"""
|
||||
|
||||
def __init__(self, p):
|
||||
"""
|
||||
@type p L{twisted.conch.ssh.session.SSHSessionProcessProtocol} instance
|
||||
"""
|
||||
self.proto = p
|
||||
p.makeConnection(self)
|
||||
self.closed = 0
|
||||
|
||||
|
||||
def loseConnection(self):
|
||||
"""
|
||||
Disconnect the protocol associated with this transport.
|
||||
"""
|
||||
if self.closed:
|
||||
return
|
||||
self.closed = 1
|
||||
self.proto.inConnectionLost()
|
||||
self.proto.outConnectionLost()
|
||||
self.proto.errConnectionLost()
|
||||
self.proto.processEnded(failure.Failure(ProcessTerminated(255, None, None)))
|
||||
|
||||
|
||||
|
||||
class EchoTransport:
|
||||
|
||||
def __init__(self, p):
|
||||
self.proto = p
|
||||
p.makeConnection(self)
|
||||
self.closed = 0
|
||||
|
||||
def write(self, data):
|
||||
log.msg(repr(data))
|
||||
self.proto.outReceived(data)
|
||||
self.proto.outReceived('\r\n')
|
||||
if '\x00' in data: # mimic 'exit' for the shell test
|
||||
self.loseConnection()
|
||||
|
||||
def loseConnection(self):
|
||||
if self.closed: return
|
||||
self.closed = 1
|
||||
self.proto.inConnectionLost()
|
||||
self.proto.outConnectionLost()
|
||||
self.proto.errConnectionLost()
|
||||
self.proto.processEnded(failure.Failure(ProcessTerminated(0, None, None)))
|
||||
|
||||
class ErrEchoTransport:
|
||||
|
||||
def __init__(self, p):
|
||||
self.proto = p
|
||||
p.makeConnection(self)
|
||||
self.closed = 0
|
||||
|
||||
def write(self, data):
|
||||
self.proto.errReceived(data)
|
||||
self.proto.errReceived('\r\n')
|
||||
|
||||
def loseConnection(self):
|
||||
if self.closed: return
|
||||
self.closed = 1
|
||||
self.proto.inConnectionLost()
|
||||
self.proto.outConnectionLost()
|
||||
self.proto.errConnectionLost()
|
||||
self.proto.processEnded(failure.Failure(ProcessTerminated(0, None, None)))
|
||||
|
||||
class SuperEchoTransport:
|
||||
|
||||
def __init__(self, p):
|
||||
self.proto = p
|
||||
p.makeConnection(self)
|
||||
self.closed = 0
|
||||
|
||||
def write(self, data):
|
||||
self.proto.outReceived(data)
|
||||
self.proto.outReceived('\r\n')
|
||||
self.proto.errReceived(data)
|
||||
self.proto.errReceived('\r\n')
|
||||
|
||||
def loseConnection(self):
|
||||
if self.closed: return
|
||||
self.closed = 1
|
||||
self.proto.inConnectionLost()
|
||||
self.proto.outConnectionLost()
|
||||
self.proto.errConnectionLost()
|
||||
self.proto.processEnded(failure.Failure(ProcessTerminated(0, None, None)))
|
||||
|
||||
|
||||
if Crypto is not None and pyasn1 is not None:
|
||||
from twisted.conch import checkers
|
||||
from twisted.conch.ssh import channel, connection, factory, keys
|
||||
from twisted.conch.ssh import transport, userauth
|
||||
|
||||
class UtilityTestCase(unittest.TestCase):
|
||||
def testCounter(self):
|
||||
c = transport._Counter('\x00\x00', 2)
|
||||
for i in xrange(256 * 256):
|
||||
self.assertEqual(c(), struct.pack('!H', (i + 1) % (2 ** 16)))
|
||||
# It should wrap around, too.
|
||||
for i in xrange(256 * 256):
|
||||
self.assertEqual(c(), struct.pack('!H', (i + 1) % (2 ** 16)))
|
||||
|
||||
|
||||
class ConchTestPublicKeyChecker(checkers.SSHPublicKeyDatabase):
|
||||
def checkKey(self, credentials):
|
||||
blob = keys.Key.fromString(publicDSA_openssh).blob()
|
||||
if credentials.username == 'testuser' and credentials.blob == blob:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class ConchTestPasswordChecker:
|
||||
credentialInterfaces = checkers.IUsernamePassword,
|
||||
|
||||
def requestAvatarId(self, credentials):
|
||||
if credentials.username == 'testuser' and credentials.password == 'testpass':
|
||||
return defer.succeed(credentials.username)
|
||||
return defer.fail(Exception("Bad credentials"))
|
||||
|
||||
|
||||
class ConchTestSSHChecker(checkers.SSHProtocolChecker):
|
||||
|
||||
def areDone(self, avatarId):
|
||||
if avatarId != 'testuser' or len(self.successfulCredentials[avatarId]) < 2:
|
||||
return False
|
||||
return True
|
||||
|
||||
class ConchTestServerFactory(factory.SSHFactory):
|
||||
noisy = 0
|
||||
|
||||
services = {
|
||||
'ssh-userauth':userauth.SSHUserAuthServer,
|
||||
'ssh-connection':connection.SSHConnection
|
||||
}
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
proto = ConchTestServer()
|
||||
proto.supportedPublicKeys = self.privateKeys.keys()
|
||||
proto.factory = self
|
||||
|
||||
if hasattr(self, 'expectedLoseConnection'):
|
||||
proto.expectedLoseConnection = self.expectedLoseConnection
|
||||
|
||||
self.proto = proto
|
||||
return proto
|
||||
|
||||
def getPublicKeys(self):
|
||||
return {
|
||||
'ssh-rsa': keys.Key.fromString(publicRSA_openssh),
|
||||
'ssh-dss': keys.Key.fromString(publicDSA_openssh)
|
||||
}
|
||||
|
||||
def getPrivateKeys(self):
|
||||
return {
|
||||
'ssh-rsa': keys.Key.fromString(privateRSA_openssh),
|
||||
'ssh-dss': keys.Key.fromString(privateDSA_openssh)
|
||||
}
|
||||
|
||||
def getPrimes(self):
|
||||
return {
|
||||
2048:[(transport.DH_GENERATOR, transport.DH_PRIME)]
|
||||
}
|
||||
|
||||
def getService(self, trans, name):
|
||||
return factory.SSHFactory.getService(self, trans, name)
|
||||
|
||||
class ConchTestBase:
|
||||
|
||||
done = 0
|
||||
|
||||
def connectionLost(self, reason):
|
||||
if self.done:
|
||||
return
|
||||
if not hasattr(self,'expectedLoseConnection'):
|
||||
unittest.fail('unexpectedly lost connection %s\n%s' % (self, reason))
|
||||
self.done = 1
|
||||
|
||||
def receiveError(self, reasonCode, desc):
|
||||
self.expectedLoseConnection = 1
|
||||
# Some versions of OpenSSH (for example, OpenSSH_5.3p1) will
|
||||
# send a DISCONNECT_BY_APPLICATION error before closing the
|
||||
# connection. Other, older versions (for example,
|
||||
# OpenSSH_5.1p1), won't. So accept this particular error here,
|
||||
# but no others.
|
||||
if reasonCode != transport.DISCONNECT_BY_APPLICATION:
|
||||
log.err(
|
||||
Exception(
|
||||
'got disconnect for %s: reason %s, desc: %s' % (
|
||||
self, reasonCode, desc)))
|
||||
self.loseConnection()
|
||||
|
||||
def receiveUnimplemented(self, seqID):
|
||||
unittest.fail('got unimplemented: seqid %s' % seqID)
|
||||
self.expectedLoseConnection = 1
|
||||
self.loseConnection()
|
||||
|
||||
class ConchTestServer(ConchTestBase, transport.SSHServerTransport):
|
||||
|
||||
def connectionLost(self, reason):
|
||||
ConchTestBase.connectionLost(self, reason)
|
||||
transport.SSHServerTransport.connectionLost(self, reason)
|
||||
|
||||
|
||||
class ConchTestClient(ConchTestBase, transport.SSHClientTransport):
|
||||
"""
|
||||
@ivar _channelFactory: A callable which accepts an SSH connection and
|
||||
returns a channel which will be attached to a new channel on that
|
||||
connection.
|
||||
"""
|
||||
def __init__(self, channelFactory):
|
||||
self._channelFactory = channelFactory
|
||||
|
||||
def connectionLost(self, reason):
|
||||
ConchTestBase.connectionLost(self, reason)
|
||||
transport.SSHClientTransport.connectionLost(self, reason)
|
||||
|
||||
def verifyHostKey(self, key, fp):
|
||||
keyMatch = key == keys.Key.fromString(publicRSA_openssh).blob()
|
||||
fingerprintMatch = (
|
||||
fp == '3d:13:5f:cb:c9:79:8a:93:06:27:65:bc:3d:0b:8f:af')
|
||||
if keyMatch and fingerprintMatch:
|
||||
return defer.succeed(1)
|
||||
return defer.fail(Exception("Key or fingerprint mismatch"))
|
||||
|
||||
def connectionSecure(self):
|
||||
self.requestService(ConchTestClientAuth('testuser',
|
||||
ConchTestClientConnection(self._channelFactory)))
|
||||
|
||||
|
||||
class ConchTestClientAuth(userauth.SSHUserAuthClient):
|
||||
|
||||
hasTriedNone = 0 # have we tried the 'none' auth yet?
|
||||
canSucceedPublicKey = 0 # can we succed with this yet?
|
||||
canSucceedPassword = 0
|
||||
|
||||
def ssh_USERAUTH_SUCCESS(self, packet):
|
||||
if not self.canSucceedPassword and self.canSucceedPublicKey:
|
||||
unittest.fail('got USERAUTH_SUCESS before password and publickey')
|
||||
userauth.SSHUserAuthClient.ssh_USERAUTH_SUCCESS(self, packet)
|
||||
|
||||
def getPassword(self):
|
||||
self.canSucceedPassword = 1
|
||||
return defer.succeed('testpass')
|
||||
|
||||
def getPrivateKey(self):
|
||||
self.canSucceedPublicKey = 1
|
||||
return defer.succeed(keys.Key.fromString(privateDSA_openssh))
|
||||
|
||||
def getPublicKey(self):
|
||||
return keys.Key.fromString(publicDSA_openssh)
|
||||
|
||||
|
||||
class ConchTestClientConnection(connection.SSHConnection):
|
||||
"""
|
||||
@ivar _completed: A L{Deferred} which will be fired when the number of
|
||||
results collected reaches C{totalResults}.
|
||||
"""
|
||||
name = 'ssh-connection'
|
||||
results = 0
|
||||
totalResults = 8
|
||||
|
||||
def __init__(self, channelFactory):
|
||||
connection.SSHConnection.__init__(self)
|
||||
self._channelFactory = channelFactory
|
||||
|
||||
def serviceStarted(self):
|
||||
self.openChannel(self._channelFactory(conn=self))
|
||||
|
||||
|
||||
class SSHTestChannel(channel.SSHChannel):
|
||||
|
||||
def __init__(self, name, opened, *args, **kwargs):
|
||||
self.name = name
|
||||
self._opened = opened
|
||||
self.received = []
|
||||
self.receivedExt = []
|
||||
self.onClose = defer.Deferred()
|
||||
channel.SSHChannel.__init__(self, *args, **kwargs)
|
||||
|
||||
|
||||
def openFailed(self, reason):
|
||||
self._opened.errback(reason)
|
||||
|
||||
|
||||
def channelOpen(self, ignore):
|
||||
self._opened.callback(self)
|
||||
|
||||
|
||||
def dataReceived(self, data):
|
||||
self.received.append(data)
|
||||
|
||||
|
||||
def extReceived(self, dataType, data):
|
||||
if dataType == connection.EXTENDED_DATA_STDERR:
|
||||
self.receivedExt.append(data)
|
||||
else:
|
||||
log.msg("Unrecognized extended data: %r" % (dataType,))
|
||||
|
||||
|
||||
def request_exit_status(self, status):
|
||||
[self.status] = struct.unpack('>L', status)
|
||||
|
||||
|
||||
def eofReceived(self):
|
||||
self.eofCalled = True
|
||||
|
||||
|
||||
def closed(self):
|
||||
self.onClose.callback(None)
|
||||
|
||||
|
||||
|
||||
class SSHProtocolTestCase(unittest.TestCase):
|
||||
"""
|
||||
Tests for communication between L{SSHServerTransport} and
|
||||
L{SSHClientTransport}.
|
||||
"""
|
||||
|
||||
if not Crypto:
|
||||
skip = "can't run w/o PyCrypto"
|
||||
|
||||
if not pyasn1:
|
||||
skip = "Cannot run without PyASN1"
|
||||
|
||||
def _ourServerOurClientTest(self, name='session', **kwargs):
|
||||
"""
|
||||
Create a connected SSH client and server protocol pair and return a
|
||||
L{Deferred} which fires with an L{SSHTestChannel} instance connected to
|
||||
a channel on that SSH connection.
|
||||
"""
|
||||
result = defer.Deferred()
|
||||
self.realm = ConchTestRealm('testuser')
|
||||
p = portal.Portal(self.realm)
|
||||
sshpc = ConchTestSSHChecker()
|
||||
sshpc.registerChecker(ConchTestPasswordChecker())
|
||||
sshpc.registerChecker(ConchTestPublicKeyChecker())
|
||||
p.registerChecker(sshpc)
|
||||
fac = ConchTestServerFactory()
|
||||
fac.portal = p
|
||||
fac.startFactory()
|
||||
self.server = fac.buildProtocol(None)
|
||||
self.clientTransport = LoopbackRelay(self.server)
|
||||
self.client = ConchTestClient(
|
||||
lambda conn: SSHTestChannel(name, result, conn=conn, **kwargs))
|
||||
|
||||
self.serverTransport = LoopbackRelay(self.client)
|
||||
|
||||
self.server.makeConnection(self.serverTransport)
|
||||
self.client.makeConnection(self.clientTransport)
|
||||
return result
|
||||
|
||||
|
||||
def test_subsystemsAndGlobalRequests(self):
|
||||
"""
|
||||
Run the Conch server against the Conch client. Set up several different
|
||||
channels which exercise different behaviors and wait for them to
|
||||
complete. Verify that the channels with errors log them.
|
||||
"""
|
||||
channel = self._ourServerOurClientTest()
|
||||
|
||||
def cbSubsystem(channel):
|
||||
self.channel = channel
|
||||
return self.assertFailure(
|
||||
channel.conn.sendRequest(
|
||||
channel, 'subsystem', common.NS('not-crazy'), 1),
|
||||
Exception)
|
||||
channel.addCallback(cbSubsystem)
|
||||
|
||||
def cbNotCrazyFailed(ignored):
|
||||
channel = self.channel
|
||||
return channel.conn.sendRequest(
|
||||
channel, 'subsystem', common.NS('crazy'), 1)
|
||||
channel.addCallback(cbNotCrazyFailed)
|
||||
|
||||
def cbGlobalRequests(ignored):
|
||||
channel = self.channel
|
||||
d1 = channel.conn.sendGlobalRequest('foo', 'bar', 1)
|
||||
|
||||
d2 = channel.conn.sendGlobalRequest('foo-2', 'bar2', 1)
|
||||
d2.addCallback(self.assertEqual, 'data')
|
||||
|
||||
d3 = self.assertFailure(
|
||||
channel.conn.sendGlobalRequest('bar', 'foo', 1),
|
||||
Exception)
|
||||
|
||||
return defer.gatherResults([d1, d2, d3])
|
||||
channel.addCallback(cbGlobalRequests)
|
||||
|
||||
def disconnect(ignored):
|
||||
self.assertEqual(
|
||||
self.realm.avatar.globalRequests,
|
||||
{"foo": "bar", "foo_2": "bar2"})
|
||||
channel = self.channel
|
||||
channel.conn.transport.expectedLoseConnection = True
|
||||
channel.conn.serviceStopped()
|
||||
channel.loseConnection()
|
||||
channel.addCallback(disconnect)
|
||||
|
||||
return channel
|
||||
|
||||
|
||||
def test_shell(self):
|
||||
"""
|
||||
L{SSHChannel.sendRequest} can open a shell with a I{pty-req} request,
|
||||
specifying a terminal type and window size.
|
||||
"""
|
||||
channel = self._ourServerOurClientTest()
|
||||
|
||||
data = session.packRequest_pty_req('conch-test-term', (24, 80, 0, 0), '')
|
||||
def cbChannel(channel):
|
||||
self.channel = channel
|
||||
return channel.conn.sendRequest(channel, 'pty-req', data, 1)
|
||||
channel.addCallback(cbChannel)
|
||||
|
||||
def cbPty(ignored):
|
||||
# The server-side object corresponding to our client side channel.
|
||||
session = self.realm.avatar.conn.channels[0].session
|
||||
self.assertIs(session.avatar, self.realm.avatar)
|
||||
self.assertEqual(session._terminalType, 'conch-test-term')
|
||||
self.assertEqual(session._windowSize, (24, 80, 0, 0))
|
||||
self.assertTrue(session.ptyReq)
|
||||
channel = self.channel
|
||||
return channel.conn.sendRequest(channel, 'shell', '', 1)
|
||||
channel.addCallback(cbPty)
|
||||
|
||||
def cbShell(ignored):
|
||||
self.channel.write('testing the shell!\x00')
|
||||
self.channel.conn.sendEOF(self.channel)
|
||||
return defer.gatherResults([
|
||||
self.channel.onClose,
|
||||
self.realm.avatar._testSession.onClose])
|
||||
channel.addCallback(cbShell)
|
||||
|
||||
def cbExited(ignored):
|
||||
if self.channel.status != 0:
|
||||
log.msg(
|
||||
'shell exit status was not 0: %i' % (self.channel.status,))
|
||||
self.assertEqual(
|
||||
"".join(self.channel.received),
|
||||
'testing the shell!\x00\r\n')
|
||||
self.assertTrue(self.channel.eofCalled)
|
||||
self.assertTrue(
|
||||
self.realm.avatar._testSession.eof)
|
||||
channel.addCallback(cbExited)
|
||||
return channel
|
||||
|
||||
|
||||
def test_failedExec(self):
|
||||
"""
|
||||
If L{SSHChannel.sendRequest} issues an exec which the server responds to
|
||||
with an error, the L{Deferred} it returns fires its errback.
|
||||
"""
|
||||
channel = self._ourServerOurClientTest()
|
||||
|
||||
def cbChannel(channel):
|
||||
self.channel = channel
|
||||
return self.assertFailure(
|
||||
channel.conn.sendRequest(
|
||||
channel, 'exec', common.NS('jumboliah'), 1),
|
||||
Exception)
|
||||
channel.addCallback(cbChannel)
|
||||
|
||||
def cbFailed(ignored):
|
||||
# The server logs this exception when it cannot perform the
|
||||
# requested exec.
|
||||
errors = self.flushLoggedErrors(error.ConchError)
|
||||
self.assertEqual(errors[0].value.args, ('bad exec', None))
|
||||
channel.addCallback(cbFailed)
|
||||
return channel
|
||||
|
||||
|
||||
def test_falseChannel(self):
|
||||
"""
|
||||
When the process started by a L{SSHChannel.sendRequest} exec request
|
||||
exits, the exit status is reported to the channel.
|
||||
"""
|
||||
channel = self._ourServerOurClientTest()
|
||||
|
||||
def cbChannel(channel):
|
||||
self.channel = channel
|
||||
return channel.conn.sendRequest(
|
||||
channel, 'exec', common.NS('false'), 1)
|
||||
channel.addCallback(cbChannel)
|
||||
|
||||
def cbExec(ignored):
|
||||
return self.channel.onClose
|
||||
channel.addCallback(cbExec)
|
||||
|
||||
def cbClosed(ignored):
|
||||
# No data is expected
|
||||
self.assertEqual(self.channel.received, [])
|
||||
self.assertNotEqual(self.channel.status, 0)
|
||||
channel.addCallback(cbClosed)
|
||||
return channel
|
||||
|
||||
|
||||
def test_errorChannel(self):
|
||||
"""
|
||||
Bytes sent over the extended channel for stderr data are delivered to
|
||||
the channel's C{extReceived} method.
|
||||
"""
|
||||
channel = self._ourServerOurClientTest(localWindow=4, localMaxPacket=5)
|
||||
|
||||
def cbChannel(channel):
|
||||
self.channel = channel
|
||||
return channel.conn.sendRequest(
|
||||
channel, 'exec', common.NS('eecho hello'), 1)
|
||||
channel.addCallback(cbChannel)
|
||||
|
||||
def cbExec(ignored):
|
||||
return defer.gatherResults([
|
||||
self.channel.onClose,
|
||||
self.realm.avatar._testSession.onClose])
|
||||
channel.addCallback(cbExec)
|
||||
|
||||
def cbClosed(ignored):
|
||||
self.assertEqual(self.channel.received, [])
|
||||
self.assertEqual("".join(self.channel.receivedExt), "hello\r\n")
|
||||
self.assertEqual(self.channel.status, 0)
|
||||
self.assertTrue(self.channel.eofCalled)
|
||||
self.assertEqual(self.channel.localWindowLeft, 4)
|
||||
self.assertEqual(
|
||||
self.channel.localWindowLeft,
|
||||
self.realm.avatar._testSession.remoteWindowLeftAtClose)
|
||||
channel.addCallback(cbClosed)
|
||||
return channel
|
||||
|
||||
|
||||
def test_unknownChannel(self):
|
||||
"""
|
||||
When an attempt is made to open an unknown channel type, the L{Deferred}
|
||||
returned by L{SSHChannel.sendRequest} fires its errback.
|
||||
"""
|
||||
d = self.assertFailure(
|
||||
self._ourServerOurClientTest('crazy-unknown-channel'), Exception)
|
||||
def cbFailed(ignored):
|
||||
errors = self.flushLoggedErrors(error.ConchError)
|
||||
self.assertEqual(errors[0].value.args, (3, 'unknown channel'))
|
||||
self.assertEqual(len(errors), 1)
|
||||
d.addCallback(cbFailed)
|
||||
return d
|
||||
|
||||
|
||||
def test_maxPacket(self):
|
||||
"""
|
||||
An L{SSHChannel} can be configured with a maximum packet size to
|
||||
receive.
|
||||
"""
|
||||
# localWindow needs to be at least 11 otherwise the assertion about it
|
||||
# in cbClosed is invalid.
|
||||
channel = self._ourServerOurClientTest(
|
||||
localWindow=11, localMaxPacket=1)
|
||||
|
||||
def cbChannel(channel):
|
||||
self.channel = channel
|
||||
return channel.conn.sendRequest(
|
||||
channel, 'exec', common.NS('secho hello'), 1)
|
||||
channel.addCallback(cbChannel)
|
||||
|
||||
def cbExec(ignored):
|
||||
return self.channel.onClose
|
||||
channel.addCallback(cbExec)
|
||||
|
||||
def cbClosed(ignored):
|
||||
self.assertEqual(self.channel.status, 0)
|
||||
self.assertEqual("".join(self.channel.received), "hello\r\n")
|
||||
self.assertEqual("".join(self.channel.receivedExt), "hello\r\n")
|
||||
self.assertEqual(self.channel.localWindowLeft, 11)
|
||||
self.assertTrue(self.channel.eofCalled)
|
||||
channel.addCallback(cbClosed)
|
||||
return channel
|
||||
|
||||
|
||||
def test_echo(self):
|
||||
"""
|
||||
Normal standard out bytes are sent to the channel's C{dataReceived}
|
||||
method.
|
||||
"""
|
||||
channel = self._ourServerOurClientTest(localWindow=4, localMaxPacket=5)
|
||||
|
||||
def cbChannel(channel):
|
||||
self.channel = channel
|
||||
return channel.conn.sendRequest(
|
||||
channel, 'exec', common.NS('echo hello'), 1)
|
||||
channel.addCallback(cbChannel)
|
||||
|
||||
def cbEcho(ignored):
|
||||
return defer.gatherResults([
|
||||
self.channel.onClose,
|
||||
self.realm.avatar._testSession.onClose])
|
||||
channel.addCallback(cbEcho)
|
||||
|
||||
def cbClosed(ignored):
|
||||
self.assertEqual(self.channel.status, 0)
|
||||
self.assertEqual("".join(self.channel.received), "hello\r\n")
|
||||
self.assertEqual(self.channel.localWindowLeft, 4)
|
||||
self.assertTrue(self.channel.eofCalled)
|
||||
self.assertEqual(
|
||||
self.channel.localWindowLeft,
|
||||
self.realm.avatar._testSession.remoteWindowLeftAtClose)
|
||||
channel.addCallback(cbClosed)
|
||||
return channel
|
||||
|
||||
|
||||
|
||||
class TestSSHFactory(unittest.TestCase):
|
||||
|
||||
if not Crypto:
|
||||
skip = "can't run w/o PyCrypto"
|
||||
|
||||
if not pyasn1:
|
||||
skip = "Cannot run without PyASN1"
|
||||
|
||||
def makeSSHFactory(self, primes=None):
|
||||
sshFactory = factory.SSHFactory()
|
||||
gpk = lambda: {'ssh-rsa' : keys.Key(None)}
|
||||
sshFactory.getPrimes = lambda: primes
|
||||
sshFactory.getPublicKeys = sshFactory.getPrivateKeys = gpk
|
||||
sshFactory.startFactory()
|
||||
return sshFactory
|
||||
|
||||
|
||||
def test_buildProtocol(self):
|
||||
"""
|
||||
By default, buildProtocol() constructs an instance of
|
||||
SSHServerTransport.
|
||||
"""
|
||||
factory = self.makeSSHFactory()
|
||||
protocol = factory.buildProtocol(None)
|
||||
self.assertIsInstance(protocol, transport.SSHServerTransport)
|
||||
|
||||
|
||||
def test_buildProtocolRespectsProtocol(self):
|
||||
"""
|
||||
buildProtocol() calls 'self.protocol()' to construct a protocol
|
||||
instance.
|
||||
"""
|
||||
calls = []
|
||||
def makeProtocol(*args):
|
||||
calls.append(args)
|
||||
return transport.SSHServerTransport()
|
||||
factory = self.makeSSHFactory()
|
||||
factory.protocol = makeProtocol
|
||||
factory.buildProtocol(None)
|
||||
self.assertEqual([()], calls)
|
||||
|
||||
|
||||
def test_multipleFactories(self):
|
||||
f1 = self.makeSSHFactory(primes=None)
|
||||
f2 = self.makeSSHFactory(primes={1:(2,3)})
|
||||
p1 = f1.buildProtocol(None)
|
||||
p2 = f2.buildProtocol(None)
|
||||
self.assertNotIn(
|
||||
'diffie-hellman-group-exchange-sha1', p1.supportedKeyExchanges)
|
||||
self.assertIn(
|
||||
'diffie-hellman-group-exchange-sha1', p2.supportedKeyExchanges)
|
||||
|
||||
|
||||
|
||||
class MPTestCase(unittest.TestCase):
|
||||
"""
|
||||
Tests for L{common.getMP}.
|
||||
|
||||
@cvar getMP: a method providing a MP parser.
|
||||
@type getMP: C{callable}
|
||||
"""
|
||||
getMP = staticmethod(common.getMP)
|
||||
|
||||
if not Crypto:
|
||||
skip = "can't run w/o PyCrypto"
|
||||
|
||||
if not pyasn1:
|
||||
skip = "Cannot run without PyASN1"
|
||||
|
||||
|
||||
def test_getMP(self):
|
||||
"""
|
||||
L{common.getMP} should parse the a multiple precision integer from a
|
||||
string: a 4-byte length followed by length bytes of the integer.
|
||||
"""
|
||||
self.assertEqual(
|
||||
self.getMP('\x00\x00\x00\x04\x00\x00\x00\x01'),
|
||||
(1, ''))
|
||||
|
||||
|
||||
def test_getMPBigInteger(self):
|
||||
"""
|
||||
L{common.getMP} should be able to parse a big enough integer
|
||||
(that doesn't fit on one byte).
|
||||
"""
|
||||
self.assertEqual(
|
||||
self.getMP('\x00\x00\x00\x04\x01\x02\x03\x04'),
|
||||
(16909060, ''))
|
||||
|
||||
|
||||
def test_multipleGetMP(self):
|
||||
"""
|
||||
L{common.getMP} has the ability to parse multiple integer in the same
|
||||
string.
|
||||
"""
|
||||
self.assertEqual(
|
||||
self.getMP('\x00\x00\x00\x04\x00\x00\x00\x01'
|
||||
'\x00\x00\x00\x04\x00\x00\x00\x02', 2),
|
||||
(1, 2, ''))
|
||||
|
||||
|
||||
def test_getMPRemainingData(self):
|
||||
"""
|
||||
When more data than needed is sent to L{common.getMP}, it should return
|
||||
the remaining data.
|
||||
"""
|
||||
self.assertEqual(
|
||||
self.getMP('\x00\x00\x00\x04\x00\x00\x00\x01foo'),
|
||||
(1, 'foo'))
|
||||
|
||||
|
||||
def test_notEnoughData(self):
|
||||
"""
|
||||
When the string passed to L{common.getMP} doesn't even make 5 bytes,
|
||||
it should raise a L{struct.error}.
|
||||
"""
|
||||
self.assertRaises(struct.error, self.getMP, '\x02\x00')
|
||||
|
||||
|
||||
|
||||
class PyMPTestCase(MPTestCase):
|
||||
"""
|
||||
Tests for the python implementation of L{common.getMP}.
|
||||
"""
|
||||
getMP = staticmethod(common.getMP_py)
|
||||
|
||||
|
||||
|
||||
class GMPYMPTestCase(MPTestCase):
|
||||
"""
|
||||
Tests for the gmpy implementation of L{common.getMP}.
|
||||
"""
|
||||
getMP = staticmethod(common._fastgetMP)
|
||||
|
||||
|
||||
class BuiltinPowHackTestCase(unittest.TestCase):
|
||||
"""
|
||||
Tests that the builtin pow method is still correct after
|
||||
L{twisted.conch.ssh.common} monkeypatches it to use gmpy.
|
||||
"""
|
||||
|
||||
def test_floatBase(self):
|
||||
"""
|
||||
pow gives the correct result when passed a base of type float with a
|
||||
non-integer value.
|
||||
"""
|
||||
self.assertEqual(6.25, pow(2.5, 2))
|
||||
|
||||
def test_intBase(self):
|
||||
"""
|
||||
pow gives the correct result when passed a base of type int.
|
||||
"""
|
||||
self.assertEqual(81, pow(3, 4))
|
||||
|
||||
def test_longBase(self):
|
||||
"""
|
||||
pow gives the correct result when passed a base of type long.
|
||||
"""
|
||||
self.assertEqual(81, pow(3, 4))
|
||||
|
||||
def test_mpzBase(self):
|
||||
"""
|
||||
pow gives the correct result when passed a base of type gmpy.mpz.
|
||||
"""
|
||||
if gmpy is None:
|
||||
raise unittest.SkipTest('gmpy not available')
|
||||
self.assertEqual(81, pow(gmpy.mpz(3), 4))
|
||||
|
||||
|
||||
try:
|
||||
import gmpy
|
||||
except ImportError:
|
||||
GMPYMPTestCase.skip = "gmpy not available"
|
||||
gmpy = None
|
||||
183
Linux/lib/python2.7/site-packages/twisted/conch/test/test_tap.py
Normal file
183
Linux/lib/python2.7/site-packages/twisted/conch/test/test_tap.py
Normal file
|
|
@ -0,0 +1,183 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.conch.tap}.
|
||||
"""
|
||||
|
||||
try:
|
||||
import Crypto.Cipher.DES3
|
||||
except:
|
||||
Crypto = None
|
||||
|
||||
try:
|
||||
import pyasn1
|
||||
except ImportError:
|
||||
pyasn1 = None
|
||||
|
||||
try:
|
||||
from twisted.conch import unix
|
||||
except ImportError:
|
||||
unix = None
|
||||
|
||||
if Crypto and pyasn1 and unix:
|
||||
from twisted.conch import tap
|
||||
from twisted.conch.openssh_compat.factory import OpenSSHFactory
|
||||
|
||||
from twisted.application.internet import StreamServerEndpointService
|
||||
from twisted.cred import error
|
||||
from twisted.cred.credentials import IPluggableAuthenticationModules
|
||||
from twisted.cred.credentials import ISSHPrivateKey
|
||||
from twisted.cred.credentials import IUsernamePassword, UsernamePassword
|
||||
|
||||
from twisted.trial.unittest import TestCase
|
||||
|
||||
|
||||
|
||||
class MakeServiceTest(TestCase):
|
||||
"""
|
||||
Tests for L{tap.makeService}.
|
||||
"""
|
||||
|
||||
if not Crypto:
|
||||
skip = "can't run w/o PyCrypto"
|
||||
|
||||
if not pyasn1:
|
||||
skip = "Cannot run without PyASN1"
|
||||
|
||||
if not unix:
|
||||
skip = "can't run on non-posix computers"
|
||||
|
||||
usernamePassword = ('iamuser', 'thisispassword')
|
||||
|
||||
def setUp(self):
|
||||
"""
|
||||
Create a file with two users.
|
||||
"""
|
||||
self.filename = self.mktemp()
|
||||
f = open(self.filename, 'wb+')
|
||||
f.write(':'.join(self.usernamePassword))
|
||||
f.close()
|
||||
self.options = tap.Options()
|
||||
|
||||
|
||||
def test_basic(self):
|
||||
"""
|
||||
L{tap.makeService} returns a L{StreamServerEndpointService} instance
|
||||
running on TCP port 22, and the linked protocol factory is an instance
|
||||
of L{OpenSSHFactory}.
|
||||
"""
|
||||
config = tap.Options()
|
||||
service = tap.makeService(config)
|
||||
self.assertIsInstance(service, StreamServerEndpointService)
|
||||
self.assertEqual(service.endpoint._port, 22)
|
||||
self.assertIsInstance(service.factory, OpenSSHFactory)
|
||||
|
||||
|
||||
def test_defaultAuths(self):
|
||||
"""
|
||||
Make sure that if the C{--auth} command-line option is not passed,
|
||||
the default checkers are (for backwards compatibility): SSH, UNIX, and
|
||||
PAM if available
|
||||
"""
|
||||
numCheckers = 2
|
||||
try:
|
||||
from twisted.cred import pamauth
|
||||
self.assertIn(IPluggableAuthenticationModules,
|
||||
self.options['credInterfaces'],
|
||||
"PAM should be one of the modules")
|
||||
numCheckers += 1
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
self.assertIn(ISSHPrivateKey, self.options['credInterfaces'],
|
||||
"SSH should be one of the default checkers")
|
||||
self.assertIn(IUsernamePassword, self.options['credInterfaces'],
|
||||
"UNIX should be one of the default checkers")
|
||||
self.assertEqual(numCheckers, len(self.options['credCheckers']),
|
||||
"There should be %d checkers by default" % (numCheckers,))
|
||||
|
||||
|
||||
def test_authAdded(self):
|
||||
"""
|
||||
The C{--auth} command-line option will add a checker to the list of
|
||||
checkers, and it should be the only auth checker
|
||||
"""
|
||||
self.options.parseOptions(['--auth', 'file:' + self.filename])
|
||||
self.assertEqual(len(self.options['credCheckers']), 1)
|
||||
|
||||
|
||||
def test_multipleAuthAdded(self):
|
||||
"""
|
||||
Multiple C{--auth} command-line options will add all checkers specified
|
||||
to the list ofcheckers, and there should only be the specified auth
|
||||
checkers (no default checkers).
|
||||
"""
|
||||
self.options.parseOptions(['--auth', 'file:' + self.filename,
|
||||
'--auth', 'memory:testuser:testpassword'])
|
||||
self.assertEqual(len(self.options['credCheckers']), 2)
|
||||
|
||||
|
||||
def test_authFailure(self):
|
||||
"""
|
||||
The checker created by the C{--auth} command-line option returns a
|
||||
L{Deferred} that fails with L{UnauthorizedLogin} when
|
||||
presented with credentials that are unknown to that checker.
|
||||
"""
|
||||
self.options.parseOptions(['--auth', 'file:' + self.filename])
|
||||
checker = self.options['credCheckers'][-1]
|
||||
invalid = UsernamePassword(self.usernamePassword[0], 'fake')
|
||||
# Wrong password should raise error
|
||||
return self.assertFailure(
|
||||
checker.requestAvatarId(invalid), error.UnauthorizedLogin)
|
||||
|
||||
|
||||
def test_authSuccess(self):
|
||||
"""
|
||||
The checker created by the C{--auth} command-line option returns a
|
||||
L{Deferred} that returns the avatar id when presented with credentials
|
||||
that are known to that checker.
|
||||
"""
|
||||
self.options.parseOptions(['--auth', 'file:' + self.filename])
|
||||
checker = self.options['credCheckers'][-1]
|
||||
correct = UsernamePassword(*self.usernamePassword)
|
||||
d = checker.requestAvatarId(correct)
|
||||
|
||||
def checkSuccess(username):
|
||||
self.assertEqual(username, correct.username)
|
||||
|
||||
return d.addCallback(checkSuccess)
|
||||
|
||||
|
||||
def test_checkersPamAuth(self):
|
||||
"""
|
||||
The L{OpenSSHFactory} built by L{tap.makeService} has a portal with
|
||||
L{IPluggableAuthenticationModules}, L{ISSHPrivateKey} and
|
||||
L{IUsernamePassword} interfaces registered as checkers if C{pamauth} is
|
||||
available.
|
||||
"""
|
||||
# Fake the presence of pamauth, even if PyPAM is not installed
|
||||
self.patch(tap, "pamauth", object())
|
||||
config = tap.Options()
|
||||
service = tap.makeService(config)
|
||||
portal = service.factory.portal
|
||||
self.assertEqual(
|
||||
set(portal.checkers.keys()),
|
||||
set([IPluggableAuthenticationModules, ISSHPrivateKey,
|
||||
IUsernamePassword]))
|
||||
|
||||
|
||||
def test_checkersWithoutPamAuth(self):
|
||||
"""
|
||||
The L{OpenSSHFactory} built by L{tap.makeService} has a portal with
|
||||
L{ISSHPrivateKey} and L{IUsernamePassword} interfaces registered as
|
||||
checkers if C{pamauth} is not available.
|
||||
"""
|
||||
# Fake the absence of pamauth, even if PyPAM is installed
|
||||
self.patch(tap, "pamauth", None)
|
||||
config = tap.Options()
|
||||
service = tap.makeService(config)
|
||||
portal = service.factory.portal
|
||||
self.assertEqual(
|
||||
set(portal.checkers.keys()),
|
||||
set([ISSHPrivateKey, IUsernamePassword]))
|
||||
|
|
@ -0,0 +1,767 @@
|
|||
# -*- test-case-name: twisted.conch.test.test_telnet -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.conch.telnet}.
|
||||
"""
|
||||
|
||||
from zope.interface import implements
|
||||
from zope.interface.verify import verifyObject
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from twisted.conch import telnet
|
||||
|
||||
from twisted.trial import unittest
|
||||
from twisted.test import proto_helpers
|
||||
|
||||
|
||||
class TestProtocol:
|
||||
implements(telnet.ITelnetProtocol)
|
||||
|
||||
localEnableable = ()
|
||||
remoteEnableable = ()
|
||||
|
||||
def __init__(self):
|
||||
self.bytes = ''
|
||||
self.subcmd = ''
|
||||
self.calls = []
|
||||
|
||||
self.enabledLocal = []
|
||||
self.enabledRemote = []
|
||||
self.disabledLocal = []
|
||||
self.disabledRemote = []
|
||||
|
||||
def makeConnection(self, transport):
|
||||
d = transport.negotiationMap = {}
|
||||
d['\x12'] = self.neg_TEST_COMMAND
|
||||
|
||||
d = transport.commandMap = transport.commandMap.copy()
|
||||
for cmd in ('NOP', 'DM', 'BRK', 'IP', 'AO', 'AYT', 'EC', 'EL', 'GA'):
|
||||
d[getattr(telnet, cmd)] = lambda arg, cmd=cmd: self.calls.append(cmd)
|
||||
|
||||
def dataReceived(self, bytes):
|
||||
self.bytes += bytes
|
||||
|
||||
def connectionLost(self, reason):
|
||||
pass
|
||||
|
||||
def neg_TEST_COMMAND(self, payload):
|
||||
self.subcmd = payload
|
||||
|
||||
def enableLocal(self, option):
|
||||
if option in self.localEnableable:
|
||||
self.enabledLocal.append(option)
|
||||
return True
|
||||
return False
|
||||
|
||||
def disableLocal(self, option):
|
||||
self.disabledLocal.append(option)
|
||||
|
||||
def enableRemote(self, option):
|
||||
if option in self.remoteEnableable:
|
||||
self.enabledRemote.append(option)
|
||||
return True
|
||||
return False
|
||||
|
||||
def disableRemote(self, option):
|
||||
self.disabledRemote.append(option)
|
||||
|
||||
|
||||
|
||||
class TestInterfaces(unittest.TestCase):
|
||||
def test_interface(self):
|
||||
"""
|
||||
L{telnet.TelnetProtocol} implements L{telnet.ITelnetProtocol}
|
||||
"""
|
||||
p = telnet.TelnetProtocol()
|
||||
verifyObject(telnet.ITelnetProtocol, p)
|
||||
|
||||
|
||||
|
||||
class TelnetTransportTestCase(unittest.TestCase):
|
||||
"""
|
||||
Tests for L{telnet.TelnetTransport}.
|
||||
"""
|
||||
def setUp(self):
|
||||
self.p = telnet.TelnetTransport(TestProtocol)
|
||||
self.t = proto_helpers.StringTransport()
|
||||
self.p.makeConnection(self.t)
|
||||
|
||||
def testRegularBytes(self):
|
||||
# Just send a bunch of bytes. None of these do anything
|
||||
# with telnet. They should pass right through to the
|
||||
# application layer.
|
||||
h = self.p.protocol
|
||||
|
||||
L = ["here are some bytes la la la",
|
||||
"some more arrive here",
|
||||
"lots of bytes to play with",
|
||||
"la la la",
|
||||
"ta de da",
|
||||
"dum"]
|
||||
for b in L:
|
||||
self.p.dataReceived(b)
|
||||
|
||||
self.assertEqual(h.bytes, ''.join(L))
|
||||
|
||||
def testNewlineHandling(self):
|
||||
# Send various kinds of newlines and make sure they get translated
|
||||
# into \n.
|
||||
h = self.p.protocol
|
||||
|
||||
L = ["here is the first line\r\n",
|
||||
"here is the second line\r\0",
|
||||
"here is the third line\r\n",
|
||||
"here is the last line\r\0"]
|
||||
|
||||
for b in L:
|
||||
self.p.dataReceived(b)
|
||||
|
||||
self.assertEqual(h.bytes, L[0][:-2] + '\n' +
|
||||
L[1][:-2] + '\r' +
|
||||
L[2][:-2] + '\n' +
|
||||
L[3][:-2] + '\r')
|
||||
|
||||
def testIACEscape(self):
|
||||
# Send a bunch of bytes and a couple quoted \xFFs. Unquoted,
|
||||
# \xFF is a telnet command. Quoted, one of them from each pair
|
||||
# should be passed through to the application layer.
|
||||
h = self.p.protocol
|
||||
|
||||
L = ["here are some bytes\xff\xff with an embedded IAC",
|
||||
"and here is a test of a border escape\xff",
|
||||
"\xff did you get that IAC?"]
|
||||
|
||||
for b in L:
|
||||
self.p.dataReceived(b)
|
||||
|
||||
self.assertEqual(h.bytes, ''.join(L).replace('\xff\xff', '\xff'))
|
||||
|
||||
def _simpleCommandTest(self, cmdName):
|
||||
# Send a single simple telnet command and make sure
|
||||
# it gets noticed and the appropriate method gets
|
||||
# called.
|
||||
h = self.p.protocol
|
||||
|
||||
cmd = telnet.IAC + getattr(telnet, cmdName)
|
||||
L = ["Here's some bytes, tra la la",
|
||||
"But ono!" + cmd + " an interrupt"]
|
||||
|
||||
for b in L:
|
||||
self.p.dataReceived(b)
|
||||
|
||||
self.assertEqual(h.calls, [cmdName])
|
||||
self.assertEqual(h.bytes, ''.join(L).replace(cmd, ''))
|
||||
|
||||
def testInterrupt(self):
|
||||
self._simpleCommandTest("IP")
|
||||
|
||||
def testNoOperation(self):
|
||||
self._simpleCommandTest("NOP")
|
||||
|
||||
def testDataMark(self):
|
||||
self._simpleCommandTest("DM")
|
||||
|
||||
def testBreak(self):
|
||||
self._simpleCommandTest("BRK")
|
||||
|
||||
def testAbortOutput(self):
|
||||
self._simpleCommandTest("AO")
|
||||
|
||||
def testAreYouThere(self):
|
||||
self._simpleCommandTest("AYT")
|
||||
|
||||
def testEraseCharacter(self):
|
||||
self._simpleCommandTest("EC")
|
||||
|
||||
def testEraseLine(self):
|
||||
self._simpleCommandTest("EL")
|
||||
|
||||
def testGoAhead(self):
|
||||
self._simpleCommandTest("GA")
|
||||
|
||||
def testSubnegotiation(self):
|
||||
# Send a subnegotiation command and make sure it gets
|
||||
# parsed and that the correct method is called.
|
||||
h = self.p.protocol
|
||||
|
||||
cmd = telnet.IAC + telnet.SB + '\x12hello world' + telnet.IAC + telnet.SE
|
||||
L = ["These are some bytes but soon" + cmd,
|
||||
"there will be some more"]
|
||||
|
||||
for b in L:
|
||||
self.p.dataReceived(b)
|
||||
|
||||
self.assertEqual(h.bytes, ''.join(L).replace(cmd, ''))
|
||||
self.assertEqual(h.subcmd, list("hello world"))
|
||||
|
||||
def testSubnegotiationWithEmbeddedSE(self):
|
||||
# Send a subnegotiation command with an embedded SE. Make sure
|
||||
# that SE gets passed to the correct method.
|
||||
h = self.p.protocol
|
||||
|
||||
cmd = (telnet.IAC + telnet.SB +
|
||||
'\x12' + telnet.SE +
|
||||
telnet.IAC + telnet.SE)
|
||||
|
||||
L = ["Some bytes are here" + cmd + "and here",
|
||||
"and here"]
|
||||
|
||||
for b in L:
|
||||
self.p.dataReceived(b)
|
||||
|
||||
self.assertEqual(h.bytes, ''.join(L).replace(cmd, ''))
|
||||
self.assertEqual(h.subcmd, [telnet.SE])
|
||||
|
||||
def testBoundarySubnegotiation(self):
|
||||
# Send a subnegotiation command. Split it at every possible byte boundary
|
||||
# and make sure it always gets parsed and that it is passed to the correct
|
||||
# method.
|
||||
cmd = (telnet.IAC + telnet.SB +
|
||||
'\x12' + telnet.SE + 'hello' +
|
||||
telnet.IAC + telnet.SE)
|
||||
|
||||
for i in range(len(cmd)):
|
||||
h = self.p.protocol = TestProtocol()
|
||||
h.makeConnection(self.p)
|
||||
|
||||
a, b = cmd[:i], cmd[i:]
|
||||
L = ["first part" + a,
|
||||
b + "last part"]
|
||||
|
||||
for bytes in L:
|
||||
self.p.dataReceived(bytes)
|
||||
|
||||
self.assertEqual(h.bytes, ''.join(L).replace(cmd, ''))
|
||||
self.assertEqual(h.subcmd, [telnet.SE] + list('hello'))
|
||||
|
||||
def _enabledHelper(self, o, eL=[], eR=[], dL=[], dR=[]):
|
||||
self.assertEqual(o.enabledLocal, eL)
|
||||
self.assertEqual(o.enabledRemote, eR)
|
||||
self.assertEqual(o.disabledLocal, dL)
|
||||
self.assertEqual(o.disabledRemote, dR)
|
||||
|
||||
def testRefuseWill(self):
|
||||
# Try to enable an option. The server should refuse to enable it.
|
||||
cmd = telnet.IAC + telnet.WILL + '\x12'
|
||||
|
||||
bytes = "surrounding bytes" + cmd + "to spice things up"
|
||||
self.p.dataReceived(bytes)
|
||||
|
||||
self.assertEqual(self.p.protocol.bytes, bytes.replace(cmd, ''))
|
||||
self.assertEqual(self.t.value(), telnet.IAC + telnet.DONT + '\x12')
|
||||
self._enabledHelper(self.p.protocol)
|
||||
|
||||
def testRefuseDo(self):
|
||||
# Try to enable an option. The server should refuse to enable it.
|
||||
cmd = telnet.IAC + telnet.DO + '\x12'
|
||||
|
||||
bytes = "surrounding bytes" + cmd + "to spice things up"
|
||||
self.p.dataReceived(bytes)
|
||||
|
||||
self.assertEqual(self.p.protocol.bytes, bytes.replace(cmd, ''))
|
||||
self.assertEqual(self.t.value(), telnet.IAC + telnet.WONT + '\x12')
|
||||
self._enabledHelper(self.p.protocol)
|
||||
|
||||
def testAcceptDo(self):
|
||||
# Try to enable an option. The option is in our allowEnable
|
||||
# list, so we will allow it to be enabled.
|
||||
cmd = telnet.IAC + telnet.DO + '\x19'
|
||||
bytes = 'padding' + cmd + 'trailer'
|
||||
|
||||
h = self.p.protocol
|
||||
h.localEnableable = ('\x19',)
|
||||
self.p.dataReceived(bytes)
|
||||
|
||||
self.assertEqual(self.t.value(), telnet.IAC + telnet.WILL + '\x19')
|
||||
self._enabledHelper(h, eL=['\x19'])
|
||||
|
||||
def testAcceptWill(self):
|
||||
# Same as testAcceptDo, but reversed.
|
||||
cmd = telnet.IAC + telnet.WILL + '\x91'
|
||||
bytes = 'header' + cmd + 'padding'
|
||||
|
||||
h = self.p.protocol
|
||||
h.remoteEnableable = ('\x91',)
|
||||
self.p.dataReceived(bytes)
|
||||
|
||||
self.assertEqual(self.t.value(), telnet.IAC + telnet.DO + '\x91')
|
||||
self._enabledHelper(h, eR=['\x91'])
|
||||
|
||||
def testAcceptWont(self):
|
||||
# Try to disable an option. The server must allow any option to
|
||||
# be disabled at any time. Make sure it disables it and sends
|
||||
# back an acknowledgement of this.
|
||||
cmd = telnet.IAC + telnet.WONT + '\x29'
|
||||
|
||||
# Jimmy it - after these two lines, the server will be in a state
|
||||
# such that it believes the option to have been previously enabled
|
||||
# via normal negotiation.
|
||||
s = self.p.getOptionState('\x29')
|
||||
s.him.state = 'yes'
|
||||
|
||||
bytes = "fiddle dee" + cmd
|
||||
self.p.dataReceived(bytes)
|
||||
|
||||
self.assertEqual(self.p.protocol.bytes, bytes.replace(cmd, ''))
|
||||
self.assertEqual(self.t.value(), telnet.IAC + telnet.DONT + '\x29')
|
||||
self.assertEqual(s.him.state, 'no')
|
||||
self._enabledHelper(self.p.protocol, dR=['\x29'])
|
||||
|
||||
def testAcceptDont(self):
|
||||
# Try to disable an option. The server must allow any option to
|
||||
# be disabled at any time. Make sure it disables it and sends
|
||||
# back an acknowledgement of this.
|
||||
cmd = telnet.IAC + telnet.DONT + '\x29'
|
||||
|
||||
# Jimmy it - after these two lines, the server will be in a state
|
||||
# such that it believes the option to have beenp previously enabled
|
||||
# via normal negotiation.
|
||||
s = self.p.getOptionState('\x29')
|
||||
s.us.state = 'yes'
|
||||
|
||||
bytes = "fiddle dum " + cmd
|
||||
self.p.dataReceived(bytes)
|
||||
|
||||
self.assertEqual(self.p.protocol.bytes, bytes.replace(cmd, ''))
|
||||
self.assertEqual(self.t.value(), telnet.IAC + telnet.WONT + '\x29')
|
||||
self.assertEqual(s.us.state, 'no')
|
||||
self._enabledHelper(self.p.protocol, dL=['\x29'])
|
||||
|
||||
def testIgnoreWont(self):
|
||||
# Try to disable an option. The option is already disabled. The
|
||||
# server should send nothing in response to this.
|
||||
cmd = telnet.IAC + telnet.WONT + '\x47'
|
||||
|
||||
bytes = "dum de dum" + cmd + "tra la la"
|
||||
self.p.dataReceived(bytes)
|
||||
|
||||
self.assertEqual(self.p.protocol.bytes, bytes.replace(cmd, ''))
|
||||
self.assertEqual(self.t.value(), '')
|
||||
self._enabledHelper(self.p.protocol)
|
||||
|
||||
def testIgnoreDont(self):
|
||||
# Try to disable an option. The option is already disabled. The
|
||||
# server should send nothing in response to this. Doing so could
|
||||
# lead to a negotiation loop.
|
||||
cmd = telnet.IAC + telnet.DONT + '\x47'
|
||||
|
||||
bytes = "dum de dum" + cmd + "tra la la"
|
||||
self.p.dataReceived(bytes)
|
||||
|
||||
self.assertEqual(self.p.protocol.bytes, bytes.replace(cmd, ''))
|
||||
self.assertEqual(self.t.value(), '')
|
||||
self._enabledHelper(self.p.protocol)
|
||||
|
||||
def testIgnoreWill(self):
|
||||
# Try to enable an option. The option is already enabled. The
|
||||
# server should send nothing in response to this. Doing so could
|
||||
# lead to a negotiation loop.
|
||||
cmd = telnet.IAC + telnet.WILL + '\x56'
|
||||
|
||||
# Jimmy it - after these two lines, the server will be in a state
|
||||
# such that it believes the option to have been previously enabled
|
||||
# via normal negotiation.
|
||||
s = self.p.getOptionState('\x56')
|
||||
s.him.state = 'yes'
|
||||
|
||||
bytes = "tra la la" + cmd + "dum de dum"
|
||||
self.p.dataReceived(bytes)
|
||||
|
||||
self.assertEqual(self.p.protocol.bytes, bytes.replace(cmd, ''))
|
||||
self.assertEqual(self.t.value(), '')
|
||||
self._enabledHelper(self.p.protocol)
|
||||
|
||||
def testIgnoreDo(self):
|
||||
# Try to enable an option. The option is already enabled. The
|
||||
# server should send nothing in response to this. Doing so could
|
||||
# lead to a negotiation loop.
|
||||
cmd = telnet.IAC + telnet.DO + '\x56'
|
||||
|
||||
# Jimmy it - after these two lines, the server will be in a state
|
||||
# such that it believes the option to have been previously enabled
|
||||
# via normal negotiation.
|
||||
s = self.p.getOptionState('\x56')
|
||||
s.us.state = 'yes'
|
||||
|
||||
bytes = "tra la la" + cmd + "dum de dum"
|
||||
self.p.dataReceived(bytes)
|
||||
|
||||
self.assertEqual(self.p.protocol.bytes, bytes.replace(cmd, ''))
|
||||
self.assertEqual(self.t.value(), '')
|
||||
self._enabledHelper(self.p.protocol)
|
||||
|
||||
def testAcceptedEnableRequest(self):
|
||||
# Try to enable an option through the user-level API. This
|
||||
# returns a Deferred that fires when negotiation about the option
|
||||
# finishes. Make sure it fires, make sure state gets updated
|
||||
# properly, make sure the result indicates the option was enabled.
|
||||
d = self.p.do('\x42')
|
||||
|
||||
h = self.p.protocol
|
||||
h.remoteEnableable = ('\x42',)
|
||||
|
||||
self.assertEqual(self.t.value(), telnet.IAC + telnet.DO + '\x42')
|
||||
|
||||
self.p.dataReceived(telnet.IAC + telnet.WILL + '\x42')
|
||||
|
||||
d.addCallback(self.assertEqual, True)
|
||||
d.addCallback(lambda _: self._enabledHelper(h, eR=['\x42']))
|
||||
return d
|
||||
|
||||
|
||||
def test_refusedEnableRequest(self):
|
||||
"""
|
||||
If the peer refuses to enable an option we request it to enable, the
|
||||
L{Deferred} returned by L{TelnetProtocol.do} fires with an
|
||||
L{OptionRefused} L{Failure}.
|
||||
"""
|
||||
# Try to enable an option through the user-level API. This returns a
|
||||
# Deferred that fires when negotiation about the option finishes. Make
|
||||
# sure it fires, make sure state gets updated properly, make sure the
|
||||
# result indicates the option was enabled.
|
||||
self.p.protocol.remoteEnableable = ('\x42',)
|
||||
d = self.p.do('\x42')
|
||||
|
||||
self.assertEqual(self.t.value(), telnet.IAC + telnet.DO + '\x42')
|
||||
|
||||
s = self.p.getOptionState('\x42')
|
||||
self.assertEqual(s.him.state, 'no')
|
||||
self.assertEqual(s.us.state, 'no')
|
||||
self.assertEqual(s.him.negotiating, True)
|
||||
self.assertEqual(s.us.negotiating, False)
|
||||
|
||||
self.p.dataReceived(telnet.IAC + telnet.WONT + '\x42')
|
||||
|
||||
d = self.assertFailure(d, telnet.OptionRefused)
|
||||
d.addCallback(lambda ignored: self._enabledHelper(self.p.protocol))
|
||||
d.addCallback(
|
||||
lambda ignored: self.assertEqual(s.him.negotiating, False))
|
||||
return d
|
||||
|
||||
|
||||
def test_refusedEnableOffer(self):
|
||||
"""
|
||||
If the peer refuses to allow us to enable an option, the L{Deferred}
|
||||
returned by L{TelnetProtocol.will} fires with an L{OptionRefused}
|
||||
L{Failure}.
|
||||
"""
|
||||
# Try to offer an option through the user-level API. This returns a
|
||||
# Deferred that fires when negotiation about the option finishes. Make
|
||||
# sure it fires, make sure state gets updated properly, make sure the
|
||||
# result indicates the option was enabled.
|
||||
self.p.protocol.localEnableable = ('\x42',)
|
||||
d = self.p.will('\x42')
|
||||
|
||||
self.assertEqual(self.t.value(), telnet.IAC + telnet.WILL + '\x42')
|
||||
|
||||
s = self.p.getOptionState('\x42')
|
||||
self.assertEqual(s.him.state, 'no')
|
||||
self.assertEqual(s.us.state, 'no')
|
||||
self.assertEqual(s.him.negotiating, False)
|
||||
self.assertEqual(s.us.negotiating, True)
|
||||
|
||||
self.p.dataReceived(telnet.IAC + telnet.DONT + '\x42')
|
||||
|
||||
d = self.assertFailure(d, telnet.OptionRefused)
|
||||
d.addCallback(lambda ignored: self._enabledHelper(self.p.protocol))
|
||||
d.addCallback(
|
||||
lambda ignored: self.assertEqual(s.us.negotiating, False))
|
||||
return d
|
||||
|
||||
|
||||
def testAcceptedDisableRequest(self):
|
||||
# Try to disable an option through the user-level API. This
|
||||
# returns a Deferred that fires when negotiation about the option
|
||||
# finishes. Make sure it fires, make sure state gets updated
|
||||
# properly, make sure the result indicates the option was enabled.
|
||||
s = self.p.getOptionState('\x42')
|
||||
s.him.state = 'yes'
|
||||
|
||||
d = self.p.dont('\x42')
|
||||
|
||||
self.assertEqual(self.t.value(), telnet.IAC + telnet.DONT + '\x42')
|
||||
|
||||
self.p.dataReceived(telnet.IAC + telnet.WONT + '\x42')
|
||||
|
||||
d.addCallback(self.assertEqual, True)
|
||||
d.addCallback(lambda _: self._enabledHelper(self.p.protocol,
|
||||
dR=['\x42']))
|
||||
return d
|
||||
|
||||
def testNegotiationBlocksFurtherNegotiation(self):
|
||||
# Try to disable an option, then immediately try to enable it, then
|
||||
# immediately try to disable it. Ensure that the 2nd and 3rd calls
|
||||
# fail quickly with the right exception.
|
||||
s = self.p.getOptionState('\x24')
|
||||
s.him.state = 'yes'
|
||||
d2 = self.p.dont('\x24') # fires after the first line of _final
|
||||
|
||||
def _do(x):
|
||||
d = self.p.do('\x24')
|
||||
return self.assertFailure(d, telnet.AlreadyNegotiating)
|
||||
|
||||
def _dont(x):
|
||||
d = self.p.dont('\x24')
|
||||
return self.assertFailure(d, telnet.AlreadyNegotiating)
|
||||
|
||||
def _final(x):
|
||||
self.p.dataReceived(telnet.IAC + telnet.WONT + '\x24')
|
||||
# an assertion that only passes if d2 has fired
|
||||
self._enabledHelper(self.p.protocol, dR=['\x24'])
|
||||
# Make sure we allow this
|
||||
self.p.protocol.remoteEnableable = ('\x24',)
|
||||
d = self.p.do('\x24')
|
||||
self.p.dataReceived(telnet.IAC + telnet.WILL + '\x24')
|
||||
d.addCallback(self.assertEqual, True)
|
||||
d.addCallback(lambda _: self._enabledHelper(self.p.protocol,
|
||||
eR=['\x24'],
|
||||
dR=['\x24']))
|
||||
return d
|
||||
|
||||
d = _do(None)
|
||||
d.addCallback(_dont)
|
||||
d.addCallback(_final)
|
||||
return d
|
||||
|
||||
def testSuperfluousDisableRequestRaises(self):
|
||||
# Try to disable a disabled option. Make sure it fails properly.
|
||||
d = self.p.dont('\xab')
|
||||
return self.assertFailure(d, telnet.AlreadyDisabled)
|
||||
|
||||
def testSuperfluousEnableRequestRaises(self):
|
||||
# Try to disable a disabled option. Make sure it fails properly.
|
||||
s = self.p.getOptionState('\xab')
|
||||
s.him.state = 'yes'
|
||||
d = self.p.do('\xab')
|
||||
return self.assertFailure(d, telnet.AlreadyEnabled)
|
||||
|
||||
def testLostConnectionFailsDeferreds(self):
|
||||
d1 = self.p.do('\x12')
|
||||
d2 = self.p.do('\x23')
|
||||
d3 = self.p.do('\x34')
|
||||
|
||||
class TestException(Exception):
|
||||
pass
|
||||
|
||||
self.p.connectionLost(TestException("Total failure!"))
|
||||
|
||||
d1 = self.assertFailure(d1, TestException)
|
||||
d2 = self.assertFailure(d2, TestException)
|
||||
d3 = self.assertFailure(d3, TestException)
|
||||
return defer.gatherResults([d1, d2, d3])
|
||||
|
||||
|
||||
class TestTelnet(telnet.Telnet):
|
||||
"""
|
||||
A trivial extension of the telnet protocol class useful to unit tests.
|
||||
"""
|
||||
def __init__(self):
|
||||
telnet.Telnet.__init__(self)
|
||||
self.events = []
|
||||
|
||||
|
||||
def applicationDataReceived(self, bytes):
|
||||
"""
|
||||
Record the given data in C{self.events}.
|
||||
"""
|
||||
self.events.append(('bytes', bytes))
|
||||
|
||||
|
||||
def unhandledCommand(self, command, bytes):
|
||||
"""
|
||||
Record the given command in C{self.events}.
|
||||
"""
|
||||
self.events.append(('command', command, bytes))
|
||||
|
||||
|
||||
def unhandledSubnegotiation(self, command, bytes):
|
||||
"""
|
||||
Record the given subnegotiation command in C{self.events}.
|
||||
"""
|
||||
self.events.append(('negotiate', command, bytes))
|
||||
|
||||
|
||||
|
||||
class TelnetTests(unittest.TestCase):
|
||||
"""
|
||||
Tests for L{telnet.Telnet}.
|
||||
|
||||
L{telnet.Telnet} implements the TELNET protocol (RFC 854), including option
|
||||
and suboption negotiation, and option state tracking.
|
||||
"""
|
||||
def setUp(self):
|
||||
"""
|
||||
Create an unconnected L{telnet.Telnet} to be used by tests.
|
||||
"""
|
||||
self.protocol = TestTelnet()
|
||||
|
||||
|
||||
def test_enableLocal(self):
|
||||
"""
|
||||
L{telnet.Telnet.enableLocal} should reject all options, since
|
||||
L{telnet.Telnet} does not know how to implement any options.
|
||||
"""
|
||||
self.assertFalse(self.protocol.enableLocal('\0'))
|
||||
|
||||
|
||||
def test_enableRemote(self):
|
||||
"""
|
||||
L{telnet.Telnet.enableRemote} should reject all options, since
|
||||
L{telnet.Telnet} does not know how to implement any options.
|
||||
"""
|
||||
self.assertFalse(self.protocol.enableRemote('\0'))
|
||||
|
||||
|
||||
def test_disableLocal(self):
|
||||
"""
|
||||
It is an error for L{telnet.Telnet.disableLocal} to be called, since
|
||||
L{telnet.Telnet.enableLocal} will never allow any options to be enabled
|
||||
locally. If a subclass overrides enableLocal, it must also override
|
||||
disableLocal.
|
||||
"""
|
||||
self.assertRaises(NotImplementedError, self.protocol.disableLocal, '\0')
|
||||
|
||||
|
||||
def test_disableRemote(self):
|
||||
"""
|
||||
It is an error for L{telnet.Telnet.disableRemote} to be called, since
|
||||
L{telnet.Telnet.enableRemote} will never allow any options to be
|
||||
enabled remotely. If a subclass overrides enableRemote, it must also
|
||||
override disableRemote.
|
||||
"""
|
||||
self.assertRaises(NotImplementedError, self.protocol.disableRemote, '\0')
|
||||
|
||||
|
||||
def test_requestNegotiation(self):
|
||||
"""
|
||||
L{telnet.Telnet.requestNegotiation} formats the feature byte and the
|
||||
payload bytes into the subnegotiation format and sends them.
|
||||
|
||||
See RFC 855.
|
||||
"""
|
||||
transport = proto_helpers.StringTransport()
|
||||
self.protocol.makeConnection(transport)
|
||||
self.protocol.requestNegotiation('\x01', '\x02\x03')
|
||||
self.assertEqual(
|
||||
transport.value(),
|
||||
# IAC SB feature bytes IAC SE
|
||||
'\xff\xfa\x01\x02\x03\xff\xf0')
|
||||
|
||||
|
||||
def test_requestNegotiationEscapesIAC(self):
|
||||
"""
|
||||
If the payload for a subnegotiation includes I{IAC}, it is escaped by
|
||||
L{telnet.Telnet.requestNegotiation} with another I{IAC}.
|
||||
|
||||
See RFC 855.
|
||||
"""
|
||||
transport = proto_helpers.StringTransport()
|
||||
self.protocol.makeConnection(transport)
|
||||
self.protocol.requestNegotiation('\x01', '\xff')
|
||||
self.assertEqual(
|
||||
transport.value(),
|
||||
'\xff\xfa\x01\xff\xff\xff\xf0')
|
||||
|
||||
|
||||
def _deliver(self, bytes, *expected):
|
||||
"""
|
||||
Pass the given bytes to the protocol's C{dataReceived} method and
|
||||
assert that the given events occur.
|
||||
"""
|
||||
received = self.protocol.events = []
|
||||
self.protocol.dataReceived(bytes)
|
||||
self.assertEqual(received, list(expected))
|
||||
|
||||
|
||||
def test_oneApplicationDataByte(self):
|
||||
"""
|
||||
One application-data byte in the default state gets delivered right
|
||||
away.
|
||||
"""
|
||||
self._deliver('a', ('bytes', 'a'))
|
||||
|
||||
|
||||
def test_twoApplicationDataBytes(self):
|
||||
"""
|
||||
Two application-data bytes in the default state get delivered
|
||||
together.
|
||||
"""
|
||||
self._deliver('bc', ('bytes', 'bc'))
|
||||
|
||||
|
||||
def test_threeApplicationDataBytes(self):
|
||||
"""
|
||||
Three application-data bytes followed by a control byte get
|
||||
delivered, but the control byte doesn't.
|
||||
"""
|
||||
self._deliver('def' + telnet.IAC, ('bytes', 'def'))
|
||||
|
||||
|
||||
def test_escapedControl(self):
|
||||
"""
|
||||
IAC in the escaped state gets delivered and so does another
|
||||
application-data byte following it.
|
||||
"""
|
||||
self._deliver(telnet.IAC)
|
||||
self._deliver(telnet.IAC + 'g', ('bytes', telnet.IAC + 'g'))
|
||||
|
||||
|
||||
def test_carriageReturn(self):
|
||||
"""
|
||||
A carriage return only puts the protocol into the newline state. A
|
||||
linefeed in the newline state causes just the newline to be
|
||||
delivered. A nul in the newline state causes a carriage return to
|
||||
be delivered. An IAC in the newline state causes a carriage return
|
||||
to be delivered and puts the protocol into the escaped state.
|
||||
Anything else causes a carriage return and that thing to be
|
||||
delivered.
|
||||
"""
|
||||
self._deliver('\r')
|
||||
self._deliver('\n', ('bytes', '\n'))
|
||||
self._deliver('\r\n', ('bytes', '\n'))
|
||||
|
||||
self._deliver('\r')
|
||||
self._deliver('\0', ('bytes', '\r'))
|
||||
self._deliver('\r\0', ('bytes', '\r'))
|
||||
|
||||
self._deliver('\r')
|
||||
self._deliver('a', ('bytes', '\ra'))
|
||||
self._deliver('\ra', ('bytes', '\ra'))
|
||||
|
||||
self._deliver('\r')
|
||||
self._deliver(
|
||||
telnet.IAC + telnet.IAC + 'x', ('bytes', '\r' + telnet.IAC + 'x'))
|
||||
|
||||
|
||||
def test_applicationDataBeforeSimpleCommand(self):
|
||||
"""
|
||||
Application bytes received before a command are delivered before the
|
||||
command is processed.
|
||||
"""
|
||||
self._deliver(
|
||||
'x' + telnet.IAC + telnet.NOP,
|
||||
('bytes', 'x'), ('command', telnet.NOP, None))
|
||||
|
||||
|
||||
def test_applicationDataBeforeCommand(self):
|
||||
"""
|
||||
Application bytes received before a WILL/WONT/DO/DONT are delivered
|
||||
before the command is processed.
|
||||
"""
|
||||
self.protocol.commandMap = {}
|
||||
self._deliver(
|
||||
'y' + telnet.IAC + telnet.WILL + '\x00',
|
||||
('bytes', 'y'), ('command', telnet.WILL, '\x00'))
|
||||
|
||||
|
||||
def test_applicationDataBeforeSubnegotiation(self):
|
||||
"""
|
||||
Application bytes received before a subnegotiation command are
|
||||
delivered before the negotiation is processed.
|
||||
"""
|
||||
self._deliver(
|
||||
'z' + telnet.IAC + telnet.SB + 'Qx' + telnet.IAC + telnet.SE,
|
||||
('bytes', 'z'), ('negotiate', 'Q', ['x']))
|
||||
|
|
@ -0,0 +1,161 @@
|
|||
# -*- test-case-name: twisted.conch.test.test_text -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
from twisted.trial import unittest
|
||||
|
||||
from twisted.conch.insults import helper, text
|
||||
from twisted.conch.insults.text import attributes as A
|
||||
|
||||
|
||||
|
||||
class FormattedTextTests(unittest.TestCase):
|
||||
"""
|
||||
Tests for assembling formatted text.
|
||||
"""
|
||||
def test_trivial(self):
|
||||
"""
|
||||
Using no formatting attributes produces no VT102 control sequences in
|
||||
the flattened output.
|
||||
"""
|
||||
self.assertEqual(
|
||||
text.assembleFormattedText(A.normal['Hello, world.']),
|
||||
'Hello, world.')
|
||||
|
||||
|
||||
def test_bold(self):
|
||||
"""
|
||||
The bold formatting attribute, L{A.bold}, emits the VT102 control
|
||||
sequence to enable bold when flattened.
|
||||
"""
|
||||
self.assertEqual(
|
||||
text.assembleFormattedText(A.bold['Hello, world.']),
|
||||
'\x1b[1mHello, world.')
|
||||
|
||||
|
||||
def test_underline(self):
|
||||
"""
|
||||
The underline formatting attribute, L{A.underline}, emits the VT102
|
||||
control sequence to enable underlining when flattened.
|
||||
"""
|
||||
self.assertEqual(
|
||||
text.assembleFormattedText(A.underline['Hello, world.']),
|
||||
'\x1b[4mHello, world.')
|
||||
|
||||
|
||||
def test_blink(self):
|
||||
"""
|
||||
The blink formatting attribute, L{A.blink}, emits the VT102 control
|
||||
sequence to enable blinking when flattened.
|
||||
"""
|
||||
self.assertEqual(
|
||||
text.assembleFormattedText(A.blink['Hello, world.']),
|
||||
'\x1b[5mHello, world.')
|
||||
|
||||
|
||||
def test_reverseVideo(self):
|
||||
"""
|
||||
The reverse-video formatting attribute, L{A.reverseVideo}, emits the
|
||||
VT102 control sequence to enable reversed video when flattened.
|
||||
"""
|
||||
self.assertEqual(
|
||||
text.assembleFormattedText(A.reverseVideo['Hello, world.']),
|
||||
'\x1b[7mHello, world.')
|
||||
|
||||
|
||||
def test_minus(self):
|
||||
"""
|
||||
Formatting attributes prefixed with a minus (C{-}) temporarily disable
|
||||
the prefixed attribute, emitting no VT102 control sequence to enable
|
||||
it in the flattened output.
|
||||
"""
|
||||
self.assertEqual(
|
||||
text.assembleFormattedText(
|
||||
A.bold[A.blink['Hello', -A.bold[' world'], '.']]),
|
||||
'\x1b[1;5mHello\x1b[0;5m world\x1b[1;5m.')
|
||||
|
||||
|
||||
def test_foreground(self):
|
||||
"""
|
||||
The foreground color formatting attribute, L{A.fg}, emits the VT102
|
||||
control sequence to set the selected foreground color when flattened.
|
||||
"""
|
||||
self.assertEqual(
|
||||
text.assembleFormattedText(
|
||||
A.normal[A.fg.red['Hello, '], A.fg.green['world!']]),
|
||||
'\x1b[31mHello, \x1b[32mworld!')
|
||||
|
||||
|
||||
def test_background(self):
|
||||
"""
|
||||
The background color formatting attribute, L{A.bg}, emits the VT102
|
||||
control sequence to set the selected background color when flattened.
|
||||
"""
|
||||
self.assertEqual(
|
||||
text.assembleFormattedText(
|
||||
A.normal[A.bg.red['Hello, '], A.bg.green['world!']]),
|
||||
'\x1b[41mHello, \x1b[42mworld!')
|
||||
|
||||
|
||||
def test_flattenDeprecated(self):
|
||||
"""
|
||||
L{twisted.conch.insults.text.flatten} emits a deprecation warning when
|
||||
imported or accessed.
|
||||
"""
|
||||
warningsShown = self.flushWarnings([self.test_flattenDeprecated])
|
||||
self.assertEqual(len(warningsShown), 0)
|
||||
|
||||
# Trigger the deprecation warning.
|
||||
text.flatten
|
||||
|
||||
warningsShown = self.flushWarnings([self.test_flattenDeprecated])
|
||||
self.assertEqual(len(warningsShown), 1)
|
||||
self.assertEqual(warningsShown[0]['category'], DeprecationWarning)
|
||||
self.assertEqual(
|
||||
warningsShown[0]['message'],
|
||||
'twisted.conch.insults.text.flatten was deprecated in Twisted '
|
||||
'13.1.0: Use twisted.conch.insults.text.assembleFormattedText '
|
||||
'instead.')
|
||||
|
||||
|
||||
|
||||
class EfficiencyTestCase(unittest.TestCase):
|
||||
todo = ("flatten() isn't quite stateful enough to avoid emitting a few extra bytes in "
|
||||
"certain circumstances, so these tests fail. The failures take the form of "
|
||||
"additional elements in the ;-delimited character attribute lists. For example, "
|
||||
"\\x1b[0;31;46m might be emitted instead of \\x[46m, even if 31 has already been "
|
||||
"activated and no conflicting attributes are set which need to be cleared.")
|
||||
|
||||
def setUp(self):
|
||||
self.attrs = helper._FormattingState()
|
||||
|
||||
def testComplexStructure(self):
|
||||
output = A.normal[
|
||||
A.bold[
|
||||
A.bg.cyan[
|
||||
A.fg.red[
|
||||
"Foreground Red, Background Cyan, Bold",
|
||||
A.blink[
|
||||
"Blinking"],
|
||||
-A.bold[
|
||||
"Foreground Red, Background Cyan, normal"]],
|
||||
A.fg.green[
|
||||
"Foreground Green, Background Cyan, Bold"]]]]
|
||||
|
||||
self.assertEqual(
|
||||
text.flatten(output, self.attrs),
|
||||
"\x1b[1;31;46mForeground Red, Background Cyan, Bold"
|
||||
"\x1b[5mBlinking"
|
||||
"\x1b[0;31;46mForeground Red, Background Cyan, normal"
|
||||
"\x1b[1;32;46mForeground Green, Background Cyan, Bold")
|
||||
|
||||
def testNesting(self):
|
||||
self.assertEqual(
|
||||
text.flatten(A.bold['Hello, ', A.underline['world.']], self.attrs),
|
||||
'\x1b[1mHello, \x1b[4mworld.')
|
||||
|
||||
self.assertEqual(
|
||||
text.flatten(
|
||||
A.bold[A.reverseVideo['Hello, ', A.normal['world'], '.']],
|
||||
self.attrs),
|
||||
'\x1b[1;7mHello, \x1b[0mworld\x1b[1;7m.')
|
||||
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
|
@ -0,0 +1,67 @@
|
|||
|
||||
"""
|
||||
Tests for the insults windowing module, L{twisted.conch.insults.window}.
|
||||
"""
|
||||
|
||||
from twisted.trial.unittest import TestCase
|
||||
|
||||
from twisted.conch.insults.window import TopWindow, ScrolledArea, TextOutput
|
||||
|
||||
|
||||
class TopWindowTests(TestCase):
|
||||
"""
|
||||
Tests for L{TopWindow}, the root window container class.
|
||||
"""
|
||||
|
||||
def test_paintScheduling(self):
|
||||
"""
|
||||
Verify that L{TopWindow.repaint} schedules an actual paint to occur
|
||||
using the scheduling object passed to its initializer.
|
||||
"""
|
||||
paints = []
|
||||
scheduled = []
|
||||
root = TopWindow(lambda: paints.append(None), scheduled.append)
|
||||
|
||||
# Nothing should have happened yet.
|
||||
self.assertEqual(paints, [])
|
||||
self.assertEqual(scheduled, [])
|
||||
|
||||
# Cause a paint to be scheduled.
|
||||
root.repaint()
|
||||
self.assertEqual(paints, [])
|
||||
self.assertEqual(len(scheduled), 1)
|
||||
|
||||
# Do another one to verify nothing else happens as long as the previous
|
||||
# one is still pending.
|
||||
root.repaint()
|
||||
self.assertEqual(paints, [])
|
||||
self.assertEqual(len(scheduled), 1)
|
||||
|
||||
# Run the actual paint call.
|
||||
scheduled.pop()()
|
||||
self.assertEqual(len(paints), 1)
|
||||
self.assertEqual(scheduled, [])
|
||||
|
||||
# Do one more to verify that now that the previous one is finished
|
||||
# future paints will succeed.
|
||||
root.repaint()
|
||||
self.assertEqual(len(paints), 1)
|
||||
self.assertEqual(len(scheduled), 1)
|
||||
|
||||
|
||||
|
||||
class ScrolledAreaTests(TestCase):
|
||||
"""
|
||||
Tests for L{ScrolledArea}, a widget which creates a viewport containing
|
||||
another widget and can reposition that viewport using scrollbars.
|
||||
"""
|
||||
def test_parent(self):
|
||||
"""
|
||||
The parent of the widget passed to L{ScrolledArea} is set to a new
|
||||
L{Viewport} created by the L{ScrolledArea} which itself has the
|
||||
L{ScrolledArea} instance as its parent.
|
||||
"""
|
||||
widget = TextOutput()
|
||||
scrolled = ScrolledArea(widget)
|
||||
self.assertIs(widget.parent, scrolled._viewport)
|
||||
self.assertIs(scrolled._viewport.parent, scrolled)
|
||||
Loading…
Add table
Add a link
Reference in a new issue