mirror of
				https://gitee.com/gitea/gitea
				synced 2025-11-04 16:40:24 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			128 lines
		
	
	
		
			3.2 KiB
		
	
	
	
		
			Go
		
	
	
		
			Executable File
		
	
	
	
	
			
		
		
	
	
			128 lines
		
	
	
		
			3.2 KiB
		
	
	
	
		
			Go
		
	
	
		
			Executable File
		
	
	
	
	
// Copyright 2011 The Go Authors. All rights reserved.
 | 
						|
// Use of this source code is governed by a BSD-style
 | 
						|
// license that can be found in the LICENSE file.
 | 
						|
 | 
						|
package ssh
 | 
						|
 | 
						|
import (
 | 
						|
	"bytes"
 | 
						|
	"crypto"
 | 
						|
	"crypto/aes"
 | 
						|
	"crypto/rand"
 | 
						|
	"testing"
 | 
						|
)
 | 
						|
 | 
						|
func TestDefaultCiphersExist(t *testing.T) {
 | 
						|
	for _, cipherAlgo := range supportedCiphers {
 | 
						|
		if _, ok := cipherModes[cipherAlgo]; !ok {
 | 
						|
			t.Errorf("default cipher %q is unknown", cipherAlgo)
 | 
						|
		}
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestPacketCiphers(t *testing.T) {
 | 
						|
	// Still test aes128cbc cipher althought it's commented out.
 | 
						|
	cipherModes[aes128cbcID] = &streamCipherMode{16, aes.BlockSize, 0, nil}
 | 
						|
	defer delete(cipherModes, aes128cbcID)
 | 
						|
 | 
						|
	for cipher := range cipherModes {
 | 
						|
		kr := &kexResult{Hash: crypto.SHA1}
 | 
						|
		algs := directionAlgorithms{
 | 
						|
			Cipher:      cipher,
 | 
						|
			MAC:         "hmac-sha1",
 | 
						|
			Compression: "none",
 | 
						|
		}
 | 
						|
		client, err := newPacketCipher(clientKeys, algs, kr)
 | 
						|
		if err != nil {
 | 
						|
			t.Errorf("newPacketCipher(client, %q): %v", cipher, err)
 | 
						|
			continue
 | 
						|
		}
 | 
						|
		server, err := newPacketCipher(clientKeys, algs, kr)
 | 
						|
		if err != nil {
 | 
						|
			t.Errorf("newPacketCipher(client, %q): %v", cipher, err)
 | 
						|
			continue
 | 
						|
		}
 | 
						|
 | 
						|
		want := "bla bla"
 | 
						|
		input := []byte(want)
 | 
						|
		buf := &bytes.Buffer{}
 | 
						|
		if err := client.writePacket(0, buf, rand.Reader, input); err != nil {
 | 
						|
			t.Errorf("writePacket(%q): %v", cipher, err)
 | 
						|
			continue
 | 
						|
		}
 | 
						|
 | 
						|
		packet, err := server.readPacket(0, buf)
 | 
						|
		if err != nil {
 | 
						|
			t.Errorf("readPacket(%q): %v", cipher, err)
 | 
						|
			continue
 | 
						|
		}
 | 
						|
 | 
						|
		if string(packet) != want {
 | 
						|
			t.Errorf("roundtrip(%q): got %q, want %q", cipher, packet, want)
 | 
						|
		}
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestCBCOracleCounterMeasure(t *testing.T) {
 | 
						|
	cipherModes[aes128cbcID] = &streamCipherMode{16, aes.BlockSize, 0, nil}
 | 
						|
	defer delete(cipherModes, aes128cbcID)
 | 
						|
 | 
						|
	kr := &kexResult{Hash: crypto.SHA1}
 | 
						|
	algs := directionAlgorithms{
 | 
						|
		Cipher:      aes128cbcID,
 | 
						|
		MAC:         "hmac-sha1",
 | 
						|
		Compression: "none",
 | 
						|
	}
 | 
						|
	client, err := newPacketCipher(clientKeys, algs, kr)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("newPacketCipher(client): %v", err)
 | 
						|
	}
 | 
						|
 | 
						|
	want := "bla bla"
 | 
						|
	input := []byte(want)
 | 
						|
	buf := &bytes.Buffer{}
 | 
						|
	if err := client.writePacket(0, buf, rand.Reader, input); err != nil {
 | 
						|
		t.Errorf("writePacket: %v", err)
 | 
						|
	}
 | 
						|
 | 
						|
	packetSize := buf.Len()
 | 
						|
	buf.Write(make([]byte, 2*maxPacket))
 | 
						|
 | 
						|
	// We corrupt each byte, but this usually will only test the
 | 
						|
	// 'packet too large' or 'MAC failure' cases.
 | 
						|
	lastRead := -1
 | 
						|
	for i := 0; i < packetSize; i++ {
 | 
						|
		server, err := newPacketCipher(clientKeys, algs, kr)
 | 
						|
		if err != nil {
 | 
						|
			t.Fatalf("newPacketCipher(client): %v", err)
 | 
						|
		}
 | 
						|
 | 
						|
		fresh := &bytes.Buffer{}
 | 
						|
		fresh.Write(buf.Bytes())
 | 
						|
		fresh.Bytes()[i] ^= 0x01
 | 
						|
 | 
						|
		before := fresh.Len()
 | 
						|
		_, err = server.readPacket(0, fresh)
 | 
						|
		if err == nil {
 | 
						|
			t.Errorf("corrupt byte %d: readPacket succeeded ", i)
 | 
						|
			continue
 | 
						|
		}
 | 
						|
		if _, ok := err.(cbcError); !ok {
 | 
						|
			t.Errorf("corrupt byte %d: got %v (%T), want cbcError", i, err, err)
 | 
						|
			continue
 | 
						|
		}
 | 
						|
 | 
						|
		after := fresh.Len()
 | 
						|
		bytesRead := before - after
 | 
						|
		if bytesRead < maxPacket {
 | 
						|
			t.Errorf("corrupt byte %d: read %d bytes, want more than %d", i, bytesRead, maxPacket)
 | 
						|
			continue
 | 
						|
		}
 | 
						|
 | 
						|
		if i > 0 && bytesRead != lastRead {
 | 
						|
			t.Errorf("corrupt byte %d: read %d bytes, want %d bytes read", i, bytesRead, lastRead)
 | 
						|
		}
 | 
						|
		lastRead = bytesRead
 | 
						|
	}
 | 
						|
}
 |