Open Media Library Platform

This commit is contained in:
j 2013-10-11 19:28:32 +02:00
commit 411ad5b16f
5849 changed files with 1778641 additions and 0 deletions

View file

@ -0,0 +1 @@
'conch tests'

View file

@ -0,0 +1,208 @@
# -*- test-case-name: twisted.conch.test.test_keys -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Data used by test_keys as well as others.
"""
RSAData = {
'n':long('1062486685755247411169438309495398947372127791189432809481'
'382072971106157632182084539383569281493520117634129557550415277'
'516685881326038852354459895734875625093273594925884531272867425'
'864910490065695876046999646807138717162833156501L'),
'e':35L,
'd':long('6678487739032983727350755088256793383481946116047863373882'
'973030104095847973715959961839578340816412167985957218887914482'
'713602371850869127033494910375212470664166001439410214474266799'
'85974425203903884190893469297150446322896587555L'),
'q':long('3395694744258061291019136154000709371890447462086362702627'
'9704149412726577280741108645721676968699696898960891593323L'),
'p':long('3128922844292337321766351031842562691837301298995834258844'
'4720539204069737532863831050930719431498338835415515173887L')}
DSAData = {
'y':long('2300663509295750360093768159135720439490120577534296730713'
'348508834878775464483169644934425336771277908527130096489120714'
'610188630979820723924744291603865L'),
'g':long('4451569990409370769930903934104221766858515498655655091803'
'866645719060300558655677517139568505649468378587802312867198352'
'1161998270001677664063945776405L'),
'p':long('7067311773048598659694590252855127633397024017439939353776'
'608320410518694001356789646664502838652272205440894335303988504'
'978724817717069039110940675621677L'),
'q':1184501645189849666738820838619601267690550087703L,
'x':863951293559205482820041244219051653999559962819L}
publicRSA_openssh = ("ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAGEArzJx8OYOnJmzf4tfBE"
"vLi8DVPrJ3/c9k2I/Az64fxjHf9imyRJbixtQhlH9lfNjUIx+4LmrJH5QNRsFporcHDKOTwTTYL"
"h5KmRpslkYHRivcJSkbh/C+BR3utDS555mV comment")
privateRSA_openssh = """-----BEGIN RSA PRIVATE KEY-----
MIIByAIBAAJhAK8ycfDmDpyZs3+LXwRLy4vA1T6yd/3PZNiPwM+uH8Yx3/YpskSW
4sbUIZR/ZXzY1CMfuC5qyR+UDUbBaaK3Bwyjk8E02C4eSpkabJZGB0Yr3CUpG4fw
vgUd7rQ0ueeZlQIBIwJgbh+1VZfr7WftK5lu7MHtqE1S1vPWZQYE3+VUn8yJADyb
Z4fsZaCrzW9lkIqXkE3GIY+ojdhZhkO1gbG0118sIgphwSWKRxK0mvh6ERxKqIt1
xJEJO74EykXZV4oNJ8sjAjEA3J9r2ZghVhGN6V8DnQrTk24Td0E8hU8AcP0FVP+8
PQm/g/aXf2QQkQT+omdHVEJrAjEAy0pL0EBH6EVS98evDCBtQw22OZT52qXlAwZ2
gyTriKFVoqjeEjt3SZKKqXHSApP/AjBLpF99zcJJZRq2abgYlf9lv1chkrWqDHUu
DZttmYJeEfiFBBavVYIF1dOlZT0G8jMCMBc7sOSZodFnAiryP+Qg9otSBjJ3bQML
pSTqy7c3a2AScC/YyOwkDaICHnnD3XyjMwIxALRzl0tQEKMXs6hH8ToUdlLROCrP
EhQ0wahUTCk1gKA4uPD6TMTChavbh4K63OvbKg==
-----END RSA PRIVATE KEY-----"""
# some versions of OpenSSH generate these (slightly different keys)
privateRSA_openssh_alternate = """-----BEGIN RSA PRIVATE KEY-----
MIIBzjCCAcgCAQACYQCvMnHw5g6cmbN/i18ES8uLwNU+snf9z2TYj8DPrh/GMd/2
KbJEluLG1CGUf2V82NQjH7guaskflA1GwWmitwcMo5PBNNguHkqZGmyWRgdGK9wl
KRuH8L4FHe60NLnnmZUCASMCYG4ftVWX6+1n7SuZbuzB7ahNUtbz1mUGBN/lVJ/M
iQA8m2eH7GWgq81vZZCKl5BNxiGPqI3YWYZDtYGxtNdfLCIKYcElikcStJr4ehEc
SqiLdcSRCTu+BMpF2VeKDSfLIwIxANyfa9mYIVYRjelfA50K05NuE3dBPIVPAHD9
BVT/vD0Jv4P2l39kEJEE/qJnR1RCawIxAMtKS9BAR+hFUvfHrwwgbUMNtjmU+dql
5QMGdoMk64ihVaKo3hI7d0mSiqlx0gKT/wIwS6Rffc3CSWUatmm4GJX/Zb9XIZK1
qgx1Lg2bbZmCXhH4hQQWr1WCBdXTpWU9BvIzAjAXO7DkmaHRZwIq8j/kIPaLUgYy
d20DC6Uk6su3N2tgEnAv2MjsJA2iAh55w918ozMCMQC0c5dLUBCjF7OoR/E6FHZS
0TgqzxIUNMGoVEwpNYCgOLjw+kzEwoWr24eCutzr2yowAA==
------END RSA PRIVATE KEY------"""
# encrypted with the passphrase 'encrypted'
privateRSA_openssh_encrypted = """-----BEGIN RSA PRIVATE KEY-----
Proc-Type: 4,ENCRYPTED
DEK-Info: DES-EDE3-CBC,FFFFFFFFFFFFFFFF
30qUR7DYY/rpVJu159paRM1mUqt/IMibfEMTKWSjNhCVD21hskftZCJROw/WgIFt
ncusHpJMkjgwEpho0KyKilcC7zxjpunTex24Meb5pCdXCrYft8AyUkRdq3dugMqT
4nuWuWxziluBhKQ2M9tPGcEOeulU4vVjceZt2pZhZQVBf08o3XUv5/7RYd24M9md
WIo+5zdj2YQkI6xMFTP954O/X32ME1KQt98wgNEy6mxhItbvf00mH3woALwEKP3v
PSMxxtx3VKeDKd9YTOm1giKkXZUf91vZWs0378tUBrU4U5qJxgryTjvvVKOtofj6
4qQy6+r6M6wtwVlXBgeRm2gBPvL3nv6MsROp3E6ztBd/e7A8fSec+UTq3ko/EbGP
0QG+IG5tg8FsdITxQ9WAIITZL3Rc6hA5Ymx1VNhySp3iSiso8Jof27lku4pyuvRV
ko/B3N2H7LnQrGV0GyrjeYocW/qZh/PCsY48JBFhlNQexn2mn44AJW3y5xgbhvKA
3mrmMD1hD17ZvZxi4fPHjbuAyM1vFqhQx63eT9ijbwJ91svKJl5O5MIv41mCRonm
hxvOXw8S0mjSasyofptzzQCtXxFLQigXbpQBltII+Ys=
-----END RSA PRIVATE KEY-----"""
# encrypted with the passphrase 'testxp'. NB: this key was generated by
# OpenSSH, so it doesn't use the same key data as the other keys here.
privateRSA_openssh_encrypted_aes = """-----BEGIN RSA PRIVATE KEY-----
Proc-Type: 4,ENCRYPTED
DEK-Info: AES-128-CBC,0673309A6ACCAB4B77DEE1C1E536AC26
4Ed/a9OgJWHJsne7yOGWeWMzHYKsxuP9w1v0aYcp+puS75wvhHLiUnNwxz0KDi6n
T3YkKLBsoCWS68ApR2J9yeQ6R+EyS+UQDrO9nwqo3DB5BT3Ggt8S1wE7vjNLQD0H
g/SJnlqwsECNhh8aAx+Ag0m3ZKOZiRD5mCkcDQsZET7URSmFytDKOjhFn3u6ZFVB
sXrfpYc6TJtOQlHd/52JB6aAbjt6afSv955Z7enIi+5yEJ5y7oYQTaE5zrFMP7N5
9LbfJFlKXxEddy/DErRLxEjmC+t4svHesoJKc2jjjyNPiOoGGF3kJXea62vsjdNV
gMK5Eged3TBVIk2dv8rtJUvyFeCUtjQ1UJZIebScRR47KrbsIpCmU8I4/uHWm5hW
0mOwvdx1L/mqx/BHqVU9Dw2COhOdLbFxlFI92chkovkmNk4P48ziyVnpm7ME22sE
vfCMsyirdqB1mrL4CSM7FXONv+CgfBfeYVkYW8RfJac9U1L/O+JNn7yee414O/rS
hRYw4UdWnH6Gg6niklVKWNY0ZwUZC8zgm2iqy8YCYuneS37jC+OEKP+/s6HSKuqk
2bzcl3/TcZXNSM815hnFRpz0anuyAsvwPNRyvxG2/DacJHL1f6luV4B0o6W410yf
qXQx01DLo7nuyhJqoH3UGCyyXB+/QUs0mbG2PAEn3f5dVs31JMdbt+PrxURXXjKk
4cexpUcIpqqlfpIRe3RD0sDVbH4OXsGhi2kiTfPZu7mgyFxKopRbn1KwU1qKinfY
EU9O4PoTak/tPT+5jFNhaP+HrURoi/pU8EAUNSktl7xAkHYwkN/9Cm7DeBghgf3n
8+tyCGYDsB5utPD0/Xe9yx0Qhc/kMm4xIyQDyA937dk3mUvLC9vulnAP8I+Izim0
fZ182+D1bWwykoD0997mUHG/AUChWR01V1OLwRyPv2wUtiS8VNG76Y2aqKlgqP1P
V+IvIEqR4ERvSBVFzXNF8Y6j/sVxo8+aZw+d0L1Ns/R55deErGg3B8i/2EqGd3r+
0jps9BqFHHWW87n3VyEB3jWCMj8Vi2EJIfa/7pSaViFIQn8LiBLf+zxG5LTOToK5
xkN42fReDcqi3UNfKNGnv4dsplyTR2hyx65lsj4bRKDGLKOuB1y7iB0AGb0LtcAI
dcsVlcCeUquDXtqKvRnwfIMg+ZunyjqHBhj3qgRgbXbT6zjaSdNnih569aTg0Vup
VykzZ7+n/KVcGLmvX0NesdoI7TKbq4TnEIOynuG5Sf+2GpARO5bjcWKSZeN/Ybgk
gccf8Cqf6XWqiwlWd0B7BR3SymeHIaSymC45wmbgdstrbk7Ppa2Tp9AZku8M2Y7c
8mY9b+onK075/ypiwBm4L4GRNTFLnoNQJXx0OSl4FNRWsn6ztbD+jZhu8Seu10Jw
SEJVJ+gmTKdRLYORJKyqhDet6g7kAxs4EoJ25WsOnX5nNr00rit+NkMPA7xbJT+7
CfI51GQLw7pUPeO2WNt6yZO/YkzZrqvTj5FEwybkUyBv7L0gkqu9wjfDdUw0fVHE
xEm4DxjEoaIp8dW/JOzXQ2EF+WaSOgdYsw3Ac+rnnjnNptCdOEDGP6QBkt+oXj4P
-----END RSA PRIVATE KEY-----"""
publicRSA_lsh = ("{KDEwOnB1YmxpYy1rZXkoMTQ6cnNhLXBrY3MxLXNoYTEoMTpuOTc6AK8yc"
"fDmDpyZs3+LXwRLy4vA1T6yd/3PZNiPwM+uH8Yx3/YpskSW4sbUIZR/ZXzY1CMfuC5qyR+UDUbB"
"aaK3Bwyjk8E02C4eSpkabJZGB0Yr3CUpG4fwvgUd7rQ0ueeZlSkoMTplMTojKSkp}")
privateRSA_lsh = ("(11:private-key(9:rsa-pkcs1(1:n97:\x00\xaf2q\xf0\xe6\x0e"
"\x9c\x99\xb3\x7f\x8b_\x04K\xcb\x8b\xc0\xd5>\xb2w\xfd\xcfd\xd8\x8f\xc0\xcf"
"\xae\x1f\xc61\xdf\xf6)\xb2D\x96\xe2\xc6\xd4!\x94\x7fe|\xd8\xd4#\x1f\xb8.j"
"\xc9\x1f\x94\rF\xc1i\xa2\xb7\x07\x0c\xa3\x93\xc14\xd8.\x1eJ\x99\x1al\x96F"
"\x07F+\xdc%)\x1b\x87\xf0\xbe\x05\x1d\xee\xb44\xb9\xe7\x99\x95)(1:e1:#)(1:d9"
"6:n\x1f\xb5U\x97\xeb\xedg\xed+\x99n\xec\xc1\xed\xa8MR\xd6\xf3\xd6e\x06\x04"
"\xdf\xe5T\x9f\xcc\x89\x00<\x9bg\x87\xece\xa0\xab\xcdoe\x90\x8a\x97\x90M\xc6"
'!\x8f\xa8\x8d\xd8Y\x86C\xb5\x81\xb1\xb4\xd7_,"\na\xc1%\x8aG\x12\xb4\x9a\xf8'
"z\x11\x1cJ\xa8\x8bu\xc4\x91\t;\xbe\x04\xcaE\xd9W\x8a\r\'\xcb#)(1:p49:\x00"
"\xdc\x9fk\xd9\x98!V\x11\x8d\xe9_\x03\x9d\n\xd3\x93n\x13wA<\x85O\x00p\xfd"
"\x05T\xff\xbc=\t\xbf\x83\xf6\x97\x7fd\x10\x91\x04\xfe\xa2gGTBk)(1:q49:\x00"
"\xcbJK\xd0@G\xe8ER\xf7\xc7\xaf\x0c mC\r\xb69\x94\xf9\xda\xa5\xe5\x03\x06v"
"\x83$\xeb\x88\xa1U\xa2\xa8\xde\x12;wI\x92\x8a\xa9q\xd2\x02\x93\xff)(1:a48:K"
"\xa4_}\xcd\xc2Ie\x1a\xb6i\xb8\x18\x95\xffe\xbfW!\x92\xb5\xaa\x0cu.\r\x9bm"
"\x99\x82^\x11\xf8\x85\x04\x16\xafU\x82\x05\xd5\xd3\xa5e=\x06\xf23)(1:b48:"
"\x17;\xb0\xe4\x99\xa1\xd1g\x02*\xf2?\xe4 \xf6\x8bR\x062wm\x03\x0b\xa5$\xea"
"\xcb\xb77k`\x12p/\xd8\xc8\xec$\r\xa2\x02\x1ey\xc3\xdd|\xa33)(1:c49:\x00\xb4"
"s\x97KP\x10\xa3\x17\xb3\xa8G\xf1:\x14vR\xd18*\xcf\x12\x144\xc1\xa8TL)5\x80"
"\xa08\xb8\xf0\xfaL\xc4\xc2\x85\xab\xdb\x87\x82\xba\xdc\xeb\xdb*)))")
privateRSA_agentv3 = ("\x00\x00\x00\x07ssh-rsa\x00\x00\x00\x01#\x00\x00\x00`"
"n\x1f\xb5U\x97\xeb\xedg\xed+\x99n\xec\xc1\xed\xa8MR\xd6\xf3\xd6e\x06\x04"
"\xdf\xe5T\x9f\xcc\x89\x00<\x9bg\x87\xece\xa0\xab\xcdoe\x90\x8a\x97\x90M\xc6"
'!\x8f\xa8\x8d\xd8Y\x86C\xb5\x81\xb1\xb4\xd7_,"\na\xc1%\x8aG\x12\xb4\x9a\xf8'
"z\x11\x1cJ\xa8\x8bu\xc4\x91\t;\xbe\x04\xcaE\xd9W\x8a\r\'\xcb#\x00\x00\x00a"
"\x00\xaf2q\xf0\xe6\x0e\x9c\x99\xb3\x7f\x8b_\x04K\xcb\x8b\xc0\xd5>\xb2w\xfd"
"\xcfd\xd8\x8f\xc0\xcf\xae\x1f\xc61\xdf\xf6)\xb2D\x96\xe2\xc6\xd4!\x94\x7fe|"
"\xd8\xd4#\x1f\xb8.j\xc9\x1f\x94\rF\xc1i\xa2\xb7\x07\x0c\xa3\x93\xc14\xd8."
"\x1eJ\x99\x1al\x96F\x07F+\xdc%)\x1b\x87\xf0\xbe\x05\x1d\xee\xb44\xb9\xe7"
"\x99\x95\x00\x00\x001\x00\xb4s\x97KP\x10\xa3\x17\xb3\xa8G\xf1:\x14vR\xd18*"
"\xcf\x12\x144\xc1\xa8TL)5\x80\xa08\xb8\xf0\xfaL\xc4\xc2\x85\xab\xdb\x87\x82"
"\xba\xdc\xeb\xdb*\x00\x00\x001\x00\xcbJK\xd0@G\xe8ER\xf7\xc7\xaf\x0c mC\r"
"\xb69\x94\xf9\xda\xa5\xe5\x03\x06v\x83$\xeb\x88\xa1U\xa2\xa8\xde\x12;wI\x92"
"\x8a\xa9q\xd2\x02\x93\xff\x00\x00\x001\x00\xdc\x9fk\xd9\x98!V\x11\x8d\xe9_"
"\x03\x9d\n\xd3\x93n\x13wA<\x85O\x00p\xfd\x05T\xff\xbc=\t\xbf\x83\xf6\x97"
"\x7fd\x10\x91\x04\xfe\xa2gGTBk")
publicDSA_openssh = ("ssh-dss AAAAB3NzaC1kc3MAAABBAIbwTOSsZ7Bl7U1KyMNqV13Tu7"
"yRAtTr70PVI3QnfrPumf2UzCgpL1ljbKxSfAi05XvrE/1vfCFAsFYXRZLhQy0AAAAVAM965Akmo"
"6eAi7K+k9qDR4TotFAXAAAAQADZlpTW964haQWS4vC063NGdldT6xpUGDcDRqbm90CoPEa2RmNO"
"uOqi8lnbhYraEzypYH3K4Gzv/bxCBnKtHRUAAABAK+1osyWBS0+P90u/rAuko6chZ98thUSY2kL"
"SHp6hLKyy2bjnT29h7haELE+XHfq2bM9fckDx2FLOSIJzy83VmQ== comment")
privateDSA_openssh = """-----BEGIN DSA PRIVATE KEY-----
MIH4AgEAAkEAhvBM5KxnsGXtTUrIw2pXXdO7vJEC1OvvQ9UjdCd+s+6Z/ZTMKCkv
WWNsrFJ8CLTle+sT/W98IUCwVhdFkuFDLQIVAM965Akmo6eAi7K+k9qDR4TotFAX
AkAA2ZaU1veuIWkFkuLwtOtzRnZXU+saVBg3A0am5vdAqDxGtkZjTrjqovJZ24WK
2hM8qWB9yuBs7/28QgZyrR0VAkAr7WizJYFLT4/3S7+sC6SjpyFn3y2FRJjaQtIe
nqEsrLLZuOdPb2HuFoQsT5cd+rZsz19yQPHYUs5IgnPLzdWZAhUAl1TqdmlAG/b4
nnVchGiO9sML8MM=
-----END DSA PRIVATE KEY-----"""
publicDSA_lsh = ("{KDEwOnB1YmxpYy1rZXkoMzpkc2EoMTpwNjU6AIbwTOSsZ7Bl7U1KyMNqV"
"13Tu7yRAtTr70PVI3QnfrPumf2UzCgpL1ljbKxSfAi05XvrE/1vfCFAsFYXRZLhQy0pKDE6cTIx"
"OgDPeuQJJqOngIuyvpPag0eE6LRQFykoMTpnNjQ6ANmWlNb3riFpBZLi8LTrc0Z2V1PrGlQYNwN"
"Gpub3QKg8RrZGY0646qLyWduFitoTPKlgfcrgbO/9vEIGcq0dFSkoMTp5NjQ6K+1osyWBS0+P90"
"u/rAuko6chZ98thUSY2kLSHp6hLKyy2bjnT29h7haELE+XHfq2bM9fckDx2FLOSIJzy83VmSkpK"
"Q==}")
privateDSA_lsh = ("(11:private-key(3:dsa(1:p65:\x00\x86\xf0L\xe4\xacg\xb0e"
"\xedMJ\xc8\xc3jW]\xd3\xbb\xbc\x91\x02\xd4\xeb\xefC\xd5#t'~\xb3\xee\x99\xfd"
"\x94\xcc()/Ycl\xacR|\x08\xb4\xe5{\xeb\x13\xfdo|!@\xb0V\x17E\x92\xe1C-)(1:q2"
"1:\x00\xcfz\xe4\t&\xa3\xa7\x80\x8b\xb2\xbe\x93\xda\x83G\x84\xe8\xb4P\x17)(1"
":g64:\x00\xd9\x96\x94\xd6\xf7\xae!i\x05\x92\xe2\xf0\xb4\xebsFvWS\xeb\x1aT"
"\x187\x03F\xa6\xe6\xf7@\xa8<F\xb6FcN\xb8\xea\xa2\xf2Y\xdb\x85\x8a\xda\x13<"
"\xa9`}\xca\xe0l\xef\xfd\xbcB\x06r\xad\x1d\x15)(1:y64:+\xedh\xb3%\x81KO\x8f"
"\xf7K\xbf\xac\x0b\xa4\xa3\xa7!g\xdf-\x85D\x98\xdaB\xd2\x1e\x9e\xa1,\xac\xb2"
"\xd9\xb8\xe7Ooa\xee\x16\x84,O\x97\x1d\xfa\xb6l\xcf_r@\xf1\xd8R\xceH\x82s"
"\xcb\xcd\xd5\x99)(1:x21:\x00\x97T\xeavi@\x1b\xf6\xf8\x9eu\\\x84h\x8e\xf6"
"\xc3\x0b\xf0\xc3)))")
privateDSA_agentv3 = ("\x00\x00\x00\x07ssh-dss\x00\x00\x00A\x00\x86\xf0L\xe4"
"\xacg\xb0e\xedMJ\xc8\xc3jW]\xd3\xbb\xbc\x91\x02\xd4\xeb\xefC\xd5#t'~\xb3"
"\xee\x99\xfd\x94\xcc()/Ycl\xacR|\x08\xb4\xe5{\xeb\x13\xfdo|!@\xb0V\x17E\x92"
"\xe1C-\x00\x00\x00\x15\x00\xcfz\xe4\t&\xa3\xa7\x80\x8b\xb2\xbe\x93\xda\x83G"
"\x84\xe8\xb4P\x17\x00\x00\x00@\x00\xd9\x96\x94\xd6\xf7\xae!i\x05\x92\xe2"
"\xf0\xb4\xebsFvWS\xeb\x1aT\x187\x03F\xa6\xe6\xf7@\xa8<F\xb6FcN\xb8\xea\xa2"
"\xf2Y\xdb\x85\x8a\xda\x13<\xa9`}\xca\xe0l\xef\xfd\xbcB\x06r\xad\x1d\x15\x00"
"\x00\x00@+\xedh\xb3%\x81KO\x8f\xf7K\xbf\xac\x0b\xa4\xa3\xa7!g\xdf-\x85D\x98"
"\xdaB\xd2\x1e\x9e\xa1,\xac\xb2\xd9\xb8\xe7Ooa\xee\x16\x84,O\x97\x1d\xfa\xb6"
"l\xcf_r@\xf1\xd8R\xceH\x82s\xcb\xcd\xd5\x99\x00\x00\x00\x15\x00\x97T\xeavi@"
"\x1b\xf6\xf8\x9eu\\\x84h\x8e\xf6\xc3\x0b\xf0\xc3")
__all__ = ['DSAData', 'RSAData', 'privateDSA_agentv3', 'privateDSA_lsh',
'privateDSA_openssh', 'privateRSA_agentv3', 'privateRSA_lsh',
'privateRSA_openssh', 'publicDSA_lsh', 'publicDSA_openssh',
'publicRSA_lsh', 'publicRSA_openssh', 'privateRSA_openssh_alternate']

View file

@ -0,0 +1,49 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{SSHTransportAddrress} in ssh/address.py
"""
from twisted.trial import unittest
from twisted.internet.address import IPv4Address
from twisted.internet.test.test_address import AddressTestCaseMixin
from twisted.conch.ssh.address import SSHTransportAddress
class SSHTransportAddressTestCase(unittest.TestCase, AddressTestCaseMixin):
"""
L{twisted.conch.ssh.address.SSHTransportAddress} is what Conch transports
use to represent the other side of the SSH connection. This tests the
basic functionality of that class (string representation, comparison, &c).
"""
def _stringRepresentation(self, stringFunction):
"""
The string representation of C{SSHTransportAddress} should be
"SSHTransportAddress(<stringFunction on address>)".
"""
addr = self.buildAddress()
stringValue = stringFunction(addr)
addressValue = stringFunction(addr.address)
self.assertEqual(stringValue,
"SSHTransportAddress(%s)" % addressValue)
def buildAddress(self):
"""
Create an arbitrary new C{SSHTransportAddress}. A new instance is
created for each call, but always for the same address.
"""
return SSHTransportAddress(IPv4Address("TCP", "127.0.0.1", 22))
def buildDifferentAddress(self):
"""
Like C{buildAddress}, but with a different fixed address.
"""
return SSHTransportAddress(IPv4Address("TCP", "127.0.0.2", 22))

View file

@ -0,0 +1,399 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.conch.ssh.agent}.
"""
import struct
from twisted.trial import unittest
try:
import OpenSSL
except ImportError:
iosim = None
else:
from twisted.test import iosim
try:
import Crypto.Cipher.DES3
except ImportError:
Crypto = None
try:
import pyasn1
except ImportError:
pyasn1 = None
if Crypto and pyasn1:
from twisted.conch.ssh import keys, agent
else:
keys = agent = None
from twisted.conch.test import keydata
from twisted.conch.error import ConchError, MissingKeyStoreError
class StubFactory(object):
"""
Mock factory that provides the keys attribute required by the
SSHAgentServerProtocol
"""
def __init__(self):
self.keys = {}
class AgentTestBase(unittest.TestCase):
"""
Tests for SSHAgentServer/Client.
"""
if iosim is None:
skip = "iosim requires SSL, but SSL is not available"
elif agent is None or keys is None:
skip = "Cannot run without PyCrypto or PyASN1"
def setUp(self):
# wire up our client <-> server
self.client, self.server, self.pump = iosim.connectedServerAndClient(
agent.SSHAgentServer, agent.SSHAgentClient)
# the server's end of the protocol is stateful and we store it on the
# factory, for which we only need a mock
self.server.factory = StubFactory()
# pub/priv keys of each kind
self.rsaPrivate = keys.Key.fromString(keydata.privateRSA_openssh)
self.dsaPrivate = keys.Key.fromString(keydata.privateDSA_openssh)
self.rsaPublic = keys.Key.fromString(keydata.publicRSA_openssh)
self.dsaPublic = keys.Key.fromString(keydata.publicDSA_openssh)
class TestServerProtocolContractWithFactory(AgentTestBase):
"""
The server protocol is stateful and so uses its factory to track state
across requests. This test asserts that the protocol raises if its factory
doesn't provide the necessary storage for that state.
"""
def test_factorySuppliesKeyStorageForServerProtocol(self):
# need a message to send into the server
msg = struct.pack('!LB',1, agent.AGENTC_REQUEST_IDENTITIES)
del self.server.factory.__dict__['keys']
self.assertRaises(MissingKeyStoreError,
self.server.dataReceived, msg)
class TestUnimplementedVersionOneServer(AgentTestBase):
"""
Tests for methods with no-op implementations on the server. We need these
for clients, such as openssh, that try v1 methods before going to v2.
Because the client doesn't expose these operations with nice method names,
we invoke sendRequest directly with an op code.
"""
def test_agentc_REQUEST_RSA_IDENTITIES(self):
"""
assert that we get the correct op code for an RSA identities request
"""
d = self.client.sendRequest(agent.AGENTC_REQUEST_RSA_IDENTITIES, '')
self.pump.flush()
def _cb(packet):
self.assertEqual(
agent.AGENT_RSA_IDENTITIES_ANSWER, ord(packet[0]))
return d.addCallback(_cb)
def test_agentc_REMOVE_RSA_IDENTITY(self):
"""
assert that we get the correct op code for an RSA remove identity request
"""
d = self.client.sendRequest(agent.AGENTC_REMOVE_RSA_IDENTITY, '')
self.pump.flush()
return d.addCallback(self.assertEqual, '')
def test_agentc_REMOVE_ALL_RSA_IDENTITIES(self):
"""
assert that we get the correct op code for an RSA remove all identities
request.
"""
d = self.client.sendRequest(agent.AGENTC_REMOVE_ALL_RSA_IDENTITIES, '')
self.pump.flush()
return d.addCallback(self.assertEqual, '')
if agent is not None:
class CorruptServer(agent.SSHAgentServer):
"""
A misbehaving server that returns bogus response op codes so that we can
verify that our callbacks that deal with these op codes handle such
miscreants.
"""
def agentc_REQUEST_IDENTITIES(self, data):
self.sendResponse(254, '')
def agentc_SIGN_REQUEST(self, data):
self.sendResponse(254, '')
class TestClientWithBrokenServer(AgentTestBase):
"""
verify error handling code in the client using a misbehaving server
"""
def setUp(self):
AgentTestBase.setUp(self)
self.client, self.server, self.pump = iosim.connectedServerAndClient(
CorruptServer, agent.SSHAgentClient)
# the server's end of the protocol is stateful and we store it on the
# factory, for which we only need a mock
self.server.factory = StubFactory()
def test_signDataCallbackErrorHandling(self):
"""
Assert that L{SSHAgentClient.signData} raises a ConchError
if we get a response from the server whose opcode doesn't match
the protocol for data signing requests.
"""
d = self.client.signData(self.rsaPublic.blob(), "John Hancock")
self.pump.flush()
return self.assertFailure(d, ConchError)
def test_requestIdentitiesCallbackErrorHandling(self):
"""
Assert that L{SSHAgentClient.requestIdentities} raises a ConchError
if we get a response from the server whose opcode doesn't match
the protocol for identity requests.
"""
d = self.client.requestIdentities()
self.pump.flush()
return self.assertFailure(d, ConchError)
class TestAgentKeyAddition(AgentTestBase):
"""
Test adding different flavors of keys to an agent.
"""
def test_addRSAIdentityNoComment(self):
"""
L{SSHAgentClient.addIdentity} adds the private key it is called
with to the SSH agent server to which it is connected, associating
it with the comment it is called with.
This test asserts that ommitting the comment produces an
empty string for the comment on the server.
"""
d = self.client.addIdentity(self.rsaPrivate.privateBlob())
self.pump.flush()
def _check(ignored):
serverKey = self.server.factory.keys[self.rsaPrivate.blob()]
self.assertEqual(self.rsaPrivate, serverKey[0])
self.assertEqual('', serverKey[1])
return d.addCallback(_check)
def test_addDSAIdentityNoComment(self):
"""
L{SSHAgentClient.addIdentity} adds the private key it is called
with to the SSH agent server to which it is connected, associating
it with the comment it is called with.
This test asserts that ommitting the comment produces an
empty string for the comment on the server.
"""
d = self.client.addIdentity(self.dsaPrivate.privateBlob())
self.pump.flush()
def _check(ignored):
serverKey = self.server.factory.keys[self.dsaPrivate.blob()]
self.assertEqual(self.dsaPrivate, serverKey[0])
self.assertEqual('', serverKey[1])
return d.addCallback(_check)
def test_addRSAIdentityWithComment(self):
"""
L{SSHAgentClient.addIdentity} adds the private key it is called
with to the SSH agent server to which it is connected, associating
it with the comment it is called with.
This test asserts that the server receives/stores the comment
as sent by the client.
"""
d = self.client.addIdentity(
self.rsaPrivate.privateBlob(), comment='My special key')
self.pump.flush()
def _check(ignored):
serverKey = self.server.factory.keys[self.rsaPrivate.blob()]
self.assertEqual(self.rsaPrivate, serverKey[0])
self.assertEqual('My special key', serverKey[1])
return d.addCallback(_check)
def test_addDSAIdentityWithComment(self):
"""
L{SSHAgentClient.addIdentity} adds the private key it is called
with to the SSH agent server to which it is connected, associating
it with the comment it is called with.
This test asserts that the server receives/stores the comment
as sent by the client.
"""
d = self.client.addIdentity(
self.dsaPrivate.privateBlob(), comment='My special key')
self.pump.flush()
def _check(ignored):
serverKey = self.server.factory.keys[self.dsaPrivate.blob()]
self.assertEqual(self.dsaPrivate, serverKey[0])
self.assertEqual('My special key', serverKey[1])
return d.addCallback(_check)
class TestAgentClientFailure(AgentTestBase):
def test_agentFailure(self):
"""
verify that the client raises ConchError on AGENT_FAILURE
"""
d = self.client.sendRequest(254, '')
self.pump.flush()
return self.assertFailure(d, ConchError)
class TestAgentIdentityRequests(AgentTestBase):
"""
Test operations against a server with identities already loaded.
"""
def setUp(self):
AgentTestBase.setUp(self)
self.server.factory.keys[self.dsaPrivate.blob()] = (
self.dsaPrivate, 'a comment')
self.server.factory.keys[self.rsaPrivate.blob()] = (
self.rsaPrivate, 'another comment')
def test_signDataRSA(self):
"""
Sign data with an RSA private key and then verify it with the public
key.
"""
d = self.client.signData(self.rsaPublic.blob(), "John Hancock")
self.pump.flush()
def _check(sig):
expected = self.rsaPrivate.sign("John Hancock")
self.assertEqual(expected, sig)
self.assertTrue(self.rsaPublic.verify(sig, "John Hancock"))
return d.addCallback(_check)
def test_signDataDSA(self):
"""
Sign data with a DSA private key and then verify it with the public
key.
"""
d = self.client.signData(self.dsaPublic.blob(), "John Hancock")
self.pump.flush()
def _check(sig):
# Cannot do this b/c DSA uses random numbers when signing
# expected = self.dsaPrivate.sign("John Hancock")
# self.assertEqual(expected, sig)
self.assertTrue(self.dsaPublic.verify(sig, "John Hancock"))
return d.addCallback(_check)
def test_signDataRSAErrbackOnUnknownBlob(self):
"""
Assert that we get an errback if we try to sign data using a key that
wasn't added.
"""
del self.server.factory.keys[self.rsaPublic.blob()]
d = self.client.signData(self.rsaPublic.blob(), "John Hancock")
self.pump.flush()
return self.assertFailure(d, ConchError)
def test_requestIdentities(self):
"""
Assert that we get all of the keys/comments that we add when we issue a
request for all identities.
"""
d = self.client.requestIdentities()
self.pump.flush()
def _check(keyt):
expected = {}
expected[self.dsaPublic.blob()] = 'a comment'
expected[self.rsaPublic.blob()] = 'another comment'
received = {}
for k in keyt:
received[keys.Key.fromString(k[0], type='blob').blob()] = k[1]
self.assertEqual(expected, received)
return d.addCallback(_check)
class TestAgentKeyRemoval(AgentTestBase):
"""
Test support for removing keys in a remote server.
"""
def setUp(self):
AgentTestBase.setUp(self)
self.server.factory.keys[self.dsaPrivate.blob()] = (
self.dsaPrivate, 'a comment')
self.server.factory.keys[self.rsaPrivate.blob()] = (
self.rsaPrivate, 'another comment')
def test_removeRSAIdentity(self):
"""
Assert that we can remove an RSA identity.
"""
# only need public key for this
d = self.client.removeIdentity(self.rsaPrivate.blob())
self.pump.flush()
def _check(ignored):
self.assertEqual(1, len(self.server.factory.keys))
self.assertIn(self.dsaPrivate.blob(), self.server.factory.keys)
self.assertNotIn(self.rsaPrivate.blob(), self.server.factory.keys)
return d.addCallback(_check)
def test_removeDSAIdentity(self):
"""
Assert that we can remove a DSA identity.
"""
# only need public key for this
d = self.client.removeIdentity(self.dsaPrivate.blob())
self.pump.flush()
def _check(ignored):
self.assertEqual(1, len(self.server.factory.keys))
self.assertIn(self.rsaPrivate.blob(), self.server.factory.keys)
return d.addCallback(_check)
def test_removeAllIdentities(self):
"""
Assert that we can remove all identities.
"""
d = self.client.removeAllIdentities()
self.pump.flush()
def _check(ignored):
self.assertEqual(0, len(self.server.factory.keys))
return d.addCallback(_check)

View file

@ -0,0 +1,992 @@
# -*- test-case-name: twisted.conch.test.test_cftp -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE file for details.
"""
Tests for L{twisted.conch.scripts.cftp}.
"""
import locale
import time, sys, os, operator, getpass, struct
from StringIO import StringIO
from twisted.conch.test.test_ssh import Crypto, pyasn1
_reason = None
if Crypto and pyasn1:
try:
from twisted.conch import unix
from twisted.conch.scripts import cftp
from twisted.conch.scripts.cftp import SSHSession
from twisted.conch.test.test_filetransfer import FileTransferForTestAvatar
except ImportError as e:
unix = None
_reason = str(e)
del e
else:
unix = None
from twisted.python.fakepwd import UserDatabase
from twisted.trial.unittest import TestCase
from twisted.cred import portal
from twisted.internet import reactor, protocol, interfaces, defer, error
from twisted.internet.utils import getProcessOutputAndValue
from twisted.python import log
from twisted.conch import ls
from twisted.test.proto_helpers import StringTransport
from twisted.internet.task import Clock
from twisted.conch.test import test_ssh, test_conch
from twisted.conch.test.test_filetransfer import SFTPTestBase
from twisted.conch.test.test_filetransfer import FileTransferTestAvatar
from twisted.conch.test.test_conch import FakeStdio
class SSHSessionTests(TestCase):
"""
Tests for L{twisted.conch.scripts.cftp.SSHSession}.
"""
def test_eofReceived(self):
"""
L{twisted.conch.scripts.cftp.SSHSession.eofReceived} loses the write
half of its stdio connection.
"""
stdio = FakeStdio()
channel = SSHSession()
channel.stdio = stdio
channel.eofReceived()
self.assertTrue(stdio.writeConnLost)
class ListingTests(TestCase):
"""
Tests for L{lsLine}, the function which generates an entry for a file or
directory in an SFTP I{ls} command's output.
"""
if getattr(time, 'tzset', None) is None:
skip = "Cannot test timestamp formatting code without time.tzset"
def setUp(self):
"""
Patch the L{ls} module's time function so the results of L{lsLine} are
deterministic.
"""
self.now = 123456789
def fakeTime():
return self.now
self.patch(ls, 'time', fakeTime)
# Make sure that the timezone ends up the same after these tests as
# it was before.
if 'TZ' in os.environ:
self.addCleanup(operator.setitem, os.environ, 'TZ', os.environ['TZ'])
self.addCleanup(time.tzset)
else:
def cleanup():
# os.environ.pop is broken! Don't use it! Ever! Or die!
try:
del os.environ['TZ']
except KeyError:
pass
time.tzset()
self.addCleanup(cleanup)
def _lsInTimezone(self, timezone, stat):
"""
Call L{ls.lsLine} after setting the timezone to C{timezone} and return
the result.
"""
# Set the timezone to a well-known value so the timestamps are
# predictable.
os.environ['TZ'] = timezone
time.tzset()
return ls.lsLine('foo', stat)
def test_oldFile(self):
"""
A file with an mtime six months (approximately) or more in the past has
a listing including a low-resolution timestamp.
"""
# Go with 7 months. That's more than 6 months.
then = self.now - (60 * 60 * 24 * 31 * 7)
stat = os.stat_result((0, 0, 0, 0, 0, 0, 0, 0, then, 0))
self.assertEqual(
self._lsInTimezone('America/New_York', stat),
'!--------- 0 0 0 0 Apr 26 1973 foo')
self.assertEqual(
self._lsInTimezone('Pacific/Auckland', stat),
'!--------- 0 0 0 0 Apr 27 1973 foo')
def test_oldSingleDigitDayOfMonth(self):
"""
A file with a high-resolution timestamp which falls on a day of the
month which can be represented by one decimal digit is formatted with
one padding 0 to preserve the columns which come after it.
"""
# A point about 7 months in the past, tweaked to fall on the first of a
# month so we test the case we want to test.
then = self.now - (60 * 60 * 24 * 31 * 7) + (60 * 60 * 24 * 5)
stat = os.stat_result((0, 0, 0, 0, 0, 0, 0, 0, then, 0))
self.assertEqual(
self._lsInTimezone('America/New_York', stat),
'!--------- 0 0 0 0 May 01 1973 foo')
self.assertEqual(
self._lsInTimezone('Pacific/Auckland', stat),
'!--------- 0 0 0 0 May 02 1973 foo')
def test_newFile(self):
"""
A file with an mtime fewer than six months (approximately) in the past
has a listing including a high-resolution timestamp excluding the year.
"""
# A point about three months in the past.
then = self.now - (60 * 60 * 24 * 31 * 3)
stat = os.stat_result((0, 0, 0, 0, 0, 0, 0, 0, then, 0))
self.assertEqual(
self._lsInTimezone('America/New_York', stat),
'!--------- 0 0 0 0 Aug 28 17:33 foo')
self.assertEqual(
self._lsInTimezone('Pacific/Auckland', stat),
'!--------- 0 0 0 0 Aug 29 09:33 foo')
def test_localeIndependent(self):
"""
The month name in the date is locale independent.
"""
# A point about three months in the past.
then = self.now - (60 * 60 * 24 * 31 * 3)
stat = os.stat_result((0, 0, 0, 0, 0, 0, 0, 0, then, 0))
# Fake that we're in a language where August is not Aug (e.g.: Spanish)
currentLocale = locale.getlocale()
locale.setlocale(locale.LC_ALL, "es_AR.UTF8")
self.addCleanup(locale.setlocale, locale.LC_ALL, currentLocale)
self.assertEqual(
self._lsInTimezone('America/New_York', stat),
'!--------- 0 0 0 0 Aug 28 17:33 foo')
self.assertEqual(
self._lsInTimezone('Pacific/Auckland', stat),
'!--------- 0 0 0 0 Aug 29 09:33 foo')
# if alternate locale is not available, the previous test will be
# skipped, please install this locale for it to run
currentLocale = locale.getlocale()
try:
try:
locale.setlocale(locale.LC_ALL, "es_AR.UTF8")
except locale.Error:
test_localeIndependent.skip = "The es_AR.UTF8 locale is not installed."
finally:
locale.setlocale(locale.LC_ALL, currentLocale)
def test_newSingleDigitDayOfMonth(self):
"""
A file with a high-resolution timestamp which falls on a day of the
month which can be represented by one decimal digit is formatted with
one padding 0 to preserve the columns which come after it.
"""
# A point about three months in the past, tweaked to fall on the first
# of a month so we test the case we want to test.
then = self.now - (60 * 60 * 24 * 31 * 3) + (60 * 60 * 24 * 4)
stat = os.stat_result((0, 0, 0, 0, 0, 0, 0, 0, then, 0))
self.assertEqual(
self._lsInTimezone('America/New_York', stat),
'!--------- 0 0 0 0 Sep 01 17:33 foo')
self.assertEqual(
self._lsInTimezone('Pacific/Auckland', stat),
'!--------- 0 0 0 0 Sep 02 09:33 foo')
class StdioClientTests(TestCase):
"""
Tests for L{cftp.StdioClient}.
"""
def setUp(self):
"""
Create a L{cftp.StdioClient} hooked up to dummy transport and a fake
user database.
"""
class Connection:
pass
conn = Connection()
conn.transport = StringTransport()
conn.transport.localClosed = False
self.client = cftp.StdioClient(conn)
self.database = self.client._pwd = UserDatabase()
# Intentionally bypassing makeConnection - that triggers some code
# which uses features not provided by our dumb Connection fake.
self.client.transport = StringTransport()
def test_exec(self):
"""
The I{exec} command runs its arguments locally in a child process
using the user's shell.
"""
self.database.addUser(
getpass.getuser(), 'secret', os.getuid(), 1234, 'foo', 'bar',
sys.executable)
d = self.client._dispatchCommand("exec print 1 + 2")
d.addCallback(self.assertEqual, "3\n")
return d
def test_execWithoutShell(self):
"""
If the local user has no shell, the I{exec} command runs its arguments
using I{/bin/sh}.
"""
self.database.addUser(
getpass.getuser(), 'secret', os.getuid(), 1234, 'foo', 'bar', '')
d = self.client._dispatchCommand("exec echo hello")
d.addCallback(self.assertEqual, "hello\n")
return d
def test_bang(self):
"""
The I{exec} command is run for lines which start with C{"!"}.
"""
self.database.addUser(
getpass.getuser(), 'secret', os.getuid(), 1234, 'foo', 'bar',
'/bin/sh')
d = self.client._dispatchCommand("!echo hello")
d.addCallback(self.assertEqual, "hello\n")
return d
def setKnownConsoleSize(self, width, height):
"""
For the duration of this test, patch C{cftp}'s C{fcntl} module to return
a fixed width and height.
@param width: the width in characters
@type width: C{int}
@param height: the height in characters
@type height: C{int}
"""
import tty # local import to avoid win32 issues
class FakeFcntl(object):
def ioctl(self, fd, opt, mutate):
if opt != tty.TIOCGWINSZ:
self.fail("Only window-size queries supported.")
return struct.pack("4H", height, width, 0, 0)
self.patch(cftp, "fcntl", FakeFcntl())
def test_progressReporting(self):
"""
L{StdioClient._printProgressBar} prints a progress description,
including percent done, amount transferred, transfer rate, and time
remaining, all based the given start time, the given L{FileWrapper}'s
progress information and the reactor's current time.
"""
# Use a short, known console width because this simple test doesn't need
# to test the console padding.
self.setKnownConsoleSize(10, 34)
clock = self.client.reactor = Clock()
wrapped = StringIO("x")
wrapped.name = "sample"
wrapper = cftp.FileWrapper(wrapped)
wrapper.size = 1024 * 10
startTime = clock.seconds()
clock.advance(2.0)
wrapper.total += 4096
self.client._printProgressBar(wrapper, startTime)
self.assertEqual(self.client.transport.value(),
"\rsample 40% 4.0kB 2.0kBps 00:03 ")
def test_reportNoProgress(self):
"""
L{StdioClient._printProgressBar} prints a progress description that
indicates 0 bytes transferred if no bytes have been transferred and no
time has passed.
"""
self.setKnownConsoleSize(10, 34)
clock = self.client.reactor = Clock()
wrapped = StringIO("x")
wrapped.name = "sample"
wrapper = cftp.FileWrapper(wrapped)
startTime = clock.seconds()
self.client._printProgressBar(wrapper, startTime)
self.assertEqual(self.client.transport.value(),
"\rsample 0% 0.0B 0.0Bps 00:00 ")
class FileTransferTestRealm:
def __init__(self, testDir):
self.testDir = testDir
def requestAvatar(self, avatarID, mind, *interfaces):
a = FileTransferTestAvatar(self.testDir)
return interfaces[0], a, lambda: None
class SFTPTestProcess(protocol.ProcessProtocol):
"""
Protocol for testing cftp. Provides an interface between Python (where all
the tests are) and the cftp client process (which does the work that is
being tested).
"""
def __init__(self, onOutReceived):
"""
@param onOutReceived: A L{Deferred} to be fired as soon as data is
received from stdout.
"""
self.clearBuffer()
self.onOutReceived = onOutReceived
self.onProcessEnd = None
self._expectingCommand = None
self._processEnded = False
def clearBuffer(self):
"""
Clear any buffered data received from stdout. Should be private.
"""
self.buffer = ''
self._linesReceived = []
self._lineBuffer = ''
def outReceived(self, data):
"""
Called by Twisted when the cftp client prints data to stdout.
"""
log.msg('got %s' % data)
lines = (self._lineBuffer + data).split('\n')
self._lineBuffer = lines.pop(-1)
self._linesReceived.extend(lines)
# XXX - not strictly correct.
# We really want onOutReceived to fire after the first 'cftp>' prompt
# has been received. (See use in TestOurServerCmdLineClient.setUp)
if self.onOutReceived is not None:
d, self.onOutReceived = self.onOutReceived, None
d.callback(data)
self.buffer += data
self._checkForCommand()
def _checkForCommand(self):
prompt = 'cftp> '
if self._expectingCommand and self._lineBuffer == prompt:
buf = '\n'.join(self._linesReceived)
if buf.startswith(prompt):
buf = buf[len(prompt):]
self.clearBuffer()
d, self._expectingCommand = self._expectingCommand, None
d.callback(buf)
def errReceived(self, data):
"""
Called by Twisted when the cftp client prints data to stderr.
"""
log.msg('err: %s' % data)
def getBuffer(self):
"""
Return the contents of the buffer of data received from stdout.
"""
return self.buffer
def runCommand(self, command):
"""
Issue the given command via the cftp client. Return a C{Deferred} that
fires when the server returns a result. Note that the C{Deferred} will
callback even if the server returns some kind of error.
@param command: A string containing an sftp command.
@return: A C{Deferred} that fires when the sftp server returns a
result. The payload is the server's response string.
"""
self._expectingCommand = defer.Deferred()
self.clearBuffer()
self.transport.write(command + '\n')
return self._expectingCommand
def runScript(self, commands):
"""
Run each command in sequence and return a Deferred that fires when all
commands are completed.
@param commands: A list of strings containing sftp commands.
@return: A C{Deferred} that fires when all commands are completed. The
payload is a list of response strings from the server, in the same
order as the commands.
"""
sem = defer.DeferredSemaphore(1)
dl = [sem.run(self.runCommand, command) for command in commands]
return defer.gatherResults(dl)
def killProcess(self):
"""
Kill the process if it is still running.
If the process is still running, sends a KILL signal to the transport
and returns a C{Deferred} which fires when L{processEnded} is called.
@return: a C{Deferred}.
"""
if self._processEnded:
return defer.succeed(None)
self.onProcessEnd = defer.Deferred()
self.transport.signalProcess('KILL')
return self.onProcessEnd
def processEnded(self, reason):
"""
Called by Twisted when the cftp client process ends.
"""
self._processEnded = True
if self.onProcessEnd:
d, self.onProcessEnd = self.onProcessEnd, None
d.callback(None)
class CFTPClientTestBase(SFTPTestBase):
def setUp(self):
f = open('dsa_test.pub','w')
f.write(test_ssh.publicDSA_openssh)
f.close()
f = open('dsa_test','w')
f.write(test_ssh.privateDSA_openssh)
f.close()
os.chmod('dsa_test', 33152)
f = open('kh_test','w')
f.write('127.0.0.1 ' + test_ssh.publicRSA_openssh)
f.close()
return SFTPTestBase.setUp(self)
def startServer(self):
realm = FileTransferTestRealm(self.testDir)
p = portal.Portal(realm)
p.registerChecker(test_ssh.ConchTestPublicKeyChecker())
fac = test_ssh.ConchTestServerFactory()
fac.portal = p
self.server = reactor.listenTCP(0, fac, interface="127.0.0.1")
def stopServer(self):
if not hasattr(self.server.factory, 'proto'):
return self._cbStopServer(None)
self.server.factory.proto.expectedLoseConnection = 1
d = defer.maybeDeferred(
self.server.factory.proto.transport.loseConnection)
d.addCallback(self._cbStopServer)
return d
def _cbStopServer(self, ignored):
return defer.maybeDeferred(self.server.stopListening)
def tearDown(self):
for f in ['dsa_test.pub', 'dsa_test', 'kh_test']:
try:
os.remove(f)
except:
pass
return SFTPTestBase.tearDown(self)
class TestOurServerCmdLineClient(CFTPClientTestBase):
def setUp(self):
CFTPClientTestBase.setUp(self)
self.startServer()
cmds = ('-p %i -l testuser '
'--known-hosts kh_test '
'--user-authentications publickey '
'--host-key-algorithms ssh-rsa '
'-i dsa_test '
'-a '
'-v '
'127.0.0.1')
port = self.server.getHost().port
cmds = test_conch._makeArgs((cmds % port).split(), mod='cftp')
log.msg('running %s %s' % (sys.executable, cmds))
d = defer.Deferred()
self.processProtocol = SFTPTestProcess(d)
d.addCallback(lambda _: self.processProtocol.clearBuffer())
env = os.environ.copy()
env['PYTHONPATH'] = os.pathsep.join(sys.path)
reactor.spawnProcess(self.processProtocol, sys.executable, cmds,
env=env)
return d
def tearDown(self):
d = self.stopServer()
d.addCallback(lambda _: self.processProtocol.killProcess())
return d
def _killProcess(self, ignored):
try:
self.processProtocol.transport.signalProcess('KILL')
except error.ProcessExitedAlready:
pass
def runCommand(self, command):
"""
Run the given command with the cftp client. Return a C{Deferred} that
fires when the command is complete. Payload is the server's output for
that command.
"""
return self.processProtocol.runCommand(command)
def runScript(self, *commands):
"""
Run the given commands with the cftp client. Returns a C{Deferred}
that fires when the commands are all complete. The C{Deferred}'s
payload is a list of output for each command.
"""
return self.processProtocol.runScript(commands)
def testCdPwd(self):
"""
Test that 'pwd' reports the current remote directory, that 'lpwd'
reports the current local directory, and that changing to a
subdirectory then changing to its parent leaves you in the original
remote directory.
"""
# XXX - not actually a unit test, see docstring.
homeDir = os.path.join(os.getcwd(), self.testDir)
d = self.runScript('pwd', 'lpwd', 'cd testDirectory', 'cd ..', 'pwd')
d.addCallback(lambda xs: xs[:3] + xs[4:])
d.addCallback(self.assertEqual,
[homeDir, os.getcwd(), '', homeDir])
return d
def testChAttrs(self):
"""
Check that 'ls -l' output includes the access permissions and that
this output changes appropriately with 'chmod'.
"""
def _check(results):
self.flushLoggedErrors()
self.assertTrue(results[0].startswith('-rw-r--r--'))
self.assertEqual(results[1], '')
self.assertTrue(results[2].startswith('----------'), results[2])
self.assertEqual(results[3], '')
d = self.runScript('ls -l testfile1', 'chmod 0 testfile1',
'ls -l testfile1', 'chmod 644 testfile1')
return d.addCallback(_check)
# XXX test chgrp/own
def testList(self):
"""
Check 'ls' works as expected. Checks for wildcards, hidden files,
listing directories and listing empty directories.
"""
def _check(results):
self.assertEqual(results[0], ['testDirectory', 'testRemoveFile',
'testRenameFile', 'testfile1'])
self.assertEqual(results[1], ['testDirectory', 'testRemoveFile',
'testRenameFile', 'testfile1'])
self.assertEqual(results[2], ['testRemoveFile', 'testRenameFile'])
self.assertEqual(results[3], ['.testHiddenFile', 'testRemoveFile',
'testRenameFile'])
self.assertEqual(results[4], [''])
d = self.runScript('ls', 'ls ../' + os.path.basename(self.testDir),
'ls *File', 'ls -a *File', 'ls -l testDirectory')
d.addCallback(lambda xs: [x.split('\n') for x in xs])
return d.addCallback(_check)
def testHelp(self):
"""
Check that running the '?' command returns help.
"""
d = self.runCommand('?')
d.addCallback(self.assertEqual,
cftp.StdioClient(None).cmd_HELP('').strip())
return d
def assertFilesEqual(self, name1, name2, msg=None):
"""
Assert that the files at C{name1} and C{name2} contain exactly the
same data.
"""
f1 = file(name1).read()
f2 = file(name2).read()
self.assertEqual(f1, f2, msg)
def testGet(self):
"""
Test that 'get' saves the remote file to the correct local location,
that the output of 'get' is correct and that 'rm' actually removes
the file.
"""
# XXX - not actually a unit test
expectedOutput = ("Transferred %s/%s/testfile1 to %s/test file2"
% (os.getcwd(), self.testDir, self.testDir))
def _checkGet(result):
self.assertTrue(result.endswith(expectedOutput))
self.assertFilesEqual(self.testDir + '/testfile1',
self.testDir + '/test file2',
"get failed")
return self.runCommand('rm "test file2"')
d = self.runCommand('get testfile1 "%s/test file2"' % (self.testDir,))
d.addCallback(_checkGet)
d.addCallback(lambda _: self.assertFalse(
os.path.exists(self.testDir + '/test file2')))
return d
def testWildcardGet(self):
"""
Test that 'get' works correctly when given wildcard parameters.
"""
def _check(ignored):
self.assertFilesEqual(self.testDir + '/testRemoveFile',
'testRemoveFile',
'testRemoveFile get failed')
self.assertFilesEqual(self.testDir + '/testRenameFile',
'testRenameFile',
'testRenameFile get failed')
d = self.runCommand('get testR*')
return d.addCallback(_check)
def testPut(self):
"""
Check that 'put' uploads files correctly and that they can be
successfully removed. Also check the output of the put command.
"""
# XXX - not actually a unit test
expectedOutput = ('Transferred %s/testfile1 to %s/%s/test"file2'
% (self.testDir, os.getcwd(), self.testDir))
def _checkPut(result):
self.assertFilesEqual(self.testDir + '/testfile1',
self.testDir + '/test"file2')
self.assertTrue(result.endswith(expectedOutput))
return self.runCommand('rm "test\\"file2"')
d = self.runCommand('put %s/testfile1 "test\\"file2"'
% (self.testDir,))
d.addCallback(_checkPut)
d.addCallback(lambda _: self.assertFalse(
os.path.exists(self.testDir + '/test"file2')))
return d
def test_putOverLongerFile(self):
"""
Check that 'put' uploads files correctly when overwriting a longer
file.
"""
# XXX - not actually a unit test
f = file(os.path.join(self.testDir, 'shorterFile'), 'w')
f.write("a")
f.close()
f = file(os.path.join(self.testDir, 'longerFile'), 'w')
f.write("bb")
f.close()
def _checkPut(result):
self.assertFilesEqual(self.testDir + '/shorterFile',
self.testDir + '/longerFile')
d = self.runCommand('put %s/shorterFile longerFile'
% (self.testDir,))
d.addCallback(_checkPut)
return d
def test_putMultipleOverLongerFile(self):
"""
Check that 'put' uploads files correctly when overwriting a longer
file and you use a wildcard to specify the files to upload.
"""
# XXX - not actually a unit test
os.mkdir(os.path.join(self.testDir, 'dir'))
f = file(os.path.join(self.testDir, 'dir', 'file'), 'w')
f.write("a")
f.close()
f = file(os.path.join(self.testDir, 'file'), 'w')
f.write("bb")
f.close()
def _checkPut(result):
self.assertFilesEqual(self.testDir + '/dir/file',
self.testDir + '/file')
d = self.runCommand('put %s/dir/*'
% (self.testDir,))
d.addCallback(_checkPut)
return d
def testWildcardPut(self):
"""
What happens if you issue a 'put' command and include a wildcard (i.e.
'*') in parameter? Check that all files matching the wildcard are
uploaded to the correct directory.
"""
def check(results):
self.assertEqual(results[0], '')
self.assertEqual(results[2], '')
self.assertFilesEqual(self.testDir + '/testRemoveFile',
self.testDir + '/../testRemoveFile',
'testRemoveFile get failed')
self.assertFilesEqual(self.testDir + '/testRenameFile',
self.testDir + '/../testRenameFile',
'testRenameFile get failed')
d = self.runScript('cd ..',
'put %s/testR*' % (self.testDir,),
'cd %s' % os.path.basename(self.testDir))
d.addCallback(check)
return d
def testLink(self):
"""
Test that 'ln' creates a file which appears as a link in the output of
'ls'. Check that removing the new file succeeds without output.
"""
def _check(results):
self.flushLoggedErrors()
self.assertEqual(results[0], '')
self.assertTrue(results[1].startswith('l'), 'link failed')
return self.runCommand('rm testLink')
d = self.runScript('ln testLink testfile1', 'ls -l testLink')
d.addCallback(_check)
d.addCallback(self.assertEqual, '')
return d
def testRemoteDirectory(self):
"""
Test that we can create and remove directories with the cftp client.
"""
def _check(results):
self.assertEqual(results[0], '')
self.assertTrue(results[1].startswith('d'))
return self.runCommand('rmdir testMakeDirectory')
d = self.runScript('mkdir testMakeDirectory',
'ls -l testMakeDirector?')
d.addCallback(_check)
d.addCallback(self.assertEqual, '')
return d
def test_existingRemoteDirectory(self):
"""
Test that a C{mkdir} on an existing directory fails with the
appropriate error, and doesn't log an useless error server side.
"""
def _check(results):
self.assertEqual(results[0], '')
self.assertEqual(results[1],
'remote error 11: mkdir failed')
d = self.runScript('mkdir testMakeDirectory',
'mkdir testMakeDirectory')
d.addCallback(_check)
return d
def testLocalDirectory(self):
"""
Test that we can create a directory locally and remove it with the
cftp client. This test works because the 'remote' server is running
out of a local directory.
"""
d = self.runCommand('lmkdir %s/testLocalDirectory' % (self.testDir,))
d.addCallback(self.assertEqual, '')
d.addCallback(lambda _: self.runCommand('rmdir testLocalDirectory'))
d.addCallback(self.assertEqual, '')
return d
def testRename(self):
"""
Test that we can rename a file.
"""
def _check(results):
self.assertEqual(results[0], '')
self.assertEqual(results[1], 'testfile2')
return self.runCommand('rename testfile2 testfile1')
d = self.runScript('rename testfile1 testfile2', 'ls testfile?')
d.addCallback(_check)
d.addCallback(self.assertEqual, '')
return d
class TestOurServerBatchFile(CFTPClientTestBase):
def setUp(self):
CFTPClientTestBase.setUp(self)
self.startServer()
def tearDown(self):
CFTPClientTestBase.tearDown(self)
return self.stopServer()
def _getBatchOutput(self, f):
fn = self.mktemp()
open(fn, 'w').write(f)
port = self.server.getHost().port
cmds = ('-p %i -l testuser '
'--known-hosts kh_test '
'--user-authentications publickey '
'--host-key-algorithms ssh-rsa '
'-i dsa_test '
'-a '
'-v -b %s 127.0.0.1') % (port, fn)
cmds = test_conch._makeArgs(cmds.split(), mod='cftp')[1:]
log.msg('running %s %s' % (sys.executable, cmds))
env = os.environ.copy()
env['PYTHONPATH'] = os.pathsep.join(sys.path)
self.server.factory.expectedLoseConnection = 1
d = getProcessOutputAndValue(sys.executable, cmds, env=env)
def _cleanup(res):
os.remove(fn)
return res
d.addCallback(lambda res: res[0])
d.addBoth(_cleanup)
return d
def testBatchFile(self):
"""Test whether batch file function of cftp ('cftp -b batchfile').
This works by treating the file as a list of commands to be run.
"""
cmds = """pwd
ls
exit
"""
def _cbCheckResult(res):
res = res.split('\n')
log.msg('RES %s' % str(res))
self.assertIn(self.testDir, res[1])
self.assertEqual(res[3:-2], ['testDirectory', 'testRemoveFile',
'testRenameFile', 'testfile1'])
d = self._getBatchOutput(cmds)
d.addCallback(_cbCheckResult)
return d
def testError(self):
"""Test that an error in the batch file stops running the batch.
"""
cmds = """chown 0 missingFile
pwd
exit
"""
def _cbCheckResult(res):
self.assertNotIn(self.testDir, res)
d = self._getBatchOutput(cmds)
d.addCallback(_cbCheckResult)
return d
def testIgnoredError(self):
"""Test that a minus sign '-' at the front of a line ignores
any errors.
"""
cmds = """-chown 0 missingFile
pwd
exit
"""
def _cbCheckResult(res):
self.assertIn(self.testDir, res)
d = self._getBatchOutput(cmds)
d.addCallback(_cbCheckResult)
return d
class TestOurServerSftpClient(CFTPClientTestBase):
"""
Test the sftp server against sftp command line client.
"""
def setUp(self):
CFTPClientTestBase.setUp(self)
return self.startServer()
def tearDown(self):
return self.stopServer()
def test_extendedAttributes(self):
"""
Test the return of extended attributes by the server: the sftp client
should ignore them, but still be able to parse the response correctly.
This test is mainly here to check that
L{filetransfer.FILEXFER_ATTR_EXTENDED} has the correct value.
"""
fn = self.mktemp()
open(fn, 'w').write("ls .\nexit")
port = self.server.getHost().port
oldGetAttr = FileTransferForTestAvatar._getAttrs
def _getAttrs(self, s):
attrs = oldGetAttr(self, s)
attrs["ext_foo"] = "bar"
return attrs
self.patch(FileTransferForTestAvatar, "_getAttrs", _getAttrs)
self.server.factory.expectedLoseConnection = True
cmds = ('-o', 'IdentityFile=dsa_test',
'-o', 'UserKnownHostsFile=kh_test',
'-o', 'HostKeyAlgorithms=ssh-rsa',
'-o', 'Port=%i' % (port,), '-b', fn, 'testuser@127.0.0.1')
d = getProcessOutputAndValue("sftp", cmds)
def check(result):
self.assertEqual(result[2], 0)
for i in ['testDirectory', 'testRemoveFile',
'testRenameFile', 'testfile1']:
self.assertIn(i, result[0])
return d.addCallback(check)
if unix is None or Crypto is None or pyasn1 is None or interfaces.IReactorProcess(reactor, None) is None:
if _reason is None:
_reason = "don't run w/o spawnProcess or PyCrypto or pyasn1"
TestOurServerCmdLineClient.skip = _reason
TestOurServerBatchFile.skip = _reason
TestOurServerSftpClient.skip = _reason
StdioClientTests.skip = _reason
SSHSessionTests.skip = _reason
else:
from twisted.python.procutils import which
if not which('sftp'):
TestOurServerSftpClient.skip = "no sftp command-line client available"

View file

@ -0,0 +1,279 @@
# Copyright (C) 2007-2008 Twisted Matrix Laboratories
# See LICENSE for details
"""
Test ssh/channel.py.
"""
from twisted.conch.ssh import channel
from twisted.trial import unittest
class MockTransport(object):
"""
A mock Transport. All we use is the getPeer() and getHost() methods.
Channels implement the ITransport interface, and their getPeer() and
getHost() methods return ('SSH', <transport's getPeer/Host value>) so
we need to implement these methods so they have something to draw
from.
"""
def getPeer(self):
return ('MockPeer',)
def getHost(self):
return ('MockHost',)
class MockConnection(object):
"""
A mock for twisted.conch.ssh.connection.SSHConnection. Record the data
that channels send, and when they try to close the connection.
@ivar data: a C{dict} mapping channel id #s to lists of data sent by that
channel.
@ivar extData: a C{dict} mapping channel id #s to lists of 2-tuples
(extended data type, data) sent by that channel.
@ivar closes: a C{dict} mapping channel id #s to True if that channel sent
a close message.
"""
transport = MockTransport()
def __init__(self):
self.data = {}
self.extData = {}
self.closes = {}
def logPrefix(self):
"""
Return our logging prefix.
"""
return "MockConnection"
def sendData(self, channel, data):
"""
Record the sent data.
"""
self.data.setdefault(channel, []).append(data)
def sendExtendedData(self, channel, type, data):
"""
Record the sent extended data.
"""
self.extData.setdefault(channel, []).append((type, data))
def sendClose(self, channel):
"""
Record that the channel sent a close message.
"""
self.closes[channel] = True
class ChannelTestCase(unittest.TestCase):
def setUp(self):
"""
Initialize the channel. remoteMaxPacket is 10 so that data is able
to be sent (the default of 0 means no data is sent because no packets
are made).
"""
self.conn = MockConnection()
self.channel = channel.SSHChannel(conn=self.conn,
remoteMaxPacket=10)
self.channel.name = 'channel'
def test_init(self):
"""
Test that SSHChannel initializes correctly. localWindowSize defaults
to 131072 (2**17) and localMaxPacket to 32768 (2**15) as reasonable
defaults (what OpenSSH uses for those variables).
The values in the second set of assertions are meaningless; they serve
only to verify that the instance variables are assigned in the correct
order.
"""
c = channel.SSHChannel(conn=self.conn)
self.assertEqual(c.localWindowSize, 131072)
self.assertEqual(c.localWindowLeft, 131072)
self.assertEqual(c.localMaxPacket, 32768)
self.assertEqual(c.remoteWindowLeft, 0)
self.assertEqual(c.remoteMaxPacket, 0)
self.assertEqual(c.conn, self.conn)
self.assertEqual(c.data, None)
self.assertEqual(c.avatar, None)
c2 = channel.SSHChannel(1, 2, 3, 4, 5, 6, 7)
self.assertEqual(c2.localWindowSize, 1)
self.assertEqual(c2.localWindowLeft, 1)
self.assertEqual(c2.localMaxPacket, 2)
self.assertEqual(c2.remoteWindowLeft, 3)
self.assertEqual(c2.remoteMaxPacket, 4)
self.assertEqual(c2.conn, 5)
self.assertEqual(c2.data, 6)
self.assertEqual(c2.avatar, 7)
def test_str(self):
"""
Test that str(SSHChannel) works gives the channel name and local and
remote windows at a glance..
"""
self.assertEqual(str(self.channel), '<SSHChannel channel (lw 131072 '
'rw 0)>')
def test_logPrefix(self):
"""
Test that SSHChannel.logPrefix gives the name of the channel, the
local channel ID and the underlying connection.
"""
self.assertEqual(self.channel.logPrefix(), 'SSHChannel channel '
'(unknown) on MockConnection')
def test_addWindowBytes(self):
"""
Test that addWindowBytes adds bytes to the window and resumes writing
if it was paused.
"""
cb = [False]
def stubStartWriting():
cb[0] = True
self.channel.startWriting = stubStartWriting
self.channel.write('test')
self.channel.writeExtended(1, 'test')
self.channel.addWindowBytes(50)
self.assertEqual(self.channel.remoteWindowLeft, 50 - 4 - 4)
self.assertTrue(self.channel.areWriting)
self.assertTrue(cb[0])
self.assertEqual(self.channel.buf, '')
self.assertEqual(self.conn.data[self.channel], ['test'])
self.assertEqual(self.channel.extBuf, [])
self.assertEqual(self.conn.extData[self.channel], [(1, 'test')])
cb[0] = False
self.channel.addWindowBytes(20)
self.assertFalse(cb[0])
self.channel.write('a'*80)
self.channel.loseConnection()
self.channel.addWindowBytes(20)
self.assertFalse(cb[0])
def test_requestReceived(self):
"""
Test that requestReceived handles requests by dispatching them to
request_* methods.
"""
self.channel.request_test_method = lambda data: data == ''
self.assertTrue(self.channel.requestReceived('test-method', ''))
self.assertFalse(self.channel.requestReceived('test-method', 'a'))
self.assertFalse(self.channel.requestReceived('bad-method', ''))
def test_closeReceieved(self):
"""
Test that the default closeReceieved closes the connection.
"""
self.assertFalse(self.channel.closing)
self.channel.closeReceived()
self.assertTrue(self.channel.closing)
def test_write(self):
"""
Test that write handles data correctly. Send data up to the size
of the remote window, splitting the data into packets of length
remoteMaxPacket.
"""
cb = [False]
def stubStopWriting():
cb[0] = True
# no window to start with
self.channel.stopWriting = stubStopWriting
self.channel.write('d')
self.channel.write('a')
self.assertFalse(self.channel.areWriting)
self.assertTrue(cb[0])
# regular write
self.channel.addWindowBytes(20)
self.channel.write('ta')
data = self.conn.data[self.channel]
self.assertEqual(data, ['da', 'ta'])
self.assertEqual(self.channel.remoteWindowLeft, 16)
# larger than max packet
self.channel.write('12345678901')
self.assertEqual(data, ['da', 'ta', '1234567890', '1'])
self.assertEqual(self.channel.remoteWindowLeft, 5)
# running out of window
cb[0] = False
self.channel.write('123456')
self.assertFalse(self.channel.areWriting)
self.assertTrue(cb[0])
self.assertEqual(data, ['da', 'ta', '1234567890', '1', '12345'])
self.assertEqual(self.channel.buf, '6')
self.assertEqual(self.channel.remoteWindowLeft, 0)
def test_writeExtended(self):
"""
Test that writeExtended handles data correctly. Send extended data
up to the size of the window, splitting the extended data into packets
of length remoteMaxPacket.
"""
cb = [False]
def stubStopWriting():
cb[0] = True
# no window to start with
self.channel.stopWriting = stubStopWriting
self.channel.writeExtended(1, 'd')
self.channel.writeExtended(1, 'a')
self.channel.writeExtended(2, 't')
self.assertFalse(self.channel.areWriting)
self.assertTrue(cb[0])
# regular write
self.channel.addWindowBytes(20)
self.channel.writeExtended(2, 'a')
data = self.conn.extData[self.channel]
self.assertEqual(data, [(1, 'da'), (2, 't'), (2, 'a')])
self.assertEqual(self.channel.remoteWindowLeft, 16)
# larger than max packet
self.channel.writeExtended(3, '12345678901')
self.assertEqual(data, [(1, 'da'), (2, 't'), (2, 'a'),
(3, '1234567890'), (3, '1')])
self.assertEqual(self.channel.remoteWindowLeft, 5)
# running out of window
cb[0] = False
self.channel.writeExtended(4, '123456')
self.assertFalse(self.channel.areWriting)
self.assertTrue(cb[0])
self.assertEqual(data, [(1, 'da'), (2, 't'), (2, 'a'),
(3, '1234567890'), (3, '1'), (4, '12345')])
self.assertEqual(self.channel.extBuf, [[4, '6']])
self.assertEqual(self.channel.remoteWindowLeft, 0)
def test_writeSequence(self):
"""
Test that writeSequence is equivalent to write(''.join(sequece)).
"""
self.channel.addWindowBytes(20)
self.channel.writeSequence(map(str, range(10)))
self.assertEqual(self.conn.data[self.channel], ['0123456789'])
def test_loseConnection(self):
"""
Tesyt that loseConnection() doesn't close the channel until all
the data is sent.
"""
self.channel.write('data')
self.channel.writeExtended(1, 'datadata')
self.channel.loseConnection()
self.assertEqual(self.conn.closes.get(self.channel), None)
self.channel.addWindowBytes(4) # send regular data
self.assertEqual(self.conn.closes.get(self.channel), None)
self.channel.addWindowBytes(8) # send extended data
self.assertTrue(self.conn.closes.get(self.channel))
def test_getPeer(self):
"""
Test that getPeer() returns ('SSH', <connection transport peer>).
"""
self.assertEqual(self.channel.getPeer(), ('SSH', 'MockPeer'))
def test_getHost(self):
"""
Test that getHost() returns ('SSH', <connection transport host>).
"""
self.assertEqual(self.channel.getHost(), ('SSH', 'MockHost'))

View file

@ -0,0 +1,603 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.conch.checkers}.
"""
try:
import crypt
except ImportError:
cryptSkip = 'cannot run without crypt module'
else:
cryptSkip = None
import os, base64
from twisted.python import util
from twisted.python.failure import Failure
from twisted.trial.unittest import TestCase
from twisted.python.filepath import FilePath
from twisted.cred.checkers import InMemoryUsernamePasswordDatabaseDontUse
from twisted.cred.credentials import UsernamePassword, IUsernamePassword, \
SSHPrivateKey, ISSHPrivateKey
from twisted.cred.error import UnhandledCredentials, UnauthorizedLogin
from twisted.python.fakepwd import UserDatabase, ShadowDatabase
from twisted.test.test_process import MockOS
try:
import Crypto.Cipher.DES3
import pyasn1
except ImportError:
dependencySkip = "can't run without Crypto and PyASN1"
else:
dependencySkip = None
from twisted.conch.ssh import keys
from twisted.conch import checkers
from twisted.conch.error import NotEnoughAuthentication, ValidPublicKey
from twisted.conch.test import keydata
if getattr(os, 'geteuid', None) is None:
euidSkip = "Cannot run without effective UIDs (questionable)"
else:
euidSkip = None
class HelperTests(TestCase):
"""
Tests for helper functions L{verifyCryptedPassword}, L{_pwdGetByName} and
L{_shadowGetByName}.
"""
skip = cryptSkip or dependencySkip
def setUp(self):
self.mockos = MockOS()
def test_verifyCryptedPassword(self):
"""
L{verifyCryptedPassword} returns C{True} if the plaintext password
passed to it matches the encrypted password passed to it.
"""
password = 'secret string'
salt = 'salty'
crypted = crypt.crypt(password, salt)
self.assertTrue(
checkers.verifyCryptedPassword(crypted, password),
'%r supposed to be valid encrypted password for %r' % (
crypted, password))
def test_verifyCryptedPasswordMD5(self):
"""
L{verifyCryptedPassword} returns True if the provided cleartext password
matches the provided MD5 password hash.
"""
password = 'password'
salt = '$1$salt'
crypted = crypt.crypt(password, salt)
self.assertTrue(
checkers.verifyCryptedPassword(crypted, password),
'%r supposed to be valid encrypted password for %s' % (
crypted, password))
def test_refuteCryptedPassword(self):
"""
L{verifyCryptedPassword} returns C{False} if the plaintext password
passed to it does not match the encrypted password passed to it.
"""
password = 'string secret'
wrong = 'secret string'
crypted = crypt.crypt(password, password)
self.assertFalse(
checkers.verifyCryptedPassword(crypted, wrong),
'%r not supposed to be valid encrypted password for %s' % (
crypted, wrong))
def test_pwdGetByName(self):
"""
L{_pwdGetByName} returns a tuple of items from the UNIX /etc/passwd
database if the L{pwd} module is present.
"""
userdb = UserDatabase()
userdb.addUser(
'alice', 'secrit', 1, 2, 'first last', '/foo', '/bin/sh')
self.patch(checkers, 'pwd', userdb)
self.assertEqual(
checkers._pwdGetByName('alice'), userdb.getpwnam('alice'))
def test_pwdGetByNameWithoutPwd(self):
"""
If the C{pwd} module isn't present, L{_pwdGetByName} returns C{None}.
"""
self.patch(checkers, 'pwd', None)
self.assertIs(checkers._pwdGetByName('alice'), None)
def test_shadowGetByName(self):
"""
L{_shadowGetByName} returns a tuple of items from the UNIX /etc/shadow
database if the L{spwd} is present.
"""
userdb = ShadowDatabase()
userdb.addUser('bob', 'passphrase', 1, 2, 3, 4, 5, 6, 7)
self.patch(checkers, 'spwd', userdb)
self.mockos.euid = 2345
self.mockos.egid = 1234
self.patch(util, 'os', self.mockos)
self.assertEqual(
checkers._shadowGetByName('bob'), userdb.getspnam('bob'))
self.assertEqual(self.mockos.seteuidCalls, [0, 2345])
self.assertEqual(self.mockos.setegidCalls, [0, 1234])
def test_shadowGetByNameWithoutSpwd(self):
"""
L{_shadowGetByName} uses the C{shadow} module to return a tuple of items
from the UNIX /etc/shadow database if the C{spwd} module is not present
and the C{shadow} module is.
"""
userdb = ShadowDatabase()
userdb.addUser('bob', 'passphrase', 1, 2, 3, 4, 5, 6, 7)
self.patch(checkers, 'spwd', None)
self.patch(checkers, 'shadow', userdb)
self.patch(util, 'os', self.mockos)
self.mockos.euid = 2345
self.mockos.egid = 1234
self.assertEqual(
checkers._shadowGetByName('bob'), userdb.getspnam('bob'))
self.assertEqual(self.mockos.seteuidCalls, [0, 2345])
self.assertEqual(self.mockos.setegidCalls, [0, 1234])
def test_shadowGetByNameWithoutEither(self):
"""
L{_shadowGetByName} returns C{None} if neither C{spwd} nor C{shadow} is
present.
"""
self.patch(checkers, 'spwd', None)
self.patch(checkers, 'shadow', None)
self.assertIs(checkers._shadowGetByName('bob'), None)
self.assertEqual(self.mockos.seteuidCalls, [])
self.assertEqual(self.mockos.setegidCalls, [])
class SSHPublicKeyDatabaseTestCase(TestCase):
"""
Tests for L{SSHPublicKeyDatabase}.
"""
skip = euidSkip or dependencySkip
def setUp(self):
self.checker = checkers.SSHPublicKeyDatabase()
self.key1 = base64.encodestring("foobar")
self.key2 = base64.encodestring("eggspam")
self.content = "t1 %s foo\nt2 %s egg\n" % (self.key1, self.key2)
self.mockos = MockOS()
self.mockos.path = FilePath(self.mktemp())
self.mockos.path.makedirs()
self.patch(util, 'os', self.mockos)
self.sshDir = self.mockos.path.child('.ssh')
self.sshDir.makedirs()
userdb = UserDatabase()
userdb.addUser(
'user', 'password', 1, 2, 'first last',
self.mockos.path.path, '/bin/shell')
self.checker._userdb = userdb
def _testCheckKey(self, filename):
self.sshDir.child(filename).setContent(self.content)
user = UsernamePassword("user", "password")
user.blob = "foobar"
self.assertTrue(self.checker.checkKey(user))
user.blob = "eggspam"
self.assertTrue(self.checker.checkKey(user))
user.blob = "notallowed"
self.assertFalse(self.checker.checkKey(user))
def test_checkKey(self):
"""
L{SSHPublicKeyDatabase.checkKey} should retrieve the content of the
authorized_keys file and check the keys against that file.
"""
self._testCheckKey("authorized_keys")
self.assertEqual(self.mockos.seteuidCalls, [])
self.assertEqual(self.mockos.setegidCalls, [])
def test_checkKey2(self):
"""
L{SSHPublicKeyDatabase.checkKey} should retrieve the content of the
authorized_keys2 file and check the keys against that file.
"""
self._testCheckKey("authorized_keys2")
self.assertEqual(self.mockos.seteuidCalls, [])
self.assertEqual(self.mockos.setegidCalls, [])
def test_checkKeyAsRoot(self):
"""
If the key file is readable, L{SSHPublicKeyDatabase.checkKey} should
switch its uid/gid to the ones of the authenticated user.
"""
keyFile = self.sshDir.child("authorized_keys")
keyFile.setContent(self.content)
# Fake permission error by changing the mode
keyFile.chmod(0000)
self.addCleanup(keyFile.chmod, 0777)
# And restore the right mode when seteuid is called
savedSeteuid = self.mockos.seteuid
def seteuid(euid):
keyFile.chmod(0777)
return savedSeteuid(euid)
self.mockos.euid = 2345
self.mockos.egid = 1234
self.patch(self.mockos, "seteuid", seteuid)
self.patch(util, 'os', self.mockos)
user = UsernamePassword("user", "password")
user.blob = "foobar"
self.assertTrue(self.checker.checkKey(user))
self.assertEqual(self.mockos.seteuidCalls, [0, 1, 0, 2345])
self.assertEqual(self.mockos.setegidCalls, [2, 1234])
def test_requestAvatarId(self):
"""
L{SSHPublicKeyDatabase.requestAvatarId} should return the avatar id
passed in if its C{_checkKey} method returns True.
"""
def _checkKey(ignored):
return True
self.patch(self.checker, 'checkKey', _checkKey)
credentials = SSHPrivateKey(
'test', 'ssh-rsa', keydata.publicRSA_openssh, 'foo',
keys.Key.fromString(keydata.privateRSA_openssh).sign('foo'))
d = self.checker.requestAvatarId(credentials)
def _verify(avatarId):
self.assertEqual(avatarId, 'test')
return d.addCallback(_verify)
def test_requestAvatarIdWithoutSignature(self):
"""
L{SSHPublicKeyDatabase.requestAvatarId} should raise L{ValidPublicKey}
if the credentials represent a valid key without a signature. This
tells the user that the key is valid for login, but does not actually
allow that user to do so without a signature.
"""
def _checkKey(ignored):
return True
self.patch(self.checker, 'checkKey', _checkKey)
credentials = SSHPrivateKey(
'test', 'ssh-rsa', keydata.publicRSA_openssh, None, None)
d = self.checker.requestAvatarId(credentials)
return self.assertFailure(d, ValidPublicKey)
def test_requestAvatarIdInvalidKey(self):
"""
If L{SSHPublicKeyDatabase.checkKey} returns False,
C{_cbRequestAvatarId} should raise L{UnauthorizedLogin}.
"""
def _checkKey(ignored):
return False
self.patch(self.checker, 'checkKey', _checkKey)
d = self.checker.requestAvatarId(None);
return self.assertFailure(d, UnauthorizedLogin)
def test_requestAvatarIdInvalidSignature(self):
"""
Valid keys with invalid signatures should cause
L{SSHPublicKeyDatabase.requestAvatarId} to return a {UnauthorizedLogin}
failure
"""
def _checkKey(ignored):
return True
self.patch(self.checker, 'checkKey', _checkKey)
credentials = SSHPrivateKey(
'test', 'ssh-rsa', keydata.publicRSA_openssh, 'foo',
keys.Key.fromString(keydata.privateDSA_openssh).sign('foo'))
d = self.checker.requestAvatarId(credentials)
return self.assertFailure(d, UnauthorizedLogin)
def test_requestAvatarIdNormalizeException(self):
"""
Exceptions raised while verifying the key should be normalized into an
C{UnauthorizedLogin} failure.
"""
def _checkKey(ignored):
return True
self.patch(self.checker, 'checkKey', _checkKey)
credentials = SSHPrivateKey('test', None, 'blob', 'sigData', 'sig')
d = self.checker.requestAvatarId(credentials)
def _verifyLoggedException(failure):
errors = self.flushLoggedErrors(keys.BadKeyError)
self.assertEqual(len(errors), 1)
return failure
d.addErrback(_verifyLoggedException)
return self.assertFailure(d, UnauthorizedLogin)
class SSHProtocolCheckerTestCase(TestCase):
"""
Tests for L{SSHProtocolChecker}.
"""
skip = dependencySkip
def test_registerChecker(self):
"""
L{SSHProcotolChecker.registerChecker} should add the given checker to
the list of registered checkers.
"""
checker = checkers.SSHProtocolChecker()
self.assertEqual(checker.credentialInterfaces, [])
checker.registerChecker(checkers.SSHPublicKeyDatabase(), )
self.assertEqual(checker.credentialInterfaces, [ISSHPrivateKey])
self.assertIsInstance(checker.checkers[ISSHPrivateKey],
checkers.SSHPublicKeyDatabase)
def test_registerCheckerWithInterface(self):
"""
If a apecific interface is passed into
L{SSHProtocolChecker.registerChecker}, that interface should be
registered instead of what the checker specifies in
credentialIntefaces.
"""
checker = checkers.SSHProtocolChecker()
self.assertEqual(checker.credentialInterfaces, [])
checker.registerChecker(checkers.SSHPublicKeyDatabase(),
IUsernamePassword)
self.assertEqual(checker.credentialInterfaces, [IUsernamePassword])
self.assertIsInstance(checker.checkers[IUsernamePassword],
checkers.SSHPublicKeyDatabase)
def test_requestAvatarId(self):
"""
L{SSHProtocolChecker.requestAvatarId} should defer to one if its
registered checkers to authenticate a user.
"""
checker = checkers.SSHProtocolChecker()
passwordDatabase = InMemoryUsernamePasswordDatabaseDontUse()
passwordDatabase.addUser('test', 'test')
checker.registerChecker(passwordDatabase)
d = checker.requestAvatarId(UsernamePassword('test', 'test'))
def _callback(avatarId):
self.assertEqual(avatarId, 'test')
return d.addCallback(_callback)
def test_requestAvatarIdWithNotEnoughAuthentication(self):
"""
If the client indicates that it is never satisfied, by always returning
False from _areDone, then L{SSHProtocolChecker} should raise
L{NotEnoughAuthentication}.
"""
checker = checkers.SSHProtocolChecker()
def _areDone(avatarId):
return False
self.patch(checker, 'areDone', _areDone)
passwordDatabase = InMemoryUsernamePasswordDatabaseDontUse()
passwordDatabase.addUser('test', 'test')
checker.registerChecker(passwordDatabase)
d = checker.requestAvatarId(UsernamePassword('test', 'test'))
return self.assertFailure(d, NotEnoughAuthentication)
def test_requestAvatarIdInvalidCredential(self):
"""
If the passed credentials aren't handled by any registered checker,
L{SSHProtocolChecker} should raise L{UnhandledCredentials}.
"""
checker = checkers.SSHProtocolChecker()
d = checker.requestAvatarId(UsernamePassword('test', 'test'))
return self.assertFailure(d, UnhandledCredentials)
def test_areDone(self):
"""
The default L{SSHProcotolChecker.areDone} should simply return True.
"""
self.assertEqual(checkers.SSHProtocolChecker().areDone(None), True)
class UNIXPasswordDatabaseTests(TestCase):
"""
Tests for L{UNIXPasswordDatabase}.
"""
skip = cryptSkip or dependencySkip
def assertLoggedIn(self, d, username):
"""
Assert that the L{Deferred} passed in is called back with the value
'username'. This represents a valid login for this TestCase.
NOTE: To work, this method's return value must be returned from the
test method, or otherwise hooked up to the test machinery.
@param d: a L{Deferred} from an L{IChecker.requestAvatarId} method.
@type d: L{Deferred}
@rtype: L{Deferred}
"""
result = []
d.addBoth(result.append)
self.assertEqual(len(result), 1, "login incomplete")
if isinstance(result[0], Failure):
result[0].raiseException()
self.assertEqual(result[0], username)
def test_defaultCheckers(self):
"""
L{UNIXPasswordDatabase} with no arguments has checks the C{pwd} database
and then the C{spwd} database.
"""
checker = checkers.UNIXPasswordDatabase()
def crypted(username, password):
salt = crypt.crypt(password, username)
crypted = crypt.crypt(password, '$1$' + salt)
return crypted
pwd = UserDatabase()
pwd.addUser('alice', crypted('alice', 'password'),
1, 2, 'foo', '/foo', '/bin/sh')
# x and * are convention for "look elsewhere for the password"
pwd.addUser('bob', 'x', 1, 2, 'bar', '/bar', '/bin/sh')
spwd = ShadowDatabase()
spwd.addUser('alice', 'wrong', 1, 2, 3, 4, 5, 6, 7)
spwd.addUser('bob', crypted('bob', 'password'),
8, 9, 10, 11, 12, 13, 14)
self.patch(checkers, 'pwd', pwd)
self.patch(checkers, 'spwd', spwd)
mockos = MockOS()
self.patch(util, 'os', mockos)
mockos.euid = 2345
mockos.egid = 1234
cred = UsernamePassword("alice", "password")
self.assertLoggedIn(checker.requestAvatarId(cred), 'alice')
self.assertEqual(mockos.seteuidCalls, [])
self.assertEqual(mockos.setegidCalls, [])
cred.username = "bob"
self.assertLoggedIn(checker.requestAvatarId(cred), 'bob')
self.assertEqual(mockos.seteuidCalls, [0, 2345])
self.assertEqual(mockos.setegidCalls, [0, 1234])
def assertUnauthorizedLogin(self, d):
"""
Asserts that the L{Deferred} passed in is erred back with an
L{UnauthorizedLogin} L{Failure}. This reprsents an invalid login for
this TestCase.
NOTE: To work, this method's return value must be returned from the
test method, or otherwise hooked up to the test machinery.
@param d: a L{Deferred} from an L{IChecker.requestAvatarId} method.
@type d: L{Deferred}
@rtype: L{None}
"""
self.assertRaises(
checkers.UnauthorizedLogin, self.assertLoggedIn, d, 'bogus value')
def test_passInCheckers(self):
"""
L{UNIXPasswordDatabase} takes a list of functions to check for UNIX
user information.
"""
password = crypt.crypt('secret', 'secret')
userdb = UserDatabase()
userdb.addUser('anybody', password, 1, 2, 'foo', '/bar', '/bin/sh')
checker = checkers.UNIXPasswordDatabase([userdb.getpwnam])
self.assertLoggedIn(
checker.requestAvatarId(UsernamePassword('anybody', 'secret')),
'anybody')
def test_verifyPassword(self):
"""
If the encrypted password provided by the getpwnam function is valid
(verified by the L{verifyCryptedPassword} function), we callback the
C{requestAvatarId} L{Deferred} with the username.
"""
def verifyCryptedPassword(crypted, pw):
return crypted == pw
def getpwnam(username):
return [username, username]
self.patch(checkers, 'verifyCryptedPassword', verifyCryptedPassword)
checker = checkers.UNIXPasswordDatabase([getpwnam])
credential = UsernamePassword('username', 'username')
self.assertLoggedIn(checker.requestAvatarId(credential), 'username')
def test_failOnKeyError(self):
"""
If the getpwnam function raises a KeyError, the login fails with an
L{UnauthorizedLogin} exception.
"""
def getpwnam(username):
raise KeyError(username)
checker = checkers.UNIXPasswordDatabase([getpwnam])
credential = UsernamePassword('username', 'username')
self.assertUnauthorizedLogin(checker.requestAvatarId(credential))
def test_failOnBadPassword(self):
"""
If the verifyCryptedPassword function doesn't verify the password, the
login fails with an L{UnauthorizedLogin} exception.
"""
def verifyCryptedPassword(crypted, pw):
return False
def getpwnam(username):
return [username, username]
self.patch(checkers, 'verifyCryptedPassword', verifyCryptedPassword)
checker = checkers.UNIXPasswordDatabase([getpwnam])
credential = UsernamePassword('username', 'username')
self.assertUnauthorizedLogin(checker.requestAvatarId(credential))
def test_loopThroughFunctions(self):
"""
UNIXPasswordDatabase.requestAvatarId loops through each getpwnam
function associated with it and returns a L{Deferred} which fires with
the result of the first one which returns a value other than None.
ones do not verify the password.
"""
def verifyCryptedPassword(crypted, pw):
return crypted == pw
def getpwnam1(username):
return [username, 'not the password']
def getpwnam2(username):
return [username, username]
self.patch(checkers, 'verifyCryptedPassword', verifyCryptedPassword)
checker = checkers.UNIXPasswordDatabase([getpwnam1, getpwnam2])
credential = UsernamePassword('username', 'username')
self.assertLoggedIn(checker.requestAvatarId(credential), 'username')
def test_failOnSpecial(self):
"""
If the password returned by any function is C{""}, C{"x"}, or C{"*"} it
is not compared against the supplied password. Instead it is skipped.
"""
pwd = UserDatabase()
pwd.addUser('alice', '', 1, 2, '', 'foo', 'bar')
pwd.addUser('bob', 'x', 1, 2, '', 'foo', 'bar')
pwd.addUser('carol', '*', 1, 2, '', 'foo', 'bar')
self.patch(checkers, 'pwd', pwd)
checker = checkers.UNIXPasswordDatabase([checkers._pwdGetByName])
cred = UsernamePassword('alice', '')
self.assertUnauthorizedLogin(checker.requestAvatarId(cred))
cred = UsernamePassword('bob', 'x')
self.assertUnauthorizedLogin(checker.requestAvatarId(cred))
cred = UsernamePassword('carol', '*')
self.assertUnauthorizedLogin(checker.requestAvatarId(cred))

View file

@ -0,0 +1,340 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.conch.scripts.ckeygen}.
"""
import getpass
import sys
from StringIO import StringIO
try:
import Crypto
import pyasn1
except ImportError:
skip = "PyCrypto and pyasn1 required for twisted.conch.scripts.ckeygen."
else:
from twisted.conch.ssh.keys import Key, BadKeyError
from twisted.conch.scripts.ckeygen import (
changePassPhrase, displayPublicKey, printFingerprint, _saveKey)
from twisted.python.filepath import FilePath
from twisted.trial.unittest import TestCase
from twisted.conch.test.keydata import (
publicRSA_openssh, privateRSA_openssh, privateRSA_openssh_encrypted)
def makeGetpass(*passphrases):
"""
Return a callable to patch C{getpass.getpass}. Yields a passphrase each
time called. Use case is to provide an old, then new passphrase(s) as if
requested interactively.
@param passphrases: The list of passphrases returned, one per each call.
"""
passphrases = iter(passphrases)
def fakeGetpass(_):
return passphrases.next()
return fakeGetpass
class KeyGenTests(TestCase):
"""
Tests for various functions used to implement the I{ckeygen} script.
"""
def setUp(self):
"""
Patch C{sys.stdout} with a L{StringIO} instance to tests can make
assertions about what's printed.
"""
self.stdout = StringIO()
self.patch(sys, 'stdout', self.stdout)
def test_printFingerprint(self):
"""
L{printFingerprint} writes a line to standard out giving the number of
bits of the key, its fingerprint, and the basename of the file from it
was read.
"""
filename = self.mktemp()
FilePath(filename).setContent(publicRSA_openssh)
printFingerprint({'filename': filename})
self.assertEqual(
self.stdout.getvalue(),
'768 3d:13:5f:cb:c9:79:8a:93:06:27:65:bc:3d:0b:8f:af temp\n')
def test_saveKey(self):
"""
L{_saveKey} writes the private and public parts of a key to two
different files and writes a report of this to standard out.
"""
base = FilePath(self.mktemp())
base.makedirs()
filename = base.child('id_rsa').path
key = Key.fromString(privateRSA_openssh)
_saveKey(
key.keyObject,
{'filename': filename, 'pass': 'passphrase'})
self.assertEqual(
self.stdout.getvalue(),
"Your identification has been saved in %s\n"
"Your public key has been saved in %s.pub\n"
"The key fingerprint is:\n"
"3d:13:5f:cb:c9:79:8a:93:06:27:65:bc:3d:0b:8f:af\n" % (
filename,
filename))
self.assertEqual(
key.fromString(
base.child('id_rsa').getContent(), None, 'passphrase'),
key)
self.assertEqual(
Key.fromString(base.child('id_rsa.pub').getContent()),
key.public())
def test_saveKeyEmptyPassphrase(self):
"""
L{_saveKey} will choose an empty string for the passphrase if
no-passphrase is C{True}.
"""
base = FilePath(self.mktemp())
base.makedirs()
filename = base.child('id_rsa').path
key = Key.fromString(privateRSA_openssh)
_saveKey(
key.keyObject,
{'filename': filename, 'no-passphrase': True})
self.assertEqual(
key.fromString(
base.child('id_rsa').getContent(), None, b''),
key)
def test_displayPublicKey(self):
"""
L{displayPublicKey} prints out the public key associated with a given
private key.
"""
filename = self.mktemp()
pubKey = Key.fromString(publicRSA_openssh)
FilePath(filename).setContent(privateRSA_openssh)
displayPublicKey({'filename': filename})
self.assertEqual(
self.stdout.getvalue().strip('\n'),
pubKey.toString('openssh'))
def test_displayPublicKeyEncrypted(self):
"""
L{displayPublicKey} prints out the public key associated with a given
private key using the given passphrase when it's encrypted.
"""
filename = self.mktemp()
pubKey = Key.fromString(publicRSA_openssh)
FilePath(filename).setContent(privateRSA_openssh_encrypted)
displayPublicKey({'filename': filename, 'pass': 'encrypted'})
self.assertEqual(
self.stdout.getvalue().strip('\n'),
pubKey.toString('openssh'))
def test_displayPublicKeyEncryptedPassphrasePrompt(self):
"""
L{displayPublicKey} prints out the public key associated with a given
private key, asking for the passphrase when it's encrypted.
"""
filename = self.mktemp()
pubKey = Key.fromString(publicRSA_openssh)
FilePath(filename).setContent(privateRSA_openssh_encrypted)
self.patch(getpass, 'getpass', lambda x: 'encrypted')
displayPublicKey({'filename': filename})
self.assertEqual(
self.stdout.getvalue().strip('\n'),
pubKey.toString('openssh'))
def test_displayPublicKeyWrongPassphrase(self):
"""
L{displayPublicKey} fails with a L{BadKeyError} when trying to decrypt
an encrypted key with the wrong password.
"""
filename = self.mktemp()
FilePath(filename).setContent(privateRSA_openssh_encrypted)
self.assertRaises(
BadKeyError, displayPublicKey,
{'filename': filename, 'pass': 'wrong'})
def test_changePassphrase(self):
"""
L{changePassPhrase} allows a user to change the passphrase of a
private key interactively.
"""
oldNewConfirm = makeGetpass('encrypted', 'newpass', 'newpass')
self.patch(getpass, 'getpass', oldNewConfirm)
filename = self.mktemp()
FilePath(filename).setContent(privateRSA_openssh_encrypted)
changePassPhrase({'filename': filename})
self.assertEqual(
self.stdout.getvalue().strip('\n'),
'Your identification has been saved with the new passphrase.')
self.assertNotEqual(privateRSA_openssh_encrypted,
FilePath(filename).getContent())
def test_changePassphraseWithOld(self):
"""
L{changePassPhrase} allows a user to change the passphrase of a
private key, providing the old passphrase and prompting for new one.
"""
newConfirm = makeGetpass('newpass', 'newpass')
self.patch(getpass, 'getpass', newConfirm)
filename = self.mktemp()
FilePath(filename).setContent(privateRSA_openssh_encrypted)
changePassPhrase({'filename': filename, 'pass': 'encrypted'})
self.assertEqual(
self.stdout.getvalue().strip('\n'),
'Your identification has been saved with the new passphrase.')
self.assertNotEqual(privateRSA_openssh_encrypted,
FilePath(filename).getContent())
def test_changePassphraseWithBoth(self):
"""
L{changePassPhrase} allows a user to change the passphrase of a private
key by providing both old and new passphrases without prompting.
"""
filename = self.mktemp()
FilePath(filename).setContent(privateRSA_openssh_encrypted)
changePassPhrase(
{'filename': filename, 'pass': 'encrypted',
'newpass': 'newencrypt'})
self.assertEqual(
self.stdout.getvalue().strip('\n'),
'Your identification has been saved with the new passphrase.')
self.assertNotEqual(privateRSA_openssh_encrypted,
FilePath(filename).getContent())
def test_changePassphraseWrongPassphrase(self):
"""
L{changePassPhrase} exits if passed an invalid old passphrase when
trying to change the passphrase of a private key.
"""
filename = self.mktemp()
FilePath(filename).setContent(privateRSA_openssh_encrypted)
error = self.assertRaises(
SystemExit, changePassPhrase,
{'filename': filename, 'pass': 'wrong'})
self.assertEqual('Could not change passphrase: old passphrase error',
str(error))
self.assertEqual(privateRSA_openssh_encrypted,
FilePath(filename).getContent())
def test_changePassphraseEmptyGetPass(self):
"""
L{changePassPhrase} exits if no passphrase is specified for the
C{getpass} call and the key is encrypted.
"""
self.patch(getpass, 'getpass', makeGetpass(''))
filename = self.mktemp()
FilePath(filename).setContent(privateRSA_openssh_encrypted)
error = self.assertRaises(
SystemExit, changePassPhrase, {'filename': filename})
self.assertEqual(
'Could not change passphrase: Passphrase must be provided '
'for an encrypted key',
str(error))
self.assertEqual(privateRSA_openssh_encrypted,
FilePath(filename).getContent())
def test_changePassphraseBadKey(self):
"""
L{changePassPhrase} exits if the file specified points to an invalid
key.
"""
filename = self.mktemp()
FilePath(filename).setContent('foobar')
error = self.assertRaises(
SystemExit, changePassPhrase, {'filename': filename})
self.assertEqual(
"Could not change passphrase: cannot guess the type of 'foobar'",
str(error))
self.assertEqual('foobar', FilePath(filename).getContent())
def test_changePassphraseCreateError(self):
"""
L{changePassPhrase} doesn't modify the key file if an unexpected error
happens when trying to create the key with the new passphrase.
"""
filename = self.mktemp()
FilePath(filename).setContent(privateRSA_openssh)
def toString(*args, **kwargs):
raise RuntimeError('oops')
self.patch(Key, 'toString', toString)
error = self.assertRaises(
SystemExit, changePassPhrase,
{'filename': filename,
'newpass': 'newencrypt'})
self.assertEqual(
'Could not change passphrase: oops', str(error))
self.assertEqual(privateRSA_openssh, FilePath(filename).getContent())
def test_changePassphraseEmptyStringError(self):
"""
L{changePassPhrase} doesn't modify the key file if C{toString} returns
an empty string.
"""
filename = self.mktemp()
FilePath(filename).setContent(privateRSA_openssh)
def toString(*args, **kwargs):
return ''
self.patch(Key, 'toString', toString)
error = self.assertRaises(
SystemExit, changePassPhrase,
{'filename': filename, 'newpass': 'newencrypt'})
self.assertEqual(
"Could not change passphrase: "
"cannot guess the type of ''", str(error))
self.assertEqual(privateRSA_openssh, FilePath(filename).getContent())
def test_changePassphrasePublicKey(self):
"""
L{changePassPhrase} exits when trying to change the passphrase on a
public key, and doesn't change the file.
"""
filename = self.mktemp()
FilePath(filename).setContent(publicRSA_openssh)
error = self.assertRaises(
SystemExit, changePassPhrase,
{'filename': filename, 'newpass': 'pass'})
self.assertEqual(
'Could not change passphrase: key not encrypted', str(error))
self.assertEqual(publicRSA_openssh, FilePath(filename).getContent())

View file

@ -0,0 +1,564 @@
# -*- test-case-name: twisted.conch.test.test_conch -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
import os, sys, socket
from itertools import count
from zope.interface import implements
from twisted.cred import portal
from twisted.internet import reactor, defer, protocol
from twisted.internet.error import ProcessExitedAlready
from twisted.internet.task import LoopingCall
from twisted.python import log, runtime
from twisted.trial import unittest
from twisted.conch.error import ConchError
from twisted.conch.avatar import ConchUser
from twisted.conch.ssh.session import ISession, SSHSession, wrapProtocol
try:
from twisted.conch.scripts.conch import SSHSession as StdioInteractingSession
except ImportError, e:
StdioInteractingSession = None
_reason = str(e)
del e
from twisted.conch.test.test_ssh import ConchTestRealm
from twisted.python.procutils import which
from twisted.conch.test.keydata import publicRSA_openssh, privateRSA_openssh
from twisted.conch.test.keydata import publicDSA_openssh, privateDSA_openssh
from twisted.conch.test.test_ssh import Crypto, pyasn1
try:
from twisted.conch.test.test_ssh import ConchTestServerFactory, \
ConchTestPublicKeyChecker
except ImportError:
pass
class FakeStdio(object):
"""
A fake for testing L{twisted.conch.scripts.conch.SSHSession.eofReceived} and
L{twisted.conch.scripts.cftp.SSHSession.eofReceived}.
@ivar writeConnLost: A flag which records whether L{loserWriteConnection}
has been called.
"""
writeConnLost = False
def loseWriteConnection(self):
"""
Record the call to loseWriteConnection.
"""
self.writeConnLost = True
class StdioInteractingSessionTests(unittest.TestCase):
"""
Tests for L{twisted.conch.scripts.conch.SSHSession}.
"""
if StdioInteractingSession is None:
skip = _reason
def test_eofReceived(self):
"""
L{twisted.conch.scripts.conch.SSHSession.eofReceived} loses the
write half of its stdio connection.
"""
stdio = FakeStdio()
channel = StdioInteractingSession()
channel.stdio = stdio
channel.eofReceived()
self.assertTrue(stdio.writeConnLost)
class Echo(protocol.Protocol):
def connectionMade(self):
log.msg('ECHO CONNECTION MADE')
def connectionLost(self, reason):
log.msg('ECHO CONNECTION DONE')
def dataReceived(self, data):
self.transport.write(data)
if '\n' in data:
self.transport.loseConnection()
class EchoFactory(protocol.Factory):
protocol = Echo
class ConchTestOpenSSHProcess(protocol.ProcessProtocol):
"""
Test protocol for launching an OpenSSH client process.
@ivar deferred: Set by whatever uses this object. Accessed using
L{_getDeferred}, which destroys the value so the Deferred is not
fired twice. Fires when the process is terminated.
"""
deferred = None
buf = ''
def _getDeferred(self):
d, self.deferred = self.deferred, None
return d
def outReceived(self, data):
self.buf += data
def processEnded(self, reason):
"""
Called when the process has ended.
@param reason: a Failure giving the reason for the process' end.
"""
if reason.value.exitCode != 0:
self._getDeferred().errback(
ConchError("exit code was not 0: %s" %
reason.value.exitCode))
else:
buf = self.buf.replace('\r\n', '\n')
self._getDeferred().callback(buf)
class ConchTestForwardingProcess(protocol.ProcessProtocol):
"""
Manages a third-party process which launches a server.
Uses L{ConchTestForwardingPort} to connect to the third-party server.
Once L{ConchTestForwardingPort} has disconnected, kill the process and fire
a Deferred with the data received by the L{ConchTestForwardingPort}.
@ivar deferred: Set by whatever uses this object. Accessed using
L{_getDeferred}, which destroys the value so the Deferred is not
fired twice. Fires when the process is terminated.
"""
deferred = None
def __init__(self, port, data):
"""
@type port: C{int}
@param port: The port on which the third-party server is listening.
(it is assumed that the server is running on localhost).
@type data: C{str}
@param data: This is sent to the third-party server. Must end with '\n'
in order to trigger a disconnect.
"""
self.port = port
self.buffer = None
self.data = data
def _getDeferred(self):
d, self.deferred = self.deferred, None
return d
def connectionMade(self):
self._connect()
def _connect(self):
"""
Connect to the server, which is often a third-party process.
Tries to reconnect if it fails because we have no way of determining
exactly when the port becomes available for listening -- we can only
know when the process starts.
"""
cc = protocol.ClientCreator(reactor, ConchTestForwardingPort, self,
self.data)
d = cc.connectTCP('127.0.0.1', self.port)
d.addErrback(self._ebConnect)
return d
def _ebConnect(self, f):
reactor.callLater(.1, self._connect)
def forwardingPortDisconnected(self, buffer):
"""
The network connection has died; save the buffer of output
from the network and attempt to quit the process gracefully,
and then (after the reactor has spun) send it a KILL signal.
"""
self.buffer = buffer
self.transport.write('\x03')
self.transport.loseConnection()
reactor.callLater(0, self._reallyDie)
def _reallyDie(self):
try:
self.transport.signalProcess('KILL')
except ProcessExitedAlready:
pass
def processEnded(self, reason):
"""
Fire the Deferred at self.deferred with the data collected
from the L{ConchTestForwardingPort} connection, if any.
"""
self._getDeferred().callback(self.buffer)
class ConchTestForwardingPort(protocol.Protocol):
"""
Connects to server launched by a third-party process (managed by
L{ConchTestForwardingProcess}) sends data, then reports whatever it
received back to the L{ConchTestForwardingProcess} once the connection
is ended.
"""
def __init__(self, protocol, data):
"""
@type protocol: L{ConchTestForwardingProcess}
@param protocol: The L{ProcessProtocol} which made this connection.
@type data: str
@param data: The data to be sent to the third-party server.
"""
self.protocol = protocol
self.data = data
def connectionMade(self):
self.buffer = ''
self.transport.write(self.data)
def dataReceived(self, data):
self.buffer += data
def connectionLost(self, reason):
self.protocol.forwardingPortDisconnected(self.buffer)
def _makeArgs(args, mod="conch"):
start = [sys.executable, '-c'
"""
### Twisted Preamble
import sys, os
path = os.path.abspath(sys.argv[0])
while os.path.dirname(path) != path:
if os.path.basename(path).startswith('Twisted'):
sys.path.insert(0, path)
break
path = os.path.dirname(path)
from twisted.conch.scripts.%s import run
run()""" % mod]
return start + list(args)
class ConchServerSetupMixin:
if not Crypto:
skip = "can't run w/o PyCrypto"
if not pyasn1:
skip = "Cannot run without PyASN1"
realmFactory = staticmethod(lambda: ConchTestRealm('testuser'))
def _createFiles(self):
for f in ['rsa_test','rsa_test.pub','dsa_test','dsa_test.pub',
'kh_test']:
if os.path.exists(f):
os.remove(f)
open('rsa_test','w').write(privateRSA_openssh)
open('rsa_test.pub','w').write(publicRSA_openssh)
open('dsa_test.pub','w').write(publicDSA_openssh)
open('dsa_test','w').write(privateDSA_openssh)
os.chmod('dsa_test', 33152)
os.chmod('rsa_test', 33152)
open('kh_test','w').write('127.0.0.1 '+publicRSA_openssh)
def _getFreePort(self):
s = socket.socket()
s.bind(('', 0))
port = s.getsockname()[1]
s.close()
return port
def _makeConchFactory(self):
"""
Make a L{ConchTestServerFactory}, which allows us to start a
L{ConchTestServer} -- i.e. an actually listening conch.
"""
realm = self.realmFactory()
p = portal.Portal(realm)
p.registerChecker(ConchTestPublicKeyChecker())
factory = ConchTestServerFactory()
factory.portal = p
return factory
def setUp(self):
self._createFiles()
self.conchFactory = self._makeConchFactory()
self.conchFactory.expectedLoseConnection = 1
self.conchServer = reactor.listenTCP(0, self.conchFactory,
interface="127.0.0.1")
self.echoServer = reactor.listenTCP(0, EchoFactory())
self.echoPort = self.echoServer.getHost().port
def tearDown(self):
try:
self.conchFactory.proto.done = 1
except AttributeError:
pass
else:
self.conchFactory.proto.transport.loseConnection()
return defer.gatherResults([
defer.maybeDeferred(self.conchServer.stopListening),
defer.maybeDeferred(self.echoServer.stopListening)])
class ForwardingMixin(ConchServerSetupMixin):
"""
Template class for tests of the Conch server's ability to forward arbitrary
protocols over SSH.
These tests are integration tests, not unit tests. They launch a Conch
server, a custom TCP server (just an L{EchoProtocol}) and then call
L{execute}.
L{execute} is implemented by subclasses of L{ForwardingMixin}. It should
cause an SSH client to connect to the Conch server, asking it to forward
data to the custom TCP server.
"""
def test_exec(self):
"""
Test that we can use whatever client to send the command "echo goodbye"
to the Conch server. Make sure we receive "goodbye" back from the
server.
"""
d = self.execute('echo goodbye', ConchTestOpenSSHProcess())
return d.addCallback(self.assertEqual, 'goodbye\n')
def test_localToRemoteForwarding(self):
"""
Test that we can use whatever client to forward a local port to a
specified port on the server.
"""
localPort = self._getFreePort()
process = ConchTestForwardingProcess(localPort, 'test\n')
d = self.execute('', process,
sshArgs='-N -L%i:127.0.0.1:%i'
% (localPort, self.echoPort))
d.addCallback(self.assertEqual, 'test\n')
return d
def test_remoteToLocalForwarding(self):
"""
Test that we can use whatever client to forward a port from the server
to a port locally.
"""
localPort = self._getFreePort()
process = ConchTestForwardingProcess(localPort, 'test\n')
d = self.execute('', process,
sshArgs='-N -R %i:127.0.0.1:%i'
% (localPort, self.echoPort))
d.addCallback(self.assertEqual, 'test\n')
return d
class RekeyAvatar(ConchUser):
"""
This avatar implements a shell which sends 60 numbered lines to whatever
connects to it, then closes the session with a 0 exit status.
60 lines is selected as being enough to send more than 2kB of traffic, the
amount the client is configured to initiate a rekey after.
"""
# Conventionally there is a separate adapter object which provides ISession
# for the user, but making the user provide ISession directly works too.
# This isn't a full implementation of ISession though, just enough to make
# these tests pass.
implements(ISession)
def __init__(self):
ConchUser.__init__(self)
self.channelLookup['session'] = SSHSession
def openShell(self, transport):
"""
Write 60 lines of data to the transport, then exit.
"""
proto = protocol.Protocol()
proto.makeConnection(transport)
transport.makeConnection(wrapProtocol(proto))
# Send enough bytes to the connection so that a rekey is triggered in
# the client.
def write(counter):
i = counter()
if i == 60:
call.stop()
transport.session.conn.sendRequest(
transport.session, 'exit-status', '\x00\x00\x00\x00')
transport.loseConnection()
else:
transport.write("line #%02d\n" % (i,))
# The timing for this loop is an educated guess (and/or the result of
# experimentation) to exercise the case where a packet is generated
# mid-rekey. Since the other side of the connection is (so far) the
# OpenSSH command line client, there's no easy way to determine when the
# rekey has been initiated. If there were, then generating a packet
# immediately at that time would be a better way to test the
# functionality being tested here.
call = LoopingCall(write, count().next)
call.start(0.01)
def closed(self):
"""
Ignore the close of the session.
"""
class RekeyRealm:
"""
This realm gives out new L{RekeyAvatar} instances for any avatar request.
"""
def requestAvatar(self, avatarID, mind, *interfaces):
return interfaces[0], RekeyAvatar(), lambda: None
class RekeyTestsMixin(ConchServerSetupMixin):
"""
TestCase mixin which defines tests exercising L{SSHTransportBase}'s handling
of rekeying messages.
"""
realmFactory = RekeyRealm
def test_clientRekey(self):
"""
After a client-initiated rekey is completed, application data continues
to be passed over the SSH connection.
"""
process = ConchTestOpenSSHProcess()
d = self.execute("", process, '-o RekeyLimit=2K')
def finished(result):
self.assertEqual(
result,
'\n'.join(['line #%02d' % (i,) for i in range(60)]) + '\n')
d.addCallback(finished)
return d
class OpenSSHClientMixin:
if not which('ssh'):
skip = "no ssh command-line client available"
def execute(self, remoteCommand, process, sshArgs=''):
"""
Connects to the SSH server started in L{ConchServerSetupMixin.setUp} by
running the 'ssh' command line tool.
@type remoteCommand: str
@param remoteCommand: The command (with arguments) to run on the
remote end.
@type process: L{ConchTestOpenSSHProcess}
@type sshArgs: str
@param sshArgs: Arguments to pass to the 'ssh' process.
@return: L{defer.Deferred}
"""
process.deferred = defer.Deferred()
cmdline = ('ssh -2 -l testuser -p %i '
'-oUserKnownHostsFile=kh_test '
'-oPasswordAuthentication=no '
# Always use the RSA key, since that's the one in kh_test.
'-oHostKeyAlgorithms=ssh-rsa '
'-a '
'-i dsa_test ') + sshArgs + \
' 127.0.0.1 ' + remoteCommand
port = self.conchServer.getHost().port
cmds = (cmdline % port).split()
reactor.spawnProcess(process, "ssh", cmds)
return process.deferred
class OpenSSHClientForwardingTestCase(ForwardingMixin, OpenSSHClientMixin,
unittest.TestCase):
"""
Connection forwarding tests run against the OpenSSL command line client.
"""
class OpenSSHClientRekeyTestCase(RekeyTestsMixin, OpenSSHClientMixin,
unittest.TestCase):
"""
Rekeying tests run against the OpenSSL command line client.
"""
class CmdLineClientTestCase(ForwardingMixin, unittest.TestCase):
"""
Connection forwarding tests run against the Conch command line client.
"""
if runtime.platformType == 'win32':
skip = "can't run cmdline client on win32"
def execute(self, remoteCommand, process, sshArgs=''):
"""
As for L{OpenSSHClientTestCase.execute}, except it runs the 'conch'
command line tool, not 'ssh'.
"""
process.deferred = defer.Deferred()
port = self.conchServer.getHost().port
cmd = ('-p %i -l testuser '
'--known-hosts kh_test '
'--user-authentications publickey '
'--host-key-algorithms ssh-rsa '
'-a '
'-i dsa_test '
'-v ') % port + sshArgs + \
' 127.0.0.1 ' + remoteCommand
cmds = _makeArgs(cmd.split())
log.msg(str(cmds))
env = os.environ.copy()
env['PYTHONPATH'] = os.pathsep.join(sys.path)
reactor.spawnProcess(process, sys.executable, cmds, env=env)
return process.deferred

View file

@ -0,0 +1,730 @@
# Copyright (c) 2007-2010 Twisted Matrix Laboratories.
# See LICENSE for details
"""
This module tests twisted.conch.ssh.connection.
"""
import struct
from twisted.conch import error
from twisted.conch.ssh import channel, common, connection
from twisted.trial import unittest
from twisted.conch.test import test_userauth
class TestChannel(channel.SSHChannel):
"""
A mocked-up version of twisted.conch.ssh.channel.SSHChannel.
@ivar gotOpen: True if channelOpen has been called.
@type gotOpen: C{bool}
@ivar specificData: the specific channel open data passed to channelOpen.
@type specificData: C{str}
@ivar openFailureReason: the reason passed to openFailed.
@type openFailed: C{error.ConchError}
@ivar inBuffer: a C{list} of strings received by the channel.
@type inBuffer: C{list}
@ivar extBuffer: a C{list} of 2-tuples (type, extended data) of received by
the channel.
@type extBuffer: C{list}
@ivar numberRequests: the number of requests that have been made to this
channel.
@type numberRequests: C{int}
@ivar gotEOF: True if the other side sent EOF.
@type gotEOF: C{bool}
@ivar gotOneClose: True if the other side closed the connection.
@type gotOneClose: C{bool}
@ivar gotClosed: True if the channel is closed.
@type gotClosed: C{bool}
"""
name = "TestChannel"
gotOpen = False
def logPrefix(self):
return "TestChannel %i" % self.id
def channelOpen(self, specificData):
"""
The channel is open. Set up the instance variables.
"""
self.gotOpen = True
self.specificData = specificData
self.inBuffer = []
self.extBuffer = []
self.numberRequests = 0
self.gotEOF = False
self.gotOneClose = False
self.gotClosed = False
def openFailed(self, reason):
"""
Opening the channel failed. Store the reason why.
"""
self.openFailureReason = reason
def request_test(self, data):
"""
A test request. Return True if data is 'data'.
@type data: C{str}
"""
self.numberRequests += 1
return data == 'data'
def dataReceived(self, data):
"""
Data was received. Store it in the buffer.
"""
self.inBuffer.append(data)
def extReceived(self, code, data):
"""
Extended data was received. Store it in the buffer.
"""
self.extBuffer.append((code, data))
def eofReceived(self):
"""
EOF was received. Remember it.
"""
self.gotEOF = True
def closeReceived(self):
"""
Close was received. Remember it.
"""
self.gotOneClose = True
def closed(self):
"""
The channel is closed. Rembember it.
"""
self.gotClosed = True
class TestAvatar:
"""
A mocked-up version of twisted.conch.avatar.ConchUser
"""
_ARGS_ERROR_CODE = 123
def lookupChannel(self, channelType, windowSize, maxPacket, data):
"""
The server wants us to return a channel. If the requested channel is
our TestChannel, return it, otherwise return None.
"""
if channelType == TestChannel.name:
return TestChannel(remoteWindow=windowSize,
remoteMaxPacket=maxPacket,
data=data, avatar=self)
elif channelType == "conch-error-args":
# Raise a ConchError with backwards arguments to make sure the
# connection fixes it for us. This case should be deprecated and
# deleted eventually, but only after all of Conch gets the argument
# order right.
raise error.ConchError(
self._ARGS_ERROR_CODE, "error args in wrong order")
def gotGlobalRequest(self, requestType, data):
"""
The client has made a global request. If the global request is
'TestGlobal', return True. If the global request is 'TestData',
return True and the request-specific data we received. Otherwise,
return False.
"""
if requestType == 'TestGlobal':
return True
elif requestType == 'TestData':
return True, data
else:
return False
class TestConnection(connection.SSHConnection):
"""
A subclass of SSHConnection for testing.
@ivar channel: the current channel.
@type channel. C{TestChannel}
"""
def logPrefix(self):
return "TestConnection"
def global_TestGlobal(self, data):
"""
The other side made the 'TestGlobal' global request. Return True.
"""
return True
def global_Test_Data(self, data):
"""
The other side made the 'Test-Data' global request. Return True and
the data we received.
"""
return True, data
def channel_TestChannel(self, windowSize, maxPacket, data):
"""
The other side is requesting the TestChannel. Create a C{TestChannel}
instance, store it, and return it.
"""
self.channel = TestChannel(remoteWindow=windowSize,
remoteMaxPacket=maxPacket, data=data)
return self.channel
def channel_ErrorChannel(self, windowSize, maxPacket, data):
"""
The other side is requesting the ErrorChannel. Raise an exception.
"""
raise AssertionError('no such thing')
class ConnectionTestCase(unittest.TestCase):
if test_userauth.transport is None:
skip = "Cannot run without both PyCrypto and pyasn1"
def setUp(self):
self.transport = test_userauth.FakeTransport(None)
self.transport.avatar = TestAvatar()
self.conn = TestConnection()
self.conn.transport = self.transport
self.conn.serviceStarted()
def _openChannel(self, channel):
"""
Open the channel with the default connection.
"""
self.conn.openChannel(channel)
self.transport.packets = self.transport.packets[:-1]
self.conn.ssh_CHANNEL_OPEN_CONFIRMATION(struct.pack('>2L',
channel.id, 255) + '\x00\x02\x00\x00\x00\x00\x80\x00')
def tearDown(self):
self.conn.serviceStopped()
def test_linkAvatar(self):
"""
Test that the connection links itself to the avatar in the
transport.
"""
self.assertIs(self.transport.avatar.conn, self.conn)
def test_serviceStopped(self):
"""
Test that serviceStopped() closes any open channels.
"""
channel1 = TestChannel()
channel2 = TestChannel()
self.conn.openChannel(channel1)
self.conn.openChannel(channel2)
self.conn.ssh_CHANNEL_OPEN_CONFIRMATION('\x00\x00\x00\x00' * 4)
self.assertTrue(channel1.gotOpen)
self.assertFalse(channel2.gotOpen)
self.conn.serviceStopped()
self.assertTrue(channel1.gotClosed)
def test_GLOBAL_REQUEST(self):
"""
Test that global request packets are dispatched to the global_*
methods and the return values are translated into success or failure
messages.
"""
self.conn.ssh_GLOBAL_REQUEST(common.NS('TestGlobal') + '\xff')
self.assertEqual(self.transport.packets,
[(connection.MSG_REQUEST_SUCCESS, '')])
self.transport.packets = []
self.conn.ssh_GLOBAL_REQUEST(common.NS('TestData') + '\xff' +
'test data')
self.assertEqual(self.transport.packets,
[(connection.MSG_REQUEST_SUCCESS, 'test data')])
self.transport.packets = []
self.conn.ssh_GLOBAL_REQUEST(common.NS('TestBad') + '\xff')
self.assertEqual(self.transport.packets,
[(connection.MSG_REQUEST_FAILURE, '')])
self.transport.packets = []
self.conn.ssh_GLOBAL_REQUEST(common.NS('TestGlobal') + '\x00')
self.assertEqual(self.transport.packets, [])
def test_REQUEST_SUCCESS(self):
"""
Test that global request success packets cause the Deferred to be
called back.
"""
d = self.conn.sendGlobalRequest('request', 'data', True)
self.conn.ssh_REQUEST_SUCCESS('data')
def check(data):
self.assertEqual(data, 'data')
d.addCallback(check)
d.addErrback(self.fail)
return d
def test_REQUEST_FAILURE(self):
"""
Test that global request failure packets cause the Deferred to be
erred back.
"""
d = self.conn.sendGlobalRequest('request', 'data', True)
self.conn.ssh_REQUEST_FAILURE('data')
def check(f):
self.assertEqual(f.value.data, 'data')
d.addCallback(self.fail)
d.addErrback(check)
return d
def test_CHANNEL_OPEN(self):
"""
Test that open channel packets cause a channel to be created and
opened or a failure message to be returned.
"""
del self.transport.avatar
self.conn.ssh_CHANNEL_OPEN(common.NS('TestChannel') +
'\x00\x00\x00\x01' * 4)
self.assertTrue(self.conn.channel.gotOpen)
self.assertEqual(self.conn.channel.conn, self.conn)
self.assertEqual(self.conn.channel.data, '\x00\x00\x00\x01')
self.assertEqual(self.conn.channel.specificData, '\x00\x00\x00\x01')
self.assertEqual(self.conn.channel.remoteWindowLeft, 1)
self.assertEqual(self.conn.channel.remoteMaxPacket, 1)
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_OPEN_CONFIRMATION,
'\x00\x00\x00\x01\x00\x00\x00\x00\x00\x02\x00\x00'
'\x00\x00\x80\x00')])
self.transport.packets = []
self.conn.ssh_CHANNEL_OPEN(common.NS('BadChannel') +
'\x00\x00\x00\x02' * 4)
self.flushLoggedErrors()
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_OPEN_FAILURE,
'\x00\x00\x00\x02\x00\x00\x00\x03' + common.NS(
'unknown channel') + common.NS(''))])
self.transport.packets = []
self.conn.ssh_CHANNEL_OPEN(common.NS('ErrorChannel') +
'\x00\x00\x00\x02' * 4)
self.flushLoggedErrors()
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_OPEN_FAILURE,
'\x00\x00\x00\x02\x00\x00\x00\x02' + common.NS(
'unknown failure') + common.NS(''))])
def _lookupChannelErrorTest(self, code):
"""
Deliver a request for a channel open which will result in an exception
being raised during channel lookup. Assert that an error response is
delivered as a result.
"""
self.transport.avatar._ARGS_ERROR_CODE = code
self.conn.ssh_CHANNEL_OPEN(
common.NS('conch-error-args') + '\x00\x00\x00\x01' * 4)
errors = self.flushLoggedErrors(error.ConchError)
self.assertEqual(
len(errors), 1, "Expected one error, got: %r" % (errors,))
self.assertEqual(errors[0].value.args, (123, "error args in wrong order"))
self.assertEqual(
self.transport.packets,
[(connection.MSG_CHANNEL_OPEN_FAILURE,
# The response includes some bytes which identifying the
# associated request, as well as the error code (7b in hex) and
# the error message.
'\x00\x00\x00\x01\x00\x00\x00\x7b' + common.NS(
'error args in wrong order') + common.NS(''))])
def test_lookupChannelError(self):
"""
If a C{lookupChannel} implementation raises L{error.ConchError} with the
arguments in the wrong order, a C{MSG_CHANNEL_OPEN} failure is still
sent in response to the message.
This is a temporary work-around until L{error.ConchError} is given
better attributes and all of the Conch code starts constructing
instances of it properly. Eventually this functionality should be
deprecated and then removed.
"""
self._lookupChannelErrorTest(123)
def test_lookupChannelErrorLongCode(self):
"""
Like L{test_lookupChannelError}, but for the case where the failure code
is represented as a C{long} instead of a C{int}.
"""
self._lookupChannelErrorTest(123L)
def test_CHANNEL_OPEN_CONFIRMATION(self):
"""
Test that channel open confirmation packets cause the channel to be
notified that it's open.
"""
channel = TestChannel()
self.conn.openChannel(channel)
self.conn.ssh_CHANNEL_OPEN_CONFIRMATION('\x00\x00\x00\x00'*5)
self.assertEqual(channel.remoteWindowLeft, 0)
self.assertEqual(channel.remoteMaxPacket, 0)
self.assertEqual(channel.specificData, '\x00\x00\x00\x00')
self.assertEqual(self.conn.channelsToRemoteChannel[channel],
0)
self.assertEqual(self.conn.localToRemoteChannel[0], 0)
def test_CHANNEL_OPEN_FAILURE(self):
"""
Test that channel open failure packets cause the channel to be
notified that its opening failed.
"""
channel = TestChannel()
self.conn.openChannel(channel)
self.conn.ssh_CHANNEL_OPEN_FAILURE('\x00\x00\x00\x00\x00\x00\x00'
'\x01' + common.NS('failure!'))
self.assertEqual(channel.openFailureReason.args, ('failure!', 1))
self.assertEqual(self.conn.channels.get(channel), None)
def test_CHANNEL_WINDOW_ADJUST(self):
"""
Test that channel window adjust messages add bytes to the channel
window.
"""
channel = TestChannel()
self._openChannel(channel)
oldWindowSize = channel.remoteWindowLeft
self.conn.ssh_CHANNEL_WINDOW_ADJUST('\x00\x00\x00\x00\x00\x00\x00'
'\x01')
self.assertEqual(channel.remoteWindowLeft, oldWindowSize + 1)
def test_CHANNEL_DATA(self):
"""
Test that channel data messages are passed up to the channel, or
cause the channel to be closed if the data is too large.
"""
channel = TestChannel(localWindow=6, localMaxPacket=5)
self._openChannel(channel)
self.conn.ssh_CHANNEL_DATA('\x00\x00\x00\x00' + common.NS('data'))
self.assertEqual(channel.inBuffer, ['data'])
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_WINDOW_ADJUST, '\x00\x00\x00\xff'
'\x00\x00\x00\x04')])
self.transport.packets = []
longData = 'a' * (channel.localWindowLeft + 1)
self.conn.ssh_CHANNEL_DATA('\x00\x00\x00\x00' + common.NS(longData))
self.assertEqual(channel.inBuffer, ['data'])
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_CLOSE, '\x00\x00\x00\xff')])
channel = TestChannel()
self._openChannel(channel)
bigData = 'a' * (channel.localMaxPacket + 1)
self.transport.packets = []
self.conn.ssh_CHANNEL_DATA('\x00\x00\x00\x01' + common.NS(bigData))
self.assertEqual(channel.inBuffer, [])
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_CLOSE, '\x00\x00\x00\xff')])
def test_CHANNEL_EXTENDED_DATA(self):
"""
Test that channel extended data messages are passed up to the channel,
or cause the channel to be closed if they're too big.
"""
channel = TestChannel(localWindow=6, localMaxPacket=5)
self._openChannel(channel)
self.conn.ssh_CHANNEL_EXTENDED_DATA('\x00\x00\x00\x00\x00\x00\x00'
'\x00' + common.NS('data'))
self.assertEqual(channel.extBuffer, [(0, 'data')])
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_WINDOW_ADJUST, '\x00\x00\x00\xff'
'\x00\x00\x00\x04')])
self.transport.packets = []
longData = 'a' * (channel.localWindowLeft + 1)
self.conn.ssh_CHANNEL_EXTENDED_DATA('\x00\x00\x00\x00\x00\x00\x00'
'\x00' + common.NS(longData))
self.assertEqual(channel.extBuffer, [(0, 'data')])
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_CLOSE, '\x00\x00\x00\xff')])
channel = TestChannel()
self._openChannel(channel)
bigData = 'a' * (channel.localMaxPacket + 1)
self.transport.packets = []
self.conn.ssh_CHANNEL_EXTENDED_DATA('\x00\x00\x00\x01\x00\x00\x00'
'\x00' + common.NS(bigData))
self.assertEqual(channel.extBuffer, [])
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_CLOSE, '\x00\x00\x00\xff')])
def test_CHANNEL_EOF(self):
"""
Test that channel eof messages are passed up to the channel.
"""
channel = TestChannel()
self._openChannel(channel)
self.conn.ssh_CHANNEL_EOF('\x00\x00\x00\x00')
self.assertTrue(channel.gotEOF)
def test_CHANNEL_CLOSE(self):
"""
Test that channel close messages are passed up to the channel. Also,
test that channel.close() is called if both sides are closed when this
message is received.
"""
channel = TestChannel()
self._openChannel(channel)
self.conn.sendClose(channel)
self.conn.ssh_CHANNEL_CLOSE('\x00\x00\x00\x00')
self.assertTrue(channel.gotOneClose)
self.assertTrue(channel.gotClosed)
def test_CHANNEL_REQUEST_success(self):
"""
Test that channel requests that succeed send MSG_CHANNEL_SUCCESS.
"""
channel = TestChannel()
self._openChannel(channel)
self.conn.ssh_CHANNEL_REQUEST('\x00\x00\x00\x00' + common.NS('test')
+ '\x00')
self.assertEqual(channel.numberRequests, 1)
d = self.conn.ssh_CHANNEL_REQUEST('\x00\x00\x00\x00' + common.NS(
'test') + '\xff' + 'data')
def check(result):
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_SUCCESS, '\x00\x00\x00\xff')])
d.addCallback(check)
return d
def test_CHANNEL_REQUEST_failure(self):
"""
Test that channel requests that fail send MSG_CHANNEL_FAILURE.
"""
channel = TestChannel()
self._openChannel(channel)
d = self.conn.ssh_CHANNEL_REQUEST('\x00\x00\x00\x00' + common.NS(
'test') + '\xff')
def check(result):
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_FAILURE, '\x00\x00\x00\xff'
)])
d.addCallback(self.fail)
d.addErrback(check)
return d
def test_CHANNEL_REQUEST_SUCCESS(self):
"""
Test that channel request success messages cause the Deferred to be
called back.
"""
channel = TestChannel()
self._openChannel(channel)
d = self.conn.sendRequest(channel, 'test', 'data', True)
self.conn.ssh_CHANNEL_SUCCESS('\x00\x00\x00\x00')
def check(result):
self.assertTrue(result)
return d
def test_CHANNEL_REQUEST_FAILURE(self):
"""
Test that channel request failure messages cause the Deferred to be
erred back.
"""
channel = TestChannel()
self._openChannel(channel)
d = self.conn.sendRequest(channel, 'test', '', True)
self.conn.ssh_CHANNEL_FAILURE('\x00\x00\x00\x00')
def check(result):
self.assertEqual(result.value.value, 'channel request failed')
d.addCallback(self.fail)
d.addErrback(check)
return d
def test_sendGlobalRequest(self):
"""
Test that global request messages are sent in the right format.
"""
d = self.conn.sendGlobalRequest('wantReply', 'data', True)
# must be added to prevent errbacking during teardown
d.addErrback(lambda failure: None)
self.conn.sendGlobalRequest('noReply', '', False)
self.assertEqual(self.transport.packets,
[(connection.MSG_GLOBAL_REQUEST, common.NS('wantReply') +
'\xffdata'),
(connection.MSG_GLOBAL_REQUEST, common.NS('noReply') +
'\x00')])
self.assertEqual(self.conn.deferreds, {'global':[d]})
def test_openChannel(self):
"""
Test that open channel messages are sent in the right format.
"""
channel = TestChannel()
self.conn.openChannel(channel, 'aaaa')
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_OPEN, common.NS('TestChannel') +
'\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x80\x00aaaa')])
self.assertEqual(channel.id, 0)
self.assertEqual(self.conn.localChannelID, 1)
def test_sendRequest(self):
"""
Test that channel request messages are sent in the right format.
"""
channel = TestChannel()
self._openChannel(channel)
d = self.conn.sendRequest(channel, 'test', 'test', True)
# needed to prevent errbacks during teardown.
d.addErrback(lambda failure: None)
self.conn.sendRequest(channel, 'test2', '', False)
channel.localClosed = True # emulate sending a close message
self.conn.sendRequest(channel, 'test3', '', True)
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_REQUEST, '\x00\x00\x00\xff' +
common.NS('test') + '\x01test'),
(connection.MSG_CHANNEL_REQUEST, '\x00\x00\x00\xff' +
common.NS('test2') + '\x00')])
self.assertEqual(self.conn.deferreds[0], [d])
def test_adjustWindow(self):
"""
Test that channel window adjust messages cause bytes to be added
to the window.
"""
channel = TestChannel(localWindow=5)
self._openChannel(channel)
channel.localWindowLeft = 0
self.conn.adjustWindow(channel, 1)
self.assertEqual(channel.localWindowLeft, 1)
channel.localClosed = True
self.conn.adjustWindow(channel, 2)
self.assertEqual(channel.localWindowLeft, 1)
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_WINDOW_ADJUST, '\x00\x00\x00\xff'
'\x00\x00\x00\x01')])
def test_sendData(self):
"""
Test that channel data messages are sent in the right format.
"""
channel = TestChannel()
self._openChannel(channel)
self.conn.sendData(channel, 'a')
channel.localClosed = True
self.conn.sendData(channel, 'b')
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_DATA, '\x00\x00\x00\xff' +
common.NS('a'))])
def test_sendExtendedData(self):
"""
Test that channel extended data messages are sent in the right format.
"""
channel = TestChannel()
self._openChannel(channel)
self.conn.sendExtendedData(channel, 1, 'test')
channel.localClosed = True
self.conn.sendExtendedData(channel, 2, 'test2')
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_EXTENDED_DATA, '\x00\x00\x00\xff' +
'\x00\x00\x00\x01' + common.NS('test'))])
def test_sendEOF(self):
"""
Test that channel EOF messages are sent in the right format.
"""
channel = TestChannel()
self._openChannel(channel)
self.conn.sendEOF(channel)
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_EOF, '\x00\x00\x00\xff')])
channel.localClosed = True
self.conn.sendEOF(channel)
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_EOF, '\x00\x00\x00\xff')])
def test_sendClose(self):
"""
Test that channel close messages are sent in the right format.
"""
channel = TestChannel()
self._openChannel(channel)
self.conn.sendClose(channel)
self.assertTrue(channel.localClosed)
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_CLOSE, '\x00\x00\x00\xff')])
self.conn.sendClose(channel)
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_CLOSE, '\x00\x00\x00\xff')])
channel2 = TestChannel()
self._openChannel(channel2)
channel2.remoteClosed = True
self.conn.sendClose(channel2)
self.assertTrue(channel2.gotClosed)
def test_getChannelWithAvatar(self):
"""
Test that getChannel dispatches to the avatar when an avatar is
present. Correct functioning without the avatar is verified in
test_CHANNEL_OPEN.
"""
channel = self.conn.getChannel('TestChannel', 50, 30, 'data')
self.assertEqual(channel.data, 'data')
self.assertEqual(channel.remoteWindowLeft, 50)
self.assertEqual(channel.remoteMaxPacket, 30)
self.assertRaises(error.ConchError, self.conn.getChannel,
'BadChannel', 50, 30, 'data')
def test_gotGlobalRequestWithoutAvatar(self):
"""
Test that gotGlobalRequests dispatches to global_* without an avatar.
"""
del self.transport.avatar
self.assertTrue(self.conn.gotGlobalRequest('TestGlobal', 'data'))
self.assertEqual(self.conn.gotGlobalRequest('Test-Data', 'data'),
(True, 'data'))
self.assertFalse(self.conn.gotGlobalRequest('BadGlobal', 'data'))
def test_channelClosedCausesLeftoverChannelDeferredsToErrback(self):
"""
Whenever an SSH channel gets closed any Deferred that was returned by a
sendRequest() on its parent connection must be errbacked.
"""
channel = TestChannel()
self._openChannel(channel)
d = self.conn.sendRequest(
channel, "dummyrequest", "dummydata", wantReply=1)
d = self.assertFailure(d, error.ConchError)
self.conn.channelClosed(channel)
return d
class TestCleanConnectionShutdown(unittest.TestCase):
"""
Check whether correct cleanup is performed on connection shutdown.
"""
if test_userauth.transport is None:
skip = "Cannot run without both PyCrypto and pyasn1"
def setUp(self):
self.transport = test_userauth.FakeTransport(None)
self.transport.avatar = TestAvatar()
self.conn = TestConnection()
self.conn.transport = self.transport
def test_serviceStoppedCausesLeftoverGlobalDeferredsToErrback(self):
"""
Once the service is stopped any leftover global deferred returned by
a sendGlobalRequest() call must be errbacked.
"""
self.conn.serviceStarted()
d = self.conn.sendGlobalRequest(
"dummyrequest", "dummydata", wantReply=1)
d = self.assertFailure(d, error.ConchError)
self.conn.serviceStopped()
return d

View file

@ -0,0 +1,171 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.conch.client.default}.
"""
try:
import Crypto.Cipher.DES3
import pyasn1
except ImportError:
skip = "PyCrypto and PyASN1 required for twisted.conch.client.default."
else:
from twisted.conch.client.agent import SSHAgentClient
from twisted.conch.client.default import SSHUserAuthClient
from twisted.conch.client.options import ConchOptions
from twisted.conch.ssh.keys import Key
from twisted.trial.unittest import TestCase
from twisted.python.filepath import FilePath
from twisted.conch.test import keydata
from twisted.test.proto_helpers import StringTransport
class SSHUserAuthClientTest(TestCase):
"""
Tests for L{SSHUserAuthClient}.
@type rsaPublic: L{Key}
@ivar rsaPublic: A public RSA key.
"""
def setUp(self):
self.rsaPublic = Key.fromString(keydata.publicRSA_openssh)
self.tmpdir = FilePath(self.mktemp())
self.tmpdir.makedirs()
self.rsaFile = self.tmpdir.child('id_rsa')
self.rsaFile.setContent(keydata.privateRSA_openssh)
self.tmpdir.child('id_rsa.pub').setContent(keydata.publicRSA_openssh)
def test_signDataWithAgent(self):
"""
When connected to an agent, L{SSHUserAuthClient} can use it to
request signatures of particular data with a particular L{Key}.
"""
client = SSHUserAuthClient("user", ConchOptions(), None)
agent = SSHAgentClient()
transport = StringTransport()
agent.makeConnection(transport)
client.keyAgent = agent
cleartext = "Sign here"
client.signData(self.rsaPublic, cleartext)
self.assertEqual(
transport.value(),
"\x00\x00\x00\x8b\r\x00\x00\x00u" + self.rsaPublic.blob() +
"\x00\x00\x00\t" + cleartext +
"\x00\x00\x00\x00")
def test_agentGetPublicKey(self):
"""
L{SSHUserAuthClient} looks up public keys from the agent using the
L{SSHAgentClient} class. That L{SSHAgentClient.getPublicKey} returns a
L{Key} object with one of the public keys in the agent. If no more
keys are present, it returns C{None}.
"""
agent = SSHAgentClient()
agent.blobs = [self.rsaPublic.blob()]
key = agent.getPublicKey()
self.assertEqual(key.isPublic(), True)
self.assertEqual(key, self.rsaPublic)
self.assertEqual(agent.getPublicKey(), None)
def test_getPublicKeyFromFile(self):
"""
L{SSHUserAuthClient.getPublicKey()} is able to get a public key from
the first file described by its options' C{identitys} list, and return
the corresponding public L{Key} object.
"""
options = ConchOptions()
options.identitys = [self.rsaFile.path]
client = SSHUserAuthClient("user", options, None)
key = client.getPublicKey()
self.assertEqual(key.isPublic(), True)
self.assertEqual(key, self.rsaPublic)
def test_getPublicKeyAgentFallback(self):
"""
If an agent is present, but doesn't return a key,
L{SSHUserAuthClient.getPublicKey} continue with the normal key lookup.
"""
options = ConchOptions()
options.identitys = [self.rsaFile.path]
agent = SSHAgentClient()
client = SSHUserAuthClient("user", options, None)
client.keyAgent = agent
key = client.getPublicKey()
self.assertEqual(key.isPublic(), True)
self.assertEqual(key, self.rsaPublic)
def test_getPublicKeyBadKeyError(self):
"""
If L{keys.Key.fromFile} raises a L{keys.BadKeyError}, the
L{SSHUserAuthClient.getPublicKey} tries again to get a public key by
calling itself recursively.
"""
options = ConchOptions()
self.tmpdir.child('id_dsa.pub').setContent(keydata.publicDSA_openssh)
dsaFile = self.tmpdir.child('id_dsa')
dsaFile.setContent(keydata.privateDSA_openssh)
options.identitys = [self.rsaFile.path, dsaFile.path]
self.tmpdir.child('id_rsa.pub').setContent('not a key!')
client = SSHUserAuthClient("user", options, None)
key = client.getPublicKey()
self.assertEqual(key.isPublic(), True)
self.assertEqual(key, Key.fromString(keydata.publicDSA_openssh))
self.assertEqual(client.usedFiles, [self.rsaFile.path, dsaFile.path])
def test_getPrivateKey(self):
"""
L{SSHUserAuthClient.getPrivateKey} will load a private key from the
last used file populated by L{SSHUserAuthClient.getPublicKey}, and
return a L{Deferred} which fires with the corresponding private L{Key}.
"""
rsaPrivate = Key.fromString(keydata.privateRSA_openssh)
options = ConchOptions()
options.identitys = [self.rsaFile.path]
client = SSHUserAuthClient("user", options, None)
# Populate the list of used files
client.getPublicKey()
def _cbGetPrivateKey(key):
self.assertEqual(key.isPublic(), False)
self.assertEqual(key, rsaPrivate)
return client.getPrivateKey().addCallback(_cbGetPrivateKey)
def test_getPrivateKeyPassphrase(self):
"""
L{SSHUserAuthClient} can get a private key from a file, and return a
Deferred called back with a private L{Key} object, even if the key is
encrypted.
"""
rsaPrivate = Key.fromString(keydata.privateRSA_openssh)
passphrase = 'this is the passphrase'
self.rsaFile.setContent(rsaPrivate.toString('openssh', passphrase))
options = ConchOptions()
options.identitys = [self.rsaFile.path]
client = SSHUserAuthClient("user", options, None)
# Populate the list of used files
client.getPublicKey()
def _getPassword(prompt):
self.assertEqual(prompt,
"Enter passphrase for key '%s': " % (
self.rsaFile.path,))
return passphrase
def _cbGetPrivateKey(key):
self.assertEqual(key.isPublic(), False)
self.assertEqual(key, rsaPrivate)
self.patch(client, '_getPassword', _getPassword)
return client.getPrivateKey().addCallback(_cbGetPrivateKey)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,771 @@
# -*- test-case-name: twisted.conch.test.test_filetransfer -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE file for details.
"""
Tests for L{twisted.conch.ssh.filetransfer}.
"""
import os
import re
import struct
import sys
from twisted.trial import unittest
try:
from twisted.conch import unix
unix # shut up pyflakes
except ImportError:
unix = None
from twisted.conch import avatar
from twisted.conch.ssh import common, connection, filetransfer, session
from twisted.internet import defer
from twisted.protocols import loopback
from twisted.python import components
class TestAvatar(avatar.ConchUser):
def __init__(self):
avatar.ConchUser.__init__(self)
self.channelLookup['session'] = session.SSHSession
self.subsystemLookup['sftp'] = filetransfer.FileTransferServer
def _runAsUser(self, f, *args, **kw):
try:
f = iter(f)
except TypeError:
f = [(f, args, kw)]
for i in f:
func = i[0]
args = len(i)>1 and i[1] or ()
kw = len(i)>2 and i[2] or {}
r = func(*args, **kw)
return r
class FileTransferTestAvatar(TestAvatar):
def __init__(self, homeDir):
TestAvatar.__init__(self)
self.homeDir = homeDir
def getHomeDir(self):
return os.path.join(os.getcwd(), self.homeDir)
class ConchSessionForTestAvatar:
def __init__(self, avatar):
self.avatar = avatar
if unix:
if not hasattr(unix, 'SFTPServerForUnixConchUser'):
# unix should either be a fully working module, or None. I'm not sure
# how this happens, but on win32 it does. Try to cope. --spiv.
import warnings
warnings.warn(("twisted.conch.unix imported %r, "
"but doesn't define SFTPServerForUnixConchUser'")
% (unix,))
unix = None
else:
class FileTransferForTestAvatar(unix.SFTPServerForUnixConchUser):
def gotVersion(self, version, otherExt):
return {'conchTest' : 'ext data'}
def extendedRequest(self, extName, extData):
if extName == 'testExtendedRequest':
return 'bar'
raise NotImplementedError
components.registerAdapter(FileTransferForTestAvatar,
TestAvatar,
filetransfer.ISFTPServer)
class SFTPTestBase(unittest.TestCase):
def setUp(self):
self.testDir = self.mktemp()
# Give the testDir another level so we can safely "cd .." from it in
# tests.
self.testDir = os.path.join(self.testDir, 'extra')
os.makedirs(os.path.join(self.testDir, 'testDirectory'))
f = file(os.path.join(self.testDir, 'testfile1'),'w')
f.write('a'*10+'b'*10)
f.write(file('/dev/urandom').read(1024*64)) # random data
os.chmod(os.path.join(self.testDir, 'testfile1'), 0644)
file(os.path.join(self.testDir, 'testRemoveFile'), 'w').write('a')
file(os.path.join(self.testDir, 'testRenameFile'), 'w').write('a')
file(os.path.join(self.testDir, '.testHiddenFile'), 'w').write('a')
class TestOurServerOurClient(SFTPTestBase):
if not unix:
skip = "can't run on non-posix computers"
def setUp(self):
SFTPTestBase.setUp(self)
self.avatar = FileTransferTestAvatar(self.testDir)
self.server = filetransfer.FileTransferServer(avatar=self.avatar)
clientTransport = loopback.LoopbackRelay(self.server)
self.client = filetransfer.FileTransferClient()
self._serverVersion = None
self._extData = None
def _(serverVersion, extData):
self._serverVersion = serverVersion
self._extData = extData
self.client.gotServerVersion = _
serverTransport = loopback.LoopbackRelay(self.client)
self.client.makeConnection(clientTransport)
self.server.makeConnection(serverTransport)
self.clientTransport = clientTransport
self.serverTransport = serverTransport
self._emptyBuffers()
def _emptyBuffers(self):
while self.serverTransport.buffer or self.clientTransport.buffer:
self.serverTransport.clearBuffer()
self.clientTransport.clearBuffer()
def tearDown(self):
self.serverTransport.loseConnection()
self.clientTransport.loseConnection()
self.serverTransport.clearBuffer()
self.clientTransport.clearBuffer()
def testServerVersion(self):
self.assertEqual(self._serverVersion, 3)
self.assertEqual(self._extData, {'conchTest' : 'ext data'})
def test_interface_implementation(self):
"""
It implements the ISFTPServer interface.
"""
self.assertTrue(
filetransfer.ISFTPServer.providedBy(self.server.client),
"ISFTPServer not provided by %r" % (self.server.client,))
def test_openedFileClosedWithConnection(self):
"""
A file opened with C{openFile} is close when the connection is lost.
"""
d = self.client.openFile("testfile1", filetransfer.FXF_READ |
filetransfer.FXF_WRITE, {})
self._emptyBuffers()
oldClose = os.close
closed = []
def close(fd):
closed.append(fd)
oldClose(fd)
self.patch(os, "close", close)
def _fileOpened(openFile):
fd = self.server.openFiles[openFile.handle[4:]].fd
self.serverTransport.loseConnection()
self.clientTransport.loseConnection()
self.serverTransport.clearBuffer()
self.clientTransport.clearBuffer()
self.assertEqual(self.server.openFiles, {})
self.assertIn(fd, closed)
d.addCallback(_fileOpened)
return d
def test_openedDirectoryClosedWithConnection(self):
"""
A directory opened with C{openDirectory} is close when the connection
is lost.
"""
d = self.client.openDirectory('')
self._emptyBuffers()
def _getFiles(openDir):
self.serverTransport.loseConnection()
self.clientTransport.loseConnection()
self.serverTransport.clearBuffer()
self.clientTransport.clearBuffer()
self.assertEqual(self.server.openDirs, {})
d.addCallback(_getFiles)
return d
def testOpenFileIO(self):
d = self.client.openFile("testfile1", filetransfer.FXF_READ |
filetransfer.FXF_WRITE, {})
self._emptyBuffers()
def _fileOpened(openFile):
self.assertEqual(openFile, filetransfer.ISFTPFile(openFile))
d = _readChunk(openFile)
d.addCallback(_writeChunk, openFile)
return d
def _readChunk(openFile):
d = openFile.readChunk(0, 20)
self._emptyBuffers()
d.addCallback(self.assertEqual, 'a'*10 + 'b'*10)
return d
def _writeChunk(_, openFile):
d = openFile.writeChunk(20, 'c'*10)
self._emptyBuffers()
d.addCallback(_readChunk2, openFile)
return d
def _readChunk2(_, openFile):
d = openFile.readChunk(0, 30)
self._emptyBuffers()
d.addCallback(self.assertEqual, 'a'*10 + 'b'*10 + 'c'*10)
return d
d.addCallback(_fileOpened)
return d
def testClosedFileGetAttrs(self):
d = self.client.openFile("testfile1", filetransfer.FXF_READ |
filetransfer.FXF_WRITE, {})
self._emptyBuffers()
def _getAttrs(_, openFile):
d = openFile.getAttrs()
self._emptyBuffers()
return d
def _err(f):
self.flushLoggedErrors()
return f
def _close(openFile):
d = openFile.close()
self._emptyBuffers()
d.addCallback(_getAttrs, openFile)
d.addErrback(_err)
return self.assertFailure(d, filetransfer.SFTPError)
d.addCallback(_close)
return d
def testOpenFileAttributes(self):
d = self.client.openFile("testfile1", filetransfer.FXF_READ |
filetransfer.FXF_WRITE, {})
self._emptyBuffers()
def _getAttrs(openFile):
d = openFile.getAttrs()
self._emptyBuffers()
d.addCallback(_getAttrs2)
return d
def _getAttrs2(attrs1):
d = self.client.getAttrs('testfile1')
self._emptyBuffers()
d.addCallback(self.assertEqual, attrs1)
return d
return d.addCallback(_getAttrs)
def testOpenFileSetAttrs(self):
# XXX test setAttrs
# Ok, how about this for a start? It caught a bug :) -- spiv.
d = self.client.openFile("testfile1", filetransfer.FXF_READ |
filetransfer.FXF_WRITE, {})
self._emptyBuffers()
def _getAttrs(openFile):
d = openFile.getAttrs()
self._emptyBuffers()
d.addCallback(_setAttrs)
return d
def _setAttrs(attrs):
attrs['atime'] = 0
d = self.client.setAttrs('testfile1', attrs)
self._emptyBuffers()
d.addCallback(_getAttrs2)
d.addCallback(self.assertEqual, attrs)
return d
def _getAttrs2(_):
d = self.client.getAttrs('testfile1')
self._emptyBuffers()
return d
d.addCallback(_getAttrs)
return d
def test_openFileExtendedAttributes(self):
"""
Check that L{filetransfer.FileTransferClient.openFile} can send
extended attributes, that should be extracted server side. By default,
they are ignored, so we just verify they are correctly parsed.
"""
savedAttributes = {}
oldOpenFile = self.server.client.openFile
def openFile(filename, flags, attrs):
savedAttributes.update(attrs)
return oldOpenFile(filename, flags, attrs)
self.server.client.openFile = openFile
d = self.client.openFile("testfile1", filetransfer.FXF_READ |
filetransfer.FXF_WRITE, {"ext_foo": "bar"})
self._emptyBuffers()
def check(ign):
self.assertEqual(savedAttributes, {"ext_foo": "bar"})
return d.addCallback(check)
def testRemoveFile(self):
d = self.client.getAttrs("testRemoveFile")
self._emptyBuffers()
def _removeFile(ignored):
d = self.client.removeFile("testRemoveFile")
self._emptyBuffers()
return d
d.addCallback(_removeFile)
d.addCallback(_removeFile)
return self.assertFailure(d, filetransfer.SFTPError)
def testRenameFile(self):
d = self.client.getAttrs("testRenameFile")
self._emptyBuffers()
def _rename(attrs):
d = self.client.renameFile("testRenameFile", "testRenamedFile")
self._emptyBuffers()
d.addCallback(_testRenamed, attrs)
return d
def _testRenamed(_, attrs):
d = self.client.getAttrs("testRenamedFile")
self._emptyBuffers()
d.addCallback(self.assertEqual, attrs)
return d.addCallback(_rename)
def testDirectoryBad(self):
d = self.client.getAttrs("testMakeDirectory")
self._emptyBuffers()
return self.assertFailure(d, filetransfer.SFTPError)
def testDirectoryCreation(self):
d = self.client.makeDirectory("testMakeDirectory", {})
self._emptyBuffers()
def _getAttrs(_):
d = self.client.getAttrs("testMakeDirectory")
self._emptyBuffers()
return d
# XXX not until version 4/5
# self.assertEqual(filetransfer.FILEXFER_TYPE_DIRECTORY&attrs['type'],
# filetransfer.FILEXFER_TYPE_DIRECTORY)
def _removeDirectory(_):
d = self.client.removeDirectory("testMakeDirectory")
self._emptyBuffers()
return d
d.addCallback(_getAttrs)
d.addCallback(_removeDirectory)
d.addCallback(_getAttrs)
return self.assertFailure(d, filetransfer.SFTPError)
def testOpenDirectory(self):
d = self.client.openDirectory('')
self._emptyBuffers()
files = []
def _getFiles(openDir):
def append(f):
files.append(f)
return openDir
d = defer.maybeDeferred(openDir.next)
self._emptyBuffers()
d.addCallback(append)
d.addCallback(_getFiles)
d.addErrback(_close, openDir)
return d
def _checkFiles(ignored):
fs = list(zip(*files)[0])
fs.sort()
self.assertEqual(fs,
['.testHiddenFile', 'testDirectory',
'testRemoveFile', 'testRenameFile',
'testfile1'])
def _close(_, openDir):
d = openDir.close()
self._emptyBuffers()
return d
d.addCallback(_getFiles)
d.addCallback(_checkFiles)
return d
def testLinkDoesntExist(self):
d = self.client.getAttrs('testLink')
self._emptyBuffers()
return self.assertFailure(d, filetransfer.SFTPError)
def testLinkSharesAttrs(self):
d = self.client.makeLink('testLink', 'testfile1')
self._emptyBuffers()
def _getFirstAttrs(_):
d = self.client.getAttrs('testLink', 1)
self._emptyBuffers()
return d
def _getSecondAttrs(firstAttrs):
d = self.client.getAttrs('testfile1')
self._emptyBuffers()
d.addCallback(self.assertEqual, firstAttrs)
return d
d.addCallback(_getFirstAttrs)
return d.addCallback(_getSecondAttrs)
def testLinkPath(self):
d = self.client.makeLink('testLink', 'testfile1')
self._emptyBuffers()
def _readLink(_):
d = self.client.readLink('testLink')
self._emptyBuffers()
d.addCallback(self.assertEqual,
os.path.join(os.getcwd(), self.testDir, 'testfile1'))
return d
def _realPath(_):
d = self.client.realPath('testLink')
self._emptyBuffers()
d.addCallback(self.assertEqual,
os.path.join(os.getcwd(), self.testDir, 'testfile1'))
return d
d.addCallback(_readLink)
d.addCallback(_realPath)
return d
def testExtendedRequest(self):
d = self.client.extendedRequest('testExtendedRequest', 'foo')
self._emptyBuffers()
d.addCallback(self.assertEqual, 'bar')
d.addCallback(self._cbTestExtendedRequest)
return d
def _cbTestExtendedRequest(self, ignored):
d = self.client.extendedRequest('testBadRequest', '')
self._emptyBuffers()
return self.assertFailure(d, NotImplementedError)
class FakeConn:
def sendClose(self, channel):
pass
class TestFileTransferClose(unittest.TestCase):
if not unix:
skip = "can't run on non-posix computers"
def setUp(self):
self.avatar = TestAvatar()
def buildServerConnection(self):
# make a server connection
conn = connection.SSHConnection()
# server connections have a 'self.transport.avatar'.
class DummyTransport:
def __init__(self):
self.transport = self
def sendPacket(self, kind, data):
pass
def logPrefix(self):
return 'dummy transport'
conn.transport = DummyTransport()
conn.transport.avatar = self.avatar
return conn
def interceptConnectionLost(self, sftpServer):
self.connectionLostFired = False
origConnectionLost = sftpServer.connectionLost
def connectionLost(reason):
self.connectionLostFired = True
origConnectionLost(reason)
sftpServer.connectionLost = connectionLost
def assertSFTPConnectionLost(self):
self.assertTrue(self.connectionLostFired,
"sftpServer's connectionLost was not called")
def test_sessionClose(self):
"""
Closing a session should notify an SFTP subsystem launched by that
session.
"""
# make a session
testSession = session.SSHSession(conn=FakeConn(), avatar=self.avatar)
# start an SFTP subsystem on the session
testSession.request_subsystem(common.NS('sftp'))
sftpServer = testSession.client.transport.proto
# intercept connectionLost so we can check that it's called
self.interceptConnectionLost(sftpServer)
# close session
testSession.closeReceived()
self.assertSFTPConnectionLost()
def test_clientClosesChannelOnConnnection(self):
"""
A client sending CHANNEL_CLOSE should trigger closeReceived on the
associated channel instance.
"""
conn = self.buildServerConnection()
# somehow get a session
packet = common.NS('session') + struct.pack('>L', 0) * 3
conn.ssh_CHANNEL_OPEN(packet)
sessionChannel = conn.channels[0]
sessionChannel.request_subsystem(common.NS('sftp'))
sftpServer = sessionChannel.client.transport.proto
self.interceptConnectionLost(sftpServer)
# intercept closeReceived
self.interceptConnectionLost(sftpServer)
# close the connection
conn.ssh_CHANNEL_CLOSE(struct.pack('>L', 0))
self.assertSFTPConnectionLost()
def test_stopConnectionServiceClosesChannel(self):
"""
Closing an SSH connection should close all sessions within it.
"""
conn = self.buildServerConnection()
# somehow get a session
packet = common.NS('session') + struct.pack('>L', 0) * 3
conn.ssh_CHANNEL_OPEN(packet)
sessionChannel = conn.channels[0]
sessionChannel.request_subsystem(common.NS('sftp'))
sftpServer = sessionChannel.client.transport.proto
self.interceptConnectionLost(sftpServer)
# close the connection
conn.serviceStopped()
self.assertSFTPConnectionLost()
class TestConstants(unittest.TestCase):
"""
Tests for the constants used by the SFTP protocol implementation.
@ivar filexferSpecExcerpts: Excerpts from the
draft-ietf-secsh-filexfer-02.txt (draft) specification of the SFTP
protocol. There are more recent drafts of the specification, but this
one describes version 3, which is what conch (and OpenSSH) implements.
"""
filexferSpecExcerpts = [
"""
The following values are defined for packet types.
#define SSH_FXP_INIT 1
#define SSH_FXP_VERSION 2
#define SSH_FXP_OPEN 3
#define SSH_FXP_CLOSE 4
#define SSH_FXP_READ 5
#define SSH_FXP_WRITE 6
#define SSH_FXP_LSTAT 7
#define SSH_FXP_FSTAT 8
#define SSH_FXP_SETSTAT 9
#define SSH_FXP_FSETSTAT 10
#define SSH_FXP_OPENDIR 11
#define SSH_FXP_READDIR 12
#define SSH_FXP_REMOVE 13
#define SSH_FXP_MKDIR 14
#define SSH_FXP_RMDIR 15
#define SSH_FXP_REALPATH 16
#define SSH_FXP_STAT 17
#define SSH_FXP_RENAME 18
#define SSH_FXP_READLINK 19
#define SSH_FXP_SYMLINK 20
#define SSH_FXP_STATUS 101
#define SSH_FXP_HANDLE 102
#define SSH_FXP_DATA 103
#define SSH_FXP_NAME 104
#define SSH_FXP_ATTRS 105
#define SSH_FXP_EXTENDED 200
#define SSH_FXP_EXTENDED_REPLY 201
Additional packet types should only be defined if the protocol
version number (see Section ``Protocol Initialization'') is
incremented, and their use MUST be negotiated using the version
number. However, the SSH_FXP_EXTENDED and SSH_FXP_EXTENDED_REPLY
packets can be used to implement vendor-specific extensions. See
Section ``Vendor-Specific-Extensions'' for more details.
""",
"""
The flags bits are defined to have the following values:
#define SSH_FILEXFER_ATTR_SIZE 0x00000001
#define SSH_FILEXFER_ATTR_UIDGID 0x00000002
#define SSH_FILEXFER_ATTR_PERMISSIONS 0x00000004
#define SSH_FILEXFER_ATTR_ACMODTIME 0x00000008
#define SSH_FILEXFER_ATTR_EXTENDED 0x80000000
""",
"""
The `pflags' field is a bitmask. The following bits have been
defined.
#define SSH_FXF_READ 0x00000001
#define SSH_FXF_WRITE 0x00000002
#define SSH_FXF_APPEND 0x00000004
#define SSH_FXF_CREAT 0x00000008
#define SSH_FXF_TRUNC 0x00000010
#define SSH_FXF_EXCL 0x00000020
""",
"""
Currently, the following values are defined (other values may be
defined by future versions of this protocol):
#define SSH_FX_OK 0
#define SSH_FX_EOF 1
#define SSH_FX_NO_SUCH_FILE 2
#define SSH_FX_PERMISSION_DENIED 3
#define SSH_FX_FAILURE 4
#define SSH_FX_BAD_MESSAGE 5
#define SSH_FX_NO_CONNECTION 6
#define SSH_FX_CONNECTION_LOST 7
#define SSH_FX_OP_UNSUPPORTED 8
"""]
def test_constantsAgainstSpec(self):
"""
The constants used by the SFTP protocol implementation match those
found by searching through the spec.
"""
constants = {}
for excerpt in self.filexferSpecExcerpts:
for line in excerpt.splitlines():
m = re.match('^\s*#define SSH_([A-Z_]+)\s+([0-9x]*)\s*$', line)
if m:
constants[m.group(1)] = long(m.group(2), 0)
self.assertTrue(
len(constants) > 0, "No constants found (the test must be buggy).")
for k, v in constants.items():
self.assertEqual(v, getattr(filetransfer, k))
class TestRawPacketData(unittest.TestCase):
"""
Tests for L{filetransfer.FileTransferClient} which explicitly craft certain
less common protocol messages to exercise their handling.
"""
def setUp(self):
self.ftc = filetransfer.FileTransferClient()
def test_packetSTATUS(self):
"""
A STATUS packet containing a result code, a message, and a language is
parsed to produce the result of an outstanding request L{Deferred}.
@see: U{section 9.1<http://tools.ietf.org/html/draft-ietf-secsh-filexfer-13#section-9.1>}
of the SFTP Internet-Draft.
"""
d = defer.Deferred()
d.addCallback(self._cbTestPacketSTATUS)
self.ftc.openRequests[1] = d
data = struct.pack('!LL', 1, filetransfer.FX_OK) + common.NS('msg') + common.NS('lang')
self.ftc.packet_STATUS(data)
return d
def _cbTestPacketSTATUS(self, result):
"""
Assert that the result is a two-tuple containing the message and
language from the STATUS packet.
"""
self.assertEqual(result[0], 'msg')
self.assertEqual(result[1], 'lang')
def test_packetSTATUSShort(self):
"""
A STATUS packet containing only a result code can also be parsed to
produce the result of an outstanding request L{Deferred}. Such packets
are sent by some SFTP implementations, though not strictly legal.
@see: U{section 9.1<http://tools.ietf.org/html/draft-ietf-secsh-filexfer-13#section-9.1>}
of the SFTP Internet-Draft.
"""
d = defer.Deferred()
d.addCallback(self._cbTestPacketSTATUSShort)
self.ftc.openRequests[1] = d
data = struct.pack('!LL', 1, filetransfer.FX_OK)
self.ftc.packet_STATUS(data)
return d
def _cbTestPacketSTATUSShort(self, result):
"""
Assert that the result is a two-tuple containing empty strings, since
the STATUS packet had neither a message nor a language.
"""
self.assertEqual(result[0], '')
self.assertEqual(result[1], '')
def test_packetSTATUSWithoutLang(self):
"""
A STATUS packet containing a result code and a message but no language
can also be parsed to produce the result of an outstanding request
L{Deferred}. Such packets are sent by some SFTP implementations, though
not strictly legal.
@see: U{section 9.1<http://tools.ietf.org/html/draft-ietf-secsh-filexfer-13#section-9.1>}
of the SFTP Internet-Draft.
"""
d = defer.Deferred()
d.addCallback(self._cbTestPacketSTATUSWithoutLang)
self.ftc.openRequests[1] = d
data = struct.pack('!LL', 1, filetransfer.FX_OK) + common.NS('msg')
self.ftc.packet_STATUS(data)
return d
def _cbTestPacketSTATUSWithoutLang(self, result):
"""
Assert that the result is a two-tuple containing the message from the
STATUS packet and an empty string, since the language was missing.
"""
self.assertEqual(result[0], 'msg')
self.assertEqual(result[1], '')

View file

@ -0,0 +1,614 @@
# -*- test-case-name: twisted.conch.test.test_helper -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
from twisted.conch.insults import helper
from twisted.conch.insults.insults import G0, G1, G2, G3
from twisted.conch.insults.insults import modes, privateModes
from twisted.conch.insults.insults import (
NORMAL, BOLD, UNDERLINE, BLINK, REVERSE_VIDEO)
from twisted.trial import unittest
WIDTH = 80
HEIGHT = 24
class BufferTestCase(unittest.TestCase):
def setUp(self):
self.term = helper.TerminalBuffer()
self.term.connectionMade()
def testInitialState(self):
self.assertEqual(self.term.width, WIDTH)
self.assertEqual(self.term.height, HEIGHT)
self.assertEqual(str(self.term),
'\n' * (HEIGHT - 1))
self.assertEqual(self.term.reportCursorPosition(), (0, 0))
def test_initialPrivateModes(self):
"""
Verify that only DEC Auto Wrap Mode (DECAWM) and DEC Text Cursor Enable
Mode (DECTCEM) are initially in the Set Mode (SM) state.
"""
self.assertEqual(
{privateModes.AUTO_WRAP: True,
privateModes.CURSOR_MODE: True},
self.term.privateModes)
def test_carriageReturn(self):
"""
C{"\r"} moves the cursor to the first column in the current row.
"""
self.term.cursorForward(5)
self.term.cursorDown(3)
self.assertEqual(self.term.reportCursorPosition(), (5, 3))
self.term.insertAtCursor("\r")
self.assertEqual(self.term.reportCursorPosition(), (0, 3))
def test_linefeed(self):
"""
C{"\n"} moves the cursor to the next row without changing the column.
"""
self.term.cursorForward(5)
self.assertEqual(self.term.reportCursorPosition(), (5, 0))
self.term.insertAtCursor("\n")
self.assertEqual(self.term.reportCursorPosition(), (5, 1))
def test_newline(self):
"""
C{write} transforms C{"\n"} into C{"\r\n"}.
"""
self.term.cursorForward(5)
self.term.cursorDown(3)
self.assertEqual(self.term.reportCursorPosition(), (5, 3))
self.term.write("\n")
self.assertEqual(self.term.reportCursorPosition(), (0, 4))
def test_setPrivateModes(self):
"""
Verify that L{helper.TerminalBuffer.setPrivateModes} changes the Set
Mode (SM) state to "set" for the private modes it is passed.
"""
expected = self.term.privateModes.copy()
self.term.setPrivateModes([privateModes.SCROLL, privateModes.SCREEN])
expected[privateModes.SCROLL] = True
expected[privateModes.SCREEN] = True
self.assertEqual(expected, self.term.privateModes)
def test_resetPrivateModes(self):
"""
Verify that L{helper.TerminalBuffer.resetPrivateModes} changes the Set
Mode (SM) state to "reset" for the private modes it is passed.
"""
expected = self.term.privateModes.copy()
self.term.resetPrivateModes([privateModes.AUTO_WRAP, privateModes.CURSOR_MODE])
del expected[privateModes.AUTO_WRAP]
del expected[privateModes.CURSOR_MODE]
self.assertEqual(expected, self.term.privateModes)
def testCursorDown(self):
self.term.cursorDown(3)
self.assertEqual(self.term.reportCursorPosition(), (0, 3))
self.term.cursorDown()
self.assertEqual(self.term.reportCursorPosition(), (0, 4))
self.term.cursorDown(HEIGHT)
self.assertEqual(self.term.reportCursorPosition(), (0, HEIGHT - 1))
def testCursorUp(self):
self.term.cursorUp(5)
self.assertEqual(self.term.reportCursorPosition(), (0, 0))
self.term.cursorDown(20)
self.term.cursorUp(1)
self.assertEqual(self.term.reportCursorPosition(), (0, 19))
self.term.cursorUp(19)
self.assertEqual(self.term.reportCursorPosition(), (0, 0))
def testCursorForward(self):
self.term.cursorForward(2)
self.assertEqual(self.term.reportCursorPosition(), (2, 0))
self.term.cursorForward(2)
self.assertEqual(self.term.reportCursorPosition(), (4, 0))
self.term.cursorForward(WIDTH)
self.assertEqual(self.term.reportCursorPosition(), (WIDTH, 0))
def testCursorBackward(self):
self.term.cursorForward(10)
self.term.cursorBackward(2)
self.assertEqual(self.term.reportCursorPosition(), (8, 0))
self.term.cursorBackward(7)
self.assertEqual(self.term.reportCursorPosition(), (1, 0))
self.term.cursorBackward(1)
self.assertEqual(self.term.reportCursorPosition(), (0, 0))
self.term.cursorBackward(1)
self.assertEqual(self.term.reportCursorPosition(), (0, 0))
def testCursorPositioning(self):
self.term.cursorPosition(3, 9)
self.assertEqual(self.term.reportCursorPosition(), (3, 9))
def testSimpleWriting(self):
s = "Hello, world."
self.term.write(s)
self.assertEqual(
str(self.term),
s + '\n' +
'\n' * (HEIGHT - 2))
def testOvertype(self):
s = "hello, world."
self.term.write(s)
self.term.cursorBackward(len(s))
self.term.resetModes([modes.IRM])
self.term.write("H")
self.assertEqual(
str(self.term),
("H" + s[1:]) + '\n' +
'\n' * (HEIGHT - 2))
def testInsert(self):
s = "ello, world."
self.term.write(s)
self.term.cursorBackward(len(s))
self.term.setModes([modes.IRM])
self.term.write("H")
self.assertEqual(
str(self.term),
("H" + s) + '\n' +
'\n' * (HEIGHT - 2))
def testWritingInTheMiddle(self):
s = "Hello, world."
self.term.cursorDown(5)
self.term.cursorForward(5)
self.term.write(s)
self.assertEqual(
str(self.term),
'\n' * 5 +
(self.term.fill * 5) + s + '\n' +
'\n' * (HEIGHT - 7))
def testWritingWrappedAtEndOfLine(self):
s = "Hello, world."
self.term.cursorForward(WIDTH - 5)
self.term.write(s)
self.assertEqual(
str(self.term),
s[:5].rjust(WIDTH) + '\n' +
s[5:] + '\n' +
'\n' * (HEIGHT - 3))
def testIndex(self):
self.term.index()
self.assertEqual(self.term.reportCursorPosition(), (0, 1))
self.term.cursorDown(HEIGHT)
self.assertEqual(self.term.reportCursorPosition(), (0, HEIGHT - 1))
self.term.index()
self.assertEqual(self.term.reportCursorPosition(), (0, HEIGHT - 1))
def testReverseIndex(self):
self.term.reverseIndex()
self.assertEqual(self.term.reportCursorPosition(), (0, 0))
self.term.cursorDown(2)
self.assertEqual(self.term.reportCursorPosition(), (0, 2))
self.term.reverseIndex()
self.assertEqual(self.term.reportCursorPosition(), (0, 1))
def test_nextLine(self):
"""
C{nextLine} positions the cursor at the beginning of the row below the
current row.
"""
self.term.nextLine()
self.assertEqual(self.term.reportCursorPosition(), (0, 1))
self.term.cursorForward(5)
self.assertEqual(self.term.reportCursorPosition(), (5, 1))
self.term.nextLine()
self.assertEqual(self.term.reportCursorPosition(), (0, 2))
def testSaveCursor(self):
self.term.cursorDown(5)
self.term.cursorForward(7)
self.assertEqual(self.term.reportCursorPosition(), (7, 5))
self.term.saveCursor()
self.term.cursorDown(7)
self.term.cursorBackward(3)
self.assertEqual(self.term.reportCursorPosition(), (4, 12))
self.term.restoreCursor()
self.assertEqual(self.term.reportCursorPosition(), (7, 5))
def testSingleShifts(self):
self.term.singleShift2()
self.term.write('Hi')
ch = self.term.getCharacter(0, 0)
self.assertEqual(ch[0], 'H')
self.assertEqual(ch[1].charset, G2)
ch = self.term.getCharacter(1, 0)
self.assertEqual(ch[0], 'i')
self.assertEqual(ch[1].charset, G0)
self.term.singleShift3()
self.term.write('!!')
ch = self.term.getCharacter(2, 0)
self.assertEqual(ch[0], '!')
self.assertEqual(ch[1].charset, G3)
ch = self.term.getCharacter(3, 0)
self.assertEqual(ch[0], '!')
self.assertEqual(ch[1].charset, G0)
def testShifting(self):
s1 = "Hello"
s2 = "World"
s3 = "Bye!"
self.term.write("Hello\n")
self.term.shiftOut()
self.term.write("World\n")
self.term.shiftIn()
self.term.write("Bye!\n")
g = G0
h = 0
for s in (s1, s2, s3):
for i in range(len(s)):
ch = self.term.getCharacter(i, h)
self.assertEqual(ch[0], s[i])
self.assertEqual(ch[1].charset, g)
g = g == G0 and G1 or G0
h += 1
def testGraphicRendition(self):
self.term.selectGraphicRendition(BOLD, UNDERLINE, BLINK, REVERSE_VIDEO)
self.term.write('W')
self.term.selectGraphicRendition(NORMAL)
self.term.write('X')
self.term.selectGraphicRendition(BLINK)
self.term.write('Y')
self.term.selectGraphicRendition(BOLD)
self.term.write('Z')
ch = self.term.getCharacter(0, 0)
self.assertEqual(ch[0], 'W')
self.assertTrue(ch[1].bold)
self.assertTrue(ch[1].underline)
self.assertTrue(ch[1].blink)
self.assertTrue(ch[1].reverseVideo)
ch = self.term.getCharacter(1, 0)
self.assertEqual(ch[0], 'X')
self.assertFalse(ch[1].bold)
self.assertFalse(ch[1].underline)
self.assertFalse(ch[1].blink)
self.assertFalse(ch[1].reverseVideo)
ch = self.term.getCharacter(2, 0)
self.assertEqual(ch[0], 'Y')
self.assertTrue(ch[1].blink)
self.assertFalse(ch[1].bold)
self.assertFalse(ch[1].underline)
self.assertFalse(ch[1].reverseVideo)
ch = self.term.getCharacter(3, 0)
self.assertEqual(ch[0], 'Z')
self.assertTrue(ch[1].blink)
self.assertTrue(ch[1].bold)
self.assertFalse(ch[1].underline)
self.assertFalse(ch[1].reverseVideo)
def testColorAttributes(self):
s1 = "Merry xmas"
s2 = "Just kidding"
self.term.selectGraphicRendition(helper.FOREGROUND + helper.RED,
helper.BACKGROUND + helper.GREEN)
self.term.write(s1 + "\n")
self.term.selectGraphicRendition(NORMAL)
self.term.write(s2 + "\n")
for i in range(len(s1)):
ch = self.term.getCharacter(i, 0)
self.assertEqual(ch[0], s1[i])
self.assertEqual(ch[1].charset, G0)
self.assertEqual(ch[1].bold, False)
self.assertEqual(ch[1].underline, False)
self.assertEqual(ch[1].blink, False)
self.assertEqual(ch[1].reverseVideo, False)
self.assertEqual(ch[1].foreground, helper.RED)
self.assertEqual(ch[1].background, helper.GREEN)
for i in range(len(s2)):
ch = self.term.getCharacter(i, 1)
self.assertEqual(ch[0], s2[i])
self.assertEqual(ch[1].charset, G0)
self.assertEqual(ch[1].bold, False)
self.assertEqual(ch[1].underline, False)
self.assertEqual(ch[1].blink, False)
self.assertEqual(ch[1].reverseVideo, False)
self.assertEqual(ch[1].foreground, helper.WHITE)
self.assertEqual(ch[1].background, helper.BLACK)
def testEraseLine(self):
s1 = 'line 1'
s2 = 'line 2'
s3 = 'line 3'
self.term.write('\n'.join((s1, s2, s3)) + '\n')
self.term.cursorPosition(1, 1)
self.term.eraseLine()
self.assertEqual(
str(self.term),
s1 + '\n' +
'\n' +
s3 + '\n' +
'\n' * (HEIGHT - 4))
def testEraseToLineEnd(self):
s = 'Hello, world.'
self.term.write(s)
self.term.cursorBackward(5)
self.term.eraseToLineEnd()
self.assertEqual(
str(self.term),
s[:-5] + '\n' +
'\n' * (HEIGHT - 2))
def testEraseToLineBeginning(self):
s = 'Hello, world.'
self.term.write(s)
self.term.cursorBackward(5)
self.term.eraseToLineBeginning()
self.assertEqual(
str(self.term),
s[-4:].rjust(len(s)) + '\n' +
'\n' * (HEIGHT - 2))
def testEraseDisplay(self):
self.term.write('Hello world\n')
self.term.write('Goodbye world\n')
self.term.eraseDisplay()
self.assertEqual(
str(self.term),
'\n' * (HEIGHT - 1))
def testEraseToDisplayEnd(self):
s1 = "Hello world"
s2 = "Goodbye world"
self.term.write('\n'.join((s1, s2, '')))
self.term.cursorPosition(5, 1)
self.term.eraseToDisplayEnd()
self.assertEqual(
str(self.term),
s1 + '\n' +
s2[:5] + '\n' +
'\n' * (HEIGHT - 3))
def testEraseToDisplayBeginning(self):
s1 = "Hello world"
s2 = "Goodbye world"
self.term.write('\n'.join((s1, s2)))
self.term.cursorPosition(5, 1)
self.term.eraseToDisplayBeginning()
self.assertEqual(
str(self.term),
'\n' +
s2[6:].rjust(len(s2)) + '\n' +
'\n' * (HEIGHT - 3))
def testLineInsertion(self):
s1 = "Hello world"
s2 = "Goodbye world"
self.term.write('\n'.join((s1, s2)))
self.term.cursorPosition(7, 1)
self.term.insertLine()
self.assertEqual(
str(self.term),
s1 + '\n' +
'\n' +
s2 + '\n' +
'\n' * (HEIGHT - 4))
def testLineDeletion(self):
s1 = "Hello world"
s2 = "Middle words"
s3 = "Goodbye world"
self.term.write('\n'.join((s1, s2, s3)))
self.term.cursorPosition(9, 1)
self.term.deleteLine()
self.assertEqual(
str(self.term),
s1 + '\n' +
s3 + '\n' +
'\n' * (HEIGHT - 3))
class FakeDelayedCall:
called = False
cancelled = False
def __init__(self, fs, timeout, f, a, kw):
self.fs = fs
self.timeout = timeout
self.f = f
self.a = a
self.kw = kw
def active(self):
return not (self.cancelled or self.called)
def cancel(self):
self.cancelled = True
# self.fs.calls.remove(self)
def call(self):
self.called = True
self.f(*self.a, **self.kw)
class FakeScheduler:
def __init__(self):
self.calls = []
def callLater(self, timeout, f, *a, **kw):
self.calls.append(FakeDelayedCall(self, timeout, f, a, kw))
return self.calls[-1]
class ExpectTestCase(unittest.TestCase):
def setUp(self):
self.term = helper.ExpectableBuffer()
self.term.connectionMade()
self.fs = FakeScheduler()
def testSimpleString(self):
result = []
d = self.term.expect("hello world", timeout=1, scheduler=self.fs)
d.addCallback(result.append)
self.term.write("greeting puny earthlings\n")
self.assertFalse(result)
self.term.write("hello world\n")
self.assertTrue(result)
self.assertEqual(result[0].group(), "hello world")
self.assertEqual(len(self.fs.calls), 1)
self.assertFalse(self.fs.calls[0].active())
def testBrokenUpString(self):
result = []
d = self.term.expect("hello world")
d.addCallback(result.append)
self.assertFalse(result)
self.term.write("hello ")
self.assertFalse(result)
self.term.write("worl")
self.assertFalse(result)
self.term.write("d")
self.assertTrue(result)
self.assertEqual(result[0].group(), "hello world")
def testMultiple(self):
result = []
d1 = self.term.expect("hello ")
d1.addCallback(result.append)
d2 = self.term.expect("world")
d2.addCallback(result.append)
self.assertFalse(result)
self.term.write("hello")
self.assertFalse(result)
self.term.write(" ")
self.assertEqual(len(result), 1)
self.term.write("world")
self.assertEqual(len(result), 2)
self.assertEqual(result[0].group(), "hello ")
self.assertEqual(result[1].group(), "world")
def testSynchronous(self):
self.term.write("hello world")
result = []
d = self.term.expect("hello world")
d.addCallback(result.append)
self.assertTrue(result)
self.assertEqual(result[0].group(), "hello world")
def testMultipleSynchronous(self):
self.term.write("goodbye world")
result = []
d1 = self.term.expect("bye")
d1.addCallback(result.append)
d2 = self.term.expect("world")
d2.addCallback(result.append)
self.assertEqual(len(result), 2)
self.assertEqual(result[0].group(), "bye")
self.assertEqual(result[1].group(), "world")
def _cbTestTimeoutFailure(self, res):
self.assertTrue(hasattr(res, 'type'))
self.assertEqual(res.type, helper.ExpectationTimeout)
def testTimeoutFailure(self):
d = self.term.expect("hello world", timeout=1, scheduler=self.fs)
d.addBoth(self._cbTestTimeoutFailure)
self.fs.calls[0].call()
def testOverlappingTimeout(self):
self.term.write("not zoomtastic")
result = []
d1 = self.term.expect("hello world", timeout=1, scheduler=self.fs)
d1.addBoth(self._cbTestTimeoutFailure)
d2 = self.term.expect("zoom")
d2.addCallback(result.append)
self.fs.calls[0].call()
self.assertEqual(len(result), 1)
self.assertEqual(result[0].group(), "zoom")
class CharacterAttributeTests(unittest.TestCase):
"""
Tests for L{twisted.conch.insults.helper.CharacterAttribute}.
"""
def test_equality(self):
"""
L{CharacterAttribute}s must have matching character attribute values
(bold, blink, underline, etc) with the same values to be considered
equal.
"""
self.assertEqual(
helper.CharacterAttribute(),
helper.CharacterAttribute())
self.assertEqual(
helper.CharacterAttribute(),
helper.CharacterAttribute(charset=G0))
self.assertEqual(
helper.CharacterAttribute(
bold=True, underline=True, blink=False, reverseVideo=True,
foreground=helper.BLUE),
helper.CharacterAttribute(
bold=True, underline=True, blink=False, reverseVideo=True,
foreground=helper.BLUE))
self.assertNotEqual(
helper.CharacterAttribute(),
helper.CharacterAttribute(charset=G1))
self.assertNotEqual(
helper.CharacterAttribute(bold=True),
helper.CharacterAttribute(bold=False))
def test_wantOneDeprecated(self):
"""
L{twisted.conch.insults.helper.CharacterAttribute.wantOne} emits
a deprecation warning when invoked.
"""
# Trigger the deprecation warning.
helper._FormattingState().wantOne(bold=True)
warningsShown = self.flushWarnings([self.test_wantOneDeprecated])
self.assertEqual(len(warningsShown), 1)
self.assertEqual(warningsShown[0]['category'], DeprecationWarning)
self.assertEqual(
warningsShown[0]['message'],
'twisted.conch.insults.helper.wantOne was deprecated in '
'Twisted 13.1.0')

View file

@ -0,0 +1,496 @@
# -*- test-case-name: twisted.conch.test.test_insults -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
from twisted.trial import unittest
from twisted.test.proto_helpers import StringTransport
from twisted.conch.insults.insults import ServerProtocol, ClientProtocol
from twisted.conch.insults.insults import CS_UK, CS_US, CS_DRAWING, CS_ALTERNATE, CS_ALTERNATE_SPECIAL
from twisted.conch.insults.insults import G0, G1
from twisted.conch.insults.insults import modes
def _getattr(mock, name):
return super(Mock, mock).__getattribute__(name)
def occurrences(mock):
return _getattr(mock, 'occurrences')
def methods(mock):
return _getattr(mock, 'methods')
def _append(mock, obj):
occurrences(mock).append(obj)
default = object()
class Mock(object):
callReturnValue = default
def __init__(self, methods=None, callReturnValue=default):
"""
@param methods: Mapping of names to return values
@param callReturnValue: object __call__ should return
"""
self.occurrences = []
if methods is None:
methods = {}
self.methods = methods
if callReturnValue is not default:
self.callReturnValue = callReturnValue
def __call__(self, *a, **kw):
returnValue = _getattr(self, 'callReturnValue')
if returnValue is default:
returnValue = Mock()
# _getattr(self, 'occurrences').append(('__call__', returnValue, a, kw))
_append(self, ('__call__', returnValue, a, kw))
return returnValue
def __getattribute__(self, name):
methods = _getattr(self, 'methods')
if name in methods:
attrValue = Mock(callReturnValue=methods[name])
else:
attrValue = Mock()
# _getattr(self, 'occurrences').append((name, attrValue))
_append(self, (name, attrValue))
return attrValue
class MockMixin:
def assertCall(self, occurrence, methodName, expectedPositionalArgs=(),
expectedKeywordArgs={}):
attr, mock = occurrence
self.assertEqual(attr, methodName)
self.assertEqual(len(occurrences(mock)), 1)
[(call, result, args, kw)] = occurrences(mock)
self.assertEqual(call, "__call__")
self.assertEqual(args, expectedPositionalArgs)
self.assertEqual(kw, expectedKeywordArgs)
return result
_byteGroupingTestTemplate = """\
def testByte%(groupName)s(self):
transport = StringTransport()
proto = Mock()
parser = self.protocolFactory(lambda: proto)
parser.factory = self
parser.makeConnection(transport)
bytes = self.TEST_BYTES
while bytes:
chunk = bytes[:%(bytesPer)d]
bytes = bytes[%(bytesPer)d:]
parser.dataReceived(chunk)
self.verifyResults(transport, proto, parser)
"""
class ByteGroupingsMixin(MockMixin):
protocolFactory = None
for word, n in [('Pairs', 2), ('Triples', 3), ('Quads', 4), ('Quints', 5), ('Sexes', 6)]:
exec _byteGroupingTestTemplate % {'groupName': word, 'bytesPer': n}
del word, n
def verifyResults(self, transport, proto, parser):
result = self.assertCall(occurrences(proto).pop(0), "makeConnection", (parser,))
self.assertEqual(occurrences(result), [])
del _byteGroupingTestTemplate
class ServerArrowKeys(ByteGroupingsMixin, unittest.TestCase):
protocolFactory = ServerProtocol
# All the arrow keys once
TEST_BYTES = '\x1b[A\x1b[B\x1b[C\x1b[D'
def verifyResults(self, transport, proto, parser):
ByteGroupingsMixin.verifyResults(self, transport, proto, parser)
for arrow in (parser.UP_ARROW, parser.DOWN_ARROW,
parser.RIGHT_ARROW, parser.LEFT_ARROW):
result = self.assertCall(occurrences(proto).pop(0), "keystrokeReceived", (arrow, None))
self.assertEqual(occurrences(result), [])
self.assertFalse(occurrences(proto))
class PrintableCharacters(ByteGroupingsMixin, unittest.TestCase):
protocolFactory = ServerProtocol
# Some letters and digits, first on their own, then capitalized,
# then modified with alt
TEST_BYTES = 'abc123ABC!@#\x1ba\x1bb\x1bc\x1b1\x1b2\x1b3'
def verifyResults(self, transport, proto, parser):
ByteGroupingsMixin.verifyResults(self, transport, proto, parser)
for char in 'abc123ABC!@#':
result = self.assertCall(occurrences(proto).pop(0), "keystrokeReceived", (char, None))
self.assertEqual(occurrences(result), [])
for char in 'abc123':
result = self.assertCall(occurrences(proto).pop(0), "keystrokeReceived", (char, parser.ALT))
self.assertEqual(occurrences(result), [])
occs = occurrences(proto)
self.assertFalse(occs, "%r should have been []" % (occs,))
class ServerFunctionKeys(ByteGroupingsMixin, unittest.TestCase):
"""Test for parsing and dispatching function keys (F1 - F12)
"""
protocolFactory = ServerProtocol
byteList = []
for bytes in ('OP', 'OQ', 'OR', 'OS', # F1 - F4
'15~', '17~', '18~', '19~', # F5 - F8
'20~', '21~', '23~', '24~'): # F9 - F12
byteList.append('\x1b[' + bytes)
TEST_BYTES = ''.join(byteList)
del byteList, bytes
def verifyResults(self, transport, proto, parser):
ByteGroupingsMixin.verifyResults(self, transport, proto, parser)
for funcNum in range(1, 13):
funcArg = getattr(parser, 'F%d' % (funcNum,))
result = self.assertCall(occurrences(proto).pop(0), "keystrokeReceived", (funcArg, None))
self.assertEqual(occurrences(result), [])
self.assertFalse(occurrences(proto))
class ClientCursorMovement(ByteGroupingsMixin, unittest.TestCase):
protocolFactory = ClientProtocol
d2 = "\x1b[2B"
r4 = "\x1b[4C"
u1 = "\x1b[A"
l2 = "\x1b[2D"
# Move the cursor down two, right four, up one, left two, up one, left two
TEST_BYTES = d2 + r4 + u1 + l2 + u1 + l2
del d2, r4, u1, l2
def verifyResults(self, transport, proto, parser):
ByteGroupingsMixin.verifyResults(self, transport, proto, parser)
for (method, count) in [('Down', 2), ('Forward', 4), ('Up', 1),
('Backward', 2), ('Up', 1), ('Backward', 2)]:
result = self.assertCall(occurrences(proto).pop(0), "cursor" + method, (count,))
self.assertEqual(occurrences(result), [])
self.assertFalse(occurrences(proto))
class ClientControlSequences(unittest.TestCase, MockMixin):
def setUp(self):
self.transport = StringTransport()
self.proto = Mock()
self.parser = ClientProtocol(lambda: self.proto)
self.parser.factory = self
self.parser.makeConnection(self.transport)
result = self.assertCall(occurrences(self.proto).pop(0), "makeConnection", (self.parser,))
self.assertFalse(occurrences(result))
def testSimpleCardinals(self):
self.parser.dataReceived(
''.join([''.join(['\x1b[' + str(n) + ch for n in ('', 2, 20, 200)]) for ch in 'BACD']))
occs = occurrences(self.proto)
for meth in ("Down", "Up", "Forward", "Backward"):
for count in (1, 2, 20, 200):
result = self.assertCall(occs.pop(0), "cursor" + meth, (count,))
self.assertFalse(occurrences(result))
self.assertFalse(occs)
def testScrollRegion(self):
self.parser.dataReceived('\x1b[5;22r\x1b[r')
occs = occurrences(self.proto)
result = self.assertCall(occs.pop(0), "setScrollRegion", (5, 22))
self.assertFalse(occurrences(result))
result = self.assertCall(occs.pop(0), "setScrollRegion", (None, None))
self.assertFalse(occurrences(result))
self.assertFalse(occs)
def testHeightAndWidth(self):
self.parser.dataReceived("\x1b#3\x1b#4\x1b#5\x1b#6")
occs = occurrences(self.proto)
result = self.assertCall(occs.pop(0), "doubleHeightLine", (True,))
self.assertFalse(occurrences(result))
result = self.assertCall(occs.pop(0), "doubleHeightLine", (False,))
self.assertFalse(occurrences(result))
result = self.assertCall(occs.pop(0), "singleWidthLine")
self.assertFalse(occurrences(result))
result = self.assertCall(occs.pop(0), "doubleWidthLine")
self.assertFalse(occurrences(result))
self.assertFalse(occs)
def testCharacterSet(self):
self.parser.dataReceived(
''.join([''.join(['\x1b' + g + n for n in 'AB012']) for g in '()']))
occs = occurrences(self.proto)
for which in (G0, G1):
for charset in (CS_UK, CS_US, CS_DRAWING, CS_ALTERNATE, CS_ALTERNATE_SPECIAL):
result = self.assertCall(occs.pop(0), "selectCharacterSet", (charset, which))
self.assertFalse(occurrences(result))
self.assertFalse(occs)
def testShifting(self):
self.parser.dataReceived("\x15\x14")
occs = occurrences(self.proto)
result = self.assertCall(occs.pop(0), "shiftIn")
self.assertFalse(occurrences(result))
result = self.assertCall(occs.pop(0), "shiftOut")
self.assertFalse(occurrences(result))
self.assertFalse(occs)
def testSingleShifts(self):
self.parser.dataReceived("\x1bN\x1bO")
occs = occurrences(self.proto)
result = self.assertCall(occs.pop(0), "singleShift2")
self.assertFalse(occurrences(result))
result = self.assertCall(occs.pop(0), "singleShift3")
self.assertFalse(occurrences(result))
self.assertFalse(occs)
def testKeypadMode(self):
self.parser.dataReceived("\x1b=\x1b>")
occs = occurrences(self.proto)
result = self.assertCall(occs.pop(0), "applicationKeypadMode")
self.assertFalse(occurrences(result))
result = self.assertCall(occs.pop(0), "numericKeypadMode")
self.assertFalse(occurrences(result))
self.assertFalse(occs)
def testCursor(self):
self.parser.dataReceived("\x1b7\x1b8")
occs = occurrences(self.proto)
result = self.assertCall(occs.pop(0), "saveCursor")
self.assertFalse(occurrences(result))
result = self.assertCall(occs.pop(0), "restoreCursor")
self.assertFalse(occurrences(result))
self.assertFalse(occs)
def testReset(self):
self.parser.dataReceived("\x1bc")
occs = occurrences(self.proto)
result = self.assertCall(occs.pop(0), "reset")
self.assertFalse(occurrences(result))
self.assertFalse(occs)
def testIndex(self):
self.parser.dataReceived("\x1bD\x1bM\x1bE")
occs = occurrences(self.proto)
result = self.assertCall(occs.pop(0), "index")
self.assertFalse(occurrences(result))
result = self.assertCall(occs.pop(0), "reverseIndex")
self.assertFalse(occurrences(result))
result = self.assertCall(occs.pop(0), "nextLine")
self.assertFalse(occurrences(result))
self.assertFalse(occs)
def testModes(self):
self.parser.dataReceived(
"\x1b[" + ';'.join(map(str, [modes.KAM, modes.IRM, modes.LNM])) + "h")
self.parser.dataReceived(
"\x1b[" + ';'.join(map(str, [modes.KAM, modes.IRM, modes.LNM])) + "l")
occs = occurrences(self.proto)
result = self.assertCall(occs.pop(0), "setModes", ([modes.KAM, modes.IRM, modes.LNM],))
self.assertFalse(occurrences(result))
result = self.assertCall(occs.pop(0), "resetModes", ([modes.KAM, modes.IRM, modes.LNM],))
self.assertFalse(occurrences(result))
self.assertFalse(occs)
def testErasure(self):
self.parser.dataReceived(
"\x1b[K\x1b[1K\x1b[2K\x1b[J\x1b[1J\x1b[2J\x1b[3P")
occs = occurrences(self.proto)
for meth in ("eraseToLineEnd", "eraseToLineBeginning", "eraseLine",
"eraseToDisplayEnd", "eraseToDisplayBeginning",
"eraseDisplay"):
result = self.assertCall(occs.pop(0), meth)
self.assertFalse(occurrences(result))
result = self.assertCall(occs.pop(0), "deleteCharacter", (3,))
self.assertFalse(occurrences(result))
self.assertFalse(occs)
def testLineDeletion(self):
self.parser.dataReceived("\x1b[M\x1b[3M")
occs = occurrences(self.proto)
for arg in (1, 3):
result = self.assertCall(occs.pop(0), "deleteLine", (arg,))
self.assertFalse(occurrences(result))
self.assertFalse(occs)
def testLineInsertion(self):
self.parser.dataReceived("\x1b[L\x1b[3L")
occs = occurrences(self.proto)
for arg in (1, 3):
result = self.assertCall(occs.pop(0), "insertLine", (arg,))
self.assertFalse(occurrences(result))
self.assertFalse(occs)
def testCursorPosition(self):
methods(self.proto)['reportCursorPosition'] = (6, 7)
self.parser.dataReceived("\x1b[6n")
self.assertEqual(self.transport.value(), "\x1b[7;8R")
occs = occurrences(self.proto)
result = self.assertCall(occs.pop(0), "reportCursorPosition")
# This isn't really an interesting assert, since it only tests that
# our mock setup is working right, but I'll include it anyway.
self.assertEqual(result, (6, 7))
def test_applicationDataBytes(self):
"""
Contiguous non-control bytes are passed to a single call to the
C{write} method of the terminal to which the L{ClientProtocol} is
connected.
"""
occs = occurrences(self.proto)
self.parser.dataReceived('a')
self.assertCall(occs.pop(0), "write", ("a",))
self.parser.dataReceived('bc')
self.assertCall(occs.pop(0), "write", ("bc",))
def _applicationDataTest(self, data, calls):
occs = occurrences(self.proto)
self.parser.dataReceived(data)
while calls:
self.assertCall(occs.pop(0), *calls.pop(0))
self.assertFalse(occs, "No other calls should happen: %r" % (occs,))
def test_shiftInAfterApplicationData(self):
"""
Application data bytes followed by a shift-in command are passed to a
call to C{write} before the terminal's C{shiftIn} method is called.
"""
self._applicationDataTest(
'ab\x15', [
("write", ("ab",)),
("shiftIn",)])
def test_shiftOutAfterApplicationData(self):
"""
Application data bytes followed by a shift-out command are passed to a
call to C{write} before the terminal's C{shiftOut} method is called.
"""
self._applicationDataTest(
'ab\x14', [
("write", ("ab",)),
("shiftOut",)])
def test_cursorBackwardAfterApplicationData(self):
"""
Application data bytes followed by a cursor-backward command are passed
to a call to C{write} before the terminal's C{cursorBackward} method is
called.
"""
self._applicationDataTest(
'ab\x08', [
("write", ("ab",)),
("cursorBackward",)])
def test_escapeAfterApplicationData(self):
"""
Application data bytes followed by an escape character are passed to a
call to C{write} before the terminal's handler method for the escape is
called.
"""
# Test a short escape
self._applicationDataTest(
'ab\x1bD', [
("write", ("ab",)),
("index",)])
# And a long escape
self._applicationDataTest(
'ab\x1b[4h', [
("write", ("ab",)),
("setModes", ([4],))])
# There's some other cases too, but they're all handled by the same
# codepaths as above.
class ServerProtocolOutputTests(unittest.TestCase):
"""
Tests for the bytes L{ServerProtocol} writes to its transport when its
methods are called.
"""
def test_nextLine(self):
"""
L{ServerProtocol.nextLine} writes C{"\r\n"} to its transport.
"""
# Why doesn't it write ESC E? Because ESC E is poorly supported. For
# example, gnome-terminal (many different versions) fails to scroll if
# it receives ESC E and the cursor is already on the last row.
protocol = ServerProtocol()
transport = StringTransport()
protocol.makeConnection(transport)
protocol.nextLine()
self.assertEqual(transport.value(), "\r\n")
class Deprecations(unittest.TestCase):
"""
Tests to ensure deprecation of L{insults.colors} and L{insults.client}
"""
def ensureDeprecated(self, message):
"""
Ensures that the correct deprecation warning was issued.
"""
warnings = self.flushWarnings()
self.assertIs(warnings[0]['category'], DeprecationWarning)
self.assertEqual(warnings[0]['message'], message)
self.assertEqual(len(warnings), 1)
def test_colors(self):
"""
The L{insults.colors} module is deprecated
"""
from twisted.conch.insults import colors
self.ensureDeprecated("twisted.conch.insults.colors was deprecated "
"in Twisted 10.1.0: Please use "
"twisted.conch.insults.helper instead.")
def test_client(self):
"""
The L{insults.client} module is deprecated
"""
from twisted.conch.insults import client
self.ensureDeprecated("twisted.conch.insults.client was deprecated "
"in Twisted 10.1.0: Please use "
"twisted.conch.insults.insults instead.")

View file

@ -0,0 +1,644 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.conch.ssh.keys}.
"""
try:
import Crypto.Cipher.DES3
except ImportError:
# we'll have to skip these tests without PyCypto and pyasn1
Crypto = None
try:
import pyasn1
except ImportError:
pyasn1 = None
if Crypto and pyasn1:
from twisted.conch.ssh import keys, common, sexpy
import os, base64
from hashlib import sha1
from twisted.conch.test import keydata
from twisted.python import randbytes
from twisted.trial import unittest
class HelpersTestCase(unittest.TestCase):
if Crypto is None:
skip = "cannot run w/o PyCrypto"
if pyasn1 is None:
skip = "Cannot run without PyASN1"
def setUp(self):
self._secureRandom = randbytes.secureRandom
randbytes.secureRandom = lambda x: '\x55' * x
def tearDown(self):
randbytes.secureRandom = self._secureRandom
self._secureRandom = None
def test_pkcs1(self):
"""
Test Public Key Cryptographic Standard #1 functions.
"""
data = 'ABC'
messageSize = 6
self.assertEqual(keys.pkcs1Pad(data, messageSize),
'\x01\xff\x00ABC')
hash = sha1().digest()
messageSize = 40
self.assertEqual(keys.pkcs1Digest('', messageSize),
'\x01\xff\xff\xff\x00' + keys.ID_SHA1 + hash)
def _signRSA(self, data):
key = keys.Key.fromString(keydata.privateRSA_openssh)
sig = key.sign(data)
return key.keyObject, sig
def _signDSA(self, data):
key = keys.Key.fromString(keydata.privateDSA_openssh)
sig = key.sign(data)
return key.keyObject, sig
def test_signRSA(self):
"""
Test that RSA keys return appropriate signatures.
"""
data = 'data'
key, sig = self._signRSA(data)
sigData = keys.pkcs1Digest(data, keys.lenSig(key))
v = key.sign(sigData, '')[0]
self.assertEqual(sig, common.NS('ssh-rsa') + common.MP(v))
return key, sig
def test_signDSA(self):
"""
Test that DSA keys return appropriate signatures.
"""
data = 'data'
key, sig = self._signDSA(data)
sigData = sha1(data).digest()
v = key.sign(sigData, '\x55' * 19)
self.assertEqual(sig, common.NS('ssh-dss') + common.NS(
Crypto.Util.number.long_to_bytes(v[0], 20) +
Crypto.Util.number.long_to_bytes(v[1], 20)))
return key, sig
def test_objectType(self):
"""
Test that objectType, returns the correct type for objects.
"""
self.assertEqual(keys.objectType(keys.Key.fromString(
keydata.privateRSA_openssh).keyObject), 'ssh-rsa')
self.assertEqual(keys.objectType(keys.Key.fromString(
keydata.privateDSA_openssh).keyObject), 'ssh-dss')
self.assertRaises(keys.BadKeyError, keys.objectType, None)
class KeyTestCase(unittest.TestCase):
if Crypto is None:
skip = "cannot run w/o PyCrypto"
if pyasn1 is None:
skip = "Cannot run without PyASN1"
def setUp(self):
self.rsaObj = Crypto.PublicKey.RSA.construct((1L, 2L, 3L, 4L, 5L))
self.dsaObj = Crypto.PublicKey.DSA.construct((1L, 2L, 3L, 4L, 5L))
self.rsaSignature = ('\x00\x00\x00\x07ssh-rsa\x00'
'\x00\x00`N\xac\xb4@qK\xa0(\xc3\xf2h \xd3\xdd\xee6Np\x9d_'
'\xb0>\xe3\x0c(L\x9d{\txUd|!\xf6m\x9c\xd3\x93\x842\x7fU'
'\x05\xf4\xf7\xfaD\xda\xce\x81\x8ea\x7f=Y\xed*\xb7\xba\x81'
'\xf2\xad\xda\xeb(\x97\x03S\x08\x81\xc7\xb1\xb7\xe6\xe3'
'\xcd*\xd4\xbd\xc0wt\xf7y\xcd\xf0\xb7\x7f\xfb\x1e>\xf9r'
'\x8c\xba')
self.dsaSignature = ('\x00\x00\x00\x07ssh-dss\x00\x00'
'\x00(\x18z)H\x8a\x1b\xc6\r\xbbq\xa2\xd7f\x7f$\xa7\xbf'
'\xe8\x87\x8c\x88\xef\xd9k\x1a\x98\xdd{=\xdec\x18\t\xe3'
'\x87\xa9\xc72h\x95')
self.oldSecureRandom = randbytes.secureRandom
randbytes.secureRandom = lambda x: '\xff' * x
self.keyFile = self.mktemp()
file(self.keyFile, 'wb').write(keydata.privateRSA_lsh)
def tearDown(self):
randbytes.secureRandom = self.oldSecureRandom
del self.oldSecureRandom
os.unlink(self.keyFile)
def test__guessStringType(self):
"""
Test that the _guessStringType method guesses string types
correctly.
"""
self.assertEqual(keys.Key._guessStringType(keydata.publicRSA_openssh),
'public_openssh')
self.assertEqual(keys.Key._guessStringType(keydata.publicDSA_openssh),
'public_openssh')
self.assertEqual(keys.Key._guessStringType(
keydata.privateRSA_openssh), 'private_openssh')
self.assertEqual(keys.Key._guessStringType(
keydata.privateDSA_openssh), 'private_openssh')
self.assertEqual(keys.Key._guessStringType(keydata.publicRSA_lsh),
'public_lsh')
self.assertEqual(keys.Key._guessStringType(keydata.publicDSA_lsh),
'public_lsh')
self.assertEqual(keys.Key._guessStringType(keydata.privateRSA_lsh),
'private_lsh')
self.assertEqual(keys.Key._guessStringType(keydata.privateDSA_lsh),
'private_lsh')
self.assertEqual(keys.Key._guessStringType(
keydata.privateRSA_agentv3), 'agentv3')
self.assertEqual(keys.Key._guessStringType(
keydata.privateDSA_agentv3), 'agentv3')
self.assertEqual(keys.Key._guessStringType(
'\x00\x00\x00\x07ssh-rsa\x00\x00\x00\x01\x01'),
'blob')
self.assertEqual(keys.Key._guessStringType(
'\x00\x00\x00\x07ssh-dss\x00\x00\x00\x01\x01'),
'blob')
self.assertEqual(keys.Key._guessStringType('not a key'),
None)
def _testPublicPrivateFromString(self, public, private, type, data):
self._testPublicFromString(public, type, data)
self._testPrivateFromString(private, type, data)
def _testPublicFromString(self, public, type, data):
publicKey = keys.Key.fromString(public)
self.assertTrue(publicKey.isPublic())
self.assertEqual(publicKey.type(), type)
for k, v in publicKey.data().items():
self.assertEqual(data[k], v)
def _testPrivateFromString(self, private, type, data):
privateKey = keys.Key.fromString(private)
self.assertFalse(privateKey.isPublic())
self.assertEqual(privateKey.type(), type)
for k, v in data.items():
self.assertEqual(privateKey.data()[k], v)
def test_fromOpenSSH(self):
"""
Test that keys are correctly generated from OpenSSH strings.
"""
self._testPublicPrivateFromString(keydata.publicRSA_openssh,
keydata.privateRSA_openssh, 'RSA', keydata.RSAData)
self.assertEqual(keys.Key.fromString(
keydata.privateRSA_openssh_encrypted,
passphrase='encrypted'),
keys.Key.fromString(keydata.privateRSA_openssh))
self.assertEqual(keys.Key.fromString(
keydata.privateRSA_openssh_alternate),
keys.Key.fromString(keydata.privateRSA_openssh))
self._testPublicPrivateFromString(keydata.publicDSA_openssh,
keydata.privateDSA_openssh, 'DSA', keydata.DSAData)
def test_fromOpenSSH_with_whitespace(self):
"""
If key strings have trailing whitespace, it should be ignored.
"""
# from bug #3391, since our test key data doesn't have
# an issue with appended newlines
privateDSAData = """-----BEGIN DSA PRIVATE KEY-----
MIIBuwIBAAKBgQDylESNuc61jq2yatCzZbenlr9llG+p9LhIpOLUbXhhHcwC6hrh
EZIdCKqTO0USLrGoP5uS9UHAUoeN62Z0KXXWTwOWGEQn/syyPzNJtnBorHpNUT9D
Qzwl1yUa53NNgEctpo4NoEFOx8PuU6iFLyvgHCjNn2MsuGuzkZm7sI9ZpQIVAJiR
9dPc08KLdpJyRxz8T74b4FQRAoGAGBc4Z5Y6R/HZi7AYM/iNOM8su6hrk8ypkBwR
a3Dbhzk97fuV3SF1SDrcQu4zF7c4CtH609N5nfZs2SUjLLGPWln83Ysb8qhh55Em
AcHXuROrHS/sDsnqu8FQp86MaudrqMExCOYyVPE7jaBWW+/JWFbKCxmgOCSdViUJ
esJpBFsCgYEA7+jtVvSt9yrwsS/YU1QGP5wRAiDYB+T5cK4HytzAqJKRdC5qS4zf
C7R0eKcDHHLMYO39aPnCwXjscisnInEhYGNblTDyPyiyNxAOXuC8x7luTmwzMbNJ
/ow0IqSj0VF72VJN9uSoPpFd4lLT0zN8v42RWja0M8ohWNf+YNJluPgCFE0PT4Vm
SUrCyZXsNh6VXwjs3gKQ
-----END DSA PRIVATE KEY-----"""
self.assertEqual(keys.Key.fromString(privateDSAData),
keys.Key.fromString(privateDSAData + '\n'))
def test_fromNewerOpenSSH(self):
"""
Newer versions of OpenSSH generate encrypted keys which have a longer
IV than the older versions. These newer keys are also loaded.
"""
key = keys.Key.fromString(keydata.privateRSA_openssh_encrypted_aes,
passphrase='testxp')
self.assertEqual(key.type(), 'RSA')
key2 = keys.Key.fromString(
keydata.privateRSA_openssh_encrypted_aes + '\n',
passphrase='testxp')
self.assertEqual(key, key2)
def test_fromLSH(self):
"""
Test that keys are correctly generated from LSH strings.
"""
self._testPublicPrivateFromString(keydata.publicRSA_lsh,
keydata.privateRSA_lsh, 'RSA', keydata.RSAData)
self._testPublicPrivateFromString(keydata.publicDSA_lsh,
keydata.privateDSA_lsh, 'DSA', keydata.DSAData)
sexp = sexpy.pack([['public-key', ['bad-key', ['p', '2']]]])
self.assertRaises(keys.BadKeyError, keys.Key.fromString,
data='{'+base64.encodestring(sexp)+'}')
sexp = sexpy.pack([['private-key', ['bad-key', ['p', '2']]]])
self.assertRaises(keys.BadKeyError, keys.Key.fromString,
sexp)
def test_fromAgentv3(self):
"""
Test that keys are correctly generated from Agent v3 strings.
"""
self._testPrivateFromString(keydata.privateRSA_agentv3, 'RSA',
keydata.RSAData)
self._testPrivateFromString(keydata.privateDSA_agentv3, 'DSA',
keydata.DSAData)
self.assertRaises(keys.BadKeyError, keys.Key.fromString,
'\x00\x00\x00\x07ssh-foo'+'\x00\x00\x00\x01\x01'*5)
def test_fromStringErrors(self):
"""
keys.Key.fromString should raise BadKeyError when the key is invalid.
"""
self.assertRaises(keys.BadKeyError, keys.Key.fromString, '')
# no key data with a bad key type
self.assertRaises(keys.BadKeyError, keys.Key.fromString, '',
'bad_type')
# trying to decrypt a key which doesn't support encryption
self.assertRaises(keys.BadKeyError, keys.Key.fromString,
keydata.publicRSA_lsh, passphrase = 'unencrypted')
# trying to decrypt a key with the wrong passphrase
self.assertRaises(keys.EncryptedKeyError, keys.Key.fromString,
keys.Key(self.rsaObj).toString('openssh', 'encrypted'))
# key with no key data
self.assertRaises(keys.BadKeyError, keys.Key.fromString,
'-----BEGIN RSA KEY-----\nwA==\n')
# key with invalid DEK Info
self.assertRaises(
keys.BadKeyError, keys.Key.fromString,
"""-----BEGIN ENCRYPTED RSA KEY-----
Proc-Type: 4,ENCRYPTED
DEK-Info: weird type
4Ed/a9OgJWHJsne7yOGWeWMzHYKsxuP9w1v0aYcp+puS75wvhHLiUnNwxz0KDi6n
T3YkKLBsoCWS68ApR2J9yeQ6R+EyS+UQDrO9nwqo3DB5BT3Ggt8S1wE7vjNLQD0H
g/SJnlqwsECNhh8aAx+Ag0m3ZKOZiRD5mCkcDQsZET7URSmFytDKOjhFn3u6ZFVB
sXrfpYc6TJtOQlHd/52JB6aAbjt6afSv955Z7enIi+5yEJ5y7oYQTaE5zrFMP7N5
9LbfJFlKXxEddy/DErRLxEjmC+t4svHesoJKc2jjjyNPiOoGGF3kJXea62vsjdNV
gMK5Eged3TBVIk2dv8rtJUvyFeCUtjQ1UJZIebScRR47KrbsIpCmU8I4/uHWm5hW
0mOwvdx1L/mqx/BHqVU9Dw2COhOdLbFxlFI92chkovkmNk4P48ziyVnpm7ME22sE
vfCMsyirdqB1mrL4CSM7FXONv+CgfBfeYVkYW8RfJac9U1L/O+JNn7yee414O/rS
hRYw4UdWnH6Gg6niklVKWNY0ZwUZC8zgm2iqy8YCYuneS37jC+OEKP+/s6HSKuqk
2bzcl3/TcZXNSM815hnFRpz0anuyAsvwPNRyvxG2/DacJHL1f6luV4B0o6W410yf
qXQx01DLo7nuyhJqoH3UGCyyXB+/QUs0mbG2PAEn3f5dVs31JMdbt+PrxURXXjKk
4cexpUcIpqqlfpIRe3RD0sDVbH4OXsGhi2kiTfPZu7mgyFxKopRbn1KwU1qKinfY
EU9O4PoTak/tPT+5jFNhaP+HrURoi/pU8EAUNSktl7xAkHYwkN/9Cm7DeBghgf3n
8+tyCGYDsB5utPD0/Xe9yx0Qhc/kMm4xIyQDyA937dk3mUvLC9vulnAP8I+Izim0
fZ182+D1bWwykoD0997mUHG/AUChWR01V1OLwRyPv2wUtiS8VNG76Y2aqKlgqP1P
V+IvIEqR4ERvSBVFzXNF8Y6j/sVxo8+aZw+d0L1Ns/R55deErGg3B8i/2EqGd3r+
0jps9BqFHHWW87n3VyEB3jWCMj8Vi2EJIfa/7pSaViFIQn8LiBLf+zxG5LTOToK5
xkN42fReDcqi3UNfKNGnv4dsplyTR2hyx65lsj4bRKDGLKOuB1y7iB0AGb0LtcAI
dcsVlcCeUquDXtqKvRnwfIMg+ZunyjqHBhj3qgRgbXbT6zjaSdNnih569aTg0Vup
VykzZ7+n/KVcGLmvX0NesdoI7TKbq4TnEIOynuG5Sf+2GpARO5bjcWKSZeN/Ybgk
gccf8Cqf6XWqiwlWd0B7BR3SymeHIaSymC45wmbgdstrbk7Ppa2Tp9AZku8M2Y7c
8mY9b+onK075/ypiwBm4L4GRNTFLnoNQJXx0OSl4FNRWsn6ztbD+jZhu8Seu10Jw
SEJVJ+gmTKdRLYORJKyqhDet6g7kAxs4EoJ25WsOnX5nNr00rit+NkMPA7xbJT+7
CfI51GQLw7pUPeO2WNt6yZO/YkzZrqvTj5FEwybkUyBv7L0gkqu9wjfDdUw0fVHE
xEm4DxjEoaIp8dW/JOzXQ2EF+WaSOgdYsw3Ac+rnnjnNptCdOEDGP6QBkt+oXj4P
-----END RSA PRIVATE KEY-----""", passphrase='encrypted')
# key with invalid encryption type
self.assertRaises(
keys.BadKeyError, keys.Key.fromString,
"""-----BEGIN ENCRYPTED RSA KEY-----
Proc-Type: 4,ENCRYPTED
DEK-Info: FOO-123-BAR,01234567
4Ed/a9OgJWHJsne7yOGWeWMzHYKsxuP9w1v0aYcp+puS75wvhHLiUnNwxz0KDi6n
T3YkKLBsoCWS68ApR2J9yeQ6R+EyS+UQDrO9nwqo3DB5BT3Ggt8S1wE7vjNLQD0H
g/SJnlqwsECNhh8aAx+Ag0m3ZKOZiRD5mCkcDQsZET7URSmFytDKOjhFn3u6ZFVB
sXrfpYc6TJtOQlHd/52JB6aAbjt6afSv955Z7enIi+5yEJ5y7oYQTaE5zrFMP7N5
9LbfJFlKXxEddy/DErRLxEjmC+t4svHesoJKc2jjjyNPiOoGGF3kJXea62vsjdNV
gMK5Eged3TBVIk2dv8rtJUvyFeCUtjQ1UJZIebScRR47KrbsIpCmU8I4/uHWm5hW
0mOwvdx1L/mqx/BHqVU9Dw2COhOdLbFxlFI92chkovkmNk4P48ziyVnpm7ME22sE
vfCMsyirdqB1mrL4CSM7FXONv+CgfBfeYVkYW8RfJac9U1L/O+JNn7yee414O/rS
hRYw4UdWnH6Gg6niklVKWNY0ZwUZC8zgm2iqy8YCYuneS37jC+OEKP+/s6HSKuqk
2bzcl3/TcZXNSM815hnFRpz0anuyAsvwPNRyvxG2/DacJHL1f6luV4B0o6W410yf
qXQx01DLo7nuyhJqoH3UGCyyXB+/QUs0mbG2PAEn3f5dVs31JMdbt+PrxURXXjKk
4cexpUcIpqqlfpIRe3RD0sDVbH4OXsGhi2kiTfPZu7mgyFxKopRbn1KwU1qKinfY
EU9O4PoTak/tPT+5jFNhaP+HrURoi/pU8EAUNSktl7xAkHYwkN/9Cm7DeBghgf3n
8+tyCGYDsB5utPD0/Xe9yx0Qhc/kMm4xIyQDyA937dk3mUvLC9vulnAP8I+Izim0
fZ182+D1bWwykoD0997mUHG/AUChWR01V1OLwRyPv2wUtiS8VNG76Y2aqKlgqP1P
V+IvIEqR4ERvSBVFzXNF8Y6j/sVxo8+aZw+d0L1Ns/R55deErGg3B8i/2EqGd3r+
0jps9BqFHHWW87n3VyEB3jWCMj8Vi2EJIfa/7pSaViFIQn8LiBLf+zxG5LTOToK5
xkN42fReDcqi3UNfKNGnv4dsplyTR2hyx65lsj4bRKDGLKOuB1y7iB0AGb0LtcAI
dcsVlcCeUquDXtqKvRnwfIMg+ZunyjqHBhj3qgRgbXbT6zjaSdNnih569aTg0Vup
VykzZ7+n/KVcGLmvX0NesdoI7TKbq4TnEIOynuG5Sf+2GpARO5bjcWKSZeN/Ybgk
gccf8Cqf6XWqiwlWd0B7BR3SymeHIaSymC45wmbgdstrbk7Ppa2Tp9AZku8M2Y7c
8mY9b+onK075/ypiwBm4L4GRNTFLnoNQJXx0OSl4FNRWsn6ztbD+jZhu8Seu10Jw
SEJVJ+gmTKdRLYORJKyqhDet6g7kAxs4EoJ25WsOnX5nNr00rit+NkMPA7xbJT+7
CfI51GQLw7pUPeO2WNt6yZO/YkzZrqvTj5FEwybkUyBv7L0gkqu9wjfDdUw0fVHE
xEm4DxjEoaIp8dW/JOzXQ2EF+WaSOgdYsw3Ac+rnnjnNptCdOEDGP6QBkt+oXj4P
-----END RSA PRIVATE KEY-----""", passphrase='encrypted')
# key with bad IV (AES)
self.assertRaises(
keys.BadKeyError, keys.Key.fromString,
"""-----BEGIN ENCRYPTED RSA KEY-----
Proc-Type: 4,ENCRYPTED
DEK-Info: AES-128-CBC,01234
4Ed/a9OgJWHJsne7yOGWeWMzHYKsxuP9w1v0aYcp+puS75wvhHLiUnNwxz0KDi6n
T3YkKLBsoCWS68ApR2J9yeQ6R+EyS+UQDrO9nwqo3DB5BT3Ggt8S1wE7vjNLQD0H
g/SJnlqwsECNhh8aAx+Ag0m3ZKOZiRD5mCkcDQsZET7URSmFytDKOjhFn3u6ZFVB
sXrfpYc6TJtOQlHd/52JB6aAbjt6afSv955Z7enIi+5yEJ5y7oYQTaE5zrFMP7N5
9LbfJFlKXxEddy/DErRLxEjmC+t4svHesoJKc2jjjyNPiOoGGF3kJXea62vsjdNV
gMK5Eged3TBVIk2dv8rtJUvyFeCUtjQ1UJZIebScRR47KrbsIpCmU8I4/uHWm5hW
0mOwvdx1L/mqx/BHqVU9Dw2COhOdLbFxlFI92chkovkmNk4P48ziyVnpm7ME22sE
vfCMsyirdqB1mrL4CSM7FXONv+CgfBfeYVkYW8RfJac9U1L/O+JNn7yee414O/rS
hRYw4UdWnH6Gg6niklVKWNY0ZwUZC8zgm2iqy8YCYuneS37jC+OEKP+/s6HSKuqk
2bzcl3/TcZXNSM815hnFRpz0anuyAsvwPNRyvxG2/DacJHL1f6luV4B0o6W410yf
qXQx01DLo7nuyhJqoH3UGCyyXB+/QUs0mbG2PAEn3f5dVs31JMdbt+PrxURXXjKk
4cexpUcIpqqlfpIRe3RD0sDVbH4OXsGhi2kiTfPZu7mgyFxKopRbn1KwU1qKinfY
EU9O4PoTak/tPT+5jFNhaP+HrURoi/pU8EAUNSktl7xAkHYwkN/9Cm7DeBghgf3n
8+tyCGYDsB5utPD0/Xe9yx0Qhc/kMm4xIyQDyA937dk3mUvLC9vulnAP8I+Izim0
fZ182+D1bWwykoD0997mUHG/AUChWR01V1OLwRyPv2wUtiS8VNG76Y2aqKlgqP1P
V+IvIEqR4ERvSBVFzXNF8Y6j/sVxo8+aZw+d0L1Ns/R55deErGg3B8i/2EqGd3r+
0jps9BqFHHWW87n3VyEB3jWCMj8Vi2EJIfa/7pSaViFIQn8LiBLf+zxG5LTOToK5
xkN42fReDcqi3UNfKNGnv4dsplyTR2hyx65lsj4bRKDGLKOuB1y7iB0AGb0LtcAI
dcsVlcCeUquDXtqKvRnwfIMg+ZunyjqHBhj3qgRgbXbT6zjaSdNnih569aTg0Vup
VykzZ7+n/KVcGLmvX0NesdoI7TKbq4TnEIOynuG5Sf+2GpARO5bjcWKSZeN/Ybgk
gccf8Cqf6XWqiwlWd0B7BR3SymeHIaSymC45wmbgdstrbk7Ppa2Tp9AZku8M2Y7c
8mY9b+onK075/ypiwBm4L4GRNTFLnoNQJXx0OSl4FNRWsn6ztbD+jZhu8Seu10Jw
SEJVJ+gmTKdRLYORJKyqhDet6g7kAxs4EoJ25WsOnX5nNr00rit+NkMPA7xbJT+7
CfI51GQLw7pUPeO2WNt6yZO/YkzZrqvTj5FEwybkUyBv7L0gkqu9wjfDdUw0fVHE
xEm4DxjEoaIp8dW/JOzXQ2EF+WaSOgdYsw3Ac+rnnjnNptCdOEDGP6QBkt+oXj4P
-----END RSA PRIVATE KEY-----""", passphrase='encrypted')
# key with bad IV (DES3)
self.assertRaises(
keys.BadKeyError, keys.Key.fromString,
"""-----BEGIN ENCRYPTED RSA KEY-----
Proc-Type: 4,ENCRYPTED
DEK-Info: DES-EDE3-CBC,01234
4Ed/a9OgJWHJsne7yOGWeWMzHYKsxuP9w1v0aYcp+puS75wvhHLiUnNwxz0KDi6n
T3YkKLBsoCWS68ApR2J9yeQ6R+EyS+UQDrO9nwqo3DB5BT3Ggt8S1wE7vjNLQD0H
g/SJnlqwsECNhh8aAx+Ag0m3ZKOZiRD5mCkcDQsZET7URSmFytDKOjhFn3u6ZFVB
sXrfpYc6TJtOQlHd/52JB6aAbjt6afSv955Z7enIi+5yEJ5y7oYQTaE5zrFMP7N5
9LbfJFlKXxEddy/DErRLxEjmC+t4svHesoJKc2jjjyNPiOoGGF3kJXea62vsjdNV
gMK5Eged3TBVIk2dv8rtJUvyFeCUtjQ1UJZIebScRR47KrbsIpCmU8I4/uHWm5hW
0mOwvdx1L/mqx/BHqVU9Dw2COhOdLbFxlFI92chkovkmNk4P48ziyVnpm7ME22sE
vfCMsyirdqB1mrL4CSM7FXONv+CgfBfeYVkYW8RfJac9U1L/O+JNn7yee414O/rS
hRYw4UdWnH6Gg6niklVKWNY0ZwUZC8zgm2iqy8YCYuneS37jC+OEKP+/s6HSKuqk
2bzcl3/TcZXNSM815hnFRpz0anuyAsvwPNRyvxG2/DacJHL1f6luV4B0o6W410yf
qXQx01DLo7nuyhJqoH3UGCyyXB+/QUs0mbG2PAEn3f5dVs31JMdbt+PrxURXXjKk
4cexpUcIpqqlfpIRe3RD0sDVbH4OXsGhi2kiTfPZu7mgyFxKopRbn1KwU1qKinfY
EU9O4PoTak/tPT+5jFNhaP+HrURoi/pU8EAUNSktl7xAkHYwkN/9Cm7DeBghgf3n
8+tyCGYDsB5utPD0/Xe9yx0Qhc/kMm4xIyQDyA937dk3mUvLC9vulnAP8I+Izim0
fZ182+D1bWwykoD0997mUHG/AUChWR01V1OLwRyPv2wUtiS8VNG76Y2aqKlgqP1P
V+IvIEqR4ERvSBVFzXNF8Y6j/sVxo8+aZw+d0L1Ns/R55deErGg3B8i/2EqGd3r+
0jps9BqFHHWW87n3VyEB3jWCMj8Vi2EJIfa/7pSaViFIQn8LiBLf+zxG5LTOToK5
xkN42fReDcqi3UNfKNGnv4dsplyTR2hyx65lsj4bRKDGLKOuB1y7iB0AGb0LtcAI
dcsVlcCeUquDXtqKvRnwfIMg+ZunyjqHBhj3qgRgbXbT6zjaSdNnih569aTg0Vup
VykzZ7+n/KVcGLmvX0NesdoI7TKbq4TnEIOynuG5Sf+2GpARO5bjcWKSZeN/Ybgk
gccf8Cqf6XWqiwlWd0B7BR3SymeHIaSymC45wmbgdstrbk7Ppa2Tp9AZku8M2Y7c
8mY9b+onK075/ypiwBm4L4GRNTFLnoNQJXx0OSl4FNRWsn6ztbD+jZhu8Seu10Jw
SEJVJ+gmTKdRLYORJKyqhDet6g7kAxs4EoJ25WsOnX5nNr00rit+NkMPA7xbJT+7
CfI51GQLw7pUPeO2WNt6yZO/YkzZrqvTj5FEwybkUyBv7L0gkqu9wjfDdUw0fVHE
xEm4DxjEoaIp8dW/JOzXQ2EF+WaSOgdYsw3Ac+rnnjnNptCdOEDGP6QBkt+oXj4P
-----END RSA PRIVATE KEY-----""", passphrase='encrypted')
def test_fromFile(self):
"""
Test that fromFile works correctly.
"""
self.assertEqual(keys.Key.fromFile(self.keyFile),
keys.Key.fromString(keydata.privateRSA_lsh))
self.assertRaises(keys.BadKeyError, keys.Key.fromFile,
self.keyFile, 'bad_type')
self.assertRaises(keys.BadKeyError, keys.Key.fromFile,
self.keyFile, passphrase='unencrypted')
def test_init(self):
"""
Test that the PublicKey object is initialized correctly.
"""
obj = Crypto.PublicKey.RSA.construct((1L, 2L))
key = keys.Key(obj)
self.assertEqual(key.keyObject, obj)
def test_equal(self):
"""
Test that Key objects are compared correctly.
"""
rsa1 = keys.Key(self.rsaObj)
rsa2 = keys.Key(self.rsaObj)
rsa3 = keys.Key(Crypto.PublicKey.RSA.construct((1L, 2L)))
dsa = keys.Key(self.dsaObj)
self.assertTrue(rsa1 == rsa2)
self.assertFalse(rsa1 == rsa3)
self.assertFalse(rsa1 == dsa)
self.assertFalse(rsa1 == object)
self.assertFalse(rsa1 == None)
def test_notEqual(self):
"""
Test that Key objects are not-compared correctly.
"""
rsa1 = keys.Key(self.rsaObj)
rsa2 = keys.Key(self.rsaObj)
rsa3 = keys.Key(Crypto.PublicKey.RSA.construct((1L, 2L)))
dsa = keys.Key(self.dsaObj)
self.assertFalse(rsa1 != rsa2)
self.assertTrue(rsa1 != rsa3)
self.assertTrue(rsa1 != dsa)
self.assertTrue(rsa1 != object)
self.assertTrue(rsa1 != None)
def test_type(self):
"""
Test that the type method returns the correct type for an object.
"""
self.assertEqual(keys.Key(self.rsaObj).type(), 'RSA')
self.assertEqual(keys.Key(self.rsaObj).sshType(), 'ssh-rsa')
self.assertEqual(keys.Key(self.dsaObj).type(), 'DSA')
self.assertEqual(keys.Key(self.dsaObj).sshType(), 'ssh-dss')
self.assertRaises(RuntimeError, keys.Key(None).type)
self.assertRaises(RuntimeError, keys.Key(None).sshType)
self.assertRaises(RuntimeError, keys.Key(self).type)
self.assertRaises(RuntimeError, keys.Key(self).sshType)
def test_fromBlob(self):
"""
Test that a public key is correctly generated from a public key blob.
"""
rsaBlob = common.NS('ssh-rsa') + common.MP(2) + common.MP(3)
rsaKey = keys.Key.fromString(rsaBlob)
dsaBlob = (common.NS('ssh-dss') + common.MP(2) + common.MP(3) +
common.MP(4) + common.MP(5))
dsaKey = keys.Key.fromString(dsaBlob)
badBlob = common.NS('ssh-bad')
self.assertTrue(rsaKey.isPublic())
self.assertEqual(rsaKey.data(), {'e':2L, 'n':3L})
self.assertTrue(dsaKey.isPublic())
self.assertEqual(dsaKey.data(), {'p':2L, 'q':3L, 'g':4L, 'y':5L})
self.assertRaises(keys.BadKeyError,
keys.Key.fromString, badBlob)
def test_fromPrivateBlob(self):
"""
Test that a private key is correctly generated from a private key blob.
"""
rsaBlob = (common.NS('ssh-rsa') + common.MP(2) + common.MP(3) +
common.MP(4) + common.MP(5) + common.MP(6) + common.MP(7))
rsaKey = keys.Key._fromString_PRIVATE_BLOB(rsaBlob)
dsaBlob = (common.NS('ssh-dss') + common.MP(2) + common.MP(3) +
common.MP(4) + common.MP(5) + common.MP(6))
dsaKey = keys.Key._fromString_PRIVATE_BLOB(dsaBlob)
badBlob = common.NS('ssh-bad')
self.assertFalse(rsaKey.isPublic())
self.assertEqual(
rsaKey.data(), {'n':2L, 'e':3L, 'd':4L, 'u':5L, 'p':6L, 'q':7L})
self.assertFalse(dsaKey.isPublic())
self.assertEqual(dsaKey.data(), {'p':2L, 'q':3L, 'g':4L, 'y':5L, 'x':6L})
self.assertRaises(
keys.BadKeyError, keys.Key._fromString_PRIVATE_BLOB, badBlob)
def test_blob(self):
"""
Test that the Key object generates blobs correctly.
"""
self.assertEqual(keys.Key(self.rsaObj).blob(),
'\x00\x00\x00\x07ssh-rsa\x00\x00\x00\x01\x02'
'\x00\x00\x00\x01\x01')
self.assertEqual(keys.Key(self.dsaObj).blob(),
'\x00\x00\x00\x07ssh-dss\x00\x00\x00\x01\x03'
'\x00\x00\x00\x01\x04\x00\x00\x00\x01\x02'
'\x00\x00\x00\x01\x01')
badKey = keys.Key(None)
self.assertRaises(RuntimeError, badKey.blob)
def test_privateBlob(self):
"""
L{Key.privateBlob} returns the SSH protocol-level format of the private
key and raises L{RuntimeError} if the underlying key object is invalid.
"""
self.assertEqual(keys.Key(self.rsaObj).privateBlob(),
'\x00\x00\x00\x07ssh-rsa\x00\x00\x00\x01\x01'
'\x00\x00\x00\x01\x02\x00\x00\x00\x01\x03\x00'
'\x00\x00\x01\x04\x00\x00\x00\x01\x04\x00\x00'
'\x00\x01\x05')
self.assertEqual(keys.Key(self.dsaObj).privateBlob(),
'\x00\x00\x00\x07ssh-dss\x00\x00\x00\x01\x03'
'\x00\x00\x00\x01\x04\x00\x00\x00\x01\x02\x00'
'\x00\x00\x01\x01\x00\x00\x00\x01\x05')
badKey = keys.Key(None)
self.assertRaises(RuntimeError, badKey.privateBlob)
def test_toOpenSSH(self):
"""
Test that the Key object generates OpenSSH keys correctly.
"""
key = keys.Key.fromString(keydata.privateRSA_lsh)
self.assertEqual(key.toString('openssh'), keydata.privateRSA_openssh)
self.assertEqual(key.toString('openssh', 'encrypted'),
keydata.privateRSA_openssh_encrypted)
self.assertEqual(key.public().toString('openssh'),
keydata.publicRSA_openssh[:-8]) # no comment
self.assertEqual(key.public().toString('openssh', 'comment'),
keydata.publicRSA_openssh)
key = keys.Key.fromString(keydata.privateDSA_lsh)
self.assertEqual(key.toString('openssh'), keydata.privateDSA_openssh)
self.assertEqual(key.public().toString('openssh', 'comment'),
keydata.publicDSA_openssh)
self.assertEqual(key.public().toString('openssh'),
keydata.publicDSA_openssh[:-8]) # no comment
def test_toLSH(self):
"""
Test that the Key object generates LSH keys correctly.
"""
key = keys.Key.fromString(keydata.privateRSA_openssh)
self.assertEqual(key.toString('lsh'), keydata.privateRSA_lsh)
self.assertEqual(key.public().toString('lsh'),
keydata.publicRSA_lsh)
key = keys.Key.fromString(keydata.privateDSA_openssh)
self.assertEqual(key.toString('lsh'), keydata.privateDSA_lsh)
self.assertEqual(key.public().toString('lsh'),
keydata.publicDSA_lsh)
def test_toAgentv3(self):
"""
Test that the Key object generates Agent v3 keys correctly.
"""
key = keys.Key.fromString(keydata.privateRSA_openssh)
self.assertEqual(key.toString('agentv3'), keydata.privateRSA_agentv3)
key = keys.Key.fromString(keydata.privateDSA_openssh)
self.assertEqual(key.toString('agentv3'), keydata.privateDSA_agentv3)
def test_toStringErrors(self):
"""
Test that toString raises errors appropriately.
"""
self.assertRaises(keys.BadKeyError, keys.Key(self.rsaObj).toString,
'bad_type')
def test_sign(self):
"""
Test that the Key object generates correct signatures.
"""
key = keys.Key.fromString(keydata.privateRSA_openssh)
self.assertEqual(key.sign(''), self.rsaSignature)
key = keys.Key.fromString(keydata.privateDSA_openssh)
self.assertEqual(key.sign(''), self.dsaSignature)
def test_verify(self):
"""
Test that the Key object correctly verifies signatures.
"""
key = keys.Key.fromString(keydata.publicRSA_openssh)
self.assertTrue(key.verify(self.rsaSignature, ''))
self.assertFalse(key.verify(self.rsaSignature, 'a'))
self.assertFalse(key.verify(self.dsaSignature, ''))
key = keys.Key.fromString(keydata.publicDSA_openssh)
self.assertTrue(key.verify(self.dsaSignature, ''))
self.assertFalse(key.verify(self.dsaSignature, 'a'))
self.assertFalse(key.verify(self.rsaSignature, ''))
def test_verifyDSANoPrefix(self):
"""
Some commercial SSH servers send DSA keys as 2 20-byte numbers;
they are still verified as valid keys.
"""
key = keys.Key.fromString(keydata.publicDSA_openssh)
self.assertTrue(key.verify(self.dsaSignature[-40:], ''))
def test_repr(self):
"""
Test the pretty representation of Key.
"""
self.assertEqual(repr(keys.Key(self.rsaObj)),
"""<RSA Private Key (0 bits)
attr d:
\t03
attr e:
\t02
attr n:
\t01
attr p:
\t04
attr q:
\t05
attr u:
\t04>""")

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,372 @@
# -*- test-case-name: twisted.conch.test.test_manhole -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.conch.manhole}.
"""
import traceback
from twisted.trial import unittest
from twisted.internet import error, defer
from twisted.test.proto_helpers import StringTransport
from twisted.conch.test.test_recvline import _TelnetMixin, _SSHMixin, _StdioMixin, stdio, ssh
from twisted.conch import manhole
from twisted.conch.insults import insults
def determineDefaultFunctionName():
"""
Return the string used by Python as the name for code objects which are
compiled from interactive input or at the top-level of modules.
"""
try:
1 // 0
except:
# The last frame is this function. The second to last frame is this
# function's caller, which is module-scope, which is what we want,
# so -2.
return traceback.extract_stack()[-2][2]
defaultFunctionName = determineDefaultFunctionName()
class ManholeInterpreterTests(unittest.TestCase):
"""
Tests for L{manhole.ManholeInterpreter}.
"""
def test_resetBuffer(self):
"""
L{ManholeInterpreter.resetBuffer} should empty the input buffer.
"""
interpreter = manhole.ManholeInterpreter(None)
interpreter.buffer.extend(["1", "2"])
interpreter.resetBuffer()
self.assertFalse(interpreter.buffer)
class ManholeProtocolTests(unittest.TestCase):
"""
Tests for L{manhole.Manhole}.
"""
def test_interruptResetsInterpreterBuffer(self):
"""
L{manhole.Manhole.handle_INT} should cause the interpreter input buffer
to be reset.
"""
transport = StringTransport()
terminal = insults.ServerProtocol(manhole.Manhole)
terminal.makeConnection(transport)
protocol = terminal.terminalProtocol
interpreter = protocol.interpreter
interpreter.buffer.extend(["1", "2"])
protocol.handle_INT()
self.assertFalse(interpreter.buffer)
class WriterTestCase(unittest.TestCase):
def testInteger(self):
manhole.lastColorizedLine("1")
def testDoubleQuoteString(self):
manhole.lastColorizedLine('"1"')
def testSingleQuoteString(self):
manhole.lastColorizedLine("'1'")
def testTripleSingleQuotedString(self):
manhole.lastColorizedLine("'''1'''")
def testTripleDoubleQuotedString(self):
manhole.lastColorizedLine('"""1"""')
def testFunctionDefinition(self):
manhole.lastColorizedLine("def foo():")
def testClassDefinition(self):
manhole.lastColorizedLine("class foo:")
class ManholeLoopbackMixin:
serverProtocol = manhole.ColoredManhole
def wfd(self, d):
return defer.waitForDeferred(d)
def testSimpleExpression(self):
done = self.recvlineClient.expect("done")
self._testwrite(
"1 + 1\n"
"done")
def finished(ign):
self._assertBuffer(
[">>> 1 + 1",
"2",
">>> done"])
return done.addCallback(finished)
def testTripleQuoteLineContinuation(self):
done = self.recvlineClient.expect("done")
self._testwrite(
"'''\n'''\n"
"done")
def finished(ign):
self._assertBuffer(
[">>> '''",
"... '''",
"'\\n'",
">>> done"])
return done.addCallback(finished)
def testFunctionDefinition(self):
done = self.recvlineClient.expect("done")
self._testwrite(
"def foo(bar):\n"
"\tprint bar\n\n"
"foo(42)\n"
"done")
def finished(ign):
self._assertBuffer(
[">>> def foo(bar):",
"... print bar",
"... ",
">>> foo(42)",
"42",
">>> done"])
return done.addCallback(finished)
def testClassDefinition(self):
done = self.recvlineClient.expect("done")
self._testwrite(
"class Foo:\n"
"\tdef bar(self):\n"
"\t\tprint 'Hello, world!'\n\n"
"Foo().bar()\n"
"done")
def finished(ign):
self._assertBuffer(
[">>> class Foo:",
"... def bar(self):",
"... print 'Hello, world!'",
"... ",
">>> Foo().bar()",
"Hello, world!",
">>> done"])
return done.addCallback(finished)
def testException(self):
done = self.recvlineClient.expect("done")
self._testwrite(
"raise Exception('foo bar baz')\n"
"done")
def finished(ign):
self._assertBuffer(
[">>> raise Exception('foo bar baz')",
"Traceback (most recent call last):",
' File "<console>", line 1, in ' + defaultFunctionName,
"Exception: foo bar baz",
">>> done"])
return done.addCallback(finished)
def testControlC(self):
done = self.recvlineClient.expect("done")
self._testwrite(
"cancelled line" + manhole.CTRL_C +
"done")
def finished(ign):
self._assertBuffer(
[">>> cancelled line",
"KeyboardInterrupt",
">>> done"])
return done.addCallback(finished)
def test_interruptDuringContinuation(self):
"""
Sending ^C to Manhole while in a state where more input is required to
complete a statement should discard the entire ongoing statement and
reset the input prompt to the non-continuation prompt.
"""
continuing = self.recvlineClient.expect("things")
self._testwrite("(\nthings")
def gotContinuation(ignored):
self._assertBuffer(
[">>> (",
"... things"])
interrupted = self.recvlineClient.expect(">>> ")
self._testwrite(manhole.CTRL_C)
return interrupted
continuing.addCallback(gotContinuation)
def gotInterruption(ignored):
self._assertBuffer(
[">>> (",
"... things",
"KeyboardInterrupt",
">>> "])
continuing.addCallback(gotInterruption)
return continuing
def testControlBackslash(self):
self._testwrite("cancelled line")
partialLine = self.recvlineClient.expect("cancelled line")
def gotPartialLine(ign):
self._assertBuffer(
[">>> cancelled line"])
self._testwrite(manhole.CTRL_BACKSLASH)
d = self.recvlineClient.onDisconnection
return self.assertFailure(d, error.ConnectionDone)
def gotClearedLine(ign):
self._assertBuffer(
[""])
return partialLine.addCallback(gotPartialLine).addCallback(gotClearedLine)
def testControlD(self):
self._testwrite("1 + 1")
helloWorld = self.wfd(self.recvlineClient.expect(r"\+ 1"))
yield helloWorld
helloWorld.getResult()
self._assertBuffer([">>> 1 + 1"])
self._testwrite(manhole.CTRL_D + " + 1")
cleared = self.wfd(self.recvlineClient.expect(r"\+ 1"))
yield cleared
cleared.getResult()
self._assertBuffer([">>> 1 + 1 + 1"])
self._testwrite("\n")
printed = self.wfd(self.recvlineClient.expect("3\n>>> "))
yield printed
printed.getResult()
self._testwrite(manhole.CTRL_D)
d = self.recvlineClient.onDisconnection
disconnected = self.wfd(self.assertFailure(d, error.ConnectionDone))
yield disconnected
disconnected.getResult()
testControlD = defer.deferredGenerator(testControlD)
def testControlL(self):
"""
CTRL+L is generally used as a redraw-screen command in terminal
applications. Manhole doesn't currently respect this usage of it,
but it should at least do something reasonable in response to this
event (rather than, say, eating your face).
"""
# Start off with a newline so that when we clear the display we can
# tell by looking for the missing first empty prompt line.
self._testwrite("\n1 + 1")
helloWorld = self.wfd(self.recvlineClient.expect(r"\+ 1"))
yield helloWorld
helloWorld.getResult()
self._assertBuffer([">>> ", ">>> 1 + 1"])
self._testwrite(manhole.CTRL_L + " + 1")
redrew = self.wfd(self.recvlineClient.expect(r"1 \+ 1 \+ 1"))
yield redrew
redrew.getResult()
self._assertBuffer([">>> 1 + 1 + 1"])
testControlL = defer.deferredGenerator(testControlL)
def test_controlA(self):
"""
CTRL-A can be used as HOME - returning cursor to beginning of
current line buffer.
"""
self._testwrite('rint "hello"' + '\x01' + 'p')
d = self.recvlineClient.expect('print "hello"')
def cb(ignore):
self._assertBuffer(['>>> print "hello"'])
return d.addCallback(cb)
def test_controlE(self):
"""
CTRL-E can be used as END - setting cursor to end of current
line buffer.
"""
self._testwrite('rint "hello' + '\x01' + 'p' + '\x05' + '"')
d = self.recvlineClient.expect('print "hello"')
def cb(ignore):
self._assertBuffer(['>>> print "hello"'])
return d.addCallback(cb)
def testDeferred(self):
self._testwrite(
"from twisted.internet import defer, reactor\n"
"d = defer.Deferred()\n"
"d\n")
deferred = self.wfd(self.recvlineClient.expect("<Deferred #0>"))
yield deferred
deferred.getResult()
self._testwrite(
"c = reactor.callLater(0.1, d.callback, 'Hi!')\n")
delayed = self.wfd(self.recvlineClient.expect(">>> "))
yield delayed
delayed.getResult()
called = self.wfd(self.recvlineClient.expect("Deferred #0 called back: 'Hi!'\n>>> "))
yield called
called.getResult()
self._assertBuffer(
[">>> from twisted.internet import defer, reactor",
">>> d = defer.Deferred()",
">>> d",
"<Deferred #0>",
">>> c = reactor.callLater(0.1, d.callback, 'Hi!')",
"Deferred #0 called back: 'Hi!'",
">>> "])
testDeferred = defer.deferredGenerator(testDeferred)
class ManholeLoopbackTelnet(_TelnetMixin, unittest.TestCase, ManholeLoopbackMixin):
pass
class ManholeLoopbackSSH(_SSHMixin, unittest.TestCase, ManholeLoopbackMixin):
if ssh is None:
skip = "Crypto requirements missing, can't run manhole tests over ssh"
class ManholeLoopbackStdio(_StdioMixin, unittest.TestCase, ManholeLoopbackMixin):
if stdio is None:
skip = "Terminal requirements missing, can't run manhole tests over stdio"
else:
serverProtocol = stdio.ConsoleManhole

View file

@ -0,0 +1,47 @@
# -*- twisted.conch.test.test_mixin -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
import time
from twisted.internet import reactor, protocol
from twisted.trial import unittest
from twisted.test.proto_helpers import StringTransport
from twisted.conch import mixin
class TestBufferingProto(mixin.BufferingMixin):
scheduled = False
rescheduled = 0
def schedule(self):
self.scheduled = True
return object()
def reschedule(self, token):
self.rescheduled += 1
class BufferingTest(unittest.TestCase):
def testBuffering(self):
p = TestBufferingProto()
t = p.transport = StringTransport()
self.assertFalse(p.scheduled)
L = ['foo', 'bar', 'baz', 'quux']
p.write('foo')
self.assertTrue(p.scheduled)
self.assertFalse(p.rescheduled)
for s in L:
n = p.rescheduled
p.write(s)
self.assertEqual(p.rescheduled, n + 1)
self.assertEqual(t.value(), '')
p.flush()
self.assertEqual(t.value(), 'foo' + ''.join(L))

View file

@ -0,0 +1,101 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.conch.openssh_compat}.
"""
import os
from twisted.trial.unittest import TestCase
from twisted.python.filepath import FilePath
try:
import Crypto.Cipher.DES3
import pyasn1
except ImportError:
OpenSSHFactory = None
else:
from twisted.conch.openssh_compat.factory import OpenSSHFactory
from twisted.conch.test import keydata
from twisted.test.test_process import MockOS
class OpenSSHFactoryTests(TestCase):
"""
Tests for L{OpenSSHFactory}.
"""
if getattr(os, "geteuid", None) is None:
skip = "geteuid/seteuid not available"
elif OpenSSHFactory is None:
skip = "Cannot run without PyCrypto or PyASN1"
def setUp(self):
self.factory = OpenSSHFactory()
self.keysDir = FilePath(self.mktemp())
self.keysDir.makedirs()
self.factory.dataRoot = self.keysDir.path
self.keysDir.child("ssh_host_foo").setContent("foo")
self.keysDir.child("bar_key").setContent("foo")
self.keysDir.child("ssh_host_one_key").setContent(
keydata.privateRSA_openssh)
self.keysDir.child("ssh_host_two_key").setContent(
keydata.privateDSA_openssh)
self.keysDir.child("ssh_host_three_key").setContent(
"not a key content")
self.keysDir.child("ssh_host_one_key.pub").setContent(
keydata.publicRSA_openssh)
self.mockos = MockOS()
self.patch(os, "seteuid", self.mockos.seteuid)
self.patch(os, "setegid", self.mockos.setegid)
def test_getPublicKeys(self):
"""
L{OpenSSHFactory.getPublicKeys} should return the available public keys
in the data directory
"""
keys = self.factory.getPublicKeys()
self.assertEqual(len(keys), 1)
keyTypes = keys.keys()
self.assertEqual(keyTypes, ['ssh-rsa'])
def test_getPrivateKeys(self):
"""
L{OpenSSHFactory.getPrivateKeys} should return the available private
keys in the data directory.
"""
keys = self.factory.getPrivateKeys()
self.assertEqual(len(keys), 2)
keyTypes = keys.keys()
self.assertEqual(set(keyTypes), set(['ssh-rsa', 'ssh-dss']))
self.assertEqual(self.mockos.seteuidCalls, [])
self.assertEqual(self.mockos.setegidCalls, [])
def test_getPrivateKeysAsRoot(self):
"""
L{OpenSSHFactory.getPrivateKeys} should switch to root if the keys
aren't readable by the current user.
"""
keyFile = self.keysDir.child("ssh_host_two_key")
# Fake permission error by changing the mode
keyFile.chmod(0000)
self.addCleanup(keyFile.chmod, 0777)
# And restore the right mode when seteuid is called
savedSeteuid = os.seteuid
def seteuid(euid):
keyFile.chmod(0777)
return savedSeteuid(euid)
self.patch(os, "seteuid", seteuid)
keys = self.factory.getPrivateKeys()
self.assertEqual(len(keys), 2)
keyTypes = keys.keys()
self.assertEqual(set(keyTypes), set(['ssh-rsa', 'ssh-dss']))
self.assertEqual(self.mockos.seteuidCalls, [0, os.geteuid()])
self.assertEqual(self.mockos.setegidCalls, [0, os.getegid()])

View file

@ -0,0 +1,706 @@
# -*- test-case-name: twisted.conch.test.test_recvline -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.conch.recvline} and fixtures for testing related
functionality.
"""
import sys, os
from twisted.conch.insults import insults
from twisted.conch import recvline
from twisted.python import reflect, components
from twisted.internet import defer, error
from twisted.trial import unittest
from twisted.cred import portal
from twisted.test.proto_helpers import StringTransport
class Arrows(unittest.TestCase):
def setUp(self):
self.underlyingTransport = StringTransport()
self.pt = insults.ServerProtocol()
self.p = recvline.HistoricRecvLine()
self.pt.protocolFactory = lambda: self.p
self.pt.factory = self
self.pt.makeConnection(self.underlyingTransport)
# self.p.makeConnection(self.pt)
def test_printableCharacters(self):
"""
When L{HistoricRecvLine} receives a printable character,
it adds it to the current line buffer.
"""
self.p.keystrokeReceived('x', None)
self.p.keystrokeReceived('y', None)
self.p.keystrokeReceived('z', None)
self.assertEqual(self.p.currentLineBuffer(), ('xyz', ''))
def test_horizontalArrows(self):
"""
When L{HistoricRecvLine} receives an LEFT_ARROW or
RIGHT_ARROW keystroke it moves the cursor left or right
in the current line buffer, respectively.
"""
kR = lambda ch: self.p.keystrokeReceived(ch, None)
for ch in 'xyz':
kR(ch)
self.assertEqual(self.p.currentLineBuffer(), ('xyz', ''))
kR(self.pt.RIGHT_ARROW)
self.assertEqual(self.p.currentLineBuffer(), ('xyz', ''))
kR(self.pt.LEFT_ARROW)
self.assertEqual(self.p.currentLineBuffer(), ('xy', 'z'))
kR(self.pt.LEFT_ARROW)
self.assertEqual(self.p.currentLineBuffer(), ('x', 'yz'))
kR(self.pt.LEFT_ARROW)
self.assertEqual(self.p.currentLineBuffer(), ('', 'xyz'))
kR(self.pt.LEFT_ARROW)
self.assertEqual(self.p.currentLineBuffer(), ('', 'xyz'))
kR(self.pt.RIGHT_ARROW)
self.assertEqual(self.p.currentLineBuffer(), ('x', 'yz'))
kR(self.pt.RIGHT_ARROW)
self.assertEqual(self.p.currentLineBuffer(), ('xy', 'z'))
kR(self.pt.RIGHT_ARROW)
self.assertEqual(self.p.currentLineBuffer(), ('xyz', ''))
kR(self.pt.RIGHT_ARROW)
self.assertEqual(self.p.currentLineBuffer(), ('xyz', ''))
def test_newline(self):
"""
When {HistoricRecvLine} receives a newline, it adds the current
line buffer to the end of its history buffer.
"""
kR = lambda ch: self.p.keystrokeReceived(ch, None)
for ch in 'xyz\nabc\n123\n':
kR(ch)
self.assertEqual(self.p.currentHistoryBuffer(),
(('xyz', 'abc', '123'), ()))
kR('c')
kR('b')
kR('a')
self.assertEqual(self.p.currentHistoryBuffer(),
(('xyz', 'abc', '123'), ()))
kR('\n')
self.assertEqual(self.p.currentHistoryBuffer(),
(('xyz', 'abc', '123', 'cba'), ()))
def test_verticalArrows(self):
"""
When L{HistoricRecvLine} receives UP_ARROW or DOWN_ARROW
keystrokes it move the current index in the current history
buffer up or down, and resets the current line buffer to the
previous or next line in history, respectively for each.
"""
kR = lambda ch: self.p.keystrokeReceived(ch, None)
for ch in 'xyz\nabc\n123\n':
kR(ch)
self.assertEqual(self.p.currentHistoryBuffer(),
(('xyz', 'abc', '123'), ()))
self.assertEqual(self.p.currentLineBuffer(), ('', ''))
kR(self.pt.UP_ARROW)
self.assertEqual(self.p.currentHistoryBuffer(),
(('xyz', 'abc'), ('123',)))
self.assertEqual(self.p.currentLineBuffer(), ('123', ''))
kR(self.pt.UP_ARROW)
self.assertEqual(self.p.currentHistoryBuffer(),
(('xyz',), ('abc', '123')))
self.assertEqual(self.p.currentLineBuffer(), ('abc', ''))
kR(self.pt.UP_ARROW)
self.assertEqual(self.p.currentHistoryBuffer(),
((), ('xyz', 'abc', '123')))
self.assertEqual(self.p.currentLineBuffer(), ('xyz', ''))
kR(self.pt.UP_ARROW)
self.assertEqual(self.p.currentHistoryBuffer(),
((), ('xyz', 'abc', '123')))
self.assertEqual(self.p.currentLineBuffer(), ('xyz', ''))
for i in range(4):
kR(self.pt.DOWN_ARROW)
self.assertEqual(self.p.currentHistoryBuffer(),
(('xyz', 'abc', '123'), ()))
def test_home(self):
"""
When L{HistoricRecvLine} receives a HOME keystroke it moves the
cursor to the beginning of the current line buffer.
"""
kR = lambda ch: self.p.keystrokeReceived(ch, None)
for ch in 'hello, world':
kR(ch)
self.assertEqual(self.p.currentLineBuffer(), ('hello, world', ''))
kR(self.pt.HOME)
self.assertEqual(self.p.currentLineBuffer(), ('', 'hello, world'))
def test_end(self):
"""
When L{HistoricRecvLine} receives a END keystroke it moves the cursor
to the end of the current line buffer.
"""
kR = lambda ch: self.p.keystrokeReceived(ch, None)
for ch in 'hello, world':
kR(ch)
self.assertEqual(self.p.currentLineBuffer(), ('hello, world', ''))
kR(self.pt.HOME)
kR(self.pt.END)
self.assertEqual(self.p.currentLineBuffer(), ('hello, world', ''))
def test_backspace(self):
"""
When L{HistoricRecvLine} receives a BACKSPACE keystroke it deletes
the character immediately before the cursor.
"""
kR = lambda ch: self.p.keystrokeReceived(ch, None)
for ch in 'xyz':
kR(ch)
self.assertEqual(self.p.currentLineBuffer(), ('xyz', ''))
kR(self.pt.BACKSPACE)
self.assertEqual(self.p.currentLineBuffer(), ('xy', ''))
kR(self.pt.LEFT_ARROW)
kR(self.pt.BACKSPACE)
self.assertEqual(self.p.currentLineBuffer(), ('', 'y'))
kR(self.pt.BACKSPACE)
self.assertEqual(self.p.currentLineBuffer(), ('', 'y'))
def test_delete(self):
"""
When L{HistoricRecvLine} receives a DELETE keystroke, it
delets the character immediately after the cursor.
"""
kR = lambda ch: self.p.keystrokeReceived(ch, None)
for ch in 'xyz':
kR(ch)
self.assertEqual(self.p.currentLineBuffer(), ('xyz', ''))
kR(self.pt.DELETE)
self.assertEqual(self.p.currentLineBuffer(), ('xyz', ''))
kR(self.pt.LEFT_ARROW)
kR(self.pt.DELETE)
self.assertEqual(self.p.currentLineBuffer(), ('xy', ''))
kR(self.pt.LEFT_ARROW)
kR(self.pt.DELETE)
self.assertEqual(self.p.currentLineBuffer(), ('x', ''))
kR(self.pt.LEFT_ARROW)
kR(self.pt.DELETE)
self.assertEqual(self.p.currentLineBuffer(), ('', ''))
kR(self.pt.DELETE)
self.assertEqual(self.p.currentLineBuffer(), ('', ''))
def test_insert(self):
"""
When not in INSERT mode, L{HistoricRecvLine} inserts the typed
character at the cursor before the next character.
"""
kR = lambda ch: self.p.keystrokeReceived(ch, None)
for ch in 'xyz':
kR(ch)
kR(self.pt.LEFT_ARROW)
kR('A')
self.assertEqual(self.p.currentLineBuffer(), ('xyA', 'z'))
kR(self.pt.LEFT_ARROW)
kR('B')
self.assertEqual(self.p.currentLineBuffer(), ('xyB', 'Az'))
def test_typeover(self):
"""
When in INSERT mode and upon receiving a keystroke with a printable
character, L{HistoricRecvLine} replaces the character at
the cursor with the typed character rather than inserting before.
Ah, the ironies of INSERT mode.
"""
kR = lambda ch: self.p.keystrokeReceived(ch, None)
for ch in 'xyz':
kR(ch)
kR(self.pt.INSERT)
kR(self.pt.LEFT_ARROW)
kR('A')
self.assertEqual(self.p.currentLineBuffer(), ('xyA', ''))
kR(self.pt.LEFT_ARROW)
kR('B')
self.assertEqual(self.p.currentLineBuffer(), ('xyB', ''))
def test_unprintableCharacters(self):
"""
When L{HistoricRecvLine} receives a keystroke for an unprintable
function key with no assigned behavior, the line buffer is unmodified.
"""
kR = lambda ch: self.p.keystrokeReceived(ch, None)
pt = self.pt
for ch in (pt.F1, pt.F2, pt.F3, pt.F4, pt.F5, pt.F6, pt.F7, pt.F8,
pt.F9, pt.F10, pt.F11, pt.F12, pt.PGUP, pt.PGDN):
kR(ch)
self.assertEqual(self.p.currentLineBuffer(), ('', ''))
from twisted.conch import telnet
from twisted.conch.insults import helper
from twisted.protocols import loopback
class EchoServer(recvline.HistoricRecvLine):
def lineReceived(self, line):
self.terminal.write(line + '\n' + self.ps[self.pn])
# An insults API for this would be nice.
left = "\x1b[D"
right = "\x1b[C"
up = "\x1b[A"
down = "\x1b[B"
insert = "\x1b[2~"
home = "\x1b[1~"
delete = "\x1b[3~"
end = "\x1b[4~"
backspace = "\x7f"
from twisted.cred import checkers
try:
from twisted.conch.ssh import userauth, transport, channel, connection, session
from twisted.conch.manhole_ssh import TerminalUser, TerminalSession, TerminalRealm, TerminalSessionTransport, ConchFactory
except ImportError:
ssh = False
else:
ssh = True
class SessionChannel(channel.SSHChannel):
name = 'session'
def __init__(self, protocolFactory, protocolArgs, protocolKwArgs, width, height, *a, **kw):
channel.SSHChannel.__init__(self, *a, **kw)
self.protocolFactory = protocolFactory
self.protocolArgs = protocolArgs
self.protocolKwArgs = protocolKwArgs
self.width = width
self.height = height
def channelOpen(self, data):
term = session.packRequest_pty_req("vt102", (self.height, self.width, 0, 0), '')
self.conn.sendRequest(self, 'pty-req', term)
self.conn.sendRequest(self, 'shell', '')
self._protocolInstance = self.protocolFactory(*self.protocolArgs, **self.protocolKwArgs)
self._protocolInstance.factory = self
self._protocolInstance.makeConnection(self)
def closed(self):
self._protocolInstance.connectionLost(error.ConnectionDone())
def dataReceived(self, data):
self._protocolInstance.dataReceived(data)
class TestConnection(connection.SSHConnection):
def __init__(self, protocolFactory, protocolArgs, protocolKwArgs, width, height, *a, **kw):
connection.SSHConnection.__init__(self, *a, **kw)
self.protocolFactory = protocolFactory
self.protocolArgs = protocolArgs
self.protocolKwArgs = protocolKwArgs
self.width = width
self.height = height
def serviceStarted(self):
self.__channel = SessionChannel(self.protocolFactory, self.protocolArgs, self.protocolKwArgs, self.width, self.height)
self.openChannel(self.__channel)
def write(self, bytes):
return self.__channel.write(bytes)
class TestAuth(userauth.SSHUserAuthClient):
def __init__(self, username, password, *a, **kw):
userauth.SSHUserAuthClient.__init__(self, username, *a, **kw)
self.password = password
def getPassword(self):
return defer.succeed(self.password)
class TestTransport(transport.SSHClientTransport):
def __init__(self, protocolFactory, protocolArgs, protocolKwArgs, username, password, width, height, *a, **kw):
# transport.SSHClientTransport.__init__(self, *a, **kw)
self.protocolFactory = protocolFactory
self.protocolArgs = protocolArgs
self.protocolKwArgs = protocolKwArgs
self.username = username
self.password = password
self.width = width
self.height = height
def verifyHostKey(self, hostKey, fingerprint):
return defer.succeed(True)
def connectionSecure(self):
self.__connection = TestConnection(self.protocolFactory, self.protocolArgs, self.protocolKwArgs, self.width, self.height)
self.requestService(
TestAuth(self.username, self.password, self.__connection))
def write(self, bytes):
return self.__connection.write(bytes)
class TestSessionTransport(TerminalSessionTransport):
def protocolFactory(self):
return self.avatar.conn.transport.factory.serverProtocol()
class TestSession(TerminalSession):
transportFactory = TestSessionTransport
class TestUser(TerminalUser):
pass
components.registerAdapter(TestSession, TestUser, session.ISession)
class LoopbackRelay(loopback.LoopbackRelay):
clearCall = None
def logPrefix(self):
return "LoopbackRelay(%r)" % (self.target.__class__.__name__,)
def write(self, bytes):
loopback.LoopbackRelay.write(self, bytes)
if self.clearCall is not None:
self.clearCall.cancel()
from twisted.internet import reactor
self.clearCall = reactor.callLater(0, self._clearBuffer)
def _clearBuffer(self):
self.clearCall = None
loopback.LoopbackRelay.clearBuffer(self)
class NotifyingExpectableBuffer(helper.ExpectableBuffer):
def __init__(self):
self.onConnection = defer.Deferred()
self.onDisconnection = defer.Deferred()
def connectionMade(self):
helper.ExpectableBuffer.connectionMade(self)
self.onConnection.callback(self)
def connectionLost(self, reason):
self.onDisconnection.errback(reason)
class _BaseMixin:
WIDTH = 80
HEIGHT = 24
def _assertBuffer(self, lines):
receivedLines = str(self.recvlineClient).splitlines()
expectedLines = lines + ([''] * (self.HEIGHT - len(lines) - 1))
self.assertEqual(len(receivedLines), len(expectedLines))
for i in range(len(receivedLines)):
self.assertEqual(
receivedLines[i], expectedLines[i],
str(receivedLines[max(0, i-1):i+1]) +
" != " +
str(expectedLines[max(0, i-1):i+1]))
def _trivialTest(self, input, output):
done = self.recvlineClient.expect("done")
self._testwrite(input)
def finished(ign):
self._assertBuffer(output)
return done.addCallback(finished)
class _SSHMixin(_BaseMixin):
def setUp(self):
if not ssh:
raise unittest.SkipTest("Crypto requirements missing, can't run historic recvline tests over ssh")
u, p = 'testuser', 'testpass'
rlm = TerminalRealm()
rlm.userFactory = TestUser
rlm.chainedProtocolFactory = lambda: insultsServer
ptl = portal.Portal(
rlm,
[checkers.InMemoryUsernamePasswordDatabaseDontUse(**{u: p})])
sshFactory = ConchFactory(ptl)
sshFactory.serverProtocol = self.serverProtocol
sshFactory.startFactory()
recvlineServer = self.serverProtocol()
insultsServer = insults.ServerProtocol(lambda: recvlineServer)
sshServer = sshFactory.buildProtocol(None)
clientTransport = LoopbackRelay(sshServer)
recvlineClient = NotifyingExpectableBuffer()
insultsClient = insults.ClientProtocol(lambda: recvlineClient)
sshClient = TestTransport(lambda: insultsClient, (), {}, u, p, self.WIDTH, self.HEIGHT)
serverTransport = LoopbackRelay(sshClient)
sshClient.makeConnection(clientTransport)
sshServer.makeConnection(serverTransport)
self.recvlineClient = recvlineClient
self.sshClient = sshClient
self.sshServer = sshServer
self.clientTransport = clientTransport
self.serverTransport = serverTransport
return recvlineClient.onConnection
def _testwrite(self, bytes):
self.sshClient.write(bytes)
from twisted.conch.test import test_telnet
class TestInsultsClientProtocol(insults.ClientProtocol,
test_telnet.TestProtocol):
pass
class TestInsultsServerProtocol(insults.ServerProtocol,
test_telnet.TestProtocol):
pass
class _TelnetMixin(_BaseMixin):
def setUp(self):
recvlineServer = self.serverProtocol()
insultsServer = TestInsultsServerProtocol(lambda: recvlineServer)
telnetServer = telnet.TelnetTransport(lambda: insultsServer)
clientTransport = LoopbackRelay(telnetServer)
recvlineClient = NotifyingExpectableBuffer()
insultsClient = TestInsultsClientProtocol(lambda: recvlineClient)
telnetClient = telnet.TelnetTransport(lambda: insultsClient)
serverTransport = LoopbackRelay(telnetClient)
telnetClient.makeConnection(clientTransport)
telnetServer.makeConnection(serverTransport)
serverTransport.clearBuffer()
clientTransport.clearBuffer()
self.recvlineClient = recvlineClient
self.telnetClient = telnetClient
self.clientTransport = clientTransport
self.serverTransport = serverTransport
return recvlineClient.onConnection
def _testwrite(self, bytes):
self.telnetClient.write(bytes)
try:
from twisted.conch import stdio
except ImportError:
stdio = None
class _StdioMixin(_BaseMixin):
def setUp(self):
# A memory-only terminal emulator, into which the server will
# write things and make other state changes. What ends up
# here is basically what a user would have seen on their
# screen.
testTerminal = NotifyingExpectableBuffer()
# An insults client protocol which will translate bytes
# received from the child process into keystroke commands for
# an ITerminalProtocol.
insultsClient = insults.ClientProtocol(lambda: testTerminal)
# A process protocol which will translate stdout and stderr
# received from the child process to dataReceived calls and
# error reporting on an insults client protocol.
processClient = stdio.TerminalProcessProtocol(insultsClient)
# Run twisted/conch/stdio.py with the name of a class
# implementing ITerminalProtocol. This class will be used to
# handle bytes we send to the child process.
exe = sys.executable
module = stdio.__file__
if module.endswith('.pyc') or module.endswith('.pyo'):
module = module[:-1]
args = [exe, module, reflect.qual(self.serverProtocol)]
env = os.environ.copy()
env["PYTHONPATH"] = os.pathsep.join(sys.path)
from twisted.internet import reactor
clientTransport = reactor.spawnProcess(processClient, exe, args,
env=env, usePTY=True)
self.recvlineClient = self.testTerminal = testTerminal
self.processClient = processClient
self.clientTransport = clientTransport
# Wait for the process protocol and test terminal to become
# connected before proceeding. The former should always
# happen first, but it doesn't hurt to be safe.
return defer.gatherResults(filter(None, [
processClient.onConnection,
testTerminal.expect(">>> ")]))
def tearDown(self):
# Kill the child process. We're done with it.
try:
self.clientTransport.signalProcess("KILL")
except (error.ProcessExitedAlready, OSError):
pass
def trap(failure):
failure.trap(error.ProcessTerminated)
self.assertEqual(failure.value.exitCode, None)
self.assertEqual(failure.value.status, 9)
return self.testTerminal.onDisconnection.addErrback(trap)
def _testwrite(self, bytes):
self.clientTransport.write(bytes)
class RecvlineLoopbackMixin:
serverProtocol = EchoServer
def testSimple(self):
return self._trivialTest(
"first line\ndone",
[">>> first line",
"first line",
">>> done"])
def testLeftArrow(self):
return self._trivialTest(
insert + 'first line' + left * 4 + "xxxx\ndone",
[">>> first xxxx",
"first xxxx",
">>> done"])
def testRightArrow(self):
return self._trivialTest(
insert + 'right line' + left * 4 + right * 2 + "xx\ndone",
[">>> right lixx",
"right lixx",
">>> done"])
def testBackspace(self):
return self._trivialTest(
"second line" + backspace * 4 + "xxxx\ndone",
[">>> second xxxx",
"second xxxx",
">>> done"])
def testDelete(self):
return self._trivialTest(
"delete xxxx" + left * 4 + delete * 4 + "line\ndone",
[">>> delete line",
"delete line",
">>> done"])
def testInsert(self):
return self._trivialTest(
"third ine" + left * 3 + "l\ndone",
[">>> third line",
"third line",
">>> done"])
def testTypeover(self):
return self._trivialTest(
"fourth xine" + left * 4 + insert + "l\ndone",
[">>> fourth line",
"fourth line",
">>> done"])
def testHome(self):
return self._trivialTest(
insert + "blah line" + home + "home\ndone",
[">>> home line",
"home line",
">>> done"])
def testEnd(self):
return self._trivialTest(
"end " + left * 4 + end + "line\ndone",
[">>> end line",
"end line",
">>> done"])
class RecvlineLoopbackTelnet(_TelnetMixin, unittest.TestCase, RecvlineLoopbackMixin):
pass
class RecvlineLoopbackSSH(_SSHMixin, unittest.TestCase, RecvlineLoopbackMixin):
pass
class RecvlineLoopbackStdio(_StdioMixin, unittest.TestCase, RecvlineLoopbackMixin):
if stdio is None:
skip = "Terminal requirements missing, can't run recvline tests over stdio"
class HistoricRecvlineLoopbackMixin:
serverProtocol = EchoServer
def testUpArrow(self):
return self._trivialTest(
"first line\n" + up + "\ndone",
[">>> first line",
"first line",
">>> first line",
"first line",
">>> done"])
def testDownArrow(self):
return self._trivialTest(
"first line\nsecond line\n" + up * 2 + down + "\ndone",
[">>> first line",
"first line",
">>> second line",
"second line",
">>> second line",
"second line",
">>> done"])
class HistoricRecvlineLoopbackTelnet(_TelnetMixin, unittest.TestCase, HistoricRecvlineLoopbackMixin):
pass
class HistoricRecvlineLoopbackSSH(_SSHMixin, unittest.TestCase, HistoricRecvlineLoopbackMixin):
pass
class HistoricRecvlineLoopbackStdio(_StdioMixin, unittest.TestCase, HistoricRecvlineLoopbackMixin):
if stdio is None:
skip = "Terminal requirements missing, can't run historic recvline tests over stdio"

View file

@ -0,0 +1,82 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for the command-line interfaces to conch.
"""
try:
import pyasn1
except ImportError:
pyasn1Skip = "Cannot run without PyASN1"
else:
pyasn1Skip = None
try:
import Crypto
except ImportError:
cryptoSkip = "can't run w/o PyCrypto"
else:
cryptoSkip = None
try:
import tty
except ImportError:
ttySkip = "can't run w/o tty"
else:
ttySkip = None
try:
import Tkinter
except ImportError:
tkskip = "can't run w/o Tkinter"
else:
try:
Tkinter.Tk().destroy()
except Tkinter.TclError, e:
tkskip = "Can't test Tkinter: " + str(e)
else:
tkskip = None
from twisted.trial.unittest import TestCase
from twisted.scripts.test.test_scripts import ScriptTestsMixin
from twisted.python.test.test_shellcomp import ZshScriptTestMixin
class ScriptTests(TestCase, ScriptTestsMixin):
"""
Tests for the Conch scripts.
"""
skip = pyasn1Skip or cryptoSkip
def test_conch(self):
self.scriptTest("conch/conch")
test_conch.skip = ttySkip or skip
def test_cftp(self):
self.scriptTest("conch/cftp")
test_cftp.skip = ttySkip or skip
def test_ckeygen(self):
self.scriptTest("conch/ckeygen")
def test_tkconch(self):
self.scriptTest("conch/tkconch")
test_tkconch.skip = tkskip or skip
class ZshIntegrationTestCase(TestCase, ZshScriptTestMixin):
"""
Test that zsh completion functions are generated without error
"""
generateFor = [('conch', 'twisted.conch.scripts.conch.ClientOptions'),
('cftp', 'twisted.conch.scripts.cftp.ClientOptions'),
('ckeygen', 'twisted.conch.scripts.ckeygen.GeneralOptions'),
('tkconch', 'twisted.conch.scripts.tkconch.GeneralOptions'),
]

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,995 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.conch.ssh}.
"""
import struct
try:
import Crypto.Cipher.DES3
except ImportError:
Crypto = None
try:
import pyasn1
except ImportError:
pyasn1 = None
from twisted.conch.ssh import common, session, forwarding
from twisted.conch import avatar, error
from twisted.conch.test.keydata import publicRSA_openssh, privateRSA_openssh
from twisted.conch.test.keydata import publicDSA_openssh, privateDSA_openssh
from twisted.cred import portal
from twisted.cred.error import UnauthorizedLogin
from twisted.internet import defer, protocol, reactor
from twisted.internet.error import ProcessTerminated
from twisted.python import failure, log
from twisted.trial import unittest
from twisted.conch.test.test_recvline import LoopbackRelay
class ConchTestRealm(object):
"""
A realm which expects a particular avatarId to log in once and creates a
L{ConchTestAvatar} for that request.
@ivar expectedAvatarID: The only avatarID that this realm will produce an
avatar for.
@ivar avatar: A reference to the avatar after it is requested.
"""
avatar = None
def __init__(self, expectedAvatarID):
self.expectedAvatarID = expectedAvatarID
def requestAvatar(self, avatarID, mind, *interfaces):
"""
Return a new L{ConchTestAvatar} if the avatarID matches the expected one
and this is the first avatar request.
"""
if avatarID == self.expectedAvatarID:
if self.avatar is not None:
raise UnauthorizedLogin("Only one login allowed")
self.avatar = ConchTestAvatar()
return interfaces[0], self.avatar, self.avatar.logout
raise UnauthorizedLogin(
"Only %r may log in, not %r" % (self.expectedAvatarID, avatarID))
class ConchTestAvatar(avatar.ConchUser):
"""
An avatar against which various SSH features can be tested.
@ivar loggedOut: A flag indicating whether the avatar logout method has been
called.
"""
loggedOut = False
def __init__(self):
avatar.ConchUser.__init__(self)
self.listeners = {}
self.globalRequests = {}
self.channelLookup.update({'session': session.SSHSession,
'direct-tcpip':forwarding.openConnectForwardingClient})
self.subsystemLookup.update({'crazy': CrazySubsystem})
def global_foo(self, data):
self.globalRequests['foo'] = data
return 1
def global_foo_2(self, data):
self.globalRequests['foo_2'] = data
return 1, 'data'
def global_tcpip_forward(self, data):
host, port = forwarding.unpackGlobal_tcpip_forward(data)
try:
listener = reactor.listenTCP(
port, forwarding.SSHListenForwardingFactory(
self.conn, (host, port),
forwarding.SSHListenServerForwardingChannel),
interface=host)
except:
log.err(None, "something went wrong with remote->local forwarding")
return 0
else:
self.listeners[(host, port)] = listener
return 1
def global_cancel_tcpip_forward(self, data):
host, port = forwarding.unpackGlobal_tcpip_forward(data)
listener = self.listeners.get((host, port), None)
if not listener:
return 0
del self.listeners[(host, port)]
listener.stopListening()
return 1
def logout(self):
self.loggedOut = True
for listener in self.listeners.values():
log.msg('stopListening %s' % listener)
listener.stopListening()
class ConchSessionForTestAvatar(object):
"""
An ISession adapter for ConchTestAvatar.
"""
def __init__(self, avatar):
"""
Initialize the session and create a reference to it on the avatar for
later inspection.
"""
self.avatar = avatar
self.avatar._testSession = self
self.cmd = None
self.proto = None
self.ptyReq = False
self.eof = 0
self.onClose = defer.Deferred()
def getPty(self, term, windowSize, attrs):
log.msg('pty req')
self._terminalType = term
self._windowSize = windowSize
self.ptyReq = True
def openShell(self, proto):
log.msg('opening shell')
self.proto = proto
EchoTransport(proto)
self.cmd = 'shell'
def execCommand(self, proto, cmd):
self.cmd = cmd
self.proto = proto
f = cmd.split()[0]
if f == 'false':
t = FalseTransport(proto)
# Avoid disconnecting this immediately. If the channel is closed
# before execCommand even returns the caller gets confused.
reactor.callLater(0, t.loseConnection)
elif f == 'echo':
t = EchoTransport(proto)
t.write(cmd[5:])
t.loseConnection()
elif f == 'secho':
t = SuperEchoTransport(proto)
t.write(cmd[6:])
t.loseConnection()
elif f == 'eecho':
t = ErrEchoTransport(proto)
t.write(cmd[6:])
t.loseConnection()
else:
raise error.ConchError('bad exec')
self.avatar.conn.transport.expectedLoseConnection = 1
def eofReceived(self):
self.eof = 1
def closed(self):
log.msg('closed cmd "%s"' % self.cmd)
self.remoteWindowLeftAtClose = self.proto.session.remoteWindowLeft
self.onClose.callback(None)
from twisted.python import components
components.registerAdapter(ConchSessionForTestAvatar, ConchTestAvatar, session.ISession)
class CrazySubsystem(protocol.Protocol):
def __init__(self, *args, **kw):
pass
def connectionMade(self):
"""
good ... good
"""
class FalseTransport:
"""
False transport should act like a /bin/false execution, i.e. just exit with
nonzero status, writing nothing to the terminal.
@ivar proto: The protocol associated with this transport.
@ivar closed: A flag tracking whether C{loseConnection} has been called yet.
"""
def __init__(self, p):
"""
@type p L{twisted.conch.ssh.session.SSHSessionProcessProtocol} instance
"""
self.proto = p
p.makeConnection(self)
self.closed = 0
def loseConnection(self):
"""
Disconnect the protocol associated with this transport.
"""
if self.closed:
return
self.closed = 1
self.proto.inConnectionLost()
self.proto.outConnectionLost()
self.proto.errConnectionLost()
self.proto.processEnded(failure.Failure(ProcessTerminated(255, None, None)))
class EchoTransport:
def __init__(self, p):
self.proto = p
p.makeConnection(self)
self.closed = 0
def write(self, data):
log.msg(repr(data))
self.proto.outReceived(data)
self.proto.outReceived('\r\n')
if '\x00' in data: # mimic 'exit' for the shell test
self.loseConnection()
def loseConnection(self):
if self.closed: return
self.closed = 1
self.proto.inConnectionLost()
self.proto.outConnectionLost()
self.proto.errConnectionLost()
self.proto.processEnded(failure.Failure(ProcessTerminated(0, None, None)))
class ErrEchoTransport:
def __init__(self, p):
self.proto = p
p.makeConnection(self)
self.closed = 0
def write(self, data):
self.proto.errReceived(data)
self.proto.errReceived('\r\n')
def loseConnection(self):
if self.closed: return
self.closed = 1
self.proto.inConnectionLost()
self.proto.outConnectionLost()
self.proto.errConnectionLost()
self.proto.processEnded(failure.Failure(ProcessTerminated(0, None, None)))
class SuperEchoTransport:
def __init__(self, p):
self.proto = p
p.makeConnection(self)
self.closed = 0
def write(self, data):
self.proto.outReceived(data)
self.proto.outReceived('\r\n')
self.proto.errReceived(data)
self.proto.errReceived('\r\n')
def loseConnection(self):
if self.closed: return
self.closed = 1
self.proto.inConnectionLost()
self.proto.outConnectionLost()
self.proto.errConnectionLost()
self.proto.processEnded(failure.Failure(ProcessTerminated(0, None, None)))
if Crypto is not None and pyasn1 is not None:
from twisted.conch import checkers
from twisted.conch.ssh import channel, connection, factory, keys
from twisted.conch.ssh import transport, userauth
class UtilityTestCase(unittest.TestCase):
def testCounter(self):
c = transport._Counter('\x00\x00', 2)
for i in xrange(256 * 256):
self.assertEqual(c(), struct.pack('!H', (i + 1) % (2 ** 16)))
# It should wrap around, too.
for i in xrange(256 * 256):
self.assertEqual(c(), struct.pack('!H', (i + 1) % (2 ** 16)))
class ConchTestPublicKeyChecker(checkers.SSHPublicKeyDatabase):
def checkKey(self, credentials):
blob = keys.Key.fromString(publicDSA_openssh).blob()
if credentials.username == 'testuser' and credentials.blob == blob:
return True
return False
class ConchTestPasswordChecker:
credentialInterfaces = checkers.IUsernamePassword,
def requestAvatarId(self, credentials):
if credentials.username == 'testuser' and credentials.password == 'testpass':
return defer.succeed(credentials.username)
return defer.fail(Exception("Bad credentials"))
class ConchTestSSHChecker(checkers.SSHProtocolChecker):
def areDone(self, avatarId):
if avatarId != 'testuser' or len(self.successfulCredentials[avatarId]) < 2:
return False
return True
class ConchTestServerFactory(factory.SSHFactory):
noisy = 0
services = {
'ssh-userauth':userauth.SSHUserAuthServer,
'ssh-connection':connection.SSHConnection
}
def buildProtocol(self, addr):
proto = ConchTestServer()
proto.supportedPublicKeys = self.privateKeys.keys()
proto.factory = self
if hasattr(self, 'expectedLoseConnection'):
proto.expectedLoseConnection = self.expectedLoseConnection
self.proto = proto
return proto
def getPublicKeys(self):
return {
'ssh-rsa': keys.Key.fromString(publicRSA_openssh),
'ssh-dss': keys.Key.fromString(publicDSA_openssh)
}
def getPrivateKeys(self):
return {
'ssh-rsa': keys.Key.fromString(privateRSA_openssh),
'ssh-dss': keys.Key.fromString(privateDSA_openssh)
}
def getPrimes(self):
return {
2048:[(transport.DH_GENERATOR, transport.DH_PRIME)]
}
def getService(self, trans, name):
return factory.SSHFactory.getService(self, trans, name)
class ConchTestBase:
done = 0
def connectionLost(self, reason):
if self.done:
return
if not hasattr(self,'expectedLoseConnection'):
unittest.fail('unexpectedly lost connection %s\n%s' % (self, reason))
self.done = 1
def receiveError(self, reasonCode, desc):
self.expectedLoseConnection = 1
# Some versions of OpenSSH (for example, OpenSSH_5.3p1) will
# send a DISCONNECT_BY_APPLICATION error before closing the
# connection. Other, older versions (for example,
# OpenSSH_5.1p1), won't. So accept this particular error here,
# but no others.
if reasonCode != transport.DISCONNECT_BY_APPLICATION:
log.err(
Exception(
'got disconnect for %s: reason %s, desc: %s' % (
self, reasonCode, desc)))
self.loseConnection()
def receiveUnimplemented(self, seqID):
unittest.fail('got unimplemented: seqid %s' % seqID)
self.expectedLoseConnection = 1
self.loseConnection()
class ConchTestServer(ConchTestBase, transport.SSHServerTransport):
def connectionLost(self, reason):
ConchTestBase.connectionLost(self, reason)
transport.SSHServerTransport.connectionLost(self, reason)
class ConchTestClient(ConchTestBase, transport.SSHClientTransport):
"""
@ivar _channelFactory: A callable which accepts an SSH connection and
returns a channel which will be attached to a new channel on that
connection.
"""
def __init__(self, channelFactory):
self._channelFactory = channelFactory
def connectionLost(self, reason):
ConchTestBase.connectionLost(self, reason)
transport.SSHClientTransport.connectionLost(self, reason)
def verifyHostKey(self, key, fp):
keyMatch = key == keys.Key.fromString(publicRSA_openssh).blob()
fingerprintMatch = (
fp == '3d:13:5f:cb:c9:79:8a:93:06:27:65:bc:3d:0b:8f:af')
if keyMatch and fingerprintMatch:
return defer.succeed(1)
return defer.fail(Exception("Key or fingerprint mismatch"))
def connectionSecure(self):
self.requestService(ConchTestClientAuth('testuser',
ConchTestClientConnection(self._channelFactory)))
class ConchTestClientAuth(userauth.SSHUserAuthClient):
hasTriedNone = 0 # have we tried the 'none' auth yet?
canSucceedPublicKey = 0 # can we succed with this yet?
canSucceedPassword = 0
def ssh_USERAUTH_SUCCESS(self, packet):
if not self.canSucceedPassword and self.canSucceedPublicKey:
unittest.fail('got USERAUTH_SUCESS before password and publickey')
userauth.SSHUserAuthClient.ssh_USERAUTH_SUCCESS(self, packet)
def getPassword(self):
self.canSucceedPassword = 1
return defer.succeed('testpass')
def getPrivateKey(self):
self.canSucceedPublicKey = 1
return defer.succeed(keys.Key.fromString(privateDSA_openssh))
def getPublicKey(self):
return keys.Key.fromString(publicDSA_openssh)
class ConchTestClientConnection(connection.SSHConnection):
"""
@ivar _completed: A L{Deferred} which will be fired when the number of
results collected reaches C{totalResults}.
"""
name = 'ssh-connection'
results = 0
totalResults = 8
def __init__(self, channelFactory):
connection.SSHConnection.__init__(self)
self._channelFactory = channelFactory
def serviceStarted(self):
self.openChannel(self._channelFactory(conn=self))
class SSHTestChannel(channel.SSHChannel):
def __init__(self, name, opened, *args, **kwargs):
self.name = name
self._opened = opened
self.received = []
self.receivedExt = []
self.onClose = defer.Deferred()
channel.SSHChannel.__init__(self, *args, **kwargs)
def openFailed(self, reason):
self._opened.errback(reason)
def channelOpen(self, ignore):
self._opened.callback(self)
def dataReceived(self, data):
self.received.append(data)
def extReceived(self, dataType, data):
if dataType == connection.EXTENDED_DATA_STDERR:
self.receivedExt.append(data)
else:
log.msg("Unrecognized extended data: %r" % (dataType,))
def request_exit_status(self, status):
[self.status] = struct.unpack('>L', status)
def eofReceived(self):
self.eofCalled = True
def closed(self):
self.onClose.callback(None)
class SSHProtocolTestCase(unittest.TestCase):
"""
Tests for communication between L{SSHServerTransport} and
L{SSHClientTransport}.
"""
if not Crypto:
skip = "can't run w/o PyCrypto"
if not pyasn1:
skip = "Cannot run without PyASN1"
def _ourServerOurClientTest(self, name='session', **kwargs):
"""
Create a connected SSH client and server protocol pair and return a
L{Deferred} which fires with an L{SSHTestChannel} instance connected to
a channel on that SSH connection.
"""
result = defer.Deferred()
self.realm = ConchTestRealm('testuser')
p = portal.Portal(self.realm)
sshpc = ConchTestSSHChecker()
sshpc.registerChecker(ConchTestPasswordChecker())
sshpc.registerChecker(ConchTestPublicKeyChecker())
p.registerChecker(sshpc)
fac = ConchTestServerFactory()
fac.portal = p
fac.startFactory()
self.server = fac.buildProtocol(None)
self.clientTransport = LoopbackRelay(self.server)
self.client = ConchTestClient(
lambda conn: SSHTestChannel(name, result, conn=conn, **kwargs))
self.serverTransport = LoopbackRelay(self.client)
self.server.makeConnection(self.serverTransport)
self.client.makeConnection(self.clientTransport)
return result
def test_subsystemsAndGlobalRequests(self):
"""
Run the Conch server against the Conch client. Set up several different
channels which exercise different behaviors and wait for them to
complete. Verify that the channels with errors log them.
"""
channel = self._ourServerOurClientTest()
def cbSubsystem(channel):
self.channel = channel
return self.assertFailure(
channel.conn.sendRequest(
channel, 'subsystem', common.NS('not-crazy'), 1),
Exception)
channel.addCallback(cbSubsystem)
def cbNotCrazyFailed(ignored):
channel = self.channel
return channel.conn.sendRequest(
channel, 'subsystem', common.NS('crazy'), 1)
channel.addCallback(cbNotCrazyFailed)
def cbGlobalRequests(ignored):
channel = self.channel
d1 = channel.conn.sendGlobalRequest('foo', 'bar', 1)
d2 = channel.conn.sendGlobalRequest('foo-2', 'bar2', 1)
d2.addCallback(self.assertEqual, 'data')
d3 = self.assertFailure(
channel.conn.sendGlobalRequest('bar', 'foo', 1),
Exception)
return defer.gatherResults([d1, d2, d3])
channel.addCallback(cbGlobalRequests)
def disconnect(ignored):
self.assertEqual(
self.realm.avatar.globalRequests,
{"foo": "bar", "foo_2": "bar2"})
channel = self.channel
channel.conn.transport.expectedLoseConnection = True
channel.conn.serviceStopped()
channel.loseConnection()
channel.addCallback(disconnect)
return channel
def test_shell(self):
"""
L{SSHChannel.sendRequest} can open a shell with a I{pty-req} request,
specifying a terminal type and window size.
"""
channel = self._ourServerOurClientTest()
data = session.packRequest_pty_req('conch-test-term', (24, 80, 0, 0), '')
def cbChannel(channel):
self.channel = channel
return channel.conn.sendRequest(channel, 'pty-req', data, 1)
channel.addCallback(cbChannel)
def cbPty(ignored):
# The server-side object corresponding to our client side channel.
session = self.realm.avatar.conn.channels[0].session
self.assertIs(session.avatar, self.realm.avatar)
self.assertEqual(session._terminalType, 'conch-test-term')
self.assertEqual(session._windowSize, (24, 80, 0, 0))
self.assertTrue(session.ptyReq)
channel = self.channel
return channel.conn.sendRequest(channel, 'shell', '', 1)
channel.addCallback(cbPty)
def cbShell(ignored):
self.channel.write('testing the shell!\x00')
self.channel.conn.sendEOF(self.channel)
return defer.gatherResults([
self.channel.onClose,
self.realm.avatar._testSession.onClose])
channel.addCallback(cbShell)
def cbExited(ignored):
if self.channel.status != 0:
log.msg(
'shell exit status was not 0: %i' % (self.channel.status,))
self.assertEqual(
"".join(self.channel.received),
'testing the shell!\x00\r\n')
self.assertTrue(self.channel.eofCalled)
self.assertTrue(
self.realm.avatar._testSession.eof)
channel.addCallback(cbExited)
return channel
def test_failedExec(self):
"""
If L{SSHChannel.sendRequest} issues an exec which the server responds to
with an error, the L{Deferred} it returns fires its errback.
"""
channel = self._ourServerOurClientTest()
def cbChannel(channel):
self.channel = channel
return self.assertFailure(
channel.conn.sendRequest(
channel, 'exec', common.NS('jumboliah'), 1),
Exception)
channel.addCallback(cbChannel)
def cbFailed(ignored):
# The server logs this exception when it cannot perform the
# requested exec.
errors = self.flushLoggedErrors(error.ConchError)
self.assertEqual(errors[0].value.args, ('bad exec', None))
channel.addCallback(cbFailed)
return channel
def test_falseChannel(self):
"""
When the process started by a L{SSHChannel.sendRequest} exec request
exits, the exit status is reported to the channel.
"""
channel = self._ourServerOurClientTest()
def cbChannel(channel):
self.channel = channel
return channel.conn.sendRequest(
channel, 'exec', common.NS('false'), 1)
channel.addCallback(cbChannel)
def cbExec(ignored):
return self.channel.onClose
channel.addCallback(cbExec)
def cbClosed(ignored):
# No data is expected
self.assertEqual(self.channel.received, [])
self.assertNotEqual(self.channel.status, 0)
channel.addCallback(cbClosed)
return channel
def test_errorChannel(self):
"""
Bytes sent over the extended channel for stderr data are delivered to
the channel's C{extReceived} method.
"""
channel = self._ourServerOurClientTest(localWindow=4, localMaxPacket=5)
def cbChannel(channel):
self.channel = channel
return channel.conn.sendRequest(
channel, 'exec', common.NS('eecho hello'), 1)
channel.addCallback(cbChannel)
def cbExec(ignored):
return defer.gatherResults([
self.channel.onClose,
self.realm.avatar._testSession.onClose])
channel.addCallback(cbExec)
def cbClosed(ignored):
self.assertEqual(self.channel.received, [])
self.assertEqual("".join(self.channel.receivedExt), "hello\r\n")
self.assertEqual(self.channel.status, 0)
self.assertTrue(self.channel.eofCalled)
self.assertEqual(self.channel.localWindowLeft, 4)
self.assertEqual(
self.channel.localWindowLeft,
self.realm.avatar._testSession.remoteWindowLeftAtClose)
channel.addCallback(cbClosed)
return channel
def test_unknownChannel(self):
"""
When an attempt is made to open an unknown channel type, the L{Deferred}
returned by L{SSHChannel.sendRequest} fires its errback.
"""
d = self.assertFailure(
self._ourServerOurClientTest('crazy-unknown-channel'), Exception)
def cbFailed(ignored):
errors = self.flushLoggedErrors(error.ConchError)
self.assertEqual(errors[0].value.args, (3, 'unknown channel'))
self.assertEqual(len(errors), 1)
d.addCallback(cbFailed)
return d
def test_maxPacket(self):
"""
An L{SSHChannel} can be configured with a maximum packet size to
receive.
"""
# localWindow needs to be at least 11 otherwise the assertion about it
# in cbClosed is invalid.
channel = self._ourServerOurClientTest(
localWindow=11, localMaxPacket=1)
def cbChannel(channel):
self.channel = channel
return channel.conn.sendRequest(
channel, 'exec', common.NS('secho hello'), 1)
channel.addCallback(cbChannel)
def cbExec(ignored):
return self.channel.onClose
channel.addCallback(cbExec)
def cbClosed(ignored):
self.assertEqual(self.channel.status, 0)
self.assertEqual("".join(self.channel.received), "hello\r\n")
self.assertEqual("".join(self.channel.receivedExt), "hello\r\n")
self.assertEqual(self.channel.localWindowLeft, 11)
self.assertTrue(self.channel.eofCalled)
channel.addCallback(cbClosed)
return channel
def test_echo(self):
"""
Normal standard out bytes are sent to the channel's C{dataReceived}
method.
"""
channel = self._ourServerOurClientTest(localWindow=4, localMaxPacket=5)
def cbChannel(channel):
self.channel = channel
return channel.conn.sendRequest(
channel, 'exec', common.NS('echo hello'), 1)
channel.addCallback(cbChannel)
def cbEcho(ignored):
return defer.gatherResults([
self.channel.onClose,
self.realm.avatar._testSession.onClose])
channel.addCallback(cbEcho)
def cbClosed(ignored):
self.assertEqual(self.channel.status, 0)
self.assertEqual("".join(self.channel.received), "hello\r\n")
self.assertEqual(self.channel.localWindowLeft, 4)
self.assertTrue(self.channel.eofCalled)
self.assertEqual(
self.channel.localWindowLeft,
self.realm.avatar._testSession.remoteWindowLeftAtClose)
channel.addCallback(cbClosed)
return channel
class TestSSHFactory(unittest.TestCase):
if not Crypto:
skip = "can't run w/o PyCrypto"
if not pyasn1:
skip = "Cannot run without PyASN1"
def makeSSHFactory(self, primes=None):
sshFactory = factory.SSHFactory()
gpk = lambda: {'ssh-rsa' : keys.Key(None)}
sshFactory.getPrimes = lambda: primes
sshFactory.getPublicKeys = sshFactory.getPrivateKeys = gpk
sshFactory.startFactory()
return sshFactory
def test_buildProtocol(self):
"""
By default, buildProtocol() constructs an instance of
SSHServerTransport.
"""
factory = self.makeSSHFactory()
protocol = factory.buildProtocol(None)
self.assertIsInstance(protocol, transport.SSHServerTransport)
def test_buildProtocolRespectsProtocol(self):
"""
buildProtocol() calls 'self.protocol()' to construct a protocol
instance.
"""
calls = []
def makeProtocol(*args):
calls.append(args)
return transport.SSHServerTransport()
factory = self.makeSSHFactory()
factory.protocol = makeProtocol
factory.buildProtocol(None)
self.assertEqual([()], calls)
def test_multipleFactories(self):
f1 = self.makeSSHFactory(primes=None)
f2 = self.makeSSHFactory(primes={1:(2,3)})
p1 = f1.buildProtocol(None)
p2 = f2.buildProtocol(None)
self.assertNotIn(
'diffie-hellman-group-exchange-sha1', p1.supportedKeyExchanges)
self.assertIn(
'diffie-hellman-group-exchange-sha1', p2.supportedKeyExchanges)
class MPTestCase(unittest.TestCase):
"""
Tests for L{common.getMP}.
@cvar getMP: a method providing a MP parser.
@type getMP: C{callable}
"""
getMP = staticmethod(common.getMP)
if not Crypto:
skip = "can't run w/o PyCrypto"
if not pyasn1:
skip = "Cannot run without PyASN1"
def test_getMP(self):
"""
L{common.getMP} should parse the a multiple precision integer from a
string: a 4-byte length followed by length bytes of the integer.
"""
self.assertEqual(
self.getMP('\x00\x00\x00\x04\x00\x00\x00\x01'),
(1, ''))
def test_getMPBigInteger(self):
"""
L{common.getMP} should be able to parse a big enough integer
(that doesn't fit on one byte).
"""
self.assertEqual(
self.getMP('\x00\x00\x00\x04\x01\x02\x03\x04'),
(16909060, ''))
def test_multipleGetMP(self):
"""
L{common.getMP} has the ability to parse multiple integer in the same
string.
"""
self.assertEqual(
self.getMP('\x00\x00\x00\x04\x00\x00\x00\x01'
'\x00\x00\x00\x04\x00\x00\x00\x02', 2),
(1, 2, ''))
def test_getMPRemainingData(self):
"""
When more data than needed is sent to L{common.getMP}, it should return
the remaining data.
"""
self.assertEqual(
self.getMP('\x00\x00\x00\x04\x00\x00\x00\x01foo'),
(1, 'foo'))
def test_notEnoughData(self):
"""
When the string passed to L{common.getMP} doesn't even make 5 bytes,
it should raise a L{struct.error}.
"""
self.assertRaises(struct.error, self.getMP, '\x02\x00')
class PyMPTestCase(MPTestCase):
"""
Tests for the python implementation of L{common.getMP}.
"""
getMP = staticmethod(common.getMP_py)
class GMPYMPTestCase(MPTestCase):
"""
Tests for the gmpy implementation of L{common.getMP}.
"""
getMP = staticmethod(common._fastgetMP)
class BuiltinPowHackTestCase(unittest.TestCase):
"""
Tests that the builtin pow method is still correct after
L{twisted.conch.ssh.common} monkeypatches it to use gmpy.
"""
def test_floatBase(self):
"""
pow gives the correct result when passed a base of type float with a
non-integer value.
"""
self.assertEqual(6.25, pow(2.5, 2))
def test_intBase(self):
"""
pow gives the correct result when passed a base of type int.
"""
self.assertEqual(81, pow(3, 4))
def test_longBase(self):
"""
pow gives the correct result when passed a base of type long.
"""
self.assertEqual(81, pow(3, 4))
def test_mpzBase(self):
"""
pow gives the correct result when passed a base of type gmpy.mpz.
"""
if gmpy is None:
raise unittest.SkipTest('gmpy not available')
self.assertEqual(81, pow(gmpy.mpz(3), 4))
try:
import gmpy
except ImportError:
GMPYMPTestCase.skip = "gmpy not available"
gmpy = None

View file

@ -0,0 +1,183 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.conch.tap}.
"""
try:
import Crypto.Cipher.DES3
except:
Crypto = None
try:
import pyasn1
except ImportError:
pyasn1 = None
try:
from twisted.conch import unix
except ImportError:
unix = None
if Crypto and pyasn1 and unix:
from twisted.conch import tap
from twisted.conch.openssh_compat.factory import OpenSSHFactory
from twisted.application.internet import StreamServerEndpointService
from twisted.cred import error
from twisted.cred.credentials import IPluggableAuthenticationModules
from twisted.cred.credentials import ISSHPrivateKey
from twisted.cred.credentials import IUsernamePassword, UsernamePassword
from twisted.trial.unittest import TestCase
class MakeServiceTest(TestCase):
"""
Tests for L{tap.makeService}.
"""
if not Crypto:
skip = "can't run w/o PyCrypto"
if not pyasn1:
skip = "Cannot run without PyASN1"
if not unix:
skip = "can't run on non-posix computers"
usernamePassword = ('iamuser', 'thisispassword')
def setUp(self):
"""
Create a file with two users.
"""
self.filename = self.mktemp()
f = open(self.filename, 'wb+')
f.write(':'.join(self.usernamePassword))
f.close()
self.options = tap.Options()
def test_basic(self):
"""
L{tap.makeService} returns a L{StreamServerEndpointService} instance
running on TCP port 22, and the linked protocol factory is an instance
of L{OpenSSHFactory}.
"""
config = tap.Options()
service = tap.makeService(config)
self.assertIsInstance(service, StreamServerEndpointService)
self.assertEqual(service.endpoint._port, 22)
self.assertIsInstance(service.factory, OpenSSHFactory)
def test_defaultAuths(self):
"""
Make sure that if the C{--auth} command-line option is not passed,
the default checkers are (for backwards compatibility): SSH, UNIX, and
PAM if available
"""
numCheckers = 2
try:
from twisted.cred import pamauth
self.assertIn(IPluggableAuthenticationModules,
self.options['credInterfaces'],
"PAM should be one of the modules")
numCheckers += 1
except ImportError:
pass
self.assertIn(ISSHPrivateKey, self.options['credInterfaces'],
"SSH should be one of the default checkers")
self.assertIn(IUsernamePassword, self.options['credInterfaces'],
"UNIX should be one of the default checkers")
self.assertEqual(numCheckers, len(self.options['credCheckers']),
"There should be %d checkers by default" % (numCheckers,))
def test_authAdded(self):
"""
The C{--auth} command-line option will add a checker to the list of
checkers, and it should be the only auth checker
"""
self.options.parseOptions(['--auth', 'file:' + self.filename])
self.assertEqual(len(self.options['credCheckers']), 1)
def test_multipleAuthAdded(self):
"""
Multiple C{--auth} command-line options will add all checkers specified
to the list ofcheckers, and there should only be the specified auth
checkers (no default checkers).
"""
self.options.parseOptions(['--auth', 'file:' + self.filename,
'--auth', 'memory:testuser:testpassword'])
self.assertEqual(len(self.options['credCheckers']), 2)
def test_authFailure(self):
"""
The checker created by the C{--auth} command-line option returns a
L{Deferred} that fails with L{UnauthorizedLogin} when
presented with credentials that are unknown to that checker.
"""
self.options.parseOptions(['--auth', 'file:' + self.filename])
checker = self.options['credCheckers'][-1]
invalid = UsernamePassword(self.usernamePassword[0], 'fake')
# Wrong password should raise error
return self.assertFailure(
checker.requestAvatarId(invalid), error.UnauthorizedLogin)
def test_authSuccess(self):
"""
The checker created by the C{--auth} command-line option returns a
L{Deferred} that returns the avatar id when presented with credentials
that are known to that checker.
"""
self.options.parseOptions(['--auth', 'file:' + self.filename])
checker = self.options['credCheckers'][-1]
correct = UsernamePassword(*self.usernamePassword)
d = checker.requestAvatarId(correct)
def checkSuccess(username):
self.assertEqual(username, correct.username)
return d.addCallback(checkSuccess)
def test_checkersPamAuth(self):
"""
The L{OpenSSHFactory} built by L{tap.makeService} has a portal with
L{IPluggableAuthenticationModules}, L{ISSHPrivateKey} and
L{IUsernamePassword} interfaces registered as checkers if C{pamauth} is
available.
"""
# Fake the presence of pamauth, even if PyPAM is not installed
self.patch(tap, "pamauth", object())
config = tap.Options()
service = tap.makeService(config)
portal = service.factory.portal
self.assertEqual(
set(portal.checkers.keys()),
set([IPluggableAuthenticationModules, ISSHPrivateKey,
IUsernamePassword]))
def test_checkersWithoutPamAuth(self):
"""
The L{OpenSSHFactory} built by L{tap.makeService} has a portal with
L{ISSHPrivateKey} and L{IUsernamePassword} interfaces registered as
checkers if C{pamauth} is not available.
"""
# Fake the absence of pamauth, even if PyPAM is installed
self.patch(tap, "pamauth", None)
config = tap.Options()
service = tap.makeService(config)
portal = service.factory.portal
self.assertEqual(
set(portal.checkers.keys()),
set([ISSHPrivateKey, IUsernamePassword]))

View file

@ -0,0 +1,767 @@
# -*- test-case-name: twisted.conch.test.test_telnet -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.conch.telnet}.
"""
from zope.interface import implements
from zope.interface.verify import verifyObject
from twisted.internet import defer
from twisted.conch import telnet
from twisted.trial import unittest
from twisted.test import proto_helpers
class TestProtocol:
implements(telnet.ITelnetProtocol)
localEnableable = ()
remoteEnableable = ()
def __init__(self):
self.bytes = ''
self.subcmd = ''
self.calls = []
self.enabledLocal = []
self.enabledRemote = []
self.disabledLocal = []
self.disabledRemote = []
def makeConnection(self, transport):
d = transport.negotiationMap = {}
d['\x12'] = self.neg_TEST_COMMAND
d = transport.commandMap = transport.commandMap.copy()
for cmd in ('NOP', 'DM', 'BRK', 'IP', 'AO', 'AYT', 'EC', 'EL', 'GA'):
d[getattr(telnet, cmd)] = lambda arg, cmd=cmd: self.calls.append(cmd)
def dataReceived(self, bytes):
self.bytes += bytes
def connectionLost(self, reason):
pass
def neg_TEST_COMMAND(self, payload):
self.subcmd = payload
def enableLocal(self, option):
if option in self.localEnableable:
self.enabledLocal.append(option)
return True
return False
def disableLocal(self, option):
self.disabledLocal.append(option)
def enableRemote(self, option):
if option in self.remoteEnableable:
self.enabledRemote.append(option)
return True
return False
def disableRemote(self, option):
self.disabledRemote.append(option)
class TestInterfaces(unittest.TestCase):
def test_interface(self):
"""
L{telnet.TelnetProtocol} implements L{telnet.ITelnetProtocol}
"""
p = telnet.TelnetProtocol()
verifyObject(telnet.ITelnetProtocol, p)
class TelnetTransportTestCase(unittest.TestCase):
"""
Tests for L{telnet.TelnetTransport}.
"""
def setUp(self):
self.p = telnet.TelnetTransport(TestProtocol)
self.t = proto_helpers.StringTransport()
self.p.makeConnection(self.t)
def testRegularBytes(self):
# Just send a bunch of bytes. None of these do anything
# with telnet. They should pass right through to the
# application layer.
h = self.p.protocol
L = ["here are some bytes la la la",
"some more arrive here",
"lots of bytes to play with",
"la la la",
"ta de da",
"dum"]
for b in L:
self.p.dataReceived(b)
self.assertEqual(h.bytes, ''.join(L))
def testNewlineHandling(self):
# Send various kinds of newlines and make sure they get translated
# into \n.
h = self.p.protocol
L = ["here is the first line\r\n",
"here is the second line\r\0",
"here is the third line\r\n",
"here is the last line\r\0"]
for b in L:
self.p.dataReceived(b)
self.assertEqual(h.bytes, L[0][:-2] + '\n' +
L[1][:-2] + '\r' +
L[2][:-2] + '\n' +
L[3][:-2] + '\r')
def testIACEscape(self):
# Send a bunch of bytes and a couple quoted \xFFs. Unquoted,
# \xFF is a telnet command. Quoted, one of them from each pair
# should be passed through to the application layer.
h = self.p.protocol
L = ["here are some bytes\xff\xff with an embedded IAC",
"and here is a test of a border escape\xff",
"\xff did you get that IAC?"]
for b in L:
self.p.dataReceived(b)
self.assertEqual(h.bytes, ''.join(L).replace('\xff\xff', '\xff'))
def _simpleCommandTest(self, cmdName):
# Send a single simple telnet command and make sure
# it gets noticed and the appropriate method gets
# called.
h = self.p.protocol
cmd = telnet.IAC + getattr(telnet, cmdName)
L = ["Here's some bytes, tra la la",
"But ono!" + cmd + " an interrupt"]
for b in L:
self.p.dataReceived(b)
self.assertEqual(h.calls, [cmdName])
self.assertEqual(h.bytes, ''.join(L).replace(cmd, ''))
def testInterrupt(self):
self._simpleCommandTest("IP")
def testNoOperation(self):
self._simpleCommandTest("NOP")
def testDataMark(self):
self._simpleCommandTest("DM")
def testBreak(self):
self._simpleCommandTest("BRK")
def testAbortOutput(self):
self._simpleCommandTest("AO")
def testAreYouThere(self):
self._simpleCommandTest("AYT")
def testEraseCharacter(self):
self._simpleCommandTest("EC")
def testEraseLine(self):
self._simpleCommandTest("EL")
def testGoAhead(self):
self._simpleCommandTest("GA")
def testSubnegotiation(self):
# Send a subnegotiation command and make sure it gets
# parsed and that the correct method is called.
h = self.p.protocol
cmd = telnet.IAC + telnet.SB + '\x12hello world' + telnet.IAC + telnet.SE
L = ["These are some bytes but soon" + cmd,
"there will be some more"]
for b in L:
self.p.dataReceived(b)
self.assertEqual(h.bytes, ''.join(L).replace(cmd, ''))
self.assertEqual(h.subcmd, list("hello world"))
def testSubnegotiationWithEmbeddedSE(self):
# Send a subnegotiation command with an embedded SE. Make sure
# that SE gets passed to the correct method.
h = self.p.protocol
cmd = (telnet.IAC + telnet.SB +
'\x12' + telnet.SE +
telnet.IAC + telnet.SE)
L = ["Some bytes are here" + cmd + "and here",
"and here"]
for b in L:
self.p.dataReceived(b)
self.assertEqual(h.bytes, ''.join(L).replace(cmd, ''))
self.assertEqual(h.subcmd, [telnet.SE])
def testBoundarySubnegotiation(self):
# Send a subnegotiation command. Split it at every possible byte boundary
# and make sure it always gets parsed and that it is passed to the correct
# method.
cmd = (telnet.IAC + telnet.SB +
'\x12' + telnet.SE + 'hello' +
telnet.IAC + telnet.SE)
for i in range(len(cmd)):
h = self.p.protocol = TestProtocol()
h.makeConnection(self.p)
a, b = cmd[:i], cmd[i:]
L = ["first part" + a,
b + "last part"]
for bytes in L:
self.p.dataReceived(bytes)
self.assertEqual(h.bytes, ''.join(L).replace(cmd, ''))
self.assertEqual(h.subcmd, [telnet.SE] + list('hello'))
def _enabledHelper(self, o, eL=[], eR=[], dL=[], dR=[]):
self.assertEqual(o.enabledLocal, eL)
self.assertEqual(o.enabledRemote, eR)
self.assertEqual(o.disabledLocal, dL)
self.assertEqual(o.disabledRemote, dR)
def testRefuseWill(self):
# Try to enable an option. The server should refuse to enable it.
cmd = telnet.IAC + telnet.WILL + '\x12'
bytes = "surrounding bytes" + cmd + "to spice things up"
self.p.dataReceived(bytes)
self.assertEqual(self.p.protocol.bytes, bytes.replace(cmd, ''))
self.assertEqual(self.t.value(), telnet.IAC + telnet.DONT + '\x12')
self._enabledHelper(self.p.protocol)
def testRefuseDo(self):
# Try to enable an option. The server should refuse to enable it.
cmd = telnet.IAC + telnet.DO + '\x12'
bytes = "surrounding bytes" + cmd + "to spice things up"
self.p.dataReceived(bytes)
self.assertEqual(self.p.protocol.bytes, bytes.replace(cmd, ''))
self.assertEqual(self.t.value(), telnet.IAC + telnet.WONT + '\x12')
self._enabledHelper(self.p.protocol)
def testAcceptDo(self):
# Try to enable an option. The option is in our allowEnable
# list, so we will allow it to be enabled.
cmd = telnet.IAC + telnet.DO + '\x19'
bytes = 'padding' + cmd + 'trailer'
h = self.p.protocol
h.localEnableable = ('\x19',)
self.p.dataReceived(bytes)
self.assertEqual(self.t.value(), telnet.IAC + telnet.WILL + '\x19')
self._enabledHelper(h, eL=['\x19'])
def testAcceptWill(self):
# Same as testAcceptDo, but reversed.
cmd = telnet.IAC + telnet.WILL + '\x91'
bytes = 'header' + cmd + 'padding'
h = self.p.protocol
h.remoteEnableable = ('\x91',)
self.p.dataReceived(bytes)
self.assertEqual(self.t.value(), telnet.IAC + telnet.DO + '\x91')
self._enabledHelper(h, eR=['\x91'])
def testAcceptWont(self):
# Try to disable an option. The server must allow any option to
# be disabled at any time. Make sure it disables it and sends
# back an acknowledgement of this.
cmd = telnet.IAC + telnet.WONT + '\x29'
# Jimmy it - after these two lines, the server will be in a state
# such that it believes the option to have been previously enabled
# via normal negotiation.
s = self.p.getOptionState('\x29')
s.him.state = 'yes'
bytes = "fiddle dee" + cmd
self.p.dataReceived(bytes)
self.assertEqual(self.p.protocol.bytes, bytes.replace(cmd, ''))
self.assertEqual(self.t.value(), telnet.IAC + telnet.DONT + '\x29')
self.assertEqual(s.him.state, 'no')
self._enabledHelper(self.p.protocol, dR=['\x29'])
def testAcceptDont(self):
# Try to disable an option. The server must allow any option to
# be disabled at any time. Make sure it disables it and sends
# back an acknowledgement of this.
cmd = telnet.IAC + telnet.DONT + '\x29'
# Jimmy it - after these two lines, the server will be in a state
# such that it believes the option to have beenp previously enabled
# via normal negotiation.
s = self.p.getOptionState('\x29')
s.us.state = 'yes'
bytes = "fiddle dum " + cmd
self.p.dataReceived(bytes)
self.assertEqual(self.p.protocol.bytes, bytes.replace(cmd, ''))
self.assertEqual(self.t.value(), telnet.IAC + telnet.WONT + '\x29')
self.assertEqual(s.us.state, 'no')
self._enabledHelper(self.p.protocol, dL=['\x29'])
def testIgnoreWont(self):
# Try to disable an option. The option is already disabled. The
# server should send nothing in response to this.
cmd = telnet.IAC + telnet.WONT + '\x47'
bytes = "dum de dum" + cmd + "tra la la"
self.p.dataReceived(bytes)
self.assertEqual(self.p.protocol.bytes, bytes.replace(cmd, ''))
self.assertEqual(self.t.value(), '')
self._enabledHelper(self.p.protocol)
def testIgnoreDont(self):
# Try to disable an option. The option is already disabled. The
# server should send nothing in response to this. Doing so could
# lead to a negotiation loop.
cmd = telnet.IAC + telnet.DONT + '\x47'
bytes = "dum de dum" + cmd + "tra la la"
self.p.dataReceived(bytes)
self.assertEqual(self.p.protocol.bytes, bytes.replace(cmd, ''))
self.assertEqual(self.t.value(), '')
self._enabledHelper(self.p.protocol)
def testIgnoreWill(self):
# Try to enable an option. The option is already enabled. The
# server should send nothing in response to this. Doing so could
# lead to a negotiation loop.
cmd = telnet.IAC + telnet.WILL + '\x56'
# Jimmy it - after these two lines, the server will be in a state
# such that it believes the option to have been previously enabled
# via normal negotiation.
s = self.p.getOptionState('\x56')
s.him.state = 'yes'
bytes = "tra la la" + cmd + "dum de dum"
self.p.dataReceived(bytes)
self.assertEqual(self.p.protocol.bytes, bytes.replace(cmd, ''))
self.assertEqual(self.t.value(), '')
self._enabledHelper(self.p.protocol)
def testIgnoreDo(self):
# Try to enable an option. The option is already enabled. The
# server should send nothing in response to this. Doing so could
# lead to a negotiation loop.
cmd = telnet.IAC + telnet.DO + '\x56'
# Jimmy it - after these two lines, the server will be in a state
# such that it believes the option to have been previously enabled
# via normal negotiation.
s = self.p.getOptionState('\x56')
s.us.state = 'yes'
bytes = "tra la la" + cmd + "dum de dum"
self.p.dataReceived(bytes)
self.assertEqual(self.p.protocol.bytes, bytes.replace(cmd, ''))
self.assertEqual(self.t.value(), '')
self._enabledHelper(self.p.protocol)
def testAcceptedEnableRequest(self):
# Try to enable an option through the user-level API. This
# returns a Deferred that fires when negotiation about the option
# finishes. Make sure it fires, make sure state gets updated
# properly, make sure the result indicates the option was enabled.
d = self.p.do('\x42')
h = self.p.protocol
h.remoteEnableable = ('\x42',)
self.assertEqual(self.t.value(), telnet.IAC + telnet.DO + '\x42')
self.p.dataReceived(telnet.IAC + telnet.WILL + '\x42')
d.addCallback(self.assertEqual, True)
d.addCallback(lambda _: self._enabledHelper(h, eR=['\x42']))
return d
def test_refusedEnableRequest(self):
"""
If the peer refuses to enable an option we request it to enable, the
L{Deferred} returned by L{TelnetProtocol.do} fires with an
L{OptionRefused} L{Failure}.
"""
# Try to enable an option through the user-level API. This returns a
# Deferred that fires when negotiation about the option finishes. Make
# sure it fires, make sure state gets updated properly, make sure the
# result indicates the option was enabled.
self.p.protocol.remoteEnableable = ('\x42',)
d = self.p.do('\x42')
self.assertEqual(self.t.value(), telnet.IAC + telnet.DO + '\x42')
s = self.p.getOptionState('\x42')
self.assertEqual(s.him.state, 'no')
self.assertEqual(s.us.state, 'no')
self.assertEqual(s.him.negotiating, True)
self.assertEqual(s.us.negotiating, False)
self.p.dataReceived(telnet.IAC + telnet.WONT + '\x42')
d = self.assertFailure(d, telnet.OptionRefused)
d.addCallback(lambda ignored: self._enabledHelper(self.p.protocol))
d.addCallback(
lambda ignored: self.assertEqual(s.him.negotiating, False))
return d
def test_refusedEnableOffer(self):
"""
If the peer refuses to allow us to enable an option, the L{Deferred}
returned by L{TelnetProtocol.will} fires with an L{OptionRefused}
L{Failure}.
"""
# Try to offer an option through the user-level API. This returns a
# Deferred that fires when negotiation about the option finishes. Make
# sure it fires, make sure state gets updated properly, make sure the
# result indicates the option was enabled.
self.p.protocol.localEnableable = ('\x42',)
d = self.p.will('\x42')
self.assertEqual(self.t.value(), telnet.IAC + telnet.WILL + '\x42')
s = self.p.getOptionState('\x42')
self.assertEqual(s.him.state, 'no')
self.assertEqual(s.us.state, 'no')
self.assertEqual(s.him.negotiating, False)
self.assertEqual(s.us.negotiating, True)
self.p.dataReceived(telnet.IAC + telnet.DONT + '\x42')
d = self.assertFailure(d, telnet.OptionRefused)
d.addCallback(lambda ignored: self._enabledHelper(self.p.protocol))
d.addCallback(
lambda ignored: self.assertEqual(s.us.negotiating, False))
return d
def testAcceptedDisableRequest(self):
# Try to disable an option through the user-level API. This
# returns a Deferred that fires when negotiation about the option
# finishes. Make sure it fires, make sure state gets updated
# properly, make sure the result indicates the option was enabled.
s = self.p.getOptionState('\x42')
s.him.state = 'yes'
d = self.p.dont('\x42')
self.assertEqual(self.t.value(), telnet.IAC + telnet.DONT + '\x42')
self.p.dataReceived(telnet.IAC + telnet.WONT + '\x42')
d.addCallback(self.assertEqual, True)
d.addCallback(lambda _: self._enabledHelper(self.p.protocol,
dR=['\x42']))
return d
def testNegotiationBlocksFurtherNegotiation(self):
# Try to disable an option, then immediately try to enable it, then
# immediately try to disable it. Ensure that the 2nd and 3rd calls
# fail quickly with the right exception.
s = self.p.getOptionState('\x24')
s.him.state = 'yes'
d2 = self.p.dont('\x24') # fires after the first line of _final
def _do(x):
d = self.p.do('\x24')
return self.assertFailure(d, telnet.AlreadyNegotiating)
def _dont(x):
d = self.p.dont('\x24')
return self.assertFailure(d, telnet.AlreadyNegotiating)
def _final(x):
self.p.dataReceived(telnet.IAC + telnet.WONT + '\x24')
# an assertion that only passes if d2 has fired
self._enabledHelper(self.p.protocol, dR=['\x24'])
# Make sure we allow this
self.p.protocol.remoteEnableable = ('\x24',)
d = self.p.do('\x24')
self.p.dataReceived(telnet.IAC + telnet.WILL + '\x24')
d.addCallback(self.assertEqual, True)
d.addCallback(lambda _: self._enabledHelper(self.p.protocol,
eR=['\x24'],
dR=['\x24']))
return d
d = _do(None)
d.addCallback(_dont)
d.addCallback(_final)
return d
def testSuperfluousDisableRequestRaises(self):
# Try to disable a disabled option. Make sure it fails properly.
d = self.p.dont('\xab')
return self.assertFailure(d, telnet.AlreadyDisabled)
def testSuperfluousEnableRequestRaises(self):
# Try to disable a disabled option. Make sure it fails properly.
s = self.p.getOptionState('\xab')
s.him.state = 'yes'
d = self.p.do('\xab')
return self.assertFailure(d, telnet.AlreadyEnabled)
def testLostConnectionFailsDeferreds(self):
d1 = self.p.do('\x12')
d2 = self.p.do('\x23')
d3 = self.p.do('\x34')
class TestException(Exception):
pass
self.p.connectionLost(TestException("Total failure!"))
d1 = self.assertFailure(d1, TestException)
d2 = self.assertFailure(d2, TestException)
d3 = self.assertFailure(d3, TestException)
return defer.gatherResults([d1, d2, d3])
class TestTelnet(telnet.Telnet):
"""
A trivial extension of the telnet protocol class useful to unit tests.
"""
def __init__(self):
telnet.Telnet.__init__(self)
self.events = []
def applicationDataReceived(self, bytes):
"""
Record the given data in C{self.events}.
"""
self.events.append(('bytes', bytes))
def unhandledCommand(self, command, bytes):
"""
Record the given command in C{self.events}.
"""
self.events.append(('command', command, bytes))
def unhandledSubnegotiation(self, command, bytes):
"""
Record the given subnegotiation command in C{self.events}.
"""
self.events.append(('negotiate', command, bytes))
class TelnetTests(unittest.TestCase):
"""
Tests for L{telnet.Telnet}.
L{telnet.Telnet} implements the TELNET protocol (RFC 854), including option
and suboption negotiation, and option state tracking.
"""
def setUp(self):
"""
Create an unconnected L{telnet.Telnet} to be used by tests.
"""
self.protocol = TestTelnet()
def test_enableLocal(self):
"""
L{telnet.Telnet.enableLocal} should reject all options, since
L{telnet.Telnet} does not know how to implement any options.
"""
self.assertFalse(self.protocol.enableLocal('\0'))
def test_enableRemote(self):
"""
L{telnet.Telnet.enableRemote} should reject all options, since
L{telnet.Telnet} does not know how to implement any options.
"""
self.assertFalse(self.protocol.enableRemote('\0'))
def test_disableLocal(self):
"""
It is an error for L{telnet.Telnet.disableLocal} to be called, since
L{telnet.Telnet.enableLocal} will never allow any options to be enabled
locally. If a subclass overrides enableLocal, it must also override
disableLocal.
"""
self.assertRaises(NotImplementedError, self.protocol.disableLocal, '\0')
def test_disableRemote(self):
"""
It is an error for L{telnet.Telnet.disableRemote} to be called, since
L{telnet.Telnet.enableRemote} will never allow any options to be
enabled remotely. If a subclass overrides enableRemote, it must also
override disableRemote.
"""
self.assertRaises(NotImplementedError, self.protocol.disableRemote, '\0')
def test_requestNegotiation(self):
"""
L{telnet.Telnet.requestNegotiation} formats the feature byte and the
payload bytes into the subnegotiation format and sends them.
See RFC 855.
"""
transport = proto_helpers.StringTransport()
self.protocol.makeConnection(transport)
self.protocol.requestNegotiation('\x01', '\x02\x03')
self.assertEqual(
transport.value(),
# IAC SB feature bytes IAC SE
'\xff\xfa\x01\x02\x03\xff\xf0')
def test_requestNegotiationEscapesIAC(self):
"""
If the payload for a subnegotiation includes I{IAC}, it is escaped by
L{telnet.Telnet.requestNegotiation} with another I{IAC}.
See RFC 855.
"""
transport = proto_helpers.StringTransport()
self.protocol.makeConnection(transport)
self.protocol.requestNegotiation('\x01', '\xff')
self.assertEqual(
transport.value(),
'\xff\xfa\x01\xff\xff\xff\xf0')
def _deliver(self, bytes, *expected):
"""
Pass the given bytes to the protocol's C{dataReceived} method and
assert that the given events occur.
"""
received = self.protocol.events = []
self.protocol.dataReceived(bytes)
self.assertEqual(received, list(expected))
def test_oneApplicationDataByte(self):
"""
One application-data byte in the default state gets delivered right
away.
"""
self._deliver('a', ('bytes', 'a'))
def test_twoApplicationDataBytes(self):
"""
Two application-data bytes in the default state get delivered
together.
"""
self._deliver('bc', ('bytes', 'bc'))
def test_threeApplicationDataBytes(self):
"""
Three application-data bytes followed by a control byte get
delivered, but the control byte doesn't.
"""
self._deliver('def' + telnet.IAC, ('bytes', 'def'))
def test_escapedControl(self):
"""
IAC in the escaped state gets delivered and so does another
application-data byte following it.
"""
self._deliver(telnet.IAC)
self._deliver(telnet.IAC + 'g', ('bytes', telnet.IAC + 'g'))
def test_carriageReturn(self):
"""
A carriage return only puts the protocol into the newline state. A
linefeed in the newline state causes just the newline to be
delivered. A nul in the newline state causes a carriage return to
be delivered. An IAC in the newline state causes a carriage return
to be delivered and puts the protocol into the escaped state.
Anything else causes a carriage return and that thing to be
delivered.
"""
self._deliver('\r')
self._deliver('\n', ('bytes', '\n'))
self._deliver('\r\n', ('bytes', '\n'))
self._deliver('\r')
self._deliver('\0', ('bytes', '\r'))
self._deliver('\r\0', ('bytes', '\r'))
self._deliver('\r')
self._deliver('a', ('bytes', '\ra'))
self._deliver('\ra', ('bytes', '\ra'))
self._deliver('\r')
self._deliver(
telnet.IAC + telnet.IAC + 'x', ('bytes', '\r' + telnet.IAC + 'x'))
def test_applicationDataBeforeSimpleCommand(self):
"""
Application bytes received before a command are delivered before the
command is processed.
"""
self._deliver(
'x' + telnet.IAC + telnet.NOP,
('bytes', 'x'), ('command', telnet.NOP, None))
def test_applicationDataBeforeCommand(self):
"""
Application bytes received before a WILL/WONT/DO/DONT are delivered
before the command is processed.
"""
self.protocol.commandMap = {}
self._deliver(
'y' + telnet.IAC + telnet.WILL + '\x00',
('bytes', 'y'), ('command', telnet.WILL, '\x00'))
def test_applicationDataBeforeSubnegotiation(self):
"""
Application bytes received before a subnegotiation command are
delivered before the negotiation is processed.
"""
self._deliver(
'z' + telnet.IAC + telnet.SB + 'Qx' + telnet.IAC + telnet.SE,
('bytes', 'z'), ('negotiate', 'Q', ['x']))

View file

@ -0,0 +1,161 @@
# -*- test-case-name: twisted.conch.test.test_text -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
from twisted.trial import unittest
from twisted.conch.insults import helper, text
from twisted.conch.insults.text import attributes as A
class FormattedTextTests(unittest.TestCase):
"""
Tests for assembling formatted text.
"""
def test_trivial(self):
"""
Using no formatting attributes produces no VT102 control sequences in
the flattened output.
"""
self.assertEqual(
text.assembleFormattedText(A.normal['Hello, world.']),
'Hello, world.')
def test_bold(self):
"""
The bold formatting attribute, L{A.bold}, emits the VT102 control
sequence to enable bold when flattened.
"""
self.assertEqual(
text.assembleFormattedText(A.bold['Hello, world.']),
'\x1b[1mHello, world.')
def test_underline(self):
"""
The underline formatting attribute, L{A.underline}, emits the VT102
control sequence to enable underlining when flattened.
"""
self.assertEqual(
text.assembleFormattedText(A.underline['Hello, world.']),
'\x1b[4mHello, world.')
def test_blink(self):
"""
The blink formatting attribute, L{A.blink}, emits the VT102 control
sequence to enable blinking when flattened.
"""
self.assertEqual(
text.assembleFormattedText(A.blink['Hello, world.']),
'\x1b[5mHello, world.')
def test_reverseVideo(self):
"""
The reverse-video formatting attribute, L{A.reverseVideo}, emits the
VT102 control sequence to enable reversed video when flattened.
"""
self.assertEqual(
text.assembleFormattedText(A.reverseVideo['Hello, world.']),
'\x1b[7mHello, world.')
def test_minus(self):
"""
Formatting attributes prefixed with a minus (C{-}) temporarily disable
the prefixed attribute, emitting no VT102 control sequence to enable
it in the flattened output.
"""
self.assertEqual(
text.assembleFormattedText(
A.bold[A.blink['Hello', -A.bold[' world'], '.']]),
'\x1b[1;5mHello\x1b[0;5m world\x1b[1;5m.')
def test_foreground(self):
"""
The foreground color formatting attribute, L{A.fg}, emits the VT102
control sequence to set the selected foreground color when flattened.
"""
self.assertEqual(
text.assembleFormattedText(
A.normal[A.fg.red['Hello, '], A.fg.green['world!']]),
'\x1b[31mHello, \x1b[32mworld!')
def test_background(self):
"""
The background color formatting attribute, L{A.bg}, emits the VT102
control sequence to set the selected background color when flattened.
"""
self.assertEqual(
text.assembleFormattedText(
A.normal[A.bg.red['Hello, '], A.bg.green['world!']]),
'\x1b[41mHello, \x1b[42mworld!')
def test_flattenDeprecated(self):
"""
L{twisted.conch.insults.text.flatten} emits a deprecation warning when
imported or accessed.
"""
warningsShown = self.flushWarnings([self.test_flattenDeprecated])
self.assertEqual(len(warningsShown), 0)
# Trigger the deprecation warning.
text.flatten
warningsShown = self.flushWarnings([self.test_flattenDeprecated])
self.assertEqual(len(warningsShown), 1)
self.assertEqual(warningsShown[0]['category'], DeprecationWarning)
self.assertEqual(
warningsShown[0]['message'],
'twisted.conch.insults.text.flatten was deprecated in Twisted '
'13.1.0: Use twisted.conch.insults.text.assembleFormattedText '
'instead.')
class EfficiencyTestCase(unittest.TestCase):
todo = ("flatten() isn't quite stateful enough to avoid emitting a few extra bytes in "
"certain circumstances, so these tests fail. The failures take the form of "
"additional elements in the ;-delimited character attribute lists. For example, "
"\\x1b[0;31;46m might be emitted instead of \\x[46m, even if 31 has already been "
"activated and no conflicting attributes are set which need to be cleared.")
def setUp(self):
self.attrs = helper._FormattingState()
def testComplexStructure(self):
output = A.normal[
A.bold[
A.bg.cyan[
A.fg.red[
"Foreground Red, Background Cyan, Bold",
A.blink[
"Blinking"],
-A.bold[
"Foreground Red, Background Cyan, normal"]],
A.fg.green[
"Foreground Green, Background Cyan, Bold"]]]]
self.assertEqual(
text.flatten(output, self.attrs),
"\x1b[1;31;46mForeground Red, Background Cyan, Bold"
"\x1b[5mBlinking"
"\x1b[0;31;46mForeground Red, Background Cyan, normal"
"\x1b[1;32;46mForeground Green, Background Cyan, Bold")
def testNesting(self):
self.assertEqual(
text.flatten(A.bold['Hello, ', A.underline['world.']], self.attrs),
'\x1b[1mHello, \x1b[4mworld.')
self.assertEqual(
text.flatten(
A.bold[A.reverseVideo['Hello, ', A.normal['world'], '.']],
self.attrs),
'\x1b[1;7mHello, \x1b[0mworld\x1b[1;7m.')

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,67 @@
"""
Tests for the insults windowing module, L{twisted.conch.insults.window}.
"""
from twisted.trial.unittest import TestCase
from twisted.conch.insults.window import TopWindow, ScrolledArea, TextOutput
class TopWindowTests(TestCase):
"""
Tests for L{TopWindow}, the root window container class.
"""
def test_paintScheduling(self):
"""
Verify that L{TopWindow.repaint} schedules an actual paint to occur
using the scheduling object passed to its initializer.
"""
paints = []
scheduled = []
root = TopWindow(lambda: paints.append(None), scheduled.append)
# Nothing should have happened yet.
self.assertEqual(paints, [])
self.assertEqual(scheduled, [])
# Cause a paint to be scheduled.
root.repaint()
self.assertEqual(paints, [])
self.assertEqual(len(scheduled), 1)
# Do another one to verify nothing else happens as long as the previous
# one is still pending.
root.repaint()
self.assertEqual(paints, [])
self.assertEqual(len(scheduled), 1)
# Run the actual paint call.
scheduled.pop()()
self.assertEqual(len(paints), 1)
self.assertEqual(scheduled, [])
# Do one more to verify that now that the previous one is finished
# future paints will succeed.
root.repaint()
self.assertEqual(len(paints), 1)
self.assertEqual(len(scheduled), 1)
class ScrolledAreaTests(TestCase):
"""
Tests for L{ScrolledArea}, a widget which creates a viewport containing
another widget and can reposition that viewport using scrollbars.
"""
def test_parent(self):
"""
The parent of the widget passed to L{ScrolledArea} is set to a new
L{Viewport} created by the L{ScrolledArea} which itself has the
L{ScrolledArea} instance as its parent.
"""
widget = TextOutput()
scrolled = ScrolledArea(widget)
self.assertIs(widget.parent, scrolled._viewport)
self.assertIs(scrolled._viewport.parent, scrolled)